alice/bench/run_curiosity_demo.py

140 lines
4.7 KiB
Python

from __future__ import annotations
import argparse, csv, os
from datetime import datetime
import numpy as np
from alice_fast.batched_belt import BatchedBelt
from alice_fast.kernels import PASS, PEEK, EAT
from alice_tools.sequence import gellermann_k, audit_sequence
"""
Curiosity demo with CSV logging.
Two puzzle families:
0 = Informative: PEEK (non-advancing) makes EAT good; without PEEK, EAT is bad.
1 = Uninformative: PEEK costs but does not change EAT value.
We encode "non-advancing" by augmenting state:
S=2 states per puzzle: 0=unpeeked, 1=peeked.
PEEK: 0->1, 1->1 (information state only)
EAT: returns to 0; reward depends on family+state
PASS: resets to unpeeked (small cost).
"""
def build_tables_informative():
S, A = 2, 3
tt = np.zeros((S, A), dtype=np.int32)
tt[:, PASS] = 0
tt[0, PEEK] = 1
tt[1, PEEK] = 1
tt[:, EAT] = 0
rt = np.zeros((S, A, S), dtype=np.float32)
base_costs = np.array([-0.02, -0.05, 0.0], dtype=np.float32)
rt[0, EAT, 0] = -0.25 # uninformed 'eat' is risky/bad
rt[1, EAT, 0] = 1.0 # informed 'eat' is good
return S, A, tt, rt, base_costs
def build_tables_uninformative():
S, A = 2, 3
tt = np.zeros((S, A), dtype=np.int32)
tt[:, PASS] = 0
tt[0, PEEK] = 1
tt[1, PEEK] = 1
tt[:, EAT] = 0
rt = np.zeros((S, A, S), dtype=np.float32)
base_costs = np.array([-0.02, -0.05, 0.0], dtype=np.float32)
rt[0, EAT, 0] = 0.30 # same payoff whether peeked or not
rt[1, EAT, 0] = 0.30
return S, A, tt, rt, base_costs
def run_segment(belt: BatchedBelt, steps: int):
total_reward = 0.0
total_peeks = 0
total_actions = 0
for _ in range(steps):
out = belt.step_learn()
total_reward += float(out["rewards"].sum())
total_peeks += int(np.sum(out["actions"] == PEEK))
total_actions += out["actions"].size
return {
"avg_reward_per_box_step": total_reward / total_actions,
"peek_rate": total_peeks / total_actions
}
def ensure_parent(path: str):
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--out", type=str, default=None, help="CSV output path (default: results/curiosity_demo_YYYYmmdd-HHMMSS.csv)")
ap.add_argument("--segments", type=int, default=20, help="Number of segments")
ap.add_argument("--steps_per_segment", type=int, default=1000, help="Steps per segment")
ap.add_argument("--batch", type=int, default=4096, help="Batch size")
ap.add_argument("--seed", type=int, default=7, help="Base RNG seed")
args = ap.parse_args()
if args.out is None:
stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
args.out = f"results/curiosity_demo_{stamp}.csv"
ensure_parent(args.out)
# Build families
S0, A0, tt0, rt0, costs0 = build_tables_informative()
S1, A1, tt1, rt1, costs1 = build_tables_uninformative()
assert (S0, A0) == (S1, A1)
S, A = S0, A0
# Two belts (same shape, different reward tables)
belt_inf = BatchedBelt(S, A, tt0, rt0, costs0, batch_size=args.batch, gamma=0.97, alpha=0.2, epsilon=0.05, seed=args.seed)
belt_uninf= BatchedBelt(S, A, tt1, rt1, costs1, batch_size=args.batch, gamma=0.97, alpha=0.2, epsilon=0.05, seed=args.seed+1)
# k=2 families, balanced, limited runs
seq = gellermann_k(n=args.segments, k=2, run_cap=3, seed=args.seed)
audit = audit_sequence(seq, k=2)
print("Sequence (0=informative, 1=uninformative):", seq.tolist())
print("Audit:", audit)
# CSV header
header = [
"segment_index", "family", "peek_rate", "avg_reward_per_box_step",
"batch", "steps_per_segment", "S", "A",
"gamma", "alpha", "epsilon",
"cost_pass", "cost_peek", "cost_eat",
"seed"
]
with open(args.out, "w", newline="") as f:
w = csv.writer(f)
w.writerow(header)
for i, sym in enumerate(seq):
if sym == 0:
res = run_segment(belt_inf, args.steps_per_segment)
fam = "informative"
c = costs0
else:
res = run_segment(belt_uninf, args.steps_per_segment)
fam = "uninformative"
c = costs1
row = [
i, fam,
f"{res['peek_rate']:.6f}", f"{res['avg_reward_per_box_step']:.6f}",
args.batch, args.steps_per_segment, S, A,
0.97, 0.2, 0.05,
float(c[0]), float(c[1]), float(c[2]),
args.seed
]
w.writerow(row)
print(f"Seg {i:02d} [{fam[:5].upper()}] peek_rate={res['peek_rate']:.3f} "
f"avg_reward/step={res['avg_reward_per_box_step']:.4f}")
print(f"\nWrote CSV → {args.out}")
if __name__ == "__main__":
main()