45 lines
1.3 KiB
Python
45 lines
1.3 KiB
Python
|
|
import sys, csv
|
|
import matplotlib.pyplot as plt
|
|
|
|
def main(csv_path, out_prefix="curves"):
|
|
thr = []
|
|
prec = []
|
|
rec = []
|
|
with open(csv_path, newline="") as f:
|
|
r = csv.DictReader(f)
|
|
for row in r:
|
|
thr.append(float(row["threshold"]))
|
|
prec.append(float(row["precision"]))
|
|
rec.append(float(row["recall"]))
|
|
|
|
# Precision-Recall
|
|
plt.figure()
|
|
plt.plot(rec, prec, marker='o')
|
|
plt.xlabel('Recall')
|
|
plt.ylabel('Precision')
|
|
plt.title('Precision-Recall Curve')
|
|
plt.grid(True)
|
|
plt.tight_layout()
|
|
plt.savefig(f"{out_prefix}_pr.png", dpi=160)
|
|
|
|
# Precision & Recall vs Threshold
|
|
plt.figure()
|
|
plt.plot(thr, prec, marker='o', label="Precision")
|
|
plt.plot(thr, rec, marker='o', label="Recall")
|
|
plt.xlabel('Confidence threshold')
|
|
plt.ylabel('Score')
|
|
plt.title('Precision/Recall vs Threshold')
|
|
plt.legend()
|
|
plt.grid(True)
|
|
plt.tight_layout()
|
|
plt.savefig(f"{out_prefix}_thr.png", dpi=160)
|
|
|
|
if __name__ == "__main__":
|
|
if len(sys.argv) < 2:
|
|
print("Usage: python scripts/plot_eval.py metrics.csv [out_prefix]")
|
|
sys.exit(1)
|
|
csv_path = sys.argv[1]
|
|
out_prefix = sys.argv[2] if len(sys.argv) > 2 else "curves"
|
|
main(csv_path, out_prefix)
|