from __future__ import annotations from dataclasses import dataclass import json from .art1 import ART1Network, ART1Params, ART1Result from .encoding import encode_art_input, encode_note_sequence from .types import ART_CATEGORY_LIMIT, ART_INPUT_LENGTH, NoteSequence @dataclass(frozen=True) class BeethovenResult: notes: NoteSequence is_classical: bool art_result: ART1Result class BeethovenCategorizer: def __init__(self, network: ART1Network | None = None) -> None: self.network = network or ART1Network( ART1Params( max_categories=ART_CATEGORY_LIMIT, input_length=ART_INPUT_LENGTH, ) ) def categorize( self, notes: list[int] | tuple[int, ...], *, is_classical: bool, ) -> BeethovenResult: sequence = encode_note_sequence(notes) input_vector = encode_art_input(sequence, is_classical=is_classical) art_result = self.network.categorize(input_vector) return BeethovenResult( notes=sequence, is_classical=is_classical, art_result=art_result, ) @classmethod def with_params( cls, *, max_categories: int = ART_CATEGORY_LIMIT, input_length: int = ART_INPUT_LENGTH, vigilance: float = 0.9, vigilance_decay: float = 0.99, ) -> "BeethovenCategorizer": return cls( network=ART1Network( ART1Params( max_categories=max_categories, input_length=input_length, vigilance=vigilance, vigilance_decay=vigilance_decay, ) ) ) def to_dict(self) -> dict[str, object]: return {"network": self.network.to_dict()} @classmethod def from_dict(cls, data: dict[str, object]) -> "BeethovenCategorizer": return cls(network=ART1Network.from_dict(data["network"])) # type: ignore[arg-type] def save_json(self, path: str) -> None: with open(path, "w", encoding="utf-8") as handle: json.dump(self.to_dict(), handle, indent=2) @classmethod def load_json(cls, path: str) -> "BeethovenCategorizer": with open(path, "r", encoding="utf-8") as handle: data = json.load(handle) return cls.from_dict(data)