Java Development News:

Using CI-Bayes

By Joseph Ottinger

01 Feb 2009 | TheServerSide.com

CI-Bayes is a Java project that provides classification features. As of version 1.0.4, this includes three Bayesian classifiers, a "simple Bayesian" classifier, a naïve classifier, and - the most useful - a Fisher classifier. (The classifiers mostly differ in how they compensate for some probabilities in the dataset; generally, the Fisher classifier will give you the best numbers.)

Bayesian classification is built off of Thomas Bayes' theorem of probability analysis, which states - in a very simple paraphrase - that the likelihood of an object O being classifiable as class C1 depends on how often elements E (contained in object O) have appeared as class C1 as opposed to C2 before.

As you might be able to tell from the description (with which mathematicians will surely take exception), Bayesian algorithms are trained, not intuitive - a Bayesian classifier will have no way to tell you what classification something falls in until it's been trained, and often won't be able to generate meaningful results until after extensive training has taken place.

Bayesian classification is often thought of in context of spam filtering; Apache's SpamAssassin used to use a Bayesian algorithm to classify spam, but Bayesian systems are subject to "poisoning" and also have difficulty classifying data that has not been seen before. Thus, SpamAssassin now uses a multi-level Perceptron, a type of neural network, instead; a multilevel Perceptron can "intuit" relationships between words that make it seem slightly "smarter" than a Bayesian algorithm.

Even with its flaws, Bayesian analysis remains effective. For one thing, it's simple; interfacing with generalized neural networks involves a translation step that seems somewhat arcane. In addition, bayesian algorithms are simple (and thus fast.) While your needs may necessitate another approach eventually, a Bayesian classifier is easy to integrate and use.

Installing CI-Bayes

CI-Bayes is a project on java.net, located at https://ci-bayes.dev.java.net. It can be checked out from SVN with the following command:

svn checkout https://ci-bayes.dev.java.net/svn/ci-bayes/trunk ci-bayes

It uses Maven to build; running mvn install from the main directory will download the dependencies (most of which are used only for testing) and install them into your local Maven repository, along with the CI-Bayes distribution jar itself.

After CI-Bayes has been installed into a Maven repository, using it is very simple: add the following dependency to your pom.xml:

<dependency>
   <groupId>com.enigmastation</groupId>
   <artifactId>ci-bayes</artifactId>
   <version>1.0.4</version>
</dependency>

If you're not using Maven for your project, the CI-Bayes jar will be in the target directory. You'll also need to grab Javolution, Jean-Marie Dautelle's utility library, used for its collection classes.

Early Going

Using CI-Bayes is very simple. In concept, all you have to do is create a classifier, train it, then throw some text at it to see what it says the classification is. We can do this with some simple TestNG tests.

For my simple tests here, I used the first paragraph from two Wikipedia articles, Rush and Java, and read them into Strings called rush and java, respectively.

Then I used the following method as a test:

@Test
public void testClassifier() {
    NaiveClassifier classifier = new NaiveClassifierImpl();
    classifier.train(rush, "music");
    classifier.train(java, "programming");
    assert
  classifier.getClassification("Eclipse is an integrated development environment, or IDE.", "unknown")
   .equals("programming");
}

Here's the text that managed to get classified:

Java is a programming language originally developed by Sun Microsystems and released in 1995 as a core component of Sun Microsystems' Java platform. The language derives much of its syntax from C and C++ but has a simpler object model and fewer low-level facilities. Java applications are typically compiled to bytecode that can run on any Java virtual machine (JVM) regardless of computer architecture.

This is a poor test - using one body of text for each classification - but even with this poor test, we have a successful measurement. In fact, if we look at the actual probabilities generated - which we can get from classifier.getProbabilities() - we'll find that "programming" got a "score" of 0.023, and "music" got a mere 0.008. These are poor numbers indeed - but they're enough to generate a valid result this time.

A better test would use more training sessions. If we use the training texts thirty times each, by way of a simple loop, our probabilities change some, but not much: "programming" gets a score of 0.031 and "music" gets a dismal 0.001. The sameness of the training texts emphasizes the results some, but not much.

In the real world, a user would have the opportunity to confirm or deny the classification. Let's use another test, one that trains thirty times with the training texts, then confirm the test text, then test it again:

