Image by Author
Â
In recent years, the Keras + Tensorflow tandem has encountered a competitor framework slowly gaining importance across the deep learning developers community: JAX. But, what exactly is JAX, what are its capabilities and how does it resemble and differ from the Keras API that has historically been the almost universal approach to using Tensorflow, the biggest deep learning library in Python? This article unveils the answers to these questions.
Â
What is Keras?
Â
Keras was born in 2015 as an interface to simplify the use of well-established libraries for building neural network architectures, like Tensorflow. Even though it was initially created as a standalone framework, Keras, eventually became one with Tensorflow: a major Python library for efficient training and using scalabLe deep neural networks. Keras then became an abstraction layer on top of Tensorflow: in other words, it made the use of “raw” Tensorflow much easier.
Our Top 3 Partner Recommendations
1. Best VPN for Engineers – 3 Months Free – Stay secure online with a free trial
2. Best Project Management Tool for Tech Teams – Boost team efficiency today
4. Best Password Management for Tech Teams – zero-trust and zero-knowledge security
Keras provides implementations of the most common building blocks of neural network architectures: layers of neurons, objective and activation functions, optimizers, and so on. Special types of deep neural network architectures like convolutional neural networks (CNNs) and recurrent neural networks (RNNs) are easily constructed by using Keras abstraction classes and methods.
Â
What is JAX?
Â
JAX is a comparatively newer framework not only for deep learning but for machine learning developments as a whole. It was released by Google in 2018 and its core focus is high-performance numerical computations. Concretely, JAX makes the use of Python and numpy (its largest numerical computations library) simpler and quicker, along with seamless support for GPU and TPU high-performance processing. This is an important advantage over plain numpy in terms of scientific and numerical computations since numpy only supports CPU executions.
Due to its balance of intuitiveness and versatility of high-performance execution modes, JAX is rapidly gaining the reputation of becoming the most advanced framework for machine learning and deep learning developments, with chances of eventually replacing other frameworks like Tensorflow and PyTorch. Its automatic differentiation feature is handy for efficiently performing the complex gradient-based computations behind training a deep neural network.
In short, JAX unifies the capabilities of scientific and high-performance computing into a single framework.
Â
Similarities and Differences Between Keras and JAX
Â
Now that we have a glimpse of what Keras and JAX are, we’ll list some features shared by both frameworks, as well as a number of aspects in which will differ.
Similarities:
- Deep learning model development: both frameworks are popularly used to build and train deep learning models.
- GPU/TPU acceleration: both Keras and JAX can take advantage of accelerated hardware like GPUs and TPUs to train models efficiently.
- Automatic differentiation: the two frameworks incorporate mechanisms for automatically computing gradients, a key process underlying the optimization of models during their training.
- Interoperability with deep learning libraries: both frameworks are compatible with the popular deep learning library TensorFlow.
Differences:
- Abstraction level: whilst both solutions provide some level of abstraction, Keras is more suited to users looking for a very high-level API with ease of use, whereas JAX bets more on flexibility of control, staying at a lower level of abstraction with a focus on numerical computations.
- Backend: Keras is strongly based and dependent on Tensorflow as its backend. Meanwhile, JAX does not depend on Tensorflow, using instead an approach called Just In Time (JIT) compilation. This said, JAX and Tensorflow can be used jointly and they complement each other well in certain situations, e.g. for integrating advanced mathematical transformations into high-level deep learning architectures.
- Ease of use: closely related to the abstraction level, Keras is designed to be easy and quick to use. JAX, while more powerful, requires a deeper technical knowledge for its smooth utilization.
- Function transformations: this is an exclusive feature of JAX, which allows advanced transformation capabilities like automatic vectorization and parallel execution.
- Automatic optimization: again, JAX is the spotlight in this aspect, being more flexible and facilitating the optimization of various functions beyond the scope of neural networks (which is why it is also suitable for other machine learning methods like ensembles), whereas Keras is exclusively focused on deep learning models.
Â
So, Which One Shall I Choose?
Â
Having gained an understanding of the similarities and differences between both frameworks, it is not a big chore to decide on which framework to choose depending on the problem or scenario at hand.
Keras is the go-to option for users looking for ease of use, a smaller learning curve, and a higher level of abstraction. This API on top of Tensorflow library will get them prototyping and utilizing a variety of deep learning models for predictive and inference tasks in no time.
On the other hand, JAX is a more powerful and versatile option for experienced developers to gain added capabilities like optimized calculations and advanced function transformations -not being tightly limited to Tensorflow or deep learning modeling- although it demands more control and low-level engineering decisions from the user.
Â
Â
Iván Palomares Carrascosa is a leader, writer, speaker, and adviser in AI, machine learning, deep learning & LLMs. He trains and guides others in harnessing AI in the real world.