TriuneCadence/composer_ans/beethoven.py

78 lines
2.3 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
import json
from .art1 import ART1Network, ART1Params, ART1Result
from .encoding import encode_art_input, encode_note_sequence
from .types import ART_CATEGORY_LIMIT, ART_INPUT_LENGTH, NoteSequence
@dataclass(frozen=True)
class BeethovenResult:
notes: NoteSequence
is_classical: bool
art_result: ART1Result
class BeethovenCategorizer:
def __init__(self, network: ART1Network | None = None) -> None:
self.network = network or ART1Network(
ART1Params(
max_categories=ART_CATEGORY_LIMIT,
input_length=ART_INPUT_LENGTH,
)
)
def categorize(
self,
notes: list[int] | tuple[int, ...],
*,
is_classical: bool,
) -> BeethovenResult:
sequence = encode_note_sequence(notes)
input_vector = encode_art_input(sequence, is_classical=is_classical)
art_result = self.network.categorize(input_vector)
return BeethovenResult(
notes=sequence,
is_classical=is_classical,
art_result=art_result,
)
@classmethod
def with_params(
cls,
*,
max_categories: int = ART_CATEGORY_LIMIT,
input_length: int = ART_INPUT_LENGTH,
vigilance: float = 0.9,
vigilance_decay: float = 0.99,
) -> "BeethovenCategorizer":
return cls(
network=ART1Network(
ART1Params(
max_categories=max_categories,
input_length=input_length,
vigilance=vigilance,
vigilance_decay=vigilance_decay,
)
)
)
def to_dict(self) -> dict[str, object]:
return {"network": self.network.to_dict()}
@classmethod
def from_dict(cls, data: dict[str, object]) -> "BeethovenCategorizer":
return cls(network=ART1Network.from_dict(data["network"])) # type: ignore[arg-type]
def save_json(self, path: str) -> None:
with open(path, "w", encoding="utf-8") as handle:
json.dump(self.to_dict(), handle, indent=2)
@classmethod
def load_json(cls, path: str) -> "BeethovenCategorizer":
with open(path, "r", encoding="utf-8") as handle:
data = json.load(handle)
return cls.from_dict(data)