Synaptopus/tests/test_hopfield.py

53 lines
1.5 KiB
Python

from __future__ import annotations
from synaptopus.hopfield import HopfieldNetwork, HopfieldParams
def test_hopfield_zero_matrix_runs_on_arbitrary_grid_shape() -> None:
inputs = (
(0.8, 0.2),
(0.1, 0.9),
(0.4, 0.3),
)
size = len(inputs) * len(inputs[0])
weights = tuple(tuple(0.0 for _ in range(size)) for _ in range(size))
result = HopfieldNetwork(weight_matrix=weights).run(inputs)
assert result.iterations > 0
assert len(result.state.outputs) == 3
assert len(result.state.outputs[0]) == 2
def test_hopfield_respects_initial_activation_shape() -> None:
inputs = (
(0.5, 0.5),
(0.5, 0.5),
)
weights = tuple(tuple(0.0 for _ in range(4)) for _ in range(4))
network = HopfieldNetwork(weight_matrix=weights, params=HopfieldParams())
result = network.run(
inputs,
initial_activations=(
(0.1, 0.2),
(0.3, 0.4),
),
)
assert len(result.state.activations) == 2
assert len(result.state.activations[0]) == 2
def test_hopfield_round_trips_through_dict() -> None:
weights = tuple(tuple(float(i == j) for j in range(4)) for i in range(4))
network = HopfieldNetwork(
weight_matrix=weights,
params=HopfieldParams(epsilon=0.01, weight_scale=0.5),
)
restored = HopfieldNetwork.from_dict(network.to_dict())
assert restored.weight_matrix == network.weight_matrix
assert restored.params == network.params