@Test
public void testClassifierWithConfirmation() {
    NaiveClassifier classifier = new NaiveClassifierImpl();
    for (int i = 0; i < 30; i++) {
        classifier.train(rush, "music");
        classifier.train(java, "programming");
    }
    String testText = "Eclipse is an integrated development environment, or IDE.";
    assert
            classifier.getClassification(testText, "unknown")
                    .equals("programming");
    System.out.println(Arrays.toString(classifier.getProbabilities(testText)));
    classifier.train(testText, "programming");
    assert
            classifier.getClassification(testText, "unknown")
                    .equals("programming");
    System.out.println(Arrays.toString(classifier.getProbabilities(testText)));
}

Our output now looks like this:

[[ClassifierProbability:category="programming",score=0.031],
    [ClassifierProbability:category="music",score=0.001]]
[[ClassifierProbability:category="programming",score=0.003],
    [ClassifierProbability:category="music",score=0]]

The results barely changed - but definitely confirmed that our testing text isn't about music! (Well, according to our training data, at least... IDEs now have the ability to various audio things, proving that every IDE secretly wishes it was EMACS.)

Changing to a FisherClassifier changes things. Not the results - we don't want those to change, since we know they're right - but the probability scores dramatically change, using a baseline of 1.0 as a measure of probability instead of a simple computed value. This is roughly analogous to Lucene's "match score," where you can see a likelihood that a match was found. The actual values aren't quite the same, even adjusted for the baseline, but that's because the Fisher algorithm has a slightly different calculation of probability.

Nice, but how do I use it in my application?

The problem with our tests is that it builds a new classifier every test, and trains it from scratch. In a real world application, you'd obviously want to retain training over time.

Let's see how we can add a listener to the classifier such that we can see the new training events as they occur - and once we can do that, we'll see how we can store the data into a relational database.

Every classifier has a addListener(ClassifierListener) method. These listeners handle two types of events: FeatureIncrements and CategoryIncrements. CategoryIncrements keep track of how many training sessions were associated with a given category; FeatureIncrements do the same thing for specific words in that category.

Here's a test that shows the listener in action:

@Test
public void showListenerTest() {
 FisherClassifier classifier = new FisherClassifierImpl();
    classifier.addListener(new ClassifierListener(){
        public void handleFeatureUpdate(FeatureIncrement featureIncrement) {
            System.out.println(featureIncrement);
        }

        public void handleCategoryUpdate(CategoryIncrement categoryIncrement) {
            System.out.println(categoryIncrement);
        }
    });
    classifier.train("Now is the time for all good men "+
            "to come to the aid of their country", "test");
}

The tailend of its output:

...
FeatureIncrement[feature=come,category=test,count=1]
FeatureIncrement[feature=aid,category=test,count=1]
FeatureIncrement[feature=their,category=test,count=1]
FeatureIncrement[feature=countri,category=test,count=1]
CategoryIncrement[category=test,count=1]

Now we can create a listener to store this data into a database.

Before we begin: This is terrible code. Do not use it in anything even close to a production system. It's been written for speed and (sort of) brevity; it's compact and complete and has no dependency except HSQLDB, and it was quick to throw together. These are its only strengths. Using it in real code is sure to result in peasants chasing you down with torches and pitchforks and other such unpleasant things, or perhaps journalists throwing shoes at you. Neither alternative is good.

public class DBListener implements ClassifierListener {
    {
        try {
            Class.forName("org.hsqldb.jdbcDriver");
            Runtime.getRuntime().addShutdownHook(new Thread() {
                @Override
                public void run() {
                    Connection conn = null;
                    try {
                        conn = getConnection();
                        PreparedStatement ps = conn.prepareStatement("shutdown");
                        ps.execute();
                        conn.close();
                    } catch (SQLException e) {
                        e.printStackTrace();
                    }
                }
            });
            Connection conn = getConnection();
            PreparedStatement ps = conn.prepareStatement(
                    "create table cats(category varchar primary key, counts integer)");
            ps.execute();
            ps = conn.prepareStatement(
                    "create table feats(feature varchar, category varchar, "+
                            "counts integer, primary key(feature, category))");
            ps.execute();
            conn.close();
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
            System.exit(1);
        } catch (SQLException e) {
            e.printStackTrace();
        }
    }

    Connection getConnection() throws SQLException {
        return DriverManager.getConnection("jdbc:hsqldb:file:testdb", "sa", "");
    }


