Image Feature Extraction
Transformers
Safetensors
dinov2
fepegar commited on
Commit
450a84c
·
1 Parent(s): d59bc91

Add some configs and a module

Browse files
Files changed (3) hide show
  1. augmentations.py +147 -0
  2. ssl_default_config.yaml +135 -0
  3. vitb14_cxr.yaml +31 -0
augmentations.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # See LICENSE in the repo root for license information.
3
+ #
4
+ # Portions:
5
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
6
+ #
7
+ # This source code is licensed under the Apache License, Version 2.0
8
+ # found in the LICENSE file in the root directory of this source tree.
9
+
10
+ import logging
11
+
12
+ from PIL import Image
13
+ from torchvision import transforms
14
+
15
+ from .transforms import (
16
+ GaussianBlur,
17
+ MaybeToTensor,
18
+ make_normalize_transform,
19
+ )
20
+
21
+
22
+ logger = logging.getLogger("dinov2")
23
+
24
+
25
+ class DataAugmentationDINO(object):
26
+ def __init__(
27
+ self,
28
+ global_crops_scale,
29
+ local_crops_scale,
30
+ local_crops_number,
31
+ global_crops_size=224,
32
+ local_crops_size=96,
33
+ ):
34
+ self.global_crops_scale = global_crops_scale
35
+ self.local_crops_scale = local_crops_scale
36
+ self.local_crops_number = local_crops_number
37
+ self.global_crops_size = global_crops_size
38
+ self.local_crops_size = local_crops_size
39
+
40
+ logger.info("###################################")
41
+ logger.info("Using data augmentation parameters:")
42
+ logger.info(f"global_crops_scale: {global_crops_scale}")
43
+ logger.info(f"local_crops_scale: {local_crops_scale}")
44
+ logger.info(f"local_crops_number: {local_crops_number}")
45
+ logger.info(f"global_crops_size: {global_crops_size}")
46
+ logger.info(f"local_crops_size: {local_crops_size}")
47
+ logger.info("###################################")
48
+
49
+ # random resized crop and flip
50
+ self.geometric_augmentation_global = transforms.Compose(
51
+ [
52
+ transforms.RandomResizedCrop(
53
+ global_crops_size, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
54
+ ),
55
+ transforms.RandomHorizontalFlip(p=0.5),
56
+ ]
57
+ )
58
+
59
+ self.geometric_augmentation_local = transforms.Compose(
60
+ [
61
+ transforms.RandomResizedCrop(
62
+ local_crops_size, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
63
+ ),
64
+ transforms.RandomHorizontalFlip(p=0.5),
65
+ ]
66
+ )
67
+
68
+ # color distorsions / blurring
69
+ color_jittering = transforms.Compose(
70
+ [
71
+ transforms.RandomApply(
72
+ [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
73
+ p=0.8,
74
+ ),
75
+ transforms.RandomGrayscale(p=0.2),
76
+ ]
77
+ )
78
+
79
+ global_transfo1_extra = GaussianBlur(p=0.5)
80
+
81
+ global_transfo2_extra = transforms.Compose(
82
+ [
83
+ GaussianBlur(p=0.1),
84
+ ]
85
+ )
86
+
87
+ local_transfo_extra = GaussianBlur(p=0.5)
88
+
89
+ # normalization
90
+ self.normalize = transforms.Compose(
91
+ [
92
+ MaybeToTensor(),
93
+ make_normalize_transform(),
94
+ ]
95
+ )
96
+
97
+ self.global_transfo1 = transforms.Compose([color_jittering, global_transfo1_extra, self.normalize])
98
+ self.global_transfo2 = transforms.Compose([color_jittering, global_transfo2_extra, self.normalize])
99
+ self.local_transfo = transforms.Compose([color_jittering, local_transfo_extra, self.normalize])
100
+
101
+ def __call__(self, image):
102
+ output = {}
103
+
104
+ # global crops:
105
+ im1_base = self.geometric_augmentation_global(image)
106
+ global_crop_1 = self.global_transfo1(im1_base)
107
+
108
+ im2_base = self.geometric_augmentation_global(image)
109
+ global_crop_2 = self.global_transfo2(im2_base)
110
+
111
+ output["global_crops"] = [global_crop_1, global_crop_2]
112
+
113
+ # global crops for teacher:
114
+ output["global_crops_teacher"] = [global_crop_1, global_crop_2]
115
+
116
+ # local crops:
117
+ local_crops = [
118
+ self.local_transfo(self.geometric_augmentation_local(image)) for _ in range(self.local_crops_number)
119
+ ]
120
+ output["local_crops"] = local_crops
121
+ output["offsets"] = ()
122
+
123
+ return output
124
+
125
+
126
+ def get_online_classification_augmentation_from_config(cfg) -> transforms.Compose:
127
+ augmentation_config = cfg.evaluation.online.augmentation
128
+ interpolation = getattr(Image.Resampling, augmentation_config.interpolation)
129
+ resize_size = crop_size = cfg.crops.global_crops_size
130
+ resize = transforms.Resize(resize_size, interpolation=interpolation)
131
+ crop = transforms.CenterCrop(crop_size)
132
+ affine = transforms.RandomAffine(
133
+ degrees=augmentation_config.degrees,
134
+ scale=augmentation_config.scale,
135
+ shear=augmentation_config.shear,
136
+ interpolation=interpolation,
137
+ )
138
+ transforms_list = [
139
+ resize,
140
+ crop,
141
+ affine,
142
+ MaybeToTensor(),
143
+ make_normalize_transform(),
144
+ ]
145
+ if augmentation_config.horizontal_flip:
146
+ transforms_list.append(transforms.RandomHorizontalFlip())
147
+ return transforms.Compose(transforms_list)
ssl_default_config.yaml ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ WEIGHTS: ''
3
+ compute_precision:
4
+ grad_scaler: true
5
+ teacher:
6
+ backbone:
7
+ sharding_strategy: SHARD_GRAD_OP
8
+ mixed_precision:
9
+ param_dtype: fp16
10
+ reduce_dtype: fp16
11
+ buffer_dtype: fp32
12
+ dino_head:
13
+ sharding_strategy: SHARD_GRAD_OP
14
+ mixed_precision:
15
+ param_dtype: fp16
16
+ reduce_dtype: fp16
17
+ buffer_dtype: fp32
18
+ ibot_head:
19
+ sharding_strategy: SHARD_GRAD_OP
20
+ mixed_precision:
21
+ param_dtype: fp16
22
+ reduce_dtype: fp16
23
+ buffer_dtype: fp32
24
+ student:
25
+ backbone:
26
+ sharding_strategy: SHARD_GRAD_OP
27
+ mixed_precision:
28
+ param_dtype: fp16
29
+ reduce_dtype: fp16
30
+ buffer_dtype: fp32
31
+ dino_head:
32
+ sharding_strategy: SHARD_GRAD_OP
33
+ mixed_precision:
34
+ param_dtype: fp16
35
+ reduce_dtype: fp32
36
+ buffer_dtype: fp32
37
+ ibot_head:
38
+ sharding_strategy: SHARD_GRAD_OP
39
+ mixed_precision:
40
+ param_dtype: fp16
41
+ reduce_dtype: fp32
42
+ buffer_dtype: fp32
43
+ dino:
44
+ loss_weight: 1.0
45
+ head_n_prototypes: 65536
46
+ head_bottleneck_dim: 256
47
+ head_nlayers: 3
48
+ head_hidden_dim: 2048
49
+ koleo_loss_weight: 0.1
50
+ ibot:
51
+ loss_weight: 1.0
52
+ mask_sample_probability: 0.5
53
+ mask_ratio_min_max:
54
+ - 0.1
55
+ - 0.5
56
+ separate_head: false
57
+ head_n_prototypes: 65536
58
+ head_bottleneck_dim: 256
59
+ head_nlayers: 3
60
+ head_hidden_dim: 2048
61
+ train:
62
+ batch_size_per_gpu: 64
63
+ dataset_path: ImageNet:split=TRAIN
64
+ output_dir: .
65
+ saveckp_every_n_epoch: 5
66
+ seed: 0
67
+ num_workers: 10
68
+ OFFICIAL_EPOCH_LENGTH: 0 # automatic rescaling based on the dataset len is applied if this is set to 0
69
+ cache_dataset: true
70
+ centering: "centering" # or "sinkhorn_knopp"
71
+ student:
72
+ arch: vit_large
73
+ patch_size: 16
74
+ drop_block_rate: 0.0
75
+ drop_path_rate: 0.3
76
+ layerscale: 1.0e-05
77
+ drop_path_uniform: true
78
+ pretrained_weights: ''
79
+ ffn_layer: "mlp"
80
+ block_chunks: 0
81
+ qkv_bias: true
82
+ proj_bias: true
83
+ ffn_bias: true
84
+ num_register_tokens: 0
85
+ interpolate_antialias: false
86
+ interpolate_offset: 0.1
87
+ load_weights: true
88
+ checkpoints_dir: null
89
+ teacher:
90
+ momentum_teacher: 0.992
91
+ final_momentum_teacher: 1
92
+ warmup_teacher_temp: 0.04
93
+ teacher_temp: 0.07
94
+ warmup_teacher_temp_epochs: 30
95
+ optim:
96
+ epochs: 100
97
+ weight_decay: 0.04
98
+ weight_decay_end: 0.4
99
+ base_lr: 0.004 # learning rate for a batch size of 1024
100
+ lr: 0. # will be set after applying scaling rule
101
+ warmup_epochs: 10
102
+ min_lr: 1.0e-06
103
+ clip_grad: 3.0
104
+ freeze_last_layer_epochs: 1
105
+ scaling_rule: sqrt_wrt_1024
106
+ patch_embed_lr_mult: 0.2
107
+ layerwise_decay: 0.9
108
+ adamw_beta1: 0.9
109
+ adamw_beta2: 0.999
110
+ crops:
111
+ global_crops_scale:
112
+ - 0.32
113
+ - 1.0
114
+ local_crops_number: 8
115
+ local_crops_scale:
116
+ - 0.05
117
+ - 0.32
118
+ global_crops_size: 224
119
+ local_crops_size: 96
120
+ evaluation:
121
+ eval_period_iterations: 12500
122
+ dataset_str: None
123
+ online: # see dinov2.eval.linear_callback for documentation
124
+ learning_rate: 1e-6 # will be multiplied by batch size and number of devices
125
+ num_last_blocks: 1
126
+ add_avg_pool: true
127
+ num_update_epochs_per_eval: 3
128
+ augmentation:
129
+ degrees: 30
130
+ scale:
131
+ - 0.8
132
+ - 1.2
133
+ shear: 15
134
+ interpolation: BICUBIC
135
+ horizontal_flip: true
vitb14_cxr.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this corresponds to the CXR config
2
+ train:
3
+ batch_size_per_gpu: 40 # For nodes with v100s (32 GB), use 20.
4
+ saveckp_every_n_epoch: 25
5
+ student:
6
+ arch: vit_base
7
+ block_chunks: 4
8
+ patch_size: 14
9
+ drop_block_rate: 0.00
10
+ drop_path_rate: 0.30
11
+ teacher:
12
+ warmup_teacher_temp_epochs: 50
13
+ optim:
14
+ epochs: 100
15
+ warmup_epochs: 5
16
+ base_lr: 0.001
17
+ evaluation:
18
+ eval_period_iterations: 300
19
+ tasks: # from the metadata.csv file of the CANDID processed dataset
20
+ - pneumothorax
21
+ crops:
22
+ global_crops_size: 518
23
+ local_crops_size: 196
24
+ global_crops_scale:
25
+ - 0.50
26
+ - 1.00
27
+ local_crops_number: 8
28
+ local_crops_scale:
29
+ - 0.20
30
+ - 0.50
31
+ pretrained: true