shravvvv commited on
Commit
986f758
·
1 Parent(s): dd84d71

Updated stuff

Browse files
sagvit.bin → SAG-ViT.pth RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ff3f3f603954fd987783ae22320d9e941ddb1e53acb9ce5b341d8911bdf3d100
3
- size 27102094
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8bdfa65b805d74c284af254153960436014a6e26b740f25cf2c6c9289234d9d3
3
+ size 27137010
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2b23a28871ba0814540b11999cb6dc70b9c9e4d0bf446d2e8770e8c26d244387
3
- size 32389992
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:234c74e97091f54931e7050f775e56465b0abff0596906ad015c45df31e7b12a
3
+ size 27009200
model_components.py CHANGED
@@ -1,7 +1,7 @@
1
  import torch
2
  from torch import nn
3
  import torch.nn.functional as F
4
- from torch_geometric.nn import GATConv, global_mean_pool
5
 
6
  from torchvision import models
7
 
@@ -52,15 +52,36 @@ class GATGNN(nn.Module):
52
  This module corresponds to the Graph Attention stage (Section 3.3),
53
  refining local relationships between patches in a learned manner.
54
  """
55
- def __init__(self, in_channels, hidden_channels, out_channels, heads=8):
56
  super(GATGNN, self).__init__()
57
  # GAT layers:
58
  # First layer maps raw patch embeddings to a higher-level representation.
59
  self.conv1 = GATConv(in_channels, hidden_channels, heads=heads)
60
- # Second layer produces final node embeddings with a single head.
61
- self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1)
62
  self.pool = global_mean_pool
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def forward(self, data):
65
  """
66
  Input:
 
1
  import torch
2
  from torch import nn
3
  import torch.nn.functional as F
4
+ from torch_geometric.nn import GATConv, global_mean_pool, GCNConv
5
 
6
  from torchvision import models
7
 
 
52
  This module corresponds to the Graph Attention stage (Section 3.3),
53
  refining local relationships between patches in a learned manner.
54
  """
55
+ def __init__(self, in_channels, hidden_channels, out_channels, heads=4):
56
  super(GATGNN, self).__init__()
57
  # GAT layers:
58
  # First layer maps raw patch embeddings to a higher-level representation.
59
  self.conv1 = GATConv(in_channels, hidden_channels, heads=heads)
60
+ # Final GCN layer for refined representation
61
+ self.conv2 = GCNConv(hidden_channels * heads, out_channels)
62
  self.pool = global_mean_pool
63
 
64
+ def forward(self, data):
65
+ """
66
+ Input:
67
+ - data (PyG Data): Contains x (node features), edge_index (graph edges), and batch indexing.
68
+
69
+ Output:
70
+ - x (Tensor): Aggregated graph-level embedding after mean pooling.
71
+ """
72
+ x, edge_index, batch = data.x, data.edge_index, data.batch
73
+
74
+ # GAT layer with ReLU activation
75
+ x = F.relu(self.conv1(x, edge_index))
76
+
77
+ # GCN layer for further aggregation
78
+ x = self.conv2(x, edge_index)
79
+
80
+ # Global mean pooling to obtain graph-level representation
81
+ out = self.pool(x, batch)
82
+
83
+ return out
84
+
85
  def forward(self, data):
86
  """
87
  Input:
push_model_to_hfhub.py CHANGED
@@ -1,13 +1,25 @@
 
1
  from transformers import AutoConfig, AutoModel
2
  from modeling_sagvit import SAGViTClassifier, SAGViTConfig
3
 
4
 
 
5
  AutoConfig.register("sagvit", SAGViTConfig)
6
  AutoModel.register(SAGViTConfig, SAGViTClassifier)
 
7
 
8
  # Load config and model
9
  config = SAGViTConfig()
10
  model = SAGViTClassifier(config)
11
 
 
 
 
 
 
 
12
  # Push model and code
 
13
  model.push_to_hub("shravvvv/SAG-ViT")
 
 
 
1
+ import torch
2
  from transformers import AutoConfig, AutoModel
3
  from modeling_sagvit import SAGViTClassifier, SAGViTConfig
4
 
5
 
6
+ print("Registering model...")
7
  AutoConfig.register("sagvit", SAGViTConfig)
8
  AutoModel.register(SAGViTConfig, SAGViTClassifier)
9
+ print("Registered model...")
10
 
11
  # Load config and model
12
  config = SAGViTConfig()
13
  model = SAGViTClassifier(config)
14
 
15
+ # Load the state dict into the model
16
+ print("Loading model weights...")
17
+ state_dict = torch.load('SAG-ViT.pth')
18
+ model.load_state_dict(state_dict)
19
+ print("Loaded model weights...")
20
+
21
  # Push model and code
22
+ model.save_pretrained('.')
23
  model.push_to_hub("shravvvv/SAG-ViT")
24
+
25
+ print("Pushed model to Hugging Face hub...")