208 lines
8.0 KiB
Python
208 lines
8.0 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
import json
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ART1Params:
|
|
max_categories: int
|
|
input_length: int
|
|
vigilance: float = 0.9
|
|
initial_bottom_up: float = 0.1
|
|
initial_top_down: float = 0.9
|
|
vigilance_decay: float = 0.99
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ART1Category:
|
|
bottom_up: tuple[float, ...]
|
|
top_down: tuple[float, ...]
|
|
committed: bool
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ART1Result:
|
|
winner: int
|
|
matched: bool
|
|
new_category: bool
|
|
delta_vigilance: bool
|
|
committed_categories: int
|
|
vigilance: float
|
|
expected_vector: tuple[int, ...]
|
|
|
|
|
|
class ART1Network:
|
|
def __init__(self, params: ART1Params) -> None:
|
|
self.params = params
|
|
self.vigilance = params.vigilance
|
|
self._categories = [
|
|
{
|
|
"bottom_up": [params.initial_bottom_up] * params.input_length,
|
|
"top_down": [params.initial_top_down] * params.input_length,
|
|
"committed": False,
|
|
}
|
|
for _ in range(params.max_categories)
|
|
]
|
|
|
|
@property
|
|
def committed_categories(self) -> int:
|
|
return sum(1 for category in self._categories if category["committed"])
|
|
|
|
@property
|
|
def categories(self) -> tuple[ART1Category, ...]:
|
|
return tuple(
|
|
ART1Category(
|
|
bottom_up=tuple(category["bottom_up"]),
|
|
top_down=tuple(category["top_down"]),
|
|
committed=bool(category["committed"]),
|
|
)
|
|
for category in self._categories
|
|
)
|
|
|
|
def categorize(self, input_vector: tuple[int, ...] | list[int]) -> ART1Result:
|
|
vector = tuple(int(value) for value in input_vector)
|
|
if len(vector) != self.params.input_length:
|
|
raise ValueError(
|
|
f"expected input length {self.params.input_length}, got {len(vector)}"
|
|
)
|
|
|
|
eligible = {index for index, category in enumerate(self._categories) if category["committed"]}
|
|
delta_vigilance = False
|
|
|
|
while True:
|
|
if not eligible:
|
|
if self.committed_categories < self.params.max_categories:
|
|
winner = self.committed_categories
|
|
self._commit_category(winner, vector)
|
|
return ART1Result(
|
|
winner=winner,
|
|
matched=True,
|
|
new_category=True,
|
|
delta_vigilance=delta_vigilance,
|
|
committed_categories=self.committed_categories,
|
|
vigilance=self.vigilance,
|
|
expected_vector=tuple(vector),
|
|
)
|
|
|
|
self.vigilance *= self.params.vigilance_decay
|
|
delta_vigilance = True
|
|
eligible = {
|
|
index for index, category in enumerate(self._categories) if category["committed"]
|
|
}
|
|
winner = self._choose_winner(vector, eligible)
|
|
self._resonate(winner, vector)
|
|
expected_vector = self._expected_vector(winner)
|
|
return ART1Result(
|
|
winner=winner,
|
|
matched=True,
|
|
new_category=False,
|
|
delta_vigilance=True,
|
|
committed_categories=self.committed_categories,
|
|
vigilance=self.vigilance,
|
|
expected_vector=expected_vector,
|
|
)
|
|
|
|
winner = self._choose_winner(vector, eligible)
|
|
expected_vector = self._expected_vector(winner)
|
|
if self._match(vector, expected_vector):
|
|
self._resonate(winner, vector)
|
|
return ART1Result(
|
|
winner=winner,
|
|
matched=True,
|
|
new_category=False,
|
|
delta_vigilance=delta_vigilance,
|
|
committed_categories=self.committed_categories,
|
|
vigilance=self.vigilance,
|
|
expected_vector=expected_vector,
|
|
)
|
|
|
|
eligible.remove(winner)
|
|
|
|
def _choose_winner(self, vector: tuple[int, ...], eligible: set[int]) -> int:
|
|
best_index = min(eligible)
|
|
best_score = float("-inf")
|
|
for index in sorted(eligible):
|
|
category = self._categories[index]
|
|
score = sum(
|
|
vector[i] * category["bottom_up"][i] for i in range(self.params.input_length)
|
|
)
|
|
if score > best_score:
|
|
best_score = score
|
|
best_index = index
|
|
return best_index
|
|
|
|
def _expected_vector(self, category_index: int) -> tuple[int, ...]:
|
|
top_down = self._categories[category_index]["top_down"]
|
|
threshold = sum(top_down) / self.params.input_length
|
|
return tuple(1 if value >= threshold else 0 for value in top_down)
|
|
|
|
def _match(self, vector: tuple[int, ...], expected_vector: tuple[int, ...]) -> bool:
|
|
ones_in_input = sum(vector)
|
|
raw_match = sum(1 for left, right in zip(vector, expected_vector) if left == 1 and right == 1)
|
|
if ones_in_input == 0:
|
|
return raw_match > 0
|
|
return (raw_match / ones_in_input) >= self.vigilance
|
|
|
|
def _commit_category(self, category_index: int, vector: tuple[int, ...]) -> None:
|
|
category = self._categories[category_index]
|
|
category["committed"] = True
|
|
category["top_down"] = [float(value) for value in vector]
|
|
ones = max(1, sum(vector))
|
|
category["bottom_up"] = [float(value) / ones for value in vector]
|
|
|
|
def _resonate(self, category_index: int, vector: tuple[int, ...]) -> None:
|
|
category = self._categories[category_index]
|
|
intersected = [1 if category["top_down"][i] >= 0.5 and vector[i] == 1 else 0 for i in range(self.params.input_length)]
|
|
category["top_down"] = [float(value) for value in intersected]
|
|
ones = max(1, sum(intersected))
|
|
category["bottom_up"] = [float(value) / ones for value in intersected]
|
|
|
|
def to_dict(self) -> dict[str, object]:
|
|
return {
|
|
"params": {
|
|
"max_categories": self.params.max_categories,
|
|
"input_length": self.params.input_length,
|
|
"vigilance": self.params.vigilance,
|
|
"initial_bottom_up": self.params.initial_bottom_up,
|
|
"initial_top_down": self.params.initial_top_down,
|
|
"vigilance_decay": self.params.vigilance_decay,
|
|
},
|
|
"vigilance": self.vigilance,
|
|
"categories": self._categories,
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: dict[str, object]) -> "ART1Network":
|
|
params_data = data["params"] # type: ignore[index]
|
|
network = cls(
|
|
ART1Params(
|
|
max_categories=int(params_data["max_categories"]), # type: ignore[index]
|
|
input_length=int(params_data["input_length"]), # type: ignore[index]
|
|
vigilance=float(params_data["vigilance"]), # type: ignore[index]
|
|
initial_bottom_up=float(params_data["initial_bottom_up"]), # type: ignore[index]
|
|
initial_top_down=float(params_data["initial_top_down"]), # type: ignore[index]
|
|
vigilance_decay=float(params_data["vigilance_decay"]), # type: ignore[index]
|
|
)
|
|
)
|
|
network.vigilance = float(data["vigilance"])
|
|
network._categories = [
|
|
{
|
|
"bottom_up": [float(value) for value in category["bottom_up"]],
|
|
"top_down": [float(value) for value in category["top_down"]],
|
|
"committed": bool(category["committed"]),
|
|
}
|
|
for category in data["categories"] # type: ignore[index]
|
|
]
|
|
return network
|
|
|
|
def save_json(self, path: str) -> None:
|
|
with open(path, "w", encoding="utf-8") as handle:
|
|
json.dump(self.to_dict(), handle, indent=2)
|
|
|
|
@classmethod
|
|
def load_json(cls, path: str) -> "ART1Network":
|
|
with open(path, "r", encoding="utf-8") as handle:
|
|
data = json.load(handle)
|
|
return cls.from_dict(data)
|