lukas-heiligenbrunner
882c6f54bb
All checks were successful
Build Typst document / build_typst_documents (push) Successful in 22s
291 lines
10 KiB
Plaintext
291 lines
10 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "8dbc4931-506b-4eb7-9049-11fda71fa2fd",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/home/q315433/micromamba/envs/pmf/lib/python3.12/site-packages/torch/__init__.py:749: UserWarning: torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:431.)\n",
|
|
" _C._set_default_tensor_type(t)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import sys\n",
|
|
"import torch\n",
|
|
"from pyprojroot import here as project_root\n",
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"sys.path.insert(0, str(project_root()))\n",
|
|
"\n",
|
|
"from src.evaluation.utils import get_test_path, get_model\n",
|
|
"from src.evaluation.eval import meta_test\n",
|
|
"\n",
|
|
"from src.train_utils.trainer import train_parser\n",
|
|
"from src.models.feature_extractors.pretrained_fe import get_fe_metadata\n",
|
|
"import torchvision.transforms as transforms\n",
|
|
"from PIL import Image\n",
|
|
"\n",
|
|
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "a90ff098-ad85-45db-8576-54ffe4c8a7cc",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def test_transform():\n",
|
|
" def _convert_image_to_rgb(im):\n",
|
|
" return im.convert('RGB')\n",
|
|
"\n",
|
|
" return transforms.Compose([\n",
|
|
" #transforms.Resize(224),\n",
|
|
" transforms.Resize(224),\n",
|
|
" #transforms.CenterCrop(224),\n",
|
|
" _convert_image_to_rgb,\n",
|
|
" transforms.ToTensor(),\n",
|
|
" transforms.Normalize(mean=torch.tensor([0.4815, 0.4578, 0.4082]), std=torch.tensor([0.2686, 0.2613, 0.2758])),\n",
|
|
" ])\n",
|
|
"\n",
|
|
"preprocess = test_transform()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "3a400d41-cc0b-4af2-aafe-fb1c82bf21a2",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Defaulting to float32 dtype\n",
|
|
"Loaded pretrained timm model vit_base_patch16_clip_224.openai\n",
|
|
"../caml_pretrained_models/CAML_CLIP/model.pth\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import enum\n",
|
|
"\n",
|
|
"class T:\n",
|
|
" fe_type = \"timm:vit_base_patch16_clip_224.openai:768\"\n",
|
|
" #fe_type = \"timm:vit_huge_patch14_clip_224.laion2b:1280\"\n",
|
|
" fe_dim = 768\n",
|
|
" fe_dtype = \"float32\"\n",
|
|
" model = \"CAML\"\n",
|
|
" dropout = 0.0\n",
|
|
" encoder_size = \"large\"\n",
|
|
"\n",
|
|
"fe_metadata = get_fe_metadata(T())\n",
|
|
"#test_path = get_test_path(args, data_path)\n",
|
|
"#device = torch.device(f'cuda:{args.gpu}')\n",
|
|
"\n",
|
|
"# Get the model and load its weights.\n",
|
|
"model, model_path = get_model(T(), fe_metadata, device)\n",
|
|
"print(model_path)\n",
|
|
"#print(model)\n",
|
|
"if model_path:\n",
|
|
" model.load_state_dict(torch.load(model_path, map_location=f'cuda:0'), strict=False)\n",
|
|
"model.to(device)\n",
|
|
"_= model.eval()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "534d1750-0433-403c-9d51-c0522747a97f",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2)\n",
|
|
"torch.Size([65, 3, 224, 224])\n",
|
|
"3\n",
|
|
"62\n",
|
|
"torch.Size([62, 4, 768])\n",
|
|
"tensor([0, 1, 2], device='cuda:0')\n",
|
|
"tensor([0, 1, 2], device='cuda:0')\n",
|
|
"torch.Size([3])\n",
|
|
"torch.Size([62, 4, 768])\n",
|
|
"herre\n",
|
|
"torch.Size([62])\n",
|
|
"(62,)\n",
|
|
"(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2)\n",
|
|
"torch.Size([65, 3, 224, 224])\n",
|
|
"9\n",
|
|
"56\n",
|
|
"torch.Size([56, 10, 768])\n",
|
|
"tensor([0, 0, 0, 1, 1, 1, 2, 2, 2], device='cuda:0')\n",
|
|
"tensor([0, 0, 0, 1, 1, 1, 2, 2, 2], device='cuda:0')\n",
|
|
"torch.Size([9])\n",
|
|
"torch.Size([56, 10, 768])\n",
|
|
"herre\n",
|
|
"torch.Size([56])\n",
|
|
"(56,)\n",
|
|
"(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2)\n",
|
|
"torch.Size([65, 3, 224, 224])\n",
|
|
"15\n",
|
|
"50\n",
|
|
"torch.Size([50, 16, 768])\n",
|
|
"tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2], device='cuda:0')\n",
|
|
"tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2], device='cuda:0')\n",
|
|
"torch.Size([15])\n",
|
|
"torch.Size([50, 16, 768])\n",
|
|
"herre\n",
|
|
"torch.Size([50])\n",
|
|
"(50,)\n",
|
|
"[0.58064516 0.51785714 0.52 ]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import os\n",
|
|
"\n",
|
|
"img_path = \"../pmf_cvpr22/data_custom\"\n",
|
|
"\n",
|
|
"def filecnt_in_dir(dirr, typ):\n",
|
|
" _, _, files = next(os.walk(f\"{img_path}/{dirr}/test/{typ}/\"))\n",
|
|
" return len(files)\n",
|
|
"\n",
|
|
"def evaluate(shot, way, folder):\n",
|
|
" ts = [\"good\", \"broken_small\", \"broken_large\", \"contamination\"]\n",
|
|
" tss = [\"good\", \"cable_swap\", \"combined\", \"cut_inner_insulation\", \"cut_outer_insulation\", \"missing_cable\", \"missing_wire\", \"poke_insulation\"]\n",
|
|
" tss = ts\n",
|
|
" cat = [\"bottle\", \"cable\"]\n",
|
|
"\n",
|
|
" #goodnr = (len(tss)-1) * shot\n",
|
|
" \n",
|
|
" with torch.no_grad():\n",
|
|
" #img_supp = [preprocess(Image.open(f\"{img_path}/{folder}/train/good/{i:03d}.png\")).unsqueeze(0).to(device) for i in range(shot)]\n",
|
|
" img_supp = [preprocess(Image.open(f\"{img_path}/{folder}/test/{n}/{i:03d}.png\")).unsqueeze(0).to(device) for n in tss[1:4] for i in range(shot)]\n",
|
|
" \n",
|
|
" tmp = [(preprocess(Image.open(f\"{img_path}/{folder}/test/{n}/{i:03d}.png\")).unsqueeze(0).to(device), tss.index(n)-1) for n in tss[1:4] for i in range(shot, filecnt_in_dir(folder, n))]\n",
|
|
" img_query, query_labels = zip(*tmp)\n",
|
|
" #print(tmp)\n",
|
|
" print(query_labels)\n",
|
|
" \n",
|
|
" img_concat = img_supp + list(img_query)\n",
|
|
" img_concat = torch.cat(img_concat, 0)\n",
|
|
" print(img_concat.shape)\n",
|
|
" print(len(img_supp))\n",
|
|
" print(len(img_query))\n",
|
|
" #shot = (len(tss)-1) * shot\n",
|
|
" \n",
|
|
" #logits = model.meta_test(img_concat, way=4, shot=shot, query_shot=1)\n",
|
|
" #print(logits)\n",
|
|
" #\n",
|
|
" feature_vector = model.get_feature_vector(img_concat)\n",
|
|
" support_features = feature_vector[:way * shot]\n",
|
|
" query_features = feature_vector[way * shot:]\n",
|
|
" b, d = query_features.shape\n",
|
|
" \n",
|
|
" # Reshape query and support to a sequence.\n",
|
|
" support = support_features.reshape(1, way * shot, d).repeat(b, 1, 1)\n",
|
|
" query = query_features.reshape(-1, 1, d)\n",
|
|
" feature_sequences = torch.cat([query, support], dim=1)\n",
|
|
" print(feature_sequences.shape)\n",
|
|
" \n",
|
|
" #labels = torch.LongTensor([i // shot for i in range(shot * way)]).to(device)\n",
|
|
" labels = torch.arange(way).repeat(shot, 1).T.flatten().to(model.device)\n",
|
|
" print(labels)\n",
|
|
" #labels = torch.from_numpy(np.ones(shape=shot, dtype=int)).to(device)\n",
|
|
" #labels = torch.cat([torch.from_numpy(np.zeros(shape=shot, dtype=int)).to(device), labels])\n",
|
|
" print(labels)\n",
|
|
" \n",
|
|
" #labels = torch.LongTensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ]).to(device)\n",
|
|
" print(labels.shape)\n",
|
|
" print(feature_sequences.shape)\n",
|
|
" logits = model.transformer_encoder.forward_imagenet_v2(feature_sequences, labels, way, shot)\n",
|
|
" #print(logits)\n",
|
|
" _, max_index = torch.max(logits[:, :way], 1)\n",
|
|
" #print(max_index.cpu().numpy())\n",
|
|
" #bbb = np.ones(shape=(14*4))\n",
|
|
" #bbb[:14] = 0\n",
|
|
" #print(np.mean(max_index.cpu().numpy() == bbb))\n",
|
|
" print(\"herre\")\n",
|
|
" print(max_index.shape)\n",
|
|
" print(np.array(query_labels).shape)\n",
|
|
"\n",
|
|
" return np.mean(max_index.cpu().numpy() == np.array(query_labels))\n",
|
|
"\n",
|
|
"scores = [evaluate(shot, 3, \"bottle\") for shot in [1,3,5]]\n",
|
|
"print(np.array(scores))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "22d68c6b-0df6-4953-8447-7843932fa974",
|
|
"metadata": {},
|
|
"source": [
|
|
"CAML:\n",
|
|
"Resulsts:\n",
|
|
"\n",
|
|
"bottle:\n",
|
|
"jeweils 1,3,5 shots normal\n",
|
|
"[0.40740741 0.39726027 0.30769231]\n",
|
|
"\n",
|
|
"inbalanced - mehr good shots 5,10,15,30 -> alle anderen nur 5\n",
|
|
"- not possible\n",
|
|
"1q\n",
|
|
"2 ways nur detektieren ob fehlerhaft oder nicht 3,6,9 shots -> wegen model restrictions\n",
|
|
"[0.79012346 0.84415584 0.87671233]\n",
|
|
"\n",
|
|
"inbalance 2 way 5,10,15,30 -> rest 5\n",
|
|
"- not possible\n",
|
|
"\n",
|
|
"nur fehlerklasse erkennen 1,3,5\n",
|
|
"[0.58064516 0.51785714 0.52 ]\n",
|
|
"\n",
|
|
"\n",
|
|
"cable:\n",
|
|
"jeweils 1,3,5 shots normal\n",
|
|
"[0.24031008 0.19834711 0.15929204]\n",
|
|
"\n",
|
|
"inbalanced - mehr good shots 5,10,15,30 -> alle anderen nur 5\n",
|
|
"- not possible\n",
|
|
"\n",
|
|
"2 ways nur detektieren ob fehlerhaft oder nicht 1,3,5 shots\n",
|
|
"[0.57364341 0.54545455 0.59292035]\n",
|
|
"\n",
|
|
"inbalance 2 way 5,10,15,30 -> rest 5\n",
|
|
"- not possible\n",
|
|
"\n",
|
|
"nur fehlerklasse erkennen 1,3,5\n",
|
|
"[0.12962963 0.36363636 0.58823529]\n"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.13.1"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|