{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "import numpy as np\n", "import time\n", "import random\n", "import torch\n", "import torchvision.transforms as transforms\n", "#import gradio as gr\n", "import matplotlib.pyplot as plt\n", "\n", "from models import get_model\n", "from dotmap import DotMap\n", "from PIL import Image" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Pretrained weights found at dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth\n" ] } ], "source": [ "args = DotMap()\n", "args.deploy = 'finetune'\n", "args.arch = 'dino_base_patch16'\n", "args.no_pretrain = True\n", "small = \"https://huggingface.co/hushell/pmf_metadataset_dino/resolve/main/md_full_128x128_dinosmall_fp16_lr5e-5/best.pth?download=true\"\n", "full = 'https://huggingface.co/hushell/pmf_metadataset_dino/resolve/main/md_full_128x128_dinobase_fp16_lr5e-5/best.pth?download=true'\n", "args.resume = full\n", "args.api_key = 'AIzaSyAFkOGnXhy-2ZB0imDvNNqf2rHb98vR_qY'\n", "args.cx = '06d75168141bc47f1'\n", "\n", "args.ada_steps = 100\n", "#args.ada_lr= 0.0001\n", "#args.aug_prob = .95\n", "args.ada_lr= 0.0001\n", "args.aug_prob = .9\n", "args.aug_types = [\"color\", \"translation\"]\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "model = get_model(args)\n", "model.to(device)\n", "checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu')\n", "model.load_state_dict(checkpoint['model'], strict=True)\n" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "# image transforms\n", "def test_transform():\n", " def _convert_image_to_rgb(im):\n", " return im.convert('RGB')\n", "\n", " return transforms.Compose([\n", " transforms.Resize(224),\n", " #transforms.CenterCrop(224),\n", " _convert_image_to_rgb,\n", " transforms.ToTensor(),\n", " transforms.Normalize(mean=[0.485, 0.456, 0.406],\n", " std=[0.229, 0.224, 0.225]),\n", " ])\n", "\n", "preprocess = test_transform()" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [], "source": [ "def build_sup_set(shotnr, type_name, binary, good_sample_nr):\n", " classes = next(os.walk(f'data_custom/{type_name}/test'))[1]\n", " classes.remove(\"good\")\n", "\n", " supp_x = []\n", " supp_y = []\n", " mapping = {\n", " \"good\" : 0\n", " }\n", "\n", " # add good manually\n", " x_good = [Image.open(f\"data_custom/{type_name}/train/good/{x:03d}.png\") for x in range(0, good_sample_nr)]\n", " supp_x.extend([preprocess(x) for x in x_good]) # (3, H, W))\n", " supp_y.extend([0] * good_sample_nr)\n", " \n", " for i,c in enumerate(classes):\n", " #i-=1\n", " x_im = [Image.open(f\"data_custom/{type_name}/test/{c}/{x:03d}.png\") for x in range(0, shotnr)]\n", " supp_x.extend([preprocess(x) for x in x_im]) # (3, H, W))\n", " if binary:\n", " supp_y.extend([1] * shotnr)\n", " mapping[\"anomaly\"] = 1\n", " else:\n", " supp_y.extend([i+1] * shotnr)\n", " mapping[c] = i+1\n", " \n", " supp_x = torch.stack(supp_x, dim=0).unsqueeze(0).to(device) # (1, n_supp*n_labels, 3, H, W)\n", " supp_y = torch.tensor(supp_y).long().unsqueeze(0).to(device) # (1, n_supp*n_labels)\n", " return supp_x, supp_y, mapping\n" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "scrolled": true }, "outputs": [], "source": [ "def build_test_set(shotnr, keyy, type):\n", " _, _, files = next(os.walk(f\"data_custom/cable/test/{type}/\"))\n", " file_count = len(files)\n", " print(file_count)\n", "\n", " queries = [preprocess(Image.open(f\"data_custom/cable/test/{type}/{i:03d}.png\")).unsqueeze(0).unsqueeze(0).to(device) for i in range(shotnr,file_count)]\n", " labels = [keyy for x in range(shotnr,file_count)]\n", " return queries, labels\n", "\n", "def test(type, keyy, shotnr, folder):\n", " predictions = []\n", " _, _, files = next(os.walk(f\"data_custom/{folder}/test/{type}/\"))\n", " file_count = len(files)\n", " print(file_count)\n", "\n", " queries = [preprocess(Image.open(f\"data_custom/{folder}/test/{type}/{i:03d}.png\")).unsqueeze(0).unsqueeze(0).to(device) for i in range(shotnr,file_count)]\n", " queries = torch.cat(queries)\n", " with torch.cuda.amp.autocast(True):\n", " output = model(supp_x, supp_y, queries) # (1, 1, n_labels)\n", "\n", " probs = output.softmax(dim=-1).detach().cpu().numpy()\n", " predictions = np.argmax(probs, axis=2)\n", " print()\n", " return np.mean([x == keyy for x in predictions])\n", " pass\n", " \n", "#def test2(folder):\n", "# accs = []\n", "# queries = []\n", "# labels = []\n", "# for t in next(os.walk(f'data_custom/cable/test'))[1]:\n", "# q, l = build_test_set(shots, types.get(t, 1), t)\n", "# queries+=q\n", "# labels+=l\n", "#\n", "# queries = torch.cat(queries)\n", "# labels = np.array(labels)\n", "#\n", "# with torch.cuda.amp.autocast(True):\n", "# output = model(supp_x, supp_y, queries) # (1, 1, n_labels)\n", "#\n", "# probs = output.softmax(dim=-1).detach().cpu().numpy()\n", "# predictions = np.argmax(probs, axis=2)\n", "# print()\n", "# return np.mean([predictions == labels])\n", "# pass\n", "\n", "#print(f\"overall accuracy: {test(\"cable\")}\")\n", "\n" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "14\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp45, nQry9: loss = 0.1475423127412796: 100%|██| 100/100 [00:29<00:00, 3.40it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for cut_inner_insulation = 1.0\n", "10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp45, nQry5: loss = 0.20609889924526215: 100%|█| 100/100 [00:29<00:00, 3.37it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for poke_insulation = 1.0\n", "12\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp45, nQry7: loss = 0.12025140225887299: 100%|█| 100/100 [00:29<00:00, 3.34it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for cable_swap = 0.8571428571428571\n", "10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp45, nQry5: loss = 0.2130972295999527: 100%|██| 100/100 [00:30<00:00, 3.30it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for cut_outer_insulation = 1.0\n", "58\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp45, nQry53: loss = 0.13926956057548523: 100%|█| 100/100 [00:30<00:00, 3.30it/s\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for good = 0.16981132075471697\n", "12\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp45, nQry7: loss = 0.16337624192237854: 100%|█| 100/100 [00:30<00:00, 3.28it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for missing_cable = 1.0\n", "11\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp45, nQry6: loss = 0.16593313217163086: 100%|█| 100/100 [00:30<00:00, 3.27it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for combined = 1.0\n", "13\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp45, nQry8: loss = 0.16560573875904083: 100%|█| 100/100 [00:30<00:00, 3.27it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for bent_wire = 1.0\n", "10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp45, nQry5: loss = 0.18611018359661102: 100%|█| 100/100 [00:30<00:00, 3.28it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for missing_wire = 0.8\n", "overall accuracy: 0.8696615753219527\n", "14\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp50, nQry9: loss = 0.3357824385166168: 100%|██| 100/100 [00:33<00:00, 2.96it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for cut_inner_insulation = 0.7777777777777778\n", "10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp50, nQry5: loss = 0.3290153741836548: 100%|██| 100/100 [00:33<00:00, 2.96it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for poke_insulation = 0.6\n", "12\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp50, nQry7: loss = 0.22177687287330627: 100%|█| 100/100 [00:33<00:00, 2.96it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for cable_swap = 0.8571428571428571\n", "10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp50, nQry5: loss = 0.299775630235672: 100%|███| 100/100 [00:33<00:00, 2.96it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for cut_outer_insulation = 1.0\n", "58\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp50, nQry53: loss = 0.31954386830329895: 100%|█| 100/100 [00:33<00:00, 2.98it/s\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for good = 0.32075471698113206\n", "12\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp50, nQry7: loss = 0.336273193359375: 100%|███| 100/100 [00:33<00:00, 2.98it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for missing_cable = 0.8571428571428571\n", "11\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp50, nQry6: loss = 0.3643767237663269: 100%|██| 100/100 [00:33<00:00, 2.98it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for combined = 1.0\n", "13\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp50, nQry8: loss = 0.3085792660713196: 100%|██| 100/100 [00:33<00:00, 2.98it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for bent_wire = 1.0\n", "10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp50, nQry5: loss = 0.34715649485588074: 100%|█| 100/100 [00:33<00:00, 2.98it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for missing_wire = 0.8\n", "overall accuracy: 0.8014242454494026\n", "14\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp55, nQry9: loss = 0.375447154045105: 100%|███| 100/100 [00:36<00:00, 2.76it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for cut_inner_insulation = 0.6666666666666666\n", "10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp55, nQry5: loss = 0.42370423674583435: 100%|█| 100/100 [00:36<00:00, 2.75it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for poke_insulation = 1.0\n", "12\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp55, nQry7: loss = 0.3982161581516266: 100%|██| 100/100 [00:36<00:00, 2.74it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for cable_swap = 0.8571428571428571\n", "10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp55, nQry5: loss = 0.3903641104698181: 100%|██| 100/100 [00:36<00:00, 2.75it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for cut_outer_insulation = 1.0\n", "58\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp55, nQry53: loss = 0.4019339382648468: 100%|█| 100/100 [00:36<00:00, 2.75it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for good = 0.41509433962264153\n", "12\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp55, nQry7: loss = 0.4283098876476288: 100%|██| 100/100 [00:36<00:00, 2.75it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for missing_cable = 0.7142857142857143\n", "11\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp55, nQry6: loss = 0.3741377890110016: 100%|██| 100/100 [00:36<00:00, 2.74it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for combined = 0.8333333333333334\n", "13\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp55, nQry8: loss = 0.3858358860015869: 100%|██| 100/100 [00:36<00:00, 2.75it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for bent_wire = 1.0\n", "10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp55, nQry5: loss = 0.3570959270000458: 100%|██| 100/100 [00:36<00:00, 2.74it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for missing_wire = 0.8\n", "overall accuracy: 0.8096136567834681\n", "14\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp70, nQry9: loss = 0.5021733045578003: 100%|██| 100/100 [00:45<00:00, 2.21it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for cut_inner_insulation = 0.5555555555555556\n", "10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp70, nQry5: loss = 0.5203520059585571: 100%|██| 100/100 [00:45<00:00, 2.20it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for poke_insulation = 0.4\n", "12\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp70, nQry7: loss = 0.524366021156311: 100%|███| 100/100 [00:45<00:00, 2.21it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for cable_swap = 0.42857142857142855\n", "10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp70, nQry5: loss = 0.5256413221359253: 100%|██| 100/100 [00:45<00:00, 2.21it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for cut_outer_insulation = 1.0\n", "58\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp70, nQry53: loss = 0.5186663866043091: 100%|█| 100/100 [00:45<00:00, 2.21it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for good = 0.7358490566037735\n", "12\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp70, nQry7: loss = 0.5123675465583801: 100%|██| 100/100 [00:45<00:00, 2.21it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for missing_cable = 0.7142857142857143\n", "11\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp70, nQry6: loss = 0.5076506733894348: 100%|██| 100/100 [00:45<00:00, 2.21it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for combined = 0.8333333333333334\n", "13\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp70, nQry8: loss = 0.490247517824173: 100%|███| 100/100 [00:45<00:00, 2.21it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for bent_wire = 0.875\n", "10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "lr0.0001, nSupp70, nQry5: loss = 0.3723257780075073: 100%|██| 100/100 [00:45<00:00, 2.21it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "accuracy for missing_wire = 0.4\n", "overall accuracy: 0.6602883431499785\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "#bottle_accs = []\n", "cable_accs = []\n", "\n", "for nr in [5, 10, 15, 30]:\n", " folder = \"cable\"\n", " shot = 5\n", " supp_x, supp_y, types = build_sup_set(shot, folder, True, nr)\n", " accs = []\n", " for t in next(os.walk(f'data_custom/{folder}/test'))[1]:\n", " #if t == \"good\":\n", " # continue\n", " accuracy = test(t, types.get(t, 1), shot, folder)\n", " print(f\"accuracy for {t} = {accuracy}\")\n", " accs.append(accuracy)\n", " print(f\"overall accuracy: {np.mean(accs)}\")\n", " cable_accs.append(np.mean(accs))\n" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.57380952 0.76705653 0.84191176]\n" ] } ], "source": [ "print(np.array(bottle_accs))" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.86966158 0.80142425 0.80961366 0.66028834]\n" ] } ], "source": [ "print(np.array(cable_accs))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "P>M>F:\n", "Resulsts:\n", "\n", "bottle:\n", "jeweils 1,3,5 shots normal\n", "[0.67910401 0.71710526 0.78860294]\n", "\n", "inbalanced - mehr good shots 5,10,15,30 -> alle anderen nur 5\n", "[0.78768382 0.78860294 0.75827206 0.74356618]\n", "\n", "2 ways nur detektieren ob fehlerhaft oder nicht 1,3,5 shots\n", "[0.86422306 0.93201754 0.93933824]\n", "\n", "inbalance 2 way 5,10,15,30 -> rest 5\n", "[0.92371324 0.87867647 0.86397059 0.87775735]\n", "\n", "nur fehlerklasse erkennen 1,3,5\n", "[0.57380952 0.76705653 0.84191176]\n", "\n", "\n", "cable:\n", "jeweils 1,3,5 shots normal\n", "[0.25199021 0.44388328 0.46975059]\n", "\n", "inbalanced - mehr good shots 5,10,15,30 -> alle anderen nur 5\n", "[0.50425859 0.48023277 0.43118282 0.41842534]\n", "\n", "2 ways nur detektieren ob fehlerhaft oder nicht 1,3,5 shots\n", "[0.79263485 0.8707712 0.86756514]\n", "\n", "inbalance 2 way 5,10,15,30 -> rest 5\n", "[0.86966158 0.80142425 0.80961366 0.66028834]\n", "\n", "nur fehlerklasse erkennen 1,3,5\n", "[0.24383256 0.43800505 0.51304563]\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.1" } }, "nbformat": 4, "nbformat_minor": 4 }