import numpy as np import torch from PIL import Image import DataLoader import ex4 from ImageImpaint import get_train_device from netio import load_model def apply_model(filepath: str): device = get_train_device() img = Image.open(filepath) model = load_model() model.to(device) pic = DataLoader.preprocess(img, precision=np.float32) pic = ex4.ex4(pic, (5, 5), (4, 4))[0] Image.fromarray((np.transpose(pic * 255.0, (1, 2, 0)).astype(np.uint8))).save("filename_grid.jpg") out = model(torch.from_numpy(pic).to(device)) out = DataLoader.postprocess(out.cpu().detach().numpy()) out = np.transpose(out, (1, 2, 0)) im = Image.fromarray(out) im.save("filename.jpg", format="jpeg") if __name__ == '__main__': apply_model("training/000/000017.jpg")