22 lines
833 B
Python
22 lines
833 B
Python
from composer_ans.encoding import encode_art_input, encode_note_sequence, encode_sequence_one_hot
|
|
|
|
|
|
def test_encode_note_sequence_validates_shape_and_range() -> None:
|
|
assert encode_note_sequence([1, 2, 3, 4, 5]) == (1, 2, 3, 4, 5)
|
|
|
|
|
|
def test_encode_sequence_one_hot_matches_pascal_layout() -> None:
|
|
vector = encode_sequence_one_hot([1, 0, 8, 2, 0])
|
|
assert len(vector) == 40
|
|
assert vector[:8] == (1, 0, 0, 0, 0, 0, 0, 0)
|
|
assert vector[8:16] == (0, 0, 0, 0, 0, 0, 0, 0)
|
|
assert vector[16:24] == (0, 0, 0, 0, 0, 0, 0, 1)
|
|
assert vector[24:32] == (0, 1, 0, 0, 0, 0, 0, 0)
|
|
assert vector[32:40] == (0, 0, 0, 0, 0, 0, 0, 0)
|
|
|
|
|
|
def test_encode_art_input_appends_classicality_bit() -> None:
|
|
vector = encode_art_input([1, 2, 3, 4, 5], is_classical=True)
|
|
assert len(vector) == 41
|
|
assert vector[-1] == 1
|