In this tutorial, I'll guide you through implementing a unique and flexible neural network architecture using PyTorch. This architecture, which I'll refer to as the "Twisted MLP", is designed to highlight the adaptability and potential of neural networks for complex tasks. We'll build a model to classify digits from the MNIST dataset, leveraging both direct paths and skip connections to enhance learning capabilities.

Overview of the Twisted MLP Architecture

The Twisted MLP consists of two parallel paths processing the input data independently before combining their strengths later in the network. This setup allows for both deep and shallow processing of the same input, capturing features at different levels of abstraction.

embed (98).svg

Components of the Twisted MLP:

  1. Input Layer: MNIST input (784 features from a 28x28 flattened image).
  2. First Path (Main Path):
  3. Second Path (Lower Path, aka. Skip Connection):
  4. Concatenation of the two paths at $\textcircled{c}$ (32+32 resulting in a 64-dimensional tensor).
  5. Further Processing:
  6. Final Output Layer:

Define the Twisted MLP Model Class

# Define the neural network
class TwistedMNISTModel(nn.Module):
    def __init__(self):
        super(TwistedMNISTModel, self).__init__()
        
        # First path (main)
        self.fc1 = nn.Linear(784, 64)
        self.fc2 = nn.Linear(64, 32)

        # Second path (skip connection)
        self.skip_fc1 = nn.Linear(784, 32)

        # Processing after concatenation
        self.concat_fc = nn.Linear(64, 128)
        self.skip_fc2 = nn.Linear(32, 128)

        # Final output layer
        self.final_fc = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input (28x28 -> 784)

        # First path (main)
        z = torch.relu(self.fc1(x))
        h = torch.relu(self.fc2(z))

        # Second path (skip connection)
        u = torch.relu(self.skip_fc1(x))

        # Concatenation
        q = torch.cat((h, u), dim=1)

        # Processing concatenated output
        v = torch.relu(self.concat_fc(q))

        # Further processing of skip connection
        k = torch.relu(self.skip_fc2(u))

        # Add the processed outputs
        d = v + k

        # Final output layer
        hat_y = self.final_fc(d)
        
        return hat_y

Full Code

Conclusion

The Twisted MLP model is an excellent example of how neural network architectures can be creatively designed to handle complex problems. By employing both deep and shallow paths, it effectively learns both low and high-level features. This tutorial demonstrates the flexibility and power of neural networks in solving real-world challenges like image classification.