291 lines
10 KiB
Plaintext
Raw Normal View History

2024-09-12 13:58:14 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "8dbc4931-506b-4eb7-9049-11fda71fa2fd",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/q315433/micromamba/envs/pmf/lib/python3.12/site-packages/torch/__init__.py:749: UserWarning: torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:431.)\n",
" _C._set_default_tensor_type(t)\n"
]
}
],
"source": [
"import sys\n",
"import torch\n",
"from pyprojroot import here as project_root\n",
"import numpy as np\n",
"\n",
"sys.path.insert(0, str(project_root()))\n",
"\n",
"from src.evaluation.utils import get_test_path, get_model\n",
"from src.evaluation.eval import meta_test\n",
"\n",
"from src.train_utils.trainer import train_parser\n",
"from src.models.feature_extractors.pretrained_fe import get_fe_metadata\n",
"import torchvision.transforms as transforms\n",
"from PIL import Image\n",
"\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a90ff098-ad85-45db-8576-54ffe4c8a7cc",
"metadata": {},
"outputs": [],
"source": [
"def test_transform():\n",
" def _convert_image_to_rgb(im):\n",
" return im.convert('RGB')\n",
"\n",
" return transforms.Compose([\n",
" #transforms.Resize(224),\n",
" transforms.Resize(224),\n",
" #transforms.CenterCrop(224),\n",
" _convert_image_to_rgb,\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(mean=torch.tensor([0.4815, 0.4578, 0.4082]), std=torch.tensor([0.2686, 0.2613, 0.2758])),\n",
" ])\n",
"\n",
"preprocess = test_transform()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "3a400d41-cc0b-4af2-aafe-fb1c82bf21a2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Defaulting to float32 dtype\n",
"Loaded pretrained timm model vit_base_patch16_clip_224.openai\n",
"../caml_pretrained_models/CAML_CLIP/model.pth\n"
]
}
],
"source": [
"import enum\n",
"\n",
"class T:\n",
" fe_type = \"timm:vit_base_patch16_clip_224.openai:768\"\n",
" #fe_type = \"timm:vit_huge_patch14_clip_224.laion2b:1280\"\n",
" fe_dim = 768\n",
" fe_dtype = \"float32\"\n",
" model = \"CAML\"\n",
" dropout = 0.0\n",
" encoder_size = \"large\"\n",
"\n",
"fe_metadata = get_fe_metadata(T())\n",
"#test_path = get_test_path(args, data_path)\n",
"#device = torch.device(f'cuda:{args.gpu}')\n",
"\n",
"# Get the model and load its weights.\n",
"model, model_path = get_model(T(), fe_metadata, device)\n",
"print(model_path)\n",
"#print(model)\n",
"if model_path:\n",
" model.load_state_dict(torch.load(model_path, map_location=f'cuda:0'), strict=False)\n",
"model.to(device)\n",
"_= model.eval()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "534d1750-0433-403c-9d51-c0522747a97f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2)\n",
"torch.Size([65, 3, 224, 224])\n",
"3\n",
"62\n",
"torch.Size([62, 4, 768])\n",
"tensor([0, 1, 2], device='cuda:0')\n",
"tensor([0, 1, 2], device='cuda:0')\n",
"torch.Size([3])\n",
"torch.Size([62, 4, 768])\n",
"herre\n",
"torch.Size([62])\n",
"(62,)\n",
"(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2)\n",
"torch.Size([65, 3, 224, 224])\n",
"9\n",
"56\n",
"torch.Size([56, 10, 768])\n",
"tensor([0, 0, 0, 1, 1, 1, 2, 2, 2], device='cuda:0')\n",
"tensor([0, 0, 0, 1, 1, 1, 2, 2, 2], device='cuda:0')\n",
"torch.Size([9])\n",
"torch.Size([56, 10, 768])\n",
"herre\n",
"torch.Size([56])\n",
"(56,)\n",
"(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2)\n",
"torch.Size([65, 3, 224, 224])\n",
"15\n",
"50\n",
"torch.Size([50, 16, 768])\n",
"tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2], device='cuda:0')\n",
"tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2], device='cuda:0')\n",
"torch.Size([15])\n",
"torch.Size([50, 16, 768])\n",
"herre\n",
"torch.Size([50])\n",
"(50,)\n",
"[0.58064516 0.51785714 0.52 ]\n"
]
}
],
"source": [
"import os\n",
"\n",
"img_path = \"../pmf_cvpr22/data_custom\"\n",
"\n",
"def filecnt_in_dir(dirr, typ):\n",
" _, _, files = next(os.walk(f\"{img_path}/{dirr}/test/{typ}/\"))\n",
" return len(files)\n",
"\n",
"def evaluate(shot, way, folder):\n",
" ts = [\"good\", \"broken_small\", \"broken_large\", \"contamination\"]\n",
" tss = [\"good\", \"cable_swap\", \"combined\", \"cut_inner_insulation\", \"cut_outer_insulation\", \"missing_cable\", \"missing_wire\", \"poke_insulation\"]\n",
" tss = ts\n",
" cat = [\"bottle\", \"cable\"]\n",
"\n",
" #goodnr = (len(tss)-1) * shot\n",
" \n",
" with torch.no_grad():\n",
" #img_supp = [preprocess(Image.open(f\"{img_path}/{folder}/train/good/{i:03d}.png\")).unsqueeze(0).to(device) for i in range(shot)]\n",
" img_supp = [preprocess(Image.open(f\"{img_path}/{folder}/test/{n}/{i:03d}.png\")).unsqueeze(0).to(device) for n in tss[1:4] for i in range(shot)]\n",
" \n",
" tmp = [(preprocess(Image.open(f\"{img_path}/{folder}/test/{n}/{i:03d}.png\")).unsqueeze(0).to(device), tss.index(n)-1) for n in tss[1:4] for i in range(shot, filecnt_in_dir(folder, n))]\n",
" img_query, query_labels = zip(*tmp)\n",
" #print(tmp)\n",
" print(query_labels)\n",
" \n",
" img_concat = img_supp + list(img_query)\n",
" img_concat = torch.cat(img_concat, 0)\n",
" print(img_concat.shape)\n",
" print(len(img_supp))\n",
" print(len(img_query))\n",
" #shot = (len(tss)-1) * shot\n",
" \n",
" #logits = model.meta_test(img_concat, way=4, shot=shot, query_shot=1)\n",
" #print(logits)\n",
" #\n",
" feature_vector = model.get_feature_vector(img_concat)\n",
" support_features = feature_vector[:way * shot]\n",
" query_features = feature_vector[way * shot:]\n",
" b, d = query_features.shape\n",
" \n",
" # Reshape query and support to a sequence.\n",
" support = support_features.reshape(1, way * shot, d).repeat(b, 1, 1)\n",
" query = query_features.reshape(-1, 1, d)\n",
" feature_sequences = torch.cat([query, support], dim=1)\n",
" print(feature_sequences.shape)\n",
" \n",
" #labels = torch.LongTensor([i // shot for i in range(shot * way)]).to(device)\n",
" labels = torch.arange(way).repeat(shot, 1).T.flatten().to(model.device)\n",
" print(labels)\n",
" #labels = torch.from_numpy(np.ones(shape=shot, dtype=int)).to(device)\n",
" #labels = torch.cat([torch.from_numpy(np.zeros(shape=shot, dtype=int)).to(device), labels])\n",
" print(labels)\n",
" \n",
" #labels = torch.LongTensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ]).to(device)\n",
" print(labels.shape)\n",
" print(feature_sequences.shape)\n",
" logits = model.transformer_encoder.forward_imagenet_v2(feature_sequences, labels, way, shot)\n",
" #print(logits)\n",
" _, max_index = torch.max(logits[:, :way], 1)\n",
" #print(max_index.cpu().numpy())\n",
" #bbb = np.ones(shape=(14*4))\n",
" #bbb[:14] = 0\n",
" #print(np.mean(max_index.cpu().numpy() == bbb))\n",
" print(\"herre\")\n",
" print(max_index.shape)\n",
" print(np.array(query_labels).shape)\n",
"\n",
" return np.mean(max_index.cpu().numpy() == np.array(query_labels))\n",
"\n",
"scores = [evaluate(shot, 3, \"bottle\") for shot in [1,3,5]]\n",
"print(np.array(scores))"
]
},
{
"cell_type": "markdown",
"id": "22d68c6b-0df6-4953-8447-7843932fa974",
"metadata": {},
"source": [
"CAML:\n",
"Resulsts:\n",
"\n",
"bottle:\n",
"jeweils 1,3,5 shots normal\n",
"[0.40740741 0.39726027 0.30769231]\n",
"\n",
"inbalanced - mehr good shots 5,10,15,30 -> alle anderen nur 5\n",
"- not possible\n",
"1q\n",
"2 ways nur detektieren ob fehlerhaft oder nicht 3,6,9 shots -> wegen model restrictions\n",
"[0.79012346 0.84415584 0.87671233]\n",
"\n",
"inbalance 2 way 5,10,15,30 -> rest 5\n",
"- not possible\n",
"\n",
"nur fehlerklasse erkennen 1,3,5\n",
"[0.58064516 0.51785714 0.52 ]\n",
"\n",
"\n",
"cable:\n",
"jeweils 1,3,5 shots normal\n",
"[0.24031008 0.19834711 0.15929204]\n",
"\n",
"inbalanced - mehr good shots 5,10,15,30 -> alle anderen nur 5\n",
"- not possible\n",
"\n",
"2 ways nur detektieren ob fehlerhaft oder nicht 1,3,5 shots\n",
"[0.57364341 0.54545455 0.59292035]\n",
"\n",
"inbalance 2 way 5,10,15,30 -> rest 5\n",
"- not possible\n",
"\n",
"nur fehlerklasse erkennen 1,3,5\n",
"[0.12962963 0.36363636 0.58823529]\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2025-01-01 20:53:05 +01:00
"version": "3.13.1"
2024-09-12 13:58:14 +02:00
}
},
"nbformat": 4,
"nbformat_minor": 5
}