makiling's picture
Upload folder using huggingface_hub
5f58699 verified
raw
history blame
2.22 kB
"""Plotting helpers using Matplotlib."""
from __future__ import annotations
from pathlib import Path
from typing import Iterable
import matplotlib.pyplot as plt
import numpy as np
from sklearn.calibration import calibration_curve
from sklearn.metrics import precision_recall_curve, roc_curve
def plot_reliability_curve(
y_true: Iterable[float],
y_score: Iterable[float],
*,
path: str | Path,
n_bins: int = 10,
) -> None:
"""Save a reliability curve plot."""
prob_true, prob_pred = calibration_curve(y_true, y_score, n_bins=n_bins)
fig, ax = plt.subplots(figsize=(4, 4))
ax.plot(prob_pred, prob_true, marker="o", label="Model")
ax.plot([0, 1], [0, 1], linestyle="--", color="gray", label="Perfect")
ax.set_xlabel("Predicted probability")
ax.set_ylabel("Observed frequency")
ax.set_title("Reliability curve")
ax.legend()
fig.tight_layout()
Path(path).parent.mkdir(parents=True, exist_ok=True)
fig.savefig(path)
plt.close(fig)
def plot_precision_recall(
y_true: Iterable[float],
y_score: Iterable[float],
*,
path: str | Path,
) -> None:
"""Save a precision-recall curve."""
precision, recall, _ = precision_recall_curve(y_true, y_score)
fig, ax = plt.subplots(figsize=(4, 4))
ax.plot(recall, precision, label="Model")
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
ax.set_title("Precision-Recall curve")
ax.set_xlim([0, 1])
ax.set_ylim([0, 1])
fig.tight_layout()
Path(path).parent.mkdir(parents=True, exist_ok=True)
fig.savefig(path)
plt.close(fig)
def plot_roc_curve(
y_true: Iterable[float],
y_score: Iterable[float],
*,
path: str | Path,
) -> None:
"""Save an ROC curve plot."""
fpr, tpr, _ = roc_curve(y_true, y_score)
fig, ax = plt.subplots(figsize=(4, 4))
ax.plot(fpr, tpr, label="Model")
ax.plot([0, 1], [0, 1], linestyle="--", color="gray")
ax.set_xlabel("False positive rate")
ax.set_ylabel("True positive rate")
ax.set_title("ROC curve")
ax.set_xlim([0, 1])
ax.set_ylim([0, 1])
fig.tight_layout()
Path(path).parent.mkdir(parents=True, exist_ok=True)
fig.savefig(path)
plt.close(fig)