ML Conference
The Conference for Machine Learning Innovation

Deep Learning: Not only in Python

Although there are powerful and comprehensive machine learning solutions for the JVM with frameworks such as DL4J, it may be necessary to use TensorFlow in practice. This can, for example, be the case if a certain algorithm exists only in a TensorFlow implementation and the effort to port the algorithm into another framework is too high. Although you interact with TensorFlow via a Python API, the underlying engine is written in C++. Using the TensorFlow Java wrapper library, you can train and inference TensorFlow models from the JVM without having to rely on Python. Existing interfaces, data sources, and infrastructures can be integrated with TensorFlow without leaving the JVM.

AI and deep learning are certainly hot topics at the moment and despite some initial setbacks, e.g. in the field of self-driving cars, the potential of deep learning is far from exhausted. But there are still many areas of IT in which the topic is only just gaining momentum. It is therefore particularly important to investigate how deep learning systems can be implemented on the JVM, as Java (both the language and the platform) is still the dominant technology in the enterprise sector.

TensorFlow is one of the most important frameworks in the field of deep learning. Despite the increasing popularity of Keras, it is still inconceivable to do without it, especially as the AI top dog Google continues to drive its development forward. This article shows how TensorFlow can be used on the JVM to train and infer TensorFlow models.

What is the combination of TensorFlow and JVM suitable for?

DL4J is the only professional deep learning framework that is really at home on the JVM, so if you would like to use deep learning on the JVM, DL4J is usually the best choice. TensorFlow – like many machine learning frameworks – is mainly used with Python. However, there are reasons to use TensorFlow within a JVM context:

  • You want to use a process which has an implementation in TensorFlow, but not in DL4J, and the porting effort is too high.
  • You are working with a data science team that is used to work with TensorFlow and Python, but the target infrastructure runs on the JVM.
  • The data which is necessary for the training lies within a Java infrastructure (databases, custom data formats, APIs) and in order to get to the data, existing interface code must be ported from Java to Python.

The JVM TensorFlow combination is therefore always useful if an existing Java environment is available and, for personnel or project related reasons, TensorFlow has to be used for deep learning (see the box: “TensorFlow and JVM – always a good idea?”).

TensorFlow and JVM – always a good idea?

Although there may be good reasons for this combination, it is also important to mention what may speak against it. Especially the choice of TensorFlow should be well considered:

  • TensorFlow is not a suitable framework for deep learning or machine learning beginners.
  • TensorFlow is not user-friendly: The API changes quickly and in the mass of instructions it is often not clear which path is the best.
  • TensorFlow isn’t better just because it is made by Google: Deep learning is math, and math is the same for everyone. TensorFlow does not create “smarter” AIs than other frameworks. It is also not faster than the alternatives (but also not “dumber” or slower).

If you want to get into deep learning and stay on the JVM, the use of DL4J is absolutely recommended. Especially for professional enterprise projects, DL4J is a good choice. But also, if you want to look over the fence and try out a bit of Python, it is worth trying out the TensorFlow alternatives. Here, you are currently better off with Keras, thanks to a much more convenient API.

How does TensorFlow work?

Before you start to use a new framework, it is important to take a look at what happens under the hood (see the box: “TensorFlow cheat sheet”). When thinking of TensorFlow, the first things that come to mind are AI and neural networks. But from a technical point of view, TensorFlow is mainly a framework that can execute complex, iterative, parallel calculations on tensors – and that, if possible, GPU-accelerated. Although deep learning is the main field of application for TensorFlow, it can also be used for any other calculation.

A TensorFlow program – or better: the configuration of a calculation – is always structured like a graph in TensorFlow. The nodes of the graph represent operations, such as adding or multiplying, but also loading and saving. Everything that TensorFlow does takes place in the nodes of a previously defined calculation graph. The nodes (operations) of the graph are connected by edges through which the data flows in the form of tensors. Hence the name TensorFlow.

All calculations in TensorFlow take place in a so-called session. In the session, either a finished graph is loaded, or a new graph is created piece by piece by API calls. Special nodes in the graph can contain variables. In order for the graph to work, these must be initialized. Once this has happened and a session with a finished, initialized graph exists, TensorFlow interacts only by calling operations in the graph. What is calculated depends on which output nodes of the graph are queried. Thus, not the entire graph is executed, but only the operations that provide input for the queried node, then their input nodes, etc., back to the input operations, which must be filled with the necessary input tensors. The important thing with TensorFlow is that all operations are automatically differentiated for the user – this is needed for the training of neural networks. However, the user can safely blend it out since it happens automatically.

Usually, the graph is defined by a Python API. It can be represented graphically with auxiliary programs (Fig. 1). But such representations only serve for debugging, as they are not graphically programmed like in a visual programming language such as LabView.

