Handwritten Digit Recognition in PyTorch

PyTorch (https://pytorch.org/) is an open-source machine learning framework developed at Facebook. It provides different components such as optimizers, loss functions, fully connected layers, activation function, etc. to build deep learning architectures. 

In this lab, we will learn how to recognize handwritten digits (0-9) in the MNIST Dataset using a simple network in PyTorch.

Data Visualisation

It is a good practice to visualize the dataset that we work on. It gives us some idea of what kind of data is present and therefore can affect some decisions on the way of designing and training the neural networks.

data_iterator = iter(train_loader)
images, labels = data_iterator.next()

print('Shape of input: {}, Shape of output: {}'.format(images.shape, labels.shape))

fig = plt.figure()
num_images = 60
for image_index in range(num_images):
    plt.subplot(6, 10, image_index+1)
    plt.imshow(images[image_index].numpy().squeeze(), cmap='gray_r')