Customizing Training Loops in TensorFlow 2.0

Sayak Paul

Personally, I really like TensorFlow 2.0 - I like how the TensorFlow team has expanded the entire ecosystem and how interoperable they are, I like how they have really pushed the tf.keras integration and how easy it is now to plug tf.keras with the native TensorFlow modules. But what I like the most is the ability to customize my training loops like never before. Even if I am creating my models using the classic Sequential API of Keras, I can still write my own training loops from scratch and absolutely own the training process. If you want to read a quick summary of the things that are newly introduced in TensorFlow 2.0, this article can be a good place to start.

Using Weights and Biases (W&B) with Keras, is as easy as adding the WandbCallback to your or model.fit_generator functions. But what if you are writing your own training loops? In that case, integrating W&B can be a bit more involved. This article shows how to do that.

You can find the accompanying code in this Colab notebook. Some portions of the code are inspired from the official TensorFlow tutorials and guides.

What is a customized training loop?

You might be wondering what customizing a training loop or writing one from scratch looks like. So, before we go into the details of using W&B in a customized training loop lets explore that.

While functions like fit and fit_generator make a machine learning engineer’s life easier, if you are doing research or if you want to have more control over your model’s parameters and how they are updated, those functions might not be the best choice. This is exactly why the idea of GradientTape was introduced in TensorFlow 2.0 which lets you literally watch the gradients while your model is getting trained and also how the parameters of the model get updated using those gradients. Let’s take an example:

# Train the model
@tf.function # Speeds things up
def model_train(features, labels):
   # Define the GradientTape context
   with tf.GradientTape() as tape:
       # Get the probabilities
       predictions = model(features)
       # Calculate the loss
       loss = loss_func(labels, predictions)
   # Get the gradients
   gradients = tape.gradient(loss, model.trainable_variables)
   # Update the weights
   optimizer.apply_gradients(zip(gradients, model.trainable_variables))

   # Update the loss and accuracy
   train_acc(labels, predictions)

A couple of things to note here:

And that’s it!

This is a rough simulation of the classic fit function provided by Keras but notice that we now have the flexibility to control how we want the parameter updates to take place in our model among many other things.

Note regarding declarative API vs. imperative API

The model that we used in the above step was created using the classic Sequential API of Keras:

model = Sequential()
model.add(Conv2D(16, (5, 5), activation="relu",
   input_shape=(28, 28,1)))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(32, (5, 5), activation="relu"))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dense(128, activation="relu"))
model.add(Dense(len(CLASSES), activation="softmax"))

Nothing fancy there, I heard you. What is exciting is you can now take this sequential model and have complete flexibility over how it is training. What is even more exciting is this is not the end of customization - you can absolutely control the flow of operations in your models. Here’s an application:

class CustomModel(tensorflow.keras.Model):
   def __init__(self):
       super(CustomModel, self).__init__()
       self.do1 = tf.keras.layers.Dropout(rate=0.2, input_shape=(shape,))
       self.fc1 = tf.keras.layers.Dense(units=64, activation="relu")
       self.do2 = tf.keras.layers.Dropout(rate=0.2)
       self.fc2 = tf.keras.layers.Dense(units=64, activation="relu")
       self.do3 = tf.keras.layers.Dropout(rate=0.2)
       self.out = tf.keras.layers.Dense(units=1, activation="sigmoid")

   def call(self, x):
       x = self.do1(x)
       x = self.fc1(x)
       x = self.do2(x)
       x = self.fc2(x)
       x = self.do3(x)
       return self.out(x)

You can learn more about the synergy between declarative and imperative API designs in TensorFlow from here. As an exercise, try converting our sequential model definition to an imperative design as shown above.

Logging metrics in the training loops using W&B

This section will walk you through the following steps:

Let’s get cracking on these.

Our dataset: FashionMNIST

The dataset loading and preprocessing steps are pretty simple assuming that you have worked with the dataset before.

# Load the FashionMNIST dataset, scale the pixel values
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
X_train = X_train/255.
X_test = X_test/255.

Now you probably noticed earlier, that our sequential model is a shallow convolutional neural network. And for the images to work with the CNNs (specially the ones defined using Keras) we need to have a channel dimension. Let’s do that:

# Reshape input data
X_train = X_train.reshape(-1, 28, 28, 1)
X_test = X_test.reshape(-1, 28, 28, 1)

