11
Jan
2019
/
Chris Van Pelt, CVP of Weights & Biases

Optimizing CIFAR-10 Hyperparameters with W&B and SageMaker

Everyone knows that hyperparameter sweeps are a great way to get an extra level of performance out of your algorithm, but we often don’t do them because they’re expensive and tricky to set up. AWS SageMaker makes it easy to do hyperparameter sweeps on your existing ML code, and W&B makes it effortless to see the results.

I had code for a CNN to classify images in the cifar-10 dataset, and I wanted to find the best set of hyperparameters.

Here’s a snippet from a working example where I used W&B with SageMaker.

estimator = PyTorch(entry_point="cifar10.py",
                   source_dir=os.getcwd() + "/source",
                   role=role,
                   framework_version='1.0.0.dev',
                   train_instance_count=1,
                   train_instance_type='ml.c5.xlarge',
                   hyperparameters={
                       'epochs': 50,
                       'momentum': 0.9
                   })

hyperparameter_ranges = {
   'lr': ContinuousParameter(0.0001, 0.001),
   'hidden_nodes': IntegerParameter(20, 100),
   'batch_size': CategoricalParameter([128, 256, 512]),
   'conv1_channels': CategoricalParameter([32, 64, 128]),
   'conv2_channels': CategoricalParameter([64, 128, 256, 512]),
}

SageMaker will spin up an AWS instance for each hyperparameter value and train the model. W&B tracks everything that happens and makes it easy to visualize the sweep. Here’s a table in W&B where I’m tracking all the runs that ran in the sweep, sorted by test accuracy. The test accuracy ranges from 10 - 76.45%, depending on the hyperparameters.

To dig deeper into the patterns between hyperparameters and accuracy on different classes, we generated a parallel coordinates plot in W&B. The first five columns are the configuration parameters and the far right column is the test accuracy. Each line corresponds to a single run. I’ve highlighted the runs with the best test accuracy to see where they land on the other columns. I discovered that lower learning rate (config:lr) and fewer hidden nodes (config:hidden_nodes) correlated with higher test accuracy.

Parallel coordinates plot highlighting runs with the highest test accuracy

We can zoom in on individual accuracy metrics or even compare our different models’ classification on a single image. Here are the top 10 models overall, struggling to identify this picture of a dog correctly.

To demonstrate the integration, we setup a sweep example in wandb over the cifar-10 dataset using pytorch. If you want to reproduce this, I put my code on Github. For more info on the integration check out our docs.

Newsletter

Enter your email to get updates about new features and blog posts.

Weights & Biases

We're building lightweight, flexible experiment tracking tools for deep learning. Add a couple of lines to your python script, and we'll keep track of your hyperparameters and output metrics, making it easy to compare runs and see the whole history of your progress. Think of us like GitHub for deep learning.

Partner Program

We are building our library of deep learning articles, and we're delighted to feature the work of community members. Contact Carey to learn about opportunities to share your research and insights.

Try our free tools for experiment tracking →