Source code for yellowbrick.model_selection.cross_validation

# yellowbrick.model_selection.cross_validation
# Implements cross-validation score plotting for model selection.
#
# Author:   Prema Damodaran Roman
# Created:  Wed June 6 2018 13:32:00 -0500
# Author:   Rebecca Bilbro
# Updated:  Fri Aug 10 13:15:43 2018 -0500
#
# Copyright (C) 2018 The scikit-yb developers
# For license information, see LICENSE.txt
#
# ID: cross_validation.py [7f47800] pdamo24@gmail.com $

"""
Implements cross-validation score plotting for model selection.
"""

##########################################################################
# Imports
##########################################################################

import numpy as np
import matplotlib.ticker as ticker

from yellowbrick.base import ModelVisualizer
from sklearn.model_selection import cross_val_score


##########################################################################
# CVScores Visualizer
##########################################################################


[docs]class CVScores(ModelVisualizer): """ CVScores displays cross-validated scores as a bar chart, with the average of the scores plotted as a horizontal line. Parameters ---------- estimator : a scikit-learn estimator An object that implements ``fit`` and ``predict``, can be a classifier, regressor, or clusterer so long as there is also a valid associated scoring metric. Note that the object is cloned for each validation. ax : matplotlib.Axes object, optional The axes object to plot the figure on. cv : int, cross-validation generator or an iterable, optional Determines the cross-validation splitting strategy. Possible inputs for cv are: - None, to use the default 3-fold cross-validation, - integer, to specify the number of folds. - An object to be used as a cross-validation generator. - An iterable yielding train/test splits. See the scikit-learn `cross-validation guide <https://goo.gl/FS3VU6>`_ for more information on the possible strategies that can be used here. scoring : string, callable or None, optional, default: None A string or scorer callable object / function with signature ``scorer(estimator, X, y)``. See scikit-learn `cross-validation guide <https://goo.gl/FS3VU6>`_ for more information on the possible metrics that can be used. color: string Specify color for barchart kwargs : dict Keyword arguments that are passed to the base class and may influence the visualization as defined in other Visualizers. Attributes ------- cv_scores_ : ndarray shape (n_splits, ) The cross-validated scores from each subsection of the data cv_scores_mean_ : float Average cross-validated score across all subsections of the data Examples -------- >>> from sklearn import datasets, svm >>> iris = datasets.load_iris() >>> clf = svm.SVC(kernel='linear', C=1) >>> X = iris.data >>> y = iris.target >>> visualizer = CVScores(estimator=clf, cv=5, scoring='f1_macro') >>> visualizer.fit(X,y) >>> visualizer.show() Notes ----- This visualizer is a wrapper for `sklearn.model_selection.cross_val_score <https://goo.gl/4v7dfL>`_. Refer to the scikit-learn `cross-validation guide <https://goo.gl/FS3VU6>`_ for more details. """ def __init__(self, estimator, ax=None, cv=None, scoring=None, color=None, **kwargs): super(CVScores, self).__init__(estimator, ax=ax, **kwargs) self.cv = cv self.color = color self.scoring = scoring
[docs] def fit(self, X, y, **kwargs): """ Fits the learning curve with the wrapped model to the specified data. Draws training and test score curves and saves the scores to the estimator. Parameters ---------- X : array-like, shape (n_samples, n_features) Training vector, where n_samples is the number of samples and n_features is the number of features. y : array-like, shape (n_samples) or (n_samples, n_features), optional Target relative to X for classification or regression; None for unsupervised learning. Returns ------- self : instance """ self.cv_scores_ = cross_val_score( self.estimator, X, y, cv=self.cv, scoring=self.scoring ) self.cv_scores_mean_ = self.cv_scores_.mean() self.draw() return self
[docs] def draw(self, **kwargs): """ Creates the bar chart of the cross-validated scores generated from the fit method and places a dashed horizontal line that represents the average value of the scores. """ width = kwargs.pop("width", 0.3) linewidth = kwargs.pop("linewidth", 1) xvals = np.arange(1, len(self.cv_scores_) + 1, 1) self.ax.bar(xvals, self.cv_scores_, width=width, color=self.color) self.ax.axhline( self.cv_scores_mean_, color=self.color, label="Mean score = {:0.3f}".format(self.cv_scores_mean_), linestyle="--", linewidth=linewidth, ) return self.ax
[docs] def finalize(self, **kwargs): """ Add the title, legend, and other visual final touches to the plot. """ # Set the title of the figure self.set_title("Cross Validation Scores for {}".format(self.name)) # Add the legend loc = kwargs.pop("loc", "best") edgecolor = kwargs.pop("edgecolor", "k") self.ax.legend(frameon=True, loc=loc, edgecolor=edgecolor) # set spacing between the x ticks self.ax.xaxis.set_major_locator(ticker.MultipleLocator(1)) # Set the axis labels self.ax.set_xlabel("Training Instances") self.ax.set_ylabel("Score")
########################################################################## # Quick Method ##########################################################################
[docs]def cv_scores( estimator, X, y, ax=None, cv=None, scoring=None, color=None, show=True, **kwargs ): """ Displays cross validation scores as a bar chart and the average of the scores as a horizontal line This helper function is a quick wrapper to utilize the CVScores visualizer for one-off analysis. Parameters ---------- estimator : a scikit-learn estimator An object that implements ``fit`` and ``predict``, can be a classifier, regressor, or clusterer so long as there is also a valid associated scoring metric. Note that the object is cloned for each validation. X : array-like, shape (n_samples, n_features) Training vector, where n_samples is the number of samples and n_features is the number of features. y : array-like, shape (n_samples) or (n_samples, n_features), optional Target relative to X for classification or regression; None for unsupervised learning. ax : matplotlib.Axes object, optional The axes object to plot the figure on. cv : int, cross-validation generator or an iterable, optional Determines the cross-validation splitting strategy. Possible inputs for cv are: - None, to use the default 3-fold cross-validation, - integer, to specify the number of folds. - An object to be used as a cross-validation generator. - An iterable yielding train/test splits. see the scikit-learn `cross-validation guide <https://goo.gl/FS3VU6>`_ for more information on the possible strategies that can be used here. scoring : string, callable or None, optional, default: None A string or scorer callable object / function with signature ``scorer(estimator, X, y)``. See scikit-learn `cross-validation guide <https://goo.gl/FS3VU6>`_ for more information on the possible metrics that can be used. color: string Specify color for barchart show: bool, default: True If True, calls ``show()``, which in turn calls ``plt.show()`` however you cannot call ``plt.savefig`` from this signature, nor ``clear_figure``. If False, simply calls ``finalize()`` kwargs : dict Keyword arguments that are passed to the base class and may influence the visualization as defined in other Visualizers. Returns ------- visualizer : CVScores The fitted visualizer. """ # Initialize the visualizer visualizer = CVScores( estimator, ax=ax, cv=cv, scoring=scoring, color=None, **kwargs ) # Fit and show the visualizer visualizer.fit(X, y) if show: visualizer.show() else: visualizer.finalize() # Return the visualizer object return visualizer