# yellowbrick.text.correlation
# Implementation of word correlation for text visualization.
#
# Author: Patrick Deziel
# Created: Sun May 1 19:43:41 2022 -0600
#
# Copyright (C) 2022 The scikit-yb developers
# For license information, see LICENSE.txt
#
# ID: correlation.py [b652fc9] deziel.patrick@gmail.com $
"""
Implementation of word correlation for text visualization.
"""
##########################################################################
## Imports
##########################################################################
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer
from yellowbrick.style import find_text_color
from yellowbrick.text.base import TextVisualizer
from yellowbrick.style.palettes import color_sequence
from yellowbrick.exceptions import YellowbrickValueError
##########################################################################
## Word Correlation Plot Visualizer
##########################################################################
[docs]class WordCorrelationPlot(TextVisualizer):
"""
Word correlation illustrates the extent to which words in a corpus appear in the
same documents.
WordCorrelationPlot visualizes the binary correlation between words across
documents as a heatmap. The correlation is defined using the mean square
contingency coefficient (phi-coefficient) between any two words m and n. The
coefficient is a value between -1 and 1, inclusive. A value close to 1 or -1
indicates strong positive or negative correlation between m and n, while a value
close to 0 indicates little or no correlation. The constructor takes one required
argument, which is the list of words or n-grams to be plotted.
Parameters
----------
words : list of str
The list of words or n-grams to be plotted. The words must be present in the
provided corpus on fit().
ignore_case : bool, default: False
If True, all words will be converted to lowercase before processing.
ax : matplotlib Axes, default: None
The axes to plot the figure on.
cmap : str or cmap, default: "RdYlBu"
Colormap to use for the heatmap.
colorbar : bool, default: True
If True, a colorbar will be added to the heatmap.
fontsize : int, default: None
Font size to use for the labels on the axes.
kwargs : dict
Pass any additional keyword arguments to the super class.
Attributes
----------
self.doc_term_matrix_ : array of shape (n_docs, n_features)
The computed sparse document-term matrix containing binary values indicating if
a word is present in a document.
self.num_docs_ : int
The number of observed documents in the corpus.
self.vocab_ : dict
A dictionary mapping words to their indices in the document-term matrix.
self.num_features_ : int
The number of features (word labels) in the resulting plot.
self.correlation_matrix_ : ndarray of shape (n_features, n_features)
The computed matrix containing the phi-coefficients between all features.
"""
def __init__(
self,
words,
ignore_case=False,
ax=None,
cmap="RdYlBu",
colorbar=True,
fontsize=None,
**kwargs
):
super(WordCorrelationPlot, self).__init__(ax=ax, **kwargs)
# Visual parameters
self.fontsize = fontsize
self.colorbar = colorbar
self.cmap = color_sequence(cmap)
# Fitting parameters
self.ignore_case = ignore_case
self.words = self._construct_terms(words, ignore_case)
self.ngram_range = self._compute_ngram_range()
def _construct_terms(self, words, ignore_case):
"""
Constructs the list of terms to be plotted based on the provided words. This
performs input checking and removes duplicates to produce a list of valid words
for fitting.
"""
# Remove surrounding whitespace
terms = [word.strip() for word in words if len(word.strip()) > 0]
if len(terms) == 0:
raise YellowbrickValueError("Must provide at least one word to plot.")
# Convert to lowercase if ignore_case is set
if ignore_case:
terms = [word.lower() for word in terms]
# Sort and remove duplicates
return sorted(set(terms))
def _compute_ngram_range(self):
"""
Computes the n-gram range to use for vectorization based on the provided words.
This allows the user to specify multi-word terms for plotting.
"""
ngrams = [len(word.split()) for word in self.words]
return (min(ngrams), max(ngrams))
def _compute_coefficient(self, m, n):
"""
Computes the phi-coefficient for two words m and n, which is a correlation
value between -1 and 1 inclusive.
"""
m_col = self.doc_term_matrix_.getcol(self.vocab_[m])
n_col = self.doc_term_matrix_.getcol(self.vocab_[n])
both = m_col.multiply(n_col).sum()
m_total = m_col.sum()
n_total = n_col.sum()
only_m = m_total - both
only_n = n_total - both
neither = self.num_docs_ - both - only_m - only_n
return ((both * neither) - (only_m * only_n)) / np.sqrt(m_total * n_total * (self.num_docs_ - m_total) * (self.num_docs_ - n_total))
[docs] def fit(self, X, y=None):
"""
The fit method is the primary drawing input for the word correlation
visualization.
Parameters
----------
X : list of str or generator
Should be provided as a list of strings or a generator yielding strings
that represent the documents in the corpus.
y : None
Labels are not used for the word correlation visualization.
Returns
-------
self: instance
Returns the instance of the transformer/visualizer.
Attributes
----------
self.doc_term_matrix_ : array of shape (n_docs, n_features)
The computed sparse document-term matrix containing binary values
indicating if a word is present in a document.
self.num_docs_ : int
The number of observed documents in the corpus.
self.vocab_ : dict
A dictionary mapping words to their indices in the document-term matrix.
self.num_features_ : int
The number of features (word labels) in the resulting plot.
self.correlation_matrix_ : ndarray of shape (n_features, n_features)
The computed matrix containing the phi-coefficients between all features.
"""
# Instantiate the CountVectorizer
vecs = CountVectorizer(
vocabulary=self.words,
lowercase=self.ignore_case,
ngram_range=self.ngram_range,
binary=True
)
# Get the binary document counts for the target words
self.doc_term_matrix_ = vecs.fit_transform(X)
self.num_docs_ = self.doc_term_matrix_.shape[0]
self.vocab_ = vecs.vocabulary_
# Verify that all target words exist in the corpus
for word in self.words:
if self.doc_term_matrix_.getcol(self.vocab_[word]).sum() == 0:
raise YellowbrickValueError("Word '{}' does not exist in the corpus.".format(word))
# Compute the phi-coefficient for each pair of words
self.num_features_ = len(self.words)
self.correlation_matrix_ = np.zeros((self.num_features_, self.num_features_))
for i, m in enumerate(self.words):
for j, n in enumerate(self.words):
self.correlation_matrix_[i, j] = self._compute_coefficient(m, n)
self.draw(X)
return self
[docs] def draw(self, X):
"""
Called from the fit() method, this metod draws the heatmap on the figure using
the computed correlation matrix.
"""
# Use correlation matrix data for the heatmap
wc_display = self.correlation_matrix_
# Set up the dimensions of the pcolormesh
X, Y = np.arange(self.num_features_ + 1), np.arange(self.num_features_ + 1)
self.ax.set_ylim(bottom=0, top=wc_display.shape[0])
self.ax.set_xlim(left=0, right=wc_display.shape[1])
# Set the words as the tick labels on the plot. The Y-axis is sorted from top
# to bottom, the X-axis is sorted from left to right.
xticklabels = self.words
yticklabels = self.words[::-1]
ticks = np.arange(self.num_features_) + 0.5
self.ax.set(xticks=ticks, yticks=ticks)
self.ax.set_xticklabels(xticklabels, rotation="vertical", fontsize=self.fontsize)
self.ax.set_yticklabels(yticklabels, fontsize=self.fontsize)
# Flip the Y-axis values so that they match the sorted labels
wc_display = np.flipud(wc_display)
# Draw the labels in each heatmap cell
for x in X[:-1]:
for y in Y[:-1]:
# Get the correlation value for the cell
value = wc_display[x, y]
svalue = "{:.2f}".format(value)
# Get a compatible text color for the cell
base_color = self.cmap(value / 2 + 0.5)
text_color = find_text_color(base_color)
# Draw the text at the center of the cell
# Note: x and y coordinates are swapped to match the pcolormesh
cx, cy = y + 0.5, x + 0.5
self.ax.text(cx, cy, svalue, va="center", ha="center", color=text_color, fontsize=self.fontsize)
# Draw the heatmap
g = self.ax.pcolormesh(X, Y, wc_display, cmap=self.cmap, vmin=-1, vmax=1)
# Add the color bar
if self.colorbar:
self.ax.figure.colorbar(g, ax=self.ax)
return self.ax
[docs] def finalize(self):
"""
Prepares the figure for rendering by adding the title. This method is usually
called from show() and not directly by the user.
"""
self.set_title("Word Correlation Plot")
self.fig.tight_layout()
##########################################################################
## Quick Method
##########################################################################
[docs]def word_correlation(
words,
corpus,
ignore_case=True,
ax=None,
cmap="RdYlBu",
show=True,
colorbar=True,
fontsize=None,
**kwargs
):
"""Word Correlation
Displays the binary correlation between the given words across the documents in a
corpus.
For a list of words with length n, this produces an n x n heatmap of correlation
values in the range [-1, 1].
Parameters
----------
words : list of str
The corpus words to display in the heatmap.
corpus : list of str or generator
The corpus as a list of documents or a generator yielding documents.
ignore_case : bool, default: True
If True, all words will be converted to lowercase before proessing.
ax : matplotlib axes, default: None
The axes to plot the figure on.
cmap : str, default: "RdYlBu"
Colormap to use for the heatmap.
show : bool, default: True
If True, calls ``show()``, which in turn calls ``plt.show()`` however
you cannot call ``plt.savefig`` from this signature, nor
``clear_figure``. If False, simply calls ``finalize()``
colorbar : bool, default: True
If True, adds a colorbar to the figure.
fontsize : int, default: None
If not None, sets the font size of the labels.
"""
# Instantiate the visualizer
visualizer = WordCorrelationPlot(
words=words,
lowercase=ignore_case,
ax=ax,
cmap=cmap,
colorbar=colorbar,
fontsize=fontsize,
**kwargs
)
# Fit and transform the visualizer (calls draw)
visualizer.fit(corpus)
# Draw the final visualization
if show:
visualizer.show()
else:
visualizer.finalize()
# Return the visualizer
return visualizer