How to Train a CNN Using tf.GradientTape

A simple practical example of how to use TensorFlow's GradientTape to train a convolutional neural network.

Bjørn-Jostein Singstad

--

Main image showing tensorflow logo, MNIST digits and the text; GradientTape
Figure 1. A practical example of using tf.GradientTape with MNIST

Training TensorFlow models using model.fit()works perfectly fine for most applications, so why should you care about using tf.GradientTape() ?

For those of you who don't know what tf.GradientTape()is and what it does, here comes a short explanation:

tf.GradientTape() lets you compute the gradient while training all sorts of neural networks. The computed gradients are essential in order to do backpropagation to correct the errors of the neural network to make it gradualy improve.

Using model.fit()

The easiest way to train a model using TensorFlow is to use the model.fit()method. Assuming that you have defined some dataset X_trainwith some corresponding labels y_train and defined a model architecture create_my_model(), then a TensorFlow model can be trained using model.fit() like shown in the example below.

num_epochs = 30
batchsize = 64

model = create_my_model()
model.compile(optimizer = "adam" , loss = "categorical_crossentropy", metrics=["accuracy"])
model.fit(x=X_train, y=y_train, epochs=num_epochs, batch_size=batchsize)

This will train your model in a supervised way over 30 epochs. However, if you want to do some more crazy stuff, like training a generative adversarial network (GAN), or simply just want to understand what happens behind the curtains in model.fit(), you can use tf.GradientTape().

Using tf.GradientTape()

As a minimum working example, this would do exactly the same as model.fit() .

num_epochs = 30
batchsize = 64

model = create_my_model()
opt = tf.keras.optimizers.Adam(learning_rate=0.001)
for epoch in range(num_epochs):
for step in range(X_train.shape[0]//batchsize):
start_idx = batchsize*step
end_idx = batchsize*(step+1)
X_batch = X_train[start_idx:end_idx]
y_batch = y_train[start_idx:end_idx]
with tf.GradientTape() as tape:
pred = model(X_batch)
loss = tf.keras.losses.categorical_crossentropy(y_batch, pred)
grads = tape.gradient(loss, model.trainable_variables)
opt.apply_gradients(zip(grads, model.trainable_variables))

But you might say; “can I be 100% sure that using tf.GradientTape gives me the same good results as when I use model.fit() ?”

The answer is … almost, and I will show you this in an experiment on the well-known MNIST dataset (Figure 2 shows examples from the MNIST dataset).

Ten examples of handwritten digits from the MNIST dataset
Figure 2. MNIST examples

Experiment on MNIST

Figure 3 shows the 2D CNN architecture that was trained and validated using 10-fold cross-validation on the MNIST dataset. First using the model.fit() method then the tf.GradientTape() method.

A block diagram showing the model architecture
Figure 3. 2D CNN model architecture used in this experiment

The training curves in Figure 4 shows that there are some differences in performance using model.fit()compared to tf.GradientTape() and model.fit() seems to be significantly better. From the loss curve, it seems like the model that was trained using tf.GradientTape() starts to overfit after just a few epochs. The model that was trained using model.fit() does also overfit, but the increase in loss is less than the tf.GradientTape() model. A possible explanation for this might be that more things happen behind the curtains in model.fit() than what we have implemented in the tf.GradientTape() method.

Figure 4. Training curves showing the mean ± standard deviation for the loss and accuracy during training of the 2D CNN model on the MNIST dataset. The orange color represents the model trained using the tf.GradientTape() method, while the blue color represents the training using model.fit() method.

Conclusion

We have shown that tf.GradientTape() is relatively easy to implement. However, model.fit() seems to have some extra built-in calculations that make it slightly better than our proposed tf.GradientTape() method, at least for simple tasks such as supervised categorical classification on the MNIST dataset. Nevertheless, learning to use tf.GradientTape() will enable you to build advanced deep learning applications as well as give you a better intuition of how neural networks work.

Code availability

The full source code for this experiment is available on Kaggle.

--

--

Bjørn-Jostein Singstad

IT advisor at Vestfold Hospital Trust and PhD candidate at Akershus University Hospital with the objective of developing AI-enabled ECG interpretation algorithm