TriuneCadence/composer_ans/backprop.py

283 lines
10 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
import json
import math
import random
from typing import Iterable
from .types import LegacyBPWeights, SalieriConfig
def sigmoid(range_value: float, slope_mod: float, shift: float, x: float) -> float:
temp = -(slope_mod * x)
temp = max(min(temp, 80.0), -80.0)
return (range_value / (1.0 + math.exp(temp))) - shift
@dataclass(frozen=True)
class BackpropNodeState:
node_type: str
net_input: float
delta: float
theta: float
range_value: float = 1.0
shift: float = 0.0
@dataclass(frozen=True)
class BackpropResult:
outputs: tuple[float, ...]
error: float
node_states: tuple[BackpropNodeState, ...]
class BackpropNetwork:
def __init__(
self,
*,
n_input: int,
n_hidden: int,
n_output: int,
learning_rate: float,
alpha: float,
weights: list[list[float]],
thetas: list[float],
) -> None:
self.n_input = n_input
self.n_hidden = n_hidden
self.n_output = n_output
self.learning_rate = learning_rate
self.alpha = alpha
self.node_count = n_input + n_hidden + n_output
self.weights = weights
self.thetas = thetas
self.last_weight_updates = [
[0.0 for _ in range(self.node_count)] for _ in range(self.node_count)
]
self.last_theta_updates = [0.0 for _ in range(self.node_count)]
self.node_types = self._build_node_types()
self.connectivity = self._build_connectivity()
@classmethod
def random(
cls,
*,
n_input: int,
n_hidden: int,
n_output: int,
learning_rate: float = 0.5,
alpha: float = 0.5,
rng: random.Random | None = None,
) -> "BackpropNetwork":
generator = rng or random.Random()
node_count = n_input + n_hidden + n_output
weights = [
[generator.uniform(-1.0, 1.0) for _ in range(node_count)]
for _ in range(node_count)
]
thetas = [0.0] * n_input + [generator.gauss(0.0, 0.25) for _ in range(n_hidden + n_output)]
return cls(
n_input=n_input,
n_hidden=n_hidden,
n_output=n_output,
learning_rate=learning_rate,
alpha=alpha,
weights=weights,
thetas=thetas,
)
@classmethod
def from_legacy(
cls,
*,
config: SalieriConfig,
legacy_weights: LegacyBPWeights,
) -> "BackpropNetwork":
return cls(
n_input=config.n_input,
n_hidden=config.n_hidden,
n_output=config.n_output,
learning_rate=config.learning_rate,
alpha=config.alpha,
weights=[list(row) for row in legacy_weights.weights],
thetas=list(legacy_weights.thetas),
)
def predict(self, inputs: Iterable[float]) -> BackpropResult:
input_values = tuple(float(value) for value in inputs)
if len(input_values) != self.n_input:
raise ValueError(f"expected {self.n_input} inputs, got {len(input_values)}")
net_inputs = [0.0 for _ in range(self.node_count)]
activations = [0.0 for _ in range(self.node_count)]
for idx in range(self.node_count):
if self.node_types[idx] == "input":
net_inputs[idx] = input_values[idx]
activations[idx] = input_values[idx]
continue
total = 0.0
for src in range(self.node_count):
if not self.connectivity[idx][src]:
continue
if self.node_types[src] == "input":
total += net_inputs[src] * self.weights[idx][src]
else:
total += sigmoid(1.0, 1.0, 0.0, net_inputs[src] + self.thetas[src]) * self.weights[idx][src]
net_inputs[idx] = total
activations[idx] = sigmoid(1.0, 1.0, 0.0, total + self.thetas[idx])
outputs = tuple(activations[self.n_input + self.n_hidden :])
node_states = tuple(
BackpropNodeState(
node_type=self.node_types[idx],
net_input=net_inputs[idx],
delta=0.0,
theta=self.thetas[idx],
)
for idx in range(self.node_count)
)
return BackpropResult(outputs=outputs, error=0.0, node_states=node_states)
def train_step(self, inputs: Iterable[float], targets: Iterable[float]) -> BackpropResult:
input_values = tuple(float(value) for value in inputs)
target_values = tuple(float(value) for value in targets)
if len(target_values) != self.n_output:
raise ValueError(f"expected {self.n_output} targets, got {len(target_values)}")
net_inputs = [0.0 for _ in range(self.node_count)]
activations = [0.0 for _ in range(self.node_count)]
for idx in range(self.node_count):
if self.node_types[idx] == "input":
net_inputs[idx] = input_values[idx]
activations[idx] = input_values[idx]
continue
total = 0.0
for src in range(self.node_count):
if not self.connectivity[idx][src]:
continue
source_activation = (
net_inputs[src]
if self.node_types[src] == "input"
else sigmoid(1.0, 1.0, 0.0, net_inputs[src] + self.thetas[src])
)
total += source_activation * self.weights[idx][src]
net_inputs[idx] = total
activations[idx] = sigmoid(1.0, 1.0, 0.0, total + self.thetas[idx])
deltas = [0.0 for _ in range(self.node_count)]
output_start = self.n_input + self.n_hidden
max_error = 0.0
for idx in range(self.node_count - 1, -1, -1):
activation = activations[idx]
if self.node_types[idx] == "output":
target = target_values[idx - output_start]
raw_error = target - activation
max_error = max(max_error, abs(raw_error))
deltas[idx] = raw_error * activation * (1.0 - activation)
elif self.node_types[idx] == "hidden":
downstream = 0.0
for dst in range(self.node_count):
if self.connectivity[dst][idx]:
downstream += deltas[dst] * self.weights[dst][idx]
deltas[idx] = activation * (1.0 - activation) * downstream
for idx in range(self.node_count):
theta_update = self.learning_rate * deltas[idx] + self.alpha * self.last_theta_updates[idx]
self.last_theta_updates[idx] = theta_update
self.thetas[idx] += theta_update
for dst in range(self.node_count):
destination_activation = (
net_inputs[dst]
if self.node_types[dst] == "input"
else activations[dst]
)
for src in range(self.node_count):
if not self.connectivity[dst][src]:
continue
update = self.learning_rate * (deltas[src] * destination_activation)
update += self.alpha * self.last_weight_updates[dst][src]
self.last_weight_updates[dst][src] = update
self.weights[dst][src] += update
outputs = tuple(activations[output_start:])
node_states = tuple(
BackpropNodeState(
node_type=self.node_types[idx],
net_input=net_inputs[idx],
delta=deltas[idx],
theta=self.thetas[idx],
)
for idx in range(self.node_count)
)
return BackpropResult(outputs=outputs, error=max_error, node_states=node_states)
def _build_node_types(self) -> list[str]:
return (
["input"] * self.n_input
+ ["hidden"] * self.n_hidden
+ ["output"] * self.n_output
)
def _build_connectivity(self) -> list[list[bool]]:
connectivity = [[False for _ in range(self.node_count)] for _ in range(self.node_count)]
hidden_start = self.n_input
output_start = self.n_input + self.n_hidden
for dst in range(hidden_start, output_start):
for src in range(self.n_input):
connectivity[dst][src] = True
for dst in range(output_start, self.node_count):
for src in range(hidden_start, output_start):
connectivity[dst][src] = True
return connectivity
def to_dict(self) -> dict[str, object]:
return {
"n_input": self.n_input,
"n_hidden": self.n_hidden,
"n_output": self.n_output,
"learning_rate": self.learning_rate,
"alpha": self.alpha,
"weights": self.weights,
"thetas": self.thetas,
"last_weight_updates": self.last_weight_updates,
"last_theta_updates": self.last_theta_updates,
}
@classmethod
def from_dict(cls, data: dict[str, object]) -> "BackpropNetwork":
network = cls(
n_input=int(data["n_input"]),
n_hidden=int(data["n_hidden"]),
n_output=int(data["n_output"]),
learning_rate=float(data["learning_rate"]),
alpha=float(data["alpha"]),
weights=[[float(value) for value in row] for row in data["weights"]], # type: ignore[index]
thetas=[float(value) for value in data["thetas"]], # type: ignore[index]
)
network.last_weight_updates = [
[float(value) for value in row]
for row in data.get("last_weight_updates", network.last_weight_updates) # type: ignore[arg-type]
]
network.last_theta_updates = [
float(value)
for value in data.get("last_theta_updates", network.last_theta_updates) # type: ignore[arg-type]
]
return network
def save_json(self, path: str) -> None:
with open(path, "w", encoding="utf-8") as handle:
json.dump(self.to_dict(), handle, indent=2)
@classmethod
def load_json(cls, path: str) -> "BackpropNetwork":
with open(path, "r", encoding="utf-8") as handle:
data = json.load(handle)
return cls.from_dict(data)