Source code for yellowbrick.classifier.classification_report

# yellowbrick.classifier.classification_report
# Visual classification report for classifier scoring.
#
# Author:   Rebecca Bilbro
# Author:   Benjamin Bengfort
# Author:   Neal Humphrey
# Author:   Allyssa Riley
# Author:   Larry Gray
# Created:  Wed May 3 18:15:42 2017 -0400
#
# Copyright (C) 2017 The scikit-yb developers
# For license information, see LICENSE.txt
#
# ID: classification_report.py [5388065] neal@nhumphrey.com $

"""
Visual classification report for classifier scoring.
"""

##########################################################################
## Imports
##########################################################################

import numpy as np
import matplotlib.pyplot as plt

from sklearn.metrics import precision_recall_fscore_support

from yellowbrick.style import find_text_color
from yellowbrick.style.palettes import color_sequence
from yellowbrick.exceptions import YellowbrickValueError
from yellowbrick.classifier.base import ClassificationScoreVisualizer

##########################################################################
## Classification Report
##########################################################################

PERCENT = "percent"
CMAP_UNDERCOLOR = "w"
CMAP_OVERCOLOR = "#2a7d4f"
SCORES_KEYS = ("precision", "recall", "f1", "support")


[docs]class ClassificationReport(ClassificationScoreVisualizer): """ Classification report that shows the precision, recall, F1, and support scores for the model. Integrates numerical scores as well as a color-coded heatmap. Parameters ---------- estimator : estimator A scikit-learn estimator that should be a classifier. If the model is not a classifier, an exception is raised. If the internal model is not fitted, it is fit when the visualizer is fitted, unless otherwise specified by ``is_fitted``. ax : matplotlib Axes, default: None The axes to plot the figure on. If not specified the current axes will be used (or generated if required). classes : list of str, defult: None The class labels to use for the legend ordered by the index of the sorted classes discovered in the ``fit()`` method. Specifying classes in this manner is used to change the class names to a more specific format or to label encoded integer classes. Some visualizers may also use this field to filter the visualization for specific classes. For more advanced usage specify an encoder rather than class labels. cmap : string, default: ``'YlOrRd'`` Specify a colormap to define the heatmap of the predicted class against the actual class in the classification report. support: {True, False, None, 'percent', 'count'}, default: None Specify if support will be displayed. It can be further defined by whether support should be reported as a raw count or percentage. encoder : dict or LabelEncoder, default: None A mapping of classes to human readable labels. Often there is a mismatch between desired class labels and those contained in the target variable passed to ``fit()`` or ``score()``. The encoder disambiguates this mismatch ensuring that classes are labeled correctly in the visualization. is_fitted : bool or str, default="auto" Specify if the wrapped estimator is already fitted. If False, the estimator will be fit when the visualizer is fit, otherwise, the estimator will not be modified. If "auto" (default), a helper method will check if the estimator is fitted before fitting it again. force_model : bool, default: False Do not check to ensure that the underlying estimator is a classifier. This will prevent an exception when the visualizer is initialized but may result in unexpected or unintended behavior. colorbar : bool, default: True Specify if the color bar should be present fontsize : int or None, default: None Specify the font size of the x and y labels kwargs : dict Keyword arguments passed to the visualizer base classes. Examples -------- >>> from yellowbrick.classifier import ClassificationReport >>> from sklearn.linear_model import LogisticRegression >>> viz = ClassificationReport(LogisticRegression()) >>> viz.fit(X_train, y_train) >>> viz.score(X_test, y_test) >>> viz.show() Attributes ---------- classes_ : ndarray of shape (n_classes,) The class labels observed while fitting. class_count_ : ndarray of shape (n_classes,) Number of samples encountered for each class during fitting. score_ : float An evaluation metric of the classifier on test data produced when ``score()`` is called. This metric is between 0 and 1 -- higher scores are generally better. For classifiers, this score is usually accuracy, but ensure you check the underlying model for more details about the score. scores_ : dict of dicts Outer dictionary composed of precision, recall, f1, and support scores with inner dictionaries specifiying the values for each class listed. """ def __init__( self, estimator, ax=None, classes=None, cmap="YlOrRd", support=None, encoder=None, is_fitted="auto", force_model=False, colorbar=True, fontsize=None, **kwargs ): super(ClassificationReport, self).__init__( estimator, ax=ax, classes=classes, encoder=encoder, is_fitted=is_fitted, force_model=force_model, **kwargs ) self.colorbar = colorbar self.support = support self.cmap = color_sequence(cmap) self.cmap.set_over(color=CMAP_OVERCOLOR) self.cmap.set_under(color=CMAP_UNDERCOLOR) self._displayed_scores = [key for key in SCORES_KEYS] self.fontsize=fontsize if support not in {None, True, False, "percent", "count"}: raise YellowbrickValueError( "'{}' is an invalid argument for support, use None, True, " "False, 'percent', or 'count'".format(support) ) if not support: self._displayed_scores.remove("support")
[docs] def score(self, X, y): """ Generates the Scikit-Learn classification report. Parameters ---------- X : ndarray or DataFrame of shape n x m A matrix of n instances with m features y : ndarray or Series of length n An array or series of target or class values Returns ------- score_ : float Global accuracy score """ # Call super to check if fitted and to compute self.score_ super(ClassificationReport, self).score(X, y) # Labels must be a data type that works with np.isnan labels = range(len(self.classes_)) y_pred = self.predict(X) scores = precision_recall_fscore_support(y, y_pred, labels=labels) # Calculate the percentage for the support metric # and store the percent in place of raw support counts self.support_score_ = scores[-1] scores = list(scores) scores[-1] = scores[-1] / scores[-1].sum() # Create a mapping composed of precision, recall, F1, and support # to their respective values scores = map(lambda s: dict(zip(self.classes_, s)), scores) self.scores_ = dict(zip(SCORES_KEYS, scores)) # Remove support scores if not required if not self.support: self.scores_.pop("support") self.draw() return self.score_
[docs] def draw(self): """ Renders the classification report across each axis. """ # Create display grid cr_display = np.zeros((len(self.classes_), len(self._displayed_scores))) # For each class row, append columns for precision, recall, f1, and support for idx, cls in enumerate(self.classes_): for jdx, metric in enumerate(self._displayed_scores): cr_display[idx, jdx] = self.scores_[metric][cls] # Set up the dimensions of the pcolormesh # NOTE: pcolormesh accepts grids that are (N+1,M+1) X, Y = ( np.arange(len(self.classes_) + 1), np.arange(len(self._displayed_scores) + 1), ) self.ax.set_ylim(bottom=0, top=cr_display.shape[0]) self.ax.set_xlim(left=0, right=cr_display.shape[1]) # Get the human readable labels labels = self._labels() if labels is None: labels = self.classes_ # Fetch the grid labels from the classes in correct order; set ticks. xticklabels = self._displayed_scores yticklabels = labels[::-1] yticks = np.arange(len(labels)) + 0.5 xticks = np.arange(len(self._displayed_scores)) + 0.5 self.ax.set(yticks=yticks, xticks=xticks) self.ax.set_xticklabels( xticklabels, rotation=45, fontsize=self.fontsize ) self.ax.set_yticklabels(yticklabels, fontsize=self.fontsize) # Set data labels in the grid, enumerating over class, metric pairs # NOTE: X and Y are one element longer than the classification report # so skip the last element to label the grid correctly. for x in X[:-1]: for y in Y[:-1]: # Extract the value and the text label value = cr_display[x, y] svalue = "{:0.3f}".format(value) # change the svalue for support (when y == 3) because we want # to label it as the actual support value, not the percentage if y == 3: if self.support != PERCENT: svalue = self.support_score_[x] # Determine the grid and text colors base_color = self.cmap(value) text_color = find_text_color(base_color) # Add the label to the middle of the grid cx, cy = x + 0.5, y + 0.5 self.ax.text(cy, cx, svalue, va="center", ha="center", color=text_color) # Draw the heatmap with colors bounded by the min and max of the grid # NOTE: I do not understand why this is Y, X instead of X, Y it works # in this order but raises an exception with the other order. g = self.ax.pcolormesh( Y, X, cr_display, vmin=0, vmax=1, cmap=self.cmap, edgecolor="w" ) # Add the color bar if self.colorbar: plt.colorbar(g, ax=self.ax) # TODO: Could use self.fig now else: pass # Return the axes being drawn on return self.ax
[docs] def finalize(self, **kwargs): """ Adds a title and sets the axis labels correctly. Also calls tight layout to ensure that no parts of the figure are cut off in the final visualization. Parameters ---------- kwargs: generic keyword arguments. Notes ----- Generally this method is called from show and not directly by the user. """ # Set the title of the classifiation report self.set_title("{} Classification Report".format(self.name)) # Set the tick marks appropriately self.ax.set_xticks(np.arange(len(self._displayed_scores)) + 0.5) self.ax.set_yticks(np.arange(len(self.classes_)) + 0.5) self.ax.set_xticklabels(self._displayed_scores, rotation=45) self.ax.set_yticklabels(self.classes_) self.fig.tight_layout()
[docs]def classification_report( estimator, X_train, y_train, X_test=None, y_test=None, ax=None, classes=None, cmap="YlOrRd", support=None, encoder=None, is_fitted="auto", force_model=False, show=True, colorbar=True, fontsize=None, **kwargs ): """Classification Report Displays precision, recall, F1, and support scores for the model. Integrates numerical scores as well as color-coded heatmap. Parameters ---------- estimator : estimator A scikit-learn estimator that should be a classifier. If the model is not a classifier, an exception is raised. If the internal model is not fitted, it is fit when the visualizer is fitted, unless otherwise specified by ``is_fitted``. X_train : ndarray or DataFrame of shape n x m A feature array of n instances with m features the model is trained on. Used to fit the visualizer and also to score the visualizer if test splits are not directly specified. y_train : ndarray or Series of length n An array or series of target or class values. Used to fit the visualizer and also to score the visualizer if test splits are not specified. X_test : ndarray or DataFrame of shape n x m, default: None An optional feature array of n instances with m features that the model is scored on if specified, using X_train as the training data. y_test : ndarray or Series of length n, default: None An optional array or series of target or class values that serve as actual labels for X_test for scoring purposes. ax : matplotlib Axes, default: None The axes to plot the figure on. If not specified the current axes will be used (or generated if required). classes : list of str, defult: None The class labels to use for the legend ordered by the index of the sorted classes discovered in the ``fit()`` method. Specifying classes in this manner is used to change the class names to a more specific format or to label encoded integer classes. Some visualizers may also use this field to filter the visualization for specific classes. For more advanced usage specify an encoder rather than class labels. cmap : string, default: ``'YlOrRd'`` Specify a colormap to define the heatmap of the predicted class against the actual class in the classification report. support: {True, False, None, 'percent', 'count'}, default: None Specify if support will be displayed. It can be further defined by whether support should be reported as a raw count or percentage. encoder : dict or LabelEncoder, default: None A mapping of classes to human readable labels. Often there is a mismatch between desired class labels and those contained in the target variable passed to ``fit()`` or ``score()``. The encoder disambiguates this mismatch ensuring that classes are labeled correctly in the visualization. is_fitted : bool or str, default='auto' Specify if the wrapped estimator is already fitted. If False, the estimator will be fit when the visualizer is fit, otherwise, the estimator will not be modified. If 'auto' (default), a helper method will check if the estimator is fitted before fitting it again. force_model : bool, default: False Do not check to ensure that the underlying estimator is a classifier. This will prevent an exception when the visualizer is initialized but may result in unexpected or unintended behavior. show: bool, default: True If True, calls ``show()``, which in turn calls ``plt.show()`` however you cannot call ``plt.savefig`` from this signature, nor ``clear_figure``. If False, simply calls ``finalize()`` colorbar : bool, default: True Specify if the color bar should be present fontsize : int or None, default: None Specify the font size of the x and y labels kwargs : dict Keyword arguments passed to the visualizer base classes. Returns ------- viz : ClassificationReport Returns the fitted, finalized visualizer """ # Instantiate the visualizer visualizer = ClassificationReport( estimator=estimator, ax=ax, classes=classes, cmap=cmap, support=support, encoder=encoder, is_fitted=is_fitted, force_model=force_model, colorbar=colorbar, fontsize=fontsize, **kwargs ) # Fit and transform the visualizer (calls draw) visualizer.fit(X_train, y_train) # Score the visualizer if X_test is not None and y_test is not None: visualizer.score(X_test, y_test) elif X_test is not None or y_test is not None: raise YellowbrickValueError( "both X_test and y_test are required if one is specified" ) else: visualizer.score(X_train, y_train) # Draw the final visualization if show: visualizer.show() else: visualizer.finalize() # Return the visualizer return visualizer