Class Balance

One of the biggest challenges for classification models is an imbalance of classes in the training data. Severe class imbalances may be masked by relatively good F1 and accuracy scores – the classifier is simply guessing the majority class and not making any evaluation on the underrepresented class.

There are several techniques for dealing with class imbalance such as stratified sampling, down sampling the majority class, weighting, etc. But before these actions can be taken, it is important to understand what the class balance is in the training data. The ClassBalance visualizer supports this by creating a bar chart of the support for each class, that is the frequency of the classes’ representation in the dataset.

from yellowbrick.datasets import load_game
from yellowbrick.target import ClassBalance

# Load the classification data set
data = load_game()

# Specify the target
y = data["outcome"]

visualizer = ClassBalance(labels=["draw", "loss", "win"])
visualizer.fit(y)
visualizer.poof()
../../_images/class_balance.png

The resulting figure allows us to diagnose the severity of the balance issue. In this figure we can see that the "win" class dominates the other two classes. One potential solution might be to create a binary classifier: "win" vs "not win" and combining the "loss" and "draw" classes into one class.

Warning

The ClassBalance visualizer interface has changed in version 0.9, a classification model is no longer required to instantiate the visualizer, it can operate on data only. Additionally the signature of the fit method has changed from fit(X, y=None) to fit(y_train, y_test=None), passing in X is no longer required.

If a class imbalance must be maintained during evaluation (e.g. the event being classified is actually as rare as the frequency implies) then stratified sampling should be used to create train and test splits. This ensures that the test data has roughly the same proportion of classes as the training data. While scikit-learn does this by default in train_test_split and other cv methods, it can be useful to compare the support of each class in both splits.

The ClassBalance visualizer has a “compare” mode, where the train and test data can be passed to fit(), creating a side-by-side bar chart instead of a single bar chart as follows:

from sklearn.model_selection import train_test_split
from yellowbrick.model_selection import ClassBalance

# Load the classification data set
data = load_data('occupancy')

# Specify the features of interest and the target
features = ["temperature", "relative_humidity", "light", "C02", "humidity"]
classes = ["unoccupied", "occupied"]

# Extract the instances and target
X = data[features]
y = data["occupancy"]

# Create the train and test data
_, _, y_train, y_test = train_test_split(X, y, test_size=0.2)

# Instantiate the classification model and visualizer
visualizer = ClassBalance(labels=classes)

visualizer.fit(y_train, y_test)
return visualizer.poof()
../../_images/class_balance_compare.png

This visualization allows us to do a quick check to ensure that the proportion of each class is roughly similar in both splits. This visualization should be a first stop particularly when evaluation metrics are highly variable across different splits.

API Reference

Class balance visualizer for showing per-class support.

class yellowbrick.target.class_balance.ClassBalance(ax=None, labels=None, **kwargs)[source]

Bases: yellowbrick.target.base.TargetVisualizer

One of the biggest challenges for classification models is an imbalance of classes in the training data. The ClassBalance visualizer shows the relationship of the support for each class in both the training and test data by displaying how frequently each class occurs as a bar graph.

The ClassBalance visualizer can be displayed in two modes:

  1. Balance mode: show the frequency of each class in the dataset.
  2. Compare mode: show the relationship of support in train and test data.

These modes are determined by what is passed to the fit() method.

Parameters:
ax : matplotlib Axes, default: None

The axis to plot the figure on. If None is passed in the current axes will be used (or generated if required).

labels: list, optional

A list of class names for the x-axis if the target is already encoded. Ensure that the labels are ordered lexicographically with respect to the values in the target. A common use case is to pass LabelEncoder.classes_ as this parameter. If not specified, the labels in the data will be used.

kwargs: dict, optional

Keyword arguments passed to the super class. Here, used to colorize the bars in the histogram.

Attributes:
classes_ : array-like

The actual unique classes discovered in the target.

support_ : array of shape (n_classes,) or (2, n_classes)

A table representing the support of each class in the target. It is a vector when in balance mode, or a table with two rows in compare mode.

draw()[source]

Renders the class balance chart on the specified axes from support.

finalize(**kwargs)[source]

Finalize executes any subclass-specific axes finalization steps. The user calls poof and poof calls finalize.

Parameters:
kwargs: generic keyword arguments.
fit(y_train, y_test=None)[source]

Fit the visualizer to the the target variables, which must be 1D vectors containing discrete (classification) data. Fit has two modes:

  1. Balance mode: if only y_train is specified
  2. Compare mode: if both train and test are specified

In balance mode, the bar chart is displayed with each class as its own color. In compare mode, a side-by-side bar chart is displayed colored by train or test respectively.

Parameters:
y_train : array-like

Array or list of shape (n,) that containes discrete data.

y_test : array-like, optional

Array or list of shape (m,) that contains discrete data. If specified, the bar chart will be drawn in compare mode.