TriuneCadence/composer_ans/cli.py

82 lines
3.0 KiB
Python

from __future__ import annotations
import argparse
from pathlib import Path
from .beethoven import BeethovenCategorizer
from .pipeline import CompositionPipeline
from .reporting import build_run_report, save_run_report_json
from .salieri import SalieriCritic
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(prog="triune-cadence")
parser.add_argument("--thes-root", default="THES")
parser.add_argument("--notes", type=int, default=16)
parser.add_argument("--object-threshold", type=int, default=3)
parser.add_argument("--max-attempts-per-note", type=int, default=500)
parser.add_argument("--art-vigilance", type=float, default=0.9)
parser.add_argument("--art-vigilance-decay", type=float, default=0.99)
parser.add_argument("--save-salieri")
parser.add_argument("--save-beethoven")
parser.add_argument("--load-salieri")
parser.add_argument("--load-beethoven")
parser.add_argument("--save-report")
return parser
def main() -> int:
args = build_parser().parse_args()
root = Path(args.thes_root)
pipeline = CompositionPipeline.from_legacy_data_with_options(
root,
object_threshold=args.object_threshold,
art_vigilance=args.art_vigilance,
art_vigilance_decay=args.art_vigilance_decay,
)
if args.load_salieri:
pipeline.salieri = SalieriCritic.load_json(args.load_salieri)
if args.load_beethoven:
pipeline.beethoven = BeethovenCategorizer.load_json(args.load_beethoven)
record = pipeline.compose(
max_notes=args.notes,
max_attempts_per_note=args.max_attempts_per_note,
)
report = build_run_report(
record,
parameters={
"thes_root": str(root),
"notes_requested": args.notes,
"object_threshold": args.object_threshold,
"max_attempts_per_note": args.max_attempts_per_note,
"art_vigilance": args.art_vigilance,
"art_vigilance_decay": args.art_vigilance_decay,
},
)
print("notes:", " ".join(str(note) for note in report.notes))
print(
"per_note_seconds:",
" ".join(f"{elapsed:.6f}" for elapsed in report.per_note_seconds),
)
print(f"total_seconds: {report.total_seconds:.6f}")
if report.per_note_seconds:
mean_seconds = sum(report.per_note_seconds) / len(report.per_note_seconds)
print(f"mean_note_seconds: {mean_seconds:.6f}")
print(f"unigram_entropy_bits: {report.unigram_entropy_bits:.4f}")
print(f"conditional_entropy_bits: {report.conditional_entropy_bits:.4f}")
print(f"normalized_entropy: {report.normalized_entropy:.4f}")
print(f"predictability: {report.predictability:.4f}")
print(f"redundancy: {report.redundancy:.4f}")
if args.save_salieri:
pipeline.salieri.save_json(args.save_salieri)
if args.save_beethoven:
pipeline.beethoven.save_json(args.save_beethoven)
if args.save_report:
save_run_report_json(report, args.save_report)
return 0