2024-12-30 18:34:43 +01:00
|
|
|
#import "@preview/fletcher:0.5.3" as fletcher: diagram, node, edge
|
|
|
|
#import fletcher.shapes: rect, diamond
|
|
|
|
#import "utils.typ": todo
|
2025-01-01 20:50:52 +01:00
|
|
|
#import "@preview/subpar:0.1.1"
|
2024-12-30 18:34:43 +01:00
|
|
|
|
2025-01-07 18:04:04 +01:00
|
|
|
= Implementation <sectionimplementation>
|
2024-12-30 18:34:43 +01:00
|
|
|
The three methods described (ResNet50, CAML, P>M>F) were implemented in a Jupyter notebook and compared to each other.
|
|
|
|
|
2025-01-03 21:48:48 +01:00
|
|
|
== Experiments <experiments>
|
2025-01-14 19:22:15 +01:00
|
|
|
For all of the three methods we test the following use-cases:
|
2024-12-30 18:34:43 +01:00
|
|
|
- Detection of anomaly class (1,3,5 shots)
|
2025-01-14 19:22:15 +01:00
|
|
|
- Every faulty class and the good class is detected.
|
2024-12-30 18:34:43 +01:00
|
|
|
- 2 Way classification (1,3,5 shots)
|
2025-01-14 19:22:15 +01:00
|
|
|
- Only faulty or not faulty is detected. All the samples of the faulty classes are treated as a single class.
|
2024-12-30 18:34:43 +01:00
|
|
|
- Detect only anomaly classes (1,3,5 shots)
|
2025-01-14 19:22:15 +01:00
|
|
|
- Similar to the first test but without the good class. Only faulty classes are detected.
|
2025-01-01 20:50:52 +01:00
|
|
|
- Inbalanced 2 Way classification (5,10,15,30 good shots, 5 bad shots)
|
2025-01-14 19:22:15 +01:00
|
|
|
- Similar to the 2 way classification but with an inbalanced number of good shots.
|
|
|
|
- Inbalanced target class prediction (5,10,15,30 good shots, 5 bad shots)#todo[Avoid bullet points and write flow text?]
|
|
|
|
- Detect only the faulty classes without the good classed with an inbalanced number of shots.
|
2024-12-30 18:34:43 +01:00
|
|
|
|
2025-01-14 19:22:15 +01:00
|
|
|
All those experiments were conducted on the MVTEC AD dataset on the bottle and cable classes.
|
2024-12-31 12:23:53 +01:00
|
|
|
|
|
|
|
== Experiment Setup
|
2025-01-01 20:50:52 +01:00
|
|
|
All the experiments were done on the bottle and cable classes of the MVTEC AD dataset.
|
|
|
|
The correspoinding number of shots were randomly selected from the dataset.
|
2025-01-03 21:48:48 +01:00
|
|
|
The rest of the images was used to test the model and measure the accuracy.
|
2025-01-01 20:50:52 +01:00
|
|
|
#todo[Maybe add real number of samples per classes]
|
2024-12-31 12:23:53 +01:00
|
|
|
|
2025-01-14 19:22:15 +01:00
|
|
|
== ResNet50 <resnet50impl>
|
2024-12-30 18:34:43 +01:00
|
|
|
=== Approach
|
|
|
|
The simplest approach is to use a pre-trained ResNet50 model as a feature extractor.
|
|
|
|
From both the support and query set the features are extracted to get a downprojected representation of the images.
|
|
|
|
The support set embeddings are compared to the query set embeddings.
|
|
|
|
To predict the class of a query the class with the smallest distance to the support embedding is chosen.
|
|
|
|
If there are more than one support embedding within the same class the mean of those embeddings is used (class center).
|
2025-01-14 19:22:15 +01:00
|
|
|
This approach is similar to a prototypical network @snell2017prototypicalnetworksfewshotlearning and the work of _Just Use a Library of Pre-trained Feature
|
|
|
|
Extractors and a Simple Classifier_ @chowdhury2021fewshotimageclassificationjust but just with a simple distance metric instead of a neural net.
|
2024-12-30 18:34:43 +01:00
|
|
|
|
|
|
|
In this bachelor thesis a pre-trained ResNet50 (IMAGENET1K_V2) pytorch model was used.
|
|
|
|
It is pretrained on the imagenet dataset and has 50 residual layers.
|
|
|
|
|
|
|
|
To get the embeddings the last layer of the model was removed and the output of the second last layer was used as embedding output.
|
2025-01-14 19:22:15 +01:00
|
|
|
In the following diagram the ResNet50 architecture is visualized and the cut-point is marked.~@chowdhury2021fewshotimageclassificationjust
|
2024-12-30 18:34:43 +01:00
|
|
|
|
|
|
|
#diagram(
|
|
|
|
spacing: (5mm, 5mm),
|
|
|
|
node-stroke: 1pt,
|
|
|
|
node-fill: eastern,
|
|
|
|
edge-stroke: 1pt,
|
|
|
|
|
|
|
|
// Input
|
|
|
|
node((1, 1), "Input", shape: rect, width: 30mm, height: 10mm, name: <input>),
|
|
|
|
|
|
|
|
// Conv1
|
|
|
|
node((1, 0), "Conv1\n7x7, 64", shape: rect, width: 30mm, height: 15mm, name: <conv1>),
|
|
|
|
edge(<input>, <conv1>, "->"),
|
|
|
|
|
|
|
|
// MaxPool
|
|
|
|
node((1, -1), "MaxPool\n3x3", shape: rect, width: 30mm, height: 15mm, name: <maxpool>),
|
|
|
|
edge(<conv1>, <maxpool>, "->"),
|
|
|
|
|
|
|
|
// Residual Blocks
|
|
|
|
node((3, -1), "Residual Block 1\n3x [64, 64, 256]", shape: rect, width: 40mm, height: 15mm, name: <res1>),
|
|
|
|
edge(<maxpool>, <res1>, "->"),
|
|
|
|
|
|
|
|
node((3, 0), "Residual Block 2\n4x [128, 128, 512]", shape: rect, width: 40mm, height: 15mm, name: <res2>),
|
|
|
|
edge(<res1>, <res2>, "->"),
|
|
|
|
|
|
|
|
node((3, 1), "Residual Block 3\n6x [256, 256, 1024]", shape: rect, width: 40mm, height: 15mm, name: <res3>),
|
|
|
|
edge(<res2>, <res3>, "->"),
|
|
|
|
|
|
|
|
node((3, 2), "Residual Block 4\n3x [512, 512, 2048]", shape: rect, width: 40mm, height: 15mm, name: <res4>),
|
|
|
|
edge(<res3>, <res4>, "->"),
|
|
|
|
|
|
|
|
// Cutting Line
|
|
|
|
edge(<res4>, <avgpool>, marks: "..|..>", stroke: 1pt, label: "Cut here", label-pos: 0.5, label-side: left),
|
|
|
|
|
|
|
|
// AvgPool + FC
|
|
|
|
node((7, 2), "AvgPool\n1x1", shape: rect, width: 30mm, height: 10mm, name: <avgpool>),
|
|
|
|
//edge(<res4>, <avgpool>, "->"),
|
|
|
|
|
|
|
|
node((7, 1), "Fully Connected\n1000 classes", shape: rect, width: 40mm, height: 10mm, name: <fc>),
|
|
|
|
edge(<avgpool>, <fc>, "->"),
|
|
|
|
|
|
|
|
// Output
|
|
|
|
node((7, 0), "Output", shape: rect, width: 30mm, height: 10mm, name: <output>),
|
|
|
|
edge(<fc>, <output>, "->")
|
|
|
|
)
|
|
|
|
|
|
|
|
After creating the embeddings for the support and query set the euclidean distance is calculated.
|
|
|
|
The class with the smallest distance is chosen as the predicted class.
|
|
|
|
|
2025-01-03 21:48:48 +01:00
|
|
|
=== Results <resnet50perf>
|
2025-01-01 20:50:52 +01:00
|
|
|
This method performed better than expected wich such a simple method.
|
|
|
|
As in @resnet50bottleperfa with a normal 5 shot / 4 way classification the model achieved an accuracy of 75%.
|
|
|
|
When detecting only if there occured an anomaly or not the performance is significantly better and peaks at 81% with 5 shots / 2 ways.
|
|
|
|
Interestintly the model performed slightly better with fewer shots in this case.
|
|
|
|
Moreover in @resnet50bottleperfa, the detection of the anomaly class only (3 way) shows a similar pattern as the normal 4 way classification.
|
|
|
|
The more shots the better the performance and it peaks at around 88% accuracy with 5 shots.
|
|
|
|
|
|
|
|
In @resnet50bottleperfb the model was tested with inbalanced class distributions.
|
|
|
|
With [5,10,15,30] good shots and 5 bad shots the model performed worse than with balanced classes.
|
|
|
|
The more good shots the worse the performance.
|
|
|
|
The only exception is the faulty or not detection (2 way) where the model peaked at 15 good shots with 83% accuracy.
|
|
|
|
|
|
|
|
#subpar.grid(
|
|
|
|
figure(image("rsc/resnet/ResNet50-bottle.png"), caption: [
|
|
|
|
Normal [1,3,5] shots
|
|
|
|
]), <resnet50bottleperfa>,
|
|
|
|
figure(image("rsc/resnet/ResNet50-bottle-inbalanced.png"), caption: [
|
|
|
|
Inbalanced [5,10,15,30] shots
|
|
|
|
]), <resnet50bottleperfb>,
|
|
|
|
columns: (1fr, 1fr),
|
|
|
|
caption: [ResNet50 performance on bottle class],
|
|
|
|
label: <resnet50bottleperf>,
|
|
|
|
)
|
2024-12-30 18:34:43 +01:00
|
|
|
|
2025-01-01 20:50:52 +01:00
|
|
|
The same experiments were conducted on the cable class and the results are shown in @resnet50cableperfa and @resnet50cableperfb.
|
|
|
|
The results are very similar to the bottle class.
|
|
|
|
Generally the more shots the better the accuracy.
|
|
|
|
But the overall reached max accuracy is lower than on the bottle class,
|
|
|
|
but this is expected as the cable class consists of 8 faulty classes.
|
|
|
|
|
|
|
|
#subpar.grid(
|
|
|
|
figure(image("rsc/resnet/ResNet50-cable.png"), caption: [
|
|
|
|
Normal [1,3,5] shots
|
|
|
|
]), <resnet50cableperfa>,
|
|
|
|
figure(image("rsc/resnet/ResNet50-cable-inbalanced.png"), caption: [
|
|
|
|
Inbalanced [5,10,15,30] shots
|
|
|
|
]), <resnet50cableperfb>,
|
|
|
|
columns: (1fr, 1fr),
|
|
|
|
caption: [ResNet50 performance on cable class],
|
|
|
|
label: <resnet50cableperf>,
|
|
|
|
)
|
2024-12-30 18:34:43 +01:00
|
|
|
|
|
|
|
== P>M>F
|
2024-12-31 12:23:53 +01:00
|
|
|
=== Approach
|
2025-01-03 21:48:48 +01:00
|
|
|
For P>M>F the pretrained model weights from the original paper were used.
|
|
|
|
As backbone feature extractor a DINO model is used, which is pre-trained by facebook.
|
|
|
|
This is a vision transformer with a patch size of 16 and 12 attention heads learned in a self-supervised fashion.
|
|
|
|
This feature extractor was meta-trained with 10 public image dasets #footnote[ImageNet-1k, Omniglot, FGVC-
|
|
|
|
Aircraft, CUB-200-2011, Describable Textures, QuickDraw,
|
|
|
|
FGVCx Fungi, VGG Flower, Traffic Signs and MSCOCO~#cite(<pmfpaper>)]
|
|
|
|
of diverse domains by the authors of the original paper.#cite(<pmfpaper>)
|
|
|
|
|
|
|
|
Finally, this model is finetuned with the support set of every test iteration.
|
|
|
|
Everytime the support set changes we need to finetune the model again.
|
|
|
|
In a real world scenario this should not be the case because the support set is fixed and only the query set changes.
|
|
|
|
|
2024-12-31 12:23:53 +01:00
|
|
|
=== Results
|
2025-01-01 20:50:52 +01:00
|
|
|
The results of P>M>F look very promising and improve by a large margin over the ResNet50 method.
|
|
|
|
In @pmfbottleperfa the model reached an accuracy of 79% with 5 shots / 4 way classification.
|
2025-01-14 19:39:41 +01:00
|
|
|
The 2 way classification (faulty or not) performed even better and peaked at 94% accuracy with 5 shots.
|
2025-01-03 21:48:48 +01:00
|
|
|
|
|
|
|
Similar to the ResNet50 method in @resnet50perf the tests with an inbalanced class distribution performed worse than with balanced classes.
|
|
|
|
So it is clearly a bad idea to add more good shots to the support set.
|
2025-01-01 20:50:52 +01:00
|
|
|
|
|
|
|
#subpar.grid(
|
|
|
|
figure(image("rsc/pmf/P>M>F-bottle.png"), caption: [
|
|
|
|
Normal [1,3,5] shots
|
|
|
|
]), <pmfbottleperfa>,
|
|
|
|
figure(image("rsc/pmf/P>M>F-bottle-inbalanced.png"), caption: [
|
|
|
|
Inbalanced [5,10,15,30] shots
|
|
|
|
]), <pmfbottleperfb>,
|
|
|
|
columns: (1fr, 1fr),
|
|
|
|
caption: [P>M>F performance on bottle class],
|
|
|
|
label: <pmfbottleperf>,
|
|
|
|
)
|
|
|
|
|
|
|
|
#subpar.grid(
|
|
|
|
figure(image("rsc/pmf/P>M>F-cable.png"), caption: [
|
|
|
|
Normal [1,3,5] shots
|
|
|
|
]), <pmfcableperfa>,
|
|
|
|
figure(image("rsc/pmf/P>M>F-cable-inbalanced.png"), caption: [
|
|
|
|
Inbalanced [5,10,15,30] shots
|
|
|
|
]), <pmfcableperfb>,
|
|
|
|
columns: (1fr, 1fr),
|
|
|
|
caption: [P>M>F performance on cable class],
|
|
|
|
label: <pmfcableperf>,
|
|
|
|
)
|
2024-10-28 16:02:53 +01:00
|
|
|
|
2024-12-31 12:23:53 +01:00
|
|
|
== CAML
|
|
|
|
=== Approach
|
|
|
|
For the CAML implementation the pretrained model weights from the original paper were used.
|
2025-01-14 19:39:41 +01:00
|
|
|
This brings the limitation of a maximum squence length to the non-causal sequence model.
|
|
|
|
This is the reason why for this method the two imbalanced test cases couldn't be conducted.
|
|
|
|
|
2024-12-31 12:23:53 +01:00
|
|
|
As a feture extractor a ViT-B/16 model was used, which is a Vision Transformer with a patch size of 16.
|
|
|
|
This feature extractor was already pretrained when used by the authors of the original paper.
|
|
|
|
For the non-causal sequence model a transformer model was used
|
|
|
|
It consists of 24 Layers with 16 Attention-heads and a hidden dimension of 1024 and output MLP size of 4096.
|
|
|
|
This transformer was trained on a huge number of images as described in @CAML.
|
2024-10-28 16:02:53 +01:00
|
|
|
|
2024-12-31 12:23:53 +01:00
|
|
|
=== Results
|
|
|
|
The results were not as good as expeced.
|
|
|
|
This might be caused by the fact that the model was not fine-tuned for any industrial dataset domain.
|
|
|
|
The model was trained on a large number of general purpose images and is not fine-tuned at all.
|
|
|
|
It might not handle very similar images well.
|
2024-10-28 16:02:53 +01:00
|
|
|
|
2025-01-01 20:50:52 +01:00
|
|
|
Compared the the other two methods CAML performed poorly in almost all experiments.
|
|
|
|
The normal few-shot classification reached only 40% accuracy in @camlperfa at best.
|
|
|
|
The only test it did surprisingly well was the detection of the anomaly class for the cable class in @camlperfb were it reached almost 60% accuracy.
|
|
|
|
|
|
|
|
#subpar.grid(
|
|
|
|
figure(image("rsc/caml/CAML-bottle.png"), caption: [
|
|
|
|
Normal [1,3,5] shots - Bottle
|
|
|
|
]), <camlperfa>,
|
|
|
|
figure(image("rsc/caml/CAML-cable.png"), caption: [
|
|
|
|
Normal [1,3,5] shots - Cable
|
|
|
|
]), <camlperfb>,
|
|
|
|
columns: (1fr, 1fr),
|
|
|
|
caption: [CAML performance],
|
|
|
|
label: <camlperf>,
|
|
|
|
)
|