Handwritten Digit Recognition in PyTorch

PyTorch (https://pytorch.org/) is an open-source machine learning framework developed at Facebook. It provides different components such as optimizers, loss functions, fully connected layers, activation function, etc. to build deep learning architectures. 

In this lab, we will learn how to recognize handwritten digits (0-9) in the MNIST Dataset using a simple network in PyTorch.

Data-loaders

Data loaders are objects that wrap the whole dataset in batches. Some standard used number of batches are 32, 64, 128, 256, 512. The size of a batch depends on factors like underfitting and overfitting of a model, and sometimes when we have contained on memory, we ought to choose a smaller batch size so that that data fits in memory.

train_loader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=True)

 

 

 

Discussion