Updated stuff
Browse files- sagvit.bin → SAG-ViT.pth +2 -2
- model.safetensors +2 -2
- model_components.py +25 -4
- push_model_to_hfhub.py +12 -0
sagvit.bin → SAG-ViT.pth
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8bdfa65b805d74c284af254153960436014a6e26b740f25cf2c6c9289234d9d3
|
3 |
+
size 27137010
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:234c74e97091f54931e7050f775e56465b0abff0596906ad015c45df31e7b12a
|
3 |
+
size 27009200
|
model_components.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import torch
|
2 |
from torch import nn
|
3 |
import torch.nn.functional as F
|
4 |
-
from torch_geometric.nn import GATConv, global_mean_pool
|
5 |
|
6 |
from torchvision import models
|
7 |
|
@@ -52,15 +52,36 @@ class GATGNN(nn.Module):
|
|
52 |
This module corresponds to the Graph Attention stage (Section 3.3),
|
53 |
refining local relationships between patches in a learned manner.
|
54 |
"""
|
55 |
-
def __init__(self, in_channels, hidden_channels, out_channels, heads=
|
56 |
super(GATGNN, self).__init__()
|
57 |
# GAT layers:
|
58 |
# First layer maps raw patch embeddings to a higher-level representation.
|
59 |
self.conv1 = GATConv(in_channels, hidden_channels, heads=heads)
|
60 |
-
#
|
61 |
-
self.conv2 =
|
62 |
self.pool = global_mean_pool
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
def forward(self, data):
|
65 |
"""
|
66 |
Input:
|
|
|
1 |
import torch
|
2 |
from torch import nn
|
3 |
import torch.nn.functional as F
|
4 |
+
from torch_geometric.nn import GATConv, global_mean_pool, GCNConv
|
5 |
|
6 |
from torchvision import models
|
7 |
|
|
|
52 |
This module corresponds to the Graph Attention stage (Section 3.3),
|
53 |
refining local relationships between patches in a learned manner.
|
54 |
"""
|
55 |
+
def __init__(self, in_channels, hidden_channels, out_channels, heads=4):
|
56 |
super(GATGNN, self).__init__()
|
57 |
# GAT layers:
|
58 |
# First layer maps raw patch embeddings to a higher-level representation.
|
59 |
self.conv1 = GATConv(in_channels, hidden_channels, heads=heads)
|
60 |
+
# Final GCN layer for refined representation
|
61 |
+
self.conv2 = GCNConv(hidden_channels * heads, out_channels)
|
62 |
self.pool = global_mean_pool
|
63 |
|
64 |
+
def forward(self, data):
|
65 |
+
"""
|
66 |
+
Input:
|
67 |
+
- data (PyG Data): Contains x (node features), edge_index (graph edges), and batch indexing.
|
68 |
+
|
69 |
+
Output:
|
70 |
+
- x (Tensor): Aggregated graph-level embedding after mean pooling.
|
71 |
+
"""
|
72 |
+
x, edge_index, batch = data.x, data.edge_index, data.batch
|
73 |
+
|
74 |
+
# GAT layer with ReLU activation
|
75 |
+
x = F.relu(self.conv1(x, edge_index))
|
76 |
+
|
77 |
+
# GCN layer for further aggregation
|
78 |
+
x = self.conv2(x, edge_index)
|
79 |
+
|
80 |
+
# Global mean pooling to obtain graph-level representation
|
81 |
+
out = self.pool(x, batch)
|
82 |
+
|
83 |
+
return out
|
84 |
+
|
85 |
def forward(self, data):
|
86 |
"""
|
87 |
Input:
|
push_model_to_hfhub.py
CHANGED
@@ -1,13 +1,25 @@
|
|
|
|
1 |
from transformers import AutoConfig, AutoModel
|
2 |
from modeling_sagvit import SAGViTClassifier, SAGViTConfig
|
3 |
|
4 |
|
|
|
5 |
AutoConfig.register("sagvit", SAGViTConfig)
|
6 |
AutoModel.register(SAGViTConfig, SAGViTClassifier)
|
|
|
7 |
|
8 |
# Load config and model
|
9 |
config = SAGViTConfig()
|
10 |
model = SAGViTClassifier(config)
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
# Push model and code
|
|
|
13 |
model.push_to_hub("shravvvv/SAG-ViT")
|
|
|
|
|
|
1 |
+
import torch
|
2 |
from transformers import AutoConfig, AutoModel
|
3 |
from modeling_sagvit import SAGViTClassifier, SAGViTConfig
|
4 |
|
5 |
|
6 |
+
print("Registering model...")
|
7 |
AutoConfig.register("sagvit", SAGViTConfig)
|
8 |
AutoModel.register(SAGViTConfig, SAGViTClassifier)
|
9 |
+
print("Registered model...")
|
10 |
|
11 |
# Load config and model
|
12 |
config = SAGViTConfig()
|
13 |
model = SAGViTClassifier(config)
|
14 |
|
15 |
+
# Load the state dict into the model
|
16 |
+
print("Loading model weights...")
|
17 |
+
state_dict = torch.load('SAG-ViT.pth')
|
18 |
+
model.load_state_dict(state_dict)
|
19 |
+
print("Loaded model weights...")
|
20 |
+
|
21 |
# Push model and code
|
22 |
+
model.save_pretrained('.')
|
23 |
model.push_to_hub("shravvvv/SAG-ViT")
|
24 |
+
|
25 |
+
print("Pushed model to Hugging Face hub...")
|