From 88208d0fa4c444eacc60d320f43a7b62c219caa2 Mon Sep 17 00:00:00 2001 From: Diane Blackwood Date: Sat, 20 Sep 2025 21:00:54 -0400 Subject: [PATCH] Updated use of gellerman sequence features. --- bench/run_curiosity_demo.py | 66 +++++++++++++++++++++++++++++-------- 1 file changed, 52 insertions(+), 14 deletions(-) diff --git a/bench/run_curiosity_demo.py b/bench/run_curiosity_demo.py index e619510..bbdab76 100644 --- a/bench/run_curiosity_demo.py +++ b/bench/run_curiosity_demo.py @@ -2,9 +2,15 @@ from __future__ import annotations import argparse, csv, os from datetime import datetime import numpy as np + from alice_fast.batched_belt import BatchedBelt from alice_fast.kernels import PASS, PEEK, EAT -from alice_tools.sequence import gellermann_k, audit_sequence +from alice_tools.sequence import ( + gellermann_k, + debruijn, + tile_or_trim, + audit_sequence, +) """ Curiosity demo with CSV logging. @@ -20,6 +26,8 @@ We encode "non-advancing" by augmenting state: PASS: resets to unpeeked (small cost). """ +# ---------- Puzzle builders ---------- + def build_tables_informative(): S, A = 2, 3 tt = np.zeros((S, A), dtype=np.int32) @@ -35,6 +43,7 @@ def build_tables_informative(): rt[1, EAT, 0] = 1.0 # informed 'eat' is good return S, A, tt, rt, base_costs + def build_tables_uninformative(): S, A = 2, 3 tt = np.zeros((S, A), dtype=np.int32) @@ -46,10 +55,13 @@ def build_tables_uninformative(): rt = np.zeros((S, A, S), dtype=np.float32) base_costs = np.array([-0.02, -0.05, 0.0], dtype=np.float32) - rt[0, EAT, 0] = 0.30 # same payoff whether peeked or not + # same payoff whether peeked or not + rt[0, EAT, 0] = 0.30 rt[1, EAT, 0] = 0.30 return S, A, tt, rt, base_costs +# ---------- Run one scheduled segment ---------- + def run_segment(belt: BatchedBelt, steps: int): total_reward = 0.0 total_peeks = 0 @@ -67,13 +79,29 @@ def run_segment(belt: BatchedBelt, steps: int): def ensure_parent(path: str): os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) +# ---------- CLI / main ---------- + def main(): - ap = argparse.ArgumentParser() - ap.add_argument("--out", type=str, default=None, help="CSV output path (default: results/curiosity_demo_YYYYmmdd-HHMMSS.csv)") + ap = argparse.ArgumentParser(description="Curiosity demo with CSV logging and configurable k-ary schedules.") + ap.add_argument("--out", type=str, default=None, + help="CSV output path (default: results/curiosity_demo_YYYYmmdd-HHMMSS.csv)") ap.add_argument("--segments", type=int, default=20, help="Number of segments") ap.add_argument("--steps_per_segment", type=int, default=1000, help="Steps per segment") ap.add_argument("--batch", type=int, default=4096, help="Batch size") ap.add_argument("--seed", type=int, default=7, help="Base RNG seed") + + # Sequence controls + ap.add_argument("--seq_mode", type=str, default="gellermann", + choices=["gellermann", "debruijn"], + help="Scheduling mode for families across segments") + ap.add_argument("--run_cap", type=int, default=3, help="Max run length per symbol (Gellermann)") + ap.add_argument("--exact_counts", action="store_true", + help="(Gellermann) Force exactly equal counts (rounds n down to multiple of k)") + ap.add_argument("--half_balance", action="store_true", + help="(Gellermann) Enforce first-half counts <= ceil(target/2) per symbol") + ap.add_argument("--debruijn_order", type=int, default=4, + help="Order m for de Bruijn(k=2, m); final length is tiled/trimmed to --segments") + args = ap.parse_args() if args.out is None: @@ -81,18 +109,25 @@ def main(): args.out = f"results/curiosity_demo_{stamp}.csv" ensure_parent(args.out) - # Build families + # Build families (same S,A; different reward tables) S0, A0, tt0, rt0, costs0 = build_tables_informative() S1, A1, tt1, rt1, costs1 = build_tables_uninformative() assert (S0, A0) == (S1, A1) S, A = S0, A0 - # Two belts (same shape, different reward tables) - belt_inf = BatchedBelt(S, A, tt0, rt0, costs0, batch_size=args.batch, gamma=0.97, alpha=0.2, epsilon=0.05, seed=args.seed) - belt_uninf= BatchedBelt(S, A, tt1, rt1, costs1, batch_size=args.batch, gamma=0.97, alpha=0.2, epsilon=0.05, seed=args.seed+1) + belt_inf = BatchedBelt(S, A, tt0, rt0, costs0, batch_size=args.batch, gamma=0.97, alpha=0.2, epsilon=0.05, seed=args.seed) + belt_uninf = BatchedBelt(S, A, tt1, rt1, costs1, batch_size=args.batch, gamma=0.97, alpha=0.2, epsilon=0.05, seed=args.seed+1) + + # Schedule the N segments with a k-ary sequence over families {0,1} + if args.seq_mode == "gellermann": + seq = gellermann_k( + n=args.segments, k=2, run_cap=args.run_cap, + seed=args.seed, exact_counts=args.exact_counts, half_balance=args.half_balance + ) + else: + base = debruijn(k=2, m=args.debruijn_order) + seq = tile_or_trim(base, args.segments) - # k=2 families, balanced, limited runs - seq = gellermann_k(n=args.segments, k=2, run_cap=3, seed=args.seed) audit = audit_sequence(seq, k=2) print("Sequence (0=informative, 1=uninformative):", seq.tolist()) print("Audit:", audit) @@ -103,7 +138,8 @@ def main(): "batch", "steps_per_segment", "S", "A", "gamma", "alpha", "epsilon", "cost_pass", "cost_peek", "cost_eat", - "seed" + "seed", + "seq_mode", "run_cap", "exact_counts", "half_balance", "debruijn_order" ] with open(args.out, "w", newline="") as f: w = csv.writer(f) @@ -125,15 +161,17 @@ def main(): args.batch, args.steps_per_segment, S, A, 0.97, 0.2, 0.05, float(c[0]), float(c[1]), float(c[2]), - args.seed + args.seed, + args.seq_mode, args.run_cap, int(bool(args.exact_counts)), int(bool(args.half_balance)), args.debruijn_order ] w.writerow(row) - print(f"Seg {i:02d} [{fam[:5].upper()}] peek_rate={res['peek_rate']:.3f} " + print(f"Seg {i:02d} [{fam[:5].upper()}] " + f"peek_rate={res['peek_rate']:.3f} " f"avg_reward/step={res['avg_reward_per_box_step']:.4f}") print(f"\nWrote CSV → {args.out}") + if __name__ == "__main__": main() -