 Machine Learning – Mini Batch K-Means

In an earlier post, I had described how DBSCAN is way more efficient(in terms of time) at clustering than K-Means clustering. It turns out that there is a modified K-Means algorithm which is far more efficient than the original algorithm. The algorithm is called Mini Batch K-Means clustering. It is mostly useful in web applications where the amount of data can be huge, and the time available for clustering maybe limited.

The Problem

Suppose we have a dataset of 500000 records, and we want to divide them into 100 clusters. The complexity of the original K-Means clustering clustering algorithm is O(n*K*I*f), where n is the number of records, K is the number of clusters we want, I is the number of iterations and f is the number of features in a particular record. It can be clearly seen that this will take a lifetime for the original algorithm to cluster data.

The Idea

The idea of the algorithm is to represent the dataset by a smaller subset of the data. For example, for a dataset of 500000 records we will actually use only about say 100000 entries for training. You may say that the obtained cluster centroids may not be a good representation of how the actual clusters should be, but the algorithm actually makes sure that the model generalizes well over whole of the dataset. It does this in far lesser time than the original algorithm.

The Algorithm

The algorithm takes small batches(randomly chosen) of the dataset for each iteration. It then assigns a cluster to each data point in the batch, depending on the previous locations of the cluster centroids. It then updates the locations of cluster centroids based on the new points from the batch. The update is a gradient descent update, which is significantly faster than a normal Batch K-Means update. Formally, the algorithm is as follows:

Given: k, mini-batch size b, iterations t, data set X
Initialize each c ∈ C with an x picked randomly from X
v ← 0
for i = 1 to t do
M ← b examples picked randomly from X
for x ∈ M do
d[x] ← f (C, x)           // Cache the center nearest to x
end for
for x ∈ M do
c ← d[x]                    // Get cached center for this x
v[c] ← v[c] + 1         // Update per-center counts
η ← 1 / v[c]              // Get per-center learning rate
c ← (1 − η)c + ηx      // Take gradient step
end for
end for

The graph below shows a comparison of Stochastic Gradient Descent(SGD) K-Means, Mini Batch K-Means and Batch K-Means. SGD is the fastest, but gives a significant error compared to the other algorithms. Mini Batch is preferred because it gives a perfect balance of accuracy and computation time. It makes real time clustering possible. It is important to notice that this graph is for a particular value of K. However, according to the original paper by D. Sculley, the algorithm works fine for K = 50 for a batch size of 1000. This is one place where DBSCAN scores over K-Means; there is a way to get the parameter values for DBSCAN depending on the data.

I have written the Mini Batch K-Means algorithm in C. You can find the code over here.