from transformers import PretrainedConfig class SAGViTConfig(PretrainedConfig): model_type = "sag-vit" def __init__(self, d_model=64, dim_feedforward=64, gcn_hidden=128, gcn_out=64, hidden_mlp_features=64, in_channels=2560, nhead=4, num_classes=10, num_layers=2, patch_size=(4, 4), **kwargs): super().__init__(**kwargs) self.d_model = d_model self.dim_feedforward = dim_feedforward self.gcn_hidden = gcn_hidden self.gcn_out = gcn_out self.hidden_mlp_features = hidden_mlp_features self.in_channels = in_channels self.nhead = nhead self.num_classes = num_classes self.num_layers = num_layers self.patch_size = patch_size