31 lines
981 B
Python
31 lines
981 B
Python
from pathlib import Path
|
|
|
|
from composer_ans.beethoven import BeethovenCategorizer
|
|
from composer_ans.salieri import SalieriCritic
|
|
|
|
|
|
THES = Path(__file__).resolve().parents[1] / "THES"
|
|
|
|
|
|
def test_salieri_round_trip_json(tmp_path: Path) -> None:
|
|
critic = SalieriCritic.from_legacy_paths(THES)
|
|
critic.evaluate_and_train((1, 4, 5, 1, 0))
|
|
path = tmp_path / "salieri.json"
|
|
|
|
critic.save_json(str(path))
|
|
restored = SalieriCritic.load_json(str(path))
|
|
|
|
result = restored.evaluate_and_train((1, 4, 5, 1, 0))
|
|
assert 0.0 <= result.raw_output <= 1.0
|
|
|
|
|
|
def test_beethoven_round_trip_json(tmp_path: Path) -> None:
|
|
beethoven = BeethovenCategorizer()
|
|
beethoven.categorize((1, 4, 5, 1, 0), is_classical=True)
|
|
path = tmp_path / "beethoven.json"
|
|
|
|
beethoven.save_json(str(path))
|
|
restored = BeethovenCategorizer.load_json(str(path))
|
|
result = restored.categorize((1, 4, 5, 1, 0), is_classical=True)
|
|
assert result.art_result.committed_categories >= 1
|