Let’s now batch the dataset:

# Batches of 64
train_ds =, y_train)).batch(64)
test_ds =, y_test)).batch(64)

And that’s it!

Notice the use of here. The data module of TensorFlow offers a lot of useful functionalities for building flexible and fast data pipelines. Get started with the data module here.

Model training and the use of W&B

The model is already defined for us. The instructions on how to train the model are defined the following function which we revisited earlier:

# Train the model
def model_train(features, labels):
   # Define the GradientTape context
   with tf.GradientTape() as tape:
       # Get the probabilities
       predictions = model(features)
       # Calculate the loss
       loss = loss_func(labels, predictions)
   # Get the gradients
   gradients = tape.gradient(loss, model.trainable_variables)
   # Update the weights
   optimizer.apply_gradients(zip(gradients, model.trainable_variables))

   # Update the loss and accuracy
   train_acc(labels, predictions)

Similarly, we can define a little function to evaluate our model:

# Validating the model
def model_validate(features, labels):
   predictions = model(features)
   v_loss = loss_func(labels, predictions)

   valid_acc(labels, predictions)

Now, we are absolutely ready to start the model training:

# Train the model for 5 epochs
for epoch in range(5):
   # Run the model through train and test sets respectively
   for (features, labels) in train_ds:
       model_train(features, labels)

   for test_features, test_labels in test_ds:
       model_validate(test_features, test_labels)
   # Grab the results
   (loss, acc) = train_loss.result(), train_acc.result()
   (val_loss, val_acc) = valid_loss.result(), valid_acc.result()
   # Clear the current state of the metrics
   train_loss.reset_states(), train_acc.reset_states()
   valid_loss.reset_states(), valid_acc.reset_states()
   # Local logging
   template = "Epoch {}, loss: {:.3f}, acc: {:.3f}, val_loss: {:.3f}, val_acc: {:.3f}"
   print (template.format(epoch+1,
   # Logging with W&B
   wandb.log({"train_loss": loss.numpy(),
              "train_accuracy": acc.numpy(),
              "val_loss": val_loss.numpy(),
              "val_accuracy": val_acc.numpy()
   get_sample_predictions() # More on this later

Locally, it prints something like so:

Epoch 1, loss: 0.544, acc: 0.802, val_loss: 0.429, val_acc: 0.845
Epoch 2, loss: 0.361, acc: 0.871, val_loss: 0.377, val_acc: 0.860
Epoch 3, loss: 0.309, acc: 0.888, val_loss: 0.351, val_acc: 0.869
Epoch 4, loss: 0.277, acc: 0.899, val_loss: 0.336, val_acc: 0.873
Epoch 5, loss: 0.252, acc: 0.908, val_loss: 0.323, val_acc: 0.882

And on the W&B run page, you get all the plot delicacy:

Looks like the model is training in the right way. I have always wanted to have plots where the training and validation metrics will be overlaid. W&B gives me a much better sense of the training progress of the model than the individual plots as shown above. With W&B, it is extremely simple – here's a quick video of how you can plot your training and validation accuracies in one plot.

And now I have the overlaid plots:

The get_sample_predictions function

Let’s take a look at the function definition:

# Grab random images from the test and make predictions using
# the model *while it is training* and log them using WnB
def get_sample_predictions():
   predictions = []
   images = []
   random_indices = np.random.choice(X_test.shape[0], 25)
   for index in random_indices:
       image = X_test[index].reshape(1, 28, 28, 1)
       prediction = np.argmax(model(image).numpy(), axis=1)
       prediction = CLASSES[int(prediction)]
   wandb.log({"predictions": [wandb.Image(image, caption=prediction)
                              for (image, prediction) in zip(images, predictions)]})

As the comment above suggest, get_sample_predictions helps you log a set of randomly selected images from the test set and their model predictions while a model is training. This is actually very easily achievable if you are using the WandbCallback with data_type="image", label=CLASSES arguments. But doing the same in a customized training loop is equally doable.

Now, when I go to the W&B run page and scroll to the bottom, I can easily see get_sample_prediction doing its magic:


So, that’s it for this article. I hope you will now experiment a lot with the TensorFlow 2.0 features presented in the article Show us how you customize your training loops and use W&B to automatically keep track of the training progress of your models.

Join our mailing list to get the latest machine learning updates.