53 lines
1.5 KiB
Python
53 lines
1.5 KiB
Python
from __future__ import annotations
|
|
|
|
from synaptopus.hopfield import HopfieldNetwork, HopfieldParams
|
|
|
|
|
|
def test_hopfield_zero_matrix_runs_on_arbitrary_grid_shape() -> None:
|
|
inputs = (
|
|
(0.8, 0.2),
|
|
(0.1, 0.9),
|
|
(0.4, 0.3),
|
|
)
|
|
size = len(inputs) * len(inputs[0])
|
|
weights = tuple(tuple(0.0 for _ in range(size)) for _ in range(size))
|
|
|
|
result = HopfieldNetwork(weight_matrix=weights).run(inputs)
|
|
|
|
assert result.iterations > 0
|
|
assert len(result.state.outputs) == 3
|
|
assert len(result.state.outputs[0]) == 2
|
|
|
|
|
|
def test_hopfield_respects_initial_activation_shape() -> None:
|
|
inputs = (
|
|
(0.5, 0.5),
|
|
(0.5, 0.5),
|
|
)
|
|
weights = tuple(tuple(0.0 for _ in range(4)) for _ in range(4))
|
|
network = HopfieldNetwork(weight_matrix=weights, params=HopfieldParams())
|
|
|
|
result = network.run(
|
|
inputs,
|
|
initial_activations=(
|
|
(0.1, 0.2),
|
|
(0.3, 0.4),
|
|
),
|
|
)
|
|
|
|
assert len(result.state.activations) == 2
|
|
assert len(result.state.activations[0]) == 2
|
|
|
|
|
|
def test_hopfield_round_trips_through_dict() -> None:
|
|
weights = tuple(tuple(float(i == j) for j in range(4)) for i in range(4))
|
|
network = HopfieldNetwork(
|
|
weight_matrix=weights,
|
|
params=HopfieldParams(epsilon=0.01, weight_scale=0.5),
|
|
)
|
|
|
|
restored = HopfieldNetwork.from_dict(network.to_dict())
|
|
|
|
assert restored.weight_matrix == network.weight_matrix
|
|
assert restored.params == network.params
|