{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "8dbc4931-506b-4eb7-9049-11fda71fa2fd", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/q315433/micromamba/envs/pmf/lib/python3.12/site-packages/torch/__init__.py:749: UserWarning: torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:431.)\n", " _C._set_default_tensor_type(t)\n" ] } ], "source": [ "import sys\n", "import torch\n", "from pyprojroot import here as project_root\n", "import numpy as np\n", "\n", "sys.path.insert(0, str(project_root()))\n", "\n", "from src.evaluation.utils import get_test_path, get_model\n", "from src.evaluation.eval import meta_test\n", "\n", "from src.train_utils.trainer import train_parser\n", "from src.models.feature_extractors.pretrained_fe import get_fe_metadata\n", "import torchvision.transforms as transforms\n", "from PIL import Image\n", "\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" ] }, { "cell_type": "code", "execution_count": 2, "id": "a90ff098-ad85-45db-8576-54ffe4c8a7cc", "metadata": {}, "outputs": [], "source": [ "def test_transform():\n", " def _convert_image_to_rgb(im):\n", " return im.convert('RGB')\n", "\n", " return transforms.Compose([\n", " #transforms.Resize(224),\n", " transforms.Resize(224),\n", " #transforms.CenterCrop(224),\n", " _convert_image_to_rgb,\n", " transforms.ToTensor(),\n", " transforms.Normalize(mean=torch.tensor([0.4815, 0.4578, 0.4082]), std=torch.tensor([0.2686, 0.2613, 0.2758])),\n", " ])\n", "\n", "preprocess = test_transform()" ] }, { "cell_type": "code", "execution_count": 3, "id": "3a400d41-cc0b-4af2-aafe-fb1c82bf21a2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Defaulting to float32 dtype\n", "Loaded pretrained timm model vit_base_patch16_clip_224.openai\n", "../caml_pretrained_models/CAML_CLIP/model.pth\n" ] } ], "source": [ "import enum\n", "\n", "class T:\n", " fe_type = \"timm:vit_base_patch16_clip_224.openai:768\"\n", " #fe_type = \"timm:vit_huge_patch14_clip_224.laion2b:1280\"\n", " fe_dim = 768\n", " fe_dtype = \"float32\"\n", " model = \"CAML\"\n", " dropout = 0.0\n", " encoder_size = \"large\"\n", "\n", "fe_metadata = get_fe_metadata(T())\n", "#test_path = get_test_path(args, data_path)\n", "#device = torch.device(f'cuda:{args.gpu}')\n", "\n", "# Get the model and load its weights.\n", "model, model_path = get_model(T(), fe_metadata, device)\n", "print(model_path)\n", "#print(model)\n", "if model_path:\n", " model.load_state_dict(torch.load(model_path, map_location=f'cuda:0'), strict=False)\n", "model.to(device)\n", "_= model.eval()" ] }, { "cell_type": "code", "execution_count": 9, "id": "534d1750-0433-403c-9d51-c0522747a97f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2)\n", "torch.Size([65, 3, 224, 224])\n", "3\n", "62\n", "torch.Size([62, 4, 768])\n", "tensor([0, 1, 2], device='cuda:0')\n", "tensor([0, 1, 2], device='cuda:0')\n", "torch.Size([3])\n", "torch.Size([62, 4, 768])\n", "herre\n", "torch.Size([62])\n", "(62,)\n", "(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2)\n", "torch.Size([65, 3, 224, 224])\n", "9\n", "56\n", "torch.Size([56, 10, 768])\n", "tensor([0, 0, 0, 1, 1, 1, 2, 2, 2], device='cuda:0')\n", "tensor([0, 0, 0, 1, 1, 1, 2, 2, 2], device='cuda:0')\n", "torch.Size([9])\n", "torch.Size([56, 10, 768])\n", "herre\n", "torch.Size([56])\n", "(56,)\n", "(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2)\n", "torch.Size([65, 3, 224, 224])\n", "15\n", "50\n", "torch.Size([50, 16, 768])\n", "tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2], device='cuda:0')\n", "tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2], device='cuda:0')\n", "torch.Size([15])\n", "torch.Size([50, 16, 768])\n", "herre\n", "torch.Size([50])\n", "(50,)\n", "[0.58064516 0.51785714 0.52 ]\n" ] } ], "source": [ "import os\n", "\n", "img_path = \"../pmf_cvpr22/data_custom\"\n", "\n", "def filecnt_in_dir(dirr, typ):\n", " _, _, files = next(os.walk(f\"{img_path}/{dirr}/test/{typ}/\"))\n", " return len(files)\n", "\n", "def evaluate(shot, way, folder):\n", " ts = [\"good\", \"broken_small\", \"broken_large\", \"contamination\"]\n", " tss = [\"good\", \"cable_swap\", \"combined\", \"cut_inner_insulation\", \"cut_outer_insulation\", \"missing_cable\", \"missing_wire\", \"poke_insulation\"]\n", " tss = ts\n", " cat = [\"bottle\", \"cable\"]\n", "\n", " #goodnr = (len(tss)-1) * shot\n", " \n", " with torch.no_grad():\n", " #img_supp = [preprocess(Image.open(f\"{img_path}/{folder}/train/good/{i:03d}.png\")).unsqueeze(0).to(device) for i in range(shot)]\n", " img_supp = [preprocess(Image.open(f\"{img_path}/{folder}/test/{n}/{i:03d}.png\")).unsqueeze(0).to(device) for n in tss[1:4] for i in range(shot)]\n", " \n", " tmp = [(preprocess(Image.open(f\"{img_path}/{folder}/test/{n}/{i:03d}.png\")).unsqueeze(0).to(device), tss.index(n)-1) for n in tss[1:4] for i in range(shot, filecnt_in_dir(folder, n))]\n", " img_query, query_labels = zip(*tmp)\n", " #print(tmp)\n", " print(query_labels)\n", " \n", " img_concat = img_supp + list(img_query)\n", " img_concat = torch.cat(img_concat, 0)\n", " print(img_concat.shape)\n", " print(len(img_supp))\n", " print(len(img_query))\n", " #shot = (len(tss)-1) * shot\n", " \n", " #logits = model.meta_test(img_concat, way=4, shot=shot, query_shot=1)\n", " #print(logits)\n", " #\n", " feature_vector = model.get_feature_vector(img_concat)\n", " support_features = feature_vector[:way * shot]\n", " query_features = feature_vector[way * shot:]\n", " b, d = query_features.shape\n", " \n", " # Reshape query and support to a sequence.\n", " support = support_features.reshape(1, way * shot, d).repeat(b, 1, 1)\n", " query = query_features.reshape(-1, 1, d)\n", " feature_sequences = torch.cat([query, support], dim=1)\n", " print(feature_sequences.shape)\n", " \n", " #labels = torch.LongTensor([i // shot for i in range(shot * way)]).to(device)\n", " labels = torch.arange(way).repeat(shot, 1).T.flatten().to(model.device)\n", " print(labels)\n", " #labels = torch.from_numpy(np.ones(shape=shot, dtype=int)).to(device)\n", " #labels = torch.cat([torch.from_numpy(np.zeros(shape=shot, dtype=int)).to(device), labels])\n", " print(labels)\n", " \n", " #labels = torch.LongTensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ]).to(device)\n", " print(labels.shape)\n", " print(feature_sequences.shape)\n", " logits = model.transformer_encoder.forward_imagenet_v2(feature_sequences, labels, way, shot)\n", " #print(logits)\n", " _, max_index = torch.max(logits[:, :way], 1)\n", " #print(max_index.cpu().numpy())\n", " #bbb = np.ones(shape=(14*4))\n", " #bbb[:14] = 0\n", " #print(np.mean(max_index.cpu().numpy() == bbb))\n", " print(\"herre\")\n", " print(max_index.shape)\n", " print(np.array(query_labels).shape)\n", "\n", " return np.mean(max_index.cpu().numpy() == np.array(query_labels))\n", "\n", "scores = [evaluate(shot, 3, \"bottle\") for shot in [1,3,5]]\n", "print(np.array(scores))" ] }, { "cell_type": "markdown", "id": "22d68c6b-0df6-4953-8447-7843932fa974", "metadata": {}, "source": [ "CAML:\n", "Resulsts:\n", "\n", "bottle:\n", "jeweils 1,3,5 shots normal\n", "[0.40740741 0.39726027 0.30769231]\n", "\n", "inbalanced - mehr good shots 5,10,15,30 -> alle anderen nur 5\n", "- not possible\n", "1q\n", "2 ways nur detektieren ob fehlerhaft oder nicht 3,6,9 shots -> wegen model restrictions\n", "[0.79012346 0.84415584 0.87671233]\n", "\n", "inbalance 2 way 5,10,15,30 -> rest 5\n", "- not possible\n", "\n", "nur fehlerklasse erkennen 1,3,5\n", "[0.58064516 0.51785714 0.52 ]\n", "\n", "\n", "cable:\n", "jeweils 1,3,5 shots normal\n", "[0.24031008 0.19834711 0.15929204]\n", "\n", "inbalanced - mehr good shots 5,10,15,30 -> alle anderen nur 5\n", "- not possible\n", "\n", "2 ways nur detektieren ob fehlerhaft oder nicht 1,3,5 shots\n", "[0.57364341 0.54545455 0.59292035]\n", "\n", "inbalance 2 way 5,10,15,30 -> rest 5\n", "- not possible\n", "\n", "nur fehlerklasse erkennen 1,3,5\n", "[0.12962963 0.36363636 0.58823529]\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.4" } }, "nbformat": 4, "nbformat_minor": 5 }