# yellowbrick.classifier.threshold
# Threshold classifier visualizer for Yellowbrick.
#
# Author: Nathan Danielsen <ndanielsen@gmail.com.com>
# Created: Wed April 26 20:17:29 2017 -0700
#
# Copyright (C) 2017 District Data Labs
# For license information, see LICENSE.txt
import bisect
import numpy as np
from scipy.stats import mstats
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_curve
from yellowbrick.exceptions import YellowbrickTypeError
from yellowbrick.style.colors import resolve_colors
from yellowbrick.base import ModelVisualizer
from yellowbrick.utils import isclassifier
##########################################################################
# Quick Methods
##########################################################################
def thresholdviz(model,
X,
y,
color=None,
n_trials=50,
test_size_percent=0.1,
quantiles=(0.1, 0.5, 0.9),
random_state=0,
**kwargs):
"""Quick method for ThresholdVisualizer.
Visualizes the bounds of precision, recall and queue rate at different
thresholds for binary targets after a given number of trials.
The visualization shows the threshold precentage on the x-axis which can be
compared against the queue rate, precision, and recall as percentages on
the y-axis. The default that each of the medium curves is set at the 90%%
central interval, but can be adjusted.
This visualization will help the user determine given their tolerances for
precision, queue and recall the appropriate threshold to set in their
application.
See also::
``http://blog.insightdatalabs.com/visualizing-classifier-thresholds/``
Parameters
----------
model : a Scikit-Learn classifier, required
Should be an instance of a classifier otherwise a will raise a
YellowbrickTypeError exception on instantiation.
color : string, default: None
Optional string or matplotlib cmap to colorize lines
Use either color to colorize the lines on a per class basis
n_trials : integer, default: 50
Number of trials to conduct via train_test_split
quantiles : sequence, default: (0.1, 0.5, .9)
Setting the quantiles for visualizing model variability using
scipy.stats.mstats.mquantiles
random_state : integer, default: None
Random state integer for sampling in train_test_split
kwargs : keyword arguments passed to the super class.
Returns
-------
ax : matplotlib axes
Returns the axes that the parallel coordinates were drawn on.
"""
# Instantiate the visualizer
visualizer = ThresholdVisualizer(
model,
color=color,
n_trials=n_trials,
test_size_percent=test_size_percent,
quantiles=quantiles,
random_state=random_state,
**kwargs)
# Fit and transform the visualizer (calls draw)
visualizer.fit_poof(X, y)
# Return the axes object on the visualizer
return visualizer.ax
##########################################################################
# Static ThresholdVisualizer Visualizer
##########################################################################
class ThresholdVisualizer(ModelVisualizer):
"""Visualizes the bounds of precision, recall and queue rate at different
thresholds for binary targets after a given number of trials.
The visualization shows the threshold precentage on the x-axis which can be
compared against the queue rate, precision, and recall as percentages on
the y-axis. The default that each of the medium curves is set at the 90%%
central interval, but can be adjusted.
This visualization will help the user determine given their tolerances for
precision, queue and recall the appropriate threshold to set in their
application.
See also::
``http://blog.insightdatalabs.com/visualizing-classifier-thresholds/``
Parameters
----------
model : a Scikit-Learn classifier, required
Should be an instance of a classifier otherwise a will raise a
YellowbrickTypeError exception on instantiation.
color : string, default: None
Optional string or matplotlib cmap to colorize lines
Use either color to colorize the lines on a per class basis
n_trials : integer, default: 50
Number of trials to conduct via train_test_split
quantiles : sequence, default: (0.1, 0.5, .9)
Setting the quantiles for visualizing model variability using
scipy.stats.mstats.mquantiles
random_state : integer, default: None
Random state integer for sampling in train_test_split
kwargs : keyword arguments passed to the super class.
"""
def __init__(self,
model,
n_trials=50,
test_size_percent=0.1,
quantiles=(0.1, 0.5, 0.9),
random_state=None,
**kwargs):
# Check to see if model is an instance of a classifier.
# Should return an error if it isn't.
if not isclassifier(model):
raise YellowbrickTypeError(
"This estimator is not a classifier; try a regression or clustering score visualizer instead!"
)
super(ThresholdVisualizer, self).__init__(model, **kwargs)
self.estimator = model
self.n_trials = n_trials
self.test_size_percent = test_size_percent
self.quantiles = quantiles
self.random_state = random_state
# to be set later
self.plot_data = None
def fit(self, X, y=None, **kwargs):
"""
Parameters
----------
X : ndarray or DataFrame of shape n x m
A matrix of n instances with m features
y : ndarray or Series of length n
An array or series of target or class values
kwargs: dict
keyword arguments passed to Scikit-Learn API.
Returns
-------
self : instance
Returns the instance of the visualizer
"""
self.plot_data = []
for _ in range(self.n_trials):
train_X, test_X, train_y, test_y = train_test_split(
X,
y,
test_size=self.test_size_percent,
random_state=self.random_state # defaults to None
)
self.estimator.fit(train_X, train_y)
# get prediction probabilities for each
predictions = self.estimator.predict_proba(test_X)[:, 1]
precision, recall, thresholds = precision_recall_curve(
test_y, predictions)
# add one to each so that thresh ends at 1
thresholds = np.append(thresholds, 1)
queue_rate = []
for threshold in thresholds:
queue_rate.append((predictions >= threshold).mean())
trial_data = {
'thresholds': thresholds,
'precision': precision,
'recall': recall,
'queue_rate': queue_rate
}
self.plot_data.append(trial_data)
return self.draw()
def draw(self, *kwargs):
"""
Renders the visualization
Parameters
----------
kwargs: dict
keyword arguments passed to Scikit-Learn API.
Returns
-------
self.ax : AxesSubplot of the visualizer
Returns the AxesSubplot instance of the visualizer
"""
# Set the colors from the supplied values or reasonable defaults
color_values = resolve_colors(n_colors=3, colors=self.color)
uniform_thresholds = np.linspace(0, 1, num=101)
uniform_precision_plots = []
uniform_recall_plots = []
uniform_queue_rate_plots = []
for data in self.plot_data:
uniform_precision = []
uniform_recall = []
uniform_queue_rate = []
for ut in uniform_thresholds:
index = bisect.bisect_left(data['thresholds'], ut)
uniform_precision.append(data['precision'][index])
uniform_recall.append(data['recall'][index])
uniform_queue_rate.append(data['queue_rate'][index])
uniform_precision_plots.append(uniform_precision)
uniform_recall_plots.append(uniform_recall)
uniform_queue_rate_plots.append(uniform_queue_rate)
uplots = (uniform_precision_plots, uniform_recall_plots, uniform_queue_rate_plots)
for uniform_plot, color in zip(uplots, color_values):
# Compute the lower, median, and upper plots
lower, median, upper = mstats.mquantiles(uniform_plot, prob=self.quantiles, axis=0)
# Draw the median line
self.ax.plot(uniform_thresholds, median, color=color)
# Draw the fill between the lower and upper bounds
self.ax.fill_between(uniform_thresholds, upper, lower, alpha=0.5, linewidth=0, color=color)
return self.ax
def finalize(self, **kwargs):
"""Finalize executes any subclass-specific axes finalization steps.
The user calls poof and poof calls finalize.
Parameters
----------
kwargs: generic keyword arguments.
"""
super(ThresholdVisualizer, self).finalize(**kwargs)
# Set the title
if self.title is None:
self.set_title("Threshold Plot of Binary Classifier")
self.ax.legend(
('precision', 'recall', 'queue_rate'), frameon=True, loc='best')
self.ax.set_xlabel('threshold')
self.ax.set_ylabel('percent')
def fit_poof(self, X, y=None, **kwargs):
"""Convience method to fit, draw and poof / finalize the visualizer in
one step after instantiation.
Parameters
----------
X : ndarray or DataFrame of shape n x m
A matrix of n instances with m features
y : ndarray or Series of length n
An array or series of target or class values
kwargs: dict
keyword arguments passed to Scikit-Learn API.
Returns
-------
self : instance
Returns the instance of the visualizer
"""
self.fit(X, y)
self.poof()
return self
ThreshViz = ThresholdVisualizer