Spaces:
Sleeping
Sleeping
| """Plotting helpers using Matplotlib.""" | |
| from __future__ import annotations | |
| from pathlib import Path | |
| from typing import Iterable | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from sklearn.calibration import calibration_curve | |
| from sklearn.metrics import precision_recall_curve, roc_curve | |
| def plot_reliability_curve( | |
| y_true: Iterable[float], | |
| y_score: Iterable[float], | |
| *, | |
| path: str | Path, | |
| n_bins: int = 10, | |
| ) -> None: | |
| """Save a reliability curve plot.""" | |
| prob_true, prob_pred = calibration_curve(y_true, y_score, n_bins=n_bins) | |
| fig, ax = plt.subplots(figsize=(4, 4)) | |
| ax.plot(prob_pred, prob_true, marker="o", label="Model") | |
| ax.plot([0, 1], [0, 1], linestyle="--", color="gray", label="Perfect") | |
| ax.set_xlabel("Predicted probability") | |
| ax.set_ylabel("Observed frequency") | |
| ax.set_title("Reliability curve") | |
| ax.legend() | |
| fig.tight_layout() | |
| Path(path).parent.mkdir(parents=True, exist_ok=True) | |
| fig.savefig(path) | |
| plt.close(fig) | |
| def plot_precision_recall( | |
| y_true: Iterable[float], | |
| y_score: Iterable[float], | |
| *, | |
| path: str | Path, | |
| ) -> None: | |
| """Save a precision-recall curve.""" | |
| precision, recall, _ = precision_recall_curve(y_true, y_score) | |
| fig, ax = plt.subplots(figsize=(4, 4)) | |
| ax.plot(recall, precision, label="Model") | |
| ax.set_xlabel("Recall") | |
| ax.set_ylabel("Precision") | |
| ax.set_title("Precision-Recall curve") | |
| ax.set_xlim([0, 1]) | |
| ax.set_ylim([0, 1]) | |
| fig.tight_layout() | |
| Path(path).parent.mkdir(parents=True, exist_ok=True) | |
| fig.savefig(path) | |
| plt.close(fig) | |
| def plot_roc_curve( | |
| y_true: Iterable[float], | |
| y_score: Iterable[float], | |
| *, | |
| path: str | Path, | |
| ) -> None: | |
| """Save an ROC curve plot.""" | |
| fpr, tpr, _ = roc_curve(y_true, y_score) | |
| fig, ax = plt.subplots(figsize=(4, 4)) | |
| ax.plot(fpr, tpr, label="Model") | |
| ax.plot([0, 1], [0, 1], linestyle="--", color="gray") | |
| ax.set_xlabel("False positive rate") | |
| ax.set_ylabel("True positive rate") | |
| ax.set_title("ROC curve") | |
| ax.set_xlim([0, 1]) | |
| ax.set_ylim([0, 1]) | |
| fig.tight_layout() | |
| Path(path).parent.mkdir(parents=True, exist_ok=True) | |
| fig.savefig(path) | |
| plt.close(fig) | |