diff --git a/ApplyModel.py b/ApplyModel.py index 374f042..eec2a97 100644 --- a/ApplyModel.py +++ b/ApplyModel.py @@ -1,11 +1,14 @@ +import pickle + import numpy as np import torch from PIL import Image +import Compress import DataLoader import ex4 from ImageImpaint import get_train_device -from netio import load_model +from netio import load_model, eval_evalset, write_to_pickle def apply_model(filepath: str): @@ -26,6 +29,37 @@ def apply_model(filepath: str): im = Image.fromarray(out) im.save("filename.jpg", format="jpeg") +def test(): + # read the provided testing pickle file + print("Generating pickle file with privided test data") + PICKEL_PATH = "test" + + model = load_model() + model.eval() + + loader,_ = DataLoader.get_image_loader("training/", np.float32) + outarr = np.zeros(dtype=np.uint8, shape=(8663, 3, 100, 100)) + targetarr = np.zeros(dtype=np.uint8, shape=(8663, 3, 100, 100)) + + i = 0 + for input, target in loader: + out = model(input) + out = DataLoader.postprocess(out.cpu().detach().numpy()) + outarr[i] = out + targetarr[i] = DataLoader.postprocess(target.cpu().detach().numpy()) + print(f'\rApplying model [{i}/{len(loader)}]', end='') + i += 1 + if i==8663: + break + write_to_pickle(PICKEL_PATH + "_pred.pkl", list(outarr)) + # compress the generated pickle arr + Compress.compress(PICKEL_PATH + "_pred.pkl") + + write_to_pickle(PICKEL_PATH + "_target.pkl", list(targetarr)) + # compress the generated pickle arr + Compress.compress(PICKEL_PATH + "_target.pkl") if __name__ == '__main__': apply_model("training/000/000017.jpg") + eval_evalset() + # test() diff --git a/ImageImpaint.py b/ImageImpaint.py index 667eaaa..e24f7b9 100644 --- a/ImageImpaint.py +++ b/ImageImpaint.py @@ -115,8 +115,8 @@ def plot(inputs, targets, predictions, path, update, epoch): for ax, data, title in zip(axes, [inputs, targets, predictions], ["Input", "Target", "Prediction"]): ax.clear() ax.set_title(title) - # ax.imshow(DataLoader.postprocess(np.transpose(data[i], (1, 2, 0))), interpolation="none") - ax.imshow(np.transpose((data[i]), (1, 2, 0)), interpolation="none") + ax.imshow(DataLoader.postprocess(np.transpose(data[i], (1, 2, 0))), interpolation="none") + # ax.imshow(np.transpose((data[i]), (1, 2, 0)), interpolation="none") ax.set_axis_off() fig.savefig(os.path.join(path, f"{epoch:02d}_{update:07d}_{i:02d}.png"), dpi=100) diff --git a/Scoring.py b/Scoring.py index e8002f7..7c40548 100644 --- a/Scoring.py +++ b/Scoring.py @@ -27,8 +27,6 @@ import zipfile import dill as pkl import numpy as np -import onnx -import onnxruntime TEST_DATA_PATH = r"/daten/challenge/django/data/datasets/image_inpainting_2022/test.zip"