Synaptopus/tests/test_hopfield_build.py

58 lines
1.5 KiB
Python

from __future__ import annotations
from synaptopus.hopfield_build import (
HopfieldGridShape,
accumulate_sequence_transitions,
apply_grid_inhibition,
clear_diagonal,
grid_index,
)
def test_accumulate_sequence_transitions_builds_symmetric_weights() -> None:
shape = HopfieldGridShape(row_count=3, column_count=3)
weights = accumulate_sequence_transitions(
shape,
sequences=[(1, 2, 3)],
transition_offsets=(1,),
weight_increment=-0.5,
)
left = grid_index(0, 0, shape)
right = grid_index(1, 1, shape)
assert weights[left][right] == -0.5
assert weights[right][left] == -0.5
def test_apply_grid_inhibition_matches_row_and_column_structure() -> None:
shape = HopfieldGridShape(row_count=3, column_count=2)
weights = tuple(tuple(0.0 for _ in range(shape.size)) for _ in range(shape.size))
inhibited = apply_grid_inhibition(
weights,
shape,
row_inhibition=-0.2,
column_inhibition=-0.1,
)
current = grid_index(1, 0, shape)
same_column_other_row = grid_index(0, 0, shape)
same_row_other_column = grid_index(1, 1, shape)
assert inhibited[current][same_column_other_row] == -0.1
assert inhibited[current][same_row_other_column] == -0.2
def test_clear_diagonal_zeros_self_connections() -> None:
weights = (
(1.0, 2.0),
(3.0, 4.0),
)
cleared = clear_diagonal(weights)
assert cleared[0][0] == 0.0
assert cleared[1][1] == 0.0
assert cleared[0][1] == 2.0