TriuneCadence/composer_ans/experiments.py

63 lines
2.1 KiB
Python

from __future__ import annotations
from dataclasses import asdict
import csv
from pathlib import Path
from .pipeline import CompositionPipeline
from .reporting import build_run_report, save_run_report_json
from .types import CompositionRunReport
def run_parameter_sweep(
*,
thes_root: str | Path,
output_dir: str | Path,
notes: int,
parameter_sets: list[dict[str, object]],
) -> list[CompositionRunReport]:
root = Path(thes_root)
destination = Path(output_dir)
destination.mkdir(parents=True, exist_ok=True)
reports: list[CompositionRunReport] = []
for index, params in enumerate(parameter_sets, start=1):
pipeline = CompositionPipeline.from_legacy_data_with_options(
root,
object_threshold=int(params.get("object_threshold", 3)),
art_vigilance=float(params.get("art_vigilance", 0.9)),
art_vigilance_decay=float(params.get("art_vigilance_decay", 0.99)),
)
max_attempts = int(params.get("max_attempts_per_note", 500))
record = pipeline.compose(max_notes=notes, max_attempts_per_note=max_attempts)
report = build_run_report(
record,
parameters={
"notes_requested": notes,
**params,
},
)
save_run_report_json(report, str(destination / f"run_{index:03d}.json"))
reports.append(report)
_write_summary_csv(destination / "summary.csv", reports)
return reports
def _write_summary_csv(path: Path, reports: list[CompositionRunReport]) -> None:
if not reports:
return
rows = []
for report in reports:
row = asdict(report)
row["notes"] = " ".join(str(note) for note in report.notes)
row["per_note_seconds"] = " ".join(f"{value:.6f}" for value in report.per_note_seconds)
row["parameters"] = str(report.parameters)
rows.append(row)
fieldnames = list(rows[0].keys())
with path.open("w", encoding="utf-8", newline="") as handle:
writer = csv.DictWriter(handle, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(rows)