Importing scikit-learn Models into Java

Currently scikit-learn is the best general purpose machine learning package. It is part of the Scientific Python family of tools, built on top of the Numeric Python matrix processing engine. The code is readable, documentation extensive, and the package is popular, so there’s plenty of help available on Stack Overflow when you need it. But perhaps scikit-learn’s best selling point is that it’s written in Python, a language well suited for the ad hoc exploratory working style typical of machine learning. Java machine learning toolkits like Weka and Mallet are mathematically solid, but running mathematical algorithms is only part of the job of data science. There’s also inevitably lots of format munging, directory groveling, glue code, and trying things that don’t work. You want the basics to be as easy as possible. The Python command line achieves a level of transparency that Java–with its boilerplate, IDEs, compilers, complex build systems, and lack of a REPL–cannot match.

Illustrations of machine learning classification

Still, the JVM is a popular platform, and it would be nice to be able to train a model in scikit-learn and then deploy it in Java. There is currently no support for this. The right thing would be to have scikit-learn export its model files to some common format like PMML, but that feature does not currently exist.1 scikit-learn’s only serialization is Python’s native pickle format, which works great, but only for other Python programs. In theory, writing your own serialization should be easy. A model is just a set of numbers, but it only works if the test time code exactly reproduces the training code’s processing of its input. Any deviation and your finely tuned vector of coefficients becomes nothing more than a numeric jumble.

Let’s take a look at a fairly simple but still non-trivial machine learning model and see what is involved in exporting its semantics in a cross-language way. Say I want to do text classification. I have a corpus of short documents drawn from two genres: cookbooks and descriptions of farm life. I have tab-delimited text files that look like this.

0   The horse and the cow lived on the farm
1   Boil two eggs for five minutes
0   The hayloft of the barn was full
1   Drain the pasta

The first column is an integer class label and the second is a document. I want the computer to learn how to hypothesize a 0 or 1 for any string input it is given. A standard approach would be to treat the documents as bags of words and build a Naive Bayes model over them. To make things more sophisticated, let’s train on bi-grams in addition to individual words, and work with Tf-Idf values instead of raw counts. scikit-learn makes this easy. Here is the bulk of the code needed to train such a model.

from sklearn.naive_bayes import MultinomialNB
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.pipeline import Pipeline

def train_model(corpus):
    labels = []
    data = []
    for line in corpus:
        label, vector = line.decode('utf-8').split("")
    model = Pipeline([('vect', CountVectorizer(ngram_range=[1, 2])),
                      ('tfidf', TfidfTransformer()),
                      ('clf', MultinomialNB())]), labels)
    return model

The model returned is a scikit-learn object. If we want to export it to another language, we have to extract its meaningful parts and serialize them in a general-purpose way. These meaningful parts are sets of numbers. Specifically, for each dimension in vector space there is an Idf score and an array of coefficients for each class. Additionally there is a scalar bias for each class. So what we have is a vector of numbers plus a mapping for strings to vectors of numbers.

Bias 0 Bias 1
-0.693 -0.587
Term Idf Coefficient 0 Coefficient 1
garlic 4.673 -8.327 -6.825
peel garlic 3.522 -12.805 -10.505

You have to do some detective work to figure where inside the scikit-learn objects these numbers actually reside, but once you have them you can serialize them in a language-agnostic way by writing them out as JSON. Sure the file will be huge, and the representation of floating point numbers as strings is wildly inefficient, but we can always gzip the thing.

Now the Java decoder needs to 1) load this file 2) turn the input into n-gram terms 3) build a vector of term Tf-Idf scores 4) linearly transform that vector using the model’s coefficients and biases. None of this is particularly difficult,but you have to make sure that the Java decoder performs each of these steps in exactly the same way as the Python encoder, so that the numbers passed between them retain their meaning.

The Linear N-gram Model project contains a Python training script and a Java decoder that does this. Train a model in Python on a corpus like the one pictured above, run it in Java on unlabeled text and it will produce class predictions and log likelihoods like so.

0   -47.8674 -47.1280   The harvest was finished early this year
0   -47.0950 -42.8352   We fed the horses and the pigs
1   -45.3605 -46.8341   Place the garlic in a pan

This project can serve as starter example code for machine learning researchers faced with a similar cross-language serialization task.

1But check out Py2PMML, which looks like it gets you part of the way there. (Hat tip darknightelf.)

This entry was posted in Innumerable ones, Those that have just broken the flower vase. Bookmark the permalink.

Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s