Updated use of gellerman sequence features.

This commit is contained in:
Diane Blackwood 2025-09-20 21:00:54 -04:00
parent 4564f53577
commit 88208d0fa4
1 changed files with 52 additions and 14 deletions

View File

@ -2,9 +2,15 @@ from __future__ import annotations
import argparse, csv, os
from datetime import datetime
import numpy as np
from alice_fast.batched_belt import BatchedBelt
from alice_fast.kernels import PASS, PEEK, EAT
from alice_tools.sequence import gellermann_k, audit_sequence
from alice_tools.sequence import (
gellermann_k,
debruijn,
tile_or_trim,
audit_sequence,
)
"""
Curiosity demo with CSV logging.
@ -20,6 +26,8 @@ We encode "non-advancing" by augmenting state:
PASS: resets to unpeeked (small cost).
"""
# ---------- Puzzle builders ----------
def build_tables_informative():
S, A = 2, 3
tt = np.zeros((S, A), dtype=np.int32)
@ -35,6 +43,7 @@ def build_tables_informative():
rt[1, EAT, 0] = 1.0 # informed 'eat' is good
return S, A, tt, rt, base_costs
def build_tables_uninformative():
S, A = 2, 3
tt = np.zeros((S, A), dtype=np.int32)
@ -46,10 +55,13 @@ def build_tables_uninformative():
rt = np.zeros((S, A, S), dtype=np.float32)
base_costs = np.array([-0.02, -0.05, 0.0], dtype=np.float32)
rt[0, EAT, 0] = 0.30 # same payoff whether peeked or not
# same payoff whether peeked or not
rt[0, EAT, 0] = 0.30
rt[1, EAT, 0] = 0.30
return S, A, tt, rt, base_costs
# ---------- Run one scheduled segment ----------
def run_segment(belt: BatchedBelt, steps: int):
total_reward = 0.0
total_peeks = 0
@ -67,13 +79,29 @@ def run_segment(belt: BatchedBelt, steps: int):
def ensure_parent(path: str):
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
# ---------- CLI / main ----------
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--out", type=str, default=None, help="CSV output path (default: results/curiosity_demo_YYYYmmdd-HHMMSS.csv)")
ap = argparse.ArgumentParser(description="Curiosity demo with CSV logging and configurable k-ary schedules.")
ap.add_argument("--out", type=str, default=None,
help="CSV output path (default: results/curiosity_demo_YYYYmmdd-HHMMSS.csv)")
ap.add_argument("--segments", type=int, default=20, help="Number of segments")
ap.add_argument("--steps_per_segment", type=int, default=1000, help="Steps per segment")
ap.add_argument("--batch", type=int, default=4096, help="Batch size")
ap.add_argument("--seed", type=int, default=7, help="Base RNG seed")
# Sequence controls
ap.add_argument("--seq_mode", type=str, default="gellermann",
choices=["gellermann", "debruijn"],
help="Scheduling mode for families across segments")
ap.add_argument("--run_cap", type=int, default=3, help="Max run length per symbol (Gellermann)")
ap.add_argument("--exact_counts", action="store_true",
help="(Gellermann) Force exactly equal counts (rounds n down to multiple of k)")
ap.add_argument("--half_balance", action="store_true",
help="(Gellermann) Enforce first-half counts <= ceil(target/2) per symbol")
ap.add_argument("--debruijn_order", type=int, default=4,
help="Order m for de Bruijn(k=2, m); final length is tiled/trimmed to --segments")
args = ap.parse_args()
if args.out is None:
@ -81,18 +109,25 @@ def main():
args.out = f"results/curiosity_demo_{stamp}.csv"
ensure_parent(args.out)
# Build families
# Build families (same S,A; different reward tables)
S0, A0, tt0, rt0, costs0 = build_tables_informative()
S1, A1, tt1, rt1, costs1 = build_tables_uninformative()
assert (S0, A0) == (S1, A1)
S, A = S0, A0
# Two belts (same shape, different reward tables)
belt_inf = BatchedBelt(S, A, tt0, rt0, costs0, batch_size=args.batch, gamma=0.97, alpha=0.2, epsilon=0.05, seed=args.seed)
belt_uninf= BatchedBelt(S, A, tt1, rt1, costs1, batch_size=args.batch, gamma=0.97, alpha=0.2, epsilon=0.05, seed=args.seed+1)
belt_inf = BatchedBelt(S, A, tt0, rt0, costs0, batch_size=args.batch, gamma=0.97, alpha=0.2, epsilon=0.05, seed=args.seed)
belt_uninf = BatchedBelt(S, A, tt1, rt1, costs1, batch_size=args.batch, gamma=0.97, alpha=0.2, epsilon=0.05, seed=args.seed+1)
# Schedule the N segments with a k-ary sequence over families {0,1}
if args.seq_mode == "gellermann":
seq = gellermann_k(
n=args.segments, k=2, run_cap=args.run_cap,
seed=args.seed, exact_counts=args.exact_counts, half_balance=args.half_balance
)
else:
base = debruijn(k=2, m=args.debruijn_order)
seq = tile_or_trim(base, args.segments)
# k=2 families, balanced, limited runs
seq = gellermann_k(n=args.segments, k=2, run_cap=3, seed=args.seed)
audit = audit_sequence(seq, k=2)
print("Sequence (0=informative, 1=uninformative):", seq.tolist())
print("Audit:", audit)
@ -103,7 +138,8 @@ def main():
"batch", "steps_per_segment", "S", "A",
"gamma", "alpha", "epsilon",
"cost_pass", "cost_peek", "cost_eat",
"seed"
"seed",
"seq_mode", "run_cap", "exact_counts", "half_balance", "debruijn_order"
]
with open(args.out, "w", newline="") as f:
w = csv.writer(f)
@ -125,15 +161,17 @@ def main():
args.batch, args.steps_per_segment, S, A,
0.97, 0.2, 0.05,
float(c[0]), float(c[1]), float(c[2]),
args.seed
args.seed,
args.seq_mode, args.run_cap, int(bool(args.exact_counts)), int(bool(args.half_balance)), args.debruijn_order
]
w.writerow(row)
print(f"Seg {i:02d} [{fam[:5].upper()}] peek_rate={res['peek_rate']:.3f} "
print(f"Seg {i:02d} [{fam[:5].upper()}] "
f"peek_rate={res['peek_rate']:.3f} "
f"avg_reward/step={res['avg_reward_per_box_step']:.4f}")
print(f"\nWrote CSV → {args.out}")
if __name__ == "__main__":
main()