52 lines
1.5 KiB
Python
52 lines
1.5 KiB
Python
from pathlib import Path
|
|
|
|
from composer_ans.io.legacy_files import (
|
|
extract_active_hopfield_submatrix,
|
|
load_hopfield_weight_matrix,
|
|
load_salieri_config,
|
|
load_salieri_weights,
|
|
load_sequence_table,
|
|
)
|
|
|
|
|
|
THES = Path(__file__).resolve().parents[1] / "THES"
|
|
|
|
|
|
def test_load_sequence_table() -> None:
|
|
sequences = load_sequence_table(THES / "SEQUENCE.DAT")
|
|
assert len(sequences) == 14
|
|
assert sequences[-2:] == ("251", "258")
|
|
|
|
|
|
def test_load_salieri_config() -> None:
|
|
config = load_salieri_config(THES / "S61.DAT")
|
|
assert config.learning_rate == 0.5
|
|
assert config.alpha == 0.5
|
|
assert config.n_input == 40
|
|
assert config.n_hidden == 20
|
|
assert config.n_output == 1
|
|
assert config.training_iterations == 1
|
|
assert config.error_tolerance == 0.1
|
|
assert config.data_file == "s61.dat"
|
|
assert config.weight_file == "s61.wt"
|
|
|
|
|
|
def test_load_salieri_weights() -> None:
|
|
weights = load_salieri_weights(THES / "S61.WT")
|
|
assert weights.vector_length == 61
|
|
assert len(weights.weights) == 61
|
|
assert len(weights.weights[0]) == 61
|
|
assert len(weights.thetas) == 61
|
|
assert weights.weights[0][:5] == (0.0, 0.0, 0.0, 0.0, 0.0)
|
|
|
|
|
|
def test_load_hopfield_weight_matrix() -> None:
|
|
matrix = load_hopfield_weight_matrix(THES / "HTN.DAT")
|
|
active = extract_active_hopfield_submatrix(matrix)
|
|
assert len(matrix) == 64
|
|
assert len(matrix[0]) == 64
|
|
assert len(active) == 40
|
|
assert len(active[0]) == 40
|
|
assert matrix[0][0] == 0.0
|
|
assert abs(matrix[0][1] - (-0.35199999809265137)) < 1e-7
|