TriuneCadence/composer_ans/encoding.py

35 lines
1.2 KiB
Python

from __future__ import annotations
from .types import ART_INPUT_LENGTH, NOTE_VOCABULARY_SIZE, NoteSequence, SEQUENCE_LENGTH
def encode_note_sequence(notes: list[int] | tuple[int, ...]) -> NoteSequence:
if len(notes) != SEQUENCE_LENGTH:
raise ValueError(f"expected {SEQUENCE_LENGTH} notes, got {len(notes)}")
encoded = tuple(int(note) for note in notes)
for note in encoded:
if not 0 <= note <= NOTE_VOCABULARY_SIZE:
raise ValueError(f"note out of range: {note}")
return encoded
def encode_sequence_one_hot(notes: list[int] | tuple[int, ...]) -> tuple[int, ...]:
encoded = encode_note_sequence(notes)
vector = [0] * (SEQUENCE_LENGTH * NOTE_VOCABULARY_SIZE)
for index, note in enumerate(encoded):
if note > 0:
vector[index * NOTE_VOCABULARY_SIZE + (note - 1)] = 1
return tuple(vector)
def encode_art_input(
notes: list[int] | tuple[int, ...],
*,
is_classical: bool,
) -> tuple[int, ...]:
vector = list(encode_sequence_one_hot(notes))
vector.append(1 if is_classical else 0)
if len(vector) != ART_INPUT_LENGTH:
raise AssertionError(f"unexpected ART input length: {len(vector)}")
return tuple(vector)