50 lines
1.5 KiB
Python
50 lines
1.5 KiB
Python
from pathlib import Path
|
|
|
|
from composer_ans.backprop import BackpropNetwork
|
|
from composer_ans.io.legacy_files import load_salieri_config, load_salieri_weights
|
|
from composer_ans.salieri import SalieriCritic
|
|
|
|
|
|
THES = Path(__file__).resolve().parents[1] / "THES"
|
|
|
|
|
|
def test_generic_backprop_predict_and_train_step() -> None:
|
|
network = BackpropNetwork.random(
|
|
n_input=2,
|
|
n_hidden=2,
|
|
n_output=1,
|
|
learning_rate=0.5,
|
|
alpha=0.1,
|
|
)
|
|
|
|
predicted = network.predict((0.0, 1.0))
|
|
trained = network.train_step((0.0, 1.0), (1.0,))
|
|
|
|
assert len(predicted.outputs) == 1
|
|
assert len(trained.outputs) == 1
|
|
assert 0.0 <= trained.outputs[0] <= 1.0
|
|
assert trained.error >= 0.0
|
|
assert any(state.delta != 0.0 for state in trained.node_states if state.node_type != "input")
|
|
|
|
|
|
def test_backprop_loads_legacy_salieri_network() -> None:
|
|
config = load_salieri_config(THES / "S61.DAT")
|
|
weights = load_salieri_weights(THES / "S61.WT")
|
|
network = BackpropNetwork.from_legacy(config=config, legacy_weights=weights)
|
|
|
|
result = network.predict(tuple(0.0 for _ in range(config.n_input)))
|
|
|
|
assert network.node_count == 61
|
|
assert len(result.outputs) == 1
|
|
assert 0.0 <= result.outputs[0] <= 1.0
|
|
|
|
|
|
def test_salieri_wrapper_runs_on_thesis_sequence_window() -> None:
|
|
critic = SalieriCritic.from_legacy_paths(THES)
|
|
|
|
result = critic.evaluate_and_train((1, 4, 5, 1, 0))
|
|
|
|
assert result.target in (0, 1)
|
|
assert 0.0 <= result.raw_output <= 1.0
|
|
assert isinstance(result.is_classical, bool)
|