    public void handleFeatureUpdate(FeatureIncrement featureIncrement) {
        try {
            Connection conn=getConnection();
            PreparedStatement ps=conn.prepareStatement(
                    "update feats set counts=? where feature=? and category=?");
            ps.setInt(1, featureIncrement.getCount());
            ps.setString(2, featureIncrement.getFeature());
            ps.setString(3, featureIncrement.getCategory());
            int updates=ps.executeUpdate();
            if(updates==0) {
                ps=conn.prepareStatement(
                        "insert into feats (counts, feature, category) values (?,?,?)");
                ps.setInt(1, featureIncrement.getCount());
                ps.setString(2, featureIncrement.getFeature());
                ps.setString(3, featureIncrement.getCategory());
                ps.executeUpdate();
            }
            conn.close();
        } catch (SQLException e) {
            e.printStackTrace();
        }
    }

    public void handleCategoryUpdate(CategoryIncrement categoryIncrement) {
        try {
            Connection conn=getConnection();
            PreparedStatement ps=conn.prepareStatement(
                    "update cats set counts=? where category=?");
            ps.setInt(1, categoryIncrement.getCount());
            ps.setString(2, categoryIncrement.getCategory());
            int updates=ps.executeUpdate();
            if(updates==0) {
                ps=conn.prepareStatement(
                        "insert into cats (counts, category) values (?,?)");
                ps.setInt(1, categoryIncrement.getCount());
                ps.setString(2, categoryIncrement.getCategory());
                ps.executeUpdate();
            }
            conn.close();
        } catch (SQLException e) {
            e.printStackTrace();
        }
    }
}

Now, if we add this listener to a classifier, and then run a training session on it, we'll get the categories and features stored handily (albeit slowly and inefficiently) in a database. However, we can't really use the data yet, because this is a write-only listener. We'll need to add another method to load the data back from the database.

Here's yet another example that shows loading the entire classifier dataset from a database:

public void load(Classifier classifier) {
    try {
        Connection conn=getConnection();
        PreparedStatement ps=conn.prepareStatement("select category, counts from cats");
        ResultSet rs=ps.executeQuery();
        while(rs.next()) {
            String cat=rs.getString(1);
            int count=rs.getInt(2);
            classifier.getCategoryDocCount().put(cat, count);
            classifier.addCategory(cat);
        }
        ps=conn.prepareStatement("select category, feature, counts from feats");
        rs=ps.executeQuery();
        while(rs.next()) {
            String cat=rs.getString(1);
            String feat=rs.getString(2);
            int count=rs.getInt(3);
            ClassifierMap cm=classifier.getCategoryFeatureMap().getFeature(feat);
            cm.put(cat, count);
        }
        conn.close();
    } catch (SQLException e) {
        e.printStackTrace();
    }
}

Note the tasks it performs:

  1. It iterates over the categories, making sure to store the counts (and the category) into the internal category map;

  2. It iterates over the feature/category set, making sure to store the counts into the category/feature map.

Both steps are necessary, and all aspects of each step is necessary. A common mistake is to store the category count without updating the category set itself.

Conclusion, Future Directions, and Observations

Now you have the basis for a Bayesian classifier that can store its dataset between invocations. However, there are a few possibilities for improved performance if your situation needs it.

For one thing, the DBLister shows an example of loading the entire dataset into memory every time. It's fully possible that a large dataset could become prohibitive. Thus, you might extend the FisherClassifier to override the create*() methods to try to load from DB if the map doesn't already contain a data value (i.e., load on first query for a given feature or category). This would be much slower until enough of the feature set is loaded to be worthwhile, but might preserve heap space.

Another thing you might notice from examination of the database records is that many features aren't complete! For example, the word "architecture" is stored as "architectur." This is because the default word splitting class in CI-Bayes performs stemming, which means trimming a word into its root word. (Thus, "architectur" matches "architecture" and "architectures".) The stemming class in CI-Bayes is fairly good (based on code by Kevin Burton) but it's important to realise that other stemming algorithms exist, including those found in the Lucene project from Apache.

Lastly, while Bayesian classification is normally applied to text, it's not necessarily limited to text-only! CI-Bayes uses string representations of the objects passed to it, but as long as your objects have meaningful string representations, you can easily use a custom word lister to not perform stemming, for example, and thus you might classify biological data, or some other types of data based on your needs.

Enjoy.

Biography

Joe Ottinger is former Editor of TheServerSide and a writer on technology topics.