TriuneCadence/composer_ans/salieri.py

79 lines
2.8 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
import json
from pathlib import Path
from .backprop import BackpropNetwork, BackpropResult
from .classical_rules import ClassicalInstructor
from .encoding import encode_note_sequence, encode_sequence_one_hot
from .io.legacy_files import load_salieri_config, load_salieri_weights
from .types import LegacyPaths, NoteSequence
@dataclass(frozen=True)
class SalieriResult:
notes: NoteSequence
target: int
raw_output: float
is_classical: bool
error: float
network_result: BackpropResult
class SalieriCritic:
def __init__(self, *, network: BackpropNetwork, instructor: ClassicalInstructor) -> None:
self.network = network
self.instructor = instructor
@classmethod
def from_legacy_paths(cls, root: str | Path) -> "SalieriCritic":
paths = LegacyPaths(root=Path(root))
config = load_salieri_config(paths.salieri_config)
weights = load_salieri_weights(paths.salieri_weights)
instructor = ClassicalInstructor.from_sequence_file(paths.sequence_data)
network = BackpropNetwork.from_legacy(config=config, legacy_weights=weights)
return cls(network=network, instructor=instructor)
def evaluate_and_train(
self,
notes: list[int] | tuple[int, ...],
*,
target: int | None = None,
) -> SalieriResult:
sequence = encode_note_sequence(notes)
encoded = tuple(float(value) for value in encode_sequence_one_hot(sequence))
training_target = self.instructor(sequence) if target is None else int(target)
network_result = self.network.train_step(encoded, (float(training_target),))
raw_output = network_result.outputs[0]
return SalieriResult(
notes=sequence,
target=training_target,
raw_output=raw_output,
is_classical=raw_output > 0.5,
error=network_result.error,
network_result=network_result,
)
def to_dict(self) -> dict[str, object]:
return {
"network": self.network.to_dict(),
"sequences": list(self.instructor.sequences),
}
@classmethod
def from_dict(cls, data: dict[str, object]) -> "SalieriCritic":
network = BackpropNetwork.from_dict(data["network"]) # type: ignore[arg-type]
instructor = ClassicalInstructor(sequences=tuple(data["sequences"])) # type: ignore[arg-type]
return cls(network=network, instructor=instructor)
def save_json(self, path: str) -> None:
with open(path, "w", encoding="utf-8") as handle:
json.dump(self.to_dict(), handle, indent=2)
@classmethod
def load_json(cls, path: str) -> "SalieriCritic":
with open(path, "r", encoding="utf-8") as handle:
data = json.load(handle)
return cls.from_dict(data)