"""Plotting helpers using Matplotlib.""" from __future__ import annotations from pathlib import Path from typing import Iterable import matplotlib.pyplot as plt import numpy as np from sklearn.calibration import calibration_curve from sklearn.metrics import precision_recall_curve, roc_curve def plot_reliability_curve( y_true: Iterable[float], y_score: Iterable[float], *, path: str | Path, n_bins: int = 10, ) -> None: """Save a reliability curve plot.""" prob_true, prob_pred = calibration_curve(y_true, y_score, n_bins=n_bins) fig, ax = plt.subplots(figsize=(4, 4)) ax.plot(prob_pred, prob_true, marker="o", label="Model") ax.plot([0, 1], [0, 1], linestyle="--", color="gray", label="Perfect") ax.set_xlabel("Predicted probability") ax.set_ylabel("Observed frequency") ax.set_title("Reliability curve") ax.legend() fig.tight_layout() Path(path).parent.mkdir(parents=True, exist_ok=True) fig.savefig(path) plt.close(fig) def plot_precision_recall( y_true: Iterable[float], y_score: Iterable[float], *, path: str | Path, ) -> None: """Save a precision-recall curve.""" precision, recall, _ = precision_recall_curve(y_true, y_score) fig, ax = plt.subplots(figsize=(4, 4)) ax.plot(recall, precision, label="Model") ax.set_xlabel("Recall") ax.set_ylabel("Precision") ax.set_title("Precision-Recall curve") ax.set_xlim([0, 1]) ax.set_ylim([0, 1]) fig.tight_layout() Path(path).parent.mkdir(parents=True, exist_ok=True) fig.savefig(path) plt.close(fig) def plot_roc_curve( y_true: Iterable[float], y_score: Iterable[float], *, path: str | Path, ) -> None: """Save an ROC curve plot.""" fpr, tpr, _ = roc_curve(y_true, y_score) fig, ax = plt.subplots(figsize=(4, 4)) ax.plot(fpr, tpr, label="Model") ax.plot([0, 1], [0, 1], linestyle="--", color="gray") ax.set_xlabel("False positive rate") ax.set_ylabel("True positive rate") ax.set_title("ROC curve") ax.set_xlim([0, 1]) ax.set_ylim([0, 1]) fig.tight_layout() Path(path).parent.mkdir(parents=True, exist_ok=True) fig.savefig(path) plt.close(fig)