Fig. 1: A (small) section of a TensorFlow graph: The numbers on the edges indicate the size of the tensor flowing through them, the arrows indicate the direction.

Although, in most examples, Python is used to interact with TensorFlow, the actual engine is written in C/C++. Therefore, you can use TensorFlow with any language that can call C functions. Thus, you can also perform calculations in TensorFlow from the JVM.

TensorFlow cheat sheet

  • Tensor: The basis for calculations in TensorFlow. A tensor is actually an object from linear algebra, but for our purposes it is completely sufficient to consider a tensor as a multidimensional array (mostly from float or double values, sometimes also char or boolean). TensorFlow uses tensors for everything. All data that TensorFlow consumes, produces, and uses internally is packaged in tensors – hence the name.
  • Graph: The definition of TensorFlow calculation procedures is usually stored in a file called graph.pb in a ProtoBuf binary format, similar to a Java .class file.
  • Training: When training a machine learning method, data and expected results are presented to the algorithm over and over again, whereupon the algorithm adjusts the internal parameters of the model to improve the result. Sometimes this is called “learning”, although it has little to do with human learning.
  • Inference: Depending on the application, you may want to use a machine learning process to classify, predict, translate, create content, and much more. All these applications are summarized under the term inference. Inference therefore means as much as “using a procedure to obtain a result”. This is what we want to do most of the time in live use after training. During inference, a procedure does not learn.
  • Model: the learned parameters of a machine learning procedure, for example, a neural net. This is the result of the learning process and necessary to obtain results (the variable state of the graph, so to speak). It is distributed over several files and stored in one *.index and several *.data files, for example, *.data-0000-of-0001. The first number indicates the consecutive number of the file, the second the total number.
  • Session: the context in which TensorFlow is executed, such as a running JVM instance. In order to use TensorFlow, we need to create a session in which a graph is loaded that is initialized with a model. In Java, a JVM instance must be started in which classes are loaded that are instantiated with constructor parameters.

TensorFlow training and inference with Python

The training of a TensorFlow model with Python (box: “ or feeding?”) can be separated into the following steps:

  • Create the graph, either via several API calls that compose the graph or through loading a *.pb file that contains the graph
  • Create a session for the graph
  • Initialize the graph variable either by calling a special operation in the graph which fills the variables with default values or through loading a pre-trained model

After these three steps, we have an executable TensorFlow session with a functioning model. If we want to (further) train it, the following three steps are always executed in a loop until the model has learned enough – either by defining a fixed number of training steps beforehand or by waiting until the training error drops below a certain level:

  • Package input data in arrays and assign input sensors
  • Select output node and pack into a list
  • Execute the session: a special command causes the session to perform the necessary operations to generate the selected output

But where does the training take place? It is done by executing the correct output nodes. There is no difference between training and inference for TensorFlow, mathematical operations are simply performed in the calculation graph. We speak of training if these lead to a neural network that learns a better weighting to solve a problem. However, the API calls for training and any other type of usage are the same.

Our input consists of the data that is to be learned (for example, an image as a two-dimensional tensor and the label “dog” or “cat” in the form of an integer ID in a zero-dimensional tensor) during training. By running the correct nodes, TensorFlow updates some variables in the graph to improve the prediction. The main difference between training and inference is that we periodically save the current state of the graph variables – which are constantly changing – while this is useless for the inference because they remain constant. or feeding?

There are two possibilities to load training data into the graph when you train a TensorFlow model in Python:
the API or so-called “feeding”, i.e. the transfer of individual data for each calculation step. The API is implemented internally in C, integrated directly into the graph, and therefore very fast – but also complicated to use and very difficult to debug. The feeding method is easy to use and understand, but you need Python code at runtime. Therefore, Python usually slows down the more expensive graphics card, and valuable GPU capacity is not used. But which approach do we take in Java? Fortunately, Java is orders of magnitude faster than Python, so here we get the best of both worlds: We can use the easy-to-understand feeding method and still get full performance. That is why we leave the API out of this article, we just don’t need it.

The TensorFlow Java API

Now we can call all operations that are necessary in Python for training or inference via JNI, since TensorFlow is implemented internally in C/C++. Fortunately, we no longer have to bother wrapping the low-level C API with JNI, as Google has already done this for us. The necessary libraries are, as usual, available on Maven Central. There are four different artifacts, all in the group org.tensorflow:

  • tensorflow: A metapackage with dependencies on libtensorflow and libtensorflow_jni; in order to avoid confusion, it should not be used.
  • libtensorflow: The API against which you program in Java; this is the compile and runtime dependency and the central entry point.
  • libtensorflow_jni: Contains the native CPU dependencies for libtensorflow; this artifact is needed at runtime when using a machine without GPU; it contains native code for Windows, Linux and Mac; TensorFlow is completely included, you don’t have to install Python or TensorFlow on the running system.
  • libtensorflow_jni_gpu: The GPU equivalent to libtensorflow_jni; you should use this dependency if you use a computer with NVIDIA GPU and Cuda and CuDNN are installed correctly; it only works under Windows and Linux, there is no GPU support for TensorFlow under macOS.

