Updated use of gellerman sequence features.
This commit is contained in:
parent
4564f53577
commit
88208d0fa4
|
|
@ -2,9 +2,15 @@ from __future__ import annotations
|
||||||
import argparse, csv, os
|
import argparse, csv, os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from alice_fast.batched_belt import BatchedBelt
|
from alice_fast.batched_belt import BatchedBelt
|
||||||
from alice_fast.kernels import PASS, PEEK, EAT
|
from alice_fast.kernels import PASS, PEEK, EAT
|
||||||
from alice_tools.sequence import gellermann_k, audit_sequence
|
from alice_tools.sequence import (
|
||||||
|
gellermann_k,
|
||||||
|
debruijn,
|
||||||
|
tile_or_trim,
|
||||||
|
audit_sequence,
|
||||||
|
)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Curiosity demo with CSV logging.
|
Curiosity demo with CSV logging.
|
||||||
|
|
@ -20,6 +26,8 @@ We encode "non-advancing" by augmenting state:
|
||||||
PASS: resets to unpeeked (small cost).
|
PASS: resets to unpeeked (small cost).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# ---------- Puzzle builders ----------
|
||||||
|
|
||||||
def build_tables_informative():
|
def build_tables_informative():
|
||||||
S, A = 2, 3
|
S, A = 2, 3
|
||||||
tt = np.zeros((S, A), dtype=np.int32)
|
tt = np.zeros((S, A), dtype=np.int32)
|
||||||
|
|
@ -35,6 +43,7 @@ def build_tables_informative():
|
||||||
rt[1, EAT, 0] = 1.0 # informed 'eat' is good
|
rt[1, EAT, 0] = 1.0 # informed 'eat' is good
|
||||||
return S, A, tt, rt, base_costs
|
return S, A, tt, rt, base_costs
|
||||||
|
|
||||||
|
|
||||||
def build_tables_uninformative():
|
def build_tables_uninformative():
|
||||||
S, A = 2, 3
|
S, A = 2, 3
|
||||||
tt = np.zeros((S, A), dtype=np.int32)
|
tt = np.zeros((S, A), dtype=np.int32)
|
||||||
|
|
@ -46,10 +55,13 @@ def build_tables_uninformative():
|
||||||
rt = np.zeros((S, A, S), dtype=np.float32)
|
rt = np.zeros((S, A, S), dtype=np.float32)
|
||||||
base_costs = np.array([-0.02, -0.05, 0.0], dtype=np.float32)
|
base_costs = np.array([-0.02, -0.05, 0.0], dtype=np.float32)
|
||||||
|
|
||||||
rt[0, EAT, 0] = 0.30 # same payoff whether peeked or not
|
# same payoff whether peeked or not
|
||||||
|
rt[0, EAT, 0] = 0.30
|
||||||
rt[1, EAT, 0] = 0.30
|
rt[1, EAT, 0] = 0.30
|
||||||
return S, A, tt, rt, base_costs
|
return S, A, tt, rt, base_costs
|
||||||
|
|
||||||
|
# ---------- Run one scheduled segment ----------
|
||||||
|
|
||||||
def run_segment(belt: BatchedBelt, steps: int):
|
def run_segment(belt: BatchedBelt, steps: int):
|
||||||
total_reward = 0.0
|
total_reward = 0.0
|
||||||
total_peeks = 0
|
total_peeks = 0
|
||||||
|
|
@ -67,13 +79,29 @@ def run_segment(belt: BatchedBelt, steps: int):
|
||||||
def ensure_parent(path: str):
|
def ensure_parent(path: str):
|
||||||
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
|
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
|
||||||
|
|
||||||
|
# ---------- CLI / main ----------
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
ap = argparse.ArgumentParser()
|
ap = argparse.ArgumentParser(description="Curiosity demo with CSV logging and configurable k-ary schedules.")
|
||||||
ap.add_argument("--out", type=str, default=None, help="CSV output path (default: results/curiosity_demo_YYYYmmdd-HHMMSS.csv)")
|
ap.add_argument("--out", type=str, default=None,
|
||||||
|
help="CSV output path (default: results/curiosity_demo_YYYYmmdd-HHMMSS.csv)")
|
||||||
ap.add_argument("--segments", type=int, default=20, help="Number of segments")
|
ap.add_argument("--segments", type=int, default=20, help="Number of segments")
|
||||||
ap.add_argument("--steps_per_segment", type=int, default=1000, help="Steps per segment")
|
ap.add_argument("--steps_per_segment", type=int, default=1000, help="Steps per segment")
|
||||||
ap.add_argument("--batch", type=int, default=4096, help="Batch size")
|
ap.add_argument("--batch", type=int, default=4096, help="Batch size")
|
||||||
ap.add_argument("--seed", type=int, default=7, help="Base RNG seed")
|
ap.add_argument("--seed", type=int, default=7, help="Base RNG seed")
|
||||||
|
|
||||||
|
# Sequence controls
|
||||||
|
ap.add_argument("--seq_mode", type=str, default="gellermann",
|
||||||
|
choices=["gellermann", "debruijn"],
|
||||||
|
help="Scheduling mode for families across segments")
|
||||||
|
ap.add_argument("--run_cap", type=int, default=3, help="Max run length per symbol (Gellermann)")
|
||||||
|
ap.add_argument("--exact_counts", action="store_true",
|
||||||
|
help="(Gellermann) Force exactly equal counts (rounds n down to multiple of k)")
|
||||||
|
ap.add_argument("--half_balance", action="store_true",
|
||||||
|
help="(Gellermann) Enforce first-half counts <= ceil(target/2) per symbol")
|
||||||
|
ap.add_argument("--debruijn_order", type=int, default=4,
|
||||||
|
help="Order m for de Bruijn(k=2, m); final length is tiled/trimmed to --segments")
|
||||||
|
|
||||||
args = ap.parse_args()
|
args = ap.parse_args()
|
||||||
|
|
||||||
if args.out is None:
|
if args.out is None:
|
||||||
|
|
@ -81,18 +109,25 @@ def main():
|
||||||
args.out = f"results/curiosity_demo_{stamp}.csv"
|
args.out = f"results/curiosity_demo_{stamp}.csv"
|
||||||
ensure_parent(args.out)
|
ensure_parent(args.out)
|
||||||
|
|
||||||
# Build families
|
# Build families (same S,A; different reward tables)
|
||||||
S0, A0, tt0, rt0, costs0 = build_tables_informative()
|
S0, A0, tt0, rt0, costs0 = build_tables_informative()
|
||||||
S1, A1, tt1, rt1, costs1 = build_tables_uninformative()
|
S1, A1, tt1, rt1, costs1 = build_tables_uninformative()
|
||||||
assert (S0, A0) == (S1, A1)
|
assert (S0, A0) == (S1, A1)
|
||||||
S, A = S0, A0
|
S, A = S0, A0
|
||||||
|
|
||||||
# Two belts (same shape, different reward tables)
|
belt_inf = BatchedBelt(S, A, tt0, rt0, costs0, batch_size=args.batch, gamma=0.97, alpha=0.2, epsilon=0.05, seed=args.seed)
|
||||||
belt_inf = BatchedBelt(S, A, tt0, rt0, costs0, batch_size=args.batch, gamma=0.97, alpha=0.2, epsilon=0.05, seed=args.seed)
|
belt_uninf = BatchedBelt(S, A, tt1, rt1, costs1, batch_size=args.batch, gamma=0.97, alpha=0.2, epsilon=0.05, seed=args.seed+1)
|
||||||
belt_uninf= BatchedBelt(S, A, tt1, rt1, costs1, batch_size=args.batch, gamma=0.97, alpha=0.2, epsilon=0.05, seed=args.seed+1)
|
|
||||||
|
# Schedule the N segments with a k-ary sequence over families {0,1}
|
||||||
|
if args.seq_mode == "gellermann":
|
||||||
|
seq = gellermann_k(
|
||||||
|
n=args.segments, k=2, run_cap=args.run_cap,
|
||||||
|
seed=args.seed, exact_counts=args.exact_counts, half_balance=args.half_balance
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
base = debruijn(k=2, m=args.debruijn_order)
|
||||||
|
seq = tile_or_trim(base, args.segments)
|
||||||
|
|
||||||
# k=2 families, balanced, limited runs
|
|
||||||
seq = gellermann_k(n=args.segments, k=2, run_cap=3, seed=args.seed)
|
|
||||||
audit = audit_sequence(seq, k=2)
|
audit = audit_sequence(seq, k=2)
|
||||||
print("Sequence (0=informative, 1=uninformative):", seq.tolist())
|
print("Sequence (0=informative, 1=uninformative):", seq.tolist())
|
||||||
print("Audit:", audit)
|
print("Audit:", audit)
|
||||||
|
|
@ -103,7 +138,8 @@ def main():
|
||||||
"batch", "steps_per_segment", "S", "A",
|
"batch", "steps_per_segment", "S", "A",
|
||||||
"gamma", "alpha", "epsilon",
|
"gamma", "alpha", "epsilon",
|
||||||
"cost_pass", "cost_peek", "cost_eat",
|
"cost_pass", "cost_peek", "cost_eat",
|
||||||
"seed"
|
"seed",
|
||||||
|
"seq_mode", "run_cap", "exact_counts", "half_balance", "debruijn_order"
|
||||||
]
|
]
|
||||||
with open(args.out, "w", newline="") as f:
|
with open(args.out, "w", newline="") as f:
|
||||||
w = csv.writer(f)
|
w = csv.writer(f)
|
||||||
|
|
@ -125,15 +161,17 @@ def main():
|
||||||
args.batch, args.steps_per_segment, S, A,
|
args.batch, args.steps_per_segment, S, A,
|
||||||
0.97, 0.2, 0.05,
|
0.97, 0.2, 0.05,
|
||||||
float(c[0]), float(c[1]), float(c[2]),
|
float(c[0]), float(c[1]), float(c[2]),
|
||||||
args.seed
|
args.seed,
|
||||||
|
args.seq_mode, args.run_cap, int(bool(args.exact_counts)), int(bool(args.half_balance)), args.debruijn_order
|
||||||
]
|
]
|
||||||
w.writerow(row)
|
w.writerow(row)
|
||||||
|
|
||||||
print(f"Seg {i:02d} [{fam[:5].upper()}] peek_rate={res['peek_rate']:.3f} "
|
print(f"Seg {i:02d} [{fam[:5].upper()}] "
|
||||||
|
f"peek_rate={res['peek_rate']:.3f} "
|
||||||
f"avg_reward/step={res['avg_reward_per_box_step']:.4f}")
|
f"avg_reward/step={res['avg_reward_per_box_step']:.4f}")
|
||||||
|
|
||||||
print(f"\nWrote CSV → {args.out}")
|
print(f"\nWrote CSV → {args.out}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue