58 lines
1.5 KiB
Python
58 lines
1.5 KiB
Python
from __future__ import annotations
|
|
|
|
from synaptopus.hopfield_build import (
|
|
HopfieldGridShape,
|
|
accumulate_sequence_transitions,
|
|
apply_grid_inhibition,
|
|
clear_diagonal,
|
|
grid_index,
|
|
)
|
|
|
|
|
|
def test_accumulate_sequence_transitions_builds_symmetric_weights() -> None:
|
|
shape = HopfieldGridShape(row_count=3, column_count=3)
|
|
|
|
weights = accumulate_sequence_transitions(
|
|
shape,
|
|
sequences=[(1, 2, 3)],
|
|
transition_offsets=(1,),
|
|
weight_increment=-0.5,
|
|
)
|
|
|
|
left = grid_index(0, 0, shape)
|
|
right = grid_index(1, 1, shape)
|
|
assert weights[left][right] == -0.5
|
|
assert weights[right][left] == -0.5
|
|
|
|
|
|
def test_apply_grid_inhibition_matches_row_and_column_structure() -> None:
|
|
shape = HopfieldGridShape(row_count=3, column_count=2)
|
|
weights = tuple(tuple(0.0 for _ in range(shape.size)) for _ in range(shape.size))
|
|
|
|
inhibited = apply_grid_inhibition(
|
|
weights,
|
|
shape,
|
|
row_inhibition=-0.2,
|
|
column_inhibition=-0.1,
|
|
)
|
|
|
|
current = grid_index(1, 0, shape)
|
|
same_column_other_row = grid_index(0, 0, shape)
|
|
same_row_other_column = grid_index(1, 1, shape)
|
|
|
|
assert inhibited[current][same_column_other_row] == -0.1
|
|
assert inhibited[current][same_row_other_column] == -0.2
|
|
|
|
|
|
def test_clear_diagonal_zeros_self_connections() -> None:
|
|
weights = (
|
|
(1.0, 2.0),
|
|
(3.0, 4.0),
|
|
)
|
|
|
|
cleared = clear_diagonal(weights)
|
|
|
|
assert cleared[0][0] == 0.0
|
|
assert cleared[1][1] == 0.0
|
|
assert cleared[0][1] == 2.0
|