79 lines
2.8 KiB
Python
79 lines
2.8 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
import json
|
|
from pathlib import Path
|
|
|
|
from .backprop import BackpropNetwork, BackpropResult
|
|
from .classical_rules import ClassicalInstructor
|
|
from .encoding import encode_note_sequence, encode_sequence_one_hot
|
|
from .io.legacy_files import load_salieri_config, load_salieri_weights
|
|
from .types import LegacyPaths, NoteSequence
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SalieriResult:
|
|
notes: NoteSequence
|
|
target: int
|
|
raw_output: float
|
|
is_classical: bool
|
|
error: float
|
|
network_result: BackpropResult
|
|
|
|
|
|
class SalieriCritic:
|
|
def __init__(self, *, network: BackpropNetwork, instructor: ClassicalInstructor) -> None:
|
|
self.network = network
|
|
self.instructor = instructor
|
|
|
|
@classmethod
|
|
def from_legacy_paths(cls, root: str | Path) -> "SalieriCritic":
|
|
paths = LegacyPaths(root=Path(root))
|
|
config = load_salieri_config(paths.salieri_config)
|
|
weights = load_salieri_weights(paths.salieri_weights)
|
|
instructor = ClassicalInstructor.from_sequence_file(paths.sequence_data)
|
|
network = BackpropNetwork.from_legacy(config=config, legacy_weights=weights)
|
|
return cls(network=network, instructor=instructor)
|
|
|
|
def evaluate_and_train(
|
|
self,
|
|
notes: list[int] | tuple[int, ...],
|
|
*,
|
|
target: int | None = None,
|
|
) -> SalieriResult:
|
|
sequence = encode_note_sequence(notes)
|
|
encoded = tuple(float(value) for value in encode_sequence_one_hot(sequence))
|
|
training_target = self.instructor(sequence) if target is None else int(target)
|
|
network_result = self.network.train_step(encoded, (float(training_target),))
|
|
raw_output = network_result.outputs[0]
|
|
return SalieriResult(
|
|
notes=sequence,
|
|
target=training_target,
|
|
raw_output=raw_output,
|
|
is_classical=raw_output > 0.5,
|
|
error=network_result.error,
|
|
network_result=network_result,
|
|
)
|
|
|
|
def to_dict(self) -> dict[str, object]:
|
|
return {
|
|
"network": self.network.to_dict(),
|
|
"sequences": list(self.instructor.sequences),
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: dict[str, object]) -> "SalieriCritic":
|
|
network = BackpropNetwork.from_dict(data["network"]) # type: ignore[arg-type]
|
|
instructor = ClassicalInstructor(sequences=tuple(data["sequences"])) # type: ignore[arg-type]
|
|
return cls(network=network, instructor=instructor)
|
|
|
|
def save_json(self, path: str) -> None:
|
|
with open(path, "w", encoding="utf-8") as handle:
|
|
json.dump(self.to_dict(), handle, indent=2)
|
|
|
|
@classmethod
|
|
def load_json(cls, path: str) -> "SalieriCritic":
|
|
with open(path, "r", encoding="utf-8") as handle:
|
|
data = json.load(handle)
|
|
return cls.from_dict(data)
|