Handwritten Digit Recognition in PyTorch

PyTorch (https://pytorch.org/) is an open-source machine learning framework developed at Facebook. It provides different components such as optimizers, loss functions, fully connected layers, activation function, etc. to build deep learning architectures. 

In this lab, we will learn how to recognize handwritten digits (0-9) in the MNIST Dataset using a simple network in PyTorch.

Defining the Neural Network

In this step we will define the architecture we will use for identifying handwritten digits. 

PyTorch uses classes to implement a structure of Neural Network. In this class, we initialize the parameters of a network by calling the superclass torch.nn.Module and write each component of the whole neural network as attributes in this class. 

class Net(torch.nn.Module):
    def __init__(self, hidden_sizes=[128, 64], input_size=784, output_size=10):
        super(Net, self).__init__()
        self.input_size = input_size
        self.hidden_sizes = hidden_sizes
        self.output_size = output_size

        # Define Activation Function
        self.activation = torch.nn.ReLU(inplace=True)

        # Define Softmax function to predict probability
        self.log_softmax = torch.nn.LogSoftmax(dim=1)

        # Define the fully connected layers
        self.layer1 = torch.nn.Linear(self.input_size, self.hidden_sizes[0])
        self.layer2 = torch.nn.Linear(self.hidden_sizes[0], self.hidden_sizes[1])
        self.layer3 = torch.nn.Linear(self.hidden_sizes[1], self.output_size)

    def forward(self, input):
        """Performs the forward propogation on the input"""

        output = self.layer1(input)
        output = self.activation(output)
        output = self.layer2(output)
        output = self.activation(output)
        output = self.layer3(output)
        output = self.log_softmax(output)
        
        return output        

The forward function passes the input from the layers that defined as attributes earlier in the constructor. This function implements forward propagation.

The next step is to create an object of class Net so that we can utilize the architecture for training.

model = Net()
Discussion