alice/bench/plot_curiosity.py

220 lines
9.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import argparse, os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
# ---------- Style helpers ----------
OKABE_ITO = ["#000000", "#E69F00", "#56B4E9", "#009E73",
"#F0E442", "#0072B2", "#D55E00", "#CC79A7"]
def ensure_dir(d: str):
os.makedirs(d, exist_ok=True)
def apply_accessible_style(high_contrast: bool, font_scale: float, palette: str, large_fonts: bool):
"""
Apply a readable, colorblind-safe theme.
"""
# Base theme
ctx = "talk" if (large_fonts or font_scale >= 1.3) else "notebook"
sns.set_theme(style="whitegrid", context=ctx)
sns.set(font_scale=max(font_scale, 2.2 if large_fonts else font_scale))
# Palette
if palette == "hc":
sns.set_palette(OKABE_ITO)
else:
try:
sns.set_palette("colorblind")
except Exception:
pass # fall back to mpl defaults
# Matplotlib rc for readability
rc = plt.rcParams
rc["figure.facecolor"] = "white"
rc["axes.facecolor"] = "white"
rc["savefig.facecolor"] = "white"
rc["axes.edgecolor"] = "black"
rc["axes.grid"] = True
rc["grid.color"] = "#D0D0D0"
rc["grid.linewidth"] = 0.9 if (large_fonts or high_contrast) else 0.8
rc["legend.frameon"] = True
rc["legend.framealpha"] = 0.95
rc["legend.facecolor"] = "white"
rc["legend.edgecolor"] = "#333333"
rc["axes.titleweight"] = "bold" if high_contrast else "normal"
rc["axes.labelweight"] = "bold" if (large_fonts or high_contrast) else "regular"
rc["lines.linewidth"] = 3.2 if (large_fonts or high_contrast) else 2.0
rc["lines.markersize"] = 8.5 if (large_fonts or high_contrast) else 6.0
rc["xtick.major.size"] = 6 if (large_fonts or high_contrast) else 5
rc["ytick.major.size"] = 6 if (large_fonts or high_contrast) else 5
def load_csv(path: str) -> pd.DataFrame:
df = pd.read_csv(path)
# coerce numeric cols
num_cols = ["segment_index","peek_rate","avg_reward_per_box_step","batch","steps_per_segment","S","A",
"gamma","alpha","epsilon","cost_pass","cost_peek","cost_eat","seed"]
for c in num_cols:
if c in df.columns:
df[c] = pd.to_numeric(df[c], errors="coerce")
# Keep family as categorical with a stable order
if "family" in df.columns:
order = ["informative", "uninformative"]
cats = [x for x in order if x in df["family"].unique().tolist()]
df["family"] = pd.Categorical(df["family"], categories=cats, ordered=True)
return df
# Seaborn 0.12/0.13 compatibility: prefer errorbar=('ci',95), fallback to ci=95
def _barplot_with_ci(df: pd.DataFrame, x: str, y: str, title: str,
annotate: bool, value_fmt: str):
try:
ax = sns.barplot(data=df, x=x, y=y, estimator=np.mean, errorbar=('ci', 95))
except TypeError:
ax = sns.barplot(data=df, x=x, y=y, estimator=np.mean, ci=95)
plt.title(title)
plt.xlabel("")
plt.tight_layout()
if annotate:
_annotate_bars(ax, fmt=value_fmt)
def _annotate_bars(ax: plt.Axes, fmt: str = ".3f"):
"""
Annotate each bar with its height (value). Assumes a simple single-hue bar plot.
"""
# Compute an offset proportional to axis span
ymin, ymax = ax.get_ylim()
offset = 0.01 * (ymax - ymin)
for patch in ax.patches:
height = patch.get_height()
if np.isnan(height):
continue
x = patch.get_x() + patch.get_width() / 2
ax.text(x, height + offset, format(height, fmt),
ha="center", va="bottom", fontsize=max(10, plt.rcParams['font.size'] * 0.9),
fontweight="bold")
# ---------- Plotters ----------
def plot_peek_rate_by_segment(df: pd.DataFrame, outdir: str, dpi: int, fmt: str, transparent: bool):
plt.figure(figsize=(10.5,5.2))
sns.lineplot(data=df, x="segment_index", y="peek_rate", hue="family", marker="o")
plt.title("Peek rate by segment")
plt.xlabel("Segment")
plt.ylabel("Peek rate (fraction of actions)")
plt.tight_layout()
p = os.path.join(outdir, f"peek_rate_by_segment.{fmt}")
plt.tight_layout()
plt.savefig(p, dpi=dpi, transparent=transparent)
plt.close()
return p
def plot_reward_by_segment(df: pd.DataFrame, outdir: str, dpi: int, fmt: str, transparent: bool):
plt.figure(figsize=(10.5,5.2))
sns.lineplot(data=df, x="segment_index", y="avg_reward_per_box_step", hue="family", marker="o")
plt.title("Average reward per box-step by segment")
plt.xlabel("Segment")
plt.ylabel("Avg reward per box-step")
plt.tight_layout()
p = os.path.join(outdir, f"avg_reward_by_segment.{fmt}")
plt.tight_layout()
plt.savefig(p, dpi=dpi, transparent=transparent)
plt.close()
return p
def plot_summary_bars(df: pd.DataFrame, outdir: str, dpi: int, fmt: str, transparent: bool,
annotate: bool, value_fmt: str):
plt.figure(figsize=(7.4,5.4))
_barplot_with_ci(df, x="family", y="peek_rate",
title="Mean peek rate by family (95% CI)",
annotate=annotate, value_fmt=value_fmt)
plt.ylabel("Peek rate")
p1 = os.path.join(outdir, f"summary_peek_rate.{fmt}")
plt.savefig(p1, dpi=dpi, transparent=transparent)
plt.close()
plt.figure(figsize=(7.4,5.4))
_barplot_with_ci(df, x="family", y="avg_reward_per_box_step",
title="Mean avg reward per box-step by family (95% CI)",
annotate=annotate, value_fmt=value_fmt)
plt.ylabel("Avg reward per box-step")
p2 = os.path.join(outdir, f"summary_avg_reward.{fmt}")
plt.tight_layout()
plt.savefig(p2, dpi=dpi, transparent=transparent)
plt.close()
return p1, p2
def plot_reward_vs_peek(df: pd.DataFrame, outdir: str, dpi: int, fmt: str, transparent: bool):
plt.figure(figsize=(8.0,6.4))
sns.scatterplot(data=df, x="peek_rate", y="avg_reward_per_box_step", hue="family",
s=80, edgecolor="k", linewidth=0.6)
# Trend lines per family (no CIs to keep it uncluttered)
sns.regplot(data=df[df["family"]=="informative"], x="peek_rate", y="avg_reward_per_box_step",
scatter=False, ci=None, truncate=True, line_kws={"linewidth": 3})
sns.regplot(data=df[df["family"]=="uninformative"], x="peek_rate", y="avg_reward_per_box_step",
scatter=False, ci=None, truncate=True, line_kws={"linewidth": 3})
plt.title("Reward vs. Peek rate")
plt.xlabel("Peek rate")
plt.ylabel("Avg reward per box-step")
plt.tight_layout()
p = os.path.join(outdir, f"reward_vs_peek_scatter.{fmt}")
plt.tight_layout()
plt.savefig(p, dpi=dpi, transparent=transparent)
plt.close()
return p
# ---------- CLI ----------
def main():
ap = argparse.ArgumentParser(description="Plot curiosity demo CSV with accessible styling.")
ap.add_argument("--in", dest="inp", type=str, required=True, help="Input CSV from run_curiosity_demo.py")
ap.add_argument("--outdir", type=str, default="results/figs", help="Directory to save figures")
ap.add_argument("--high_contrast", action="store_true", help="Use high-contrast, bold styling")
ap.add_argument("--large_fonts", action="store_true", help="Use extra-large fonts and thicker lines")
ap.add_argument("--font_scale", type=float, default=1.6, help="Base font scale (ignored if --large_fonts is bigger)")
ap.add_argument("--palette", type=str, default="auto", choices=["auto","hc"], help="Color palette: auto=colorblind, hc=OkabeIto")
ap.add_argument("--dpi", type=int, default=180, help="Figure DPI")
ap.add_argument("--format", type=str, default="png", choices=["png","pdf","svg"], help="Output format")
ap.add_argument("--transparent", action="store_true", help="Save figures with transparent background")
ap.add_argument("--no_annotate", action="store_true", help="Disable numeric labels on bar charts")
ap.add_argument("--value_fmt", type=str, default=".3f", help="Number format for bar labels (e.g., .2f, .1% not supported)")
args = ap.parse_args()
ensure_dir(args.outdir)
apply_accessible_style(high_contrast=args.high_contrast,
font_scale=args.font_scale,
palette=args.palette,
large_fonts=args.large_fonts)
df = load_csv(args.inp)
print(f"Loaded {len(df)} rows from {args.inp}")
# Console summary (accessible)
grp = df.groupby("family").agg(
mean_peek=("peek_rate","mean"),
std_peek=("peek_rate","std"),
mean_reward=("avg_reward_per_box_step","mean"),
std_reward=("avg_reward_per_box_step","std"),
n=("peek_rate","count")
)
print("\nSummary by family:\n", grp)
annotate = (not args.no_annotate)
paths = []
paths.append(plot_peek_rate_by_segment(df, args.outdir, args.dpi, args.format, args.transparent))
paths.append(plot_reward_by_segment(df, args.outdir, args.dpi, args.format, args.transparent))
p1, p2 = plot_summary_bars(df, args.outdir, args.dpi, args.format, args.transparent,
annotate=annotate, value_fmt=args.value_fmt)
paths.extend([p1, p2])
paths.append(plot_reward_vs_peek(df, args.outdir, args.dpi, args.format, args.transparent))
print("\nSaved figures:")
for p in paths:
print(" -", p)
if __name__ == "__main__":
main()