Machine Learning with Python

Train your own image classification model with Keras and TensorFlow

Machine Learning with Python: Train your own image classification model with Keras and TensorFlow

Image classification models are intended to classify images into classes. We usually want to divide them into groups that reflect what objects are on a picture. For example, we can train an image classification model that can distinguish "dog" from "cat," but of course, even more complex classifications can be made in significantly more classes.

Machine Learning with Python – It’s all about bananas

In principle, you make any group classification: Maybe you’ve always wanted to be able to automatically distinguish wearers of glasses from non-wearers or beach photos from photos in the mountains; there are basically no limits to your imagination – provided that you have pictures (in this case, your data) on hand, with which you can train your task using a mathematical model.

Keras and TensorFlow

The actual training of the model is very easy. That is because today you have a number of open source libraries that are easy to use and with which even beginners can quickly achieve success. One of these libraries, which I will use to present an example, is Keras with TensorFlow Backend [1]. Keras is an open-source deep learning API that allows easy, fast and flexible training of neural networks. The API is modular and works much like a LEGO system for machine learning: Neural networks are composed of layers of different types that build on each other and can, therefore, be designed to be as sophisticated as you like. The actual calculations are not done by Keras itself, but by an underlying backend. TensorFlow has now become the standard for this and Keras is part of the TensorFlow Core Library. Both libraries are therefore optimally designed to be used together because the prototyping of neural networks is much easier with Keras than with TensorFlow directly.

Artificial neural networks

Stay tuned!
Learn more about ML Conference:

Artificial neural networks consist of nodes that we can consider as an analogy to human nerve cells. These nodes are arranged in layers. A neural network always begins with an input layer, into which the existing data flows – the so-called Input. In the end, we have a output layer that represents the result – the so-called Output. We can have as many layers as we need in between, which we refer to as hidden layers. Deep learning describes machine learning with large neural networks. When it comes to solving very complex problems with many levels of abstraction, deep learning is particularly successful – an example of this is image recognition.

Convolutional Neural Networks for image recognition

Image classification models are nowadays usually trained using Convolutional Neural Networks (CNNs), a special type of neural network. CNNs learn by finding different levels of abstraction of the possible classes. This means that in the first hidden layers of the neural network, general patterns such as edges are usually recognized. The farther we go into the hidden layers in the direction of the Output, the more specific the recognized patterns, become; for example textures, rough patterns, individual parts of the objects in the images and ultimately entire objects.

In CNNs, groups of neighboring pixels are considered. This allows the network to efficiently learn the context of patterns and objects and recognize them in other positions on the image. Specifically, this works with so-called sliding windows, or windows that look at a group of pixels and thereby scan the image from top left to bottom right. On each sliding window, a mathematical operation is performed, the so-called convolution. This convolution occurs for each window of the entire image by multiplying the pixel values in our window by a so-called filter. Depending on the values that are in a filter, the convolution leads to a specific transformation of the original image.

For example, filters can blur the original image and detect horizontal or vertical edges. In principle, however, any value can be inserted into the filters, such that a variety of patterns can emerge in an image. The values that are registered at the individual points of the filter should now be learned by our neural network. The CNN also learns which transformation it needs to perform and when to recognize the right patterns and objects in the images.

Pre-trained networks

Because Keras provides us with a range of pre-trained image classification models, we can use them directly to achieve very good results for our own tasks even if we have just a few images.

A pre-trained network was trained on a large amount of data and stored with the learned parameters. Image classification models learned different patterns of objects on images or so-called features. The idea now is that we can reuse the general features learned on this dataset for our classification task (feature extraction from the convolutional layers). Only the more special features specific to our images need to be additionally learned (fine tuning). In the example (see the box “Example 1”), I use VGG16 as the base model, which was trained on the famous ImageNet dataset [2]. Instead of VGG16, you can of course also use newer models like Xception. An overview of all pre-trained networks available in Keras can be found on the Keras website [3].

Example 1

Creation of the base model VGG16 with weights learned on ImageNet in Keras. Since we want to re-train the last layers for classification, we need to set includetop = false. To obtain a training result faster, I scale the image size down a bit. The full code can be found in the Jupyter notebook (on ) to this article.

base_model = VGG16(weights = 'imagenet', 
                   include_top = False, 
                   input_shape = (img_width, img_height, channels))

The number of images we will need is difficult to express in general numbers; this is because the number of images you need per group (also called class) depends on how much the objects you want to classify vary. If the objects in a group are very similar, a model will achieve good accuracy even with less data; things will look different though if you want to classify objects which vary greatly. A rule of thumb is that you need at least a thousand images per class if you train a new model from scratch. However, if you use a pre-trained model as we are doing here and only adapt it to your task, sometimes only a few hundred pictures per class are enough.

Unfortunately, it is not enough just to have images; each image must have a label that tells you what object can be seen on the image. If in doubt, you need to label each image by hand to generate the training dataset for the image classification model.

Example implementation

