from __future__ import annotations from dataclasses import dataclass import json @dataclass(frozen=True) class ART1Params: max_categories: int input_length: int vigilance: float = 0.9 initial_bottom_up: float = 0.1 initial_top_down: float = 0.9 vigilance_decay: float = 0.99 @dataclass(frozen=True) class ART1Category: bottom_up: tuple[float, ...] top_down: tuple[float, ...] committed: bool @dataclass(frozen=True) class ART1Result: winner: int matched: bool new_category: bool delta_vigilance: bool committed_categories: int vigilance: float expected_vector: tuple[int, ...] class ART1Network: def __init__(self, params: ART1Params) -> None: self.params = params self.vigilance = params.vigilance self._categories = [ { "bottom_up": [params.initial_bottom_up] * params.input_length, "top_down": [params.initial_top_down] * params.input_length, "committed": False, } for _ in range(params.max_categories) ] @property def committed_categories(self) -> int: return sum(1 for category in self._categories if category["committed"]) @property def categories(self) -> tuple[ART1Category, ...]: return tuple( ART1Category( bottom_up=tuple(category["bottom_up"]), top_down=tuple(category["top_down"]), committed=bool(category["committed"]), ) for category in self._categories ) def categorize(self, input_vector: tuple[int, ...] | list[int]) -> ART1Result: vector = tuple(int(value) for value in input_vector) if len(vector) != self.params.input_length: raise ValueError( f"expected input length {self.params.input_length}, got {len(vector)}" ) eligible = {index for index, category in enumerate(self._categories) if category["committed"]} delta_vigilance = False while True: if not eligible: if self.committed_categories < self.params.max_categories: winner = self.committed_categories self._commit_category(winner, vector) return ART1Result( winner=winner, matched=True, new_category=True, delta_vigilance=delta_vigilance, committed_categories=self.committed_categories, vigilance=self.vigilance, expected_vector=tuple(vector), ) self.vigilance *= self.params.vigilance_decay delta_vigilance = True eligible = { index for index, category in enumerate(self._categories) if category["committed"] } winner = self._choose_winner(vector, eligible) self._resonate(winner, vector) expected_vector = self._expected_vector(winner) return ART1Result( winner=winner, matched=True, new_category=False, delta_vigilance=True, committed_categories=self.committed_categories, vigilance=self.vigilance, expected_vector=expected_vector, ) winner = self._choose_winner(vector, eligible) expected_vector = self._expected_vector(winner) if self._match(vector, expected_vector): self._resonate(winner, vector) return ART1Result( winner=winner, matched=True, new_category=False, delta_vigilance=delta_vigilance, committed_categories=self.committed_categories, vigilance=self.vigilance, expected_vector=expected_vector, ) eligible.remove(winner) def _choose_winner(self, vector: tuple[int, ...], eligible: set[int]) -> int: best_index = min(eligible) best_score = float("-inf") for index in sorted(eligible): category = self._categories[index] score = sum( vector[i] * category["bottom_up"][i] for i in range(self.params.input_length) ) if score > best_score: best_score = score best_index = index return best_index def _expected_vector(self, category_index: int) -> tuple[int, ...]: top_down = self._categories[category_index]["top_down"] threshold = sum(top_down) / self.params.input_length return tuple(1 if value >= threshold else 0 for value in top_down) def _match(self, vector: tuple[int, ...], expected_vector: tuple[int, ...]) -> bool: ones_in_input = sum(vector) raw_match = sum(1 for left, right in zip(vector, expected_vector) if left == 1 and right == 1) if ones_in_input == 0: return raw_match > 0 return (raw_match / ones_in_input) >= self.vigilance def _commit_category(self, category_index: int, vector: tuple[int, ...]) -> None: category = self._categories[category_index] category["committed"] = True category["top_down"] = [float(value) for value in vector] ones = max(1, sum(vector)) category["bottom_up"] = [float(value) / ones for value in vector] def _resonate(self, category_index: int, vector: tuple[int, ...]) -> None: category = self._categories[category_index] intersected = [1 if category["top_down"][i] >= 0.5 and vector[i] == 1 else 0 for i in range(self.params.input_length)] category["top_down"] = [float(value) for value in intersected] ones = max(1, sum(intersected)) category["bottom_up"] = [float(value) / ones for value in intersected] def to_dict(self) -> dict[str, object]: return { "params": { "max_categories": self.params.max_categories, "input_length": self.params.input_length, "vigilance": self.params.vigilance, "initial_bottom_up": self.params.initial_bottom_up, "initial_top_down": self.params.initial_top_down, "vigilance_decay": self.params.vigilance_decay, }, "vigilance": self.vigilance, "categories": self._categories, } @classmethod def from_dict(cls, data: dict[str, object]) -> "ART1Network": params_data = data["params"] # type: ignore[index] network = cls( ART1Params( max_categories=int(params_data["max_categories"]), # type: ignore[index] input_length=int(params_data["input_length"]), # type: ignore[index] vigilance=float(params_data["vigilance"]), # type: ignore[index] initial_bottom_up=float(params_data["initial_bottom_up"]), # type: ignore[index] initial_top_down=float(params_data["initial_top_down"]), # type: ignore[index] vigilance_decay=float(params_data["vigilance_decay"]), # type: ignore[index] ) ) network.vigilance = float(data["vigilance"]) network._categories = [ { "bottom_up": [float(value) for value in category["bottom_up"]], "top_down": [float(value) for value in category["top_down"]], "committed": bool(category["committed"]), } for category in data["categories"] # type: ignore[index] ] return network def save_json(self, path: str) -> None: with open(path, "w", encoding="utf-8") as handle: json.dump(self.to_dict(), handle, indent=2) @classmethod def load_json(cls, path: str) -> "ART1Network": with open(path, "r", encoding="utf-8") as handle: data = json.load(handle) return cls.from_dict(data)