Synaptopus/src/synaptopus/art1.py

218 lines
8.1 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
import json
@dataclass(frozen=True)
class ART1Params:
max_categories: int
input_length: int
vigilance: float = 0.9
initial_bottom_up: float = 0.1
initial_top_down: float = 0.9
vigilance_decay: float = 0.99
@dataclass(frozen=True)
class ART1Category:
bottom_up: tuple[float, ...]
top_down: tuple[float, ...]
committed: bool
@dataclass(frozen=True)
class ART1Result:
winner: int
matched: bool
new_category: bool
delta_vigilance: bool
committed_categories: int
vigilance: float
expected_vector: tuple[int, ...]
class ART1Network:
def __init__(self, params: ART1Params) -> None:
self.params = params
self.vigilance = params.vigilance
self._categories = [
{
"bottom_up": [params.initial_bottom_up] * params.input_length,
"top_down": [params.initial_top_down] * params.input_length,
"committed": False,
}
for _ in range(params.max_categories)
]
@property
def committed_categories(self) -> int:
return sum(1 for category in self._categories if category["committed"])
@property
def categories(self) -> tuple[ART1Category, ...]:
return tuple(
ART1Category(
bottom_up=tuple(category["bottom_up"]),
top_down=tuple(category["top_down"]),
committed=bool(category["committed"]),
)
for category in self._categories
)
def categorize(self, input_vector: tuple[int, ...] | list[int]) -> ART1Result:
vector = tuple(int(value) for value in input_vector)
if len(vector) != self.params.input_length:
raise ValueError(
f"expected input length {self.params.input_length}, got {len(vector)}"
)
eligible = {
index for index, category in enumerate(self._categories) if category["committed"]
}
delta_vigilance = False
while True:
if not eligible:
if self.committed_categories < self.params.max_categories:
winner = self.committed_categories
self._commit_category(winner, vector)
return ART1Result(
winner=winner,
matched=True,
new_category=True,
delta_vigilance=delta_vigilance,
committed_categories=self.committed_categories,
vigilance=self.vigilance,
expected_vector=tuple(vector),
)
self.vigilance *= self.params.vigilance_decay
delta_vigilance = True
eligible = {
index
for index, category in enumerate(self._categories)
if category["committed"]
}
winner = self._choose_winner(vector, eligible)
self._resonate(winner, vector)
expected_vector = self._expected_vector(winner)
return ART1Result(
winner=winner,
matched=True,
new_category=False,
delta_vigilance=True,
committed_categories=self.committed_categories,
vigilance=self.vigilance,
expected_vector=expected_vector,
)
winner = self._choose_winner(vector, eligible)
expected_vector = self._expected_vector(winner)
if self._match(vector, expected_vector):
self._resonate(winner, vector)
return ART1Result(
winner=winner,
matched=True,
new_category=False,
delta_vigilance=delta_vigilance,
committed_categories=self.committed_categories,
vigilance=self.vigilance,
expected_vector=expected_vector,
)
eligible.remove(winner)
def _choose_winner(self, vector: tuple[int, ...], eligible: set[int]) -> int:
best_index = min(eligible)
best_score = float("-inf")
for index in sorted(eligible):
category = self._categories[index]
score = sum(
vector[i] * category["bottom_up"][i]
for i in range(self.params.input_length)
)
if score > best_score:
best_score = score
best_index = index
return best_index
def _expected_vector(self, category_index: int) -> tuple[int, ...]:
top_down = self._categories[category_index]["top_down"]
threshold = sum(top_down) / self.params.input_length
return tuple(1 if value >= threshold else 0 for value in top_down)
def _match(self, vector: tuple[int, ...], expected_vector: tuple[int, ...]) -> bool:
ones_in_input = sum(vector)
raw_match = sum(
1 for left, right in zip(vector, expected_vector) if left == 1 and right == 1
)
if ones_in_input == 0:
return raw_match > 0
return (raw_match / ones_in_input) >= self.vigilance
def _commit_category(self, category_index: int, vector: tuple[int, ...]) -> None:
category = self._categories[category_index]
category["committed"] = True
category["top_down"] = [float(value) for value in vector]
ones = max(1, sum(vector))
category["bottom_up"] = [float(value) / ones for value in vector]
def _resonate(self, category_index: int, vector: tuple[int, ...]) -> None:
category = self._categories[category_index]
intersected = [
1 if category["top_down"][i] >= 0.5 and vector[i] == 1 else 0
for i in range(self.params.input_length)
]
category["top_down"] = [float(value) for value in intersected]
ones = max(1, sum(intersected))
category["bottom_up"] = [float(value) / ones for value in intersected]
def to_dict(self) -> dict[str, object]:
return {
"params": {
"max_categories": self.params.max_categories,
"input_length": self.params.input_length,
"vigilance": self.params.vigilance,
"initial_bottom_up": self.params.initial_bottom_up,
"initial_top_down": self.params.initial_top_down,
"vigilance_decay": self.params.vigilance_decay,
},
"vigilance": self.vigilance,
"categories": self._categories,
}
@classmethod
def from_dict(cls, data: dict[str, object]) -> "ART1Network":
params_data = data["params"] # type: ignore[index]
network = cls(
ART1Params(
max_categories=int(params_data["max_categories"]), # type: ignore[index]
input_length=int(params_data["input_length"]), # type: ignore[index]
vigilance=float(params_data["vigilance"]), # type: ignore[index]
initial_bottom_up=float(params_data["initial_bottom_up"]), # type: ignore[index]
initial_top_down=float(params_data["initial_top_down"]), # type: ignore[index]
vigilance_decay=float(params_data["vigilance_decay"]), # type: ignore[index]
)
)
network.vigilance = float(data["vigilance"])
network._categories = [
{
"bottom_up": [float(value) for value in category["bottom_up"]],
"top_down": [float(value) for value in category["top_down"]],
"committed": bool(category["committed"]),
}
for category in data["categories"] # type: ignore[index]
]
return network
def save_json(self, path: str) -> None:
with open(path, "w", encoding="utf-8") as handle:
json.dump(self.to_dict(), handle, indent=2)
@classmethod
def load_json(cls, path: str) -> "ART1Network":
with open(path, "r", encoding="utf-8") as handle:
data = json.load(handle)
return cls.from_dict(data)