220 lines
9.2 KiB
Python
220 lines
9.2 KiB
Python
from __future__ import annotations
|
||
import argparse, os
|
||
import numpy as np
|
||
import pandas as pd
|
||
import matplotlib.pyplot as plt
|
||
import seaborn as sns
|
||
|
||
# ---------- Style helpers ----------
|
||
|
||
OKABE_ITO = ["#000000", "#E69F00", "#56B4E9", "#009E73",
|
||
"#F0E442", "#0072B2", "#D55E00", "#CC79A7"]
|
||
|
||
def ensure_dir(d: str):
|
||
os.makedirs(d, exist_ok=True)
|
||
|
||
def apply_accessible_style(high_contrast: bool, font_scale: float, palette: str, large_fonts: bool):
|
||
"""
|
||
Apply a readable, colorblind-safe theme.
|
||
"""
|
||
# Base theme
|
||
ctx = "talk" if (large_fonts or font_scale >= 1.3) else "notebook"
|
||
sns.set_theme(style="whitegrid", context=ctx)
|
||
sns.set(font_scale=max(font_scale, 2.2 if large_fonts else font_scale))
|
||
|
||
# Palette
|
||
if palette == "hc":
|
||
sns.set_palette(OKABE_ITO)
|
||
else:
|
||
try:
|
||
sns.set_palette("colorblind")
|
||
except Exception:
|
||
pass # fall back to mpl defaults
|
||
|
||
# Matplotlib rc for readability
|
||
rc = plt.rcParams
|
||
rc["figure.facecolor"] = "white"
|
||
rc["axes.facecolor"] = "white"
|
||
rc["savefig.facecolor"] = "white"
|
||
rc["axes.edgecolor"] = "black"
|
||
rc["axes.grid"] = True
|
||
rc["grid.color"] = "#D0D0D0"
|
||
rc["grid.linewidth"] = 0.9 if (large_fonts or high_contrast) else 0.8
|
||
rc["legend.frameon"] = True
|
||
rc["legend.framealpha"] = 0.95
|
||
rc["legend.facecolor"] = "white"
|
||
rc["legend.edgecolor"] = "#333333"
|
||
rc["axes.titleweight"] = "bold" if high_contrast else "normal"
|
||
rc["axes.labelweight"] = "bold" if (large_fonts or high_contrast) else "regular"
|
||
rc["lines.linewidth"] = 3.2 if (large_fonts or high_contrast) else 2.0
|
||
rc["lines.markersize"] = 8.5 if (large_fonts or high_contrast) else 6.0
|
||
rc["xtick.major.size"] = 6 if (large_fonts or high_contrast) else 5
|
||
rc["ytick.major.size"] = 6 if (large_fonts or high_contrast) else 5
|
||
|
||
def load_csv(path: str) -> pd.DataFrame:
|
||
df = pd.read_csv(path)
|
||
# coerce numeric cols
|
||
num_cols = ["segment_index","peek_rate","avg_reward_per_box_step","batch","steps_per_segment","S","A",
|
||
"gamma","alpha","epsilon","cost_pass","cost_peek","cost_eat","seed"]
|
||
for c in num_cols:
|
||
if c in df.columns:
|
||
df[c] = pd.to_numeric(df[c], errors="coerce")
|
||
# Keep family as categorical with a stable order
|
||
if "family" in df.columns:
|
||
order = ["informative", "uninformative"]
|
||
cats = [x for x in order if x in df["family"].unique().tolist()]
|
||
df["family"] = pd.Categorical(df["family"], categories=cats, ordered=True)
|
||
return df
|
||
|
||
# Seaborn 0.12/0.13 compatibility: prefer errorbar=('ci',95), fallback to ci=95
|
||
def _barplot_with_ci(df: pd.DataFrame, x: str, y: str, title: str,
|
||
annotate: bool, value_fmt: str):
|
||
try:
|
||
ax = sns.barplot(data=df, x=x, y=y, estimator=np.mean, errorbar=('ci', 95))
|
||
except TypeError:
|
||
ax = sns.barplot(data=df, x=x, y=y, estimator=np.mean, ci=95)
|
||
plt.title(title)
|
||
plt.xlabel("")
|
||
plt.tight_layout()
|
||
|
||
if annotate:
|
||
_annotate_bars(ax, fmt=value_fmt)
|
||
|
||
def _annotate_bars(ax: plt.Axes, fmt: str = ".3f"):
|
||
"""
|
||
Annotate each bar with its height (value). Assumes a simple single-hue bar plot.
|
||
"""
|
||
# Compute an offset proportional to axis span
|
||
ymin, ymax = ax.get_ylim()
|
||
offset = 0.01 * (ymax - ymin)
|
||
for patch in ax.patches:
|
||
height = patch.get_height()
|
||
if np.isnan(height):
|
||
continue
|
||
x = patch.get_x() + patch.get_width() / 2
|
||
ax.text(x, height + offset, format(height, fmt),
|
||
ha="center", va="bottom", fontsize=max(10, plt.rcParams['font.size'] * 0.9),
|
||
fontweight="bold")
|
||
|
||
# ---------- Plotters ----------
|
||
|
||
def plot_peek_rate_by_segment(df: pd.DataFrame, outdir: str, dpi: int, fmt: str, transparent: bool):
|
||
plt.figure(figsize=(10.5,5.2))
|
||
sns.lineplot(data=df, x="segment_index", y="peek_rate", hue="family", marker="o")
|
||
plt.title("Peek rate by segment")
|
||
plt.xlabel("Segment")
|
||
plt.ylabel("Peek rate (fraction of actions)")
|
||
plt.tight_layout()
|
||
p = os.path.join(outdir, f"peek_rate_by_segment.{fmt}")
|
||
plt.tight_layout()
|
||
plt.savefig(p, dpi=dpi, transparent=transparent)
|
||
plt.close()
|
||
return p
|
||
|
||
def plot_reward_by_segment(df: pd.DataFrame, outdir: str, dpi: int, fmt: str, transparent: bool):
|
||
plt.figure(figsize=(10.5,5.2))
|
||
sns.lineplot(data=df, x="segment_index", y="avg_reward_per_box_step", hue="family", marker="o")
|
||
plt.title("Average reward per box-step by segment")
|
||
plt.xlabel("Segment")
|
||
plt.ylabel("Avg reward per box-step")
|
||
plt.tight_layout()
|
||
p = os.path.join(outdir, f"avg_reward_by_segment.{fmt}")
|
||
plt.tight_layout()
|
||
plt.savefig(p, dpi=dpi, transparent=transparent)
|
||
plt.close()
|
||
return p
|
||
|
||
def plot_summary_bars(df: pd.DataFrame, outdir: str, dpi: int, fmt: str, transparent: bool,
|
||
annotate: bool, value_fmt: str):
|
||
plt.figure(figsize=(7.4,5.4))
|
||
_barplot_with_ci(df, x="family", y="peek_rate",
|
||
title="Mean peek rate by family (95% CI)",
|
||
annotate=annotate, value_fmt=value_fmt)
|
||
plt.ylabel("Peek rate")
|
||
p1 = os.path.join(outdir, f"summary_peek_rate.{fmt}")
|
||
plt.savefig(p1, dpi=dpi, transparent=transparent)
|
||
plt.close()
|
||
|
||
plt.figure(figsize=(7.4,5.4))
|
||
_barplot_with_ci(df, x="family", y="avg_reward_per_box_step",
|
||
title="Mean avg reward per box-step by family (95% CI)",
|
||
annotate=annotate, value_fmt=value_fmt)
|
||
plt.ylabel("Avg reward per box-step")
|
||
p2 = os.path.join(outdir, f"summary_avg_reward.{fmt}")
|
||
plt.tight_layout()
|
||
plt.savefig(p2, dpi=dpi, transparent=transparent)
|
||
plt.close()
|
||
return p1, p2
|
||
|
||
def plot_reward_vs_peek(df: pd.DataFrame, outdir: str, dpi: int, fmt: str, transparent: bool):
|
||
plt.figure(figsize=(8.0,6.4))
|
||
sns.scatterplot(data=df, x="peek_rate", y="avg_reward_per_box_step", hue="family",
|
||
s=80, edgecolor="k", linewidth=0.6)
|
||
# Trend lines per family (no CIs to keep it uncluttered)
|
||
sns.regplot(data=df[df["family"]=="informative"], x="peek_rate", y="avg_reward_per_box_step",
|
||
scatter=False, ci=None, truncate=True, line_kws={"linewidth": 3})
|
||
sns.regplot(data=df[df["family"]=="uninformative"], x="peek_rate", y="avg_reward_per_box_step",
|
||
scatter=False, ci=None, truncate=True, line_kws={"linewidth": 3})
|
||
plt.title("Reward vs. Peek rate")
|
||
plt.xlabel("Peek rate")
|
||
plt.ylabel("Avg reward per box-step")
|
||
plt.tight_layout()
|
||
p = os.path.join(outdir, f"reward_vs_peek_scatter.{fmt}")
|
||
plt.tight_layout()
|
||
plt.savefig(p, dpi=dpi, transparent=transparent)
|
||
plt.close()
|
||
return p
|
||
|
||
# ---------- CLI ----------
|
||
|
||
def main():
|
||
ap = argparse.ArgumentParser(description="Plot curiosity demo CSV with accessible styling.")
|
||
ap.add_argument("--in", dest="inp", type=str, required=True, help="Input CSV from run_curiosity_demo.py")
|
||
ap.add_argument("--outdir", type=str, default="results/figs", help="Directory to save figures")
|
||
ap.add_argument("--high_contrast", action="store_true", help="Use high-contrast, bold styling")
|
||
ap.add_argument("--large_fonts", action="store_true", help="Use extra-large fonts and thicker lines")
|
||
ap.add_argument("--font_scale", type=float, default=1.6, help="Base font scale (ignored if --large_fonts is bigger)")
|
||
ap.add_argument("--palette", type=str, default="auto", choices=["auto","hc"], help="Color palette: auto=colorblind, hc=Okabe–Ito")
|
||
ap.add_argument("--dpi", type=int, default=180, help="Figure DPI")
|
||
ap.add_argument("--format", type=str, default="png", choices=["png","pdf","svg"], help="Output format")
|
||
ap.add_argument("--transparent", action="store_true", help="Save figures with transparent background")
|
||
ap.add_argument("--no_annotate", action="store_true", help="Disable numeric labels on bar charts")
|
||
ap.add_argument("--value_fmt", type=str, default=".3f", help="Number format for bar labels (e.g., .2f, .1% not supported)")
|
||
args = ap.parse_args()
|
||
|
||
ensure_dir(args.outdir)
|
||
apply_accessible_style(high_contrast=args.high_contrast,
|
||
font_scale=args.font_scale,
|
||
palette=args.palette,
|
||
large_fonts=args.large_fonts)
|
||
|
||
df = load_csv(args.inp)
|
||
print(f"Loaded {len(df)} rows from {args.inp}")
|
||
|
||
# Console summary (accessible)
|
||
grp = df.groupby("family").agg(
|
||
mean_peek=("peek_rate","mean"),
|
||
std_peek=("peek_rate","std"),
|
||
mean_reward=("avg_reward_per_box_step","mean"),
|
||
std_reward=("avg_reward_per_box_step","std"),
|
||
n=("peek_rate","count")
|
||
)
|
||
print("\nSummary by family:\n", grp)
|
||
|
||
annotate = (not args.no_annotate)
|
||
|
||
paths = []
|
||
paths.append(plot_peek_rate_by_segment(df, args.outdir, args.dpi, args.format, args.transparent))
|
||
paths.append(plot_reward_by_segment(df, args.outdir, args.dpi, args.format, args.transparent))
|
||
p1, p2 = plot_summary_bars(df, args.outdir, args.dpi, args.format, args.transparent,
|
||
annotate=annotate, value_fmt=args.value_fmt)
|
||
paths.extend([p1, p2])
|
||
paths.append(plot_reward_vs_peek(df, args.outdir, args.dpi, args.format, args.transparent))
|
||
|
||
print("\nSaved figures:")
|
||
for p in paths:
|
||
print(" -", p)
|
||
|
||
if __name__ == "__main__":
|
||
main()
|