48 lines
1.5 KiB
Python
48 lines
1.5 KiB
Python
from __future__ import annotations
|
|
|
|
from synaptopus.art1 import ART1Network, ART1Params
|
|
|
|
|
|
def test_art1_commits_first_category() -> None:
|
|
network = ART1Network(ART1Params(max_categories=3, input_length=4, vigilance=0.9))
|
|
|
|
result = network.categorize((1, 0, 1, 0))
|
|
|
|
assert result.winner == 0
|
|
assert result.new_category is True
|
|
assert result.committed_categories == 1
|
|
assert result.expected_vector == (1, 0, 1, 0)
|
|
|
|
|
|
def test_art1_reuses_matching_category() -> None:
|
|
network = ART1Network(ART1Params(max_categories=3, input_length=4, vigilance=0.9))
|
|
network.categorize((1, 0, 1, 0))
|
|
|
|
result = network.categorize((1, 0, 1, 0))
|
|
|
|
assert result.winner == 0
|
|
assert result.new_category is False
|
|
assert result.committed_categories == 1
|
|
|
|
|
|
def test_art1_commits_new_category_for_nonmatching_pattern() -> None:
|
|
network = ART1Network(ART1Params(max_categories=3, input_length=4, vigilance=0.9))
|
|
network.categorize((1, 0, 1, 0))
|
|
|
|
result = network.categorize((0, 1, 0, 1))
|
|
|
|
assert result.winner == 1
|
|
assert result.new_category is True
|
|
assert result.committed_categories == 2
|
|
|
|
|
|
def test_art1_round_trips_through_dict() -> None:
|
|
network = ART1Network(ART1Params(max_categories=2, input_length=4, vigilance=0.8))
|
|
network.categorize((1, 1, 0, 0))
|
|
|
|
restored = ART1Network.from_dict(network.to_dict())
|
|
|
|
assert restored.vigilance == network.vigilance
|
|
assert restored.committed_categories == network.committed_categories
|
|
assert restored.categories == network.categories
|