TriuneCadence/composer_ans/io/legacy_files.py

109 lines
3.5 KiB
Python

from __future__ import annotations
from pathlib import Path
import struct
from composer_ans.types import (
HOPFIELD_WEIGHT_DIMENSION,
LegacyBPWeights,
LegacyPaths,
SALIERI_NODE_COUNT,
SalieriConfig,
)
def load_legacy_paths(root: str | Path) -> LegacyPaths:
return LegacyPaths(root=Path(root))
def load_sequence_table(path: str | Path) -> tuple[str, ...]:
sequence_path = Path(path)
return tuple(
line.strip().rstrip("\x1a")
for line in sequence_path.read_text(encoding="ascii").splitlines()
if line.strip().rstrip("\x1a")
)
def load_salieri_config(path: str | Path) -> SalieriConfig:
values: dict[str, str] = {}
for raw_line in Path(path).read_text(encoding="ascii").splitlines():
line = raw_line.strip()
if not line.startswith("!"):
continue
code = line[1:2].upper()
if code == "Z":
break
payload = line[2:].strip()
values[code] = payload
return SalieriConfig(
learning_rate=float(values["L"]),
alpha=float(values["A"]),
n_input=int(values["I"]),
n_hidden=int(values["H"]),
n_output=int(values["O"]),
training_iterations=int(values["T"].split()[0]),
error_tolerance=float(values["E"]),
data_file=values["D"],
report_file=values["R"],
weight_file=values["W"],
)
def load_salieri_weights(path: str | Path) -> LegacyBPWeights:
vector_length = None
weights: list[tuple[float, ...]] = []
thetas: tuple[float, ...] | None = None
for raw_line in Path(path).read_text(encoding="ascii").splitlines():
line = raw_line.strip()
if not line.startswith("!"):
continue
code = line[1:2].upper()
payload = line[2:].strip()
if code == "V":
vector_length = int(payload)
elif code == "W":
row = tuple(float(item) for item in payload.split())
weights.append(row)
elif code == "T":
thetas = tuple(float(item) for item in payload.split())
elif code == "Z":
break
if vector_length is None:
raise ValueError("missing !V in weight file")
if len(weights) != vector_length:
raise ValueError(f"expected {vector_length} weight rows, got {len(weights)}")
if any(len(row) != vector_length for row in weights):
raise ValueError("weight matrix is not square")
if thetas is None:
raise ValueError("missing !T in weight file")
if len(thetas) != vector_length:
raise ValueError(f"expected {vector_length} theta values, got {len(thetas)}")
return LegacyBPWeights(
vector_length=vector_length,
weights=tuple(weights),
thetas=thetas,
)
def load_hopfield_weight_matrix(path: str | Path) -> tuple[tuple[float, ...], ...]:
data = Path(path).read_bytes()
expected_size = HOPFIELD_WEIGHT_DIMENSION * HOPFIELD_WEIGHT_DIMENSION * 4
if len(data) != expected_size:
raise ValueError(f"expected {expected_size} bytes, got {len(data)}")
values = struct.unpack(
f"<{HOPFIELD_WEIGHT_DIMENSION * HOPFIELD_WEIGHT_DIMENSION}f",
data,
)
rows = []
for offset in range(0, len(values), HOPFIELD_WEIGHT_DIMENSION):
rows.append(tuple(values[offset : offset + HOPFIELD_WEIGHT_DIMENSION]))
return tuple(rows)
def extract_active_hopfield_submatrix(
matrix: tuple[tuple[float, ...], ...],
) -> tuple[tuple[float, ...], ...]:
active_size = SALIERI_NODE_COUNT - 21
return tuple(tuple(row[:active_size]) for row in matrix[:active_size])