The version numbers of the Java wrappers correspond to the version number of the included TensorFlow version. Here we should always use the newest stable release. We only have to pay attention if the code is supposed to be executed on a computer with GPU (box: “Selecting the GPU to be used”). Not every TensorFlow version supports every CUDA and CuDNN version (CUDA is a special NVIDIA driver to use graphics cards for parallel calculations, CuDNN is a CUDA based library for neural networks). We must ensure that the CUDA and TensorFlow versions are matching. Currently, all TensorFlow versions from 1.13 on support the same CUDA version: 10.0. With a Java-based solution, we already have a great advantage over Python software when installing the finished software. Thanks to Maven, our resulting artifact already includes all dependencies. Neither Python nor TensorFlow nor any Python libraries have to be pre-installed or the installations managed with a tool like Anaconda.

You should not use the top-level dependency tensorflow, it is better to directly use libtensorflow and one of the *_jni implementations. The reason for this is that the tensorflow artifact has a dependency on libtensorflow_jni (the CPU variant). If we now add libtensorflow_jni_gpu, the CPU-native code is still used and one wonders why everything runs so slowly despite the GPU. The Gradle dependencies for the TensorFlow training on the GPU look like this:

compile "org.tensorflow:libtensorflow:1.14.0"
runtimeOnly "org.tensorflow:libtensorflow_jni_gpu:1.14.0"

The required Java API for training and inference is simple and manageable. Only four classes are important: Graph, Session, Tensor and Tensors. We can now see how to use them correctly by rebuilding the Python-typical training steps in Java.

TensorFlow training in Java

The first step in training is to define the graph. Unfortunately, we have to make the first but only compromise right at the beginning. A graph can also be built step by step using the Java API, but for many node types, the Python API automatically generates many necessary helper nodes that are required for the frictionless use of the graph. In order to build this in Java, we would need a very detailed knowledge of the Python API internals. This step must therefore be done once in advance in Python. We then store the resulting graph file as a Java resource in order to then load it back into the JVM. Saving the current graph in Python is very easy:

with open(filename, 'wb') as f:

Important: Even if the method used here is called SerializeToString(), the result is still a binary file. For our convenience, we should also save the initialized variables here. Although initializing the variables in the graph from the JVM would be easy, if we always choose the here shown procedure, it makes it easier to do transfer training with complex models afterwards. Hereby, an already existing state of a model is further trained and adapted (Listing 1).

Listing 1

# This Python command creates a node for initialization
init_op = tf.global_variables_initializer()
# The saver is an auxiliary class that stores a model in Python.
saver = tf.train.Saver()
# Save is a graph operation
# and can only be executed in one session
with tf.Session() as sess:
  # Initializing Variables
  # Save state
  save_path =, filename)

Now we have saved the graph and the model and can train it in Java and execute the graph. For the sake of brevity, the following examples are in Kotlin but can be transferred to any JVM language:

//create empty graph
val graph = Graph()
//*. load pb file - either from a file or from resources
val graphDefBytes = javaClass.getResource(resourceName).readBytes()
//reconstruct graph from file

Now we have loaded the TensorFlow graph into the JVM. In order to do something with it, we need a session:

val session = Session(graph)

We only have to load the latest version of the variable before we can really get started. This can be either the file initially saved in Python or the last state of a previous training, for example, in order to continue a training. The loading of variables is only an operation in the TensorFlow graph, and a string packed into a tensor is needed for this operation. The string contains the name of the *.index file without the suffix, so foo instead of foo.index.

Here, we need the Tensors class for the first time. This class contains help functions to package Java data types into Tensor objects. Hereby, it is automatically taken into consideration that the Tensor has the correct form. Important for every Tensor object: It contains memory that has been allocated outside the JVM. Therefore, it must be closed manually, for which it implements the Closable Interface. In Java, an own try{…} finally { tensor.close(); } block must be created for each tensor. Fortunately, this is much easier in Kotlin with use:

