import matplotlib.pyplot as plt

from yellowbrick.datasets import load_spam
from yellowbrick.classifier import PrecisionRecallCurve
from sklearn.model_selection import train_test_split as tts
from sklearn.linear_model import RidgeClassifier, LogisticRegression

# Load the dataset and split into train/test splits
X, y = load_spam()

X_train, X_test, y_train, y_test = tts(
    X, y, test_size=0.2, shuffle=True, random_state=0
)

# Create the visualizers, fit, score, and show them
models = [
    RidgeClassifier(random_state=0), LogisticRegression(random_state=0)
]
_, axes = plt.subplots(ncols=2, figsize=(8,4))

for idx, ax in enumerate(axes.flatten()):
    viz = PrecisionRecallCurve(models[idx], ax=ax, show=False)
    viz.fit(X_train, y_train)
    viz.score(X_test, y_test)
    viz.finalize()

plt.show()