shravvvv commited on
Commit
46de2be
·
verified ·
1 Parent(s): 80bbf6c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +149 -5
README.md CHANGED
@@ -1,9 +1,153 @@
1
  ---
2
  tags:
3
- - model_hub_mixin
4
- - pytorch_model_hub_mixin
 
 
 
 
 
 
 
 
5
  ---
6
 
7
- This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
- - Library: [More Information Needed]
9
- - Docs: [More Information Needed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  tags:
3
+ - image_classification
4
+ - computer_vision
5
+ license: mit
6
+ datasets:
7
+ - p2pfl/CIFAR10
8
+ language:
9
+ - en
10
+ pipeline_tag: image-classification
11
+ metrics:
12
+ - f1
13
  ---
14
 
15
+ # SAG-ViT: A Scale-Aware, High-Fidelity Patching Approach with Graph Attention for Vision Transformers
16
+
17
+ ### Model Description
18
+
19
+ Implementation of the ***SAG-ViT*** model as proposed in the [SAG-ViT: A Scale-Aware, High-Fidelity Patching Approach with Graph Attention for Vision Transformers](https://arxiv.org/abs/2411.09420) paper.
20
+
21
+ It is a novel transformer framework designed to enhance Vision Transformers (ViT) with scale-awareness and refined patch-level feature embeddings. It extracts multiscale features using EfficientNetV2 organizes patches into a graph based on spatial relationships, and refines them with a Graph Attention Network (GAT). A Transformer encoder then integrates these embeddings globally, capturing long-range dependencies for comprehensive image understanding.
22
+
23
+ ### Model Architecture
24
+
25
+ ![SAGViTArchitecture](https://github.com/shravan-18/SAG-ViT/blob/main/images/SAG-ViT.png)
26
+
27
+ _Image source: [SAG-ViT: A Scale-Aware, High-Fidelity Patching Approach with Graph Attention for Vision Transformers](https://arxiv.org/abs/2411.09420)_
28
+
29
+ ### Usage
30
+
31
+ SAG-ViT expect input images normalized in the same way,
32
+ i.e. mini-batches of 3-channel RGB images of shape `(N, 3, H, W)`, where `N` is the number of images, `H` and `W` are expected to be at least `49` pixels.
33
+ The images have to be loaded in to a range of `[0, 1]` and then normalized using `mean = [0.485, 0.456, 0.406]`
34
+ and `std = [0.229, 0.224, 0.225]`.
35
+
36
+ To train or run inference on our model, refer to the following steps:
37
+
38
+ Clone our repository and load the model pretrained on CIFAR-10 dataset.
39
+ ```bash
40
+ git clone https://huggingface.co/shravvvv/SAG-ViT
41
+ cd SAG-ViT
42
+ ```
43
+
44
+ Install required dependencies.
45
+ ```bash
46
+ pip install -r requirements.txt
47
+ ```
48
+
49
+ Use `from_pretrained` to load the model from Hugging Face Hub and run inference on a sample input image.
50
+ ```python
51
+ from transformers import AutoModel, AutoConfig
52
+ from PIL import Image
53
+ from torchvision import transforms
54
+ import torch
55
+
56
+ # Step 1: Load the model and configuration directly from Hugging Face Hub
57
+ repo_name = "shravvvv/SAG-ViT"
58
+ config = AutoConfig.from_pretrained(repo_name) # Load config from hub
59
+ model = AutoModel.from_pretrained(repo_name, config=config) # Load model from hub
60
+
61
+ # Step 2: Define the transformation for the input image
62
+ transform = transforms.Compose([
63
+ transforms.Resize((224, 224)), # Resize to match the expected input size
64
+ transforms.ToTensor(),
65
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Example normalization
66
+ ])
67
+
68
+ # Step 3: Load and preprocess the input image
69
+ input_image_path = "path/to/your/image.jpg"
70
+ img = Image.open(input_image_path).convert("RGB")
71
+ img = transform(img).unsqueeze(0) # Add batch dimension
72
+
73
+ # Step 4: Ensure the model is in evaluation mode
74
+ model.eval()
75
+
76
+ # Step 5: Run inference
77
+ with torch.no_grad():
78
+ outputs = model(img)
79
+ logits = outputs.logits # Accessing logits from ModelOutput
80
+
81
+ # Step 6: Post-process the predictions
82
+ predicted_class_index = torch.argmax(logits, dim=1) # Get the predicted class index
83
+
84
+ # CIFAR-10 label mapping
85
+ class_names = [
86
+ 'airplane', 'automobile', 'bird', 'cat', 'deer',
87
+ 'dog', 'frog', 'horse', 'ship', 'truck'
88
+ ]
89
+
90
+ # Get the predicted class name from the class index
91
+ predicted_class_name = class_names[predicted_class_index.item()]
92
+ print(f"Predicted class: {predicted_class_name}")
93
+ ```
94
+
95
+ ### Running Tests
96
+
97
+ If you clone our [repository](https://github.com/shravan-18/SAG-ViT), the *'tests'* folder will contain unit tests for each of our model's modules. Make sure you have a proper Python environment with the required dependencies installed. Then run:
98
+ ```bash
99
+ python -m unittest discover -s tests
100
+ ```
101
+
102
+ or, if you are using `pytest`, you can run:
103
+ ```bash
104
+ pytest tests
105
+ ```
106
+
107
+ **Results**
108
+ We evaluated SAG-ViT on diverse datasets:
109
+ - **CIFAR-10** (natural images)
110
+ - **GTSRB** (traffic sign recognition)
111
+ - **NCT-CRC-HE-100K** (histopathological images)
112
+ - **NWPU-RESISC45** (remote sensing imagery)
113
+ - **PlantVillage** (agricultural imagery)
114
+
115
+ SAG-ViT achieves state-of-the-art results across all benchmarks, as shown in the table below (F1 scores):
116
+
117
+ <center>
118
+
119
+ | Backbone | CIFAR-10 | GTSRB | NCT-CRC-HE-100K | NWPU-RESISC45 | PlantVillage |
120
+ |--------------------|----------|--------|-----------------|---------------|--------------|
121
+ | DenseNet201 | 0.5427 | 0.9862 | 0.9214 | 0.4493 | 0.8725 |
122
+ | Vgg16 | 0.5345 | 0.8180 | 0.8234 | 0.4114 | 0.7064 |
123
+ | Vgg19 | 0.5307 | 0.7551 | 0.8178 | 0.3844 | 0.6811 |
124
+ | DenseNet121 | 0.5290 | 0.9813 | 0.9247 | 0.4381 | 0.8321 |
125
+ | AlexNet | 0.6126 | 0.9059 | 0.8743 | 0.4397 | 0.7684 |
126
+ | Inception | 0.7734 | 0.8934 | 0.8707 | 0.8707 | 0.8216 |
127
+ | ResNet | 0.9172 | 0.9134 | 0.9478 | 0.9103 | 0.8905 |
128
+ | MobileNet | 0.9169 | 0.3006 | 0.4965 | 0.1667 | 0.2213 |
129
+ | ViT - S | 0.8465 | 0.8542 | 0.8234 | 0.6116 | 0.8654 |
130
+ | ViT - L | 0.8637 | 0.8613 | 0.8345 | 0.8358 | 0.8842 |
131
+ | MNASNet1_0 | 0.1032 | 0.0024 | 0.0212 | 0.0011 | 0.0049 |
132
+ | ShuffleNet_V2_x1_0 | 0.3523 | 0.4244 | 0.4598 | 0.1808 | 0.3190 |
133
+ | SqueezeNet1_0 | 0.4328 | 0.8392 | 0.7843 | 0.3913 | 0.6638 |
134
+ | GoogLeNet | 0.4954 | 0.9455 | 0.8631 | 0.3720 | 0.7726 |
135
+ | **Proposed (SAG-ViT)** | **0.9574** | **0.9958** | **0.9861** | **0.9549** | **0.9772** |
136
+
137
+ </center>
138
+
139
+ ## Citation
140
+
141
+ If you find our [paper](https://arxiv.org/abs/2411.09420) and [code](https://github.com/shravan-18/SAG-ViT) helpful for your research, please consider citing our work and giving the repository a star:
142
+
143
+ ```bibtex
144
+ @misc{SAGViT,
145
+ title={SAG-ViT: A Scale-Aware, High-Fidelity Patching Approach with Graph Attention for Vision Transformers},
146
+ author={Shravan Venkatraman and Jaskaran Singh Walia and Joe Dhanith P R},
147
+ year={2024},
148
+ eprint={2411.09420},
149
+ archivePrefix={arXiv},
150
+ primaryClass={cs.CV},
151
+ url={https://arxiv.org/abs/2411.09420},
152
+ }
153
+ ```