Tensors.create(path).use { pathTensor ->
  session.runner().feed("save/Const", pathTensor)

Here we can see all necessary parts of a TensorFlow action on the JVM:

  • A runner is created for the session; this class has a builder API that defines what is supposed to be executed.
  • The input node for the loading and saving (“save/Const”) is filled with the tensor which contains the file name.
  • The target node is defined as the target for loading.
  • The action is executed.

The trick for all operations is to know their names. But since we build the graph ourselves beforehand and can define the name of a node at creation, we can choose them for ourselves. Exceptions are the nodes for loading and saving, which always have the here stated names.

Selecting the GPU to be used

Sometimes we don’t want to block all GPUs on systems with multiple GPUs, for example, to run multiple trainings in parallel. For this, we can configure the TensorFlow graph, which normally automatically allocates the GPU or GPUs, so that only one GPU is used. This has the big disadvantage, though, that the graph is then “hard wired” to a certain GPU and can be used only on this GPU. It is much more convenient to show or hide the GPUs by environment variable before starting the JVM. This can easily be done with the environment variable CUDA_VISIBLE_DEVICES. Here, we can specify a comma-separated list of CUDA devices that should be visible in the current shell. Caution: The numbering starts at 1, not at 0. The following console command, for example, activates only the second graphics card for TensorFlow (or other deep learning frameworks):


Now we have already seen all the operations needed to interact with TensorFlow from the JVM. Carrying out a training step is now very easy. Let’s assume that our input is an array of loaded images. The black and white values of the pixels are converted to float values in the range 0-1. Each image belongs to a class defined by an int value, for example, 0 = dog, 1 = cat. Then the input for a batch (multiple images are always trained at once) is a float[][] array, which contains the images, and an int[] array, which contains the classes to learn. A training step can now be executed as follows (Listing 2).

Listing 2

fun train(inputs: Array<FloatArray>, labels: IntArray) {
  withResources {
    val results: List<Tensor<*>> = session.runner()
      .feed("inputs", Tensors.create(inputs).use())
      .feed("labels", Tensors.create(labels).use())
    val trainingError = results[0].floatValue()
    val accuracy      = results[1].floatValue()
    val prediction    = results[2].intValue()

We see the same pattern again: A runner is created, the inputs are packaged into tensors, the target is selected (“optimize“) and the action is executed. But now we have an innovation: We get values back. The names of the nodes that are to be returned are defined with fetch. The names contain a suffix: “:0”. This means that they are nodes with multiple outputs, the :0 suffix means that the output with index 0 of the node should be returned.

The output is a list of Tensor objects. These can be converted into various primitive types and arrays to make the result available. Important: The Tensor objects created by the API also have to be closed. Normally, the entries in the list would have to be iterated and closed in a finally block. However, this is very inconvenient and difficult to read. Therefore, it is useful to define an extended use API in Kotlin, with which several objects within a block are marked with use or useAll (for lists of Closables), which are then closed safely (Listing 3).

Listing 3

class Resources : AutoCloseable {
  private val resources = mutableListOf<AutoCloseable>()

  fun <T: AutoCloseable> T.use(): T {
    resources += this
    return this

  fun <T: Collection<AutoCloseable>> T.useAll(): T {
    return this

  override fun close() {
    var exception: Exception? = null
    for (resource in resources.reversed()) {
      try {
      } catch (closeException: Exception) {
        if (exception == null) {
          exception = closeException
        } else {
    if (exception != null) throw exception

inline fun <T> withResources(block: Resources.() -> T): T = 

This useful trick allows you to close all tensors within a TensorFlow call conveniently and safely. With the inference under Java, it becomes really easy. We remember: Every action on the TensorFlow graph is performed by filling input nodes with input tensors and querying the correct output nodes. This means the following for our example above: The code remains the same, only we don’t set the inputs for the correct solution (labels). This makes sense because we don’t know them yet. In the output, we do not call the nodes for the error calculation and the update of the neural net (total_loss:0, accuracy:0, optimize), so we do not learn. Instead, we only query the result (prediction). Since the input of the solutions is not necessary for the calculation of the result, everything works just like before: There is no error because the part of the graph that trains the neural net remains inactive.

Practical experiences

The method presented here is not only an interesting experiment, but the author has already used it successfully in several commercial projects. Thereby, several advantages have emerged in practical use:

    • The Java API is fast and efficient: There is no performance loss compared to the pure Python application. On the contrary: Since Java is much faster than Python for tasks like data import and pre-processing, it is even easier to implement a high-performance training process.
    • The training runs absolutely stable over several days, Google’s Java implementation has proven to be very reliable.
    • The deployment of the finished product is much easier than that of Python-based products, since only a Java runtime environment and the correct CUDA drivers need to be present – all dependencies are part of the Java TensorFlow library.
    • TensorFlow’s low-level persistence API (as presented here) is easier to use than many of the “official” methods, such as estimators.

    The only real drawback is that part of the project is still Python-based – the definition of the graph. So you need a team that is at least partly at home in the Python world.

Behind the Tracks