Skip to content

swanlab.pr_curve

Github Source Code

python
pr_curve(
    y_true: Union[List, np.ndarray],
    y_pred_proba: Union[List, np.ndarray],
    title: Optional[str, bool] = None,
) -> None
ParameterDescription
y_true(Union[List, np.ndarray]) True labels, the actual class labels (0 or 1) in binary classification problems
y_pred_proba(Union[List, np.ndarray]) Prediction probabilities, the model's predicted probability values for the positive class (range 0-1)
title(Optional[str, bool]) Whether to display chart title, defaults to None

Introduction

Draw a PR (Precision-Recall) curve to evaluate the performance of binary classification models. The PR curve shows the relationship between precision and recall at different thresholds.

PR curves are particularly suitable for handling imbalanced datasets and can better evaluate model performance on minority classes.

Basic Usage

python
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import xgboost as xgb
import swanlab

# Generate sample data
X, y = make_classification(n_samples=1000, n_features=20, n_informative=2, n_redundant=10, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# Train model
model = xgb.XGBClassifier(use_label_encoder=False, eval_metric='logloss')
model.fit(X_train, y_train)

# Get prediction probabilities
y_pred_proba = model.predict_proba(X_test)[:, 1]

# Initialize SwanLab
swanlab.init(project="PR-Curve-Demo", experiment_name="PR-Curve-Example")

# Log PR curve
swanlab.log({
    "pr_curve": swanlab.pr_curve(y_test, y_pred_proba, title=True)
})

swanlab.finish()

Custom Title

python
# Don't show title (default)
pr_curve = swanlab.pr_curve(y_test, y_pred_proba, title=False)
swanlab.log({"pr_curve_no_title": pr_curve})

# Show title
pr_curve = swanlab.pr_curve(y_test, y_pred_proba, title=True)
swanlab.log({"pr_curve_with_title": pr_curve})

# Custom title
pr_curve = swanlab.pr_curve(y_test, y_pred_proba, title="demo")
swanlab.log({"pr_curve_with_custom_title": pr_curve})

Notes

  1. Data Format: y_true and y_pred_proba can be lists or numpy arrays
  2. Binary Classification: This function is specifically for binary classification problems
  3. Probability Values: y_pred_proba should be the model's predicted probability for the positive class, ranging from 0-1
  4. Dependencies: Requires installation of scikit-learn and pyecharts packages
  5. AUC Calculation: The function automatically calculates the area under the PR curve (AUC), but does not display it in the title by default