from __future__ import annotations import argparse, csv, os from datetime import datetime import numpy as np from alice_fast.batched_belt import BatchedBelt from alice_fast.kernels import PASS, PEEK, EAT from alice_tools.sequence import ( gellermann_k, debruijn, tile_or_trim, audit_sequence, ) """ Curiosity demo with CSV logging. Two puzzle families: 0 = Informative: PEEK (non-advancing) makes EAT good; without PEEK, EAT is bad. 1 = Uninformative: PEEK costs but does not change EAT value. We encode "non-advancing" by augmenting state: S=2 states per puzzle: 0=unpeeked, 1=peeked. PEEK: 0->1, 1->1 (information state only) EAT: returns to 0; reward depends on family+state PASS: resets to unpeeked (small cost). """ # ---------- Puzzle builders ---------- def build_tables_informative(): S, A = 2, 3 tt = np.zeros((S, A), dtype=np.int32) tt[:, PASS] = 0 tt[0, PEEK] = 1 tt[1, PEEK] = 1 tt[:, EAT] = 0 rt = np.zeros((S, A, S), dtype=np.float32) base_costs = np.array([-0.02, -0.05, 0.0], dtype=np.float32) rt[0, EAT, 0] = -0.25 # uninformed 'eat' is risky/bad rt[1, EAT, 0] = 1.0 # informed 'eat' is good return S, A, tt, rt, base_costs def build_tables_uninformative(): S, A = 2, 3 tt = np.zeros((S, A), dtype=np.int32) tt[:, PASS] = 0 tt[0, PEEK] = 1 tt[1, PEEK] = 1 tt[:, EAT] = 0 rt = np.zeros((S, A, S), dtype=np.float32) base_costs = np.array([-0.02, -0.05, 0.0], dtype=np.float32) # same payoff whether peeked or not rt[0, EAT, 0] = 0.30 rt[1, EAT, 0] = 0.30 return S, A, tt, rt, base_costs # ---------- Run one scheduled segment ---------- def run_segment(belt: BatchedBelt, steps: int): total_reward = 0.0 total_peeks = 0 total_actions = 0 for _ in range(steps): out = belt.step_learn() total_reward += float(out["rewards"].sum()) total_peeks += int(np.sum(out["actions"] == PEEK)) total_actions += out["actions"].size return { "avg_reward_per_box_step": total_reward / total_actions, "peek_rate": total_peeks / total_actions } def ensure_parent(path: str): os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) # ---------- CLI / main ---------- def main(): ap = argparse.ArgumentParser(description="Curiosity demo with CSV logging and configurable k-ary schedules.") ap.add_argument("--out", type=str, default=None, help="CSV output path (default: results/curiosity_demo_YYYYmmdd-HHMMSS.csv)") ap.add_argument("--segments", type=int, default=20, help="Number of segments") ap.add_argument("--steps_per_segment", type=int, default=1000, help="Steps per segment") ap.add_argument("--batch", type=int, default=4096, help="Batch size") ap.add_argument("--seed", type=int, default=7, help="Base RNG seed") # Sequence controls ap.add_argument("--seq_mode", type=str, default="gellermann", choices=["gellermann", "debruijn"], help="Scheduling mode for families across segments") ap.add_argument("--run_cap", type=int, default=3, help="Max run length per symbol (Gellermann)") ap.add_argument("--exact_counts", action="store_true", help="(Gellermann) Force exactly equal counts (rounds n down to multiple of k)") ap.add_argument("--half_balance", action="store_true", help="(Gellermann) Enforce first-half counts <= ceil(target/2) per symbol") ap.add_argument("--debruijn_order", type=int, default=4, help="Order m for de Bruijn(k=2, m); final length is tiled/trimmed to --segments") args = ap.parse_args() if args.out is None: stamp = datetime.now().strftime("%Y%m%d-%H%M%S") args.out = f"results/curiosity_demo_{stamp}.csv" ensure_parent(args.out) # Build families (same S,A; different reward tables) S0, A0, tt0, rt0, costs0 = build_tables_informative() S1, A1, tt1, rt1, costs1 = build_tables_uninformative() assert (S0, A0) == (S1, A1) S, A = S0, A0 belt_inf = BatchedBelt(S, A, tt0, rt0, costs0, batch_size=args.batch, gamma=0.97, alpha=0.2, epsilon=0.05, seed=args.seed) belt_uninf = BatchedBelt(S, A, tt1, rt1, costs1, batch_size=args.batch, gamma=0.97, alpha=0.2, epsilon=0.05, seed=args.seed+1) # Schedule the N segments with a k-ary sequence over families {0,1} if args.seq_mode == "gellermann": seq = gellermann_k( n=args.segments, k=2, run_cap=args.run_cap, seed=args.seed, exact_counts=args.exact_counts, half_balance=args.half_balance ) else: base = debruijn(k=2, m=args.debruijn_order) seq = tile_or_trim(base, args.segments) audit = audit_sequence(seq, k=2) print("Sequence (0=informative, 1=uninformative):", seq.tolist()) print("Audit:", audit) # CSV header header = [ "segment_index", "family", "peek_rate", "avg_reward_per_box_step", "batch", "steps_per_segment", "S", "A", "gamma", "alpha", "epsilon", "cost_pass", "cost_peek", "cost_eat", "seed", "seq_mode", "run_cap", "exact_counts", "half_balance", "debruijn_order" ] with open(args.out, "w", newline="") as f: w = csv.writer(f) w.writerow(header) for i, sym in enumerate(seq): if sym == 0: res = run_segment(belt_inf, args.steps_per_segment) fam = "informative" c = costs0 else: res = run_segment(belt_uninf, args.steps_per_segment) fam = "uninformative" c = costs1 row = [ i, fam, f"{res['peek_rate']:.6f}", f"{res['avg_reward_per_box_step']:.6f}", args.batch, args.steps_per_segment, S, A, 0.97, 0.2, 0.05, float(c[0]), float(c[1]), float(c[2]), args.seed, args.seq_mode, args.run_cap, int(bool(args.exact_counts)), int(bool(args.half_balance)), args.debruijn_order ] w.writerow(row) print(f"Seg {i:02d} [{fam[:5].upper()}] " f"peek_rate={res['peek_rate']:.3f} " f"avg_reward/step={res['avg_reward_per_box_step']:.4f}") print(f"\nWrote CSV → {args.out}") if __name__ == "__main__": main()