"""LeNet5 implementation."""

from torch import nn
import torch


class LeNet5(nn.Module):
    """LeNet5 network."""

    def __init__(self, num_classes: int) -> None:
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(6),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.dense = nn.Linear(400, 120)
        self.relu = nn.ReLU()
        self.dense1 = nn.Linear(120, 84)
        self.relu1 = nn.ReLU()
        self.dense2 = nn.Linear(84, num_classes)

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        """Forward method.

        Args:
            inputs (torch.Tensor): `(N, 1, 28, 28)`

        Returns:
            torch.Tensor: `(N, num_classes)`
        """
        out = self.layer1(inputs)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.dense(out)
        out = self.relu(out)
        out = self.dense1(out)
        out = self.relu1(out)
        out = self.dense2(out)
        return out
