from __future__ import annotations import argparse, csv, os from datetime import datetime import numpy as np from alice_fast.batched_belt import BatchedBelt from alice_fast.kernels import PASS, PEEK, EAT from alice_tools.sequence import gellermann_k, audit_sequence """ Curiosity demo with CSV logging. Two puzzle families: 0 = Informative: PEEK (non-advancing) makes EAT good; without PEEK, EAT is bad. 1 = Uninformative: PEEK costs but does not change EAT value. We encode "non-advancing" by augmenting state: S=2 states per puzzle: 0=unpeeked, 1=peeked. PEEK: 0->1, 1->1 (information state only) EAT: returns to 0; reward depends on family+state PASS: resets to unpeeked (small cost). """ def build_tables_informative(): S, A = 2, 3 tt = np.zeros((S, A), dtype=np.int32) tt[:, PASS] = 0 tt[0, PEEK] = 1 tt[1, PEEK] = 1 tt[:, EAT] = 0 rt = np.zeros((S, A, S), dtype=np.float32) base_costs = np.array([-0.02, -0.05, 0.0], dtype=np.float32) rt[0, EAT, 0] = -0.25 # uninformed 'eat' is risky/bad rt[1, EAT, 0] = 1.0 # informed 'eat' is good return S, A, tt, rt, base_costs def build_tables_uninformative(): S, A = 2, 3 tt = np.zeros((S, A), dtype=np.int32) tt[:, PASS] = 0 tt[0, PEEK] = 1 tt[1, PEEK] = 1 tt[:, EAT] = 0 rt = np.zeros((S, A, S), dtype=np.float32) base_costs = np.array([-0.02, -0.05, 0.0], dtype=np.float32) rt[0, EAT, 0] = 0.30 # same payoff whether peeked or not rt[1, EAT, 0] = 0.30 return S, A, tt, rt, base_costs def run_segment(belt: BatchedBelt, steps: int): total_reward = 0.0 total_peeks = 0 total_actions = 0 for _ in range(steps): out = belt.step_learn() total_reward += float(out["rewards"].sum()) total_peeks += int(np.sum(out["actions"] == PEEK)) total_actions += out["actions"].size return { "avg_reward_per_box_step": total_reward / total_actions, "peek_rate": total_peeks / total_actions } def ensure_parent(path: str): os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) def main(): ap = argparse.ArgumentParser() ap.add_argument("--out", type=str, default=None, help="CSV output path (default: results/curiosity_demo_YYYYmmdd-HHMMSS.csv)") ap.add_argument("--segments", type=int, default=20, help="Number of segments") ap.add_argument("--steps_per_segment", type=int, default=1000, help="Steps per segment") ap.add_argument("--batch", type=int, default=4096, help="Batch size") ap.add_argument("--seed", type=int, default=7, help="Base RNG seed") args = ap.parse_args() if args.out is None: stamp = datetime.now().strftime("%Y%m%d-%H%M%S") args.out = f"results/curiosity_demo_{stamp}.csv" ensure_parent(args.out) # Build families S0, A0, tt0, rt0, costs0 = build_tables_informative() S1, A1, tt1, rt1, costs1 = build_tables_uninformative() assert (S0, A0) == (S1, A1) S, A = S0, A0 # Two belts (same shape, different reward tables) belt_inf = BatchedBelt(S, A, tt0, rt0, costs0, batch_size=args.batch, gamma=0.97, alpha=0.2, epsilon=0.05, seed=args.seed) belt_uninf= BatchedBelt(S, A, tt1, rt1, costs1, batch_size=args.batch, gamma=0.97, alpha=0.2, epsilon=0.05, seed=args.seed+1) # k=2 families, balanced, limited runs seq = gellermann_k(n=args.segments, k=2, run_cap=3, seed=args.seed) audit = audit_sequence(seq, k=2) print("Sequence (0=informative, 1=uninformative):", seq.tolist()) print("Audit:", audit) # CSV header header = [ "segment_index", "family", "peek_rate", "avg_reward_per_box_step", "batch", "steps_per_segment", "S", "A", "gamma", "alpha", "epsilon", "cost_pass", "cost_peek", "cost_eat", "seed" ] with open(args.out, "w", newline="") as f: w = csv.writer(f) w.writerow(header) for i, sym in enumerate(seq): if sym == 0: res = run_segment(belt_inf, args.steps_per_segment) fam = "informative" c = costs0 else: res = run_segment(belt_uninf, args.steps_per_segment) fam = "uninformative" c = costs1 row = [ i, fam, f"{res['peek_rate']:.6f}", f"{res['avg_reward_per_box_step']:.6f}", args.batch, args.steps_per_segment, S, A, 0.97, 0.2, 0.05, float(c[0]), float(c[1]), float(c[2]), args.seed ] w.writerow(row) print(f"Seg {i:02d} [{fam[:5].upper()}] peek_rate={res['peek_rate']:.3f} " f"avg_reward/step={res['avg_reward_per_box_step']:.4f}") print(f"\nWrote CSV → {args.out}") if __name__ == "__main__": main()