{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import time\n",
    "import random\n",
    "import torch\n",
    "import torchvision.transforms as transforms\n",
    "#import gradio as gr\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from models import get_model\n",
    "from dotmap import DotMap\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Pretrained weights found at dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth\n"
     ]
    }
   ],
   "source": [
    "args = DotMap()\n",
    "args.deploy = 'finetune'\n",
    "args.arch = 'dino_base_patch16'\n",
    "args.no_pretrain = True\n",
    "small = \"https://huggingface.co/hushell/pmf_metadataset_dino/resolve/main/md_full_128x128_dinosmall_fp16_lr5e-5/best.pth?download=true\"\n",
    "full = 'https://huggingface.co/hushell/pmf_metadataset_dino/resolve/main/md_full_128x128_dinobase_fp16_lr5e-5/best.pth?download=true'\n",
    "args.resume = full\n",
    "args.api_key = 'AIzaSyAFkOGnXhy-2ZB0imDvNNqf2rHb98vR_qY'\n",
    "args.cx = '06d75168141bc47f1'\n",
    "\n",
    "args.ada_steps = 100\n",
    "#args.ada_lr= 0.0001\n",
    "#args.aug_prob = .95\n",
    "args.ada_lr= 0.0001\n",
    "args.aug_prob = .9\n",
    "args.aug_types = [\"color\", \"translation\"]\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "model = get_model(args)\n",
    "model.to(device)\n",
    "checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu')\n",
    "model.load_state_dict(checkpoint['model'], strict=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# image transforms\n",
    "def test_transform():\n",
    "    def _convert_image_to_rgb(im):\n",
    "        return im.convert('RGB')\n",
    "\n",
    "    return transforms.Compose([\n",
    "        transforms.Resize(224),\n",
    "        #transforms.CenterCrop(224),\n",
    "        _convert_image_to_rgb,\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
    "                             std=[0.229, 0.224, 0.225]),\n",
    "        ])\n",
    "\n",
    "preprocess = test_transform()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_sup_set(shotnr, type_name, binary, good_sample_nr):\n",
    "    classes = next(os.walk(f'data_custom/{type_name}/test'))[1]\n",
    "    classes.remove(\"good\")\n",
    "\n",
    "    supp_x = []\n",
    "    supp_y = []\n",
    "    mapping = {\n",
    "        \"good\" : 0\n",
    "    }\n",
    "\n",
    "    # add good manually\n",
    "    x_good = [Image.open(f\"data_custom/{type_name}/train/good/{x:03d}.png\") for x in range(0, good_sample_nr)]\n",
    "    supp_x.extend([preprocess(x) for x in x_good]) # (3, H, W))\n",
    "    supp_y.extend([0] * good_sample_nr)\n",
    "    \n",
    "    for i,c in enumerate(classes):\n",
    "        #i-=1\n",
    "        x_im = [Image.open(f\"data_custom/{type_name}/test/{c}/{x:03d}.png\") for x in range(0, shotnr)]\n",
    "        supp_x.extend([preprocess(x) for x in x_im]) # (3, H, W))\n",
    "        if binary:\n",
    "            supp_y.extend([1] * shotnr)\n",
    "            mapping[\"anomaly\"] = 1\n",
    "        else:\n",
    "            supp_y.extend([i+1] * shotnr)\n",
    "            mapping[c] = i+1\n",
    "    \n",
    "    supp_x = torch.stack(supp_x, dim=0).unsqueeze(0).to(device) # (1, n_supp*n_labels, 3, H, W)\n",
    "    supp_y = torch.tensor(supp_y).long().unsqueeze(0).to(device) # (1, n_supp*n_labels)\n",
    "    return supp_x, supp_y, mapping\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def build_test_set(shotnr, keyy, type):\n",
    "    _, _, files = next(os.walk(f\"data_custom/cable/test/{type}/\"))\n",
    "    file_count = len(files)\n",
    "    print(file_count)\n",
    "\n",
    "    queries = [preprocess(Image.open(f\"data_custom/cable/test/{type}/{i:03d}.png\")).unsqueeze(0).unsqueeze(0).to(device) for i in range(shotnr,file_count)]\n",
    "    labels = [keyy for x in range(shotnr,file_count)]\n",
    "    return queries, labels\n",
    "\n",
    "def test(type, keyy, shotnr, folder):\n",
    "    predictions = []\n",
    "    _, _, files = next(os.walk(f\"data_custom/{folder}/test/{type}/\"))\n",
    "    file_count = len(files)\n",
    "    print(file_count)\n",
    "\n",
    "    queries = [preprocess(Image.open(f\"data_custom/{folder}/test/{type}/{i:03d}.png\")).unsqueeze(0).unsqueeze(0).to(device) for i in range(shotnr,file_count)]\n",
    "    queries = torch.cat(queries)\n",
    "    with torch.cuda.amp.autocast(True):\n",
    "        output = model(supp_x, supp_y, queries) # (1, 1, n_labels)\n",
    "\n",
    "    probs = output.softmax(dim=-1).detach().cpu().numpy()\n",
    "    predictions = np.argmax(probs, axis=2)\n",
    "    print()\n",
    "    return np.mean([x == keyy for x in predictions])\n",
    "    pass\n",
    "    \n",
    "#def test2(folder):\n",
    "#    accs = []\n",
    "#    queries = []\n",
    "#    labels = []\n",
    "#    for t in next(os.walk(f'data_custom/cable/test'))[1]:\n",
    "#        q, l = build_test_set(shots, types.get(t, 1), t)\n",
    "#        queries+=q\n",
    "#        labels+=l\n",
    "#\n",
    "#    queries = torch.cat(queries)\n",
    "#    labels = np.array(labels)\n",
    "#\n",
    "#    with torch.cuda.amp.autocast(True):\n",
    "#        output = model(supp_x, supp_y, queries) # (1, 1, n_labels)\n",
    "#\n",
    "#    probs = output.softmax(dim=-1).detach().cpu().numpy()\n",
    "#    predictions = np.argmax(probs, axis=2)\n",
    "#    print()\n",
    "#    return np.mean([predictions == labels])\n",
    "#    pass\n",
    "\n",
    "#print(f\"overall accuracy: {test(\"cable\")}\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "14\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp45, nQry9: loss = 0.1475423127412796: 100%|██| 100/100 [00:29<00:00,  3.40it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for cut_inner_insulation = 1.0\n",
      "10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp45, nQry5: loss = 0.20609889924526215: 100%|█| 100/100 [00:29<00:00,  3.37it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for poke_insulation = 1.0\n",
      "12\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp45, nQry7: loss = 0.12025140225887299: 100%|█| 100/100 [00:29<00:00,  3.34it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for cable_swap = 0.8571428571428571\n",
      "10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp45, nQry5: loss = 0.2130972295999527: 100%|██| 100/100 [00:30<00:00,  3.30it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for cut_outer_insulation = 1.0\n",
      "58\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp45, nQry53: loss = 0.13926956057548523: 100%|█| 100/100 [00:30<00:00,  3.30it/s\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for good = 0.16981132075471697\n",
      "12\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp45, nQry7: loss = 0.16337624192237854: 100%|█| 100/100 [00:30<00:00,  3.28it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for missing_cable = 1.0\n",
      "11\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp45, nQry6: loss = 0.16593313217163086: 100%|█| 100/100 [00:30<00:00,  3.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for combined = 1.0\n",
      "13\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp45, nQry8: loss = 0.16560573875904083: 100%|█| 100/100 [00:30<00:00,  3.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for bent_wire = 1.0\n",
      "10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp45, nQry5: loss = 0.18611018359661102: 100%|█| 100/100 [00:30<00:00,  3.28it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for missing_wire = 0.8\n",
      "overall accuracy: 0.8696615753219527\n",
      "14\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp50, nQry9: loss = 0.3357824385166168: 100%|██| 100/100 [00:33<00:00,  2.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for cut_inner_insulation = 0.7777777777777778\n",
      "10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp50, nQry5: loss = 0.3290153741836548: 100%|██| 100/100 [00:33<00:00,  2.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for poke_insulation = 0.6\n",
      "12\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp50, nQry7: loss = 0.22177687287330627: 100%|█| 100/100 [00:33<00:00,  2.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for cable_swap = 0.8571428571428571\n",
      "10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp50, nQry5: loss = 0.299775630235672: 100%|███| 100/100 [00:33<00:00,  2.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for cut_outer_insulation = 1.0\n",
      "58\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp50, nQry53: loss = 0.31954386830329895: 100%|█| 100/100 [00:33<00:00,  2.98it/s\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for good = 0.32075471698113206\n",
      "12\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp50, nQry7: loss = 0.336273193359375: 100%|███| 100/100 [00:33<00:00,  2.98it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for missing_cable = 0.8571428571428571\n",
      "11\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp50, nQry6: loss = 0.3643767237663269: 100%|██| 100/100 [00:33<00:00,  2.98it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for combined = 1.0\n",
      "13\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp50, nQry8: loss = 0.3085792660713196: 100%|██| 100/100 [00:33<00:00,  2.98it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for bent_wire = 1.0\n",
      "10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp50, nQry5: loss = 0.34715649485588074: 100%|█| 100/100 [00:33<00:00,  2.98it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for missing_wire = 0.8\n",
      "overall accuracy: 0.8014242454494026\n",
      "14\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp55, nQry9: loss = 0.375447154045105: 100%|███| 100/100 [00:36<00:00,  2.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for cut_inner_insulation = 0.6666666666666666\n",
      "10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp55, nQry5: loss = 0.42370423674583435: 100%|█| 100/100 [00:36<00:00,  2.75it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for poke_insulation = 1.0\n",
      "12\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp55, nQry7: loss = 0.3982161581516266: 100%|██| 100/100 [00:36<00:00,  2.74it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for cable_swap = 0.8571428571428571\n",
      "10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp55, nQry5: loss = 0.3903641104698181: 100%|██| 100/100 [00:36<00:00,  2.75it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for cut_outer_insulation = 1.0\n",
      "58\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp55, nQry53: loss = 0.4019339382648468: 100%|█| 100/100 [00:36<00:00,  2.75it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for good = 0.41509433962264153\n",
      "12\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp55, nQry7: loss = 0.4283098876476288: 100%|██| 100/100 [00:36<00:00,  2.75it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for missing_cable = 0.7142857142857143\n",
      "11\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp55, nQry6: loss = 0.3741377890110016: 100%|██| 100/100 [00:36<00:00,  2.74it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for combined = 0.8333333333333334\n",
      "13\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp55, nQry8: loss = 0.3858358860015869: 100%|██| 100/100 [00:36<00:00,  2.75it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for bent_wire = 1.0\n",
      "10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp55, nQry5: loss = 0.3570959270000458: 100%|██| 100/100 [00:36<00:00,  2.74it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for missing_wire = 0.8\n",
      "overall accuracy: 0.8096136567834681\n",
      "14\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp70, nQry9: loss = 0.5021733045578003: 100%|██| 100/100 [00:45<00:00,  2.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for cut_inner_insulation = 0.5555555555555556\n",
      "10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp70, nQry5: loss = 0.5203520059585571: 100%|██| 100/100 [00:45<00:00,  2.20it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for poke_insulation = 0.4\n",
      "12\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp70, nQry7: loss = 0.524366021156311: 100%|███| 100/100 [00:45<00:00,  2.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for cable_swap = 0.42857142857142855\n",
      "10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp70, nQry5: loss = 0.5256413221359253: 100%|██| 100/100 [00:45<00:00,  2.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for cut_outer_insulation = 1.0\n",
      "58\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp70, nQry53: loss = 0.5186663866043091: 100%|█| 100/100 [00:45<00:00,  2.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for good = 0.7358490566037735\n",
      "12\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp70, nQry7: loss = 0.5123675465583801: 100%|██| 100/100 [00:45<00:00,  2.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for missing_cable = 0.7142857142857143\n",
      "11\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp70, nQry6: loss = 0.5076506733894348: 100%|██| 100/100 [00:45<00:00,  2.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for combined = 0.8333333333333334\n",
      "13\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp70, nQry8: loss = 0.490247517824173: 100%|███| 100/100 [00:45<00:00,  2.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for bent_wire = 0.875\n",
      "10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "lr0.0001, nSupp70, nQry5: loss = 0.3723257780075073: 100%|██| 100/100 [00:45<00:00,  2.21it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "accuracy for missing_wire = 0.4\n",
      "overall accuracy: 0.6602883431499785\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "#bottle_accs = []\n",
    "cable_accs = []\n",
    "\n",
    "for nr in [5, 10, 15, 30]:\n",
    "    folder = \"cable\"\n",
    "    shot = 5\n",
    "    supp_x, supp_y, types = build_sup_set(shot, folder, True, nr)\n",
    "    accs = []\n",
    "    for t in next(os.walk(f'data_custom/{folder}/test'))[1]:\n",
    "        #if t == \"good\":\n",
    "        #    continue\n",
    "        accuracy = test(t, types.get(t, 1), shot, folder)\n",
    "        print(f\"accuracy for {t} = {accuracy}\")\n",
    "        accs.append(accuracy)\n",
    "    print(f\"overall accuracy: {np.mean(accs)}\")\n",
    "    cable_accs.append(np.mean(accs))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.57380952 0.76705653 0.84191176]\n"
     ]
    }
   ],
   "source": [
    "print(np.array(bottle_accs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.86966158 0.80142425 0.80961366 0.66028834]\n"
     ]
    }
   ],
   "source": [
    "print(np.array(cable_accs))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "P>M>F:\n",
    "Resulsts:\n",
    "\n",
    "bottle:\n",
    "jeweils 1,3,5 shots normal\n",
    "[0.67910401 0.71710526 0.78860294]\n",
    "\n",
    "inbalanced - mehr good shots 5,10,15,30 -> alle anderen nur 5\n",
    "[0.78768382 0.78860294 0.75827206 0.74356618]\n",
    "\n",
    "2 ways nur detektieren ob fehlerhaft oder nicht 1,3,5 shots\n",
    "[0.86422306 0.93201754 0.93933824]\n",
    "\n",
    "inbalance 2 way 5,10,15,30 -> rest 5\n",
    "[0.92371324 0.87867647 0.86397059 0.87775735]\n",
    "\n",
    "nur fehlerklasse erkennen 1,3,5\n",
    "[0.57380952 0.76705653 0.84191176]\n",
    "\n",
    "\n",
    "cable:\n",
    "jeweils 1,3,5 shots normal\n",
    "[0.25199021 0.44388328 0.46975059]\n",
    "\n",
    "inbalanced - mehr good shots 5,10,15,30 -> alle anderen nur 5\n",
    "[0.50425859 0.48023277 0.43118282 0.41842534]\n",
    "\n",
    "2 ways nur detektieren ob fehlerhaft oder nicht 1,3,5 shots\n",
    "[0.79263485 0.8707712  0.86756514]\n",
    "\n",
    "inbalance 2 way 5,10,15,30 -> rest 5\n",
    "[0.86966158 0.80142425 0.80961366 0.66028834]\n",
    "\n",
    "nur fehlerklasse erkennen 1,3,5\n",
    "[0.24383256 0.43800505 0.51304563]\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}