Extract Dominant Colors from an existing Image — K-Means Clustering Algorithm

An Interesting Implementation of K-Means Clustering Algorithm

Nandini Bansal
The Startup

--

Introduction

I had just finished Andrew NG’s Machine Learning Course on Coursera and I was excited to try my hands on different projects, implementing all of that I learned over the past five months apart from the course’s weekly assignment. Just like everyone else, I searched on Google for some interesting Machine Learning Projects for beginners when I came across this idea. This project was not groundbreaking or anything like that but gave me a good start in this field.

This tutorial will help you to implement the K-Means Clustering Algorithm to extract dominant colors from an existing image. Before digging into the code, let us walk through the background of the K-Means algorithm.

K-Means Clustering Algorithm

One of the most popular yet simplest Unsupervised Algorithms is K-Means Algorithm. For all the data points scattered in an n-dimensional space, it groups the data points with some similarities in one cluster. After randomly initializing k cluster centroids, the algorithm performs two steps iteratively:

  1. Cluster Assignment: Each data point is assigned a cluster based on its distance from the cluster centroid.
  2. Move Centroid: The mean of all the points of a cluster is calculated and cluster centroid is relocated to the mean location.

According to the new centroid locations, the data points are re-assigned the clusters.

Iterative Steps of K-Means Algorithm

After a certain number of iterations, we observe that cluster centroids do not move any further or to any new location, neither do the data points in a cluster change. At this point, the algorithm has converged.

Now that we know how the K-Means algorithm works, let us dive into the code to extract colors from an existing image.

We will follow a functional approach to develop this program i.e. the entire program is split into multiple functions. The module imports needed for this program are given below.

Before starting the main function, we will be creating an ArgumentParser() object to accept command-line arguments and corresponding variables to accept the values of command-line arguments. I have kept two “optional” command-line arguments namely, clusters and imagepath.

In clusters argument, you will need to mention the number of colors you wish to extract from the image while imagepath is used to pass in the path-to-image with image name. By default, the program will be extracting 5 colors from an image and select an image called poster.jpg from folder images. You can set the default values as per your choice. We will also be defining the WIDTH and HEIGHT for image resize before extracting colors from it. I have kept the width and height as 128px.

For hex codes and their corresponding color names, I have used a JSON file. The entire dictionary of color names and their hex codes has been picked up from the JavaScript file (available for public consumption) given below:
http://chir.ag/projects/ntc/ntc.js (JavaScript file)
http://chir.ag/projects/ntc/ (link to creator’s website)

We will be reading the JSON file in a variable named color_dict. Now, the key-value pairs of JSON can be easily accessed in our program using this dictionary variable.

Let us now start with taking an image as input and passing it to the K-Means Algorithm.

As you can see the above function “TrainKMeans” accepts an image file as an argument. In the very first step, we are resizing the image to the dimensions that we defined earlier in the program. Again, if you have noticed, I have used a custom function to resize the images.

In my custom resize function, I have resized the longer dimension of the image to the fixed dimension HEIGHT or WIDTH and rescaled the other dimension, keeping the ratio of height by the width of the image constant.

Going back to TrainKMeans function, after resizing the image, I am converting the image to numpy array and then, reshaping it to a 3-dimensional vector to represent RGB values in the next step.

We are now all set to create color clusters in our image. Using KMeans() functions, we can create clusters wherein hyperparameter n_clusters is set equal to clusters, the command-line argument we accepted at the beginning of the program and random_state is equal to zero.

In the next step, I am fitting the model and predicting clusters for our input image file. Using the cluster centers (RGB values), we can find the hex code of the respective color that the cluster represents and for that, I have used a custom function called rgb_to_hex.

It is a very simple function that is using to_hex function of matplotlib.colors. I have normalized the RGB values to lie in the range of 0 to 1 and then, converted them to their respective hex codes. Now, we have the hex-codes of each color cluster.

In the next step, we are finding the names of each color using the findColorName() function.

In findColorName function, we calling yet another custom function named get_color_name() which returns two values, aname (actual name) and cname (nearest color name).

In this function, I am using a third-party module webcolors to convert RGB to color name. By default, webcolors function looks up in CSS3 color list. If it is not able to find the color in its list, it raises ValueError which I have handled using another custom function called closest_colour(). In this function, I am calculating the euclidean distance between input RGB values and all the RGB values present in the JSON. Then, the color with the least distance from the input RGB value is selected and returned.

A dictionary of hex-codes with their respective names in created in TrainKMeans() function. I have then created a list of all the RGB points present in the image using img_vector.

An empty data frame, cluster_map is initialized in the next step. I am then creating a column called position that holds the RGB value of every data point (pixel) present in the image and column cluster, I have stored the cluster number that every data point (pixel) has been grouped into.

Then, in color and color_name columns, I have stored the hex code and their respective color names for every pixel of the image. At last, we are returning the cluster_map data frame and kmeans object.

I have then plotted every data point (pixel) of the image in a 3D space using a scatter plot with colors identified in the image and a pie chart that shows the color distribution of the image.

This is all for this article. In the next part, I will show you how to create an API for this program using Flask. You can find the entire code for this project at https://github.com/nandinib1999/DominantColors.

For any questions and suggestions, please feel free to connect with me on LinkedIn.

Thanks for reading!

~ Nandini

--

--