Handwritten Digit Recognition in PyTorch

PyTorch (https://pytorch.org/) is an open-source machine learning framework developed at Facebook. It provides different components such as optimizers, loss functions, fully connected layers, activation function, etc. to build deep learning architectures. 

In this lab, we will learn how to recognize handwritten digits (0-9) in the MNIST Dataset using a simple network in PyTorch.

Testing your network

We finally test our trained model against the test set.

torch.max with dim = 1 gives us a tuple in the output, where the first vector signifies the maximum value in the input matrix across each row and the other output is the index in the column that contains that maximum value. 

got_correct = 0

for idx, (data,target) in enumerate(test_loader):
    data = data.view(data.shape[0], -1)
    with torch.no_grad():
        probability = model(data)

    _, predicted_digit = torch.max(probability, dim=1)
    check_correct_predictions = predicted_digit == target
    got_correct += torch.sum(check_correct_predictions)

print("Number Of Images Tested =", len(test_loader.dataset))
print("\nModel Accuracy =", (got_correct/len(test_loader.dataset)))
Discussion