78 lines
2.3 KiB
Python
78 lines
2.3 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
import json
|
|
|
|
from .art1 import ART1Network, ART1Params, ART1Result
|
|
from .encoding import encode_art_input, encode_note_sequence
|
|
from .types import ART_CATEGORY_LIMIT, ART_INPUT_LENGTH, NoteSequence
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class BeethovenResult:
|
|
notes: NoteSequence
|
|
is_classical: bool
|
|
art_result: ART1Result
|
|
|
|
|
|
class BeethovenCategorizer:
|
|
def __init__(self, network: ART1Network | None = None) -> None:
|
|
self.network = network or ART1Network(
|
|
ART1Params(
|
|
max_categories=ART_CATEGORY_LIMIT,
|
|
input_length=ART_INPUT_LENGTH,
|
|
)
|
|
)
|
|
|
|
def categorize(
|
|
self,
|
|
notes: list[int] | tuple[int, ...],
|
|
*,
|
|
is_classical: bool,
|
|
) -> BeethovenResult:
|
|
sequence = encode_note_sequence(notes)
|
|
input_vector = encode_art_input(sequence, is_classical=is_classical)
|
|
art_result = self.network.categorize(input_vector)
|
|
return BeethovenResult(
|
|
notes=sequence,
|
|
is_classical=is_classical,
|
|
art_result=art_result,
|
|
)
|
|
|
|
@classmethod
|
|
def with_params(
|
|
cls,
|
|
*,
|
|
max_categories: int = ART_CATEGORY_LIMIT,
|
|
input_length: int = ART_INPUT_LENGTH,
|
|
vigilance: float = 0.9,
|
|
vigilance_decay: float = 0.99,
|
|
) -> "BeethovenCategorizer":
|
|
return cls(
|
|
network=ART1Network(
|
|
ART1Params(
|
|
max_categories=max_categories,
|
|
input_length=input_length,
|
|
vigilance=vigilance,
|
|
vigilance_decay=vigilance_decay,
|
|
)
|
|
)
|
|
)
|
|
|
|
def to_dict(self) -> dict[str, object]:
|
|
return {"network": self.network.to_dict()}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: dict[str, object]) -> "BeethovenCategorizer":
|
|
return cls(network=ART1Network.from_dict(data["network"])) # type: ignore[arg-type]
|
|
|
|
def save_json(self, path: str) -> None:
|
|
with open(path, "w", encoding="utf-8") as handle:
|
|
json.dump(self.to_dict(), handle, indent=2)
|
|
|
|
@classmethod
|
|
def load_json(cls, path: str) -> "BeethovenCategorizer":
|
|
with open(path, "r", encoding="utf-8") as handle:
|
|
data = json.load(handle)
|
|
return cls.from_dict(data)
|