from __future__ import annotations from dataclasses import dataclass import json from pathlib import Path from .backprop import BackpropNetwork, BackpropResult from .classical_rules import ClassicalInstructor from .encoding import encode_note_sequence, encode_sequence_one_hot from .io.legacy_files import load_salieri_config, load_salieri_weights from .types import LegacyPaths, NoteSequence @dataclass(frozen=True) class SalieriResult: notes: NoteSequence target: int raw_output: float is_classical: bool error: float network_result: BackpropResult class SalieriCritic: def __init__(self, *, network: BackpropNetwork, instructor: ClassicalInstructor) -> None: self.network = network self.instructor = instructor @classmethod def from_legacy_paths(cls, root: str | Path) -> "SalieriCritic": paths = LegacyPaths(root=Path(root)) config = load_salieri_config(paths.salieri_config) weights = load_salieri_weights(paths.salieri_weights) instructor = ClassicalInstructor.from_sequence_file(paths.sequence_data) network = BackpropNetwork.from_legacy(config=config, legacy_weights=weights) return cls(network=network, instructor=instructor) def evaluate_and_train( self, notes: list[int] | tuple[int, ...], *, target: int | None = None, ) -> SalieriResult: sequence = encode_note_sequence(notes) encoded = tuple(float(value) for value in encode_sequence_one_hot(sequence)) training_target = self.instructor(sequence) if target is None else int(target) network_result = self.network.train_step(encoded, (float(training_target),)) raw_output = network_result.outputs[0] return SalieriResult( notes=sequence, target=training_target, raw_output=raw_output, is_classical=raw_output > 0.5, error=network_result.error, network_result=network_result, ) def to_dict(self) -> dict[str, object]: return { "network": self.network.to_dict(), "sequences": list(self.instructor.sequences), } @classmethod def from_dict(cls, data: dict[str, object]) -> "SalieriCritic": network = BackpropNetwork.from_dict(data["network"]) # type: ignore[arg-type] instructor = ClassicalInstructor(sequences=tuple(data["sequences"])) # type: ignore[arg-type] return cls(network=network, instructor=instructor) def save_json(self, path: str) -> None: with open(path, "w", encoding="utf-8") as handle: json.dump(self.to_dict(), handle, indent=2) @classmethod def load_json(cls, path: str) -> "SalieriCritic": with open(path, "r", encoding="utf-8") as handle: data = json.load(handle) return cls.from_dict(data)