896 lines
22 KiB
Plaintext
Raw Normal View History

2024-09-12 13:58:14 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import numpy as np\n",
"import time\n",
"import random\n",
"import torch\n",
"import torchvision.transforms as transforms\n",
"#import gradio as gr\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from models import get_model\n",
"from dotmap import DotMap\n",
"from PIL import Image"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pretrained weights found at dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth\n"
]
}
],
"source": [
"args = DotMap()\n",
"args.deploy = 'finetune'\n",
"args.arch = 'dino_base_patch16'\n",
"args.no_pretrain = True\n",
"small = \"https://huggingface.co/hushell/pmf_metadataset_dino/resolve/main/md_full_128x128_dinosmall_fp16_lr5e-5/best.pth?download=true\"\n",
"full = 'https://huggingface.co/hushell/pmf_metadataset_dino/resolve/main/md_full_128x128_dinobase_fp16_lr5e-5/best.pth?download=true'\n",
"args.resume = full\n",
"args.api_key = 'AIzaSyAFkOGnXhy-2ZB0imDvNNqf2rHb98vR_qY'\n",
"args.cx = '06d75168141bc47f1'\n",
"\n",
"args.ada_steps = 100\n",
"#args.ada_lr= 0.0001\n",
"#args.aug_prob = .95\n",
"args.ada_lr= 0.0001\n",
"args.aug_prob = .9\n",
"args.aug_types = [\"color\", \"translation\"]\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"model = get_model(args)\n",
"model.to(device)\n",
"checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu')\n",
"model.load_state_dict(checkpoint['model'], strict=True)\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"# image transforms\n",
"def test_transform():\n",
" def _convert_image_to_rgb(im):\n",
" return im.convert('RGB')\n",
"\n",
" return transforms.Compose([\n",
" transforms.Resize(224),\n",
" #transforms.CenterCrop(224),\n",
" _convert_image_to_rgb,\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
" std=[0.229, 0.224, 0.225]),\n",
" ])\n",
"\n",
"preprocess = test_transform()"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"def build_sup_set(shotnr, type_name, binary, good_sample_nr):\n",
" classes = next(os.walk(f'data_custom/{type_name}/test'))[1]\n",
" classes.remove(\"good\")\n",
"\n",
" supp_x = []\n",
" supp_y = []\n",
" mapping = {\n",
" \"good\" : 0\n",
" }\n",
"\n",
" # add good manually\n",
" x_good = [Image.open(f\"data_custom/{type_name}/train/good/{x:03d}.png\") for x in range(0, good_sample_nr)]\n",
" supp_x.extend([preprocess(x) for x in x_good]) # (3, H, W))\n",
" supp_y.extend([0] * good_sample_nr)\n",
" \n",
" for i,c in enumerate(classes):\n",
" #i-=1\n",
" x_im = [Image.open(f\"data_custom/{type_name}/test/{c}/{x:03d}.png\") for x in range(0, shotnr)]\n",
" supp_x.extend([preprocess(x) for x in x_im]) # (3, H, W))\n",
" if binary:\n",
" supp_y.extend([1] * shotnr)\n",
" mapping[\"anomaly\"] = 1\n",
" else:\n",
" supp_y.extend([i+1] * shotnr)\n",
" mapping[c] = i+1\n",
" \n",
" supp_x = torch.stack(supp_x, dim=0).unsqueeze(0).to(device) # (1, n_supp*n_labels, 3, H, W)\n",
" supp_y = torch.tensor(supp_y).long().unsqueeze(0).to(device) # (1, n_supp*n_labels)\n",
" return supp_x, supp_y, mapping\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"def build_test_set(shotnr, keyy, type):\n",
" _, _, files = next(os.walk(f\"data_custom/cable/test/{type}/\"))\n",
" file_count = len(files)\n",
" print(file_count)\n",
"\n",
" queries = [preprocess(Image.open(f\"data_custom/cable/test/{type}/{i:03d}.png\")).unsqueeze(0).unsqueeze(0).to(device) for i in range(shotnr,file_count)]\n",
" labels = [keyy for x in range(shotnr,file_count)]\n",
" return queries, labels\n",
"\n",
"def test(type, keyy, shotnr, folder):\n",
" predictions = []\n",
" _, _, files = next(os.walk(f\"data_custom/{folder}/test/{type}/\"))\n",
" file_count = len(files)\n",
" print(file_count)\n",
"\n",
" queries = [preprocess(Image.open(f\"data_custom/{folder}/test/{type}/{i:03d}.png\")).unsqueeze(0).unsqueeze(0).to(device) for i in range(shotnr,file_count)]\n",
" queries = torch.cat(queries)\n",
" with torch.cuda.amp.autocast(True):\n",
" output = model(supp_x, supp_y, queries) # (1, 1, n_labels)\n",
"\n",
" probs = output.softmax(dim=-1).detach().cpu().numpy()\n",
" predictions = np.argmax(probs, axis=2)\n",
" print()\n",
" return np.mean([x == keyy for x in predictions])\n",
" pass\n",
" \n",
"#def test2(folder):\n",
"# accs = []\n",
"# queries = []\n",
"# labels = []\n",
"# for t in next(os.walk(f'data_custom/cable/test'))[1]:\n",
"# q, l = build_test_set(shots, types.get(t, 1), t)\n",
"# queries+=q\n",
"# labels+=l\n",
"#\n",
"# queries = torch.cat(queries)\n",
"# labels = np.array(labels)\n",
"#\n",
"# with torch.cuda.amp.autocast(True):\n",
"# output = model(supp_x, supp_y, queries) # (1, 1, n_labels)\n",
"#\n",
"# probs = output.softmax(dim=-1).detach().cpu().numpy()\n",
"# predictions = np.argmax(probs, axis=2)\n",
"# print()\n",
"# return np.mean([predictions == labels])\n",
"# pass\n",
"\n",
"#print(f\"overall accuracy: {test(\"cable\")}\")\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"14\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp45, nQry9: loss = 0.1475423127412796: 100%|██| 100/100 [00:29<00:00, 3.40it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for cut_inner_insulation = 1.0\n",
"10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp45, nQry5: loss = 0.20609889924526215: 100%|█| 100/100 [00:29<00:00, 3.37it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for poke_insulation = 1.0\n",
"12\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp45, nQry7: loss = 0.12025140225887299: 100%|█| 100/100 [00:29<00:00, 3.34it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for cable_swap = 0.8571428571428571\n",
"10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp45, nQry5: loss = 0.2130972295999527: 100%|██| 100/100 [00:30<00:00, 3.30it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for cut_outer_insulation = 1.0\n",
"58\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp45, nQry53: loss = 0.13926956057548523: 100%|█| 100/100 [00:30<00:00, 3.30it/s\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for good = 0.16981132075471697\n",
"12\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp45, nQry7: loss = 0.16337624192237854: 100%|█| 100/100 [00:30<00:00, 3.28it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for missing_cable = 1.0\n",
"11\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp45, nQry6: loss = 0.16593313217163086: 100%|█| 100/100 [00:30<00:00, 3.27it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for combined = 1.0\n",
"13\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp45, nQry8: loss = 0.16560573875904083: 100%|█| 100/100 [00:30<00:00, 3.27it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for bent_wire = 1.0\n",
"10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp45, nQry5: loss = 0.18611018359661102: 100%|█| 100/100 [00:30<00:00, 3.28it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for missing_wire = 0.8\n",
"overall accuracy: 0.8696615753219527\n",
"14\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp50, nQry9: loss = 0.3357824385166168: 100%|██| 100/100 [00:33<00:00, 2.96it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for cut_inner_insulation = 0.7777777777777778\n",
"10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp50, nQry5: loss = 0.3290153741836548: 100%|██| 100/100 [00:33<00:00, 2.96it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for poke_insulation = 0.6\n",
"12\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp50, nQry7: loss = 0.22177687287330627: 100%|█| 100/100 [00:33<00:00, 2.96it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for cable_swap = 0.8571428571428571\n",
"10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp50, nQry5: loss = 0.299775630235672: 100%|███| 100/100 [00:33<00:00, 2.96it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for cut_outer_insulation = 1.0\n",
"58\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp50, nQry53: loss = 0.31954386830329895: 100%|█| 100/100 [00:33<00:00, 2.98it/s\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for good = 0.32075471698113206\n",
"12\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp50, nQry7: loss = 0.336273193359375: 100%|███| 100/100 [00:33<00:00, 2.98it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for missing_cable = 0.8571428571428571\n",
"11\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp50, nQry6: loss = 0.3643767237663269: 100%|██| 100/100 [00:33<00:00, 2.98it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for combined = 1.0\n",
"13\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp50, nQry8: loss = 0.3085792660713196: 100%|██| 100/100 [00:33<00:00, 2.98it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for bent_wire = 1.0\n",
"10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp50, nQry5: loss = 0.34715649485588074: 100%|█| 100/100 [00:33<00:00, 2.98it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for missing_wire = 0.8\n",
"overall accuracy: 0.8014242454494026\n",
"14\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp55, nQry9: loss = 0.375447154045105: 100%|███| 100/100 [00:36<00:00, 2.76it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for cut_inner_insulation = 0.6666666666666666\n",
"10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp55, nQry5: loss = 0.42370423674583435: 100%|█| 100/100 [00:36<00:00, 2.75it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for poke_insulation = 1.0\n",
"12\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp55, nQry7: loss = 0.3982161581516266: 100%|██| 100/100 [00:36<00:00, 2.74it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for cable_swap = 0.8571428571428571\n",
"10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp55, nQry5: loss = 0.3903641104698181: 100%|██| 100/100 [00:36<00:00, 2.75it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for cut_outer_insulation = 1.0\n",
"58\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp55, nQry53: loss = 0.4019339382648468: 100%|█| 100/100 [00:36<00:00, 2.75it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for good = 0.41509433962264153\n",
"12\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp55, nQry7: loss = 0.4283098876476288: 100%|██| 100/100 [00:36<00:00, 2.75it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for missing_cable = 0.7142857142857143\n",
"11\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp55, nQry6: loss = 0.3741377890110016: 100%|██| 100/100 [00:36<00:00, 2.74it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for combined = 0.8333333333333334\n",
"13\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp55, nQry8: loss = 0.3858358860015869: 100%|██| 100/100 [00:36<00:00, 2.75it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for bent_wire = 1.0\n",
"10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp55, nQry5: loss = 0.3570959270000458: 100%|██| 100/100 [00:36<00:00, 2.74it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for missing_wire = 0.8\n",
"overall accuracy: 0.8096136567834681\n",
"14\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp70, nQry9: loss = 0.5021733045578003: 100%|██| 100/100 [00:45<00:00, 2.21it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for cut_inner_insulation = 0.5555555555555556\n",
"10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp70, nQry5: loss = 0.5203520059585571: 100%|██| 100/100 [00:45<00:00, 2.20it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for poke_insulation = 0.4\n",
"12\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp70, nQry7: loss = 0.524366021156311: 100%|███| 100/100 [00:45<00:00, 2.21it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for cable_swap = 0.42857142857142855\n",
"10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp70, nQry5: loss = 0.5256413221359253: 100%|██| 100/100 [00:45<00:00, 2.21it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for cut_outer_insulation = 1.0\n",
"58\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp70, nQry53: loss = 0.5186663866043091: 100%|█| 100/100 [00:45<00:00, 2.21it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for good = 0.7358490566037735\n",
"12\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp70, nQry7: loss = 0.5123675465583801: 100%|██| 100/100 [00:45<00:00, 2.21it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for missing_cable = 0.7142857142857143\n",
"11\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp70, nQry6: loss = 0.5076506733894348: 100%|██| 100/100 [00:45<00:00, 2.21it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for combined = 0.8333333333333334\n",
"13\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp70, nQry8: loss = 0.490247517824173: 100%|███| 100/100 [00:45<00:00, 2.21it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for bent_wire = 0.875\n",
"10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"lr0.0001, nSupp70, nQry5: loss = 0.3723257780075073: 100%|██| 100/100 [00:45<00:00, 2.21it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"accuracy for missing_wire = 0.4\n",
"overall accuracy: 0.6602883431499785\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"#bottle_accs = []\n",
"cable_accs = []\n",
"\n",
"for nr in [5, 10, 15, 30]:\n",
" folder = \"cable\"\n",
" shot = 5\n",
" supp_x, supp_y, types = build_sup_set(shot, folder, True, nr)\n",
" accs = []\n",
" for t in next(os.walk(f'data_custom/{folder}/test'))[1]:\n",
" #if t == \"good\":\n",
" # continue\n",
" accuracy = test(t, types.get(t, 1), shot, folder)\n",
" print(f\"accuracy for {t} = {accuracy}\")\n",
" accs.append(accuracy)\n",
" print(f\"overall accuracy: {np.mean(accs)}\")\n",
" cable_accs.append(np.mean(accs))\n"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.57380952 0.76705653 0.84191176]\n"
]
}
],
"source": [
"print(np.array(bottle_accs))"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.86966158 0.80142425 0.80961366 0.66028834]\n"
]
}
],
"source": [
"print(np.array(cable_accs))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"P>M>F:\n",
"Resulsts:\n",
"\n",
"bottle:\n",
"jeweils 1,3,5 shots normal\n",
"[0.67910401 0.71710526 0.78860294]\n",
"\n",
"inbalanced - mehr good shots 5,10,15,30 -> alle anderen nur 5\n",
"[0.78768382 0.78860294 0.75827206 0.74356618]\n",
"\n",
"2 ways nur detektieren ob fehlerhaft oder nicht 1,3,5 shots\n",
"[0.86422306 0.93201754 0.93933824]\n",
"\n",
"inbalance 2 way 5,10,15,30 -> rest 5\n",
"[0.92371324 0.87867647 0.86397059 0.87775735]\n",
"\n",
"nur fehlerklasse erkennen 1,3,5\n",
"[0.57380952 0.76705653 0.84191176]\n",
"\n",
"\n",
"cable:\n",
"jeweils 1,3,5 shots normal\n",
"[0.25199021 0.44388328 0.46975059]\n",
"\n",
"inbalanced - mehr good shots 5,10,15,30 -> alle anderen nur 5\n",
"[0.50425859 0.48023277 0.43118282 0.41842534]\n",
"\n",
"2 ways nur detektieren ob fehlerhaft oder nicht 1,3,5 shots\n",
"[0.79263485 0.8707712 0.86756514]\n",
"\n",
"inbalance 2 way 5,10,15,30 -> rest 5\n",
"[0.86966158 0.80142425 0.80961366 0.66028834]\n",
"\n",
"nur fehlerklasse erkennen 1,3,5\n",
"[0.24383256 0.43800505 0.51304563]\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.4"
}
},
"nbformat": 4,
"nbformat_minor": 4
}