from __future__ import annotations from .types import ART_INPUT_LENGTH, NOTE_VOCABULARY_SIZE, NoteSequence, SEQUENCE_LENGTH def encode_note_sequence(notes: list[int] | tuple[int, ...]) -> NoteSequence: if len(notes) != SEQUENCE_LENGTH: raise ValueError(f"expected {SEQUENCE_LENGTH} notes, got {len(notes)}") encoded = tuple(int(note) for note in notes) for note in encoded: if not 0 <= note <= NOTE_VOCABULARY_SIZE: raise ValueError(f"note out of range: {note}") return encoded def encode_sequence_one_hot(notes: list[int] | tuple[int, ...]) -> tuple[int, ...]: encoded = encode_note_sequence(notes) vector = [0] * (SEQUENCE_LENGTH * NOTE_VOCABULARY_SIZE) for index, note in enumerate(encoded): if note > 0: vector[index * NOTE_VOCABULARY_SIZE + (note - 1)] = 1 return tuple(vector) def encode_art_input( notes: list[int] | tuple[int, ...], *, is_classical: bool, ) -> tuple[int, ...]: vector = list(encode_sequence_one_hot(notes)) vector.append(1 if is_classical else 0) if len(vector) != ART_INPUT_LENGTH: raise AssertionError(f"unexpected ART input length: {len(vector)}") return tuple(vector)