TriuneCadence/tests/test_backprop.py

50 lines
1.5 KiB
Python

from pathlib import Path
from composer_ans.backprop import BackpropNetwork
from composer_ans.io.legacy_files import load_salieri_config, load_salieri_weights
from composer_ans.salieri import SalieriCritic
THES = Path(__file__).resolve().parents[1] / "THES"
def test_generic_backprop_predict_and_train_step() -> None:
network = BackpropNetwork.random(
n_input=2,
n_hidden=2,
n_output=1,
learning_rate=0.5,
alpha=0.1,
)
predicted = network.predict((0.0, 1.0))
trained = network.train_step((0.0, 1.0), (1.0,))
assert len(predicted.outputs) == 1
assert len(trained.outputs) == 1
assert 0.0 <= trained.outputs[0] <= 1.0
assert trained.error >= 0.0
assert any(state.delta != 0.0 for state in trained.node_states if state.node_type != "input")
def test_backprop_loads_legacy_salieri_network() -> None:
config = load_salieri_config(THES / "S61.DAT")
weights = load_salieri_weights(THES / "S61.WT")
network = BackpropNetwork.from_legacy(config=config, legacy_weights=weights)
result = network.predict(tuple(0.0 for _ in range(config.n_input)))
assert network.node_count == 61
assert len(result.outputs) == 1
assert 0.0 <= result.outputs[0] <= 1.0
def test_salieri_wrapper_runs_on_thesis_sequence_window() -> None:
critic = SalieriCritic.from_legacy_paths(THES)
result = critic.evaluate_and_train((1, 4, 5, 1, 0))
assert result.target in (0, 1)
assert 0.0 <= result.raw_output <= 1.0
assert isinstance(result.is_classical, bool)