from pathlib import Path from composer_ans.backprop import BackpropNetwork from composer_ans.io.legacy_files import load_salieri_config, load_salieri_weights from composer_ans.salieri import SalieriCritic THES = Path(__file__).resolve().parents[1] / "THES" def test_generic_backprop_predict_and_train_step() -> None: network = BackpropNetwork.random( n_input=2, n_hidden=2, n_output=1, learning_rate=0.5, alpha=0.1, ) predicted = network.predict((0.0, 1.0)) trained = network.train_step((0.0, 1.0), (1.0,)) assert len(predicted.outputs) == 1 assert len(trained.outputs) == 1 assert 0.0 <= trained.outputs[0] <= 1.0 assert trained.error >= 0.0 assert any(state.delta != 0.0 for state in trained.node_states if state.node_type != "input") def test_backprop_loads_legacy_salieri_network() -> None: config = load_salieri_config(THES / "S61.DAT") weights = load_salieri_weights(THES / "S61.WT") network = BackpropNetwork.from_legacy(config=config, legacy_weights=weights) result = network.predict(tuple(0.0 for _ in range(config.n_input))) assert network.node_count == 61 assert len(result.outputs) == 1 assert 0.0 <= result.outputs[0] <= 1.0 def test_salieri_wrapper_runs_on_thesis_sequence_window() -> None: critic = SalieriCritic.from_legacy_paths(THES) result = critic.evaluate_and_train((1, 4, 5, 1, 0)) assert result.target in (0, 1) assert 0.0 <= result.raw_output <= 1.0 assert isinstance(result.is_classical, bool)