from __future__ import annotations import argparse, os import numpy as np import pandas as pd import matplotlib.pyplot as plt import seaborn as sns # ---------- Style helpers ---------- OKABE_ITO = ["#000000", "#E69F00", "#56B4E9", "#009E73", "#F0E442", "#0072B2", "#D55E00", "#CC79A7"] def ensure_dir(d: str): os.makedirs(d, exist_ok=True) def apply_accessible_style(high_contrast: bool, font_scale: float, palette: str, large_fonts: bool): """ Apply a readable, colorblind-safe theme. """ # Base theme ctx = "talk" if (large_fonts or font_scale >= 1.3) else "notebook" sns.set_theme(style="whitegrid", context=ctx) sns.set(font_scale=max(font_scale, 2.2 if large_fonts else font_scale)) # Palette if palette == "hc": sns.set_palette(OKABE_ITO) else: try: sns.set_palette("colorblind") except Exception: pass # fall back to mpl defaults # Matplotlib rc for readability rc = plt.rcParams rc["figure.facecolor"] = "white" rc["axes.facecolor"] = "white" rc["savefig.facecolor"] = "white" rc["axes.edgecolor"] = "black" rc["axes.grid"] = True rc["grid.color"] = "#D0D0D0" rc["grid.linewidth"] = 0.9 if (large_fonts or high_contrast) else 0.8 rc["legend.frameon"] = True rc["legend.framealpha"] = 0.95 rc["legend.facecolor"] = "white" rc["legend.edgecolor"] = "#333333" rc["axes.titleweight"] = "bold" if high_contrast else "normal" rc["axes.labelweight"] = "bold" if (large_fonts or high_contrast) else "regular" rc["lines.linewidth"] = 3.2 if (large_fonts or high_contrast) else 2.0 rc["lines.markersize"] = 8.5 if (large_fonts or high_contrast) else 6.0 rc["xtick.major.size"] = 6 if (large_fonts or high_contrast) else 5 rc["ytick.major.size"] = 6 if (large_fonts or high_contrast) else 5 def load_csv(path: str) -> pd.DataFrame: df = pd.read_csv(path) # coerce numeric cols num_cols = ["segment_index","peek_rate","avg_reward_per_box_step","batch","steps_per_segment","S","A", "gamma","alpha","epsilon","cost_pass","cost_peek","cost_eat","seed"] for c in num_cols: if c in df.columns: df[c] = pd.to_numeric(df[c], errors="coerce") # Keep family as categorical with a stable order if "family" in df.columns: order = ["informative", "uninformative"] cats = [x for x in order if x in df["family"].unique().tolist()] df["family"] = pd.Categorical(df["family"], categories=cats, ordered=True) return df # Seaborn 0.12/0.13 compatibility: prefer errorbar=('ci',95), fallback to ci=95 def _barplot_with_ci(df: pd.DataFrame, x: str, y: str, title: str, annotate: bool, value_fmt: str): try: ax = sns.barplot(data=df, x=x, y=y, estimator=np.mean, errorbar=('ci', 95)) except TypeError: ax = sns.barplot(data=df, x=x, y=y, estimator=np.mean, ci=95) plt.title(title) plt.xlabel("") plt.tight_layout() if annotate: _annotate_bars(ax, fmt=value_fmt) def _annotate_bars(ax: plt.Axes, fmt: str = ".3f"): """ Annotate each bar with its height (value). Assumes a simple single-hue bar plot. """ # Compute an offset proportional to axis span ymin, ymax = ax.get_ylim() offset = 0.01 * (ymax - ymin) for patch in ax.patches: height = patch.get_height() if np.isnan(height): continue x = patch.get_x() + patch.get_width() / 2 ax.text(x, height + offset, format(height, fmt), ha="center", va="bottom", fontsize=max(10, plt.rcParams['font.size'] * 0.9), fontweight="bold") # ---------- Plotters ---------- def plot_peek_rate_by_segment(df: pd.DataFrame, outdir: str, dpi: int, fmt: str, transparent: bool): plt.figure(figsize=(10.5,5.2)) sns.lineplot(data=df, x="segment_index", y="peek_rate", hue="family", marker="o") plt.title("Peek rate by segment") plt.xlabel("Segment") plt.ylabel("Peek rate (fraction of actions)") plt.tight_layout() p = os.path.join(outdir, f"peek_rate_by_segment.{fmt}") plt.tight_layout() plt.savefig(p, dpi=dpi, transparent=transparent) plt.close() return p def plot_reward_by_segment(df: pd.DataFrame, outdir: str, dpi: int, fmt: str, transparent: bool): plt.figure(figsize=(10.5,5.2)) sns.lineplot(data=df, x="segment_index", y="avg_reward_per_box_step", hue="family", marker="o") plt.title("Average reward per box-step by segment") plt.xlabel("Segment") plt.ylabel("Avg reward per box-step") plt.tight_layout() p = os.path.join(outdir, f"avg_reward_by_segment.{fmt}") plt.tight_layout() plt.savefig(p, dpi=dpi, transparent=transparent) plt.close() return p def plot_summary_bars(df: pd.DataFrame, outdir: str, dpi: int, fmt: str, transparent: bool, annotate: bool, value_fmt: str): plt.figure(figsize=(7.4,5.4)) _barplot_with_ci(df, x="family", y="peek_rate", title="Mean peek rate by family (95% CI)", annotate=annotate, value_fmt=value_fmt) plt.ylabel("Peek rate") p1 = os.path.join(outdir, f"summary_peek_rate.{fmt}") plt.savefig(p1, dpi=dpi, transparent=transparent) plt.close() plt.figure(figsize=(7.4,5.4)) _barplot_with_ci(df, x="family", y="avg_reward_per_box_step", title="Mean avg reward per box-step by family (95% CI)", annotate=annotate, value_fmt=value_fmt) plt.ylabel("Avg reward per box-step") p2 = os.path.join(outdir, f"summary_avg_reward.{fmt}") plt.tight_layout() plt.savefig(p2, dpi=dpi, transparent=transparent) plt.close() return p1, p2 def plot_reward_vs_peek(df: pd.DataFrame, outdir: str, dpi: int, fmt: str, transparent: bool): plt.figure(figsize=(8.0,6.4)) sns.scatterplot(data=df, x="peek_rate", y="avg_reward_per_box_step", hue="family", s=80, edgecolor="k", linewidth=0.6) # Trend lines per family (no CIs to keep it uncluttered) sns.regplot(data=df[df["family"]=="informative"], x="peek_rate", y="avg_reward_per_box_step", scatter=False, ci=None, truncate=True, line_kws={"linewidth": 3}) sns.regplot(data=df[df["family"]=="uninformative"], x="peek_rate", y="avg_reward_per_box_step", scatter=False, ci=None, truncate=True, line_kws={"linewidth": 3}) plt.title("Reward vs. Peek rate") plt.xlabel("Peek rate") plt.ylabel("Avg reward per box-step") plt.tight_layout() p = os.path.join(outdir, f"reward_vs_peek_scatter.{fmt}") plt.tight_layout() plt.savefig(p, dpi=dpi, transparent=transparent) plt.close() return p # ---------- CLI ---------- def main(): ap = argparse.ArgumentParser(description="Plot curiosity demo CSV with accessible styling.") ap.add_argument("--in", dest="inp", type=str, required=True, help="Input CSV from run_curiosity_demo.py") ap.add_argument("--outdir", type=str, default="results/figs", help="Directory to save figures") ap.add_argument("--high_contrast", action="store_true", help="Use high-contrast, bold styling") ap.add_argument("--large_fonts", action="store_true", help="Use extra-large fonts and thicker lines") ap.add_argument("--font_scale", type=float, default=1.6, help="Base font scale (ignored if --large_fonts is bigger)") ap.add_argument("--palette", type=str, default="auto", choices=["auto","hc"], help="Color palette: auto=colorblind, hc=Okabe–Ito") ap.add_argument("--dpi", type=int, default=180, help="Figure DPI") ap.add_argument("--format", type=str, default="png", choices=["png","pdf","svg"], help="Output format") ap.add_argument("--transparent", action="store_true", help="Save figures with transparent background") ap.add_argument("--no_annotate", action="store_true", help="Disable numeric labels on bar charts") ap.add_argument("--value_fmt", type=str, default=".3f", help="Number format for bar labels (e.g., .2f, .1% not supported)") args = ap.parse_args() ensure_dir(args.outdir) apply_accessible_style(high_contrast=args.high_contrast, font_scale=args.font_scale, palette=args.palette, large_fonts=args.large_fonts) df = load_csv(args.inp) print(f"Loaded {len(df)} rows from {args.inp}") # Console summary (accessible) grp = df.groupby("family").agg( mean_peek=("peek_rate","mean"), std_peek=("peek_rate","std"), mean_reward=("avg_reward_per_box_step","mean"), std_reward=("avg_reward_per_box_step","std"), n=("peek_rate","count") ) print("\nSummary by family:\n", grp) annotate = (not args.no_annotate) paths = [] paths.append(plot_peek_rate_by_segment(df, args.outdir, args.dpi, args.format, args.transparent)) paths.append(plot_reward_by_segment(df, args.outdir, args.dpi, args.format, args.transparent)) p1, p2 = plot_summary_bars(df, args.outdir, args.dpi, args.format, args.transparent, annotate=annotate, value_fmt=args.value_fmt) paths.extend([p1, p2]) paths.append(plot_reward_vs_peek(df, args.outdir, args.dpi, args.format, args.transparent)) print("\nSaved figures:") for p in paths: print(" -", p) if __name__ == "__main__": main()