The images I use in this example are from Kaggle [4] and show different types of fruit on a white background. We have put some of this fruit in a Docker image for you. It also contains all the required libraries so you can get started right away. The entire analysis for rebuilding and adapting can be found in the Jupyter notebook for this article (online at, which you can use together with the Docker Image. The pictures are saved in the fruits-360 folder, which contains two subfolders: Training and Test. Both subfolders contain more subfolders with the names of the individual fruit types (Fig. 1).

Fig. 1: Random selection of images of fruit used in the example (all images: © 2017 Mihai Oltean, Horea Muresan, MIT License)


To read in the images, we use two functions from Keras, which are made specifically for the case in which we sort images in subfolders and the names of the subfolders represent the class labels:

  • ImageDataGenerator: generates batches of image data. Here we normalize the pixel values by dividing 255 by the maximum value to get values between 0 and 1; we also use data augmentation (multiplication and modification of training images). Warning: DO NOT use data augmentation on validation data!
  • flow from_directory: reads images in batches from files in memory according to the defined ImageDataGenerator.

In this Keras example, we use the simpler sequential API (as opposed to the slightly more complex but more flexible functional API). Sequential models consist of layers that build on one another in linear fashion. There is only one input and one output layer. The hidden layers in between will only go in one direction: from Input to Output. With sequential models, most neural networks can be trained, so they are sufficient for most use cases.

We first initialize the model and then add the base model as well as our own layers. This is a dense layer that connects all the nodes together here. To make sure our multi-dimensional filters from VGG16 can pass into the Dense layer, we first have to Flatten() the data. Finally, we have an output layer with the number of possible predicted classes (Listing 1). To make sure that during our training process only the last layers will be learned, we have to freeze all other layers. This is done by setting the corresponding attributes of the base model to trainable = False.

Listing 1: Model definition with the sequential API from Keras

# Create the model
model = models.Sequential()

# Add the base model

# Add new layers
model.add(layers.Dense(519, activation='relu'))
model.add(layers.Dense(output_n, activation='softmax'))

Here I use the approach described in the book “Deep Learning with Python” by Keras developer Francois Chollet [5]:

  1. Add your own layers to the end of the base model.
  2. Freezing the base model.
  3. Train your own layers.
  4. “Thawing” (trainable = True) of the last convolutional layers.
  5. Train these last convolutional layers and your own layers.

Because we just want to fine-tune the classification, we use a very small learning rate in our optimization process to get as close to the global error minimum as possible. Since we read in our data using the ImageDataGenerator, we correspondingly use the fit_generator function and specify the training and validation data and the number of epochs and steps (steps_per_epoch) to define the number of augmented images to read per batch (see box “Example 2”).

Example 2

Training the image classification model in Keras. The Output indicates the respective epoch, steps and estimated time remaining, as well as performance metrics for training and validation data. Since early stopping was used here, the training will end after thirteen epochs because validation accuracy has not improved over several epochs.

history = model.fit_generator(
  steps_per_epoch = steps_per_epoch,
  epochs = 100,
  validation_data = valid_image_array_gen,
  validation_steps = validation_steps,
  callbacks = callbacks_list,
  verbose = 1)

Epoch 1/100
722/722 [==============================] - 329s 455ms/step - loss: 1.5751 - acc: 0.6057 - val_loss: 0.3644 - val_acc: 0.8660


Epoch 13/100
722/722 [==============================] - 284s 394ms/step - loss: 0.8131 - acc: 0.8642 - val_loss: 0.8141 - val_acc: 0.7959

Epoch 00013: val_acc did not improve from 0.98289
Epoch 00013: early stopping

The output during training already gives us an overview of the development of the performance metric on training and validation data in the individual epochs. But we can also plot it all using Matplotlib (Fig. 2).

Fig. 2: Accuracy of the image classification model per epoch during training in Keras; blue curve: Accuracy on training data, orange curve: Accuracy on validation data


And finally, we can now use the model thus trained for predictions on new test data, for example, for the image of a banana from Wikipedia (Fig. 3).

Fig. 3: Prediction of a new image using the Keras-trained image classification model to detect fruit in images; the image was recognized as a banana with a probability of 100% (source: Wikipedia [6])



When we work with just a few training pictures, we often have the problem of overfitting. Data augmentation, as we have used it here, can help reduce the problem. Here you can try to use more training images per batch by increasing the steps_per_epoch. Other hyperparameters such as learning rate, momentum, number and size of the (thawed) last layers, etc. can also be optimized. For example, the Hyperas Library [7] is suitable for this. When in doubt, the situation may be that you have not added enough of your own training images and should try to get additional labeled images.

If you want to learn more about image classification, but also about the basics of machine learning or natural language processing, you can find more information online [8].

Links & literature

[1] Keras:
[2] ImageNet:
[3] Vortrainierte Keras-Netze:
[4] Kaggle:
[5] Chollet, Francois: “Deep Learning with Python”. Manning, 2017.
[6] Image of a banana:
[7] Hyperas:



Behind the Tracks