Handwritten Digit Recognition in PyTorch
PyTorch (https://pytorch.org/) is an open-source machine learning framework developed at Facebook. It provides different components such as optimizers, loss functions, fully connected layers, activation function, etc. to build deep learning architectures.
In this lab, we will learn how to recognize handwritten digits (0-9) in the MNIST Dataset using a simple network in PyTorch.
Training of the network
After defining the architecture that we are going to use, we will proceed to train the neural network. To train the network we must define an optimizer, signifying the optimization techniques that we would use, and a loss function.
criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
In PyTorch, the training is divided into the following steps:
- Computing predictions on data
- Computing loss on the predictions with respect to the target
- Computing gradients via by backpropagating in the network
- Updating the weights using SGD optimizer
for epoch in range(num_epochs): running_loss = 0 for batch_idx, (data, target) in enumerate(train_loader): # Flat each 28x28 image to 784x1 vector data = data.view(data.shape, -1) # Step 1 probabilites = model(data) # Step 2 loss = criterion(probabilites, target) running_loss += loss.item() # Step 3 loss.backward() # Step 4 optimizer.step() optimizer.zero_grad()
We do optimizer.zero_grad() to zero out the gradients computing for the current batch. This helps us compute a new fresh set of gradients for each parameter in the next batch. Otherwise, the gradients get accumulated.