35 lines
1.2 KiB
Python
35 lines
1.2 KiB
Python
from __future__ import annotations
|
|
|
|
from .types import ART_INPUT_LENGTH, NOTE_VOCABULARY_SIZE, NoteSequence, SEQUENCE_LENGTH
|
|
|
|
|
|
def encode_note_sequence(notes: list[int] | tuple[int, ...]) -> NoteSequence:
|
|
if len(notes) != SEQUENCE_LENGTH:
|
|
raise ValueError(f"expected {SEQUENCE_LENGTH} notes, got {len(notes)}")
|
|
encoded = tuple(int(note) for note in notes)
|
|
for note in encoded:
|
|
if not 0 <= note <= NOTE_VOCABULARY_SIZE:
|
|
raise ValueError(f"note out of range: {note}")
|
|
return encoded
|
|
|
|
|
|
def encode_sequence_one_hot(notes: list[int] | tuple[int, ...]) -> tuple[int, ...]:
|
|
encoded = encode_note_sequence(notes)
|
|
vector = [0] * (SEQUENCE_LENGTH * NOTE_VOCABULARY_SIZE)
|
|
for index, note in enumerate(encoded):
|
|
if note > 0:
|
|
vector[index * NOTE_VOCABULARY_SIZE + (note - 1)] = 1
|
|
return tuple(vector)
|
|
|
|
|
|
def encode_art_input(
|
|
notes: list[int] | tuple[int, ...],
|
|
*,
|
|
is_classical: bool,
|
|
) -> tuple[int, ...]:
|
|
vector = list(encode_sequence_one_hot(notes))
|
|
vector.append(1 if is_classical else 0)
|
|
if len(vector) != ART_INPUT_LENGTH:
|
|
raise AssertionError(f"unexpected ART input length: {len(vector)}")
|
|
return tuple(vector)
|