VeriBib-rs/scripts/plot_eval.py

45 lines
1.3 KiB
Python

import sys, csv
import matplotlib.pyplot as plt
def main(csv_path, out_prefix="curves"):
thr = []
prec = []
rec = []
with open(csv_path, newline="") as f:
r = csv.DictReader(f)
for row in r:
thr.append(float(row["threshold"]))
prec.append(float(row["precision"]))
rec.append(float(row["recall"]))
# Precision-Recall
plt.figure()
plt.plot(rec, prec, marker='o')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.grid(True)
plt.tight_layout()
plt.savefig(f"{out_prefix}_pr.png", dpi=160)
# Precision & Recall vs Threshold
plt.figure()
plt.plot(thr, prec, marker='o', label="Precision")
plt.plot(thr, rec, marker='o', label="Recall")
plt.xlabel('Confidence threshold')
plt.ylabel('Score')
plt.title('Precision/Recall vs Threshold')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig(f"{out_prefix}_thr.png", dpi=160)
if __name__ == "__main__":
if len(sys.argv) < 2:
print("Usage: python scripts/plot_eval.py metrics.csv [out_prefix]")
sys.exit(1)
csv_path = sys.argv[1]
out_prefix = sys.argv[2] if len(sys.argv) > 2 else "curves"
main(csv_path, out_prefix)