63 lines
2.1 KiB
Python
63 lines
2.1 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import asdict
|
|
import csv
|
|
from pathlib import Path
|
|
|
|
from .pipeline import CompositionPipeline
|
|
from .reporting import build_run_report, save_run_report_json
|
|
from .types import CompositionRunReport
|
|
|
|
|
|
def run_parameter_sweep(
|
|
*,
|
|
thes_root: str | Path,
|
|
output_dir: str | Path,
|
|
notes: int,
|
|
parameter_sets: list[dict[str, object]],
|
|
) -> list[CompositionRunReport]:
|
|
root = Path(thes_root)
|
|
destination = Path(output_dir)
|
|
destination.mkdir(parents=True, exist_ok=True)
|
|
|
|
reports: list[CompositionRunReport] = []
|
|
for index, params in enumerate(parameter_sets, start=1):
|
|
pipeline = CompositionPipeline.from_legacy_data_with_options(
|
|
root,
|
|
object_threshold=int(params.get("object_threshold", 3)),
|
|
art_vigilance=float(params.get("art_vigilance", 0.9)),
|
|
art_vigilance_decay=float(params.get("art_vigilance_decay", 0.99)),
|
|
)
|
|
max_attempts = int(params.get("max_attempts_per_note", 500))
|
|
record = pipeline.compose(max_notes=notes, max_attempts_per_note=max_attempts)
|
|
report = build_run_report(
|
|
record,
|
|
parameters={
|
|
"notes_requested": notes,
|
|
**params,
|
|
},
|
|
)
|
|
save_run_report_json(report, str(destination / f"run_{index:03d}.json"))
|
|
reports.append(report)
|
|
|
|
_write_summary_csv(destination / "summary.csv", reports)
|
|
return reports
|
|
|
|
|
|
def _write_summary_csv(path: Path, reports: list[CompositionRunReport]) -> None:
|
|
if not reports:
|
|
return
|
|
rows = []
|
|
for report in reports:
|
|
row = asdict(report)
|
|
row["notes"] = " ".join(str(note) for note in report.notes)
|
|
row["per_note_seconds"] = " ".join(f"{value:.6f}" for value in report.per_note_seconds)
|
|
row["parameters"] = str(report.parameters)
|
|
rows.append(row)
|
|
|
|
fieldnames = list(rows[0].keys())
|
|
with path.open("w", encoding="utf-8", newline="") as handle:
|
|
writer = csv.DictWriter(handle, fieldnames=fieldnames)
|
|
writer.writeheader()
|
|
writer.writerows(rows)
|