import sys, csv import matplotlib.pyplot as plt def main(csv_path, out_prefix="curves"): thr = [] prec = [] rec = [] with open(csv_path, newline="") as f: r = csv.DictReader(f) for row in r: thr.append(float(row["threshold"])) prec.append(float(row["precision"])) rec.append(float(row["recall"])) # Precision-Recall plt.figure() plt.plot(rec, prec, marker='o') plt.xlabel('Recall') plt.ylabel('Precision') plt.title('Precision-Recall Curve') plt.grid(True) plt.tight_layout() plt.savefig(f"{out_prefix}_pr.png", dpi=160) # Precision & Recall vs Threshold plt.figure() plt.plot(thr, prec, marker='o', label="Precision") plt.plot(thr, rec, marker='o', label="Recall") plt.xlabel('Confidence threshold') plt.ylabel('Score') plt.title('Precision/Recall vs Threshold') plt.legend() plt.grid(True) plt.tight_layout() plt.savefig(f"{out_prefix}_thr.png", dpi=160) if __name__ == "__main__": if len(sys.argv) < 2: print("Usage: python scripts/plot_eval.py metrics.csv [out_prefix]") sys.exit(1) csv_path = sys.argv[1] out_prefix = sys.argv[2] if len(sys.argv) > 2 else "curves" main(csv_path, out_prefix)