import torch class ImageNN(torch.nn.Module): def __init__(self): super().__init__() # todo implement the nn structure def forward(self, x: torch.Tensor): pass # todo implement forward