Updated model registration
Browse files- register_model.py +1 -1
- sagvit_config.py +1 -1
register_model.py
CHANGED
@@ -6,7 +6,7 @@ from sagvit_config import SAGViTConfig
|
|
6 |
from sag_vit_model import SAGViTClassifier
|
7 |
|
8 |
# Register the configuration
|
9 |
-
CONFIG_MAPPING.register("
|
10 |
|
11 |
# Register the model
|
12 |
MODEL_MAPPING.register(SAGViTConfig, SAGViTClassifier)
|
|
|
6 |
from sag_vit_model import SAGViTClassifier
|
7 |
|
8 |
# Register the configuration
|
9 |
+
CONFIG_MAPPING.register("sag-vit", SAGViTConfig)
|
10 |
|
11 |
# Register the model
|
12 |
MODEL_MAPPING.register(SAGViTConfig, SAGViTClassifier)
|
sagvit_config.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
from transformers import PretrainedConfig
|
2 |
|
3 |
class SAGViTConfig(PretrainedConfig):
|
4 |
-
model_type = "
|
5 |
|
6 |
def __init__(self,
|
7 |
d_model=64,
|
|
|
1 |
from transformers import PretrainedConfig
|
2 |
|
3 |
class SAGViTConfig(PretrainedConfig):
|
4 |
+
model_type = "sag-vit"
|
5 |
|
6 |
def __init__(self,
|
7 |
d_model=64,
|