This commit is contained in:
parent
05756fcc7b
commit
c717b5e466
290
notebooks/caml.ipynb
Normal file
290
notebooks/caml.ipynb
Normal file
@ -0,0 +1,290 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "8dbc4931-506b-4eb7-9049-11fda71fa2fd",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/q315433/micromamba/envs/pmf/lib/python3.12/site-packages/torch/__init__.py:749: UserWarning: torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:431.)\n",
|
||||
" _C._set_default_tensor_type(t)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"import torch\n",
|
||||
"from pyprojroot import here as project_root\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"sys.path.insert(0, str(project_root()))\n",
|
||||
"\n",
|
||||
"from src.evaluation.utils import get_test_path, get_model\n",
|
||||
"from src.evaluation.eval import meta_test\n",
|
||||
"\n",
|
||||
"from src.train_utils.trainer import train_parser\n",
|
||||
"from src.models.feature_extractors.pretrained_fe import get_fe_metadata\n",
|
||||
"import torchvision.transforms as transforms\n",
|
||||
"from PIL import Image\n",
|
||||
"\n",
|
||||
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "a90ff098-ad85-45db-8576-54ffe4c8a7cc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def test_transform():\n",
|
||||
" def _convert_image_to_rgb(im):\n",
|
||||
" return im.convert('RGB')\n",
|
||||
"\n",
|
||||
" return transforms.Compose([\n",
|
||||
" #transforms.Resize(224),\n",
|
||||
" transforms.Resize(224),\n",
|
||||
" #transforms.CenterCrop(224),\n",
|
||||
" _convert_image_to_rgb,\n",
|
||||
" transforms.ToTensor(),\n",
|
||||
" transforms.Normalize(mean=torch.tensor([0.4815, 0.4578, 0.4082]), std=torch.tensor([0.2686, 0.2613, 0.2758])),\n",
|
||||
" ])\n",
|
||||
"\n",
|
||||
"preprocess = test_transform()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "3a400d41-cc0b-4af2-aafe-fb1c82bf21a2",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Defaulting to float32 dtype\n",
|
||||
"Loaded pretrained timm model vit_base_patch16_clip_224.openai\n",
|
||||
"../caml_pretrained_models/CAML_CLIP/model.pth\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import enum\n",
|
||||
"\n",
|
||||
"class T:\n",
|
||||
" fe_type = \"timm:vit_base_patch16_clip_224.openai:768\"\n",
|
||||
" #fe_type = \"timm:vit_huge_patch14_clip_224.laion2b:1280\"\n",
|
||||
" fe_dim = 768\n",
|
||||
" fe_dtype = \"float32\"\n",
|
||||
" model = \"CAML\"\n",
|
||||
" dropout = 0.0\n",
|
||||
" encoder_size = \"large\"\n",
|
||||
"\n",
|
||||
"fe_metadata = get_fe_metadata(T())\n",
|
||||
"#test_path = get_test_path(args, data_path)\n",
|
||||
"#device = torch.device(f'cuda:{args.gpu}')\n",
|
||||
"\n",
|
||||
"# Get the model and load its weights.\n",
|
||||
"model, model_path = get_model(T(), fe_metadata, device)\n",
|
||||
"print(model_path)\n",
|
||||
"#print(model)\n",
|
||||
"if model_path:\n",
|
||||
" model.load_state_dict(torch.load(model_path, map_location=f'cuda:0'), strict=False)\n",
|
||||
"model.to(device)\n",
|
||||
"_= model.eval()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "534d1750-0433-403c-9d51-c0522747a97f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2)\n",
|
||||
"torch.Size([65, 3, 224, 224])\n",
|
||||
"3\n",
|
||||
"62\n",
|
||||
"torch.Size([62, 4, 768])\n",
|
||||
"tensor([0, 1, 2], device='cuda:0')\n",
|
||||
"tensor([0, 1, 2], device='cuda:0')\n",
|
||||
"torch.Size([3])\n",
|
||||
"torch.Size([62, 4, 768])\n",
|
||||
"herre\n",
|
||||
"torch.Size([62])\n",
|
||||
"(62,)\n",
|
||||
"(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2)\n",
|
||||
"torch.Size([65, 3, 224, 224])\n",
|
||||
"9\n",
|
||||
"56\n",
|
||||
"torch.Size([56, 10, 768])\n",
|
||||
"tensor([0, 0, 0, 1, 1, 1, 2, 2, 2], device='cuda:0')\n",
|
||||
"tensor([0, 0, 0, 1, 1, 1, 2, 2, 2], device='cuda:0')\n",
|
||||
"torch.Size([9])\n",
|
||||
"torch.Size([56, 10, 768])\n",
|
||||
"herre\n",
|
||||
"torch.Size([56])\n",
|
||||
"(56,)\n",
|
||||
"(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2)\n",
|
||||
"torch.Size([65, 3, 224, 224])\n",
|
||||
"15\n",
|
||||
"50\n",
|
||||
"torch.Size([50, 16, 768])\n",
|
||||
"tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2], device='cuda:0')\n",
|
||||
"tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2], device='cuda:0')\n",
|
||||
"torch.Size([15])\n",
|
||||
"torch.Size([50, 16, 768])\n",
|
||||
"herre\n",
|
||||
"torch.Size([50])\n",
|
||||
"(50,)\n",
|
||||
"[0.58064516 0.51785714 0.52 ]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"img_path = \"../pmf_cvpr22/data_custom\"\n",
|
||||
"\n",
|
||||
"def filecnt_in_dir(dirr, typ):\n",
|
||||
" _, _, files = next(os.walk(f\"{img_path}/{dirr}/test/{typ}/\"))\n",
|
||||
" return len(files)\n",
|
||||
"\n",
|
||||
"def evaluate(shot, way, folder):\n",
|
||||
" ts = [\"good\", \"broken_small\", \"broken_large\", \"contamination\"]\n",
|
||||
" tss = [\"good\", \"cable_swap\", \"combined\", \"cut_inner_insulation\", \"cut_outer_insulation\", \"missing_cable\", \"missing_wire\", \"poke_insulation\"]\n",
|
||||
" tss = ts\n",
|
||||
" cat = [\"bottle\", \"cable\"]\n",
|
||||
"\n",
|
||||
" #goodnr = (len(tss)-1) * shot\n",
|
||||
" \n",
|
||||
" with torch.no_grad():\n",
|
||||
" #img_supp = [preprocess(Image.open(f\"{img_path}/{folder}/train/good/{i:03d}.png\")).unsqueeze(0).to(device) for i in range(shot)]\n",
|
||||
" img_supp = [preprocess(Image.open(f\"{img_path}/{folder}/test/{n}/{i:03d}.png\")).unsqueeze(0).to(device) for n in tss[1:4] for i in range(shot)]\n",
|
||||
" \n",
|
||||
" tmp = [(preprocess(Image.open(f\"{img_path}/{folder}/test/{n}/{i:03d}.png\")).unsqueeze(0).to(device), tss.index(n)-1) for n in tss[1:4] for i in range(shot, filecnt_in_dir(folder, n))]\n",
|
||||
" img_query, query_labels = zip(*tmp)\n",
|
||||
" #print(tmp)\n",
|
||||
" print(query_labels)\n",
|
||||
" \n",
|
||||
" img_concat = img_supp + list(img_query)\n",
|
||||
" img_concat = torch.cat(img_concat, 0)\n",
|
||||
" print(img_concat.shape)\n",
|
||||
" print(len(img_supp))\n",
|
||||
" print(len(img_query))\n",
|
||||
" #shot = (len(tss)-1) * shot\n",
|
||||
" \n",
|
||||
" #logits = model.meta_test(img_concat, way=4, shot=shot, query_shot=1)\n",
|
||||
" #print(logits)\n",
|
||||
" #\n",
|
||||
" feature_vector = model.get_feature_vector(img_concat)\n",
|
||||
" support_features = feature_vector[:way * shot]\n",
|
||||
" query_features = feature_vector[way * shot:]\n",
|
||||
" b, d = query_features.shape\n",
|
||||
" \n",
|
||||
" # Reshape query and support to a sequence.\n",
|
||||
" support = support_features.reshape(1, way * shot, d).repeat(b, 1, 1)\n",
|
||||
" query = query_features.reshape(-1, 1, d)\n",
|
||||
" feature_sequences = torch.cat([query, support], dim=1)\n",
|
||||
" print(feature_sequences.shape)\n",
|
||||
" \n",
|
||||
" #labels = torch.LongTensor([i // shot for i in range(shot * way)]).to(device)\n",
|
||||
" labels = torch.arange(way).repeat(shot, 1).T.flatten().to(model.device)\n",
|
||||
" print(labels)\n",
|
||||
" #labels = torch.from_numpy(np.ones(shape=shot, dtype=int)).to(device)\n",
|
||||
" #labels = torch.cat([torch.from_numpy(np.zeros(shape=shot, dtype=int)).to(device), labels])\n",
|
||||
" print(labels)\n",
|
||||
" \n",
|
||||
" #labels = torch.LongTensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ]).to(device)\n",
|
||||
" print(labels.shape)\n",
|
||||
" print(feature_sequences.shape)\n",
|
||||
" logits = model.transformer_encoder.forward_imagenet_v2(feature_sequences, labels, way, shot)\n",
|
||||
" #print(logits)\n",
|
||||
" _, max_index = torch.max(logits[:, :way], 1)\n",
|
||||
" #print(max_index.cpu().numpy())\n",
|
||||
" #bbb = np.ones(shape=(14*4))\n",
|
||||
" #bbb[:14] = 0\n",
|
||||
" #print(np.mean(max_index.cpu().numpy() == bbb))\n",
|
||||
" print(\"herre\")\n",
|
||||
" print(max_index.shape)\n",
|
||||
" print(np.array(query_labels).shape)\n",
|
||||
"\n",
|
||||
" return np.mean(max_index.cpu().numpy() == np.array(query_labels))\n",
|
||||
"\n",
|
||||
"scores = [evaluate(shot, 3, \"bottle\") for shot in [1,3,5]]\n",
|
||||
"print(np.array(scores))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "22d68c6b-0df6-4953-8447-7843932fa974",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"CAML:\n",
|
||||
"Resulsts:\n",
|
||||
"\n",
|
||||
"bottle:\n",
|
||||
"jeweils 1,3,5 shots normal\n",
|
||||
"[0.40740741 0.39726027 0.30769231]\n",
|
||||
"\n",
|
||||
"inbalanced - mehr good shots 5,10,15,30 -> alle anderen nur 5\n",
|
||||
"- not possible\n",
|
||||
"1q\n",
|
||||
"2 ways nur detektieren ob fehlerhaft oder nicht 3,6,9 shots -> wegen model restrictions\n",
|
||||
"[0.79012346 0.84415584 0.87671233]\n",
|
||||
"\n",
|
||||
"inbalance 2 way 5,10,15,30 -> rest 5\n",
|
||||
"- not possible\n",
|
||||
"\n",
|
||||
"nur fehlerklasse erkennen 1,3,5\n",
|
||||
"[0.58064516 0.51785714 0.52 ]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"cable:\n",
|
||||
"jeweils 1,3,5 shots normal\n",
|
||||
"[0.24031008 0.19834711 0.15929204]\n",
|
||||
"\n",
|
||||
"inbalanced - mehr good shots 5,10,15,30 -> alle anderen nur 5\n",
|
||||
"- not possible\n",
|
||||
"\n",
|
||||
"2 ways nur detektieren ob fehlerhaft oder nicht 1,3,5 shots\n",
|
||||
"[0.57364341 0.54545455 0.59292035]\n",
|
||||
"\n",
|
||||
"inbalance 2 way 5,10,15,30 -> rest 5\n",
|
||||
"- not possible\n",
|
||||
"\n",
|
||||
"nur fehlerklasse erkennen 1,3,5\n",
|
||||
"[0.12962963 0.36363636 0.58823529]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
168
notebooks/plots.ipynb
Normal file
168
notebooks/plots.ipynb
Normal file
File diff suppressed because one or more lines are too long
895
notebooks/pmf.ipynb
Normal file
895
notebooks/pmf.ipynb
Normal file
@ -0,0 +1,895 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import numpy as np\n",
|
||||
"import time\n",
|
||||
"import random\n",
|
||||
"import torch\n",
|
||||
"import torchvision.transforms as transforms\n",
|
||||
"#import gradio as gr\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"from models import get_model\n",
|
||||
"from dotmap import DotMap\n",
|
||||
"from PIL import Image"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Pretrained weights found at dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"args = DotMap()\n",
|
||||
"args.deploy = 'finetune'\n",
|
||||
"args.arch = 'dino_base_patch16'\n",
|
||||
"args.no_pretrain = True\n",
|
||||
"small = \"https://huggingface.co/hushell/pmf_metadataset_dino/resolve/main/md_full_128x128_dinosmall_fp16_lr5e-5/best.pth?download=true\"\n",
|
||||
"full = 'https://huggingface.co/hushell/pmf_metadataset_dino/resolve/main/md_full_128x128_dinobase_fp16_lr5e-5/best.pth?download=true'\n",
|
||||
"args.resume = full\n",
|
||||
"args.api_key = 'AIzaSyAFkOGnXhy-2ZB0imDvNNqf2rHb98vR_qY'\n",
|
||||
"args.cx = '06d75168141bc47f1'\n",
|
||||
"\n",
|
||||
"args.ada_steps = 100\n",
|
||||
"#args.ada_lr= 0.0001\n",
|
||||
"#args.aug_prob = .95\n",
|
||||
"args.ada_lr= 0.0001\n",
|
||||
"args.aug_prob = .9\n",
|
||||
"args.aug_types = [\"color\", \"translation\"]\n",
|
||||
"\n",
|
||||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
"model = get_model(args)\n",
|
||||
"model.to(device)\n",
|
||||
"checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu')\n",
|
||||
"model.load_state_dict(checkpoint['model'], strict=True)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# image transforms\n",
|
||||
"def test_transform():\n",
|
||||
" def _convert_image_to_rgb(im):\n",
|
||||
" return im.convert('RGB')\n",
|
||||
"\n",
|
||||
" return transforms.Compose([\n",
|
||||
" transforms.Resize(224),\n",
|
||||
" #transforms.CenterCrop(224),\n",
|
||||
" _convert_image_to_rgb,\n",
|
||||
" transforms.ToTensor(),\n",
|
||||
" transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
|
||||
" std=[0.229, 0.224, 0.225]),\n",
|
||||
" ])\n",
|
||||
"\n",
|
||||
"preprocess = test_transform()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 42,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def build_sup_set(shotnr, type_name, binary, good_sample_nr):\n",
|
||||
" classes = next(os.walk(f'data_custom/{type_name}/test'))[1]\n",
|
||||
" classes.remove(\"good\")\n",
|
||||
"\n",
|
||||
" supp_x = []\n",
|
||||
" supp_y = []\n",
|
||||
" mapping = {\n",
|
||||
" \"good\" : 0\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" # add good manually\n",
|
||||
" x_good = [Image.open(f\"data_custom/{type_name}/train/good/{x:03d}.png\") for x in range(0, good_sample_nr)]\n",
|
||||
" supp_x.extend([preprocess(x) for x in x_good]) # (3, H, W))\n",
|
||||
" supp_y.extend([0] * good_sample_nr)\n",
|
||||
" \n",
|
||||
" for i,c in enumerate(classes):\n",
|
||||
" #i-=1\n",
|
||||
" x_im = [Image.open(f\"data_custom/{type_name}/test/{c}/{x:03d}.png\") for x in range(0, shotnr)]\n",
|
||||
" supp_x.extend([preprocess(x) for x in x_im]) # (3, H, W))\n",
|
||||
" if binary:\n",
|
||||
" supp_y.extend([1] * shotnr)\n",
|
||||
" mapping[\"anomaly\"] = 1\n",
|
||||
" else:\n",
|
||||
" supp_y.extend([i+1] * shotnr)\n",
|
||||
" mapping[c] = i+1\n",
|
||||
" \n",
|
||||
" supp_x = torch.stack(supp_x, dim=0).unsqueeze(0).to(device) # (1, n_supp*n_labels, 3, H, W)\n",
|
||||
" supp_y = torch.tensor(supp_y).long().unsqueeze(0).to(device) # (1, n_supp*n_labels)\n",
|
||||
" return supp_x, supp_y, mapping\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def build_test_set(shotnr, keyy, type):\n",
|
||||
" _, _, files = next(os.walk(f\"data_custom/cable/test/{type}/\"))\n",
|
||||
" file_count = len(files)\n",
|
||||
" print(file_count)\n",
|
||||
"\n",
|
||||
" queries = [preprocess(Image.open(f\"data_custom/cable/test/{type}/{i:03d}.png\")).unsqueeze(0).unsqueeze(0).to(device) for i in range(shotnr,file_count)]\n",
|
||||
" labels = [keyy for x in range(shotnr,file_count)]\n",
|
||||
" return queries, labels\n",
|
||||
"\n",
|
||||
"def test(type, keyy, shotnr, folder):\n",
|
||||
" predictions = []\n",
|
||||
" _, _, files = next(os.walk(f\"data_custom/{folder}/test/{type}/\"))\n",
|
||||
" file_count = len(files)\n",
|
||||
" print(file_count)\n",
|
||||
"\n",
|
||||
" queries = [preprocess(Image.open(f\"data_custom/{folder}/test/{type}/{i:03d}.png\")).unsqueeze(0).unsqueeze(0).to(device) for i in range(shotnr,file_count)]\n",
|
||||
" queries = torch.cat(queries)\n",
|
||||
" with torch.cuda.amp.autocast(True):\n",
|
||||
" output = model(supp_x, supp_y, queries) # (1, 1, n_labels)\n",
|
||||
"\n",
|
||||
" probs = output.softmax(dim=-1).detach().cpu().numpy()\n",
|
||||
" predictions = np.argmax(probs, axis=2)\n",
|
||||
" print()\n",
|
||||
" return np.mean([x == keyy for x in predictions])\n",
|
||||
" pass\n",
|
||||
" \n",
|
||||
"#def test2(folder):\n",
|
||||
"# accs = []\n",
|
||||
"# queries = []\n",
|
||||
"# labels = []\n",
|
||||
"# for t in next(os.walk(f'data_custom/cable/test'))[1]:\n",
|
||||
"# q, l = build_test_set(shots, types.get(t, 1), t)\n",
|
||||
"# queries+=q\n",
|
||||
"# labels+=l\n",
|
||||
"#\n",
|
||||
"# queries = torch.cat(queries)\n",
|
||||
"# labels = np.array(labels)\n",
|
||||
"#\n",
|
||||
"# with torch.cuda.amp.autocast(True):\n",
|
||||
"# output = model(supp_x, supp_y, queries) # (1, 1, n_labels)\n",
|
||||
"#\n",
|
||||
"# probs = output.softmax(dim=-1).detach().cpu().numpy()\n",
|
||||
"# predictions = np.argmax(probs, axis=2)\n",
|
||||
"# print()\n",
|
||||
"# return np.mean([predictions == labels])\n",
|
||||
"# pass\n",
|
||||
"\n",
|
||||
"#print(f\"overall accuracy: {test(\"cable\")}\")\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 48,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"14\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp45, nQry9: loss = 0.1475423127412796: 100%|██| 100/100 [00:29<00:00, 3.40it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for cut_inner_insulation = 1.0\n",
|
||||
"10\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp45, nQry5: loss = 0.20609889924526215: 100%|█| 100/100 [00:29<00:00, 3.37it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for poke_insulation = 1.0\n",
|
||||
"12\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp45, nQry7: loss = 0.12025140225887299: 100%|█| 100/100 [00:29<00:00, 3.34it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for cable_swap = 0.8571428571428571\n",
|
||||
"10\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp45, nQry5: loss = 0.2130972295999527: 100%|██| 100/100 [00:30<00:00, 3.30it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for cut_outer_insulation = 1.0\n",
|
||||
"58\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp45, nQry53: loss = 0.13926956057548523: 100%|█| 100/100 [00:30<00:00, 3.30it/s\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for good = 0.16981132075471697\n",
|
||||
"12\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp45, nQry7: loss = 0.16337624192237854: 100%|█| 100/100 [00:30<00:00, 3.28it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for missing_cable = 1.0\n",
|
||||
"11\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp45, nQry6: loss = 0.16593313217163086: 100%|█| 100/100 [00:30<00:00, 3.27it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for combined = 1.0\n",
|
||||
"13\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp45, nQry8: loss = 0.16560573875904083: 100%|█| 100/100 [00:30<00:00, 3.27it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for bent_wire = 1.0\n",
|
||||
"10\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp45, nQry5: loss = 0.18611018359661102: 100%|█| 100/100 [00:30<00:00, 3.28it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for missing_wire = 0.8\n",
|
||||
"overall accuracy: 0.8696615753219527\n",
|
||||
"14\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp50, nQry9: loss = 0.3357824385166168: 100%|██| 100/100 [00:33<00:00, 2.96it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for cut_inner_insulation = 0.7777777777777778\n",
|
||||
"10\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp50, nQry5: loss = 0.3290153741836548: 100%|██| 100/100 [00:33<00:00, 2.96it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for poke_insulation = 0.6\n",
|
||||
"12\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp50, nQry7: loss = 0.22177687287330627: 100%|█| 100/100 [00:33<00:00, 2.96it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for cable_swap = 0.8571428571428571\n",
|
||||
"10\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp50, nQry5: loss = 0.299775630235672: 100%|███| 100/100 [00:33<00:00, 2.96it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for cut_outer_insulation = 1.0\n",
|
||||
"58\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp50, nQry53: loss = 0.31954386830329895: 100%|█| 100/100 [00:33<00:00, 2.98it/s\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for good = 0.32075471698113206\n",
|
||||
"12\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp50, nQry7: loss = 0.336273193359375: 100%|███| 100/100 [00:33<00:00, 2.98it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for missing_cable = 0.8571428571428571\n",
|
||||
"11\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp50, nQry6: loss = 0.3643767237663269: 100%|██| 100/100 [00:33<00:00, 2.98it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for combined = 1.0\n",
|
||||
"13\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp50, nQry8: loss = 0.3085792660713196: 100%|██| 100/100 [00:33<00:00, 2.98it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for bent_wire = 1.0\n",
|
||||
"10\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp50, nQry5: loss = 0.34715649485588074: 100%|█| 100/100 [00:33<00:00, 2.98it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for missing_wire = 0.8\n",
|
||||
"overall accuracy: 0.8014242454494026\n",
|
||||
"14\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp55, nQry9: loss = 0.375447154045105: 100%|███| 100/100 [00:36<00:00, 2.76it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for cut_inner_insulation = 0.6666666666666666\n",
|
||||
"10\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp55, nQry5: loss = 0.42370423674583435: 100%|█| 100/100 [00:36<00:00, 2.75it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for poke_insulation = 1.0\n",
|
||||
"12\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp55, nQry7: loss = 0.3982161581516266: 100%|██| 100/100 [00:36<00:00, 2.74it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for cable_swap = 0.8571428571428571\n",
|
||||
"10\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp55, nQry5: loss = 0.3903641104698181: 100%|██| 100/100 [00:36<00:00, 2.75it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for cut_outer_insulation = 1.0\n",
|
||||
"58\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp55, nQry53: loss = 0.4019339382648468: 100%|█| 100/100 [00:36<00:00, 2.75it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for good = 0.41509433962264153\n",
|
||||
"12\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp55, nQry7: loss = 0.4283098876476288: 100%|██| 100/100 [00:36<00:00, 2.75it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for missing_cable = 0.7142857142857143\n",
|
||||
"11\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp55, nQry6: loss = 0.3741377890110016: 100%|██| 100/100 [00:36<00:00, 2.74it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for combined = 0.8333333333333334\n",
|
||||
"13\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp55, nQry8: loss = 0.3858358860015869: 100%|██| 100/100 [00:36<00:00, 2.75it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for bent_wire = 1.0\n",
|
||||
"10\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp55, nQry5: loss = 0.3570959270000458: 100%|██| 100/100 [00:36<00:00, 2.74it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for missing_wire = 0.8\n",
|
||||
"overall accuracy: 0.8096136567834681\n",
|
||||
"14\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp70, nQry9: loss = 0.5021733045578003: 100%|██| 100/100 [00:45<00:00, 2.21it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for cut_inner_insulation = 0.5555555555555556\n",
|
||||
"10\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp70, nQry5: loss = 0.5203520059585571: 100%|██| 100/100 [00:45<00:00, 2.20it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for poke_insulation = 0.4\n",
|
||||
"12\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp70, nQry7: loss = 0.524366021156311: 100%|███| 100/100 [00:45<00:00, 2.21it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for cable_swap = 0.42857142857142855\n",
|
||||
"10\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp70, nQry5: loss = 0.5256413221359253: 100%|██| 100/100 [00:45<00:00, 2.21it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for cut_outer_insulation = 1.0\n",
|
||||
"58\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp70, nQry53: loss = 0.5186663866043091: 100%|█| 100/100 [00:45<00:00, 2.21it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for good = 0.7358490566037735\n",
|
||||
"12\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp70, nQry7: loss = 0.5123675465583801: 100%|██| 100/100 [00:45<00:00, 2.21it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for missing_cable = 0.7142857142857143\n",
|
||||
"11\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp70, nQry6: loss = 0.5076506733894348: 100%|██| 100/100 [00:45<00:00, 2.21it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for combined = 0.8333333333333334\n",
|
||||
"13\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp70, nQry8: loss = 0.490247517824173: 100%|███| 100/100 [00:45<00:00, 2.21it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for bent_wire = 0.875\n",
|
||||
"10\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"lr0.0001, nSupp70, nQry5: loss = 0.3723257780075073: 100%|██| 100/100 [00:45<00:00, 2.21it/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"accuracy for missing_wire = 0.4\n",
|
||||
"overall accuracy: 0.6602883431499785\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"#bottle_accs = []\n",
|
||||
"cable_accs = []\n",
|
||||
"\n",
|
||||
"for nr in [5, 10, 15, 30]:\n",
|
||||
" folder = \"cable\"\n",
|
||||
" shot = 5\n",
|
||||
" supp_x, supp_y, types = build_sup_set(shot, folder, True, nr)\n",
|
||||
" accs = []\n",
|
||||
" for t in next(os.walk(f'data_custom/{folder}/test'))[1]:\n",
|
||||
" #if t == \"good\":\n",
|
||||
" # continue\n",
|
||||
" accuracy = test(t, types.get(t, 1), shot, folder)\n",
|
||||
" print(f\"accuracy for {t} = {accuracy}\")\n",
|
||||
" accs.append(accuracy)\n",
|
||||
" print(f\"overall accuracy: {np.mean(accs)}\")\n",
|
||||
" cable_accs.append(np.mean(accs))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[0.57380952 0.76705653 0.84191176]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(np.array(bottle_accs))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 49,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[0.86966158 0.80142425 0.80961366 0.66028834]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(np.array(cable_accs))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"P>M>F:\n",
|
||||
"Resulsts:\n",
|
||||
"\n",
|
||||
"bottle:\n",
|
||||
"jeweils 1,3,5 shots normal\n",
|
||||
"[0.67910401 0.71710526 0.78860294]\n",
|
||||
"\n",
|
||||
"inbalanced - mehr good shots 5,10,15,30 -> alle anderen nur 5\n",
|
||||
"[0.78768382 0.78860294 0.75827206 0.74356618]\n",
|
||||
"\n",
|
||||
"2 ways nur detektieren ob fehlerhaft oder nicht 1,3,5 shots\n",
|
||||
"[0.86422306 0.93201754 0.93933824]\n",
|
||||
"\n",
|
||||
"inbalance 2 way 5,10,15,30 -> rest 5\n",
|
||||
"[0.92371324 0.87867647 0.86397059 0.87775735]\n",
|
||||
"\n",
|
||||
"nur fehlerklasse erkennen 1,3,5\n",
|
||||
"[0.57380952 0.76705653 0.84191176]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"cable:\n",
|
||||
"jeweils 1,3,5 shots normal\n",
|
||||
"[0.25199021 0.44388328 0.46975059]\n",
|
||||
"\n",
|
||||
"inbalanced - mehr good shots 5,10,15,30 -> alle anderen nur 5\n",
|
||||
"[0.50425859 0.48023277 0.43118282 0.41842534]\n",
|
||||
"\n",
|
||||
"2 ways nur detektieren ob fehlerhaft oder nicht 1,3,5 shots\n",
|
||||
"[0.79263485 0.8707712 0.86756514]\n",
|
||||
"\n",
|
||||
"inbalance 2 way 5,10,15,30 -> rest 5\n",
|
||||
"[0.86966158 0.80142425 0.80961366 0.66028834]\n",
|
||||
"\n",
|
||||
"nur fehlerklasse erkennen 1,3,5\n",
|
||||
"[0.24383256 0.43800505 0.51304563]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
495
notebooks/resnet50.ipynb
Normal file
495
notebooks/resnet50.ipynb
Normal file
@ -0,0 +1,495 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"imports imported\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"#import numpy as np\n",
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"import cv2\n",
|
||||
"\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from torch import optim, nn\n",
|
||||
"import torchvision\n",
|
||||
"from torchvision import datasets, models, transforms\n",
|
||||
"import albumentations as A\n",
|
||||
"from albumentations.pytorch import ToTensorV2\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"print(\"imports imported\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class Identity(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(Identity, self).__init__()\n",
|
||||
" \n",
|
||||
" def forward(self, x):\n",
|
||||
" return x"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"ResNet(\n",
|
||||
" (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU(inplace=True)\n",
|
||||
" (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
|
||||
" (layer1): Sequential(\n",
|
||||
" (0): Bottleneck(\n",
|
||||
" (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU(inplace=True)\n",
|
||||
" (downsample): Sequential(\n",
|
||||
" (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (1): Bottleneck(\n",
|
||||
" (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" (2): Bottleneck(\n",
|
||||
" (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (layer2): Sequential(\n",
|
||||
" (0): Bottleneck(\n",
|
||||
" (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU(inplace=True)\n",
|
||||
" (downsample): Sequential(\n",
|
||||
" (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
|
||||
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (1): Bottleneck(\n",
|
||||
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" (2): Bottleneck(\n",
|
||||
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" (3): Bottleneck(\n",
|
||||
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (layer3): Sequential(\n",
|
||||
" (0): Bottleneck(\n",
|
||||
" (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU(inplace=True)\n",
|
||||
" (downsample): Sequential(\n",
|
||||
" (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
|
||||
" (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (1): Bottleneck(\n",
|
||||
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" (2): Bottleneck(\n",
|
||||
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" (3): Bottleneck(\n",
|
||||
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" (4): Bottleneck(\n",
|
||||
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" (5): Bottleneck(\n",
|
||||
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (layer4): Sequential(\n",
|
||||
" (0): Bottleneck(\n",
|
||||
" (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU(inplace=True)\n",
|
||||
" (downsample): Sequential(\n",
|
||||
" (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
|
||||
" (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (1): Bottleneck(\n",
|
||||
" (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" (2): Bottleneck(\n",
|
||||
" (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
|
||||
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
|
||||
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
||||
" (relu): ReLU(inplace=True)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
|
||||
" (fc): Linear(in_features=2048, out_features=1000, bias=True)\n",
|
||||
")\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"resnet50 = models.resnet50(weights=models.ResNetshotnr0_Weights.DEFAULT)\n",
|
||||
"\n",
|
||||
"print(resnetshotnr0)\n",
|
||||
"# Step 2: Modify the model to output features from the layer before the fully connected layer\n",
|
||||
"class ResNetshotnr0Embeddings(nn.Module):\n",
|
||||
" def __init__(self, original_model, layernr):\n",
|
||||
" super(ResNetshotnr0Embeddings, self).__init__()\n",
|
||||
" #print(list(original_model.children())[4 + layernr])\n",
|
||||
" #print(nn.Sequential(*list(original_model.children())[:4 + shotnr]))\n",
|
||||
" self.features = nn.Sequential(*list(original_model.children())[:4+layernr])\n",
|
||||
" #self.features = nn.Sequential(*list(original_model.children())[:-1]) # Exclude the fully connected layer\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" x = self.features(x)\n",
|
||||
" x = torch.flatten(x, 1) # Flatten the tensor to (batch_size, 2048)\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"# Instantiate the modified model\n",
|
||||
"model = ResNetshotnr0Embeddings(resnetshotnr0, shotnr) # 3 = layer before fully connected one\n",
|
||||
"model.eval() # Set the model to evaluation mode\n",
|
||||
"print()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Test"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 69,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"...............\n",
|
||||
"accuracy for broken_large = 0.6666666666666666\n",
|
||||
".................\n",
|
||||
"accuracy for broken_small = 0.8823529411764706\n",
|
||||
"................\n",
|
||||
"accuracy for contamination = 0.8125\n",
|
||||
"overall accuracy: 0.7871732026143791\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from sklearn.metrics.pairwise import cosine_similarity,euclidean_distances\n",
|
||||
"from metric_learn import LMNN,NCA\n",
|
||||
"import math\n",
|
||||
"\n",
|
||||
"pipe = A.Compose([A.Resize(256,256), A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ToTensorV2()])\n",
|
||||
"#pipe = A.Compose([A.Resize(256,256), ToTensorV2()])\n",
|
||||
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
"\n",
|
||||
"m = ResNet50Embeddings(resnet50, 5) # 5 = all without fully ocnnected\n",
|
||||
"m.eval()\n",
|
||||
"m.to(device)\n",
|
||||
"\n",
|
||||
"def read_img(path):\n",
|
||||
" img = cv2.imread(path, cv2.IMREAD_COLOR)\n",
|
||||
" img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
|
||||
" #plt.imshow(img)\n",
|
||||
"\n",
|
||||
" imgpiped = pipe(image=img)[\"image\"].unsqueeze(0)\n",
|
||||
" return imgpiped\n",
|
||||
"\n",
|
||||
"def compare_embeddings(emb1, emb2, distance_metric):\n",
|
||||
" #cosi = torch.nn.CosineSimilarity(dim=0) \n",
|
||||
" #output = cosine_similarity([emb1.flatten(), emb2.flatten()])\n",
|
||||
" #output = euclidean_distances([emb1.flatten(), emb2.flatten()], [emb1.flatten(), emb2.flatten()])\n",
|
||||
" output = distance_metric(emb1, emb2)\n",
|
||||
" return output\n",
|
||||
"\n",
|
||||
"def merge_embeddings(embeddings):\n",
|
||||
" # todo calc cluster center or similar\n",
|
||||
" return np.average(embeddings, axis=0)\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"#embedding_good_1 = m(read_img(f\"./data/bottle/test/good/001.png\")).detach().numpy()\n",
|
||||
"#embedding_good_2 = m(read_img(f\"./data/bottle/test/good/002.png\")).detach().numpy()\n",
|
||||
"#embedding_good = merge_embeddings([embedding_good_1, embedding_good_2])\n",
|
||||
"#embedding_contermination_1 = m(read_img(f\"./data/bottle/test/contamination/001.png\")).detach().numpy()\n",
|
||||
"#embedding_contermination_2 = m(read_img(f\"./data/bottle/test/contamination/002.png\")).detach().numpy()\n",
|
||||
"#embedding_contermination = merge_embeddings([embedding_contermination_1, embedding_contermination_2])\n",
|
||||
"#embedding_broken_small_1 = m(read_img(f\"./data/bottle/test/broken_small/001.png\")).detach().numpy()\n",
|
||||
"\n",
|
||||
"#embeddings_test = m(read_img(f\"./data/bottle/test/contamination/004.png\")).detach().numpy()\n",
|
||||
"\n",
|
||||
"#score = compare_embeddings(embedding_good_1, embeddings_test)\n",
|
||||
"\n",
|
||||
"#def calc_base_emb(t):\n",
|
||||
"# base_emb_1 = m(read_img(f\"./data/bottle/test/{t}/000.png\")).detach().numpy()\n",
|
||||
"# base_emb_2 = m(read_img(f\"./data/bottle/test/{t}/001.png\")).detach().numpy()\n",
|
||||
"# base_emb_3 = m(read_img(f\"./data/bottle/test/{t}/002.png\")).detach().numpy()\n",
|
||||
"# base_emb_4 = m(read_img(f\"./data/bottle/test/{t}/003.png\")).detach().numpy()\n",
|
||||
"# base_emb_5 = m(read_img(f\"./data/bottle/test/{t}/004.png\")).detach().numpy()\n",
|
||||
"# base_emb = merge_embeddings([base_emb_1, base_emb_2, base_emb_3, base_emb_4, base_emb_5])\n",
|
||||
"# return base_emb\n",
|
||||
"\n",
|
||||
"MAIN_TYPE=\"bottle\"\n",
|
||||
"\n",
|
||||
"def calc_base_emb(t, nr):\n",
|
||||
" embs = []\n",
|
||||
" for i in range(nr):\n",
|
||||
" if t == \"good\":\n",
|
||||
" emb = m(read_img(f\"./data/{MAIN_TYPE}/train/{t}/{i:03d}.png\")).detach().numpy()\n",
|
||||
" else:\n",
|
||||
" emb = m(read_img(f\"./data/{MAIN_TYPE}/test/{t}/{i:03d}.png\")).detach().numpy()\n",
|
||||
" embs.append(emb)\n",
|
||||
" base_emb = merge_embeddings(embs)\n",
|
||||
" return base_emb\n",
|
||||
"\n",
|
||||
"shotnr=5\n",
|
||||
"goodnr=5\n",
|
||||
"\n",
|
||||
"types = {#\"good\": calc_base_emb(\"good\", goodnr), \n",
|
||||
" #\"bad\": merge_embeddings([calc_base_emb(\"broken_large\", shotnr), calc_base_emb(\"broken_small\", shotnr), calc_base_emb(\"contamination\", shotnr)]),\n",
|
||||
" \"broken_large\": calc_base_emb(\"broken_large\", shotnr), \n",
|
||||
" \"broken_small\": calc_base_emb(\"broken_small\", shotnr), \n",
|
||||
" \"contamination\": calc_base_emb(\"contamination\", shotnr)\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"#types = {#\"good\": calc_base_emb(\"good\", goodnr), \n",
|
||||
"# #\"bad\": merge_embeddings([calc_base_emb(\"bent_wire\", shotnr), calc_base_emb(\"cable_swap\", shotnr), calc_base_emb(\"combined\", shotnr), \n",
|
||||
"# # calc_base_emb(\"cut_inner_insulation\", shotnr), calc_base_emb(\"cut_outer_insulation\", shotnr), calc_base_emb(\"missing_cable\", shotnr), \n",
|
||||
"# # calc_base_emb(\"missing_wire\", shotnr), calc_base_emb(\"poke_insulation\", shotnr)]),\n",
|
||||
"# \"bent_wire\": calc_base_emb(\"bent_wire\", shotnr),\n",
|
||||
"# \"cable_swap\": calc_base_emb(\"cable_swap\", shotnr), \n",
|
||||
"# \"combined\": calc_base_emb(\"combined\", shotnr), \n",
|
||||
"# \"cut_inner_insulation\": calc_base_emb(\"cut_inner_insulation\", shotnr), \n",
|
||||
"# \"cut_outer_insulation\": calc_base_emb(\"cut_outer_insulation\", shotnr), \n",
|
||||
"# \"missing_cable\": calc_base_emb(\"missing_cable\", shotnr), \n",
|
||||
"# \"missing_wire\": calc_base_emb(\"missing_wire\", shotnr), \n",
|
||||
"# \"poke_insulation\": calc_base_emb(\"poke_insulation\", shotnr), \n",
|
||||
"# }\n",
|
||||
"\n",
|
||||
"# euclidean distance\n",
|
||||
"euclidean_distance_metric = lambda emb1,emb2 : math.pow(euclidean_distances([emb1.flatten(), emb2.flatten()], [emb1.flatten(), emb2.flatten()])[0][1], 2)\n",
|
||||
"# cosine metric\n",
|
||||
"cosine_similarity_metric = lambda emb1,emb2 : cosine_similarity([emb1.flatten(), emb2.flatten()])[0][1]\n",
|
||||
"\n",
|
||||
"lmnn = LMNN(n_neighbors=2, learn_rate=1e-3, verbose=True)\n",
|
||||
"#lmnn.fit(data, [0,0,0,1,1,1,2,2,2,3,3,3])\n",
|
||||
"\n",
|
||||
"lmnn_similarity_metric = lambda emb1,emb2 : lmnn.get_metric()(emb1.flatten(), emb2.flatten())\n",
|
||||
"\n",
|
||||
"Smaller_Better_Metric = False\n",
|
||||
"\n",
|
||||
"def test(type):\n",
|
||||
" predictions = []\n",
|
||||
"\n",
|
||||
" _, _, files = next(os.walk(f\"./data/{MAIN_TYPE}/test/{type}/\"))\n",
|
||||
" file_count = len(files)\n",
|
||||
" for i in range(5,file_count):\n",
|
||||
" print(\".\", end=\"\")\n",
|
||||
"\n",
|
||||
" emb = m(read_img(f\"./data/{MAIN_TYPE}/test/{type}/{i:03d}.png\")).detach().numpy()\n",
|
||||
" curr_score = .0 if Smaller_Better_Metric else 999999.0\n",
|
||||
" max_type = \"\"\n",
|
||||
" for t in types.keys():\n",
|
||||
"# for t in [\"good\", \"bad\"]:\n",
|
||||
" score = compare_embeddings(emb, types[t], euclidean_distance_metric)\n",
|
||||
" \n",
|
||||
" if Smaller_Better_Metric:\n",
|
||||
" if score > curr_score:\n",
|
||||
" curr_score = score\n",
|
||||
" max_type = t\n",
|
||||
" else:\n",
|
||||
" if score < curr_score:\n",
|
||||
" curr_score = score\n",
|
||||
" max_type = t\n",
|
||||
" pass\n",
|
||||
" predictions.append(max_type)\n",
|
||||
" pass\n",
|
||||
" print()\n",
|
||||
" return np.mean([x == type for x in predictions])\n",
|
||||
"# return np.mean([x == (\"good\" if type == \"good\" else \"bad\") for x in predictions])\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
"accs = []\n",
|
||||
"for t in types.keys():\n",
|
||||
" if t == \"bad\":\n",
|
||||
" continue\n",
|
||||
" accuracy = test(t)\n",
|
||||
" print(f\"accuracy for {t} = {accuracy}\")\n",
|
||||
" accs.append(accuracy)\n",
|
||||
"print(f\"overall accuracy: {np.mean(accs)}\")\n",
|
||||
"#print(m)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"RESNET 50:\n",
|
||||
"Resulsts:\n",
|
||||
"\n",
|
||||
"bottle:\n",
|
||||
"jeweils 1,3,5 shots normal\n",
|
||||
"[0.5892857142857143 0.7321428571428571 0.75]\n",
|
||||
"\n",
|
||||
"inbalanced - mehr good shots 5,10,15,30 -> alle anderen nur 5\n",
|
||||
"[0.75 0.732 0.696 0.696]\n",
|
||||
"\n",
|
||||
"2 ways nur detektieren ob fehlerhaft oder nicht 1,3,5 shots\n",
|
||||
"[0.8395 0.8315 0.8031]\n",
|
||||
"\n",
|
||||
"inbalance 2 way 5,10,15,30 -> rest 5\n",
|
||||
"[0.8031 0.81893 0.8336 0.8031]\n",
|
||||
"\n",
|
||||
"nur fehlerklasse erkennen 1,3,5\n",
|
||||
"[0.7638 0.7428 0.787]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"cable:\n",
|
||||
"jeweils 1,3,5 shots normal\n",
|
||||
"[0.21808 0.43815 0.4321478]\n",
|
||||
"\n",
|
||||
"inbalanced - mehr good shots 5,10,15,30 -> alle anderen nur 5\n",
|
||||
"[0.4321478 0.432986 0.42340 0.4464635]\n",
|
||||
"\n",
|
||||
"2 ways nur detektieren ob fehlerhaft oder nicht 1,3,5 shots\n",
|
||||
"[0.8592 0.8772 0.8495]\n",
|
||||
"\n",
|
||||
"inbalance 2 way 5,10,15,30 -> rest 5\n",
|
||||
"[0.8495 0.8180 0.7460 0.6846]\n",
|
||||
"\n",
|
||||
"nur fehlerklasse erkennen 1,3,5\n",
|
||||
"[0.240 0.4740 0.4805]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.14"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -1 +1,29 @@
|
||||
Test intro
|
||||
\section{Introduction}\label{sec:introduction}
|
||||
\subsection{Motivation}\label{subsec:motivation}
|
||||
For most supervised learning tasks lots of training samples are essential.
|
||||
With too less training data the model will not generalize well and not fit a real world task.
|
||||
Labeling datasets is commonly seen as an expensive task and wants to be avoided as much as possible.\cite{generalAI}
|
||||
That's why there is a machine-learning field called active learning.
|
||||
The general approach is to train a model that predicts within every iteration a ranking metric or Pseudo-Labels which then can be used to rank the importance of samples to be labeled by an oracle.
|
||||
These labeled samples are then used to train the model.\cite{activelearning}
|
||||
|
||||
The goal of this practical work is to test active learning within a simple classification task and evaluate its performance.
|
||||
\subsection{Research Questions}\label{subsec:research-questions}
|
||||
|
||||
\subsubsection{Is Few-Shot learning a suitable fit for anomaly detection?}
|
||||
|
||||
Should Few-Shot learning be used for anomaly detection tasks?
|
||||
How does it compare to well established algorithms such as Patchcore or EfficientAD?
|
||||
|
||||
\subsubsection{How does disbalancing the Shot number affect performance?}
|
||||
Does giving the Few-Shot learner more good than bad samples improve the model performance?
|
||||
|
||||
\subsubsection{How does the 3 methods perform in only detecting the anomaly class?}
|
||||
How much does the performance improve if only detecting an anomaly or not?
|
||||
How does it compare to PatchCore and EfficientAD?
|
||||
|
||||
\subsubsection{Extra: How does Euclidean distance compare to Cosine-similarity when using ResNet as a feature-extractor?}
|
||||
I've tried different distance measures -> but results are pretty much the same.
|
||||
|
||||
\subsection{Outline}\label{subsec:outline}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user