Using Third-Party Estimators
Many machine learning libraries implement the scikit-learn estimator API to easily integrate alternative optimization or decision methods into a data science workflow. Because of this, it seems like it should be simple to drop in a non-scikit-learn estimator into a Yellowbrick visualizer, and in principle, it is. However, the reality is a bit more complicated.
Yellowbrick visualizers often utilize more than just the method interface of estimators (e.g. fit()
and predict()
), relying on the learned attributes (object properties with a single underscore suffix, e.g. coef_
). The issue is that when a third-party estimator does not expose these attributes, truly gnarly exceptions and tracebacks occur. Yellowbrick is meant to aid machine learning diagnostics reasoning, therefore instead of just allowing drop-in functionality that may cause confusion, we’ve created a wrapper functionality that is a bit kinder with it’s messaging.
But first, an example.
# Import the wrap function and a Yellowbrick visualizer
from yellowbrick.contrib.wrapper import wrap
from yellowbrick.model_selection import feature_importances
# Instantiate the third party estimator and wrap it, optionally fitting it
model = wrap(ThirdPartyEstimator())
model.fit(X_train, y_train)
# Use the visualizer
oz = feature_importances(model, X_test, y_test, is_fitted=True)
The wrap
function initializes the third party model as a ContribEstimator
, which passes through all functionality to the underlying estimator, however if an error occurs, the exception that will be raised looks like:
yellowbrick.exceptions.YellowbrickAttributeError: estimator is missing the 'fit'
attribute, which is required for this visualizer - please see the third party
estimators documentation.
Some estimators are required to pass type checking, for example the estimator must be a classifier, regressor, clusterer, density estimator, or outlier detector. A second argument can be passed to the wrap
function declaring the type of estimator:
from yellowbrick.classifier import precision_recall_curve
from yellowbrick.contrib.wrapper import wrap, CLASSIFIER
model = wrap(ThirdPartyClassifier(), CLASSIFIER)
precision_recall_curve(model, X, y)
Or you can simply use the wrap helper functions of the specific type:
from yellowbrick.contrib.wrapper import regressor, classifier, clusterer
from yellowbrick.regressor import prediction_error
from yellowbrick.classifier import classification_report
from yellowbrick.cluster import intercluster_distance
reg = regressor(ThirdPartyRegressor())
prediction_error(reg, X, y)
clf = classifier(ThirdPartyClassifier())
classification_report(clf, X, y)
ctr = clusterer(ThirdPartyClusterer())
intercluster_distance(ctr, X)
So what should you do if a required attribute is missing from your estimator? The simplest and quickest thing to do is to subclass ContribEstimator
and add the required functionality.
from yellowbrick.contrib.wrapper import ContribEstimator, CLASSIFIER
class MyWrapper(ContribEstimator):
_estimator_type = CLASSIFIER
@property
def feature_importances_(self):
return self.estimator.tree_feature_importances()
model = MyWrapper(ThirdPartyEstimator()
feature_importances(model, X, y)
This is certainly less than ideal - but we’d welcome a contrib PR to add more native functionality to Yellowbrick!
Tested Libraries
The following libraries have been tested with the Yellowbrick wrapper.
xgboost: both the
XGBRFRegressor
andXGBRFClassifier
have been tested with Yellowbrick both with and without the wrapper functionality.CatBoost: the
CatBoostClassifier
has been tested with theClassificationReport
visualizer.
The following libraries have been partially tested and will likely work without too much additional effort:
cuML: it is likely that clustering, classification, and regression cuML estimators will work with Yellowbrick visualizers. However, the cuDF datasets have not been tested with Yellowbrick.
Spark MLlib: The Spark DataFrame API and estimators should work with Yellowbrick visualizers in a local notebook context after collection.
Note
If you have used a Python machine learning library not listed here with Yellowbrick, please let us know - we’d love to add it to the list! Also if you’re using a library that is not wholly compatible, please open an issue so that we can explore how to integrate it with the yellowbrick.contrib
module!
API Reference
Wrapper for third-party estimators that implement the sklearn API but do not directly
subclass the sklearn.base.BaseEstimator
class. This method is a quick way to get
other estimators into Yellowbrick, while avoiding weird errors and issues.
- class yellowbrick.contrib.wrapper.ContribEstimator(estimator, estimator_type=None)[source]
Bases:
object
Wraps a third party estimator that implements the sckit-learn API and therefore could be used with Yellowbrick but doesn’t subclass
BaseEstimator
. Since there are a number of pitfalls, this object provides sensible errors and warnings rather than completely blowing up, allowing contrib users to identify issues and fix them, smoothing the path to getting third party estimators into the Yellowbrick ecosystem.- Parameters
- estimatorobject
The non-sklearn estimator to wrap and use for Visualizers
- estimator_typestr, optional
One of “classifier”, “regressor”, “clusterer”, “DensityEstimator”, or “outlier_detector” that allows the contrib estimator to pass the scikit-learn
is_classifier
, etc. functions. If not specified, the _estimator_type attr is passed through to the underlying estimator.
- yellowbrick.contrib.wrapper.classifier(estimator)[source]
Wrap a third-party classifier to make it available to Yellowbrick visualizers.
- Parameters
- estimatorobject
The non-sklearn classifier to wrap and use for Visualizers
- yellowbrick.contrib.wrapper.clusterer(estimator)[source]
Wrap a third-party clusterer to make it available to Yellowbrick visualizers.
- Parameters
- estimatorobject
The non-sklearn clusterer to wrap and use for Visualizers
- yellowbrick.contrib.wrapper.regressor(estimator)[source]
Wrap a third-party regressor to make it available to Yellowbrick visualizers.
- Parameters
- estimatorobject
The non-sklearn regressor to wrap and use for Visualizers
- yellowbrick.contrib.wrapper.wrap(estimator, estimator_type=None)[source]
Wrap a third-party estimator that implements portions of the scikit-learn API to make it available to Yellowbrick visualizers. If the Yellowbrick visualizer cannot succeed, then a sensible error is raised instead.
- Parameters
- estimatorobject
The non-sklearn estimator to wrap and use for Visualizers
- estimator_typestr, optional
One of “classifier”, “regressor”, “clusterer”, “DensityEstimator”, or “outlier_detector” that allows the contrib estimator to pass the scikit-learn
is_classifier
, etc. functions. If not specified, the _estimator_type attr is passed through to the underlying estimator.