How to Train a CNN Using tf.GradientTape
A simple practical example of how to use TensorFlow's GradientTape to train a convolutional neural network.
Training TensorFlow models using model.fit()
works perfectly fine for most applications, so why should you care about using tf.GradientTape()
?
For those of you who don't know what tf.GradientTape()
is and what it does, here comes a short explanation:
tf.GradientTape()
lets you compute the gradient while training all sorts of neural networks. The computed gradients are essential in order to do backpropagation to correct the errors of the neural network to make it gradualy improve.
Using model.fit()
The easiest way to train a model using TensorFlow is to use the model.fit()
method. Assuming that you have defined some dataset X_train
with some corresponding labels y_train
and defined a model architecture create_my_model()
, then a TensorFlow model can be trained using model.fit()
like shown in the example below.
num_epochs = 30
batchsize = 64
model = create_my_model()
model.compile(optimizer = "adam" , loss = "categorical_crossentropy", metrics=["accuracy"])
model.fit(x=X_train, y=y_train, epochs=num_epochs, batch_size=batchsize)
This will train your model in a supervised way over 30 epochs. However, if you want to do some more crazy stuff, like training a generative adversarial network (GAN), or simply just want to understand what happens behind the curtains in model.fit()
, you can use tf.GradientTape()
.
Using tf.GradientTape()
As a minimum working example, this would do exactly the same as model.fit()
.
num_epochs = 30
batchsize = 64
model = create_my_model()
opt = tf.keras.optimizers.Adam(learning_rate=0.001)
for epoch in range(num_epochs):
for step in range(X_train.shape[0]//batchsize):
start_idx = batchsize*step
end_idx = batchsize*(step+1)
X_batch = X_train[start_idx:end_idx]
y_batch = y_train[start_idx:end_idx]
with tf.GradientTape() as tape:
pred = model(X_batch)
loss = tf.keras.losses.categorical_crossentropy(y_batch, pred)
grads = tape.gradient(loss, model.trainable_variables)
opt.apply_gradients(zip(grads, model.trainable_variables))
But you might say; “can I be 100% sure that using tf.GradientTape
gives me the same good results as when I use model.fit()
?”
The answer is … almost, and I will show you this in an experiment on the well-known MNIST dataset (Figure 2 shows examples from the MNIST dataset).
Experiment on MNIST
Figure 3 shows the 2D CNN architecture that was trained and validated using 10-fold cross-validation on the MNIST dataset. First using the model.fit()
method then the tf.GradientTape()
method.
The training curves in Figure 4 shows that there are some differences in performance using model.fit()
compared to tf.GradientTape()
and model.fit()
seems to be significantly better. From the loss curve, it seems like the model that was trained using tf.GradientTape()
starts to overfit after just a few epochs. The model that was trained using model.fit()
does also overfit, but the increase in loss is less than the tf.GradientTape()
model. A possible explanation for this might be that more things happen behind the curtains in model.fit()
than what we have implemented in the tf.GradientTape()
method.
Conclusion
We have shown that tf.GradientTape()
is relatively easy to implement. However, model.fit()
seems to have some extra built-in calculations that make it slightly better than our proposed tf.GradientTape()
method, at least for simple tasks such as supervised categorical classification on the MNIST dataset. Nevertheless, learning to use tf.GradientTape()
will enable you to build advanced deep learning applications as well as give you a better intuition of how neural networks work.
Code availability
The full source code for this experiment is available on Kaggle.