ImageImpaint_Python_II/ApplyModel.py

32 lines
866 B
Python
Raw Normal View History

2022-07-01 13:35:12 +00:00
import numpy as np
import torch
from PIL import Image
import DataLoader
import ex4
from ImageImpaint import get_train_device
from netio import load_model
def apply_model(filepath: str):
device = get_train_device()
img = Image.open(filepath)
model = load_model()
model.to(device)
pic = DataLoader.crop_image(img)
pic = DataLoader.preprocess(pic, precision=np.float32)
2022-07-01 13:35:12 +00:00
pic = ex4.ex4(pic, (5, 5), (4, 4))[0]
Image.fromarray((np.transpose(DataLoader.postprocess(pic), (1, 2, 0)).astype(np.uint8))).save("filename_grid.jpg")
2022-07-01 13:35:12 +00:00
out = model(torch.from_numpy(pic).to(device))
out = out.cpu().detach().numpy()
out = DataLoader.postprocess(out)
2022-07-01 13:35:12 +00:00
out = np.transpose(out, (1, 2, 0))
im = Image.fromarray(out)
im.save("filename.jpg", format="jpeg")
if __name__ == '__main__':
apply_model("training/000/000017.jpg")