thai thong commited on
Commit
f4c379b
·
0 Parent(s):

add file via upload

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +164 -0
  2. app.py +88 -0
  3. benchmarks.py +142 -0
  4. data/coco.yaml +106 -0
  5. data/hyps/hyp.scratch-high-ver2.yaml +30 -0
  6. data/hyps/hyp.scratch-high.yaml +30 -0
  7. detect.py +235 -0
  8. detect_dual.py +232 -0
  9. export.py +686 -0
  10. hubconf.py +107 -0
  11. models/__init__.py +1 -0
  12. models/attention/blocks.py +631 -0
  13. models/common.py +1360 -0
  14. models/experimental.py +275 -0
  15. models/tf.py +596 -0
  16. models/yolo.py +853 -0
  17. requirements.txt +47 -0
  18. utils/__init__.py +75 -0
  19. utils/augmentations.py +395 -0
  20. utils/autoanchor.py +164 -0
  21. utils/autobatch.py +67 -0
  22. utils/callbacks.py +71 -0
  23. utils/coco_utils.py +108 -0
  24. utils/dataloaders.py +1217 -0
  25. utils/downloads.py +103 -0
  26. utils/general.py +1227 -0
  27. utils/lion.py +67 -0
  28. utils/loggers/__init__.py +653 -0
  29. utils/loggers/clearml/__init__.py +1 -0
  30. utils/loggers/clearml/clearml_utils.py +157 -0
  31. utils/loggers/clearml/hpo.py +84 -0
  32. utils/loggers/comet/__init__.py +508 -0
  33. utils/loggers/comet/comet_utils.py +150 -0
  34. utils/loggers/comet/hpo.py +118 -0
  35. utils/loggers/comet/optimizer_config.json +209 -0
  36. utils/loggers/wandb/__init__.py +1 -0
  37. utils/loggers/wandb/log_dataset.py +27 -0
  38. utils/loggers/wandb/sweep.py +41 -0
  39. utils/loggers/wandb/sweep.yaml +143 -0
  40. utils/loggers/wandb/wandb_utils.py +589 -0
  41. utils/metrics.py +397 -0
  42. utils/panoptic/__init__.py +1 -0
  43. utils/panoptic/augmentations.py +183 -0
  44. utils/panoptic/dataloaders.py +478 -0
  45. utils/panoptic/general.py +137 -0
  46. utils/panoptic/loss.py +186 -0
  47. utils/panoptic/loss_tal.py +285 -0
  48. utils/panoptic/metrics.py +272 -0
  49. utils/panoptic/plots.py +164 -0
  50. utils/panoptic/tal/__init__.py +1 -0
.gitignore ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+ runs_/
162
+ weights/
163
+ *.DS_Store
164
+ .vscode/
app.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ #import spaces
3
+ #from huggingface_hub import hf_hub_download
4
+ from detect import run
5
+
6
+ def yolov9_inference(model_id, image_size, conf_threshold, iou_threshold, input_path = None):
7
+
8
+ # if img_path is not None:
9
+ # # Load the model
10
+ # # model_path = download_models(model_id)
11
+ # model = load_model(model_id)
12
+ # # Set model parameters
13
+ # model.conf = conf_threshold
14
+ # model.iou = iou_threshold
15
+ # # Perform inference
16
+ # results = model(img_path, size=image_size)
17
+ # # Optionally, show detection bounding boxes on image
18
+ # output = results.render()
19
+ # return output[0]
20
+ # else:
21
+ model = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, hide_conf=True, device='cpu')
22
+ return model, model
23
+
24
+ def app():
25
+ with gr.Blocks():
26
+ with gr.Row():
27
+ with gr.Column():
28
+ img_path = gr.Image(type="filepath", label="Image")
29
+ input_path = gr.Video(label="Input video")
30
+ model_id = gr.Dropdown(
31
+ label="Model",
32
+ choices=[
33
+ "last_best_model.pt",
34
+ ],
35
+ value="./last_best_model.pt",
36
+ )
37
+ image_size = gr.Slider(
38
+ label="Image Size",
39
+ minimum=320,
40
+ maximum=1280,
41
+ step=32,
42
+ value=640,
43
+ )
44
+ conf_threshold = gr.Slider(
45
+ label="Confidence Threshold",
46
+ minimum=0.1,
47
+ maximum=1.0,
48
+ step=0.1,
49
+ value=0.4,
50
+ )
51
+ iou_threshold = gr.Slider(
52
+ label="IoU Threshold",
53
+ minimum=0.1,
54
+ maximum=1.0,
55
+ step=0.1,
56
+ value=0.5,
57
+ )
58
+ yolov9_infer = gr.Button(value="Inference")
59
+
60
+ with gr.Column():
61
+ output = gr.Video(label="Output")
62
+ output_path = gr.Textbox(label="Output path")
63
+ yolov9_infer.click(
64
+ fn=yolov9_inference,
65
+ inputs=[
66
+ model_id,
67
+ image_size,
68
+ conf_threshold,
69
+ iou_threshold,
70
+ input_path
71
+ ],
72
+ outputs=[output, output_path],
73
+ )
74
+
75
+
76
+ gradio_app = gr.Blocks()
77
+ with gradio_app:
78
+ gr.HTML(
79
+ """
80
+ <h1 style='text-align: center'>
81
+ YOLOv9: Real-time Object Detection
82
+ </h1>
83
+ """)
84
+ with gr.Row():
85
+ with gr.Column():
86
+ app()
87
+
88
+ gradio_app.launch(debug=True)
benchmarks.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import platform
3
+ import sys
4
+ import time
5
+ from pathlib import Path
6
+
7
+ import pandas as pd
8
+
9
+ FILE = Path(__file__).resolve()
10
+ ROOT = FILE.parents[0] # YOLO root directory
11
+ if str(ROOT) not in sys.path:
12
+ sys.path.append(str(ROOT)) # add ROOT to PATH
13
+ # ROOT = ROOT.relative_to(Path.cwd()) # relative
14
+
15
+ import export
16
+ from models.experimental import attempt_load
17
+ from models.yolo import SegmentationModel
18
+ from segment.val import run as val_seg
19
+ from utils import notebook_init
20
+ from utils.general import LOGGER, check_yaml, file_size, print_args
21
+ from utils.torch_utils import select_device
22
+ from val import run as val_det
23
+
24
+
25
+ def run(
26
+ weights=ROOT / 'yolo.pt', # weights path
27
+ imgsz=640, # inference size (pixels)
28
+ batch_size=1, # batch size
29
+ data=ROOT / 'data/coco.yaml', # dataset.yaml path
30
+ device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
31
+ half=False, # use FP16 half-precision inference
32
+ test=False, # test exports only
33
+ pt_only=False, # test PyTorch only
34
+ hard_fail=False, # throw error on benchmark failure
35
+ ):
36
+ y, t = [], time.time()
37
+ device = select_device(device)
38
+ model_type = type(attempt_load(weights, fuse=False)) # DetectionModel, SegmentationModel, etc.
39
+ for i, (name, f, suffix, cpu, gpu) in export.export_formats().iterrows(): # index, (name, file, suffix, CPU, GPU)
40
+ try:
41
+ assert i not in (9, 10), 'inference not supported' # Edge TPU and TF.js are unsupported
42
+ assert i != 5 or platform.system() == 'Darwin', 'inference only supported on macOS>=10.13' # CoreML
43
+ if 'cpu' in device.type:
44
+ assert cpu, 'inference not supported on CPU'
45
+ if 'cuda' in device.type:
46
+ assert gpu, 'inference not supported on GPU'
47
+
48
+ # Export
49
+ if f == '-':
50
+ w = weights # PyTorch format
51
+ else:
52
+ w = export.run(weights=weights, imgsz=[imgsz], include=[f], device=device, half=half)[-1] # all others
53
+ assert suffix in str(w), 'export failed'
54
+
55
+ # Validate
56
+ if model_type == SegmentationModel:
57
+ result = val_seg(data, w, batch_size, imgsz, plots=False, device=device, task='speed', half=half)
58
+ metric = result[0][7] # (box(p, r, map50, map), mask(p, r, map50, map), *loss(box, obj, cls))
59
+ else: # DetectionModel:
60
+ result = val_det(data, w, batch_size, imgsz, plots=False, device=device, task='speed', half=half)
61
+ metric = result[0][3] # (p, r, map50, map, *loss(box, obj, cls))
62
+ speed = result[2][1] # times (preprocess, inference, postprocess)
63
+ y.append([name, round(file_size(w), 1), round(metric, 4), round(speed, 2)]) # MB, mAP, t_inference
64
+ except Exception as e:
65
+ if hard_fail:
66
+ assert type(e) is AssertionError, f'Benchmark --hard-fail for {name}: {e}'
67
+ LOGGER.warning(f'WARNING ⚠️ Benchmark failure for {name}: {e}')
68
+ y.append([name, None, None, None]) # mAP, t_inference
69
+ if pt_only and i == 0:
70
+ break # break after PyTorch
71
+
72
+ # Print results
73
+ LOGGER.info('\n')
74
+ parse_opt()
75
+ notebook_init() # print system info
76
+ c = ['Format', 'Size (MB)', 'mAP50-95', 'Inference time (ms)'] if map else ['Format', 'Export', '', '']
77
+ py = pd.DataFrame(y, columns=c)
78
+ LOGGER.info(f'\nBenchmarks complete ({time.time() - t:.2f}s)')
79
+ LOGGER.info(str(py if map else py.iloc[:, :2]))
80
+ if hard_fail and isinstance(hard_fail, str):
81
+ metrics = py['mAP50-95'].array # values to compare to floor
82
+ floor = eval(hard_fail) # minimum metric floor to pass
83
+ assert all(x > floor for x in metrics if pd.notna(x)), f'HARD FAIL: mAP50-95 < floor {floor}'
84
+ return py
85
+
86
+
87
+ def test(
88
+ weights=ROOT / 'yolo.pt', # weights path
89
+ imgsz=640, # inference size (pixels)
90
+ batch_size=1, # batch size
91
+ data=ROOT / 'data/coco128.yaml', # dataset.yaml path
92
+ device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
93
+ half=False, # use FP16 half-precision inference
94
+ test=False, # test exports only
95
+ pt_only=False, # test PyTorch only
96
+ hard_fail=False, # throw error on benchmark failure
97
+ ):
98
+ y, t = [], time.time()
99
+ device = select_device(device)
100
+ for i, (name, f, suffix, gpu) in export.export_formats().iterrows(): # index, (name, file, suffix, gpu-capable)
101
+ try:
102
+ w = weights if f == '-' else \
103
+ export.run(weights=weights, imgsz=[imgsz], include=[f], device=device, half=half)[-1] # weights
104
+ assert suffix in str(w), 'export failed'
105
+ y.append([name, True])
106
+ except Exception:
107
+ y.append([name, False]) # mAP, t_inference
108
+
109
+ # Print results
110
+ LOGGER.info('\n')
111
+ parse_opt()
112
+ notebook_init() # print system info
113
+ py = pd.DataFrame(y, columns=['Format', 'Export'])
114
+ LOGGER.info(f'\nExports complete ({time.time() - t:.2f}s)')
115
+ LOGGER.info(str(py))
116
+ return py
117
+
118
+
119
+ def parse_opt():
120
+ parser = argparse.ArgumentParser()
121
+ parser.add_argument('--weights', type=str, default=ROOT / 'yolo.pt', help='weights path')
122
+ parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size (pixels)')
123
+ parser.add_argument('--batch-size', type=int, default=1, help='batch size')
124
+ parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
125
+ parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
126
+ parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
127
+ parser.add_argument('--test', action='store_true', help='test exports only')
128
+ parser.add_argument('--pt-only', action='store_true', help='test PyTorch only')
129
+ parser.add_argument('--hard-fail', nargs='?', const=True, default=False, help='Exception on error or < min metric')
130
+ opt = parser.parse_args()
131
+ opt.data = check_yaml(opt.data) # check YAML
132
+ print_args(vars(opt))
133
+ return opt
134
+
135
+
136
+ def main(opt):
137
+ test(**vars(opt)) if opt.test else run(**vars(opt))
138
+
139
+
140
+ if __name__ == "__main__":
141
+ opt = parse_opt()
142
+ main(opt)
data/coco.yaml ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ path: ../datasets/coco # dataset root dir
2
+ train: train2017.txt # train images (relative to 'path') 118287 images
3
+ val: val2017.txt # val images (relative to 'path') 5000 images
4
+ test: test-dev2017.txt # 20288 of 40670 images, submit to https://competitions.codalab.org/competitions/20794
5
+
6
+ # Classes
7
+ names:
8
+ 0: person
9
+ 1: bicycle
10
+ 2: car
11
+ 3: motorcycle
12
+ 4: airplane
13
+ 5: bus
14
+ 6: train
15
+ 7: truck
16
+ 8: boat
17
+ 9: traffic light
18
+ 10: fire hydrant
19
+ 11: stop sign
20
+ 12: parking meter
21
+ 13: bench
22
+ 14: bird
23
+ 15: cat
24
+ 16: dog
25
+ 17: horse
26
+ 18: sheep
27
+ 19: cow
28
+ 20: elephant
29
+ 21: bear
30
+ 22: zebra
31
+ 23: giraffe
32
+ 24: backpack
33
+ 25: umbrella
34
+ 26: handbag
35
+ 27: tie
36
+ 28: suitcase
37
+ 29: frisbee
38
+ 30: skis
39
+ 31: snowboard
40
+ 32: sports ball
41
+ 33: kite
42
+ 34: baseball bat
43
+ 35: baseball glove
44
+ 36: skateboard
45
+ 37: surfboard
46
+ 38: tennis racket
47
+ 39: bottle
48
+ 40: wine glass
49
+ 41: cup
50
+ 42: fork
51
+ 43: knife
52
+ 44: spoon
53
+ 45: bowl
54
+ 46: banana
55
+ 47: apple
56
+ 48: sandwich
57
+ 49: orange
58
+ 50: broccoli
59
+ 51: carrot
60
+ 52: hot dog
61
+ 53: pizza
62
+ 54: donut
63
+ 55: cake
64
+ 56: chair
65
+ 57: couch
66
+ 58: potted plant
67
+ 59: bed
68
+ 60: dining table
69
+ 61: toilet
70
+ 62: tv
71
+ 63: laptop
72
+ 64: mouse
73
+ 65: remote
74
+ 66: keyboard
75
+ 67: cell phone
76
+ 68: microwave
77
+ 69: oven
78
+ 70: toaster
79
+ 71: sink
80
+ 72: refrigerator
81
+ 73: book
82
+ 74: clock
83
+ 75: vase
84
+ 76: scissors
85
+ 77: teddy bear
86
+ 78: hair drier
87
+ 79: toothbrush
88
+
89
+
90
+ # Download script/URL (optional)
91
+ download: |
92
+ from utils.general import download, Path
93
+
94
+
95
+ # Download labels
96
+ #segments = True # segment or box labels
97
+ #dir = Path(yaml['path']) # dataset root dir
98
+ #url = 'https://github.com/WongKinYiu/yolov7/releases/download/v0.1/'
99
+ #urls = [url + ('coco2017labels-segments.zip' if segments else 'coco2017labels.zip')] # labels
100
+ #download(urls, dir=dir.parent)
101
+
102
+ # Download data
103
+ #urls = ['http://images.cocodataset.org/zips/train2017.zip', # 19G, 118k images
104
+ # 'http://images.cocodataset.org/zips/val2017.zip', # 1G, 5k images
105
+ # 'http://images.cocodataset.org/zips/test2017.zip'] # 7G, 41k images (optional)
106
+ #download(urls, dir=dir / 'images', threads=3)
data/hyps/hyp.scratch-high-ver2.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3)
2
+ lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf)
3
+ momentum: 0.937 # SGD momentum/Adam beta1
4
+ weight_decay: 0.0005 # optimizer weight decay 5e-4
5
+ warmup_epochs: 3.0 # warmup epochs (fractions ok)
6
+ warmup_momentum: 0.8 # warmup initial momentum
7
+ warmup_bias_lr: 0.1 # warmup initial bias lr
8
+ box: 7.5 # box loss gain
9
+ cls: 0.5 # cls loss gain
10
+ cls_pw: 1.0 # cls BCELoss positive_weight
11
+ obj: 0.7 # obj loss gain (scale with pixels)
12
+ obj_pw: 1.0 # obj BCELoss positive_weight
13
+ dfl: 1.5 # dfl loss gain
14
+ iou_t: 0.5 # IoU training threshold
15
+ anchor_t: 5.0 # anchor-multiple threshold
16
+ # anchors: 3 # anchors per output layer (0 to ignore)
17
+ fl_gamma: 0.5 # focal loss gamma (efficientDet default gamma=1.5)
18
+ hsv_h: 0.015 # image HSV-Hue augmentation (fraction)
19
+ hsv_s: 0.7 # image HSV-Saturation augmentation (fraction)
20
+ hsv_v: 0.4 # image HSV-Value augmentation (fraction)
21
+ degrees: 0.0 # image rotation (+/- deg)
22
+ translate: 0.1 # image translation (+/- fraction)
23
+ scale: 0.9 # image scale (+/- gain)
24
+ shear: 0.0 # image shear (+/- deg)
25
+ perspective: 0.0 # image perspective (+/- fraction), range 0-0.001
26
+ flipud: 0.0 # image flip up-down (probability)
27
+ fliplr: 0.5 # image flip left-right (probability)
28
+ mosaic: 1.0 # image mosaic (probability)
29
+ mixup: 0.15 # image mixup (probability)
30
+ copy_paste: 0.3 # segment copy-paste (probability)
data/hyps/hyp.scratch-high.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3)
2
+ lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf)
3
+ momentum: 0.937 # SGD momentum/Adam beta1
4
+ weight_decay: 0.0005 # optimizer weight decay 5e-4
5
+ warmup_epochs: 3.0 # warmup epochs (fractions ok)
6
+ warmup_momentum: 0.8 # warmup initial momentum
7
+ warmup_bias_lr: 0.1 # warmup initial bias lr
8
+ box: 7.5 # box loss gain
9
+ cls: 0.5 # cls loss gain
10
+ cls_pw: 1.0 # cls BCELoss positive_weight
11
+ obj: 0.7 # obj loss gain (scale with pixels)
12
+ obj_pw: 1.0 # obj BCELoss positive_weight
13
+ dfl: 1.5 # dfl loss gain
14
+ iou_t: 0.5 # IoU training threshold
15
+ anchor_t: 5.0 # anchor-multiple threshold
16
+ # anchors: 3 # anchors per output layer (0 to ignore)
17
+ fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5)
18
+ hsv_h: 0.015 # image HSV-Hue augmentation (fraction)
19
+ hsv_s: 0.7 # image HSV-Saturation augmentation (fraction)
20
+ hsv_v: 0.4 # image HSV-Value augmentation (fraction)
21
+ degrees: 0.0 # image rotation (+/- deg)
22
+ translate: 0.1 # image translation (+/- fraction)
23
+ scale: 0.9 # image scale (+/- gain)
24
+ shear: 0.0 # image shear (+/- deg)
25
+ perspective: 0.0 # image perspective (+/- fraction), range 0-0.001
26
+ flipud: 0.0 # image flip up-down (probability)
27
+ fliplr: 0.5 # image flip left-right (probability)
28
+ mosaic: 1.0 # image mosaic (probability)
29
+ mixup: 0.15 # image mixup (probability)
30
+ copy_paste: 0.3 # segment copy-paste (probability)
detect.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import platform
4
+ import sys
5
+ from pathlib import Path
6
+
7
+ import torch
8
+
9
+ FILE = Path(__file__).resolve()
10
+ ROOT = FILE.parents[0] # YOLO root directory
11
+ if str(ROOT) not in sys.path:
12
+ sys.path.append(str(ROOT)) # add ROOT to PATH
13
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
14
+
15
+ from models.common import DetectMultiBackend
16
+ from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
17
+ from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
18
+ increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh)
19
+ from utils.plots import Annotator, colors, save_one_box
20
+ from utils.torch_utils import select_device, smart_inference_mode
21
+
22
+
23
+ @smart_inference_mode()
24
+ def run(
25
+ weights=ROOT / 'yolo.pt', # model path or triton URL
26
+ source=ROOT / 'data/images', # file/dir/URL/glob/screen/0(webcam)
27
+ data=ROOT / 'data/coco.yaml', # dataset.yaml path
28
+ imgsz=(640, 640), # inference size (height, width)
29
+ conf_thres=0.25, # confidence threshold
30
+ iou_thres=0.45, # NMS IOU threshold
31
+ max_det=1000, # maximum detections per image
32
+ device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
33
+ view_img=False, # show results
34
+ save_txt=False, # save results to *.txt
35
+ save_conf=False, # save confidences in --save-txt labels
36
+ save_crop=False, # save cropped prediction boxes
37
+ nosave=False, # do not save images/videos
38
+ classes=None, # filter by class: --class 0, or --class 0 2 3
39
+ agnostic_nms=False, # class-agnostic NMS
40
+ augment=False, # augmented inference
41
+ visualize=False, # visualize features
42
+ update=False, # update all models
43
+ project=ROOT / 'runs/detect', # save results to project/name
44
+ name='exp', # save results to project/name
45
+ exist_ok=False, # existing project/name ok, do not increment
46
+ line_thickness=3, # bounding box thickness (pixels)
47
+ hide_labels=False, # hide labels
48
+ hide_conf=False, # hide confidences
49
+ half=False, # use FP16 half-precision inference
50
+ dnn=False, # use OpenCV DNN for ONNX inference
51
+ vid_stride=1, # video frame-rate stride
52
+ ):
53
+ source = str(source)
54
+ save_img = not nosave and not source.endswith('.txt') # save inference images
55
+ is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
56
+ is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
57
+ webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file)
58
+ screenshot = source.lower().startswith('screen')
59
+ if is_url and is_file:
60
+ source = check_file(source) # download
61
+
62
+ # Directories
63
+ save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
64
+ (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
65
+
66
+ # Load model
67
+ device = select_device(device)
68
+ model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
69
+ stride, names, pt = model.stride, model.names, model.pt
70
+ imgsz = check_img_size(imgsz, s=stride) # check image size
71
+
72
+ # Dataloader
73
+ bs = 1 # batch_size
74
+ if webcam:
75
+ view_img = check_imshow(warn=True)
76
+ dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
77
+ bs = len(dataset)
78
+ elif screenshot:
79
+ dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt)
80
+ else:
81
+ dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
82
+ vid_path, vid_writer = [None] * bs, [None] * bs
83
+
84
+ # Run inference
85
+ model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup
86
+ seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
87
+ for path, im, im0s, vid_cap, s in dataset:
88
+ with dt[0]:
89
+ im = torch.from_numpy(im).to(model.device)
90
+ im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
91
+ im /= 255 # 0 - 255 to 0.0 - 1.0
92
+ if len(im.shape) == 3:
93
+ im = im[None] # expand for batch dim
94
+
95
+ # Inference
96
+ with dt[1]:
97
+ visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
98
+ pred = model(im, augment=augment, visualize=visualize)
99
+
100
+
101
+ # NMS
102
+ with dt[2]:
103
+ pred = pred[0][0] if isinstance(pred[0], list) else pred[0] # single model or ensemble
104
+ pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
105
+
106
+
107
+ # Second-stage classifier (optional)
108
+ # pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
109
+
110
+ # Process predictions
111
+ for i, det in enumerate(pred): # per image
112
+ seen += 1
113
+ if webcam: # batch_size >= 1
114
+ p, im0, frame = path[i], im0s[i].copy(), dataset.count
115
+ s += f'{i}: '
116
+ else:
117
+ p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)
118
+
119
+ p = Path(p) # to Path
120
+ save_path = str(save_dir / p.name) # im.jpg
121
+ txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # im.txt
122
+ s += '%gx%g ' % im.shape[2:] # print string
123
+ gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
124
+ imc = im0.copy() if save_crop else im0 # for save_crop
125
+ annotator = Annotator(im0, line_width=line_thickness, example=str(names))
126
+ if len(det):
127
+ # Rescale boxes from img_size to im0 size
128
+ det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
129
+
130
+ # Print results
131
+ for c in det[:, 5].unique():
132
+ n = (det[:, 5] == c).sum() # detections per class
133
+ s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
134
+
135
+ # Write results
136
+ for *xyxy, conf, cls in reversed(det):
137
+ if save_txt: # Write to file
138
+ xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
139
+ line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
140
+ with open(f'{txt_path}.txt', 'a') as f:
141
+ f.write(('%g ' * len(line)).rstrip() % line + '\n')
142
+
143
+ if save_img or save_crop or view_img: # Add bbox to image
144
+ c = int(cls) # integer class
145
+ label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')
146
+ annotator.box_label(xyxy, label, color=colors(c, True))
147
+ if save_crop:
148
+ save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
149
+
150
+ # Stream results
151
+ im0 = annotator.result()
152
+ if view_img:
153
+ if platform.system() == 'Linux' and p not in windows:
154
+ windows.append(p)
155
+ cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
156
+ cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
157
+ cv2.imshow(str(p), im0)
158
+ cv2.waitKey(1) # 1 millisecond
159
+
160
+ # Save results (image with detections)
161
+ if save_img:
162
+ if dataset.mode == 'image':
163
+ cv2.imwrite(save_path, im0)
164
+ else: # 'video' or 'stream'
165
+ if vid_path[i] != save_path: # new video
166
+ vid_path[i] = save_path
167
+ if isinstance(vid_writer[i], cv2.VideoWriter):
168
+ vid_writer[i].release() # release previous video writer
169
+ if vid_cap: # video
170
+ fps = vid_cap.get(cv2.CAP_PROP_FPS)
171
+ w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
172
+ h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
173
+ else: # stream
174
+ fps, w, h = 30, im0.shape[1], im0.shape[0]
175
+ save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos
176
+ vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'h264'), fps, (w, h))
177
+ vid_writer[i].write(im0)
178
+
179
+ # Print time (inference-only)
180
+ LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")
181
+
182
+ #vid_writer.release()
183
+ # Print results
184
+ t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image
185
+ LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t)
186
+ if save_txt or save_img:
187
+ s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
188
+ LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
189
+ if update:
190
+ strip_optimizer(weights[0]) # update model (to fix SourceChangeWarning)
191
+ return save_path
192
+
193
+ def parse_opt():
194
+ parser = argparse.ArgumentParser()
195
+ parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolo.pt', help='model path or triton URL')
196
+ parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob/screen/0(webcam)')
197
+ parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(optional) dataset.yaml path')
198
+ parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
199
+ parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')
200
+ parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold')
201
+ parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')
202
+ parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
203
+ parser.add_argument('--view-img', action='store_true', help='show results')
204
+ parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
205
+ parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
206
+ parser.add_argument('--save-crop', action='store_true', help='save cropped prediction boxes')
207
+ parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
208
+ parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --classes 0, or --classes 0 2 3')
209
+ parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
210
+ parser.add_argument('--augment', action='store_true', help='augmented inference')
211
+ parser.add_argument('--visualize', action='store_true', help='visualize features')
212
+ parser.add_argument('--update', action='store_true', help='update all models')
213
+ parser.add_argument('--project', default=ROOT / 'runs/detect', help='save results to project/name')
214
+ parser.add_argument('--name', default='exp', help='save results to project/name')
215
+ parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
216
+ parser.add_argument('--line-thickness', default=3, type=int, help='bounding box thickness (pixels)')
217
+ parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
218
+ parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
219
+ parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
220
+ parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
221
+ parser.add_argument('--vid-stride', type=int, default=1, help='video frame-rate stride')
222
+ opt = parser.parse_args()
223
+ opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
224
+ print_args(vars(opt))
225
+ return opt
226
+
227
+
228
+ def main(opt):
229
+ check_requirements(exclude=('tensorboard', 'thop'))
230
+ run(**vars(opt))
231
+
232
+
233
+ if __name__ == "__main__":
234
+ opt = parse_opt()
235
+ main(opt)
detect_dual.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import platform
4
+ import sys
5
+ from pathlib import Path
6
+
7
+ import torch
8
+
9
+ FILE = Path(__file__).resolve()
10
+ ROOT = FILE.parents[0] # YOLO root directory
11
+ if str(ROOT) not in sys.path:
12
+ sys.path.append(str(ROOT)) # add ROOT to PATH
13
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
14
+
15
+ from models.common import DetectMultiBackend
16
+ from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
17
+ from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
18
+ increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh)
19
+ from utils.plots import Annotator, colors, save_one_box
20
+ from utils.torch_utils import select_device, smart_inference_mode
21
+
22
+
23
+ @smart_inference_mode()
24
+ def run(
25
+ weights=ROOT / 'yolo.pt', # model path or triton URL
26
+ source=ROOT / 'data/images', # file/dir/URL/glob/screen/0(webcam)
27
+ data=ROOT / 'data/coco.yaml', # dataset.yaml path
28
+ imgsz=(640, 640), # inference size (height, width)
29
+ conf_thres=0.25, # confidence threshold
30
+ iou_thres=0.45, # NMS IOU threshold
31
+ max_det=1000, # maximum detections per image
32
+ device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
33
+ view_img=False, # show results
34
+ save_txt=False, # save results to *.txt
35
+ save_conf=False, # save confidences in --save-txt labels
36
+ save_crop=False, # save cropped prediction boxes
37
+ nosave=False, # do not save images/videos
38
+ classes=None, # filter by class: --class 0, or --class 0 2 3
39
+ agnostic_nms=False, # class-agnostic NMS
40
+ augment=False, # augmented inference
41
+ visualize=False, # visualize features
42
+ update=False, # update all models
43
+ project=ROOT / 'runs/detect', # save results to project/name
44
+ name='exp', # save results to project/name
45
+ exist_ok=False, # existing project/name ok, do not increment
46
+ line_thickness=3, # bounding box thickness (pixels)
47
+ hide_labels=False, # hide labels
48
+ hide_conf=False, # hide confidences
49
+ half=False, # use FP16 half-precision inference
50
+ dnn=False, # use OpenCV DNN for ONNX inference
51
+ vid_stride=1, # video frame-rate stride
52
+ ):
53
+ source = str(source)
54
+ save_img = not nosave and not source.endswith('.txt') # save inference images
55
+ is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
56
+ is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
57
+ webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file)
58
+ screenshot = source.lower().startswith('screen')
59
+ if is_url and is_file:
60
+ source = check_file(source) # download
61
+
62
+ # Directories
63
+ save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
64
+ (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
65
+
66
+ # Load model
67
+ device = select_device(device)
68
+ model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
69
+ stride, names, pt = model.stride, model.names, model.pt
70
+ imgsz = check_img_size(imgsz, s=stride) # check image size
71
+
72
+ # Dataloader
73
+ bs = 1 # batch_size
74
+ if webcam:
75
+ view_img = check_imshow(warn=True)
76
+ dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
77
+ bs = len(dataset)
78
+ elif screenshot:
79
+ dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt)
80
+ else:
81
+ dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
82
+ vid_path, vid_writer = [None] * bs, [None] * bs
83
+
84
+ # Run inference
85
+ model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup
86
+ seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
87
+ for path, im, im0s, vid_cap, s in dataset:
88
+ with dt[0]:
89
+ im = torch.from_numpy(im).to(model.device)
90
+ im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
91
+ im /= 255 # 0 - 255 to 0.0 - 1.0
92
+ if len(im.shape) == 3:
93
+ im = im[None] # expand for batch dim
94
+
95
+ # Inference
96
+ with dt[1]:
97
+ visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
98
+ pred = model(im, augment=augment, visualize=visualize)
99
+ pred = pred[0][1]
100
+
101
+ # NMS
102
+ with dt[2]:
103
+ pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
104
+
105
+ # Second-stage classifier (optional)
106
+ # pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
107
+
108
+ # Process predictions
109
+ for i, det in enumerate(pred): # per image
110
+ seen += 1
111
+ if webcam: # batch_size >= 1
112
+ p, im0, frame = path[i], im0s[i].copy(), dataset.count
113
+ s += f'{i}: '
114
+ else:
115
+ p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)
116
+
117
+ p = Path(p) # to Path
118
+ save_path = str(save_dir / p.name) # im.jpg
119
+ txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # im.txt
120
+ s += '%gx%g ' % im.shape[2:] # print string
121
+ gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
122
+ imc = im0.copy() if save_crop else im0 # for save_crop
123
+ annotator = Annotator(im0, line_width=line_thickness, example=str(names))
124
+ if len(det):
125
+ # Rescale boxes from img_size to im0 size
126
+ det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
127
+
128
+ # Print results
129
+ for c in det[:, 5].unique():
130
+ n = (det[:, 5] == c).sum() # detections per class
131
+ s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
132
+
133
+ # Write results
134
+ for *xyxy, conf, cls in reversed(det):
135
+ if save_txt: # Write to file
136
+ xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
137
+ line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
138
+ with open(f'{txt_path}.txt', 'a') as f:
139
+ f.write(('%g ' * len(line)).rstrip() % line + '\n')
140
+
141
+ if save_img or save_crop or view_img: # Add bbox to image
142
+ c = int(cls) # integer class
143
+ label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')
144
+ annotator.box_label(xyxy, label, color=colors(c, True))
145
+ if save_crop:
146
+ save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
147
+
148
+ # Stream results
149
+ im0 = annotator.result()
150
+ if view_img:
151
+ if platform.system() == 'Linux' and p not in windows:
152
+ windows.append(p)
153
+ cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
154
+ cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
155
+ cv2.imshow(str(p), im0)
156
+ cv2.waitKey(1) # 1 millisecond
157
+
158
+ # Save results (image with detections)
159
+ if save_img:
160
+ if dataset.mode == 'image':
161
+ cv2.imwrite(save_path, im0)
162
+ else: # 'video' or 'stream'
163
+ if vid_path[i] != save_path: # new video
164
+ vid_path[i] = save_path
165
+ if isinstance(vid_writer[i], cv2.VideoWriter):
166
+ vid_writer[i].release() # release previous video writer
167
+ if vid_cap: # video
168
+ fps = vid_cap.get(cv2.CAP_PROP_FPS)
169
+ w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
170
+ h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
171
+ else: # stream
172
+ fps, w, h = 30, im0.shape[1], im0.shape[0]
173
+ save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos
174
+ vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'h264'), fps, (w, h))
175
+ vid_writer[i].write(im0)
176
+
177
+ # Print time (inference-only)
178
+ LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")
179
+
180
+ # Print results
181
+ t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image
182
+ LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t)
183
+ if save_txt or save_img:
184
+ s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
185
+ LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
186
+ if update:
187
+ strip_optimizer(weights[0]) # update model (to fix SourceChangeWarning)
188
+
189
+
190
+ def parse_opt():
191
+ parser = argparse.ArgumentParser()
192
+ parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolo.pt', help='model path or triton URL')
193
+ parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob/screen/0(webcam)')
194
+ parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(optional) dataset.yaml path')
195
+ parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
196
+ parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')
197
+ parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold')
198
+ parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')
199
+ parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
200
+ parser.add_argument('--view-img', action='store_true', help='show results')
201
+ parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
202
+ parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
203
+ parser.add_argument('--save-crop', action='store_true', help='save cropped prediction boxes')
204
+ parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
205
+ parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --classes 0, or --classes 0 2 3')
206
+ parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
207
+ parser.add_argument('--augment', action='store_true', help='augmented inference')
208
+ parser.add_argument('--visualize', action='store_true', help='visualize features')
209
+ parser.add_argument('--update', action='store_true', help='update all models')
210
+ parser.add_argument('--project', default=ROOT / 'runs/detect', help='save results to project/name')
211
+ parser.add_argument('--name', default='exp', help='save results to project/name')
212
+ parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
213
+ parser.add_argument('--line-thickness', default=3, type=int, help='bounding box thickness (pixels)')
214
+ parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
215
+ parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
216
+ parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
217
+ parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
218
+ parser.add_argument('--vid-stride', type=int, default=1, help='video frame-rate stride')
219
+ opt = parser.parse_args()
220
+ opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
221
+ print_args(vars(opt))
222
+ return opt
223
+
224
+
225
+ def main(opt):
226
+ check_requirements(exclude=('tensorboard', 'thop'))
227
+ run(**vars(opt))
228
+
229
+
230
+ if __name__ == "__main__":
231
+ opt = parse_opt()
232
+ main(opt)
export.py ADDED
@@ -0,0 +1,686 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import contextlib
3
+ import json
4
+ import os
5
+ import platform
6
+ import re
7
+ import subprocess
8
+ import sys
9
+ import time
10
+ import warnings
11
+ from pathlib import Path
12
+
13
+ import pandas as pd
14
+ import torch
15
+ from torch.utils.mobile_optimizer import optimize_for_mobile
16
+
17
+ FILE = Path(__file__).resolve()
18
+ ROOT = FILE.parents[0] # YOLO root directory
19
+ if str(ROOT) not in sys.path:
20
+ sys.path.append(str(ROOT)) # add ROOT to PATH
21
+ if platform.system() != 'Windows':
22
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
23
+
24
+ from models.experimental import attempt_load, End2End
25
+ from models.yolo import ClassificationModel, Detect, DDetect, DualDetect, DualDDetect, IDualDDetect, DetectionModel, SegmentationModel
26
+ from utils.dataloaders import LoadImages
27
+ from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_version,
28
+ check_yaml, colorstr, file_size, get_default_args, print_args, url2file, yaml_save)
29
+ from utils.torch_utils import select_device, smart_inference_mode
30
+
31
+ MACOS = platform.system() == 'Darwin' # macOS environment
32
+
33
+
34
+ def export_formats():
35
+ # YOLO export formats
36
+ x = [
37
+ ['PyTorch', '-', '.pt', True, True],
38
+ ['TorchScript', 'torchscript', '.torchscript', True, True],
39
+ ['ONNX', 'onnx', '.onnx', True, True],
40
+ ['ONNX END2END', 'onnx_end2end', '_end2end.onnx', True, True],
41
+ ['OpenVINO', 'openvino', '_openvino_model', True, False],
42
+ ['TensorRT', 'engine', '.engine', False, True],
43
+ ['CoreML', 'coreml', '.mlmodel', True, False],
44
+ ['TensorFlow SavedModel', 'saved_model', '_saved_model', True, True],
45
+ ['TensorFlow GraphDef', 'pb', '.pb', True, True],
46
+ ['TensorFlow Lite', 'tflite', '.tflite', True, False],
47
+ ['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False, False],
48
+ ['TensorFlow.js', 'tfjs', '_web_model', False, False],
49
+ ['PaddlePaddle', 'paddle', '_paddle_model', True, True],]
50
+ return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
51
+
52
+
53
+ def try_export(inner_func):
54
+ # YOLO export decorator, i..e @try_export
55
+ inner_args = get_default_args(inner_func)
56
+
57
+ def outer_func(*args, **kwargs):
58
+ prefix = inner_args['prefix']
59
+ try:
60
+ with Profile() as dt:
61
+ f, model = inner_func(*args, **kwargs)
62
+ LOGGER.info(f'{prefix} export success ✅ {dt.t:.1f}s, saved as {f} ({file_size(f):.1f} MB)')
63
+ return f, model
64
+ except Exception as e:
65
+ LOGGER.info(f'{prefix} export failure ❌ {dt.t:.1f}s: {e}')
66
+ return None, None
67
+
68
+ return outer_func
69
+
70
+
71
+ @try_export
72
+ def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):
73
+ # YOLO TorchScript model export
74
+ LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')
75
+ f = file.with_suffix('.torchscript')
76
+
77
+ ts = torch.jit.trace(model, im, strict=False)
78
+ d = {"shape": im.shape, "stride": int(max(model.stride)), "names": model.names}
79
+ extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap()
80
+ if optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
81
+ optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)
82
+ else:
83
+ ts.save(str(f), _extra_files=extra_files)
84
+ return f, None
85
+
86
+
87
+ @try_export
88
+ def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX:')):
89
+ # YOLO ONNX export
90
+ check_requirements('onnx')
91
+ import onnx
92
+
93
+ LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
94
+ f = file.with_suffix('.onnx')
95
+
96
+ output_names = ['output0', 'output1'] if isinstance(model, SegmentationModel) else ['output0']
97
+ if dynamic:
98
+ dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}} # shape(1,3,640,640)
99
+ if isinstance(model, SegmentationModel):
100
+ dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
101
+ dynamic['output1'] = {0: 'batch', 2: 'mask_height', 3: 'mask_width'} # shape(1,32,160,160)
102
+ elif isinstance(model, DetectionModel):
103
+ dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
104
+
105
+ torch.onnx.export(
106
+ model.cpu() if dynamic else model, # --dynamic only compatible with cpu
107
+ im.cpu() if dynamic else im,
108
+ f,
109
+ verbose=False,
110
+ opset_version=opset,
111
+ do_constant_folding=True,
112
+ input_names=['images'],
113
+ output_names=output_names,
114
+ dynamic_axes=dynamic or None)
115
+
116
+ # Checks
117
+ model_onnx = onnx.load(f) # load onnx model
118
+ onnx.checker.check_model(model_onnx) # check onnx model
119
+
120
+ # Metadata
121
+ d = {'stride': int(max(model.stride)), 'names': model.names}
122
+ for k, v in d.items():
123
+ meta = model_onnx.metadata_props.add()
124
+ meta.key, meta.value = k, str(v)
125
+ onnx.save(model_onnx, f)
126
+
127
+ # Simplify
128
+ if simplify:
129
+ try:
130
+ cuda = torch.cuda.is_available()
131
+ check_requirements(('onnxruntime-gpu' if cuda else 'onnxruntime', 'onnx-simplifier>=0.4.1'))
132
+ import onnxsim
133
+
134
+ LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
135
+ model_onnx, check = onnxsim.simplify(model_onnx)
136
+ assert check, 'assert check failed'
137
+ onnx.save(model_onnx, f)
138
+ except Exception as e:
139
+ LOGGER.info(f'{prefix} simplifier failure: {e}')
140
+ return f, model_onnx
141
+
142
+
143
+ @try_export
144
+ def export_onnx_end2end(model, im, file, simplify, topk_all, iou_thres, conf_thres, device, labels, prefix=colorstr('ONNX END2END:')):
145
+ # YOLO ONNX export
146
+ check_requirements('onnx')
147
+ import onnx
148
+ LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
149
+ f = os.path.splitext(file)[0] + "-end2end.onnx"
150
+ batch_size = 'batch'
151
+
152
+ dynamic_axes = {'images': {0 : 'batch', 2: 'height', 3:'width'}, } # variable length axes
153
+
154
+ output_axes = {
155
+ 'num_dets': {0: 'batch'},
156
+ 'det_boxes': {0: 'batch'},
157
+ 'det_scores': {0: 'batch'},
158
+ 'det_classes': {0: 'batch'},
159
+ }
160
+ dynamic_axes.update(output_axes)
161
+ model = End2End(model, topk_all, iou_thres, conf_thres, None ,device, labels)
162
+
163
+ output_names = ['num_dets', 'det_boxes', 'det_scores', 'det_classes']
164
+ shapes = [ batch_size, 1, batch_size, topk_all, 4,
165
+ batch_size, topk_all, batch_size, topk_all]
166
+
167
+ torch.onnx.export(model,
168
+ im,
169
+ f,
170
+ verbose=False,
171
+ export_params=True, # store the trained parameter weights inside the model file
172
+ opset_version=12,
173
+ do_constant_folding=True, # whether to execute constant folding for optimization
174
+ input_names=['images'],
175
+ output_names=output_names,
176
+ dynamic_axes=dynamic_axes)
177
+
178
+ # Checks
179
+ model_onnx = onnx.load(f) # load onnx model
180
+ onnx.checker.check_model(model_onnx) # check onnx model
181
+ for i in model_onnx.graph.output:
182
+ for j in i.type.tensor_type.shape.dim:
183
+ j.dim_param = str(shapes.pop(0))
184
+
185
+ if simplify:
186
+ try:
187
+ import onnxsim
188
+
189
+ print('\nStarting to simplify ONNX...')
190
+ model_onnx, check = onnxsim.simplify(model_onnx)
191
+ assert check, 'assert check failed'
192
+ except Exception as e:
193
+ print(f'Simplifier failure: {e}')
194
+
195
+ # print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model
196
+ onnx.save(model_onnx,f)
197
+ print('ONNX export success, saved as %s' % f)
198
+ return f, model_onnx
199
+
200
+
201
+ @try_export
202
+ def export_openvino(file, metadata, half, prefix=colorstr('OpenVINO:')):
203
+ # YOLO OpenVINO export
204
+ check_requirements('openvino-dev') # requires openvino-dev: https://pypi.org/project/openvino-dev/
205
+ import openvino.inference_engine as ie
206
+
207
+ LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...')
208
+ f = str(file).replace('.pt', f'_openvino_model{os.sep}')
209
+
210
+ #cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} --data_type {'FP16' if half else 'FP32'}"
211
+ #cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} {"--compress_to_fp16" if half else ""}"
212
+ half_arg = "--compress_to_fp16" if half else ""
213
+ cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} {half_arg}"
214
+ subprocess.run(cmd.split(), check=True, env=os.environ) # export
215
+ yaml_save(Path(f) / file.with_suffix('.yaml').name, metadata) # add metadata.yaml
216
+ return f, None
217
+
218
+
219
+ @try_export
220
+ def export_paddle(model, im, file, metadata, prefix=colorstr('PaddlePaddle:')):
221
+ # YOLO Paddle export
222
+ check_requirements(('paddlepaddle', 'x2paddle'))
223
+ import x2paddle
224
+ from x2paddle.convert import pytorch2paddle
225
+
226
+ LOGGER.info(f'\n{prefix} starting export with X2Paddle {x2paddle.__version__}...')
227
+ f = str(file).replace('.pt', f'_paddle_model{os.sep}')
228
+
229
+ pytorch2paddle(module=model, save_dir=f, jit_type='trace', input_examples=[im]) # export
230
+ yaml_save(Path(f) / file.with_suffix('.yaml').name, metadata) # add metadata.yaml
231
+ return f, None
232
+
233
+
234
+ @try_export
235
+ def export_coreml(model, im, file, int8, half, prefix=colorstr('CoreML:')):
236
+ # YOLO CoreML export
237
+ check_requirements('coremltools')
238
+ import coremltools as ct
239
+
240
+ LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...')
241
+ f = file.with_suffix('.mlmodel')
242
+
243
+ ts = torch.jit.trace(model, im, strict=False) # TorchScript model
244
+ ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=im.shape, scale=1 / 255, bias=[0, 0, 0])])
245
+ bits, mode = (8, 'kmeans_lut') if int8 else (16, 'linear') if half else (32, None)
246
+ if bits < 32:
247
+ if MACOS: # quantization only supported on macOS
248
+ with warnings.catch_warnings():
249
+ warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress numpy==1.20 float warning
250
+ ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
251
+ else:
252
+ print(f'{prefix} quantization only supported on macOS, skipping...')
253
+ ct_model.save(f)
254
+ return f, ct_model
255
+
256
+
257
+ @try_export
258
+ def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
259
+ # YOLO TensorRT export https://developer.nvidia.com/tensorrt
260
+ assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
261
+ try:
262
+ import tensorrt as trt
263
+ except Exception:
264
+ if platform.system() == 'Linux':
265
+ check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')
266
+ import tensorrt as trt
267
+
268
+ if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
269
+ grid = model.model[-1].anchor_grid
270
+ model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
271
+ export_onnx(model, im, file, 12, dynamic, simplify) # opset 12
272
+ model.model[-1].anchor_grid = grid
273
+ else: # TensorRT >= 8
274
+ check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0
275
+ export_onnx(model, im, file, 12, dynamic, simplify) # opset 12
276
+ onnx = file.with_suffix('.onnx')
277
+
278
+ LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
279
+ assert onnx.exists(), f'failed to export ONNX file: {onnx}'
280
+ f = file.with_suffix('.engine') # TensorRT engine file
281
+ logger = trt.Logger(trt.Logger.INFO)
282
+ if verbose:
283
+ logger.min_severity = trt.Logger.Severity.VERBOSE
284
+
285
+ builder = trt.Builder(logger)
286
+ config = builder.create_builder_config()
287
+ config.max_workspace_size = workspace * 1 << 30
288
+ # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice
289
+
290
+ flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
291
+ network = builder.create_network(flag)
292
+ parser = trt.OnnxParser(network, logger)
293
+ if not parser.parse_from_file(str(onnx)):
294
+ raise RuntimeError(f'failed to load ONNX file: {onnx}')
295
+
296
+ inputs = [network.get_input(i) for i in range(network.num_inputs)]
297
+ outputs = [network.get_output(i) for i in range(network.num_outputs)]
298
+ for inp in inputs:
299
+ LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')
300
+ for out in outputs:
301
+ LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')
302
+
303
+ if dynamic:
304
+ if im.shape[0] <= 1:
305
+ LOGGER.warning(f"{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument")
306
+ profile = builder.create_optimization_profile()
307
+ for inp in inputs:
308
+ profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)
309
+ config.add_optimization_profile(profile)
310
+
311
+ LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}')
312
+ if builder.platform_has_fast_fp16 and half:
313
+ config.set_flag(trt.BuilderFlag.FP16)
314
+ with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
315
+ t.write(engine.serialize())
316
+ return f, None
317
+
318
+
319
+ @try_export
320
+ def export_saved_model(model,
321
+ im,
322
+ file,
323
+ dynamic,
324
+ tf_nms=False,
325
+ agnostic_nms=False,
326
+ topk_per_class=100,
327
+ topk_all=100,
328
+ iou_thres=0.45,
329
+ conf_thres=0.25,
330
+ keras=False,
331
+ prefix=colorstr('TensorFlow SavedModel:')):
332
+ # YOLO TensorFlow SavedModel export
333
+ try:
334
+ import tensorflow as tf
335
+ except Exception:
336
+ check_requirements(f"tensorflow{'' if torch.cuda.is_available() else '-macos' if MACOS else '-cpu'}")
337
+ import tensorflow as tf
338
+ from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
339
+
340
+ from models.tf import TFModel
341
+
342
+ LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
343
+ f = str(file).replace('.pt', '_saved_model')
344
+ batch_size, ch, *imgsz = list(im.shape) # BCHW
345
+
346
+ tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
347
+ im = tf.zeros((batch_size, *imgsz, ch)) # BHWC order for TensorFlow
348
+ _ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
349
+ inputs = tf.keras.Input(shape=(*imgsz, ch), batch_size=None if dynamic else batch_size)
350
+ outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
351
+ keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
352
+ keras_model.trainable = False
353
+ keras_model.summary()
354
+ if keras:
355
+ keras_model.save(f, save_format='tf')
356
+ else:
357
+ spec = tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)
358
+ m = tf.function(lambda x: keras_model(x)) # full model
359
+ m = m.get_concrete_function(spec)
360
+ frozen_func = convert_variables_to_constants_v2(m)
361
+ tfm = tf.Module()
362
+ tfm.__call__ = tf.function(lambda x: frozen_func(x)[:4] if tf_nms else frozen_func(x), [spec])
363
+ tfm.__call__(im)
364
+ tf.saved_model.save(tfm,
365
+ f,
366
+ options=tf.saved_model.SaveOptions(experimental_custom_gradients=False) if check_version(
367
+ tf.__version__, '2.6') else tf.saved_model.SaveOptions())
368
+ return f, keras_model
369
+
370
+
371
+ @try_export
372
+ def export_pb(keras_model, file, prefix=colorstr('TensorFlow GraphDef:')):
373
+ # YOLO TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow
374
+ import tensorflow as tf
375
+ from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
376
+
377
+ LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
378
+ f = file.with_suffix('.pb')
379
+
380
+ m = tf.function(lambda x: keras_model(x)) # full model
381
+ m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
382
+ frozen_func = convert_variables_to_constants_v2(m)
383
+ frozen_func.graph.as_graph_def()
384
+ tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)
385
+ return f, None
386
+
387
+
388
+ @try_export
389
+ def export_tflite(keras_model, im, file, int8, data, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')):
390
+ # YOLOv5 TensorFlow Lite export
391
+ import tensorflow as tf
392
+
393
+ LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
394
+ batch_size, ch, *imgsz = list(im.shape) # BCHW
395
+ f = str(file).replace('.pt', '-fp16.tflite')
396
+
397
+ converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
398
+ converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
399
+ converter.target_spec.supported_types = [tf.float16]
400
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
401
+ if int8:
402
+ from models.tf import representative_dataset_gen
403
+ dataset = LoadImages(check_dataset(check_yaml(data))['train'], img_size=imgsz, auto=False)
404
+ converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib=100)
405
+ converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
406
+ converter.target_spec.supported_types = []
407
+ converter.inference_input_type = tf.uint8 # or tf.int8
408
+ converter.inference_output_type = tf.uint8 # or tf.int8
409
+ converter.experimental_new_quantizer = True
410
+ f = str(file).replace('.pt', '-int8.tflite')
411
+ if nms or agnostic_nms:
412
+ converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS)
413
+
414
+ tflite_model = converter.convert()
415
+ open(f, "wb").write(tflite_model)
416
+ return f, None
417
+
418
+
419
+ @try_export
420
+ def export_edgetpu(file, prefix=colorstr('Edge TPU:')):
421
+ # YOLO Edge TPU export https://coral.ai/docs/edgetpu/models-intro/
422
+ cmd = 'edgetpu_compiler --version'
423
+ help_url = 'https://coral.ai/docs/edgetpu/compiler/'
424
+ assert platform.system() == 'Linux', f'export only supported on Linux. See {help_url}'
425
+ if subprocess.run(f'{cmd} >/dev/null', shell=True).returncode != 0:
426
+ LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}')
427
+ sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system
428
+ for c in (
429
+ 'curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -',
430
+ 'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list',
431
+ 'sudo apt-get update', 'sudo apt-get install edgetpu-compiler'):
432
+ subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True)
433
+ ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]
434
+
435
+ LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
436
+ f = str(file).replace('.pt', '-int8_edgetpu.tflite') # Edge TPU model
437
+ f_tfl = str(file).replace('.pt', '-int8.tflite') # TFLite model
438
+
439
+ cmd = f"edgetpu_compiler -s -d -k 10 --out_dir {file.parent} {f_tfl}"
440
+ subprocess.run(cmd.split(), check=True)
441
+ return f, None
442
+
443
+
444
+ @try_export
445
+ def export_tfjs(file, prefix=colorstr('TensorFlow.js:')):
446
+ # YOLO TensorFlow.js export
447
+ check_requirements('tensorflowjs')
448
+ import tensorflowjs as tfjs
449
+
450
+ LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
451
+ f = str(file).replace('.pt', '_web_model') # js dir
452
+ f_pb = file.with_suffix('.pb') # *.pb path
453
+ f_json = f'{f}/model.json' # *.json path
454
+
455
+ cmd = f'tensorflowjs_converter --input_format=tf_frozen_model ' \
456
+ f'--output_node_names=Identity,Identity_1,Identity_2,Identity_3 {f_pb} {f}'
457
+ subprocess.run(cmd.split())
458
+
459
+ json = Path(f_json).read_text()
460
+ with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order
461
+ subst = re.sub(
462
+ r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
463
+ r'"Identity.?.?": {"name": "Identity.?.?"}, '
464
+ r'"Identity.?.?": {"name": "Identity.?.?"}, '
465
+ r'"Identity.?.?": {"name": "Identity.?.?"}}}', r'{"outputs": {"Identity": {"name": "Identity"}, '
466
+ r'"Identity_1": {"name": "Identity_1"}, '
467
+ r'"Identity_2": {"name": "Identity_2"}, '
468
+ r'"Identity_3": {"name": "Identity_3"}}}', json)
469
+ j.write(subst)
470
+ return f, None
471
+
472
+
473
+ def add_tflite_metadata(file, metadata, num_outputs):
474
+ # Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata
475
+ with contextlib.suppress(ImportError):
476
+ # check_requirements('tflite_support')
477
+ from tflite_support import flatbuffers
478
+ from tflite_support import metadata as _metadata
479
+ from tflite_support import metadata_schema_py_generated as _metadata_fb
480
+
481
+ tmp_file = Path('/tmp/meta.txt')
482
+ with open(tmp_file, 'w') as meta_f:
483
+ meta_f.write(str(metadata))
484
+
485
+ model_meta = _metadata_fb.ModelMetadataT()
486
+ label_file = _metadata_fb.AssociatedFileT()
487
+ label_file.name = tmp_file.name
488
+ model_meta.associatedFiles = [label_file]
489
+
490
+ subgraph = _metadata_fb.SubGraphMetadataT()
491
+ subgraph.inputTensorMetadata = [_metadata_fb.TensorMetadataT()]
492
+ subgraph.outputTensorMetadata = [_metadata_fb.TensorMetadataT()] * num_outputs
493
+ model_meta.subgraphMetadata = [subgraph]
494
+
495
+ b = flatbuffers.Builder(0)
496
+ b.Finish(model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
497
+ metadata_buf = b.Output()
498
+
499
+ populator = _metadata.MetadataPopulator.with_model_file(file)
500
+ populator.load_metadata_buffer(metadata_buf)
501
+ populator.load_associated_files([str(tmp_file)])
502
+ populator.populate()
503
+ tmp_file.unlink()
504
+
505
+
506
+ @smart_inference_mode()
507
+ def run(
508
+ data=ROOT / 'data/coco.yaml', # 'dataset.yaml path'
509
+ weights=ROOT / 'yolo.pt', # weights path
510
+ imgsz=(640, 640), # image (height, width)
511
+ batch_size=1, # batch size
512
+ device='cpu', # cuda device, i.e. 0 or 0,1,2,3 or cpu
513
+ include=('torchscript', 'onnx'), # include formats
514
+ half=False, # FP16 half-precision export
515
+ inplace=False, # set YOLO Detect() inplace=True
516
+ keras=False, # use Keras
517
+ optimize=False, # TorchScript: optimize for mobile
518
+ int8=False, # CoreML/TF INT8 quantization
519
+ dynamic=False, # ONNX/TF/TensorRT: dynamic axes
520
+ simplify=False, # ONNX: simplify model
521
+ opset=12, # ONNX: opset version
522
+ verbose=False, # TensorRT: verbose log
523
+ workspace=4, # TensorRT: workspace size (GB)
524
+ nms=False, # TF: add NMS to model
525
+ agnostic_nms=False, # TF: add agnostic NMS to model
526
+ topk_per_class=100, # TF.js NMS: topk per class to keep
527
+ topk_all=100, # TF.js NMS: topk for all classes to keep
528
+ iou_thres=0.45, # TF.js NMS: IoU threshold
529
+ conf_thres=0.25, # TF.js NMS: confidence threshold
530
+ ):
531
+ t = time.time()
532
+ include = [x.lower() for x in include] # to lowercase
533
+ fmts = tuple(export_formats()['Argument'][1:]) # --include arguments
534
+ flags = [x in include for x in fmts]
535
+ assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {fmts}'
536
+ jit, onnx, onnx_end2end, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle = flags # export booleans
537
+ file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) # PyTorch weights
538
+
539
+ # Load PyTorch model
540
+ device = select_device(device)
541
+ if half:
542
+ assert device.type != 'cpu' or coreml, '--half only compatible with GPU export, i.e. use --device 0'
543
+ assert not dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic but not both'
544
+ model = attempt_load(weights, device=device, inplace=True, fuse=True) # load FP32 model
545
+
546
+ # Checks
547
+ imgsz *= 2 if len(imgsz) == 1 else 1 # expand
548
+ if optimize:
549
+ assert device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu'
550
+
551
+ # Input
552
+ gs = int(max(model.stride)) # grid size (max stride)
553
+ imgsz = [check_img_size(x, gs) for x in imgsz] # verify img_size are gs-multiples
554
+ im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
555
+
556
+ # Update model
557
+ model.eval()
558
+ for k, m in model.named_modules():
559
+ if isinstance(m, (Detect, DDetect, DualDetect, DualDDetect, IDualDDetect)):
560
+ m.inplace = inplace
561
+ m.dynamic = dynamic
562
+ m.export = True
563
+
564
+ for _ in range(2):
565
+ y = model(im) # dry runs
566
+ if half and not coreml:
567
+ im, model = im.half(), model.half() # to FP16
568
+ shape = tuple((y[0] if isinstance(y, (tuple, list)) else y).shape) # model output shape
569
+ metadata = {'stride': int(max(model.stride)), 'names': model.names} # model metadata
570
+ LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)")
571
+
572
+ # Exports
573
+ f = [''] * len(fmts) # exported filenames
574
+ warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
575
+ if jit: # TorchScript
576
+ f[0], _ = export_torchscript(model, im, file, optimize)
577
+ if engine: # TensorRT required before ONNX
578
+ f[1], _ = export_engine(model, im, file, half, dynamic, simplify, workspace, verbose)
579
+ if onnx or xml: # OpenVINO requires ONNX
580
+ f[2], _ = export_onnx(model, im, file, opset, dynamic, simplify)
581
+ if onnx_end2end:
582
+ if isinstance(model, DetectionModel):
583
+ labels = model.names
584
+ f[2], _ = export_onnx_end2end(model, im, file, simplify, topk_all, iou_thres, conf_thres, device, len(labels))
585
+ else:
586
+ raise RuntimeError("The model is not a DetectionModel.")
587
+ if xml: # OpenVINO
588
+ f[3], _ = export_openvino(file, metadata, half)
589
+ if coreml: # CoreML
590
+ f[4], _ = export_coreml(model, im, file, int8, half)
591
+ if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
592
+ assert not tflite or not tfjs, 'TFLite and TF.js models must be exported separately, please pass only one type.'
593
+ assert not isinstance(model, ClassificationModel), 'ClassificationModel export to TF formats not yet supported.'
594
+ f[5], s_model = export_saved_model(model.cpu(),
595
+ im,
596
+ file,
597
+ dynamic,
598
+ tf_nms=nms or agnostic_nms or tfjs,
599
+ agnostic_nms=agnostic_nms or tfjs,
600
+ topk_per_class=topk_per_class,
601
+ topk_all=topk_all,
602
+ iou_thres=iou_thres,
603
+ conf_thres=conf_thres,
604
+ keras=keras)
605
+ if pb or tfjs: # pb prerequisite to tfjs
606
+ f[6], _ = export_pb(s_model, file)
607
+ if tflite or edgetpu:
608
+ f[7], _ = export_tflite(s_model, im, file, int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms)
609
+ if edgetpu:
610
+ f[8], _ = export_edgetpu(file)
611
+ add_tflite_metadata(f[8] or f[7], metadata, num_outputs=len(s_model.outputs))
612
+ if tfjs:
613
+ f[9], _ = export_tfjs(file)
614
+ if paddle: # PaddlePaddle
615
+ f[10], _ = export_paddle(model, im, file, metadata)
616
+
617
+ # Finish
618
+ f = [str(x) for x in f if x] # filter out '' and None
619
+ if any(f):
620
+ cls, det, seg = (isinstance(model, x) for x in (ClassificationModel, DetectionModel, SegmentationModel)) # type
621
+ dir = Path('segment' if seg else 'classify' if cls else '')
622
+ h = '--half' if half else '' # --half FP16 inference arg
623
+ s = "# WARNING ⚠️ ClassificationModel not yet supported for PyTorch Hub AutoShape inference" if cls else \
624
+ "# WARNING ⚠️ SegmentationModel not yet supported for PyTorch Hub AutoShape inference" if seg else ''
625
+ if onnx_end2end:
626
+ LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
627
+ f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
628
+ f"\nVisualize: https://netron.app")
629
+ else:
630
+ LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
631
+ f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
632
+ f"\nDetect: python {dir / ('detect.py' if det else 'predict.py')} --weights {f[-1]} {h}"
633
+ f"\nValidate: python {dir / 'val.py'} --weights {f[-1]} {h}"
634
+ f"\nPyTorch Hub: model = torch.hub.load('ultralytics/yolov5', 'custom', '{f[-1]}') {s}"
635
+ f"\nVisualize: https://netron.app")
636
+ return f # return list of exported files/dirs
637
+
638
+
639
+ def parse_opt():
640
+ parser = argparse.ArgumentParser()
641
+ parser.add_argument('--data', type=str, default=ROOT / 'data/coco.yaml', help='dataset.yaml path')
642
+ parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolo.pt', help='model.pt path(s)')
643
+ parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640, 640], help='image (h, w)')
644
+ parser.add_argument('--batch-size', type=int, default=1, help='batch size')
645
+ parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
646
+ parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
647
+ parser.add_argument('--inplace', action='store_true', help='set YOLO Detect() inplace=True')
648
+ parser.add_argument('--keras', action='store_true', help='TF: use Keras')
649
+ parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')
650
+ parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization')
651
+ parser.add_argument('--dynamic', action='store_true', help='ONNX/TF/TensorRT: dynamic axes')
652
+ parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
653
+ parser.add_argument('--opset', type=int, default=12, help='ONNX: opset version')
654
+ parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')
655
+ parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)')
656
+ parser.add_argument('--nms', action='store_true', help='TF: add NMS to model')
657
+ parser.add_argument('--agnostic-nms', action='store_true', help='TF: add agnostic NMS to model')
658
+ parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep')
659
+ parser.add_argument('--topk-all', type=int, default=100, help='ONNX END2END/TF.js NMS: topk for all classes to keep')
660
+ parser.add_argument('--iou-thres', type=float, default=0.45, help='ONNX END2END/TF.js NMS: IoU threshold')
661
+ parser.add_argument('--conf-thres', type=float, default=0.25, help='ONNX END2END/TF.js NMS: confidence threshold')
662
+ parser.add_argument(
663
+ '--include',
664
+ nargs='+',
665
+ default=['torchscript'],
666
+ help='torchscript, onnx, onnx_end2end, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle')
667
+ opt = parser.parse_args()
668
+
669
+ if 'onnx_end2end' in opt.include:
670
+ opt.simplify = True
671
+ opt.dynamic = True
672
+ opt.inplace = True
673
+ opt.half = False
674
+
675
+ print_args(vars(opt))
676
+ return opt
677
+
678
+
679
+ def main(opt):
680
+ for opt.weights in (opt.weights if isinstance(opt.weights, list) else [opt.weights]):
681
+ run(**vars(opt))
682
+
683
+
684
+ if __name__ == "__main__":
685
+ opt = parse_opt()
686
+ main(opt)
hubconf.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
5
+ """Creates or loads a YOLO model
6
+
7
+ Arguments:
8
+ name (str): model name 'yolov3' or path 'path/to/best.pt'
9
+ pretrained (bool): load pretrained weights into the model
10
+ channels (int): number of input channels
11
+ classes (int): number of model classes
12
+ autoshape (bool): apply YOLO .autoshape() wrapper to model
13
+ verbose (bool): print all information to screen
14
+ device (str, torch.device, None): device to use for model parameters
15
+
16
+ Returns:
17
+ YOLO model
18
+ """
19
+ from pathlib import Path
20
+
21
+ from models.common import AutoShape, DetectMultiBackend
22
+ from models.experimental import attempt_load
23
+ from models.yolo import ClassificationModel, DetectionModel, SegmentationModel
24
+ from utils.downloads import attempt_download
25
+ from utils.general import LOGGER, check_requirements, intersect_dicts, logging
26
+ from utils.torch_utils import select_device
27
+
28
+ if not verbose:
29
+ LOGGER.setLevel(logging.WARNING)
30
+ check_requirements(exclude=('opencv-python', 'tensorboard', 'thop'))
31
+ name = Path(name)
32
+ path = name.with_suffix('.pt') if name.suffix == '' and not name.is_dir() else name # checkpoint path
33
+ try:
34
+ device = select_device(device)
35
+ if pretrained and channels == 3 and classes == 80:
36
+ try:
37
+ model = DetectMultiBackend(path, device=device, fuse=autoshape) # detection model
38
+ if autoshape:
39
+ if model.pt and isinstance(model.model, ClassificationModel):
40
+ LOGGER.warning('WARNING ⚠️ YOLO ClassificationModel is not yet AutoShape compatible. '
41
+ 'You must pass torch tensors in BCHW to this model, i.e. shape(1,3,224,224).')
42
+ elif model.pt and isinstance(model.model, SegmentationModel):
43
+ LOGGER.warning('WARNING ⚠️ YOLO SegmentationModel is not yet AutoShape compatible. '
44
+ 'You will not be able to run inference with this model.')
45
+ else:
46
+ model = AutoShape(model) # for file/URI/PIL/cv2/np inputs and NMS
47
+ except Exception:
48
+ model = attempt_load(path, device=device, fuse=False) # arbitrary model
49
+ else:
50
+ cfg = list((Path(__file__).parent / 'models').rglob(f'{path.stem}.yaml'))[0] # model.yaml path
51
+ model = DetectionModel(cfg, channels, classes) # create model
52
+ if pretrained:
53
+ ckpt = torch.load(attempt_download(path), map_location=device) # load
54
+ csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
55
+ csd = intersect_dicts(csd, model.state_dict(), exclude=['anchors']) # intersect
56
+ model.load_state_dict(csd, strict=False) # load
57
+ if len(ckpt['model'].names) == classes:
58
+ model.names = ckpt['model'].names # set class names attribute
59
+ if not verbose:
60
+ LOGGER.setLevel(logging.INFO) # reset to default
61
+ return model.to(device)
62
+
63
+ except Exception as e:
64
+ help_url = 'https://github.com/ultralytics/yolov5/issues/36'
65
+ s = f'{e}. Cache may be out of date, try `force_reload=True` or see {help_url} for help.'
66
+ raise Exception(s) from e
67
+
68
+
69
+ def custom(path='path/to/model.pt', autoshape=True, _verbose=True, device=None):
70
+ # YOLO custom or local model
71
+ return _create(path, autoshape=autoshape, verbose=_verbose, device=device)
72
+
73
+
74
+ if __name__ == '__main__':
75
+ import argparse
76
+ from pathlib import Path
77
+
78
+ import numpy as np
79
+ from PIL import Image
80
+
81
+ from utils.general import cv2, print_args
82
+
83
+ # Argparser
84
+ parser = argparse.ArgumentParser()
85
+ parser.add_argument('--model', type=str, default='yolo', help='model name')
86
+ opt = parser.parse_args()
87
+ print_args(vars(opt))
88
+
89
+ # Model
90
+ model = _create(name=opt.model, pretrained=True, channels=3, classes=80, autoshape=True, verbose=True)
91
+ # model = custom(path='path/to/model.pt') # custom
92
+
93
+ # Images
94
+ imgs = [
95
+ 'data/images/zidane.jpg', # filename
96
+ Path('data/images/zidane.jpg'), # Path
97
+ 'https://ultralytics.com/images/zidane.jpg', # URI
98
+ cv2.imread('data/images/bus.jpg')[:, :, ::-1], # OpenCV
99
+ Image.open('data/images/bus.jpg'), # PIL
100
+ np.zeros((320, 640, 3))] # numpy
101
+
102
+ # Inference
103
+ results = model(imgs, size=320) # batched inference
104
+
105
+ # Results
106
+ results.print()
107
+ results.save()
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # init
models/attention/blocks.py ADDED
@@ -0,0 +1,631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Shuffle
2
+ # CBAM
3
+ # -- GAM ECA SE SK LSK
4
+ from models.common import *
5
+
6
+ class RepNCBAM(nn.Module):
7
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
8
+ super().__init__()
9
+ c_ = int(c2 * e) # hidden channels
10
+ self.cv1 = Conv(c1, c_, 1, 1)
11
+ self.cv2 = Conv(c1, c_, 1, 1)
12
+ self.cv3 = Conv(2 * c_, c2, 1)
13
+ self.m = nn.Sequential(*(CBAMBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
14
+
15
+ def forward(self, x):
16
+ return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
17
+
18
+ class RepNSA(nn.Module):
19
+ def __init__(self, c1, c2, n=1, shortcut=True, g=16, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
20
+ super().__init__()
21
+ c_ = int(c2 * e) # hidden channels
22
+ self.cv1 = Conv(c1, c_, 1, 1)
23
+ self.cv2 = Conv(c1, c_, 1, 1)
24
+ self.cv3 = Conv(2 * c_, c2, 1)
25
+ self.m = nn.Sequential(*(SABottleneck(c_, c_, 1, shortcut, g=g) for _ in range(n)))
26
+
27
+ def forward(self, x):
28
+ return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
29
+
30
+ class RepNLSK(nn.Module):
31
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
32
+ super().__init__()
33
+ c_ = int(c2 * e) # hidden channels
34
+ self.cv1 = Conv(c1, c_, 1, 1)
35
+ self.cv2 = Conv(c1, c_, 1, 1)
36
+ self.cv3 = Conv(2 * c_, c2, 1)
37
+ self.m = nn.Sequential(*(LSKBottleneck(c_, c_, 1, shortcut, g=g) for _ in range(n)))
38
+
39
+ def forward(self, x):
40
+ return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
41
+
42
+ class RepNECA(nn.Module):
43
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
44
+ super().__init__()
45
+ c_ = int(c2 * e) # hidden channels
46
+ self.cv1 = Conv(c1, c_, 1, 1)
47
+ self.cv2 = Conv(c1, c_, 1, 1)
48
+ self.cv3 = Conv(2 * c_, c2, 1)
49
+ self.m = nn.Sequential(*(ECABottleneck(c_, c_, shortcut, g=g) for _ in range(n)))
50
+
51
+ def forward(self, x):
52
+ return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
53
+
54
+ # ----------------------- Attention Mechanism ---------------------------
55
+
56
+ ## CBAM ATTENTION
57
+ class ChannelAttention(nn.Module):
58
+
59
+ def __init__(self, in_planes, ratio=16):
60
+ super().__init__()
61
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
62
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
63
+ self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
64
+ self.act = nn.SiLU()
65
+
66
+ self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
67
+ self.sigmoid = nn.Sigmoid()
68
+
69
+ def forward(self, x):
70
+ avg_out = self.f2(self.act(self.f1(self.avg_pool(x))))
71
+ max_out = self.f2(self.act(self.f1(self.max_pool(x))))
72
+ out = self.sigmoid(avg_out + max_out)
73
+ return out
74
+
75
+
76
+ class SpatialAttention(nn.Module):
77
+
78
+ def __init__(self, kernel_size=3):
79
+ super().__init__()
80
+ assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
81
+ padding = 3 if kernel_size == 7 else 1
82
+ self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
83
+ self.sigmoid = nn.Sigmoid()
84
+
85
+ def forward(self, x):
86
+ # 1*h*w
87
+ avg_out = torch.mean(x, dim=1, keepdim=True)
88
+ max_out, _ = torch.max(x, dim=1, keepdim=True)
89
+ x = torch.cat([avg_out, max_out], dim=1)
90
+ #2*h*w
91
+ x = self.conv(x)
92
+ #1*h*w
93
+ return self.sigmoid(x)
94
+
95
+ class CBAMBottleneck(nn.Module):
96
+ def __init__(self,
97
+ c1,
98
+ c2,
99
+ shortcut=True,
100
+ g=1,
101
+ e=0.5,
102
+ ratio=16,
103
+ kernel_size=3): # ch_in, ch_out, shortcut, groups, expansion
104
+ super().__init__()
105
+ c_ = int(c2 * e) # hidden channels
106
+ self.cv1 = Conv(c1, c_, 1, 1)
107
+ self.cv2 = Conv(c_, c2, 3, 1, g=g)
108
+ self.add = shortcut and c1 == c2
109
+ self.channel_attention = ChannelAttention(c2, ratio)
110
+ self.spatial_attention = SpatialAttention(kernel_size)
111
+
112
+ def forward(self, x):
113
+ x1 = self.cv2(self.cv1(x))
114
+ out = self.channel_attention(x1) * x1
115
+ # print('outchannels:{}'.format(out.shape))
116
+ out = self.spatial_attention(out) * out
117
+ return x + out if self.add else out
118
+
119
+ class CBAMC4(nn.Module):
120
+ def __init__(self, c1, c2, c3, c4, c5=1):
121
+ super(CBAMC4, self).__init__()
122
+ self.c = c3 // 2
123
+ self.cv1 = Conv(c1, c3, 1, 1)
124
+ self.cv2 = nn.Sequential(RepNCSP(c3 // 2, c4, c5), Conv(c4, c4, 3, 1))
125
+ self.cv3 = nn.Sequential(RepNCSP(c4, c4, c5), Conv(c4, c4, 3, 1))
126
+ self.cv4 = Conv(c3 + (2 * c4), c2, 1, 1)
127
+ self.channel_attention = ChannelAttention(c2)
128
+ self.spatial_attention = SpatialAttention(kernel_size=3) # Specify kernel_size here
129
+
130
+ def forward(self, x):
131
+ y = list(self.cv1(x).chunk(2, 1))
132
+ y.extend((m(y[-1])) for m in [self.cv2, self.cv3])
133
+ y = torch.cat(y, 1)
134
+
135
+ # Apply channel attention
136
+ y = y * self.channel_attention(y)
137
+
138
+ # Apply spatial attention
139
+ y = y * self.spatial_attention(y)
140
+
141
+ return self.cv4(y)
142
+
143
+ def forward_split(self, x):
144
+ y = list(self.cv1(x).split((self.c, self.c), 1))
145
+ y.extend(m(y[-1]) for m in [self.cv2, self.cv3])
146
+ y = torch.cat(y, 1)
147
+
148
+ # Apply channel attention
149
+ y = y * self.channel_attention(y)
150
+
151
+ # Apply spatial attention
152
+ y = y * self.spatial_attention(y)
153
+
154
+ return self.cv4(y)
155
+
156
+ class RepNCBAMELAN4(RepNCSPELAN4):
157
+ # C3 module with CBAMBottleneck()
158
+ def __init__(self, c1, c2, c3, c4, c5=1):
159
+ super().__init__(c1, c2, c3, c4, c5)
160
+ self.cv2 = nn.Sequential(RepNCBAM(c3//2, c4, c5), Conv(c4, c4, 3, 1))
161
+ self.cv3 = nn.Sequential(RepNCBAM(c4, c4, c5), Conv(c4, c4, 3, 1))
162
+
163
+ # c_ = int(c2 * e) # hidden channels
164
+ # self.m = nn.Sequential(*(RepCBAM(c_, c_, shortcut) for _ in range(n)))
165
+
166
+ ## GAM ATTETION
167
+ class GAMAttention(nn.Module):
168
+ #https://paperswithcode.com/paper/global-attention-mechanism-retain-information
169
+ def __init__(self, c1, c2, group=True,rate=4):
170
+ super(GAMAttention, self).__init__()
171
+
172
+ self.channel_attention = nn.Sequential(
173
+ nn.Linear(c1, int(c1 / rate)),
174
+ nn.ReLU(inplace=True),
175
+ nn.Linear(int(c1 / rate), c1)
176
+ )
177
+ self.spatial_attention = nn.Sequential(
178
+ nn.Conv2d(c1, c1//rate, kernel_size=7, padding=3,groups=rate)if group else nn.Conv2d(c1, int(c1 / rate), kernel_size=7, padding=3),
179
+ nn.BatchNorm2d(int(c1 /rate)),
180
+ nn.ReLU(inplace=True),
181
+ nn.Conv2d(c1//rate, c2, kernel_size=7, padding=3,groups=rate) if group else nn.Conv2d(int(c1 / rate), c2, kernel_size=7, padding=3),
182
+ nn.BatchNorm2d(c2)
183
+ )
184
+
185
+ def forward(self, x):
186
+ b, c, h, w = x.shape
187
+ x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
188
+ x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
189
+ x_channel_att = x_att_permute.permute(0, 3, 1, 2)
190
+ x = x * x_channel_att
191
+
192
+ x_spatial_att = self.spatial_attention(x).sigmoid()
193
+ x_spatial_att=channel_shuffle(x_spatial_att,4) #last shuffle
194
+ out = x * x_spatial_att
195
+ return out
196
+
197
+ def channel_shuffle(x, groups=2):
198
+ B, C, H, W = x.size()
199
+ out = x.view(B, groups, C // groups, H, W).permute(0, 2, 1, 3, 4).contiguous()
200
+ out=out.view(B, C, H, W)
201
+ return out
202
+
203
+
204
+ ## SK ATTENTION
205
+
206
+ class SKAttention(nn.Module):
207
+
208
+ def __init__(self, channel=512,out_channel=512,kernels=[1,3,5,7],reduction=16,group=1,L=32):
209
+ super().__init__()
210
+ self.d=max(L,channel//reduction)
211
+ self.convs=nn.ModuleList([])
212
+ for k in kernels:
213
+ self.convs.append(
214
+ nn.Sequential(OrderedDict([
215
+ ('conv',nn.Conv2d(channel,channel,kernel_size=k,padding=k//2,groups=group)),
216
+ ('bn',nn.BatchNorm2d(channel)),
217
+ ('relu',nn.ReLU())
218
+ ]))
219
+ )
220
+ self.fc=nn.Linear(channel,self.d)
221
+ self.fcs=nn.ModuleList([])
222
+ for i in range(len(kernels)):
223
+ self.fcs.append(nn.Linear(self.d,channel))
224
+ self.softmax=nn.Softmax(dim=0)
225
+
226
+ def forward(self, x):
227
+ bs, c, _, _ = x.size()
228
+ conv_outs=[]
229
+ ### split
230
+ for conv in self.convs:
231
+ conv_outs.append(conv(x))
232
+ feats=torch.stack(conv_outs,0)#k,bs,channel,h,w
233
+
234
+ ### fuse
235
+ U=sum(conv_outs) #bs,c,h,w
236
+
237
+ ### reduction channel
238
+ S=U.mean(-1).mean(-1) #bs,c
239
+ Z=self.fc(S) #bs,d
240
+
241
+ ### calculate attention weight
242
+ weights=[]
243
+ for fc in self.fcs:
244
+ weight=fc(Z)
245
+ weights.append(weight.view(bs,c,1,1)) #bs,channel
246
+ attention_weughts=torch.stack(weights,0)#k,bs,channel,1,1
247
+ attention_weughts=self.softmax(attention_weughts)#k,bs,channel,1,1
248
+
249
+ ### fuse
250
+ V=(attention_weughts*feats).sum(0)
251
+ return V
252
+
253
+ ## SHUFFLE ATTENTION
254
+ from torch.nn.parameter import Parameter
255
+ from torch.nn import init
256
+
257
+ class sa_layer(nn.Module):
258
+ """Constructs a Channel Spatial Group module.
259
+
260
+ Args:
261
+ k_size: Adaptive selection of kernel size
262
+ """
263
+
264
+ def __init__(self, channel, groups=16):
265
+ super(sa_layer, self).__init__()
266
+ self.groups = groups
267
+ self.channel = channel
268
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
269
+ self.gn = nn.GroupNorm(self.channel // (2 * self.groups), self.channel // (2 * self.groups))
270
+ self.cweight = Parameter(torch.zeros(1, self.channel // (2 * self.groups), 1, 1))
271
+ self.cbias = Parameter(torch.ones(1, self.channel // (2 * self.groups), 1, 1))
272
+ self.sweight = Parameter(torch.zeros(1, self.channel // (2 * self.groups), 1, 1))
273
+ self.sbias = Parameter(torch.ones(1, self.channel // (2 * self.groups), 1, 1))
274
+
275
+ self.sigmoid = nn.Sigmoid()
276
+ self.gn = nn.GroupNorm(self.channel // (2 * self.groups), self.channel // (2 * self.groups))
277
+
278
+ @staticmethod
279
+ def channel_shuffle(x, groups):
280
+ b, c, h, w = x.shape
281
+
282
+ x = x.reshape(b, groups, -1, h, w)
283
+ x = x.permute(0, 2, 1, 3, 4)
284
+
285
+ # flatten
286
+ x = x.reshape(b, -1, h, w)
287
+
288
+ return x
289
+
290
+ def forward(self, x):
291
+ b, c, h, w = x.shape
292
+ # group into subfeatures
293
+ x = x.reshape(b * self.groups, -1, h, w)
294
+ # channel_split
295
+ x_0, x_1 = x.chunk(2, dim=1)
296
+ # channel attention
297
+ xn = self.avg_pool(x_0)
298
+ xn = self.cweight * xn + self.cbias
299
+ xn = x_0 * self.sigmoid(xn)
300
+ # spatial attention
301
+ xs = self.gn(x_1)
302
+ xs = self.sweight * xs + self.sbias
303
+ xs = x_1 * self.sigmoid(xs)
304
+
305
+ # concatenate along channel axis
306
+ out = torch.cat([xn, xs], dim=1)
307
+ out = out.reshape(b, -1, h, w)
308
+
309
+ out = self.channel_shuffle(out, 2)
310
+ return out
311
+
312
+
313
+ class SABottleneck(nn.Module):
314
+ # expansion = 4
315
+ def __init__(self, c1, c2, s=1, shortcut=True, k=(1, 3), e=0.5, g=1):
316
+ super(SABottleneck, self).__init__()
317
+ c_ = c2 // 2
318
+ self.shortcut = shortcut
319
+
320
+ self.conv1 = Conv(c1, c_, k[0], s)
321
+ self.conv2 = Conv(c_, c2, k[1], s, g=g)
322
+ self.add = shortcut and c1 == c2
323
+ self.sa = sa_layer(c2, g)
324
+
325
+ def forward(self, x):
326
+ x1 = self.conv2(self.conv1(x))
327
+ y = self.sa(x1)
328
+ out = y
329
+
330
+ return x + out if self.add else out
331
+
332
+ class RepNSAELAN4(RepNCSPELAN4):
333
+ def __init__(self, c1, c2, c3, c4, c5=1):
334
+ super().__init__(c1, c2, c3, c4, c5)
335
+ self.cv2 = nn.Sequential(RepNSA(c3//2, c4, c5), Conv(c4, c4, 3, 1))
336
+ self.cv3 = nn.Sequential(RepNSA(c4, c4, c5), Conv(c4, c4, 3, 1))
337
+
338
+ ## ECA
339
+ class EfficientChannelAttention(nn.Module): # Efficient Channel Attention module
340
+ def __init__(self, c, b=1, gamma=2):
341
+ super(EfficientChannelAttention, self).__init__()
342
+ t = int(abs((math.log(c, 2) + b) / gamma))
343
+ k = t if t % 2 else t + 1
344
+
345
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
346
+ self.conv1 = nn.Conv1d(1, 1, kernel_size=k, padding=int(k/2), bias=False)
347
+ self.sigmoid = nn.Sigmoid()
348
+
349
+ def forward(self, x):
350
+ out = self.avg_pool(x)
351
+ out = self.conv1(out.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
352
+ out = self.sigmoid(out)
353
+ return out * x
354
+
355
+ class ECABottleneck(nn.Module):
356
+ # Standard bottleneck
357
+ def __init__(self,
358
+ c1,
359
+ c2,
360
+ shortcut=True,
361
+ g=1,
362
+ e=0.5,
363
+ ratio=16,
364
+ k_size=3): # ch_in, ch_out, shortcut, groups, expansion
365
+ super().__init__()
366
+ c_ = int(c2 * e) # hidden channels
367
+ self.cv1 = Conv(c1, c_, 1, 1)
368
+ self.cv2 = Conv(c_, c2, 3, 1, g=g)
369
+ self.add = shortcut and c1 == c2
370
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
371
+ self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
372
+ self.sigmoid = nn.Sigmoid()
373
+
374
+ def forward(self, x):
375
+ x1 = self.cv2(self.cv1(x))
376
+ y = self.avg_pool(x1)
377
+ y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
378
+ y = self.sigmoid(y)
379
+ out = x1 * y.expand_as(x1)
380
+
381
+ return x + out if self.add else out
382
+
383
+ class RepNECALAN4(RepNCSPELAN4):
384
+ def __init__(self, c1, c2, c3, c4, c5=1):
385
+ super().__init__(c1, c2, c3, c4, c5)
386
+ self.cv2 = nn.Sequential(RepNECA(c3//2, c4, c5), Conv(c4, c4, 3, 1))
387
+ self.cv3 = nn.Sequential(RepNECA(c4, c4, c5), Conv(c4, c4, 3, 1))
388
+
389
+ ## LSK Attention
390
+ class LSKblock(nn.Module):
391
+ def __init__(self, c1):
392
+ super().__init__()
393
+ self.conv0 = nn.Conv2d(c1, c1, 5, padding=2, groups=c1)
394
+ self.conv_spatial = nn.Conv2d(c1, c1, 7, stride=1, padding=9, groups=c1, dilation=3)
395
+ self.conv1 = nn.Conv2d(c1, c1//2, 1)
396
+ self.conv2 = nn.Conv2d(c1, c1//2, 1)
397
+ # self.cv2 = nn.Sequential(RepNCSP(c3 // 2, c4, c5), Conv(c4, c4, 3, 1))
398
+ self.conv_squeeze = nn.Conv2d(2, 2, 7, padding=3)
399
+ self.conv = nn.Conv2d(c1//2, c1, 1)
400
+
401
+ def forward(self, x):
402
+ attn1 = self.conv0(x)
403
+ attn2 = self.conv_spatial(attn1)
404
+
405
+ attn1 = self.conv1(attn1)
406
+ attn2 = self.conv2(attn2)
407
+
408
+ attn = torch.cat([attn1, attn2], dim=1)
409
+ avg_attn = torch.mean(attn, dim=1, keepdim=True)
410
+ max_attn, _ = torch.max(attn, dim=1, keepdim=True)
411
+ agg = torch.cat([avg_attn, max_attn], dim=1)
412
+ sig = self.conv_squeeze(agg).sigmoid()
413
+ attn = attn1 * sig[:,0,:,:].unsqueeze(1) + attn2 * sig[:,1,:,:].unsqueeze(1)
414
+ attn = self.conv(attn)
415
+ return x * attn
416
+
417
+
418
+
419
+ # class LSKAttention(nn.Module):
420
+ # def __init__(self, c1, c2, shortcut = True):
421
+ # super().__init__()
422
+
423
+ # self.conv1 = Conv(c1, c1, 1)
424
+ # self.spatial_gating_unit = LSKblock(c1)
425
+ # self.conv2 = Conv(c1, c2, 1)
426
+ # self.add = shortcut and c1 == c2
427
+
428
+
429
+ # def forward(self, x):
430
+ # x1 = self.conv1(x)
431
+ # x = self.spatial_gating_unit(x)
432
+ # x = self.proj_2(x)
433
+ # x = x + shorcut
434
+ # return x
435
+
436
+ class LSKBottleneck(nn.Module):
437
+ # expansion = 4
438
+ def __init__(self, c1, c2, s=1, shortcut=True, g=1):
439
+ super(LSKBottleneck, self).__init__()
440
+ c_ = c2 // 2
441
+ self.shortcut = shortcut
442
+ self.add = shortcut and c1 == c2
443
+ self.conv1 = Conv(c1, c_, 1)
444
+ self.conv2 = Conv(c_, c2, 3, s, g= g)
445
+ self.lsk = LSKblock(c2)
446
+
447
+ def forward(self, x):
448
+ x1 = self.conv2(self.conv1(x))
449
+ y = self.lsk(x1)
450
+ out = y
451
+
452
+ return x + out if self.add else out
453
+
454
+
455
+ class RepNLSKELAN4(RepNCSPELAN4):
456
+ def __init__(self, c1, c2, c3, c4, c5=1):
457
+ super().__init__(c1, c2, c3, c4, c5)
458
+ self.cv2 = nn.Sequential(RepNLSK(c3//2, c4, c5), Conv(c4, c4, 3, 1))
459
+ self.cv3 = nn.Sequential(RepNLSK(c4, c4, c5), Conv(c4, c4, 3, 1))
460
+
461
+ ## SE Attention
462
+ class SEBottleneck(nn.Module):
463
+ # Standard bottleneck
464
+ def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, ratio=16): # ch_in, ch_out, shortcut, groups, expansion
465
+ super().__init__()
466
+ c_ = int(c2 * e) # hidden channels
467
+ self.cv1 = Conv(c1, c_, 1, 1)
468
+ self.cv2 = Conv(c_, c2, 3, 1, g=g)
469
+ self.add = shortcut and c1 == c2
470
+ # self.se=SE(c1,c2,ratio)
471
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
472
+ self.l1 = nn.Linear(c1, c1 // ratio, bias=False)
473
+ self.relu = nn.ReLU(inplace=True)
474
+ self.l2 = nn.Linear(c1 // ratio, c1, bias=False)
475
+ self.sig = nn.Sigmoid()
476
+
477
+ def forward(self, x):
478
+ x1 = self.cv2(self.cv1(x))
479
+ b, c, _, _ = x.size()
480
+ y = self.avgpool(x1).view(b, c)
481
+ y = self.l1(y)
482
+ y = self.relu(y)
483
+ y = self.l2(y)
484
+ y = self.sig(y)
485
+ y = y.view(b, c, 1, 1)
486
+ out = x1 * y.expand_as(x1)
487
+
488
+ # out=self.se(x1)*x1
489
+ return x + out if self.add else out
490
+
491
+ ## SOCA Attention
492
+ from torch.autograd import Function
493
+
494
+ class Covpool(Function):
495
+ @staticmethod
496
+ def forward(ctx, input):
497
+ x = input
498
+ batchSize = x.data.shape[0]
499
+ dim = x.data.shape[1]
500
+ h = x.data.shape[2]
501
+ w = x.data.shape[3]
502
+ M = h*w
503
+ x = x.reshape(batchSize,dim,M)
504
+ I_hat = (-1./M/M)*torch.ones(M,M,device = x.device) + (1./M)*torch.eye(M,M,device = x.device)
505
+ I_hat = I_hat.view(1,M,M).repeat(batchSize,1,1).type(x.dtype)
506
+ y = x.bmm(I_hat).bmm(x.transpose(1,2))
507
+ ctx.save_for_backward(input,I_hat)
508
+ return y
509
+ @staticmethod
510
+ def backward(ctx, grad_output):
511
+ input,I_hat = ctx.saved_tensors
512
+ x = input
513
+ batchSize = x.data.shape[0]
514
+ dim = x.data.shape[1]
515
+ h = x.data.shape[2]
516
+ w = x.data.shape[3]
517
+ M = h*w
518
+ x = x.reshape(batchSize,dim,M)
519
+ grad_input = grad_output + grad_output.transpose(1,2)
520
+ grad_input = grad_input.bmm(x).bmm(I_hat)
521
+ grad_input = grad_input.reshape(batchSize,dim,h,w)
522
+ return grad_input
523
+
524
+ class Sqrtm(Function):
525
+ @staticmethod
526
+ def forward(ctx, input, iterN):
527
+ x = input
528
+ batchSize = x.data.shape[0]
529
+ dim = x.data.shape[1]
530
+ dtype = x.dtype
531
+ I3 = 3.0*torch.eye(dim,dim,device = x.device).view(1, dim, dim).repeat(batchSize,1,1).type(dtype)
532
+ normA = (1.0/3.0)*x.mul(I3).sum(dim=1).sum(dim=1)
533
+ A = x.div(normA.view(batchSize,1,1).expand_as(x))
534
+ Y = torch.zeros(batchSize, iterN, dim, dim, requires_grad = False, device = x.device)
535
+ Z = torch.eye(dim,dim,device = x.device).view(1,dim,dim).repeat(batchSize,iterN,1,1)
536
+ if iterN < 2:
537
+ ZY = 0.5*(I3 - A)
538
+ Y[:,0,:,:] = A.bmm(ZY)
539
+ else:
540
+ ZY = 0.5*(I3 - A)
541
+ Y[:,0,:,:] = A.bmm(ZY)
542
+ Z[:,0,:,:] = ZY
543
+ for i in range(1, iterN-1):
544
+ ZY = 0.5*(I3 - Z[:,i-1,:,:].bmm(Y[:,i-1,:,:]))
545
+ Y[:,i,:,:] = Y[:,i-1,:,:].bmm(ZY)
546
+ Z[:,i,:,:] = ZY.bmm(Z[:,i-1,:,:])
547
+ ZY = 0.5*Y[:,iterN-2,:,:].bmm(I3 - Z[:,iterN-2,:,:].bmm(Y[:,iterN-2,:,:]))
548
+ y = ZY*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x)
549
+ ctx.save_for_backward(input, A, ZY, normA, Y, Z)
550
+ ctx.iterN = iterN
551
+ return y
552
+ @staticmethod
553
+ def backward(ctx, grad_output):
554
+ input, A, ZY, normA, Y, Z = ctx.saved_tensors
555
+ iterN = ctx.iterN
556
+ x = input
557
+ batchSize = x.data.shape[0]
558
+ dim = x.data.shape[1]
559
+ dtype = x.dtype
560
+ der_postCom = grad_output*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x)
561
+ der_postComAux = (grad_output*ZY).sum(dim=1).sum(dim=1).div(2*torch.sqrt(normA))
562
+ I3 = 3.0*torch.eye(dim,dim,device = x.device).view(1, dim, dim).repeat(batchSize,1,1).type(dtype)
563
+ if iterN < 2:
564
+ der_NSiter = 0.5*(der_postCom.bmm(I3 - A) - A.bmm(der_sacleTrace))
565
+ else:
566
+ dldY = 0.5*(der_postCom.bmm(I3 - Y[:,iterN-2,:,:].bmm(Z[:,iterN-2,:,:])) -
567
+ Z[:,iterN-2,:,:].bmm(Y[:,iterN-2,:,:]).bmm(der_postCom))
568
+ dldZ = -0.5*Y[:,iterN-2,:,:].bmm(der_postCom).bmm(Y[:,iterN-2,:,:])
569
+ for i in range(iterN-3, -1, -1):
570
+ YZ = I3 - Y[:,i,:,:].bmm(Z[:,i,:,:])
571
+ ZY = Z[:,i,:,:].bmm(Y[:,i,:,:])
572
+ dldY_ = 0.5*(dldY.bmm(YZ) -
573
+ Z[:,i,:,:].bmm(dldZ).bmm(Z[:,i,:,:]) -
574
+ ZY.bmm(dldY))
575
+ dldZ_ = 0.5*(YZ.bmm(dldZ) -
576
+ Y[:,i,:,:].bmm(dldY).bmm(Y[:,i,:,:]) -
577
+ dldZ.bmm(ZY))
578
+ dldY = dldY_
579
+ dldZ = dldZ_
580
+ der_NSiter = 0.5*(dldY.bmm(I3 - A) - dldZ - A.bmm(dldY))
581
+ grad_input = der_NSiter.div(normA.view(batchSize,1,1).expand_as(x))
582
+ grad_aux = der_NSiter.mul(x).sum(dim=1).sum(dim=1)
583
+ for i in range(batchSize):
584
+ grad_input[i,:,:] += (der_postComAux[i] \
585
+ - grad_aux[i] / (normA[i] * normA[i])) \
586
+ *torch.ones(dim,device = x.device).diag()
587
+ return grad_input, None
588
+
589
+ def CovpoolLayer(var):
590
+ return Covpool.apply(var)
591
+
592
+ def SqrtmLayer(var, iterN):
593
+ return Sqrtm.apply(var, iterN)
594
+
595
+ class SOCA(nn.Module):
596
+ # Second-order Channel Attention
597
+ def __init__(self, c1, c2, reduction=8):
598
+ super(SOCA, self).__init__()
599
+ self.max_pool = nn.MaxPool2d(kernel_size=2)
600
+
601
+ self.conv_du = nn.Sequential(
602
+ nn.Conv2d(c1, c1 // reduction, 1, padding=0, bias=True),
603
+ nn.SiLU(), # SiLU activation
604
+ nn.Conv2d(c1 // reduction, c1, 1, padding=0, bias=True),
605
+ nn.Sigmoid()
606
+ )
607
+
608
+ def forward(self, x):
609
+ batch_size, C, h, w = x.shape # x: NxCxHxW
610
+ N = int(h * w)
611
+ min_h = min(h, w)
612
+ h1 = 1000
613
+ w1 = 1000
614
+ if h < h1 and w < w1:
615
+ x_sub = x
616
+ elif h < h1 and w > w1:
617
+ W = (w - w1) // 2
618
+ x_sub = x[:, :, :, W:(W + w1)]
619
+ elif w < w1 and h > h1:
620
+ H = (h - h1) // 2
621
+ x_sub = x[:, :, H:H + h1, :]
622
+ else:
623
+ H = (h - h1) // 2
624
+ W = (w - w1) // 2
625
+ x_sub = x[:, :, H:(H + h1), W:(W + w1)]
626
+ cov_mat = CovpoolLayer(x_sub) # Global Covariance pooling layer
627
+ cov_mat_sqrt = SqrtmLayer(cov_mat, 5) # Matrix square root layer (including pre-norm, Newton-Schulz iter. and post-com. with 5 iterations)
628
+ cov_mat_sum = torch.mean(cov_mat_sqrt, 1)
629
+ cov_mat_sum = cov_mat_sum.view(batch_size, C, 1, 1)
630
+ y_cov = self.conv_du(cov_mat_sum)
631
+ return y_cov * x
models/common.py ADDED
@@ -0,0 +1,1360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import contextlib
3
+ import json
4
+ import math
5
+ import platform
6
+ import warnings
7
+ import zipfile
8
+ from collections import OrderedDict, namedtuple
9
+ from copy import copy
10
+ from pathlib import Path
11
+ from urllib.parse import urlparse
12
+
13
+ from typing import Optional
14
+
15
+ import cv2
16
+ import numpy as np
17
+ import pandas as pd
18
+ import requests
19
+ import torch
20
+ import torch.nn as nn
21
+ from IPython.display import display
22
+ from PIL import Image
23
+ from torch.cuda import amp
24
+
25
+ from utils import TryExcept
26
+ from utils.dataloaders import exif_transpose, letterbox
27
+ from utils.general import (LOGGER, ROOT, Profile, check_requirements, check_suffix, check_version, colorstr,
28
+ increment_path, is_notebook, make_divisible, non_max_suppression, scale_boxes,
29
+ xywh2xyxy, xyxy2xywh, yaml_load)
30
+ from utils.plots import Annotator, colors, save_one_box
31
+ from utils.torch_utils import copy_attr, smart_inference_mode
32
+
33
+
34
+ def autopad(k, p=None, d=1): # kernel, padding, dilation
35
+ # Pad to 'same' shape outputs
36
+ if d > 1:
37
+ k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
38
+ if p is None:
39
+ p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
40
+ return p
41
+
42
+
43
+ class Conv(nn.Module):
44
+ # Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)
45
+ default_act = nn.SiLU() # default activation
46
+
47
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
48
+ super().__init__()
49
+ self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
50
+ self.bn = nn.BatchNorm2d(c2)
51
+ self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
52
+
53
+ def forward(self, x):
54
+ return self.act(self.bn(self.conv(x)))
55
+
56
+ def forward_fuse(self, x):
57
+ return self.act(self.conv(x))
58
+
59
+ class Convb(nn.Module):
60
+ # Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)
61
+ default_act = nn.SiLU() # default activation
62
+
63
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
64
+ super().__init__()
65
+ self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=True)
66
+ self.bn = nn.BatchNorm2d(c2)
67
+ self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
68
+
69
+ def forward(self, x):
70
+ return self.act(self.bn(self.conv(x)))
71
+
72
+ def forward_fuse(self, x):
73
+ return self.act(self.conv(x))
74
+
75
+
76
+ class AConv(nn.Module):
77
+ def __init__(self, c1, c2): # ch_in, ch_out, shortcut, kernels, groups, expand
78
+ super().__init__()
79
+ self.cv1 = Conv(c1, c2, 3, 2, 1)
80
+
81
+ def forward(self, x):
82
+ x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True)
83
+ return self.cv1(x)
84
+
85
+
86
+ class ADown(nn.Module):
87
+ def __init__(self, c1, c2): # ch_in, ch_out, shortcut, kernels, groups, expand
88
+ super().__init__()
89
+ self.c = c2 // 2
90
+ self.cv1 = Conv(c1 // 2, self.c, 3, 2, 1)
91
+ self.cv2 = Conv(c1 // 2, self.c, 1, 1, 0)
92
+
93
+ def forward(self, x):
94
+ x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True)
95
+ x1,x2 = x.chunk(2, 1)
96
+ x1 = self.cv1(x1)
97
+ x2 = torch.nn.functional.max_pool2d(x2, 3, 2, 1)
98
+ x2 = self.cv2(x2)
99
+ return torch.cat((x1, x2), 1)
100
+
101
+
102
+ class RepConvN(nn.Module):
103
+ """RepConv is a basic rep-style block, including training and deploy status
104
+ This code is based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
105
+ """
106
+ default_act = nn.SiLU() # default activation
107
+
108
+ def __init__(self, c1, c2, k=3, s=1, p=1, g=1, d=1, act=True, bn=False, deploy=False):
109
+ super().__init__()
110
+ assert k == 3 and p == 1
111
+ self.g = g
112
+ self.c1 = c1
113
+ self.c2 = c2
114
+ self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
115
+
116
+ self.bn = None
117
+ self.conv1 = Conv(c1, c2, k, s, p=p, g=g, act=False)
118
+ self.conv2 = Conv(c1, c2, 1, s, p=(p - k // 2), g=g, act=False)
119
+
120
+ def forward_fuse(self, x):
121
+ """Forward process"""
122
+ return self.act(self.conv(x))
123
+
124
+ def forward(self, x):
125
+ """Forward process"""
126
+ id_out = 0 if self.bn is None else self.bn(x)
127
+ return self.act(self.conv1(x) + self.conv2(x) + id_out)
128
+
129
+ def get_equivalent_kernel_bias(self):
130
+ kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
131
+ kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
132
+ kernelid, biasid = self._fuse_bn_tensor(self.bn)
133
+ return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
134
+
135
+ def _avg_to_3x3_tensor(self, avgp):
136
+ channels = self.c1
137
+ groups = self.g
138
+ kernel_size = avgp.kernel_size
139
+ input_dim = channels // groups
140
+ k = torch.zeros((channels, input_dim, kernel_size, kernel_size))
141
+ k[np.arange(channels), np.tile(np.arange(input_dim), groups), :, :] = 1.0 / kernel_size ** 2
142
+ return k
143
+
144
+ def _pad_1x1_to_3x3_tensor(self, kernel1x1):
145
+ if kernel1x1 is None:
146
+ return 0
147
+ else:
148
+ return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
149
+
150
+ def _fuse_bn_tensor(self, branch):
151
+ if branch is None:
152
+ return 0, 0
153
+ if isinstance(branch, Conv):
154
+ kernel = branch.conv.weight
155
+ running_mean = branch.bn.running_mean
156
+ running_var = branch.bn.running_var
157
+ gamma = branch.bn.weight
158
+ beta = branch.bn.bias
159
+ eps = branch.bn.eps
160
+ elif isinstance(branch, nn.BatchNorm2d):
161
+ if not hasattr(self, 'id_tensor'):
162
+ input_dim = self.c1 // self.g
163
+ kernel_value = np.zeros((self.c1, input_dim, 3, 3), dtype=np.float32)
164
+ for i in range(self.c1):
165
+ kernel_value[i, i % input_dim, 1, 1] = 1
166
+ self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
167
+ kernel = self.id_tensor
168
+ running_mean = branch.running_mean
169
+ running_var = branch.running_var
170
+ gamma = branch.weight
171
+ beta = branch.bias
172
+ eps = branch.eps
173
+ std = (running_var + eps).sqrt()
174
+ t = (gamma / std).reshape(-1, 1, 1, 1)
175
+ return kernel * t, beta - running_mean * gamma / std
176
+
177
+ def fuse_convs(self):
178
+ if hasattr(self, 'conv'):
179
+ return
180
+ kernel, bias = self.get_equivalent_kernel_bias()
181
+ self.conv = nn.Conv2d(in_channels=self.conv1.conv.in_channels,
182
+ out_channels=self.conv1.conv.out_channels,
183
+ kernel_size=self.conv1.conv.kernel_size,
184
+ stride=self.conv1.conv.stride,
185
+ padding=self.conv1.conv.padding,
186
+ dilation=self.conv1.conv.dilation,
187
+ groups=self.conv1.conv.groups,
188
+ bias=True).requires_grad_(False)
189
+ self.conv.weight.data = kernel
190
+ self.conv.bias.data = bias
191
+ for para in self.parameters():
192
+ para.detach_()
193
+ self.__delattr__('conv1')
194
+ self.__delattr__('conv2')
195
+ if hasattr(self, 'nm'):
196
+ self.__delattr__('nm')
197
+ if hasattr(self, 'bn'):
198
+ self.__delattr__('bn')
199
+ if hasattr(self, 'id_tensor'):
200
+ self.__delattr__('id_tensor')
201
+
202
+
203
+ class SP(nn.Module):
204
+ def __init__(self, k=3, s=1):
205
+ super(SP, self).__init__()
206
+ self.m = nn.MaxPool2d(kernel_size=k, stride=s, padding=k // 2)
207
+
208
+ def forward(self, x):
209
+ return self.m(x)
210
+
211
+
212
+ class MP(nn.Module):
213
+ # Max pooling
214
+ def __init__(self, k=2):
215
+ super(MP, self).__init__()
216
+ self.m = nn.MaxPool2d(kernel_size=k, stride=k)
217
+
218
+ def forward(self, x):
219
+ return self.m(x)
220
+
221
+
222
+ class ConvTranspose(nn.Module):
223
+ # Convolution transpose 2d layer
224
+ default_act = nn.SiLU() # default activation
225
+
226
+ def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True):
227
+ super().__init__()
228
+ self.conv_transpose = nn.ConvTranspose2d(c1, c2, k, s, p, bias=not bn)
229
+ self.bn = nn.BatchNorm2d(c2) if bn else nn.Identity()
230
+ self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
231
+
232
+ def forward(self, x):
233
+ return self.act(self.bn(self.conv_transpose(x)))
234
+
235
+
236
+ class DWConv(Conv):
237
+ # Depth-wise convolution
238
+ def __init__(self, c1, c2, k=1, s=1, d=1, act=True): # ch_in, ch_out, kernel, stride, dilation, activation
239
+ super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)
240
+
241
+
242
+ class DWConvTranspose2d(nn.ConvTranspose2d):
243
+ # Depth-wise transpose convolution
244
+ def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0): # ch_in, ch_out, kernel, stride, padding, padding_out
245
+ super().__init__(c1, c2, k, s, p1, p2, groups=math.gcd(c1, c2))
246
+
247
+ class DWConvTranspose(nn.Module):
248
+ # Convolution transpose 2d layer
249
+ default_act = nn.SiLU() # default activation
250
+
251
+ def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True):
252
+ super().__init__()
253
+ self.dwconv_transpose = DWConvTranspose2d(c1, c2, k, s, p)
254
+ self.bn = nn.BatchNorm2d(c2) if bn else nn.Identity()
255
+ self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
256
+
257
+ def forward(self, x):
258
+ return self.act(self.bn(self.dwconv_transpose(x)))
259
+
260
+
261
+ class DFL(nn.Module):
262
+ # DFL module
263
+ def __init__(self, c1=17):
264
+ super().__init__()
265
+ self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
266
+ self.conv.weight.data[:] = nn.Parameter(torch.arange(c1, dtype=torch.float).view(1, c1, 1, 1)) # / 120.0
267
+ self.c1 = c1
268
+ # self.bn = nn.BatchNorm2d(4)
269
+
270
+ def forward(self, x):
271
+ b, c, a = x.shape # batch, channels, anchors
272
+ return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
273
+ # return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a)
274
+
275
+
276
+ class BottleneckBase(nn.Module):
277
+ # Standard bottleneck
278
+ def __init__(self, c1, c2, shortcut=True, g=1, k=(1, 3), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
279
+ super().__init__()
280
+ c_ = int(c2 * e) # hidden channels
281
+ self.cv1 = Conv(c1, c_, k[0], 1)
282
+ self.cv2 = Conv(c_, c2, k[1], 1, g=g)
283
+ self.add = shortcut and c1 == c2
284
+
285
+ def forward(self, x):
286
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
287
+
288
+
289
+ class RBottleneckBase(nn.Module):
290
+ # Standard bottleneck
291
+ def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 1), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
292
+ super().__init__()
293
+ c_ = int(c2 * e) # hidden channels
294
+ self.cv1 = Conv(c1, c_, k[0], 1)
295
+ self.cv2 = Conv(c_, c2, k[1], 1, g=g)
296
+ self.add = shortcut and c1 == c2
297
+
298
+ def forward(self, x):
299
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
300
+
301
+
302
+ class RepNRBottleneckBase(nn.Module):
303
+ # Standard bottleneck
304
+ def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 1), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
305
+ super().__init__()
306
+ c_ = int(c2 * e) # hidden channels
307
+ self.cv1 = RepConvN(c1, c_, k[0], 1)
308
+ self.cv2 = Conv(c_, c2, k[1], 1, g=g)
309
+ self.add = shortcut and c1 == c2
310
+
311
+ def forward(self, x):
312
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
313
+
314
+
315
+ class Bottleneck(nn.Module):
316
+ # Standard bottleneck
317
+ def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
318
+ super().__init__()
319
+ c_ = int(c2 * e) # hidden channels
320
+ self.cv1 = Conv(c1, c_, k[0], 1)
321
+ self.cv2 = Conv(c_, c2, k[1], 1, g=g)
322
+ self.add = shortcut and c1 == c2
323
+
324
+ def forward(self, x):
325
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
326
+
327
+
328
+ class RepNBottleneck(nn.Module):
329
+ # Standard bottleneck
330
+ def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
331
+ super().__init__()
332
+ c_ = int(c2 * e) # hidden channels
333
+ self.cv1 = RepConvN(c1, c_, k[0], 1)
334
+ self.cv2 = Conv(c_, c2, k[1], 1, g=g)
335
+ self.add = shortcut and c1 == c2
336
+
337
+ def forward(self, x):
338
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
339
+
340
+
341
+ class Res(nn.Module):
342
+ # ResNet bottleneck
343
+ def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
344
+ super(Res, self).__init__()
345
+ c_ = int(c2 * e) # hidden channels
346
+ self.cv1 = Conv(c1, c_, 1, 1)
347
+ self.cv2 = Conv(c_, c_, 3, 1, g=g)
348
+ self.cv3 = Conv(c_, c2, 1, 1)
349
+ self.add = shortcut and c1 == c2
350
+
351
+ def forward(self, x):
352
+ return x + self.cv3(self.cv2(self.cv1(x))) if self.add else self.cv3(self.cv2(self.cv1(x)))
353
+
354
+
355
+ class RepNRes(nn.Module):
356
+ # ResNet bottleneck
357
+ def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
358
+ super(RepNRes, self).__init__()
359
+ c_ = int(c2 * e) # hidden channels
360
+ self.cv1 = Conv(c1, c_, 1, 1)
361
+ self.cv2 = RepConvN(c_, c_, 3, 1, g=g)
362
+ self.cv3 = Conv(c_, c2, 1, 1)
363
+ self.add = shortcut and c1 == c2
364
+
365
+ def forward(self, x):
366
+ return x + self.cv3(self.cv2(self.cv1(x))) if self.add else self.cv3(self.cv2(self.cv1(x)))
367
+
368
+
369
+ class BottleneckCSP(nn.Module):
370
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
371
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
372
+ super().__init__()
373
+ c_ = int(c2 * e) # hidden channels
374
+ self.cv1 = Conv(c1, c_, 1, 1)
375
+ self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
376
+ self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
377
+ self.cv4 = Conv(2 * c_, c2, 1, 1)
378
+ self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
379
+ self.act = nn.SiLU()
380
+ self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
381
+
382
+ def forward(self, x):
383
+ y1 = self.cv3(self.m(self.cv1(x)))
384
+ y2 = self.cv2(x)
385
+ return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1))))
386
+
387
+
388
+ class CSP(nn.Module):
389
+ # CSP Bottleneck with 3 convolutions
390
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
391
+ super().__init__()
392
+ c_ = int(c2 * e) # hidden channels
393
+ self.cv1 = Conv(c1, c_, 1, 1)
394
+ self.cv2 = Conv(c1, c_, 1, 1)
395
+ self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
396
+ self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
397
+
398
+ def forward(self, x):
399
+ return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
400
+
401
+
402
+ class RepNCSP(nn.Module):
403
+ # CSP Bottleneck with 3 convolutions
404
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
405
+ super().__init__()
406
+ c_ = int(c2 * e) # hidden channels
407
+ self.cv1 = Conv(c1, c_, 1, 1)
408
+ self.cv2 = Conv(c1, c_, 1, 1)
409
+ self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
410
+ self.m = nn.Sequential(*(RepNBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
411
+
412
+ def forward(self, x):
413
+ return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
414
+
415
+
416
+ class CSPBase(nn.Module):
417
+ # CSP Bottleneck with 3 convolutions
418
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
419
+ super().__init__()
420
+ c_ = int(c2 * e) # hidden channels
421
+ self.cv1 = Conv(c1, c_, 1, 1)
422
+ self.cv2 = Conv(c1, c_, 1, 1)
423
+ self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
424
+ self.m = nn.Sequential(*(BottleneckBase(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
425
+
426
+ def forward(self, x):
427
+ return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
428
+
429
+
430
+ class SPP(nn.Module):
431
+ # Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729
432
+ def __init__(self, c1, c2, k=(5, 9, 13)):
433
+ super().__init__()
434
+ c_ = c1 // 2 # hidden channels
435
+ self.cv1 = Conv(c1, c_, 1, 1)
436
+ self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
437
+ self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
438
+
439
+ def forward(self, x):
440
+ x = self.cv1(x)
441
+ with warnings.catch_warnings():
442
+ warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
443
+ return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
444
+
445
+
446
+ class ASPP(torch.nn.Module):
447
+
448
+ def __init__(self, in_channels, out_channels):
449
+ super().__init__()
450
+ kernel_sizes = [1, 3, 3, 1]
451
+ dilations = [1, 3, 6, 1]
452
+ paddings = [0, 3, 6, 0]
453
+ self.aspp = torch.nn.ModuleList()
454
+ for aspp_idx in range(len(kernel_sizes)):
455
+ conv = torch.nn.Conv2d(
456
+ in_channels,
457
+ out_channels,
458
+ kernel_size=kernel_sizes[aspp_idx],
459
+ stride=1,
460
+ dilation=dilations[aspp_idx],
461
+ padding=paddings[aspp_idx],
462
+ bias=True)
463
+ self.aspp.append(conv)
464
+ self.gap = torch.nn.AdaptiveAvgPool2d(1)
465
+ self.aspp_num = len(kernel_sizes)
466
+ for m in self.modules():
467
+ if isinstance(m, torch.nn.Conv2d):
468
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
469
+ m.weight.data.normal_(0, math.sqrt(2. / n))
470
+ m.bias.data.fill_(0)
471
+
472
+ def forward(self, x):
473
+ avg_x = self.gap(x)
474
+ out = []
475
+ for aspp_idx in range(self.aspp_num):
476
+ inp = avg_x if (aspp_idx == self.aspp_num - 1) else x
477
+ out.append(F.relu_(self.aspp[aspp_idx](inp)))
478
+ out[-1] = out[-1].expand_as(out[-2])
479
+ out = torch.cat(out, dim=1)
480
+ return out
481
+
482
+
483
+ class SPPCSPC(nn.Module):
484
+ # CSP SPP https://github.com/WongKinYiu/CrossStagePartialNetworks
485
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, k=(5, 9, 13)):
486
+ super(SPPCSPC, self).__init__()
487
+ c_ = int(2 * c2 * e) # hidden channels
488
+ self.cv1 = Conv(c1, c_, 1, 1)
489
+ self.cv2 = Conv(c1, c_, 1, 1)
490
+ self.cv3 = Conv(c_, c_, 3, 1)
491
+ self.cv4 = Conv(c_, c_, 1, 1)
492
+ self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
493
+ self.cv5 = Conv(4 * c_, c_, 1, 1)
494
+ self.cv6 = Conv(c_, c_, 3, 1)
495
+ self.cv7 = Conv(2 * c_, c2, 1, 1)
496
+
497
+ def forward(self, x):
498
+ x1 = self.cv4(self.cv3(self.cv1(x)))
499
+ y1 = self.cv6(self.cv5(torch.cat([x1] + [m(x1) for m in self.m], 1)))
500
+ y2 = self.cv2(x)
501
+ return self.cv7(torch.cat((y1, y2), dim=1))
502
+
503
+
504
+ class SPPF(nn.Module):
505
+ # Spatial Pyramid Pooling - Fast (SPPF) layer by Glenn Jocher
506
+ def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
507
+ super().__init__()
508
+ c_ = c1 // 2 # hidden channels
509
+ self.cv1 = Conv(c1, c_, 1, 1)
510
+ self.cv2 = Conv(c_ * 4, c2, 1, 1)
511
+ self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
512
+ # self.m = SoftPool2d(kernel_size=k, stride=1, padding=k // 2)
513
+
514
+ def forward(self, x):
515
+ x = self.cv1(x)
516
+ with warnings.catch_warnings():
517
+ warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
518
+ y1 = self.m(x)
519
+ y2 = self.m(y1)
520
+ return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
521
+
522
+
523
+ import torch.nn.functional as F
524
+ from torch.nn.modules.utils import _pair
525
+
526
+
527
+ class ReOrg(nn.Module):
528
+ # yolo
529
+ def __init__(self):
530
+ super(ReOrg, self).__init__()
531
+
532
+ def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
533
+ return torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)
534
+
535
+
536
+ class Contract(nn.Module):
537
+ # Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
538
+ def __init__(self, gain=2):
539
+ super().__init__()
540
+ self.gain = gain
541
+
542
+ def forward(self, x):
543
+ b, c, h, w = x.size() # assert (h / s == 0) and (W / s == 0), 'Indivisible gain'
544
+ s = self.gain
545
+ x = x.view(b, c, h // s, s, w // s, s) # x(1,64,40,2,40,2)
546
+ x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40)
547
+ return x.view(b, c * s * s, h // s, w // s) # x(1,256,40,40)
548
+
549
+
550
+ class Expand(nn.Module):
551
+ # Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160)
552
+ def __init__(self, gain=2):
553
+ super().__init__()
554
+ self.gain = gain
555
+
556
+ def forward(self, x):
557
+ b, c, h, w = x.size() # assert C / s ** 2 == 0, 'Indivisible gain'
558
+ s = self.gain
559
+ x = x.view(b, s, s, c // s ** 2, h, w) # x(1,2,2,16,80,80)
560
+ x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2)
561
+ return x.view(b, c // s ** 2, h * s, w * s) # x(1,16,160,160)
562
+
563
+
564
+ class Concat(nn.Module):
565
+ # Concatenate a list of tensors along dimension
566
+ def __init__(self, dimension=1):
567
+ super().__init__()
568
+ self.d = dimension
569
+
570
+ def forward(self, x):
571
+ return torch.cat(x, self.d)
572
+
573
+
574
+ class Shortcut(nn.Module):
575
+ def __init__(self, dimension=0):
576
+ super(Shortcut, self).__init__()
577
+ self.d = dimension
578
+
579
+ def forward(self, x):
580
+ return x[0]+x[1]
581
+
582
+
583
+ class Silence(nn.Module):
584
+ def __init__(self):
585
+ super(Silence, self).__init__()
586
+ def forward(self, x):
587
+ return x
588
+
589
+
590
+ ##### GELAN #####
591
+
592
+ class SPPELAN(nn.Module):
593
+ # spp-elan
594
+ def __init__(self, c1, c2, c3): # ch_in, ch_out, number, shortcut, groups, expansion
595
+ super().__init__()
596
+ self.c = c3
597
+ self.cv1 = Conv(c1, c3, 1, 1)
598
+ self.cv2 = SP(5)
599
+ self.cv3 = SP(5)
600
+ self.cv4 = SP(5)
601
+ self.cv5 = Conv(4*c3, c2, 1, 1)
602
+
603
+ def forward(self, x):
604
+ y = [self.cv1(x)]
605
+ y.extend(m(y[-1]) for m in [self.cv2, self.cv3, self.cv4])
606
+ return self.cv5(torch.cat(y, 1))
607
+
608
+
609
+ class RepNCSPELAN4(nn.Module):
610
+ # csp-elan
611
+ def __init__(self, c1, c2, c3, c4, c5=1): # ch_in, ch_out, number, shortcut, groups, expansion
612
+ super().__init__()
613
+ self.c = c3//2
614
+ self.cv1 = Conv(c1, c3, 1, 1)
615
+ self.cv2 = nn.Sequential(RepNCSP(c3//2, c4, c5), Conv(c4, c4, 3, 1))
616
+ self.cv3 = nn.Sequential(RepNCSP(c4, c4, c5), Conv(c4, c4, 3, 1))
617
+ self.cv4 = Conv(c3+(2*c4), c2, 1, 1)
618
+
619
+ def forward(self, x):
620
+ y = list(self.cv1(x).chunk(2, 1))
621
+ y.extend((m(y[-1])) for m in [self.cv2, self.cv3])
622
+ return self.cv4(torch.cat(y, 1))
623
+
624
+ def forward_split(self, x):
625
+ y = list(self.cv1(x).split((self.c, self.c), 1))
626
+ y.extend(m(y[-1]) for m in [self.cv2, self.cv3])
627
+ return self.cv4(torch.cat(y, 1))
628
+
629
+ #################
630
+
631
+
632
+ ##### YOLOR #####
633
+
634
+ class ImplicitA(nn.Module):
635
+ def __init__(self, channel):
636
+ super(ImplicitA, self).__init__()
637
+ self.channel = channel
638
+ self.implicit = nn.Parameter(torch.zeros(1, channel, 1, 1))
639
+ nn.init.normal_(self.implicit, std=.02)
640
+
641
+ def forward(self, x):
642
+ return self.implicit + x
643
+
644
+
645
+ class ImplicitM(nn.Module):
646
+ def __init__(self, channel):
647
+ super(ImplicitM, self).__init__()
648
+ self.channel = channel
649
+ self.implicit = nn.Parameter(torch.ones(1, channel, 1, 1))
650
+ nn.init.normal_(self.implicit, mean=1., std=.02)
651
+
652
+ def forward(self, x):
653
+ return self.implicit * x
654
+
655
+ #################
656
+
657
+
658
+ ##### CBNet #####
659
+
660
+ class CBLinear(nn.Module):
661
+ def __init__(self, c1, c2s, k=1, s=1, p=None, g=1): # ch_in, ch_outs, kernel, stride, padding, groups
662
+ super(CBLinear, self).__init__()
663
+ self.c2s = c2s
664
+ self.conv = nn.Conv2d(c1, sum(c2s), k, s, autopad(k, p), groups=g, bias=True)
665
+
666
+ def forward(self, x):
667
+ outs = self.conv(x).split(self.c2s, dim=1)
668
+ return outs
669
+
670
+ class CBFuse(nn.Module):
671
+ def __init__(self, idx):
672
+ super(CBFuse, self).__init__()
673
+ self.idx = idx
674
+
675
+ def forward(self, xs):
676
+ target_size = xs[-1].shape[2:]
677
+ res = [F.interpolate(x[self.idx[i]], size=target_size, mode='nearest') for i, x in enumerate(xs[:-1])]
678
+ out = torch.sum(torch.stack(res + xs[-1:]), dim=0)
679
+ return out
680
+
681
+ #################
682
+
683
+
684
+ class DetectMultiBackend(nn.Module):
685
+ # YOLO MultiBackend class for python inference on various backends
686
+ def __init__(self, weights='yolo.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True):
687
+ # Usage:
688
+ # PyTorch: weights = *.pt
689
+ # TorchScript: *.torchscript
690
+ # ONNX Runtime: *.onnx
691
+ # ONNX OpenCV DNN: *.onnx --dnn
692
+ # OpenVINO: *_openvino_model
693
+ # CoreML: *.mlmodel
694
+ # TensorRT: *.engine
695
+ # TensorFlow SavedModel: *_saved_model
696
+ # TensorFlow GraphDef: *.pb
697
+ # TensorFlow Lite: *.tflite
698
+ # TensorFlow Edge TPU: *_edgetpu.tflite
699
+ # PaddlePaddle: *_paddle_model
700
+ from models.experimental import attempt_download, attempt_load # scoped to avoid circular import
701
+
702
+ super().__init__()
703
+ w = str(weights[0] if isinstance(weights, list) else weights)
704
+ pt, jit, onnx, onnx_end2end, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, triton = self._model_type(w)
705
+ fp16 &= pt or jit or onnx or engine # FP16
706
+ nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH)
707
+ stride = 32 # default stride
708
+ cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA
709
+ if not (pt or triton):
710
+ w = attempt_download(w) # download if not local
711
+
712
+ if pt: # PyTorch
713
+ model = attempt_load(weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse)
714
+ stride = max(int(model.stride.max()), 32) # model stride
715
+ names = model.module.names if hasattr(model, 'module') else model.names # get class names
716
+ model.half() if fp16 else model.float()
717
+ self.model = model # explicitly assign for to(), cpu(), cuda(), half()
718
+ elif jit: # TorchScript
719
+ LOGGER.info(f'Loading {w} for TorchScript inference...')
720
+ extra_files = {'config.txt': ''} # model metadata
721
+ model = torch.jit.load(w, _extra_files=extra_files, map_location=device)
722
+ model.half() if fp16 else model.float()
723
+ if extra_files['config.txt']: # load metadata dict
724
+ d = json.loads(extra_files['config.txt'],
725
+ object_hook=lambda d: {int(k) if k.isdigit() else k: v
726
+ for k, v in d.items()})
727
+ stride, names = int(d['stride']), d['names']
728
+ elif dnn: # ONNX OpenCV DNN
729
+ LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
730
+ check_requirements('opencv-python>=4.5.4')
731
+ net = cv2.dnn.readNetFromONNX(w)
732
+ elif onnx: # ONNX Runtime
733
+ LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
734
+ check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
735
+ import onnxruntime
736
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
737
+ session = onnxruntime.InferenceSession(w, providers=providers)
738
+ output_names = [x.name for x in session.get_outputs()]
739
+ meta = session.get_modelmeta().custom_metadata_map # metadata
740
+ if 'stride' in meta:
741
+ stride, names = int(meta['stride']), eval(meta['names'])
742
+ elif xml: # OpenVINO
743
+ LOGGER.info(f'Loading {w} for OpenVINO inference...')
744
+ check_requirements('openvino') # requires openvino-dev: https://pypi.org/project/openvino-dev/
745
+ from openvino.runtime import Core, Layout, get_batch
746
+ ie = Core()
747
+ if not Path(w).is_file(): # if not *.xml
748
+ w = next(Path(w).glob('*.xml')) # get *.xml file from *_openvino_model dir
749
+ network = ie.read_model(model=w, weights=Path(w).with_suffix('.bin'))
750
+ if network.get_parameters()[0].get_layout().empty:
751
+ network.get_parameters()[0].set_layout(Layout("NCHW"))
752
+ batch_dim = get_batch(network)
753
+ if batch_dim.is_static:
754
+ batch_size = batch_dim.get_length()
755
+ executable_network = ie.compile_model(network, device_name="CPU") # device_name="MYRIAD" for Intel NCS2
756
+ stride, names = self._load_metadata(Path(w).with_suffix('.yaml')) # load metadata
757
+ elif engine: # TensorRT
758
+ LOGGER.info(f'Loading {w} for TensorRT inference...')
759
+ import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
760
+ check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
761
+ if device.type == 'cpu':
762
+ device = torch.device('cuda:0')
763
+ Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
764
+ logger = trt.Logger(trt.Logger.INFO)
765
+ with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
766
+ model = runtime.deserialize_cuda_engine(f.read())
767
+ context = model.create_execution_context()
768
+ bindings = OrderedDict()
769
+ output_names = []
770
+ fp16 = False # default updated below
771
+ dynamic = False
772
+ for i in range(model.num_bindings):
773
+ name = model.get_binding_name(i)
774
+ dtype = trt.nptype(model.get_binding_dtype(i))
775
+ if model.binding_is_input(i):
776
+ if -1 in tuple(model.get_binding_shape(i)): # dynamic
777
+ dynamic = True
778
+ context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[2]))
779
+ if dtype == np.float16:
780
+ fp16 = True
781
+ else: # output
782
+ output_names.append(name)
783
+ shape = tuple(context.get_binding_shape(i))
784
+ im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
785
+ bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
786
+ binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
787
+ batch_size = bindings['images'].shape[0] # if dynamic, this is instead max batch size
788
+ elif coreml: # CoreML
789
+ LOGGER.info(f'Loading {w} for CoreML inference...')
790
+ import coremltools as ct
791
+ model = ct.models.MLModel(w)
792
+ elif saved_model: # TF SavedModel
793
+ LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
794
+ import tensorflow as tf
795
+ keras = False # assume TF1 saved_model
796
+ model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
797
+ elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
798
+ LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
799
+ import tensorflow as tf
800
+
801
+ def wrap_frozen_graph(gd, inputs, outputs):
802
+ x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
803
+ ge = x.graph.as_graph_element
804
+ return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
805
+
806
+ def gd_outputs(gd):
807
+ name_list, input_list = [], []
808
+ for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
809
+ name_list.append(node.name)
810
+ input_list.extend(node.input)
811
+ return sorted(f'{x}:0' for x in list(set(name_list) - set(input_list)) if not x.startswith('NoOp'))
812
+
813
+ gd = tf.Graph().as_graph_def() # TF GraphDef
814
+ with open(w, 'rb') as f:
815
+ gd.ParseFromString(f.read())
816
+ frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
817
+ elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
818
+ try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
819
+ from tflite_runtime.interpreter import Interpreter, load_delegate
820
+ except ImportError:
821
+ import tensorflow as tf
822
+ Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate,
823
+ if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime
824
+ LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
825
+ delegate = {
826
+ 'Linux': 'libedgetpu.so.1',
827
+ 'Darwin': 'libedgetpu.1.dylib',
828
+ 'Windows': 'edgetpu.dll'}[platform.system()]
829
+ interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
830
+ else: # TFLite
831
+ LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
832
+ interpreter = Interpreter(model_path=w) # load TFLite model
833
+ interpreter.allocate_tensors() # allocate
834
+ input_details = interpreter.get_input_details() # inputs
835
+ output_details = interpreter.get_output_details() # outputs
836
+ # load metadata
837
+ with contextlib.suppress(zipfile.BadZipFile):
838
+ with zipfile.ZipFile(w, "r") as model:
839
+ meta_file = model.namelist()[0]
840
+ meta = ast.literal_eval(model.read(meta_file).decode("utf-8"))
841
+ stride, names = int(meta['stride']), meta['names']
842
+ elif tfjs: # TF.js
843
+ raise NotImplementedError('ERROR: YOLO TF.js inference is not supported')
844
+ elif paddle: # PaddlePaddle
845
+ LOGGER.info(f'Loading {w} for PaddlePaddle inference...')
846
+ check_requirements('paddlepaddle-gpu' if cuda else 'paddlepaddle')
847
+ import paddle.inference as pdi
848
+ if not Path(w).is_file(): # if not *.pdmodel
849
+ w = next(Path(w).rglob('*.pdmodel')) # get *.pdmodel file from *_paddle_model dir
850
+ weights = Path(w).with_suffix('.pdiparams')
851
+ config = pdi.Config(str(w), str(weights))
852
+ if cuda:
853
+ config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0)
854
+ predictor = pdi.create_predictor(config)
855
+ input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
856
+ output_names = predictor.get_output_names()
857
+ elif triton: # NVIDIA Triton Inference Server
858
+ LOGGER.info(f'Using {w} as Triton Inference Server...')
859
+ check_requirements('tritonclient[all]')
860
+ from utils.triton import TritonRemoteModel
861
+ model = TritonRemoteModel(url=w)
862
+ nhwc = model.runtime.startswith("tensorflow")
863
+ else:
864
+ raise NotImplementedError(f'ERROR: {w} is not a supported format')
865
+
866
+ # class names
867
+ if 'names' not in locals():
868
+ names = yaml_load(data)['names'] if data else {i: f'class{i}' for i in range(999)}
869
+ if names[0] == 'n01440764' and len(names) == 1000: # ImageNet
870
+ names = yaml_load(ROOT / 'data/ImageNet.yaml')['names'] # human-readable names
871
+
872
+ self.__dict__.update(locals()) # assign all variables to self
873
+
874
+ def forward(self, im, augment=False, visualize=False):
875
+ # YOLO MultiBackend inference
876
+ b, ch, h, w = im.shape # batch, channel, height, width
877
+ if self.fp16 and im.dtype != torch.float16:
878
+ im = im.half() # to FP16
879
+ if self.nhwc:
880
+ im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)
881
+
882
+ if self.pt: # PyTorch
883
+ y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im)
884
+ elif self.jit: # TorchScript
885
+ y = self.model(im)
886
+ elif self.dnn: # ONNX OpenCV DNN
887
+ im = im.cpu().numpy() # torch to numpy
888
+ self.net.setInput(im)
889
+ y = self.net.forward()
890
+ elif self.onnx: # ONNX Runtime
891
+ im = im.cpu().numpy() # torch to numpy
892
+ y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
893
+ elif self.xml: # OpenVINO
894
+ im = im.cpu().numpy() # FP32
895
+ y = list(self.executable_network([im]).values())
896
+ elif self.engine: # TensorRT
897
+ if self.dynamic and im.shape != self.bindings['images'].shape:
898
+ i = self.model.get_binding_index('images')
899
+ self.context.set_binding_shape(i, im.shape) # reshape if dynamic
900
+ self.bindings['images'] = self.bindings['images']._replace(shape=im.shape)
901
+ for name in self.output_names:
902
+ i = self.model.get_binding_index(name)
903
+ self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))
904
+ s = self.bindings['images'].shape
905
+ assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}"
906
+ self.binding_addrs['images'] = int(im.data_ptr())
907
+ self.context.execute_v2(list(self.binding_addrs.values()))
908
+ y = [self.bindings[x].data for x in sorted(self.output_names)]
909
+ elif self.coreml: # CoreML
910
+ im = im.cpu().numpy()
911
+ im = Image.fromarray((im[0] * 255).astype('uint8'))
912
+ # im = im.resize((192, 320), Image.ANTIALIAS)
913
+ y = self.model.predict({'image': im}) # coordinates are xywh normalized
914
+ if 'confidence' in y:
915
+ box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
916
+ conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float)
917
+ y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
918
+ else:
919
+ y = list(reversed(y.values())) # reversed for segmentation models (pred, proto)
920
+ elif self.paddle: # PaddlePaddle
921
+ im = im.cpu().numpy().astype(np.float32)
922
+ self.input_handle.copy_from_cpu(im)
923
+ self.predictor.run()
924
+ y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names]
925
+ elif self.triton: # NVIDIA Triton Inference Server
926
+ y = self.model(im)
927
+ else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
928
+ im = im.cpu().numpy()
929
+ if self.saved_model: # SavedModel
930
+ y = self.model(im, training=False) if self.keras else self.model(im)
931
+ elif self.pb: # GraphDef
932
+ y = self.frozen_func(x=self.tf.constant(im))
933
+ else: # Lite or Edge TPU
934
+ input = self.input_details[0]
935
+ int8 = input['dtype'] == np.uint8 # is TFLite quantized uint8 model
936
+ if int8:
937
+ scale, zero_point = input['quantization']
938
+ im = (im / scale + zero_point).astype(np.uint8) # de-scale
939
+ self.interpreter.set_tensor(input['index'], im)
940
+ self.interpreter.invoke()
941
+ y = []
942
+ for output in self.output_details:
943
+ x = self.interpreter.get_tensor(output['index'])
944
+ if int8:
945
+ scale, zero_point = output['quantization']
946
+ x = (x.astype(np.float32) - zero_point) * scale # re-scale
947
+ y.append(x)
948
+ y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
949
+ y[0][..., :4] *= [w, h, w, h] # xywh normalized to pixels
950
+
951
+ if isinstance(y, (list, tuple)):
952
+ return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y]
953
+ else:
954
+ return self.from_numpy(y)
955
+
956
+ def from_numpy(self, x):
957
+ return torch.from_numpy(x).to(self.device) if isinstance(x, np.ndarray) else x
958
+
959
+ def warmup(self, imgsz=(1, 3, 640, 640)):
960
+ # Warmup model by running inference once
961
+ warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton
962
+ if any(warmup_types) and (self.device.type != 'cpu' or self.triton):
963
+ im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
964
+ for _ in range(2 if self.jit else 1): #
965
+ self.forward(im) # warmup
966
+
967
+ @staticmethod
968
+ def _model_type(p='path/to/model.pt'):
969
+ # Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
970
+ # types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle]
971
+ from export import export_formats
972
+ from utils.downloads import is_url
973
+ sf = list(export_formats().Suffix) # export suffixes
974
+ if not is_url(p, check=False):
975
+ check_suffix(p, sf) # checks
976
+ url = urlparse(p) # if url may be Triton inference server
977
+ types = [s in Path(p).name for s in sf]
978
+ types[8] &= not types[9] # tflite &= not edgetpu
979
+ triton = not any(types) and all([any(s in url.scheme for s in ["http", "grpc"]), url.netloc])
980
+ return types + [triton]
981
+
982
+ @staticmethod
983
+ def _load_metadata(f=Path('path/to/meta.yaml')):
984
+ # Load metadata from meta.yaml if it exists
985
+ if f.exists():
986
+ d = yaml_load(f)
987
+ return d['stride'], d['names'] # assign stride, names
988
+ return None, None
989
+
990
+
991
+ class AutoShape(nn.Module):
992
+ # YOLO input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
993
+ conf = 0.25 # NMS confidence threshold
994
+ iou = 0.45 # NMS IoU threshold
995
+ agnostic = False # NMS class-agnostic
996
+ multi_label = False # NMS multiple labels per box
997
+ classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
998
+ max_det = 1000 # maximum number of detections per image
999
+ amp = False # Automatic Mixed Precision (AMP) inference
1000
+
1001
+ def __init__(self, model, verbose=True):
1002
+ super().__init__()
1003
+ if verbose:
1004
+ LOGGER.info('Adding AutoShape... ')
1005
+ copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes
1006
+ self.dmb = isinstance(model, DetectMultiBackend) # DetectMultiBackend() instance
1007
+ self.pt = not self.dmb or model.pt # PyTorch model
1008
+ self.model = model.eval()
1009
+ if self.pt:
1010
+ m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
1011
+ m.inplace = False # Detect.inplace=False for safe multithread inference
1012
+ m.export = True # do not output loss values
1013
+
1014
+ def _apply(self, fn):
1015
+ # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
1016
+ self = super()._apply(fn)
1017
+ from models.yolo import Detect, Segment
1018
+ if self.pt:
1019
+ m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
1020
+ if isinstance(m, (Detect, Segment)):
1021
+ for k in 'stride', 'anchor_grid', 'stride_grid', 'grid':
1022
+ x = getattr(m, k)
1023
+ setattr(m, k, list(map(fn, x))) if isinstance(x, (list, tuple)) else setattr(m, k, fn(x))
1024
+ return self
1025
+
1026
+ @smart_inference_mode()
1027
+ def forward(self, ims, size=640, augment=False, profile=False):
1028
+ # Inference from various sources. For size(height=640, width=1280), RGB images example inputs are:
1029
+ # file: ims = 'data/images/zidane.jpg' # str or PosixPath
1030
+ # URI: = 'https://ultralytics.com/images/zidane.jpg'
1031
+ # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
1032
+ # PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
1033
+ # numpy: = np.zeros((640,1280,3)) # HWC
1034
+ # torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
1035
+ # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
1036
+
1037
+ dt = (Profile(), Profile(), Profile())
1038
+ with dt[0]:
1039
+ if isinstance(size, int): # expand
1040
+ size = (size, size)
1041
+ p = next(self.model.parameters()) if self.pt else torch.empty(1, device=self.model.device) # param
1042
+ autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
1043
+ if isinstance(ims, torch.Tensor): # torch
1044
+ with amp.autocast(autocast):
1045
+ return self.model(ims.to(p.device).type_as(p), augment=augment) # inference
1046
+
1047
+ # Pre-process
1048
+ n, ims = (len(ims), list(ims)) if isinstance(ims, (list, tuple)) else (1, [ims]) # number, list of images
1049
+ shape0, shape1, files = [], [], [] # image and inference shapes, filenames
1050
+ for i, im in enumerate(ims):
1051
+ f = f'image{i}' # filename
1052
+ if isinstance(im, (str, Path)): # filename or uri
1053
+ im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
1054
+ im = np.asarray(exif_transpose(im))
1055
+ elif isinstance(im, Image.Image): # PIL Image
1056
+ im, f = np.asarray(exif_transpose(im)), getattr(im, 'filename', f) or f
1057
+ files.append(Path(f).with_suffix('.jpg').name)
1058
+ if im.shape[0] < 5: # image in CHW
1059
+ im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
1060
+ im = im[..., :3] if im.ndim == 3 else cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) # enforce 3ch input
1061
+ s = im.shape[:2] # HWC
1062
+ shape0.append(s) # image shape
1063
+ g = max(size) / max(s) # gain
1064
+ shape1.append([int(y * g) for y in s])
1065
+ ims[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
1066
+ shape1 = [make_divisible(x, self.stride) for x in np.array(shape1).max(0)] # inf shape
1067
+ x = [letterbox(im, shape1, auto=False)[0] for im in ims] # pad
1068
+ x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
1069
+ x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
1070
+
1071
+ with amp.autocast(autocast):
1072
+ # Inference
1073
+ with dt[1]:
1074
+ y = self.model(x, augment=augment) # forward
1075
+
1076
+ # Post-process
1077
+ with dt[2]:
1078
+ y = non_max_suppression(y if self.dmb else y[0],
1079
+ self.conf,
1080
+ self.iou,
1081
+ self.classes,
1082
+ self.agnostic,
1083
+ self.multi_label,
1084
+ max_det=self.max_det) # NMS
1085
+ for i in range(n):
1086
+ scale_boxes(shape1, y[i][:, :4], shape0[i])
1087
+
1088
+ return Detections(ims, y, files, dt, self.names, x.shape)
1089
+
1090
+
1091
+ class Detections:
1092
+ # YOLO detections class for inference results
1093
+ def __init__(self, ims, pred, files, times=(0, 0, 0), names=None, shape=None):
1094
+ super().__init__()
1095
+ d = pred[0].device # device
1096
+ gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in ims] # normalizations
1097
+ self.ims = ims # list of images as numpy arrays
1098
+ self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
1099
+ self.names = names # class names
1100
+ self.files = files # image filenames
1101
+ self.times = times # profiling times
1102
+ self.xyxy = pred # xyxy pixels
1103
+ self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
1104
+ self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
1105
+ self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
1106
+ self.n = len(self.pred) # number of images (batch size)
1107
+ self.t = tuple(x.t / self.n * 1E3 for x in times) # timestamps (ms)
1108
+ self.s = tuple(shape) # inference BCHW shape
1109
+
1110
+ def _run(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')):
1111
+ s, crops = '', []
1112
+ for i, (im, pred) in enumerate(zip(self.ims, self.pred)):
1113
+ s += f'\nimage {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string
1114
+ if pred.shape[0]:
1115
+ for c in pred[:, -1].unique():
1116
+ n = (pred[:, -1] == c).sum() # detections per class
1117
+ s += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
1118
+ s = s.rstrip(', ')
1119
+ if show or save or render or crop:
1120
+ annotator = Annotator(im, example=str(self.names))
1121
+ for *box, conf, cls in reversed(pred): # xyxy, confidence, class
1122
+ label = f'{self.names[int(cls)]} {conf:.2f}'
1123
+ if crop:
1124
+ file = save_dir / 'crops' / self.names[int(cls)] / self.files[i] if save else None
1125
+ crops.append({
1126
+ 'box': box,
1127
+ 'conf': conf,
1128
+ 'cls': cls,
1129
+ 'label': label,
1130
+ 'im': save_one_box(box, im, file=file, save=save)})
1131
+ else: # all others
1132
+ annotator.box_label(box, label if labels else '', color=colors(cls))
1133
+ im = annotator.im
1134
+ else:
1135
+ s += '(no detections)'
1136
+
1137
+ im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
1138
+ if show:
1139
+ display(im) if is_notebook() else im.show(self.files[i])
1140
+ if save:
1141
+ f = self.files[i]
1142
+ im.save(save_dir / f) # save
1143
+ if i == self.n - 1:
1144
+ LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}")
1145
+ if render:
1146
+ self.ims[i] = np.asarray(im)
1147
+ if pprint:
1148
+ s = s.lstrip('\n')
1149
+ return f'{s}\nSpeed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {self.s}' % self.t
1150
+ if crop:
1151
+ if save:
1152
+ LOGGER.info(f'Saved results to {save_dir}\n')
1153
+ return crops
1154
+
1155
+ @TryExcept('Showing images is not supported in this environment')
1156
+ def show(self, labels=True):
1157
+ self._run(show=True, labels=labels) # show results
1158
+
1159
+ def save(self, labels=True, save_dir='runs/detect/exp', exist_ok=False):
1160
+ save_dir = increment_path(save_dir, exist_ok, mkdir=True) # increment save_dir
1161
+ self._run(save=True, labels=labels, save_dir=save_dir) # save results
1162
+
1163
+ def crop(self, save=True, save_dir='runs/detect/exp', exist_ok=False):
1164
+ save_dir = increment_path(save_dir, exist_ok, mkdir=True) if save else None
1165
+ return self._run(crop=True, save=save, save_dir=save_dir) # crop results
1166
+
1167
+ def render(self, labels=True):
1168
+ self._run(render=True, labels=labels) # render results
1169
+ return self.ims
1170
+
1171
+ def pandas(self):
1172
+ # return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
1173
+ new = copy(self) # return copy
1174
+ ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
1175
+ cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
1176
+ for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
1177
+ a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
1178
+ setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
1179
+ return new
1180
+
1181
+ def tolist(self):
1182
+ # return a list of Detections objects, i.e. 'for result in results.tolist():'
1183
+ r = range(self.n) # iterable
1184
+ x = [Detections([self.ims[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r]
1185
+ # for d in x:
1186
+ # for k in ['ims', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
1187
+ # setattr(d, k, getattr(d, k)[0]) # pop out of list
1188
+ return x
1189
+
1190
+ def print(self):
1191
+ LOGGER.info(self.__str__())
1192
+
1193
+ def __len__(self): # override len(results)
1194
+ return self.n
1195
+
1196
+ def __str__(self): # override print(results)
1197
+ return self._run(pprint=True) # print results
1198
+
1199
+ def __repr__(self):
1200
+ return f'YOLO {self.__class__} instance\n' + self.__str__()
1201
+
1202
+
1203
+ class Proto(nn.Module):
1204
+ # YOLO mask Proto module for segmentation models
1205
+ def __init__(self, c1, c_=256, c2=32): # ch_in, number of protos, number of masks
1206
+ super().__init__()
1207
+ self.cv1 = Conv(c1, c_, k=3)
1208
+ self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
1209
+ self.cv2 = Conv(c_, c_, k=3)
1210
+ self.cv3 = Conv(c_, c2)
1211
+
1212
+ def forward(self, x):
1213
+ return self.cv3(self.cv2(self.upsample(self.cv1(x))))
1214
+
1215
+
1216
+ class UConv(nn.Module):
1217
+ def __init__(self, c1, c_=256, c2=256): # ch_in, number of protos, number of masks
1218
+ super().__init__()
1219
+
1220
+ self.cv1 = Conv(c1, c_, k=3)
1221
+ self.cv2 = nn.Conv2d(c_, c2, 1, 1)
1222
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
1223
+
1224
+ def forward(self, x):
1225
+ return self.up(self.cv2(self.cv1(x)))
1226
+
1227
+
1228
+ class Classify(nn.Module):
1229
+ # YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2)
1230
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
1231
+ super().__init__()
1232
+ c_ = 1280 # efficientnet_b0 size
1233
+ self.conv = Conv(c1, c_, k, s, autopad(k, p), g)
1234
+ self.pool = nn.AdaptiveAvgPool2d(1) # to x(b,c_,1,1)
1235
+ self.drop = nn.Dropout(p=0.0, inplace=True)
1236
+ self.linear = nn.Linear(c_, c2) # to x(b,c2)
1237
+
1238
+ def forward(self, x):
1239
+ if isinstance(x, list):
1240
+ x = torch.cat(x, 1)
1241
+ return self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
1242
+
1243
+ # -------------- Deformable Convolution --------------
1244
+ class DCNv2(nn.Module):
1245
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
1246
+ padding=1, dilation=1, groups=1, deformable_groups=1):
1247
+ super(DCNv2, self).__init__()
1248
+
1249
+ self.in_channels = in_channels
1250
+ self.out_channels = out_channels
1251
+ self.kernel_size = (kernel_size, kernel_size)
1252
+ self.stride = (stride, stride)
1253
+ self.padding = (padding, padding)
1254
+ self.dilation = (dilation, dilation)
1255
+ self.groups = groups
1256
+ self.deformable_groups = deformable_groups
1257
+
1258
+ self.weight = nn.Parameter(
1259
+ torch.empty(out_channels, in_channels, *self.kernel_size)
1260
+ )
1261
+ self.bias = nn.Parameter(torch.empty(out_channels))
1262
+
1263
+ out_channels_offset_mask = (self.deformable_groups * 3 *
1264
+ self.kernel_size[0] * self.kernel_size[1])
1265
+ self.conv_offset_mask = nn.Conv2d(
1266
+ self.in_channels,
1267
+ out_channels_offset_mask,
1268
+ kernel_size=self.kernel_size,
1269
+ stride=self.stride,
1270
+ padding=self.padding,
1271
+ bias=True,
1272
+ )
1273
+ self.bn = nn.BatchNorm2d(out_channels)
1274
+ self.act = Conv.default_act
1275
+ self.reset_parameters()
1276
+
1277
+ def forward(self, x):
1278
+ offset_mask = self.conv_offset_mask(x)
1279
+ o1, o2, mask = torch.chunk(offset_mask, 3, dim=1)
1280
+ offset = torch.cat((o1, o2), dim=1)
1281
+ mask = torch.sigmoid(mask)
1282
+ x = torch.ops.torchvision.deform_conv2d(
1283
+ x,
1284
+ self.weight,
1285
+ offset,
1286
+ mask,
1287
+ self.bias,
1288
+ self.stride[0], self.stride[1],
1289
+ self.padding[0], self.padding[1],
1290
+ self.dilation[0], self.dilation[1],
1291
+ self.groups,
1292
+ self.deformable_groups,
1293
+ True
1294
+ )
1295
+ x = self.bn(x)
1296
+ x = self.act(x)
1297
+ return x
1298
+
1299
+ def reset_parameters(self):
1300
+ n = self.in_channels
1301
+ for k in self.kernel_size:
1302
+ n *= k
1303
+ std = 1. / math.sqrt(n)
1304
+ self.weight.data.uniform_(-std, std)
1305
+ self.bias.data.zero_()
1306
+ self.conv_offset_mask.weight.data.zero_()
1307
+ self.conv_offset_mask.bias.data.zero_()
1308
+
1309
+
1310
+
1311
+ class Bottleneck_DCNv2(nn.Module):
1312
+ # Standard bottleneck with DCN
1313
+ def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): # ch_in, ch_out, shortcut, groups, kernels, expand
1314
+ super().__init__()
1315
+ c_ = int(c2 * e) # hidden channels
1316
+ if k[0] == 3:
1317
+ self.cv1 = DCNv2(c1, c_, k[0], 1)
1318
+ else:
1319
+ self.cv1 = Conv(c1, c_, k[0], 1)
1320
+ if k[1] == 3:
1321
+ self.cv2 = DCNv2(c_, c2, k[1], 1, groups=g)
1322
+ else:
1323
+ self.cv2 = Conv(c_, c2, k[1], 1, g=g)
1324
+ self.add = shortcut and c1 == c2
1325
+
1326
+ def forward(self, x):
1327
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
1328
+
1329
+ class C2f_DCNv2(nn.Module):
1330
+ # CSP Bottleneck with 2 convolutions
1331
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
1332
+ super().__init__()
1333
+ self.c = int(c2 * e) # hidden channels
1334
+ self.cv1 = Conv(c1, 2 * self.c, 1, 1)
1335
+ self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2)
1336
+ self.m = nn.ModuleList(Bottleneck_DCNv2(self.c, self.c, shortcut, g, k=(3, 3), e=0.5) for _ in range(n))
1337
+
1338
+ def forward(self, x):
1339
+ y = list(self.cv1(x).split((self.c, self.c), 1))
1340
+ y.extend(m(y[-1]) for m in self.m)
1341
+ return self.cv2(torch.cat(y, 1))
1342
+
1343
+
1344
+ class RepDCNv2Bottleneck(RepNBottleneck):
1345
+ # Standard bottleneck
1346
+ def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
1347
+ super().__init__(c1, c1, shortcut, g, k, e)
1348
+ self.cv1 = Bottleneck_DCNv2(c1, c1, shortcut, g, k, e)
1349
+
1350
+ class RepDCNv2CSP(RepNCSP):
1351
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
1352
+ super().__init__(c1, c2, n, shortcut, g, e)
1353
+ c_ = int(c2 * e)
1354
+ self.m = nn.Sequential(*(RepDCNv2Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
1355
+
1356
+ class RepDCNv2LEAN4(RepNCSPELAN4):
1357
+ def __init__(self, c1, c2, c3, c4, c5=1):
1358
+ super().__init__(c1, c2, c3, c4, c5)
1359
+ self.cv2 = nn.Sequential(RepDCNv2CSP(c3//2, c4, c5), Conv(c4, c4, 3, 1))
1360
+ self.cv3 = nn.Sequential(RepDCNv2CSP(c4, c4, c5), Conv(c4, c4, 3, 1))
models/experimental.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from utils.downloads import attempt_download
8
+
9
+
10
+ class Sum(nn.Module):
11
+ # Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
12
+ def __init__(self, n, weight=False): # n: number of inputs
13
+ super().__init__()
14
+ self.weight = weight # apply weights boolean
15
+ self.iter = range(n - 1) # iter object
16
+ if weight:
17
+ self.w = nn.Parameter(-torch.arange(1.0, n) / 2, requires_grad=True) # layer weights
18
+
19
+ def forward(self, x):
20
+ y = x[0] # no weight
21
+ if self.weight:
22
+ w = torch.sigmoid(self.w) * 2
23
+ for i in self.iter:
24
+ y = y + x[i + 1] * w[i]
25
+ else:
26
+ for i in self.iter:
27
+ y = y + x[i + 1]
28
+ return y
29
+
30
+
31
+ class MixConv2d(nn.Module):
32
+ # Mixed Depth-wise Conv https://arxiv.org/abs/1907.09595
33
+ def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): # ch_in, ch_out, kernel, stride, ch_strategy
34
+ super().__init__()
35
+ n = len(k) # number of convolutions
36
+ if equal_ch: # equal c_ per group
37
+ i = torch.linspace(0, n - 1E-6, c2).floor() # c2 indices
38
+ c_ = [(i == g).sum() for g in range(n)] # intermediate channels
39
+ else: # equal weight.numel() per group
40
+ b = [c2] + [0] * n
41
+ a = np.eye(n + 1, n, k=-1)
42
+ a -= np.roll(a, 1, axis=1)
43
+ a *= np.array(k) ** 2
44
+ a[0] = 1
45
+ c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b
46
+
47
+ self.m = nn.ModuleList([
48
+ nn.Conv2d(c1, int(c_), k, s, k // 2, groups=math.gcd(c1, int(c_)), bias=False) for k, c_ in zip(k, c_)])
49
+ self.bn = nn.BatchNorm2d(c2)
50
+ self.act = nn.SiLU()
51
+
52
+ def forward(self, x):
53
+ return self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
54
+
55
+
56
+ class Ensemble(nn.ModuleList):
57
+ # Ensemble of models
58
+ def __init__(self):
59
+ super().__init__()
60
+
61
+ def forward(self, x, augment=False, profile=False, visualize=False):
62
+ y = [module(x, augment, profile, visualize)[0] for module in self]
63
+ # y = torch.stack(y).max(0)[0] # max ensemble
64
+ # y = torch.stack(y).mean(0) # mean ensemble
65
+ y = torch.cat(y, 1) # nms ensemble
66
+ return y, None # inference, train output
67
+
68
+
69
+ class ORT_NMS(torch.autograd.Function):
70
+ '''ONNX-Runtime NMS operation'''
71
+ @staticmethod
72
+ def forward(ctx,
73
+ boxes,
74
+ scores,
75
+ max_output_boxes_per_class=torch.tensor([100]),
76
+ iou_threshold=torch.tensor([0.45]),
77
+ score_threshold=torch.tensor([0.25])):
78
+ device = boxes.device
79
+ batch = scores.shape[0]
80
+ num_det = random.randint(0, 100)
81
+ batches = torch.randint(0, batch, (num_det,)).sort()[0].to(device)
82
+ idxs = torch.arange(100, 100 + num_det).to(device)
83
+ zeros = torch.zeros((num_det,), dtype=torch.int64).to(device)
84
+ selected_indices = torch.cat([batches[None], zeros[None], idxs[None]], 0).T.contiguous()
85
+ selected_indices = selected_indices.to(torch.int64)
86
+ return selected_indices
87
+
88
+ @staticmethod
89
+ def symbolic(g, boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold):
90
+ return g.op("NonMaxSuppression", boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold)
91
+
92
+
93
+ class TRT_NMS(torch.autograd.Function):
94
+ '''TensorRT NMS operation'''
95
+ @staticmethod
96
+ def forward(
97
+ ctx,
98
+ boxes,
99
+ scores,
100
+ background_class=-1,
101
+ box_coding=1,
102
+ iou_threshold=0.45,
103
+ max_output_boxes=100,
104
+ plugin_version="1",
105
+ score_activation=0,
106
+ score_threshold=0.25,
107
+ ):
108
+
109
+ batch_size, num_boxes, num_classes = scores.shape
110
+ num_det = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32)
111
+ det_boxes = torch.randn(batch_size, max_output_boxes, 4)
112
+ det_scores = torch.randn(batch_size, max_output_boxes)
113
+ det_classes = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32)
114
+ return num_det, det_boxes, det_scores, det_classes
115
+
116
+ @staticmethod
117
+ def symbolic(g,
118
+ boxes,
119
+ scores,
120
+ background_class=-1,
121
+ box_coding=1,
122
+ iou_threshold=0.45,
123
+ max_output_boxes=100,
124
+ plugin_version="1",
125
+ score_activation=0,
126
+ score_threshold=0.25):
127
+ out = g.op("TRT::EfficientNMS_TRT",
128
+ boxes,
129
+ scores,
130
+ background_class_i=background_class,
131
+ box_coding_i=box_coding,
132
+ iou_threshold_f=iou_threshold,
133
+ max_output_boxes_i=max_output_boxes,
134
+ plugin_version_s=plugin_version,
135
+ score_activation_i=score_activation,
136
+ score_threshold_f=score_threshold,
137
+ outputs=4)
138
+ nums, boxes, scores, classes = out
139
+ return nums, boxes, scores, classes
140
+
141
+
142
+ class ONNX_ORT(nn.Module):
143
+ '''onnx module with ONNX-Runtime NMS operation.'''
144
+ def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=640, device=None, n_classes=80):
145
+ super().__init__()
146
+ self.device = device if device else torch.device("cpu")
147
+ self.max_obj = torch.tensor([max_obj]).to(device)
148
+ self.iou_threshold = torch.tensor([iou_thres]).to(device)
149
+ self.score_threshold = torch.tensor([score_thres]).to(device)
150
+ self.max_wh = max_wh # if max_wh != 0 : non-agnostic else : agnostic
151
+ self.convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
152
+ dtype=torch.float32,
153
+ device=self.device)
154
+ self.n_classes=n_classes
155
+
156
+ def forward(self, x):
157
+ ## https://github.com/thaitc-hust/yolov9-tensorrt/blob/main/torch2onnx.py
158
+ ## thanks https://github.com/thaitc-hust
159
+ if isinstance(x, list): ## yolov9-c.pt and yolov9-e.pt return list
160
+ x = x[1]
161
+ x = x.permute(0, 2, 1)
162
+ bboxes_x = x[..., 0:1]
163
+ bboxes_y = x[..., 1:2]
164
+ bboxes_w = x[..., 2:3]
165
+ bboxes_h = x[..., 3:4]
166
+ bboxes = torch.cat([bboxes_x, bboxes_y, bboxes_w, bboxes_h], dim = -1)
167
+ bboxes = bboxes.unsqueeze(2) # [n_batch, n_bboxes, 4] -> [n_batch, n_bboxes, 1, 4]
168
+ obj_conf = x[..., 4:]
169
+ scores = obj_conf
170
+ bboxes @= self.convert_matrix
171
+ max_score, category_id = scores.max(2, keepdim=True)
172
+ dis = category_id.float() * self.max_wh
173
+ nmsbox = bboxes + dis
174
+ max_score_tp = max_score.transpose(1, 2).contiguous()
175
+ selected_indices = ORT_NMS.apply(nmsbox, max_score_tp, self.max_obj, self.iou_threshold, self.score_threshold)
176
+ X, Y = selected_indices[:, 0], selected_indices[:, 2]
177
+ selected_boxes = bboxes[X, Y, :]
178
+ selected_categories = category_id[X, Y, :].float()
179
+ selected_scores = max_score[X, Y, :]
180
+ X = X.unsqueeze(1).float()
181
+ return torch.cat([X, selected_boxes, selected_categories, selected_scores], 1)
182
+
183
+
184
+ class ONNX_TRT(nn.Module):
185
+ '''onnx module with TensorRT NMS operation.'''
186
+ def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None ,device=None, n_classes=80):
187
+ super().__init__()
188
+ assert max_wh is None
189
+ self.device = device if device else torch.device('cpu')
190
+ self.background_class = -1,
191
+ self.box_coding = 1,
192
+ self.iou_threshold = iou_thres
193
+ self.max_obj = max_obj
194
+ self.plugin_version = '1'
195
+ self.score_activation = 0
196
+ self.score_threshold = score_thres
197
+ self.n_classes=n_classes
198
+
199
+ def forward(self, x):
200
+ ## https://github.com/thaitc-hust/yolov9-tensorrt/blob/main/torch2onnx.py
201
+ ## thanks https://github.com/thaitc-hust
202
+ if isinstance(x, list): ## yolov9-c.pt and yolov9-e.pt return list
203
+ x = x[1]
204
+ x = x.permute(0, 2, 1)
205
+ bboxes_x = x[..., 0:1]
206
+ bboxes_y = x[..., 1:2]
207
+ bboxes_w = x[..., 2:3]
208
+ bboxes_h = x[..., 3:4]
209
+ bboxes = torch.cat([bboxes_x, bboxes_y, bboxes_w, bboxes_h], dim = -1)
210
+ bboxes = bboxes.unsqueeze(2) # [n_batch, n_bboxes, 4] -> [n_batch, n_bboxes, 1, 4]
211
+ obj_conf = x[..., 4:]
212
+ scores = obj_conf
213
+ num_det, det_boxes, det_scores, det_classes = TRT_NMS.apply(bboxes, scores, self.background_class, self.box_coding,
214
+ self.iou_threshold, self.max_obj,
215
+ self.plugin_version, self.score_activation,
216
+ self.score_threshold)
217
+ return num_det, det_boxes, det_scores, det_classes
218
+
219
+ class End2End(nn.Module):
220
+ '''export onnx or tensorrt model with NMS operation.'''
221
+ def __init__(self, model, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None, device=None, n_classes=80):
222
+ super().__init__()
223
+ device = device if device else torch.device('cpu')
224
+ assert isinstance(max_wh,(int)) or max_wh is None
225
+ self.model = model.to(device)
226
+ self.model.model[-1].end2end = True
227
+ self.patch_model = ONNX_TRT if max_wh is None else ONNX_ORT
228
+ self.end2end = self.patch_model(max_obj, iou_thres, score_thres, max_wh, device, n_classes)
229
+ self.end2end.eval()
230
+
231
+ def forward(self, x):
232
+ x = self.model(x)
233
+ x = self.end2end(x)
234
+ return x
235
+
236
+
237
+ def attempt_load(weights, device=None, inplace=True, fuse=True):
238
+ # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
239
+ from models.yolo import Detect, Model
240
+
241
+ model = Ensemble()
242
+ for w in weights if isinstance(weights, list) else [weights]:
243
+ ckpt = torch.load(attempt_download(w), map_location='cpu') # load
244
+ ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
245
+
246
+ # Model compatibility updates
247
+ if not hasattr(ckpt, 'stride'):
248
+ ckpt.stride = torch.tensor([32.])
249
+ if hasattr(ckpt, 'names') and isinstance(ckpt.names, (list, tuple)):
250
+ ckpt.names = dict(enumerate(ckpt.names)) # convert to dict
251
+
252
+ model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode
253
+
254
+ # Module compatibility updates
255
+ for m in model.modules():
256
+ t = type(m)
257
+ if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model):
258
+ m.inplace = inplace # torch 1.7.0 compatibility
259
+ # if t is Detect and not isinstance(m.anchor_grid, list):
260
+ # delattr(m, 'anchor_grid')
261
+ # setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)
262
+ elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
263
+ m.recompute_scale_factor = None # torch 1.11.0 compatibility
264
+
265
+ # Return model
266
+ if len(model) == 1:
267
+ return model[-1]
268
+
269
+ # Return detection ensemble
270
+ print(f'Ensemble created with {weights}\n')
271
+ for k in 'names', 'nc', 'yaml':
272
+ setattr(model, k, getattr(model[0], k))
273
+ model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride
274
+ assert all(model[0].nc == m.nc for m in model), f'Models have different class counts: {[m.nc for m in model]}'
275
+ return model
models/tf.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import sys
3
+ from copy import deepcopy
4
+ from pathlib import Path
5
+
6
+ FILE = Path(__file__).resolve()
7
+ ROOT = FILE.parents[1] # YOLO root directory
8
+ if str(ROOT) not in sys.path:
9
+ sys.path.append(str(ROOT)) # add ROOT to PATH
10
+ # ROOT = ROOT.relative_to(Path.cwd()) # relative
11
+
12
+ import numpy as np
13
+ import tensorflow as tf
14
+ import torch
15
+ import torch.nn as nn
16
+ from tensorflow import keras
17
+
18
+ from models.common import (C3, SPP, SPPF, Bottleneck, BottleneckCSP, C3x, Concat, Conv, CrossConv, DWConv,
19
+ DWConvTranspose2d, Focus, autopad)
20
+ from models.experimental import MixConv2d, attempt_load
21
+ from models.yolo import Detect, Segment
22
+ from utils.activations import SiLU
23
+ from utils.general import LOGGER, make_divisible, print_args
24
+
25
+
26
+ class TFBN(keras.layers.Layer):
27
+ # TensorFlow BatchNormalization wrapper
28
+ def __init__(self, w=None):
29
+ super().__init__()
30
+ self.bn = keras.layers.BatchNormalization(
31
+ beta_initializer=keras.initializers.Constant(w.bias.numpy()),
32
+ gamma_initializer=keras.initializers.Constant(w.weight.numpy()),
33
+ moving_mean_initializer=keras.initializers.Constant(w.running_mean.numpy()),
34
+ moving_variance_initializer=keras.initializers.Constant(w.running_var.numpy()),
35
+ epsilon=w.eps)
36
+
37
+ def call(self, inputs):
38
+ return self.bn(inputs)
39
+
40
+
41
+ class TFPad(keras.layers.Layer):
42
+ # Pad inputs in spatial dimensions 1 and 2
43
+ def __init__(self, pad):
44
+ super().__init__()
45
+ if isinstance(pad, int):
46
+ self.pad = tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]])
47
+ else: # tuple/list
48
+ self.pad = tf.constant([[0, 0], [pad[0], pad[0]], [pad[1], pad[1]], [0, 0]])
49
+
50
+ def call(self, inputs):
51
+ return tf.pad(inputs, self.pad, mode='constant', constant_values=0)
52
+
53
+
54
+ class TFConv(keras.layers.Layer):
55
+ # Standard convolution
56
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
57
+ # ch_in, ch_out, weights, kernel, stride, padding, groups
58
+ super().__init__()
59
+ assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
60
+ # TensorFlow convolution padding is inconsistent with PyTorch (e.g. k=3 s=2 'SAME' padding)
61
+ # see https://stackoverflow.com/questions/52975843/comparing-conv2d-with-padding-between-tensorflow-and-pytorch
62
+ conv = keras.layers.Conv2D(
63
+ filters=c2,
64
+ kernel_size=k,
65
+ strides=s,
66
+ padding='SAME' if s == 1 else 'VALID',
67
+ use_bias=not hasattr(w, 'bn'),
68
+ kernel_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy()),
69
+ bias_initializer='zeros' if hasattr(w, 'bn') else keras.initializers.Constant(w.conv.bias.numpy()))
70
+ self.conv = conv if s == 1 else keras.Sequential([TFPad(autopad(k, p)), conv])
71
+ self.bn = TFBN(w.bn) if hasattr(w, 'bn') else tf.identity
72
+ self.act = activations(w.act) if act else tf.identity
73
+
74
+ def call(self, inputs):
75
+ return self.act(self.bn(self.conv(inputs)))
76
+
77
+
78
+ class TFDWConv(keras.layers.Layer):
79
+ # Depthwise convolution
80
+ def __init__(self, c1, c2, k=1, s=1, p=None, act=True, w=None):
81
+ # ch_in, ch_out, weights, kernel, stride, padding, groups
82
+ super().__init__()
83
+ assert c2 % c1 == 0, f'TFDWConv() output={c2} must be a multiple of input={c1} channels'
84
+ conv = keras.layers.DepthwiseConv2D(
85
+ kernel_size=k,
86
+ depth_multiplier=c2 // c1,
87
+ strides=s,
88
+ padding='SAME' if s == 1 else 'VALID',
89
+ use_bias=not hasattr(w, 'bn'),
90
+ depthwise_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy()),
91
+ bias_initializer='zeros' if hasattr(w, 'bn') else keras.initializers.Constant(w.conv.bias.numpy()))
92
+ self.conv = conv if s == 1 else keras.Sequential([TFPad(autopad(k, p)), conv])
93
+ self.bn = TFBN(w.bn) if hasattr(w, 'bn') else tf.identity
94
+ self.act = activations(w.act) if act else tf.identity
95
+
96
+ def call(self, inputs):
97
+ return self.act(self.bn(self.conv(inputs)))
98
+
99
+
100
+ class TFDWConvTranspose2d(keras.layers.Layer):
101
+ # Depthwise ConvTranspose2d
102
+ def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0, w=None):
103
+ # ch_in, ch_out, weights, kernel, stride, padding, groups
104
+ super().__init__()
105
+ assert c1 == c2, f'TFDWConv() output={c2} must be equal to input={c1} channels'
106
+ assert k == 4 and p1 == 1, 'TFDWConv() only valid for k=4 and p1=1'
107
+ weight, bias = w.weight.permute(2, 3, 1, 0).numpy(), w.bias.numpy()
108
+ self.c1 = c1
109
+ self.conv = [
110
+ keras.layers.Conv2DTranspose(filters=1,
111
+ kernel_size=k,
112
+ strides=s,
113
+ padding='VALID',
114
+ output_padding=p2,
115
+ use_bias=True,
116
+ kernel_initializer=keras.initializers.Constant(weight[..., i:i + 1]),
117
+ bias_initializer=keras.initializers.Constant(bias[i])) for i in range(c1)]
118
+
119
+ def call(self, inputs):
120
+ return tf.concat([m(x) for m, x in zip(self.conv, tf.split(inputs, self.c1, 3))], 3)[:, 1:-1, 1:-1]
121
+
122
+
123
+ class TFFocus(keras.layers.Layer):
124
+ # Focus wh information into c-space
125
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
126
+ # ch_in, ch_out, kernel, stride, padding, groups
127
+ super().__init__()
128
+ self.conv = TFConv(c1 * 4, c2, k, s, p, g, act, w.conv)
129
+
130
+ def call(self, inputs): # x(b,w,h,c) -> y(b,w/2,h/2,4c)
131
+ # inputs = inputs / 255 # normalize 0-255 to 0-1
132
+ inputs = [inputs[:, ::2, ::2, :], inputs[:, 1::2, ::2, :], inputs[:, ::2, 1::2, :], inputs[:, 1::2, 1::2, :]]
133
+ return self.conv(tf.concat(inputs, 3))
134
+
135
+
136
+ class TFBottleneck(keras.layers.Layer):
137
+ # Standard bottleneck
138
+ def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, w=None): # ch_in, ch_out, shortcut, groups, expansion
139
+ super().__init__()
140
+ c_ = int(c2 * e) # hidden channels
141
+ self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
142
+ self.cv2 = TFConv(c_, c2, 3, 1, g=g, w=w.cv2)
143
+ self.add = shortcut and c1 == c2
144
+
145
+ def call(self, inputs):
146
+ return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs))
147
+
148
+
149
+ class TFCrossConv(keras.layers.Layer):
150
+ # Cross Convolution
151
+ def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False, w=None):
152
+ super().__init__()
153
+ c_ = int(c2 * e) # hidden channels
154
+ self.cv1 = TFConv(c1, c_, (1, k), (1, s), w=w.cv1)
155
+ self.cv2 = TFConv(c_, c2, (k, 1), (s, 1), g=g, w=w.cv2)
156
+ self.add = shortcut and c1 == c2
157
+
158
+ def call(self, inputs):
159
+ return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs))
160
+
161
+
162
+ class TFConv2d(keras.layers.Layer):
163
+ # Substitution for PyTorch nn.Conv2D
164
+ def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None):
165
+ super().__init__()
166
+ assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
167
+ self.conv = keras.layers.Conv2D(filters=c2,
168
+ kernel_size=k,
169
+ strides=s,
170
+ padding='VALID',
171
+ use_bias=bias,
172
+ kernel_initializer=keras.initializers.Constant(
173
+ w.weight.permute(2, 3, 1, 0).numpy()),
174
+ bias_initializer=keras.initializers.Constant(w.bias.numpy()) if bias else None)
175
+
176
+ def call(self, inputs):
177
+ return self.conv(inputs)
178
+
179
+
180
+ class TFBottleneckCSP(keras.layers.Layer):
181
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
182
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
183
+ # ch_in, ch_out, number, shortcut, groups, expansion
184
+ super().__init__()
185
+ c_ = int(c2 * e) # hidden channels
186
+ self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
187
+ self.cv2 = TFConv2d(c1, c_, 1, 1, bias=False, w=w.cv2)
188
+ self.cv3 = TFConv2d(c_, c_, 1, 1, bias=False, w=w.cv3)
189
+ self.cv4 = TFConv(2 * c_, c2, 1, 1, w=w.cv4)
190
+ self.bn = TFBN(w.bn)
191
+ self.act = lambda x: keras.activations.swish(x)
192
+ self.m = keras.Sequential([TFBottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])
193
+
194
+ def call(self, inputs):
195
+ y1 = self.cv3(self.m(self.cv1(inputs)))
196
+ y2 = self.cv2(inputs)
197
+ return self.cv4(self.act(self.bn(tf.concat((y1, y2), axis=3))))
198
+
199
+
200
+ class TFC3(keras.layers.Layer):
201
+ # CSP Bottleneck with 3 convolutions
202
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
203
+ # ch_in, ch_out, number, shortcut, groups, expansion
204
+ super().__init__()
205
+ c_ = int(c2 * e) # hidden channels
206
+ self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
207
+ self.cv2 = TFConv(c1, c_, 1, 1, w=w.cv2)
208
+ self.cv3 = TFConv(2 * c_, c2, 1, 1, w=w.cv3)
209
+ self.m = keras.Sequential([TFBottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])
210
+
211
+ def call(self, inputs):
212
+ return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3))
213
+
214
+
215
+ class TFC3x(keras.layers.Layer):
216
+ # 3 module with cross-convolutions
217
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
218
+ # ch_in, ch_out, number, shortcut, groups, expansion
219
+ super().__init__()
220
+ c_ = int(c2 * e) # hidden channels
221
+ self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
222
+ self.cv2 = TFConv(c1, c_, 1, 1, w=w.cv2)
223
+ self.cv3 = TFConv(2 * c_, c2, 1, 1, w=w.cv3)
224
+ self.m = keras.Sequential([
225
+ TFCrossConv(c_, c_, k=3, s=1, g=g, e=1.0, shortcut=shortcut, w=w.m[j]) for j in range(n)])
226
+
227
+ def call(self, inputs):
228
+ return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3))
229
+
230
+
231
+ class TFSPP(keras.layers.Layer):
232
+ # Spatial pyramid pooling layer used in YOLOv3-SPP
233
+ def __init__(self, c1, c2, k=(5, 9, 13), w=None):
234
+ super().__init__()
235
+ c_ = c1 // 2 # hidden channels
236
+ self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
237
+ self.cv2 = TFConv(c_ * (len(k) + 1), c2, 1, 1, w=w.cv2)
238
+ self.m = [keras.layers.MaxPool2D(pool_size=x, strides=1, padding='SAME') for x in k]
239
+
240
+ def call(self, inputs):
241
+ x = self.cv1(inputs)
242
+ return self.cv2(tf.concat([x] + [m(x) for m in self.m], 3))
243
+
244
+
245
+ class TFSPPF(keras.layers.Layer):
246
+ # Spatial pyramid pooling-Fast layer
247
+ def __init__(self, c1, c2, k=5, w=None):
248
+ super().__init__()
249
+ c_ = c1 // 2 # hidden channels
250
+ self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
251
+ self.cv2 = TFConv(c_ * 4, c2, 1, 1, w=w.cv2)
252
+ self.m = keras.layers.MaxPool2D(pool_size=k, strides=1, padding='SAME')
253
+
254
+ def call(self, inputs):
255
+ x = self.cv1(inputs)
256
+ y1 = self.m(x)
257
+ y2 = self.m(y1)
258
+ return self.cv2(tf.concat([x, y1, y2, self.m(y2)], 3))
259
+
260
+
261
+ class TFDetect(keras.layers.Layer):
262
+ # TF YOLO Detect layer
263
+ def __init__(self, nc=80, anchors=(), ch=(), imgsz=(640, 640), w=None): # detection layer
264
+ super().__init__()
265
+ self.stride = tf.convert_to_tensor(w.stride.numpy(), dtype=tf.float32)
266
+ self.nc = nc # number of classes
267
+ self.no = nc + 5 # number of outputs per anchor
268
+ self.nl = len(anchors) # number of detection layers
269
+ self.na = len(anchors[0]) // 2 # number of anchors
270
+ self.grid = [tf.zeros(1)] * self.nl # init grid
271
+ self.anchors = tf.convert_to_tensor(w.anchors.numpy(), dtype=tf.float32)
272
+ self.anchor_grid = tf.reshape(self.anchors * tf.reshape(self.stride, [self.nl, 1, 1]), [self.nl, 1, -1, 1, 2])
273
+ self.m = [TFConv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)]
274
+ self.training = False # set to False after building model
275
+ self.imgsz = imgsz
276
+ for i in range(self.nl):
277
+ ny, nx = self.imgsz[0] // self.stride[i], self.imgsz[1] // self.stride[i]
278
+ self.grid[i] = self._make_grid(nx, ny)
279
+
280
+ def call(self, inputs):
281
+ z = [] # inference output
282
+ x = []
283
+ for i in range(self.nl):
284
+ x.append(self.m[i](inputs[i]))
285
+ # x(bs,20,20,255) to x(bs,3,20,20,85)
286
+ ny, nx = self.imgsz[0] // self.stride[i], self.imgsz[1] // self.stride[i]
287
+ x[i] = tf.reshape(x[i], [-1, ny * nx, self.na, self.no])
288
+
289
+ if not self.training: # inference
290
+ y = x[i]
291
+ grid = tf.transpose(self.grid[i], [0, 2, 1, 3]) - 0.5
292
+ anchor_grid = tf.transpose(self.anchor_grid[i], [0, 2, 1, 3]) * 4
293
+ xy = (tf.sigmoid(y[..., 0:2]) * 2 + grid) * self.stride[i] # xy
294
+ wh = tf.sigmoid(y[..., 2:4]) ** 2 * anchor_grid
295
+ # Normalize xywh to 0-1 to reduce calibration error
296
+ xy /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
297
+ wh /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
298
+ y = tf.concat([xy, wh, tf.sigmoid(y[..., 4:5 + self.nc]), y[..., 5 + self.nc:]], -1)
299
+ z.append(tf.reshape(y, [-1, self.na * ny * nx, self.no]))
300
+
301
+ return tf.transpose(x, [0, 2, 1, 3]) if self.training else (tf.concat(z, 1),)
302
+
303
+ @staticmethod
304
+ def _make_grid(nx=20, ny=20):
305
+ # yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
306
+ # return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
307
+ xv, yv = tf.meshgrid(tf.range(nx), tf.range(ny))
308
+ return tf.cast(tf.reshape(tf.stack([xv, yv], 2), [1, 1, ny * nx, 2]), dtype=tf.float32)
309
+
310
+
311
+ class TFSegment(TFDetect):
312
+ # YOLO Segment head for segmentation models
313
+ def __init__(self, nc=80, anchors=(), nm=32, npr=256, ch=(), imgsz=(640, 640), w=None):
314
+ super().__init__(nc, anchors, ch, imgsz, w)
315
+ self.nm = nm # number of masks
316
+ self.npr = npr # number of protos
317
+ self.no = 5 + nc + self.nm # number of outputs per anchor
318
+ self.m = [TFConv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)] # output conv
319
+ self.proto = TFProto(ch[0], self.npr, self.nm, w=w.proto) # protos
320
+ self.detect = TFDetect.call
321
+
322
+ def call(self, x):
323
+ p = self.proto(x[0])
324
+ # p = TFUpsample(None, scale_factor=4, mode='nearest')(self.proto(x[0])) # (optional) full-size protos
325
+ p = tf.transpose(p, [0, 3, 1, 2]) # from shape(1,160,160,32) to shape(1,32,160,160)
326
+ x = self.detect(self, x)
327
+ return (x, p) if self.training else (x[0], p)
328
+
329
+
330
+ class TFProto(keras.layers.Layer):
331
+
332
+ def __init__(self, c1, c_=256, c2=32, w=None):
333
+ super().__init__()
334
+ self.cv1 = TFConv(c1, c_, k=3, w=w.cv1)
335
+ self.upsample = TFUpsample(None, scale_factor=2, mode='nearest')
336
+ self.cv2 = TFConv(c_, c_, k=3, w=w.cv2)
337
+ self.cv3 = TFConv(c_, c2, w=w.cv3)
338
+
339
+ def call(self, inputs):
340
+ return self.cv3(self.cv2(self.upsample(self.cv1(inputs))))
341
+
342
+
343
+ class TFUpsample(keras.layers.Layer):
344
+ # TF version of torch.nn.Upsample()
345
+ def __init__(self, size, scale_factor, mode, w=None): # warning: all arguments needed including 'w'
346
+ super().__init__()
347
+ assert scale_factor % 2 == 0, "scale_factor must be multiple of 2"
348
+ self.upsample = lambda x: tf.image.resize(x, (x.shape[1] * scale_factor, x.shape[2] * scale_factor), mode)
349
+ # self.upsample = keras.layers.UpSampling2D(size=scale_factor, interpolation=mode)
350
+ # with default arguments: align_corners=False, half_pixel_centers=False
351
+ # self.upsample = lambda x: tf.raw_ops.ResizeNearestNeighbor(images=x,
352
+ # size=(x.shape[1] * 2, x.shape[2] * 2))
353
+
354
+ def call(self, inputs):
355
+ return self.upsample(inputs)
356
+
357
+
358
+ class TFConcat(keras.layers.Layer):
359
+ # TF version of torch.concat()
360
+ def __init__(self, dimension=1, w=None):
361
+ super().__init__()
362
+ assert dimension == 1, "convert only NCHW to NHWC concat"
363
+ self.d = 3
364
+
365
+ def call(self, inputs):
366
+ return tf.concat(inputs, self.d)
367
+
368
+
369
+ def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3)
370
+ LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
371
+ anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
372
+ na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
373
+ no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
374
+
375
+ layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
376
+ for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
377
+ m_str = m
378
+ m = eval(m) if isinstance(m, str) else m # eval strings
379
+ for j, a in enumerate(args):
380
+ try:
381
+ args[j] = eval(a) if isinstance(a, str) else a # eval strings
382
+ except NameError:
383
+ pass
384
+
385
+ n = max(round(n * gd), 1) if n > 1 else n # depth gain
386
+ if m in [
387
+ nn.Conv2d, Conv, DWConv, DWConvTranspose2d, Bottleneck, SPP, SPPF, MixConv2d, Focus, CrossConv,
388
+ BottleneckCSP, C3, C3x]:
389
+ c1, c2 = ch[f], args[0]
390
+ c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
391
+
392
+ args = [c1, c2, *args[1:]]
393
+ if m in [BottleneckCSP, C3, C3x]:
394
+ args.insert(2, n)
395
+ n = 1
396
+ elif m is nn.BatchNorm2d:
397
+ args = [ch[f]]
398
+ elif m is Concat:
399
+ c2 = sum(ch[-1 if x == -1 else x + 1] for x in f)
400
+ elif m in [Detect, Segment]:
401
+ args.append([ch[x + 1] for x in f])
402
+ if isinstance(args[1], int): # number of anchors
403
+ args[1] = [list(range(args[1] * 2))] * len(f)
404
+ if m is Segment:
405
+ args[3] = make_divisible(args[3] * gw, 8)
406
+ args.append(imgsz)
407
+ else:
408
+ c2 = ch[f]
409
+
410
+ tf_m = eval('TF' + m_str.replace('nn.', ''))
411
+ m_ = keras.Sequential([tf_m(*args, w=model.model[i][j]) for j in range(n)]) if n > 1 \
412
+ else tf_m(*args, w=model.model[i]) # module
413
+
414
+ torch_m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
415
+ t = str(m)[8:-2].replace('__main__.', '') # module type
416
+ np = sum(x.numel() for x in torch_m_.parameters()) # number params
417
+ m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
418
+ LOGGER.info(f'{i:>3}{str(f):>18}{str(n):>3}{np:>10} {t:<40}{str(args):<30}') # print
419
+ save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
420
+ layers.append(m_)
421
+ ch.append(c2)
422
+ return keras.Sequential(layers), sorted(save)
423
+
424
+
425
+ class TFModel:
426
+ # TF YOLO model
427
+ def __init__(self, cfg='yolo.yaml', ch=3, nc=None, model=None, imgsz=(640, 640)): # model, channels, classes
428
+ super().__init__()
429
+ if isinstance(cfg, dict):
430
+ self.yaml = cfg # model dict
431
+ else: # is *.yaml
432
+ import yaml # for torch hub
433
+ self.yaml_file = Path(cfg).name
434
+ with open(cfg) as f:
435
+ self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict
436
+
437
+ # Define model
438
+ if nc and nc != self.yaml['nc']:
439
+ LOGGER.info(f"Overriding {cfg} nc={self.yaml['nc']} with nc={nc}")
440
+ self.yaml['nc'] = nc # override yaml value
441
+ self.model, self.savelist = parse_model(deepcopy(self.yaml), ch=[ch], model=model, imgsz=imgsz)
442
+
443
+ def predict(self,
444
+ inputs,
445
+ tf_nms=False,
446
+ agnostic_nms=False,
447
+ topk_per_class=100,
448
+ topk_all=100,
449
+ iou_thres=0.45,
450
+ conf_thres=0.25):
451
+ y = [] # outputs
452
+ x = inputs
453
+ for m in self.model.layers:
454
+ if m.f != -1: # if not from previous layer
455
+ x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
456
+
457
+ x = m(x) # run
458
+ y.append(x if m.i in self.savelist else None) # save output
459
+
460
+ # Add TensorFlow NMS
461
+ if tf_nms:
462
+ boxes = self._xywh2xyxy(x[0][..., :4])
463
+ probs = x[0][:, :, 4:5]
464
+ classes = x[0][:, :, 5:]
465
+ scores = probs * classes
466
+ if agnostic_nms:
467
+ nms = AgnosticNMS()((boxes, classes, scores), topk_all, iou_thres, conf_thres)
468
+ else:
469
+ boxes = tf.expand_dims(boxes, 2)
470
+ nms = tf.image.combined_non_max_suppression(boxes,
471
+ scores,
472
+ topk_per_class,
473
+ topk_all,
474
+ iou_thres,
475
+ conf_thres,
476
+ clip_boxes=False)
477
+ return (nms,)
478
+ return x # output [1,6300,85] = [xywh, conf, class0, class1, ...]
479
+ # x = x[0] # [x(1,6300,85), ...] to x(6300,85)
480
+ # xywh = x[..., :4] # x(6300,4) boxes
481
+ # conf = x[..., 4:5] # x(6300,1) confidences
482
+ # cls = tf.reshape(tf.cast(tf.argmax(x[..., 5:], axis=1), tf.float32), (-1, 1)) # x(6300,1) classes
483
+ # return tf.concat([conf, cls, xywh], 1)
484
+
485
+ @staticmethod
486
+ def _xywh2xyxy(xywh):
487
+ # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
488
+ x, y, w, h = tf.split(xywh, num_or_size_splits=4, axis=-1)
489
+ return tf.concat([x - w / 2, y - h / 2, x + w / 2, y + h / 2], axis=-1)
490
+
491
+
492
+ class AgnosticNMS(keras.layers.Layer):
493
+ # TF Agnostic NMS
494
+ def call(self, input, topk_all, iou_thres, conf_thres):
495
+ # wrap map_fn to avoid TypeSpec related error https://stackoverflow.com/a/65809989/3036450
496
+ return tf.map_fn(lambda x: self._nms(x, topk_all, iou_thres, conf_thres),
497
+ input,
498
+ fn_output_signature=(tf.float32, tf.float32, tf.float32, tf.int32),
499
+ name='agnostic_nms')
500
+
501
+ @staticmethod
502
+ def _nms(x, topk_all=100, iou_thres=0.45, conf_thres=0.25): # agnostic NMS
503
+ boxes, classes, scores = x
504
+ class_inds = tf.cast(tf.argmax(classes, axis=-1), tf.float32)
505
+ scores_inp = tf.reduce_max(scores, -1)
506
+ selected_inds = tf.image.non_max_suppression(boxes,
507
+ scores_inp,
508
+ max_output_size=topk_all,
509
+ iou_threshold=iou_thres,
510
+ score_threshold=conf_thres)
511
+ selected_boxes = tf.gather(boxes, selected_inds)
512
+ padded_boxes = tf.pad(selected_boxes,
513
+ paddings=[[0, topk_all - tf.shape(selected_boxes)[0]], [0, 0]],
514
+ mode="CONSTANT",
515
+ constant_values=0.0)
516
+ selected_scores = tf.gather(scores_inp, selected_inds)
517
+ padded_scores = tf.pad(selected_scores,
518
+ paddings=[[0, topk_all - tf.shape(selected_boxes)[0]]],
519
+ mode="CONSTANT",
520
+ constant_values=-1.0)
521
+ selected_classes = tf.gather(class_inds, selected_inds)
522
+ padded_classes = tf.pad(selected_classes,
523
+ paddings=[[0, topk_all - tf.shape(selected_boxes)[0]]],
524
+ mode="CONSTANT",
525
+ constant_values=-1.0)
526
+ valid_detections = tf.shape(selected_inds)[0]
527
+ return padded_boxes, padded_scores, padded_classes, valid_detections
528
+
529
+
530
+ def activations(act=nn.SiLU):
531
+ # Returns TF activation from input PyTorch activation
532
+ if isinstance(act, nn.LeakyReLU):
533
+ return lambda x: keras.activations.relu(x, alpha=0.1)
534
+ elif isinstance(act, nn.Hardswish):
535
+ return lambda x: x * tf.nn.relu6(x + 3) * 0.166666667
536
+ elif isinstance(act, (nn.SiLU, SiLU)):
537
+ return lambda x: keras.activations.swish(x)
538
+ else:
539
+ raise Exception(f'no matching TensorFlow activation found for PyTorch activation {act}')
540
+
541
+
542
+ def representative_dataset_gen(dataset, ncalib=100):
543
+ # Representative dataset generator for use with converter.representative_dataset, returns a generator of np arrays
544
+ for n, (path, img, im0s, vid_cap, string) in enumerate(dataset):
545
+ im = np.transpose(img, [1, 2, 0])
546
+ im = np.expand_dims(im, axis=0).astype(np.float32)
547
+ im /= 255
548
+ yield [im]
549
+ if n >= ncalib:
550
+ break
551
+
552
+
553
+ def run(
554
+ weights=ROOT / 'yolo.pt', # weights path
555
+ imgsz=(640, 640), # inference size h,w
556
+ batch_size=1, # batch size
557
+ dynamic=False, # dynamic batch size
558
+ ):
559
+ # PyTorch model
560
+ im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image
561
+ model = attempt_load(weights, device=torch.device('cpu'), inplace=True, fuse=False)
562
+ _ = model(im) # inference
563
+ model.info()
564
+
565
+ # TensorFlow model
566
+ im = tf.zeros((batch_size, *imgsz, 3)) # BHWC image
567
+ tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
568
+ _ = tf_model.predict(im) # inference
569
+
570
+ # Keras model
571
+ im = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
572
+ keras_model = keras.Model(inputs=im, outputs=tf_model.predict(im))
573
+ keras_model.summary()
574
+
575
+ LOGGER.info('PyTorch, TensorFlow and Keras models successfully verified.\nUse export.py for TF model export.')
576
+
577
+
578
+ def parse_opt():
579
+ parser = argparse.ArgumentParser()
580
+ parser.add_argument('--weights', type=str, default=ROOT / 'yolo.pt', help='weights path')
581
+ parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
582
+ parser.add_argument('--batch-size', type=int, default=1, help='batch size')
583
+ parser.add_argument('--dynamic', action='store_true', help='dynamic batch size')
584
+ opt = parser.parse_args()
585
+ opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
586
+ print_args(vars(opt))
587
+ return opt
588
+
589
+
590
+ def main(opt):
591
+ run(**vars(opt))
592
+
593
+
594
+ if __name__ == "__main__":
595
+ opt = parse_opt()
596
+ main(opt)
models/yolo.py ADDED
@@ -0,0 +1,853 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import platform
4
+ import sys
5
+ from copy import deepcopy
6
+ from pathlib import Path
7
+
8
+ FILE = Path(__file__).resolve()
9
+ ROOT = FILE.parents[1] # YOLO root directory
10
+ if str(ROOT) not in sys.path:
11
+ sys.path.append(str(ROOT)) # add ROOT to PATH
12
+ if platform.system() != 'Windows':
13
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
14
+
15
+ from models.common import *
16
+ from models.experimental import *
17
+ from utils.general import LOGGER, check_version, check_yaml, make_divisible, print_args
18
+ from utils.plots import feature_visualization
19
+ from utils.torch_utils import (fuse_conv_and_bn, initialize_weights, model_info, profile, scale_img, select_device,
20
+ time_sync)
21
+ from utils.tal.anchor_generator import make_anchors, dist2bbox
22
+ from models.attention.blocks import *
23
+ try:
24
+ import thop # for FLOPs computation
25
+ except ImportError:
26
+ thop = None
27
+
28
+
29
+ class Detect(nn.Module):
30
+ # YOLO Detect head for detection models
31
+ dynamic = False # force grid reconstruction
32
+ export = False # export mode
33
+ shape = None
34
+ anchors = torch.empty(0) # init
35
+ strides = torch.empty(0) # init
36
+
37
+ def __init__(self, nc=80, ch=(), inplace=True): # detection layer
38
+ super().__init__()
39
+ self.nc = nc # number of classes
40
+ self.nl = len(ch) # number of detection layers
41
+ self.reg_max = 16
42
+ self.no = nc + self.reg_max * 4 # number of outputs per anchor
43
+ self.inplace = inplace # use inplace ops (e.g. slice assignment)
44
+ self.stride = torch.zeros(self.nl) # strides computed during build
45
+
46
+ c2, c3 = max((ch[0] // 4, self.reg_max * 4, 16)), max((ch[0], min((self.nc * 2, 128)))) # channels
47
+ self.cv2 = nn.ModuleList(
48
+ nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch)
49
+ self.cv3 = nn.ModuleList(
50
+ nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
51
+ self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
52
+
53
+ def forward(self, x):
54
+ shape = x[0].shape # BCHW
55
+ for i in range(self.nl):
56
+ x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
57
+ if self.training:
58
+ return x
59
+ elif self.dynamic or self.shape != shape:
60
+ self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
61
+ self.shape = shape
62
+
63
+ box, cls = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2).split((self.reg_max * 4, self.nc), 1)
64
+ dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
65
+ y = torch.cat((dbox, cls.sigmoid()), 1)
66
+ return y if self.export else (y, x)
67
+
68
+ def bias_init(self):
69
+ # Initialize Detect() biases, WARNING: requires stride availability
70
+ m = self # self.model[-1] # Detect() module
71
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
72
+ # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
73
+ for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
74
+ a[-1].bias.data[:] = 1.0 # box
75
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
76
+
77
+
78
+ class DDetect(nn.Module):
79
+ # YOLO Detect head for detection models
80
+ dynamic = False # force grid reconstruction
81
+ export = False # export mode
82
+ shape = None
83
+ anchors = torch.empty(0) # init
84
+ strides = torch.empty(0) # init
85
+
86
+ def __init__(self, nc=80, ch=(), inplace=True): # detection layer
87
+ super().__init__()
88
+ self.nc = nc # number of classes
89
+ self.nl = len(ch) # number of detection layers
90
+ self.reg_max = 16
91
+ self.no = nc + self.reg_max * 4 # number of outputs per anchor
92
+ self.inplace = inplace # use inplace ops (e.g. slice assignment)
93
+ self.stride = torch.zeros(self.nl) # strides computed during build
94
+
95
+ c2, c3 = make_divisible(max((ch[0] // 4, self.reg_max * 4, 16)), 4), max((ch[0], min((self.nc * 2, 128)))) # channels
96
+ self.cv2 = nn.ModuleList(
97
+ nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3, g=4), nn.Conv2d(c2, 4 * self.reg_max, 1, groups=4)) for x in ch)
98
+ self.cv3 = nn.ModuleList(
99
+ nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
100
+ self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
101
+
102
+ def forward(self, x):
103
+ shape = x[0].shape # BCHW
104
+ for i in range(self.nl):
105
+ x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
106
+ if self.training:
107
+ return x
108
+ elif self.dynamic or self.shape != shape:
109
+ self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
110
+ self.shape = shape
111
+
112
+ box, cls = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2).split((self.reg_max * 4, self.nc), 1)
113
+ dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
114
+ y = torch.cat((dbox, cls.sigmoid()), 1)
115
+ return y if self.export else (y, x)
116
+
117
+ def bias_init(self):
118
+ # Initialize Detect() biases, WARNING: requires stride availability
119
+ m = self # self.model[-1] # Detect() module
120
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
121
+ # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
122
+ for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
123
+ a[-1].bias.data[:] = 1.0 # box
124
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
125
+
126
+
127
+ class DualDetect(nn.Module):
128
+ # YOLO Detect head for detection models
129
+ dynamic = False # force grid reconstruction
130
+ export = False # export mode
131
+ shape = None
132
+ anchors = torch.empty(0) # init
133
+ strides = torch.empty(0) # init
134
+
135
+ def __init__(self, nc=80, ch=(), inplace=True): # detection layer
136
+ super().__init__()
137
+ self.nc = nc # number of classes
138
+ self.nl = len(ch) // 2 # number of detection layers
139
+ self.reg_max = 16
140
+ self.no = nc + self.reg_max * 4 # number of outputs per anchor
141
+ self.inplace = inplace # use inplace ops (e.g. slice assignment)
142
+ self.stride = torch.zeros(self.nl) # strides computed during build
143
+
144
+ c2, c3 = max((ch[0] // 4, self.reg_max * 4, 16)), max((ch[0], min((self.nc * 2, 128)))) # channels
145
+ c4, c5 = max((ch[self.nl] // 4, self.reg_max * 4, 16)), max((ch[self.nl], min((self.nc * 2, 128)))) # channels
146
+ self.cv2 = nn.ModuleList(
147
+ nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch[:self.nl])
148
+ self.cv3 = nn.ModuleList(
149
+ nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch[:self.nl])
150
+ self.cv4 = nn.ModuleList(
151
+ nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, 4 * self.reg_max, 1)) for x in ch[self.nl:])
152
+ self.cv5 = nn.ModuleList(
153
+ nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nc, 1)) for x in ch[self.nl:])
154
+ self.dfl = DFL(self.reg_max)
155
+ self.dfl2 = DFL(self.reg_max)
156
+
157
+ def forward(self, x):
158
+ shape = x[0].shape # BCHW
159
+ d1 = []
160
+ d2 = []
161
+ for i in range(self.nl):
162
+ d1.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
163
+ d2.append(torch.cat((self.cv4[i](x[self.nl+i]), self.cv5[i](x[self.nl+i])), 1))
164
+ if self.training:
165
+ return [d1, d2]
166
+ elif self.dynamic or self.shape != shape:
167
+ self.anchors, self.strides = (d1.transpose(0, 1) for d1 in make_anchors(d1, self.stride, 0.5))
168
+ self.shape = shape
169
+
170
+ box, cls = torch.cat([di.view(shape[0], self.no, -1) for di in d1], 2).split((self.reg_max * 4, self.nc), 1)
171
+ dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
172
+ box2, cls2 = torch.cat([di.view(shape[0], self.no, -1) for di in d2], 2).split((self.reg_max * 4, self.nc), 1)
173
+ dbox2 = dist2bbox(self.dfl2(box2), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
174
+ y = [torch.cat((dbox, cls.sigmoid()), 1), torch.cat((dbox2, cls2.sigmoid()), 1)]
175
+ return y if self.export else (y, [d1, d2])
176
+
177
+ def bias_init(self):
178
+ # Initialize Detect() biases, WARNING: requires stride availability
179
+ m = self # self.model[-1] # Detect() module
180
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
181
+ # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
182
+ for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
183
+ a[-1].bias.data[:] = 1.0 # box
184
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
185
+ for a, b, s in zip(m.cv4, m.cv5, m.stride): # from
186
+ a[-1].bias.data[:] = 1.0 # box
187
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
188
+
189
+
190
+ class DualDDetect(nn.Module):
191
+ # YOLO Detect head for detection models
192
+ dynamic = False # force grid reconstruction
193
+ export = False # export mode
194
+ shape = None
195
+ anchors = torch.empty(0) # init
196
+ strides = torch.empty(0) # init
197
+
198
+ def __init__(self, nc=80, ch=(), inplace=True): # detection layer
199
+ super().__init__()
200
+ self.nc = nc # number of classes
201
+ self.nl = len(ch) // 2 # number of detection layers
202
+ self.reg_max = 16
203
+ self.no = nc + self.reg_max * 4 # number of outputs per anchor
204
+ self.inplace = inplace # use inplace ops (e.g. slice assignment)
205
+ self.stride = torch.zeros(self.nl) # strides computed during build
206
+
207
+ c2, c3 = make_divisible(max((ch[0] // 4, self.reg_max * 4, 16)), 4), max((ch[0], min((self.nc * 2, 128)))) # channels
208
+ c4, c5 = make_divisible(max((ch[self.nl] // 4, self.reg_max * 4, 16)), 4), max((ch[self.nl], min((self.nc * 2, 128)))) # channels
209
+ self.cv2 = nn.ModuleList(
210
+ nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3, g=4), nn.Conv2d(c2, 4 * self.reg_max, 1, groups=4)) for x in ch[:self.nl])
211
+ self.cv3 = nn.ModuleList(
212
+ nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch[:self.nl])
213
+ self.cv4 = nn.ModuleList(
214
+ nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3, g=4), nn.Conv2d(c4, 4 * self.reg_max, 1, groups=4)) for x in ch[self.nl:])
215
+ self.cv5 = nn.ModuleList(
216
+ nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nc, 1)) for x in ch[self.nl:])
217
+ self.dfl = DFL(self.reg_max)
218
+ self.dfl2 = DFL(self.reg_max)
219
+
220
+ def forward(self, x):
221
+ shape = x[0].shape # BCHW
222
+ d1 = []
223
+ d2 = []
224
+ for i in range(self.nl):
225
+ d1.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
226
+ d2.append(torch.cat((self.cv4[i](x[self.nl+i]), self.cv5[i](x[self.nl+i])), 1))
227
+ if self.training:
228
+ return [d1, d2]
229
+ elif self.dynamic or self.shape != shape:
230
+ self.anchors, self.strides = (d1.transpose(0, 1) for d1 in make_anchors(d1, self.stride, 0.5))
231
+ self.shape = shape
232
+
233
+ box, cls = torch.cat([di.view(shape[0], self.no, -1) for di in d1], 2).split((self.reg_max * 4, self.nc), 1)
234
+ dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
235
+ box2, cls2 = torch.cat([di.view(shape[0], self.no, -1) for di in d2], 2).split((self.reg_max * 4, self.nc), 1)
236
+ dbox2 = dist2bbox(self.dfl2(box2), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
237
+ y = [torch.cat((dbox, cls.sigmoid()), 1), torch.cat((dbox2, cls2.sigmoid()), 1)]
238
+ return y if self.export else (y, [d1, d2])
239
+ #y = torch.cat((dbox2, cls2.sigmoid()), 1)
240
+ #return y if self.export else (y, d2)
241
+ #y1 = torch.cat((dbox, cls.sigmoid()), 1)
242
+ #y2 = torch.cat((dbox2, cls2.sigmoid()), 1)
243
+ #return [y1, y2] if self.export else [(y1, d1), (y2, d2)]
244
+ #return [y1, y2] if self.export else [(y1, y2), (d1, d2)]
245
+
246
+ def bias_init(self):
247
+ # Initialize Detect() biases, WARNING: requires stride availability
248
+ m = self # self.model[-1] # Detect() module
249
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
250
+ # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
251
+ for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
252
+ a[-1].bias.data[:] = 1.0 # box
253
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
254
+ for a, b, s in zip(m.cv4, m.cv5, m.stride): # from
255
+ a[-1].bias.data[:] = 1.0 # box
256
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
257
+
258
+ class IDualDDetect(nn.Module):
259
+ # YOLO Detect head for detection models
260
+ dynamic = False # force grid reconstruction
261
+ export = False # export mode
262
+ shape = None
263
+ anchors = torch.empty(0) # init
264
+ strides = torch.empty(0) # init
265
+
266
+ def __init__(self, nc=80, ch=(), inplace=True): # detection layer
267
+ super().__init__()
268
+ self.nc = nc # number of classes
269
+ self.nl = len(ch) // 2 # number of detection layers
270
+ self.reg_max = 16
271
+ self.no = nc + self.reg_max * 4 # number of outputs per anchor
272
+ self.inplace = inplace # use inplace ops (e.g. slice assignment)
273
+ self.stride = torch.zeros(self.nl) # strides computed during build
274
+
275
+ c2, c3 = make_divisible(max((ch[0] // 4, self.reg_max * 4, 16)), 4), max((ch[0], min((self.nc * 2, 128)))) # channels
276
+ c4, c5 = make_divisible(max((ch[self.nl] // 4, self.reg_max * 4, 16)), 4), max((ch[self.nl], min((self.nc * 2, 128)))) # channels
277
+ self.cv2 = nn.ModuleList(
278
+ nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3, g=4), nn.Conv2d(c2, 4 * self.reg_max, 1, groups=4)) for x in ch[:self.nl])
279
+ self.cv3 = nn.ModuleList(
280
+ nn.Sequential(Convb(x, c3, 1), Conv(c3, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch[:self.nl])
281
+ self.cv4 = nn.ModuleList(
282
+ nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3, g=4), nn.Conv2d(c4, 4 * self.reg_max, 1, groups=4)) for x in ch[self.nl:])
283
+ self.cv5 = nn.ModuleList(
284
+ nn.Sequential(Convb(x, c5, 1), Conv(c5, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nc, 1)) for x in ch[self.nl:])
285
+ self.dfl = DFL(self.reg_max)
286
+ self.dfl2 = DFL(self.reg_max)
287
+ # Define ImplicitA and ImplicitM modules
288
+ self.implicita = nn.ModuleList(ImplicitA(x) for x in ch)
289
+ self.implicitm = nn.ModuleList(ImplicitM(self.nc) for _ in ch)
290
+
291
+ def fuse(self):
292
+ # Fuse weights with implicit knowledge
293
+ for i in range(self.nl):
294
+ # Fuse ImplicitA with Convolution
295
+ c1, c2, _, _ = self.cv3[i][0].conv.weight.shape
296
+ c1_, c2_, _, _ = self.implicita[i].implicit.shape
297
+ self.cv3[i][0].conv.bias.data += torch.matmul(self.cv3[i][0].conv.weight.reshape(c1, c2), self.implicita[i].implicit.reshape(c2_, c1_)).squeeze(1)
298
+ c1, c2, _, _ = self.cv5[i][0].conv.weight.shape
299
+ self.cv5[i][0].conv.bias.data += torch.matmul(self.cv5[i][0].conv.weight.reshape(c1, c2), self.implicita[i].implicit.reshape(c2_, c1_)).squeeze(1)
300
+ # Fuse ImplicitM with Convolution
301
+ c1,c2, _,_ = self.implicitm[i].implicit.shape
302
+ self.cv3[i][-1].weight.data = torch.mul(self.cv3[i][-1].weight, self.implicitm[i].implicit.transpose(0,1))
303
+ self.cv3[i][-1].bias.data = torch.matmul(self.cv3[i][-1].bias, self.implicitm[i].implicit.reshape(c2))
304
+ self.cv5[i][-1].weight.data = torch.mul(self.cv5[i][-1].weight, self.implicitm[i].implicit.transpose(0,1))
305
+ self.cv5[i][-1].bias.data = torch.matmul(self.cv5[i][-1].bias, self.implicitm[i].implicit.reshape(c2))
306
+
307
+
308
+ def forward(self, x):
309
+ shape = x[0].shape # BCHW
310
+ d1 = []
311
+ d2 = []
312
+ for i in range(self.nl):
313
+ d1.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
314
+ d2.append(torch.cat((self.cv4[i](x[self.nl+i]),self.cv5[i](x[self.nl+i])), 1))
315
+ if self.training:
316
+ return [d1, d2]
317
+ elif self.dynamic or self.shape != shape:
318
+ self.anchors, self.strides = (d1.transpose(0, 1) for d1 in make_anchors(d1, self.stride, 0.5))
319
+ self.shape = shape
320
+
321
+ box, cls = torch.cat([di.view(shape[0], self.no, -1) for di in d1], 2).split((self.reg_max * 4, self.nc), 1)
322
+ dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
323
+ box2, cls2 = torch.cat([di.view(shape[0], self.no, -1) for di in d2], 2).split((self.reg_max * 4, self.nc), 1)
324
+ dbox2 = dist2bbox(self.dfl2(box2), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
325
+ y = [torch.cat((dbox, cls.sigmoid()), 1), torch.cat((dbox2, cls2.sigmoid()), 1)]
326
+ return y if self.export else (y, [d1, d2])
327
+
328
+ def bias_init(self):
329
+ # Initialize Detect() biases, WARNING: requires stride availability
330
+ m = self # self.model[-1] # Detect() module
331
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
332
+ # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
333
+ for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
334
+ a[-1].bias.data[:] = 1.0 # box
335
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
336
+ for a, b, s in zip(m.cv4, m.cv5, m.stride): # from
337
+ a[-1].bias.data[:] = 1.0 # box
338
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
339
+
340
+
341
+ class TripleDetect(nn.Module):
342
+ # YOLO Detect head for detection models
343
+ dynamic = False # force grid reconstruction
344
+ export = False # export mode
345
+ shape = None
346
+ anchors = torch.empty(0) # init
347
+ strides = torch.empty(0) # init
348
+
349
+ def __init__(self, nc=80, ch=(), inplace=True): # detection layer
350
+ super().__init__()
351
+ self.nc = nc # number of classes
352
+ self.nl = len(ch) // 3 # number of detection layers
353
+ self.reg_max = 16
354
+ self.no = nc + self.reg_max * 4 # number of outputs per anchor
355
+ self.inplace = inplace # use inplace ops (e.g. slice assignment)
356
+ self.stride = torch.zeros(self.nl) # strides computed during build
357
+
358
+ c2, c3 = max((ch[0] // 4, self.reg_max * 4, 16)), max((ch[0], min((self.nc * 2, 128)))) # channels
359
+ c4, c5 = max((ch[self.nl] // 4, self.reg_max * 4, 16)), max((ch[self.nl], min((self.nc * 2, 128)))) # channels
360
+ c6, c7 = max((ch[self.nl * 2] // 4, self.reg_max * 4, 16)), max((ch[self.nl * 2], min((self.nc * 2, 128)))) # channels
361
+ self.cv2 = nn.ModuleList(
362
+ nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch[:self.nl])
363
+ self.cv3 = nn.ModuleList(
364
+ nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch[:self.nl])
365
+ self.cv4 = nn.ModuleList(
366
+ nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, 4 * self.reg_max, 1)) for x in ch[self.nl:self.nl*2])
367
+ self.cv5 = nn.ModuleList(
368
+ nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nc, 1)) for x in ch[self.nl:self.nl*2])
369
+ self.cv6 = nn.ModuleList(
370
+ nn.Sequential(Conv(x, c6, 3), Conv(c6, c6, 3), nn.Conv2d(c6, 4 * self.reg_max, 1)) for x in ch[self.nl*2:self.nl*3])
371
+ self.cv7 = nn.ModuleList(
372
+ nn.Sequential(Conv(x, c7, 3), Conv(c7, c7, 3), nn.Conv2d(c7, self.nc, 1)) for x in ch[self.nl*2:self.nl*3])
373
+ self.dfl = DFL(self.reg_max)
374
+ self.dfl2 = DFL(self.reg_max)
375
+ self.dfl3 = DFL(self.reg_max)
376
+
377
+ def forward(self, x):
378
+ shape = x[0].shape # BCHW
379
+ d1 = []
380
+ d2 = []
381
+ d3 = []
382
+ for i in range(self.nl):
383
+ d1.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
384
+ d2.append(torch.cat((self.cv4[i](x[self.nl+i]), self.cv5[i](x[self.nl+i])), 1))
385
+ d3.append(torch.cat((self.cv6[i](x[self.nl*2+i]), self.cv7[i](x[self.nl*2+i])), 1))
386
+ if self.training:
387
+ return [d1, d2, d3]
388
+ elif self.dynamic or self.shape != shape:
389
+ self.anchors, self.strides = (d1.transpose(0, 1) for d1 in make_anchors(d1, self.stride, 0.5))
390
+ self.shape = shape
391
+
392
+ box, cls = torch.cat([di.view(shape[0], self.no, -1) for di in d1], 2).split((self.reg_max * 4, self.nc), 1)
393
+ dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
394
+ box2, cls2 = torch.cat([di.view(shape[0], self.no, -1) for di in d2], 2).split((self.reg_max * 4, self.nc), 1)
395
+ dbox2 = dist2bbox(self.dfl2(box2), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
396
+ box3, cls3 = torch.cat([di.view(shape[0], self.no, -1) for di in d3], 2).split((self.reg_max * 4, self.nc), 1)
397
+ dbox3 = dist2bbox(self.dfl3(box3), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
398
+ y = [torch.cat((dbox, cls.sigmoid()), 1), torch.cat((dbox2, cls2.sigmoid()), 1), torch.cat((dbox3, cls3.sigmoid()), 1)]
399
+ return y if self.export else (y, [d1, d2, d3])
400
+
401
+ def bias_init(self):
402
+ # Initialize Detect() biases, WARNING: requires stride availability
403
+ m = self # self.model[-1] # Detect() module
404
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
405
+ # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
406
+ for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
407
+ a[-1].bias.data[:] = 1.0 # box
408
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
409
+ for a, b, s in zip(m.cv4, m.cv5, m.stride): # from
410
+ a[-1].bias.data[:] = 1.0 # box
411
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
412
+ for a, b, s in zip(m.cv6, m.cv7, m.stride): # from
413
+ a[-1].bias.data[:] = 1.0 # box
414
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
415
+
416
+
417
+ class TripleDDetect(nn.Module):
418
+ # YOLO Detect head for detection models
419
+ dynamic = False # force grid reconstruction
420
+ export = False # export mode
421
+ shape = None
422
+ anchors = torch.empty(0) # init
423
+ strides = torch.empty(0) # init
424
+
425
+ def __init__(self, nc=80, ch=(), inplace=True): # detection layer
426
+ super().__init__()
427
+ self.nc = nc # number of classes
428
+ self.nl = len(ch) // 3 # number of detection layers
429
+ self.reg_max = 16
430
+ self.no = nc + self.reg_max * 4 # number of outputs per anchor
431
+ self.inplace = inplace # use inplace ops (e.g. slice assignment)
432
+ self.stride = torch.zeros(self.nl) # strides computed during build
433
+
434
+ c2, c3 = make_divisible(max((ch[0] // 4, self.reg_max * 4, 16)), 4), \
435
+ max((ch[0], min((self.nc * 2, 128)))) # channels
436
+ c4, c5 = make_divisible(max((ch[self.nl] // 4, self.reg_max * 4, 16)), 4), \
437
+ max((ch[self.nl], min((self.nc * 2, 128)))) # channels
438
+ c6, c7 = make_divisible(max((ch[self.nl * 2] // 4, self.reg_max * 4, 16)), 4), \
439
+ max((ch[self.nl * 2], min((self.nc * 2, 128)))) # channels
440
+ self.cv2 = nn.ModuleList(
441
+ nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3, g=4),
442
+ nn.Conv2d(c2, 4 * self.reg_max, 1, groups=4)) for x in ch[:self.nl])
443
+ self.cv3 = nn.ModuleList(
444
+ nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch[:self.nl])
445
+ self.cv4 = nn.ModuleList(
446
+ nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3, g=4),
447
+ nn.Conv2d(c4, 4 * self.reg_max, 1, groups=4)) for x in ch[self.nl:self.nl*2])
448
+ self.cv5 = nn.ModuleList(
449
+ nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nc, 1)) for x in ch[self.nl:self.nl*2])
450
+ self.cv6 = nn.ModuleList(
451
+ nn.Sequential(Conv(x, c6, 3), Conv(c6, c6, 3, g=4),
452
+ nn.Conv2d(c6, 4 * self.reg_max, 1, groups=4)) for x in ch[self.nl*2:self.nl*3])
453
+ self.cv7 = nn.ModuleList(
454
+ nn.Sequential(Conv(x, c7, 3), Conv(c7, c7, 3), nn.Conv2d(c7, self.nc, 1)) for x in ch[self.nl*2:self.nl*3])
455
+ self.dfl = DFL(self.reg_max)
456
+ self.dfl2 = DFL(self.reg_max)
457
+ self.dfl3 = DFL(self.reg_max)
458
+
459
+ def forward(self, x):
460
+ shape = x[0].shape # BCHW
461
+ d1 = []
462
+ d2 = []
463
+ d3 = []
464
+ for i in range(self.nl):
465
+ d1.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
466
+ d2.append(torch.cat((self.cv4[i](x[self.nl+i]), self.cv5[i](x[self.nl+i])), 1))
467
+ d3.append(torch.cat((self.cv6[i](x[self.nl*2+i]), self.cv7[i](x[self.nl*2+i])), 1))
468
+ if self.training:
469
+ return [d1, d2, d3]
470
+ elif self.dynamic or self.shape != shape:
471
+ self.anchors, self.strides = (d1.transpose(0, 1) for d1 in make_anchors(d1, self.stride, 0.5))
472
+ self.shape = shape
473
+
474
+ box, cls = torch.cat([di.view(shape[0], self.no, -1) for di in d1], 2).split((self.reg_max * 4, self.nc), 1)
475
+ dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
476
+ box2, cls2 = torch.cat([di.view(shape[0], self.no, -1) for di in d2], 2).split((self.reg_max * 4, self.nc), 1)
477
+ dbox2 = dist2bbox(self.dfl2(box2), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
478
+ box3, cls3 = torch.cat([di.view(shape[0], self.no, -1) for di in d3], 2).split((self.reg_max * 4, self.nc), 1)
479
+ dbox3 = dist2bbox(self.dfl3(box3), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
480
+ #y = [torch.cat((dbox, cls.sigmoid()), 1), torch.cat((dbox2, cls2.sigmoid()), 1), torch.cat((dbox3, cls3.sigmoid()), 1)]
481
+ #return y if self.export else (y, [d1, d2, d3])
482
+ y = torch.cat((dbox3, cls3.sigmoid()), 1)
483
+ return y if self.export else (y, d3)
484
+
485
+ def bias_init(self):
486
+ # Initialize Detect() biases, WARNING: requires stride availability
487
+ m = self # self.model[-1] # Detect() module
488
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
489
+ # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
490
+ for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
491
+ a[-1].bias.data[:] = 1.0 # box
492
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
493
+ for a, b, s in zip(m.cv4, m.cv5, m.stride): # from
494
+ a[-1].bias.data[:] = 1.0 # box
495
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
496
+ for a, b, s in zip(m.cv6, m.cv7, m.stride): # from
497
+ a[-1].bias.data[:] = 1.0 # box
498
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
499
+
500
+
501
+ class Segment(Detect):
502
+ # YOLO Segment head for segmentation models
503
+ def __init__(self, nc=80, nm=32, npr=256, ch=(), inplace=True):
504
+ super().__init__(nc, ch, inplace)
505
+ self.nm = nm # number of masks
506
+ self.npr = npr # number of protos
507
+ self.proto = Proto(ch[0], self.npr, self.nm) # protos
508
+ self.detect = Detect.forward
509
+
510
+ c4 = max(ch[0] // 4, self.nm)
511
+ self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
512
+
513
+ def forward(self, x):
514
+ p = self.proto(x[0])
515
+ bs = p.shape[0]
516
+
517
+ mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
518
+ x = self.detect(self, x)
519
+ if self.training:
520
+ return x, mc, p
521
+ return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))
522
+
523
+
524
+ class Panoptic(Detect):
525
+ # YOLO Panoptic head for panoptic segmentation models
526
+ def __init__(self, nc=80, sem_nc=93, nm=32, npr=256, ch=(), inplace=True):
527
+ super().__init__(nc, ch, inplace)
528
+ self.sem_nc = sem_nc
529
+ self.nm = nm # number of masks
530
+ self.npr = npr # number of protos
531
+ self.proto = Proto(ch[0], self.npr, self.nm) # protos
532
+ self.uconv = UConv(ch[0], ch[0]//4, self.sem_nc+self.nc)
533
+ self.detect = Detect.forward
534
+
535
+ c4 = max(ch[0] // 4, self.nm)
536
+ self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
537
+
538
+
539
+ def forward(self, x):
540
+ p = self.proto(x[0])
541
+ s = self.uconv(x[0])
542
+ bs = p.shape[0]
543
+
544
+ mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
545
+ x = self.detect(self, x)
546
+ if self.training:
547
+ return x, mc, p, s
548
+ return (torch.cat([x, mc], 1), p, s) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p, s))
549
+
550
+
551
+ class BaseModel(nn.Module):
552
+ # YOLO base model
553
+ def forward(self, x, profile=False, visualize=False):
554
+ return self._forward_once(x, profile, visualize) # single-scale inference, train
555
+
556
+ def _forward_once(self, x, profile=False, visualize=False):
557
+ y, dt = [], [] # outputs
558
+ for m in self.model:
559
+ if m.f != -1: # if not from previous layer
560
+ x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
561
+ if profile:
562
+ self._profile_one_layer(m, x, dt)
563
+ x = m(x) # run
564
+ y.append(x if m.i in self.save else None) # save output
565
+ if visualize:
566
+ feature_visualization(x, m.type, m.i, save_dir=visualize)
567
+ return x
568
+
569
+ def _profile_one_layer(self, m, x, dt):
570
+ c = m == self.model[-1] # is final layer, copy input as inplace fix
571
+ o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
572
+ t = time_sync()
573
+ for _ in range(10):
574
+ m(x.copy() if c else x)
575
+ dt.append((time_sync() - t) * 100)
576
+ if m == self.model[0]:
577
+ LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
578
+ LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
579
+ if c:
580
+ LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
581
+
582
+ def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
583
+ LOGGER.info('Fusing layers... ')
584
+ for m in self.model.modules():
585
+ if isinstance(m, (RepConvN)) and hasattr(m, 'fuse_convs'):
586
+ m.fuse_convs()
587
+ m.forward = m.forward_fuse # update forward
588
+ if isinstance(m, (Conv, Convb, DWConv)) and hasattr(m, 'bn'):
589
+ m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
590
+ delattr(m, 'bn') # remove batchnorm
591
+ m.forward = m.forward_fuse # update forward
592
+ elif isinstance(m, (IDualDDetect)):
593
+ m.fuse()
594
+
595
+ self.info()
596
+ return self
597
+
598
+ def info(self, verbose=False, img_size=640): # print model information
599
+ model_info(self, verbose, img_size)
600
+
601
+ def _apply(self, fn):
602
+ # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
603
+ self = super()._apply(fn)
604
+ m = self.model[-1] # Detect()
605
+ if isinstance(m, (Detect, DualDetect, TripleDetect, DDetect, DualDDetect, IDualDDetect, TripleDDetect, Segment)):
606
+ m.stride = fn(m.stride)
607
+ m.anchors = fn(m.anchors)
608
+ m.strides = fn(m.strides)
609
+ # m.grid = list(map(fn, m.grid))
610
+ return self
611
+
612
+
613
+ class DetectionModel(BaseModel):
614
+ # YOLO detection model
615
+ def __init__(self, cfg='yolo.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
616
+ super().__init__()
617
+ if isinstance(cfg, dict):
618
+ self.yaml = cfg # model dict
619
+ else: # is *.yaml
620
+ import yaml # for torch hub
621
+ self.yaml_file = Path(cfg).name
622
+ with open(cfg, encoding='ascii', errors='ignore') as f:
623
+ self.yaml = yaml.safe_load(f) # model dict
624
+
625
+ # Define model
626
+ ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
627
+ if nc and nc != self.yaml['nc']:
628
+ LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
629
+ self.yaml['nc'] = nc # override yaml value
630
+ if anchors:
631
+ LOGGER.info(f'Overriding model.yaml anchors with anchors={anchors}')
632
+ self.yaml['anchors'] = round(anchors) # override yaml value
633
+ self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
634
+ self.names = [str(i) for i in range(self.yaml['nc'])] # default names
635
+ self.inplace = self.yaml.get('inplace', True)
636
+
637
+ # Build strides, anchors
638
+ m = self.model[-1] # Detect()
639
+ if isinstance(m, (Detect, DDetect, Segment)):
640
+ s = 256 # 2x min stride
641
+ m.inplace = self.inplace
642
+ forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment)) else self.forward(x)
643
+ m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
644
+ # check_anchor_order(m)
645
+ # m.anchors /= m.stride.view(-1, 1, 1)
646
+ self.stride = m.stride
647
+ m.bias_init() # only run once
648
+ if isinstance(m, (DualDetect, TripleDetect, DualDDetect, IDualDDetect, TripleDDetect)):
649
+ s = 256 # 2x min stride
650
+ m.inplace = self.inplace
651
+ #forward = lambda x: self.forward(x)[0][0] if isinstance(m, (DualSegment)) else self.forward(x)[0]
652
+ forward = lambda x: self.forward(x)[0]
653
+ m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
654
+ # check_anchor_order(m)
655
+ # m.anchors /= m.stride.view(-1, 1, 1)
656
+ self.stride = m.stride
657
+ m.bias_init() # only run once
658
+
659
+ # Init weights, biases
660
+ initialize_weights(self)
661
+ self.info()
662
+ LOGGER.info('')
663
+
664
+ def forward(self, x, augment=False, profile=False, visualize=False):
665
+ if augment:
666
+ return self._forward_augment(x) # augmented inference, None
667
+ return self._forward_once(x, profile, visualize) # single-scale inference, train
668
+
669
+ def _forward_augment(self, x):
670
+ img_size = x.shape[-2:] # height, width
671
+ s = [1, 0.83, 0.67] # scales
672
+ f = [None, 3, None] # flips (2-ud, 3-lr)
673
+ y = [] # outputs
674
+ for si, fi in zip(s, f):
675
+ xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
676
+ yi = self._forward_once(xi)[0] # forward
677
+ # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
678
+ yi = self._descale_pred(yi, fi, si, img_size)
679
+ y.append(yi)
680
+ y = self._clip_augmented(y) # clip augmented tails
681
+ return torch.cat(y, 1), None # augmented inference, train
682
+
683
+ def _descale_pred(self, p, flips, scale, img_size):
684
+ # de-scale predictions following augmented inference (inverse operation)
685
+ if self.inplace:
686
+ p[..., :4] /= scale # de-scale
687
+ if flips == 2:
688
+ p[..., 1] = img_size[0] - p[..., 1] # de-flip ud
689
+ elif flips == 3:
690
+ p[..., 0] = img_size[1] - p[..., 0] # de-flip lr
691
+ else:
692
+ x, y, wh = p[..., 0:1] / scale, p[..., 1:2] / scale, p[..., 2:4] / scale # de-scale
693
+ if flips == 2:
694
+ y = img_size[0] - y # de-flip ud
695
+ elif flips == 3:
696
+ x = img_size[1] - x # de-flip lr
697
+ p = torch.cat((x, y, wh, p[..., 4:]), -1)
698
+ return p
699
+
700
+ def _clip_augmented(self, y):
701
+ # Clip YOLO augmented inference tails
702
+ nl = self.model[-1].nl # number of detection layers (P3-P5)
703
+ g = sum(4 ** x for x in range(nl)) # grid points
704
+ e = 1 # exclude layer count
705
+ i = (y[0].shape[1] // g) * sum(4 ** x for x in range(e)) # indices
706
+ y[0] = y[0][:, :-i] # large
707
+ i = (y[-1].shape[1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
708
+ y[-1] = y[-1][:, i:] # small
709
+ return y
710
+
711
+
712
+ Model = DetectionModel # retain YOLO 'Model' class for backwards compatibility
713
+
714
+
715
+ class SegmentationModel(DetectionModel):
716
+ # YOLO segmentation model
717
+ def __init__(self, cfg='yolo-seg.yaml', ch=3, nc=None, anchors=None):
718
+ super().__init__(cfg, ch, nc, anchors)
719
+
720
+
721
+ class ClassificationModel(BaseModel):
722
+ # YOLO classification model
723
+ def __init__(self, cfg=None, model=None, nc=1000, cutoff=10): # yaml, model, number of classes, cutoff index
724
+ super().__init__()
725
+ self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg)
726
+
727
+ def _from_detection_model(self, model, nc=1000, cutoff=10):
728
+ # Create a YOLO classification model from a YOLO detection model
729
+ if isinstance(model, DetectMultiBackend):
730
+ model = model.model # unwrap DetectMultiBackend
731
+ model.model = model.model[:cutoff] # backbone
732
+ m = model.model[-1] # last layer
733
+ ch = m.conv.in_channels if hasattr(m, 'conv') else m.cv1.conv.in_channels # ch into module
734
+ c = Classify(ch, nc) # Classify()
735
+ c.i, c.f, c.type = m.i, m.f, 'models.common.Classify' # index, from, type
736
+ model.model[-1] = c # replace
737
+ self.model = model.model
738
+ self.stride = model.stride
739
+ self.save = []
740
+ self.nc = nc
741
+
742
+ def _from_yaml(self, cfg):
743
+ # Create a YOLO classification model from a *.yaml file
744
+ self.model = None
745
+
746
+
747
+ def parse_model(d, ch): # model_dict, input_channels(3)
748
+ # Parse a YOLO model.yaml dictionary
749
+ LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
750
+ anchors, nc, gd, gw, act = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')
751
+ if act:
752
+ Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
753
+ RepConvN.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
754
+ LOGGER.info(f"{colorstr('activation:')} {act}") # print
755
+ na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
756
+ no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
757
+
758
+ layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
759
+ for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
760
+ m = eval(m) if isinstance(m, str) else m # eval strings
761
+ for j, a in enumerate(args):
762
+ with contextlib.suppress(NameError):
763
+ args[j] = eval(a) if isinstance(a, str) else a # eval strings
764
+
765
+ n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
766
+ if m in {
767
+ Conv, Convb, AConv, ConvTranspose,
768
+ Bottleneck, SPP, SPPF, DWConv, BottleneckCSP, nn.ConvTranspose2d, DWConvTranspose2d, SPPCSPC, ADown,
769
+ RepNCSPELAN4, SPPELAN, CBAMC4, RepNCBAMELAN4, SOCA, RepNSAELAN4, SABottleneck, LSKBottleneck, RepNLSKELAN4,
770
+ RepNECALAN4, C2f_DCNv2, RepDCNv2LEAN4}:
771
+ c1, c2 = ch[f], args[0]
772
+ if c2 != no: # if not output
773
+ c2 = make_divisible(c2 * gw, 8)
774
+
775
+ args = [c1, c2, *args[1:]]
776
+ if m in {BottleneckCSP, SPPCSPC}:
777
+ args.insert(2, n) # number of repeats
778
+ n = 1
779
+ elif m is nn.BatchNorm2d:
780
+ args = [ch[f]]
781
+ elif m is Concat:
782
+ c2 = sum(ch[x] for x in f)
783
+ elif m is Shortcut:
784
+ c2 = ch[f[0]]
785
+ elif m is ReOrg:
786
+ c2 = ch[f] * 4
787
+ elif m is CBLinear:
788
+ c2 = args[0]
789
+ c1 = ch[f]
790
+ args = [c1, c2, *args[1:]]
791
+ elif m is CBFuse:
792
+ c2 = ch[f[-1]]
793
+ # TODO: channel, gw, gd
794
+ elif m in {Detect, DualDetect, TripleDetect, DDetect, DualDDetect, IDualDDetect, TripleDDetect, Segment}:
795
+ args.append([ch[x] for x in f])
796
+ # if isinstance(args[1], int): # number of anchors
797
+ # args[1] = [list(range(args[1] * 2))] * len(f)
798
+ if m in {Segment}:
799
+ args[2] = make_divisible(args[2] * gw, 8)
800
+ elif m is Contract:
801
+ c2 = ch[f] * args[0] ** 2
802
+ elif m is Expand:
803
+ c2 = ch[f] // args[0] ** 2
804
+ else:
805
+ c2 = ch[f]
806
+
807
+ m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
808
+ t = str(m)[8:-2].replace('__main__.', '') # module type
809
+ np = sum(x.numel() for x in m_.parameters()) # number params
810
+ m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
811
+ LOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{np:10.0f} {t:<40}{str(args):<30}') # print
812
+ save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
813
+ layers.append(m_)
814
+ if i == 0:
815
+ ch = []
816
+ ch.append(c2)
817
+ return nn.Sequential(*layers), sorted(save)
818
+
819
+
820
+ if __name__ == '__main__':
821
+ parser = argparse.ArgumentParser()
822
+ parser.add_argument('--cfg', type=str, default='./detect/attention/yolov9-e-repnlskelan4.yaml', help='model.yaml')
823
+ parser.add_argument('--batch-size', type=int, default=1, help='total batch size for all GPUs')
824
+ parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
825
+ parser.add_argument('--profile', action='store_true', help='profile model speed')
826
+ parser.add_argument('--line-profile', action='store_true', help='profile model speed layer by layer')
827
+ parser.add_argument('--test', action='store_true', help='test all yolo*.yaml')
828
+ opt = parser.parse_args()
829
+ opt.cfg = check_yaml(opt.cfg) # check YAML
830
+ print_args(vars(opt))
831
+ device = select_device(opt.device)
832
+
833
+ # Create model
834
+ im = torch.rand(opt.batch_size, 3, 640, 640).to(device)
835
+ model = Model(opt.cfg).to(device)
836
+ model.eval()
837
+
838
+ # Options
839
+ if opt.line_profile: # profile layer by layer
840
+ model(im, profile=True)
841
+
842
+ elif opt.profile: # profile forward-backward
843
+ results = profile(input=im, ops=[model], n=3)
844
+
845
+ elif opt.test: # test all models
846
+ for cfg in Path(ROOT / 'models').rglob('yolo*.yaml'):
847
+ try:
848
+ _ = Model(cfg)
849
+ except Exception as e:
850
+ print(f'Error in {cfg}: {e}')
851
+
852
+ else: # report fused model summary
853
+ model.fuse()
requirements.txt ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # requirements
2
+ # Usage: pip install -r requirements.txt
3
+
4
+ # Base ------------------------------------------------------------------------
5
+ gitpython
6
+ ipython
7
+ matplotlib>=3.2.2
8
+ numpy>=1.18.5
9
+ opencv-python>=4.1.1
10
+ Pillow>=7.1.2
11
+ psutil
12
+ PyYAML>=5.3.1
13
+ requests>=2.23.0
14
+ scipy>=1.4.1
15
+ thop>=0.1.1
16
+ torch>=1.7.0
17
+ torchvision>=0.8.1
18
+ tqdm>=4.64.0
19
+ # protobuf<=3.20.1
20
+ gradio
21
+ # Logging ---------------------------------------------------------------------
22
+ tensorboard>=2.4.1
23
+ # clearml>=1.2.0
24
+ # comet
25
+
26
+ # Plotting --------------------------------------------------------------------
27
+ pandas>=1.1.4
28
+ seaborn>=0.11.0
29
+
30
+ # Export ----------------------------------------------------------------------
31
+ # coremltools>=6.0
32
+ # onnx>=1.9.0
33
+ # onnx-simplifier>=0.4.1
34
+ # nvidia-pyindex
35
+ # nvidia-tensorrt
36
+ # scikit-learn<=1.1.2
37
+ # tensorflow>=2.4.1
38
+ # tensorflowjs>=3.9.0
39
+ # openvino-dev
40
+
41
+ # Deploy ----------------------------------------------------------------------
42
+ # tritonclient[all]~=2.24.0
43
+
44
+ # Extras ----------------------------------------------------------------------
45
+ # mss
46
+ albumentations>=1.0.3
47
+ pycocotools>=2.0
utils/__init__.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import platform
3
+ import threading
4
+
5
+
6
+ def emojis(str=''):
7
+ # Return platform-dependent emoji-safe version of string
8
+ return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
9
+
10
+
11
+ class TryExcept(contextlib.ContextDecorator):
12
+ # YOLOv5 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager
13
+ def __init__(self, msg=''):
14
+ self.msg = msg
15
+
16
+ def __enter__(self):
17
+ pass
18
+
19
+ def __exit__(self, exc_type, value, traceback):
20
+ if value:
21
+ print(emojis(f"{self.msg}{': ' if self.msg else ''}{value}"))
22
+ return True
23
+
24
+
25
+ def threaded(func):
26
+ # Multi-threads a target function and returns thread. Usage: @threaded decorator
27
+ def wrapper(*args, **kwargs):
28
+ thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
29
+ thread.start()
30
+ return thread
31
+
32
+ return wrapper
33
+
34
+
35
+ def join_threads(verbose=False):
36
+ # Join all daemon threads, i.e. atexit.register(lambda: join_threads())
37
+ main_thread = threading.current_thread()
38
+ for t in threading.enumerate():
39
+ if t is not main_thread:
40
+ if verbose:
41
+ print(f'Joining thread {t.name}')
42
+ t.join()
43
+
44
+
45
+ def notebook_init(verbose=True):
46
+ # Check system software and hardware
47
+ print('Checking setup...')
48
+
49
+ import os
50
+ import shutil
51
+
52
+ from utils.general import check_font, check_requirements, is_colab
53
+ from utils.torch_utils import select_device # imports
54
+
55
+ check_font()
56
+
57
+ import psutil
58
+ from IPython import display # to display images and clear console output
59
+
60
+ if is_colab():
61
+ shutil.rmtree('/content/sample_data', ignore_errors=True) # remove colab /sample_data directory
62
+
63
+ # System info
64
+ if verbose:
65
+ gb = 1 << 30 # bytes to GiB (1024 ** 3)
66
+ ram = psutil.virtual_memory().total
67
+ total, used, free = shutil.disk_usage("/")
68
+ display.clear_output()
69
+ s = f'({os.cpu_count()} CPUs, {ram / gb:.1f} GB RAM, {(total - free) / gb:.1f}/{total / gb:.1f} GB disk)'
70
+ else:
71
+ s = ''
72
+
73
+ select_device(newline=False)
74
+ print(emojis(f'Setup complete ✅ {s}'))
75
+ return display
utils/augmentations.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ import torchvision.transforms as T
8
+ import torchvision.transforms.functional as TF
9
+
10
+ from utils.general import LOGGER, check_version, colorstr, resample_segments, segment2box, xywhn2xyxy
11
+ from utils.metrics import bbox_ioa
12
+
13
+ IMAGENET_MEAN = 0.485, 0.456, 0.406 # RGB mean
14
+ IMAGENET_STD = 0.229, 0.224, 0.225 # RGB standard deviation
15
+
16
+
17
+ class Albumentations:
18
+ # YOLOv5 Albumentations class (optional, only used if package is installed)
19
+ def __init__(self, size=640):
20
+ self.transform = None
21
+ prefix = colorstr('albumentations: ')
22
+ try:
23
+ import albumentations as A
24
+ check_version(A.__version__, '1.0.3', hard=True) # version requirement
25
+
26
+ T = [
27
+ A.RandomResizedCrop(height=size, width=size, scale=(0.8, 1.0), ratio=(0.9, 1.11), p=0.0),
28
+ A.Blur(p=0.01),
29
+ A.MedianBlur(p=0.01),
30
+ A.ToGray(p=0.01),
31
+ A.CLAHE(p=0.01),
32
+ A.RandomBrightnessContrast(p=0.0),
33
+ A.RandomGamma(p=0.0),
34
+ A.ImageCompression(quality_lower=75, p=0.0)] # transforms
35
+ self.transform = A.Compose(T, bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
36
+
37
+ LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
38
+ except ImportError: # package not installed, skip
39
+ pass
40
+ except Exception as e:
41
+ LOGGER.info(f'{prefix}{e}')
42
+
43
+ def __call__(self, im, labels, p=1.0):
44
+ if self.transform and random.random() < p:
45
+ new = self.transform(image=im, bboxes=labels[:, 1:], class_labels=labels[:, 0]) # transformed
46
+ im, labels = new['image'], np.array([[c, *b] for c, b in zip(new['class_labels'], new['bboxes'])])
47
+ return im, labels
48
+
49
+
50
+ def normalize(x, mean=IMAGENET_MEAN, std=IMAGENET_STD, inplace=False):
51
+ # Denormalize RGB images x per ImageNet stats in BCHW format, i.e. = (x - mean) / std
52
+ return TF.normalize(x, mean, std, inplace=inplace)
53
+
54
+
55
+ def denormalize(x, mean=IMAGENET_MEAN, std=IMAGENET_STD):
56
+ # Denormalize RGB images x per ImageNet stats in BCHW format, i.e. = x * std + mean
57
+ for i in range(3):
58
+ x[:, i] = x[:, i] * std[i] + mean[i]
59
+ return x
60
+
61
+
62
+ def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5):
63
+ # HSV color-space augmentation
64
+ if hgain or sgain or vgain:
65
+ r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
66
+ hue, sat, val = cv2.split(cv2.cvtColor(im, cv2.COLOR_BGR2HSV))
67
+ dtype = im.dtype # uint8
68
+
69
+ x = np.arange(0, 256, dtype=r.dtype)
70
+ lut_hue = ((x * r[0]) % 180).astype(dtype)
71
+ lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
72
+ lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
73
+
74
+ im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
75
+ cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=im) # no return needed
76
+
77
+
78
+ def hist_equalize(im, clahe=True, bgr=False):
79
+ # Equalize histogram on BGR image 'im' with im.shape(n,m,3) and range 0-255
80
+ yuv = cv2.cvtColor(im, cv2.COLOR_BGR2YUV if bgr else cv2.COLOR_RGB2YUV)
81
+ if clahe:
82
+ c = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
83
+ yuv[:, :, 0] = c.apply(yuv[:, :, 0])
84
+ else:
85
+ yuv[:, :, 0] = cv2.equalizeHist(yuv[:, :, 0]) # equalize Y channel histogram
86
+ return cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR if bgr else cv2.COLOR_YUV2RGB) # convert YUV image to RGB
87
+
88
+
89
+ def replicate(im, labels):
90
+ # Replicate labels
91
+ h, w = im.shape[:2]
92
+ boxes = labels[:, 1:].astype(int)
93
+ x1, y1, x2, y2 = boxes.T
94
+ s = ((x2 - x1) + (y2 - y1)) / 2 # side length (pixels)
95
+ for i in s.argsort()[:round(s.size * 0.5)]: # smallest indices
96
+ x1b, y1b, x2b, y2b = boxes[i]
97
+ bh, bw = y2b - y1b, x2b - x1b
98
+ yc, xc = int(random.uniform(0, h - bh)), int(random.uniform(0, w - bw)) # offset x, y
99
+ x1a, y1a, x2a, y2a = [xc, yc, xc + bw, yc + bh]
100
+ im[y1a:y2a, x1a:x2a] = im[y1b:y2b, x1b:x2b] # im4[ymin:ymax, xmin:xmax]
101
+ labels = np.append(labels, [[labels[i, 0], x1a, y1a, x2a, y2a]], axis=0)
102
+
103
+ return im, labels
104
+
105
+
106
+ def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
107
+ # Resize and pad image while meeting stride-multiple constraints
108
+ shape = im.shape[:2] # current shape [height, width]
109
+ if isinstance(new_shape, int):
110
+ new_shape = (new_shape, new_shape)
111
+
112
+ # Scale ratio (new / old)
113
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
114
+ if not scaleup: # only scale down, do not scale up (for better val mAP)
115
+ r = min(r, 1.0)
116
+
117
+ # Compute padding
118
+ ratio = r, r # width, height ratios
119
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
120
+ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
121
+ if auto: # minimum rectangle
122
+ dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
123
+ elif scaleFill: # stretch
124
+ dw, dh = 0.0, 0.0
125
+ new_unpad = (new_shape[1], new_shape[0])
126
+ ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
127
+
128
+ dw /= 2 # divide padding into 2 sides
129
+ dh /= 2
130
+
131
+ if shape[::-1] != new_unpad: # resize
132
+ im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
133
+ top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
134
+ left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
135
+ im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
136
+ return im, ratio, (dw, dh)
137
+
138
+
139
+ def random_perspective(im,
140
+ targets=(),
141
+ segments=(),
142
+ degrees=10,
143
+ translate=.1,
144
+ scale=.1,
145
+ shear=10,
146
+ perspective=0.0,
147
+ border=(0, 0)):
148
+ # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(0.1, 0.1), scale=(0.9, 1.1), shear=(-10, 10))
149
+ # targets = [cls, xyxy]
150
+
151
+ height = im.shape[0] + border[0] * 2 # shape(h,w,c)
152
+ width = im.shape[1] + border[1] * 2
153
+
154
+ # Center
155
+ C = np.eye(3)
156
+ C[0, 2] = -im.shape[1] / 2 # x translation (pixels)
157
+ C[1, 2] = -im.shape[0] / 2 # y translation (pixels)
158
+
159
+ # Perspective
160
+ P = np.eye(3)
161
+ P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y)
162
+ P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x)
163
+
164
+ # Rotation and Scale
165
+ R = np.eye(3)
166
+ a = random.uniform(-degrees, degrees)
167
+ # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
168
+ s = random.uniform(1 - scale, 1 + scale)
169
+ # s = 2 ** random.uniform(-scale, scale)
170
+ R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
171
+
172
+ # Shear
173
+ S = np.eye(3)
174
+ S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
175
+ S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
176
+
177
+ # Translation
178
+ T = np.eye(3)
179
+ T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width # x translation (pixels)
180
+ T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height # y translation (pixels)
181
+
182
+ # Combined rotation matrix
183
+ M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
184
+ if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
185
+ if perspective:
186
+ im = cv2.warpPerspective(im, M, dsize=(width, height), borderValue=(114, 114, 114))
187
+ else: # affine
188
+ im = cv2.warpAffine(im, M[:2], dsize=(width, height), borderValue=(114, 114, 114))
189
+
190
+ # Visualize
191
+ # import matplotlib.pyplot as plt
192
+ # ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel()
193
+ # ax[0].imshow(im[:, :, ::-1]) # base
194
+ # ax[1].imshow(im2[:, :, ::-1]) # warped
195
+
196
+ # Transform label coordinates
197
+ n = len(targets)
198
+ if n:
199
+ use_segments = any(x.any() for x in segments)
200
+ new = np.zeros((n, 4))
201
+ if use_segments: # warp segments
202
+ segments = resample_segments(segments) # upsample
203
+ for i, segment in enumerate(segments):
204
+ xy = np.ones((len(segment), 3))
205
+ xy[:, :2] = segment
206
+ xy = xy @ M.T # transform
207
+ xy = xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2] # perspective rescale or affine
208
+
209
+ # clip
210
+ new[i] = segment2box(xy, width, height)
211
+
212
+ else: # warp boxes
213
+ xy = np.ones((n * 4, 3))
214
+ xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
215
+ xy = xy @ M.T # transform
216
+ xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine
217
+
218
+ # create new boxes
219
+ x = xy[:, [0, 2, 4, 6]]
220
+ y = xy[:, [1, 3, 5, 7]]
221
+ new = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
222
+
223
+ # clip
224
+ new[:, [0, 2]] = new[:, [0, 2]].clip(0, width)
225
+ new[:, [1, 3]] = new[:, [1, 3]].clip(0, height)
226
+
227
+ # filter candidates
228
+ i = box_candidates(box1=targets[:, 1:5].T * s, box2=new.T, area_thr=0.01 if use_segments else 0.10)
229
+ targets = targets[i]
230
+ targets[:, 1:5] = new[i]
231
+
232
+ return im, targets
233
+
234
+
235
+ def copy_paste(im, labels, segments, p=0.5):
236
+ # Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)
237
+ n = len(segments)
238
+ if p and n:
239
+ h, w, c = im.shape # height, width, channels
240
+ im_new = np.zeros(im.shape, np.uint8)
241
+
242
+ # calculate ioa first then select indexes randomly
243
+ boxes = np.stack([w - labels[:, 3], labels[:, 2], w - labels[:, 1], labels[:, 4]], axis=-1) # (n, 4)
244
+ ioa = bbox_ioa(boxes, labels[:, 1:5]) # intersection over area
245
+ indexes = np.nonzero((ioa < 0.30).all(1))[0] # (N, )
246
+ n = len(indexes)
247
+ for j in random.sample(list(indexes), k=round(p * n)):
248
+ l, box, s = labels[j], boxes[j], segments[j]
249
+ labels = np.concatenate((labels, [[l[0], *box]]), 0)
250
+ segments.append(np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1))
251
+ cv2.drawContours(im_new, [segments[j].astype(np.int32)], -1, (1, 1, 1), cv2.FILLED)
252
+
253
+ result = cv2.flip(im, 1) # augment segments (flip left-right)
254
+ i = cv2.flip(im_new, 1).astype(bool)
255
+ im[i] = result[i] # cv2.imwrite('debug.jpg', im) # debug
256
+
257
+ return im, labels, segments
258
+
259
+
260
+ def cutout(im, labels, p=0.5):
261
+ # Applies image cutout augmentation https://arxiv.org/abs/1708.04552
262
+ if random.random() < p:
263
+ h, w = im.shape[:2]
264
+ scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction
265
+ for s in scales:
266
+ mask_h = random.randint(1, int(h * s)) # create random masks
267
+ mask_w = random.randint(1, int(w * s))
268
+
269
+ # box
270
+ xmin = max(0, random.randint(0, w) - mask_w // 2)
271
+ ymin = max(0, random.randint(0, h) - mask_h // 2)
272
+ xmax = min(w, xmin + mask_w)
273
+ ymax = min(h, ymin + mask_h)
274
+
275
+ # apply random color mask
276
+ im[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)]
277
+
278
+ # return unobscured labels
279
+ if len(labels) and s > 0.03:
280
+ box = np.array([[xmin, ymin, xmax, ymax]], dtype=np.float32)
281
+ ioa = bbox_ioa(box, xywhn2xyxy(labels[:, 1:5], w, h))[0] # intersection over area
282
+ labels = labels[ioa < 0.60] # remove >60% obscured labels
283
+
284
+ return labels
285
+
286
+
287
+ def mixup(im, labels, im2, labels2):
288
+ # Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf
289
+ r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
290
+ im = (im * r + im2 * (1 - r)).astype(np.uint8)
291
+ labels = np.concatenate((labels, labels2), 0)
292
+ return im, labels
293
+
294
+
295
+ def box_candidates(box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
296
+ # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
297
+ w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
298
+ w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
299
+ ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
300
+ return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates
301
+
302
+
303
+ def classify_albumentations(
304
+ augment=True,
305
+ size=224,
306
+ scale=(0.08, 1.0),
307
+ ratio=(0.75, 1.0 / 0.75), # 0.75, 1.33
308
+ hflip=0.5,
309
+ vflip=0.0,
310
+ jitter=0.4,
311
+ mean=IMAGENET_MEAN,
312
+ std=IMAGENET_STD,
313
+ auto_aug=False):
314
+ # YOLOv5 classification Albumentations (optional, only used if package is installed)
315
+ prefix = colorstr('albumentations: ')
316
+ try:
317
+ import albumentations as A
318
+ from albumentations.pytorch import ToTensorV2
319
+ check_version(A.__version__, '1.0.3', hard=True) # version requirement
320
+ if augment: # Resize and crop
321
+ T = [A.RandomResizedCrop(height=size, width=size, scale=scale, ratio=ratio)]
322
+ if auto_aug:
323
+ # TODO: implement AugMix, AutoAug & RandAug in albumentation
324
+ LOGGER.info(f'{prefix}auto augmentations are currently not supported')
325
+ else:
326
+ if hflip > 0:
327
+ T += [A.HorizontalFlip(p=hflip)]
328
+ if vflip > 0:
329
+ T += [A.VerticalFlip(p=vflip)]
330
+ if jitter > 0:
331
+ color_jitter = (float(jitter),) * 3 # repeat value for brightness, contrast, satuaration, 0 hue
332
+ T += [A.ColorJitter(*color_jitter, 0)]
333
+ else: # Use fixed crop for eval set (reproducibility)
334
+ T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]
335
+ T += [A.Normalize(mean=mean, std=std), ToTensorV2()] # Normalize and convert to Tensor
336
+ LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
337
+ return A.Compose(T)
338
+
339
+ except ImportError: # package not installed, skip
340
+ LOGGER.warning(f'{prefix}⚠️ not found, install with `pip install albumentations` (recommended)')
341
+ except Exception as e:
342
+ LOGGER.info(f'{prefix}{e}')
343
+
344
+
345
+ def classify_transforms(size=224):
346
+ # Transforms to apply if albumentations not installed
347
+ assert isinstance(size, int), f'ERROR: classify_transforms size {size} must be integer, not (list, tuple)'
348
+ # T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
349
+ return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
350
+
351
+
352
+ class LetterBox:
353
+ # YOLOv5 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
354
+ def __init__(self, size=(640, 640), auto=False, stride=32):
355
+ super().__init__()
356
+ self.h, self.w = (size, size) if isinstance(size, int) else size
357
+ self.auto = auto # pass max size integer, automatically solve for short side using stride
358
+ self.stride = stride # used with auto
359
+
360
+ def __call__(self, im): # im = np.array HWC
361
+ imh, imw = im.shape[:2]
362
+ r = min(self.h / imh, self.w / imw) # ratio of new/old
363
+ h, w = round(imh * r), round(imw * r) # resized image
364
+ hs, ws = (math.ceil(x / self.stride) * self.stride for x in (h, w)) if self.auto else self.h, self.w
365
+ top, left = round((hs - h) / 2 - 0.1), round((ws - w) / 2 - 0.1)
366
+ im_out = np.full((self.h, self.w, 3), 114, dtype=im.dtype)
367
+ im_out[top:top + h, left:left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
368
+ return im_out
369
+
370
+
371
+ class CenterCrop:
372
+ # YOLOv5 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()])
373
+ def __init__(self, size=640):
374
+ super().__init__()
375
+ self.h, self.w = (size, size) if isinstance(size, int) else size
376
+
377
+ def __call__(self, im): # im = np.array HWC
378
+ imh, imw = im.shape[:2]
379
+ m = min(imh, imw) # min dimension
380
+ top, left = (imh - m) // 2, (imw - m) // 2
381
+ return cv2.resize(im[top:top + m, left:left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)
382
+
383
+
384
+ class ToTensor:
385
+ # YOLOv5 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
386
+ def __init__(self, half=False):
387
+ super().__init__()
388
+ self.half = half
389
+
390
+ def __call__(self, im): # im = np.array HWC in BGR order
391
+ im = np.ascontiguousarray(im.transpose((2, 0, 1))[::-1]) # HWC to CHW -> BGR to RGB -> contiguous
392
+ im = torch.from_numpy(im) # to torch
393
+ im = im.half() if self.half else im.float() # uint8 to fp16/32
394
+ im /= 255.0 # 0-255 to 0.0-1.0
395
+ return im
utils/autoanchor.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import numpy as np
4
+ import torch
5
+ import yaml
6
+ from tqdm import tqdm
7
+
8
+ from utils import TryExcept
9
+ from utils.general import LOGGER, TQDM_BAR_FORMAT, colorstr
10
+
11
+ PREFIX = colorstr('AutoAnchor: ')
12
+
13
+
14
+ def check_anchor_order(m):
15
+ # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
16
+ a = m.anchors.prod(-1).mean(-1).view(-1) # mean anchor area per output layer
17
+ da = a[-1] - a[0] # delta a
18
+ ds = m.stride[-1] - m.stride[0] # delta s
19
+ if da and (da.sign() != ds.sign()): # same order
20
+ LOGGER.info(f'{PREFIX}Reversing anchor order')
21
+ m.anchors[:] = m.anchors.flip(0)
22
+
23
+
24
+ @TryExcept(f'{PREFIX}ERROR')
25
+ def check_anchors(dataset, model, thr=4.0, imgsz=640):
26
+ # Check anchor fit to data, recompute if necessary
27
+ m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
28
+ shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
29
+ scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale
30
+ wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes * scale, dataset.labels)])).float() # wh
31
+
32
+ def metric(k): # compute metric
33
+ r = wh[:, None] / k[None]
34
+ x = torch.min(r, 1 / r).min(2)[0] # ratio metric
35
+ best = x.max(1)[0] # best_x
36
+ aat = (x > 1 / thr).float().sum(1).mean() # anchors above threshold
37
+ bpr = (best > 1 / thr).float().mean() # best possible recall
38
+ return bpr, aat
39
+
40
+ stride = m.stride.to(m.anchors.device).view(-1, 1, 1) # model strides
41
+ anchors = m.anchors.clone() * stride # current anchors
42
+ bpr, aat = metric(anchors.cpu().view(-1, 2))
43
+ s = f'\n{PREFIX}{aat:.2f} anchors/target, {bpr:.3f} Best Possible Recall (BPR). '
44
+ if bpr > 0.98: # threshold to recompute
45
+ LOGGER.info(f'{s}Current anchors are a good fit to dataset ✅')
46
+ else:
47
+ LOGGER.info(f'{s}Anchors are a poor fit to dataset ⚠️, attempting to improve...')
48
+ na = m.anchors.numel() // 2 # number of anchors
49
+ anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
50
+ new_bpr = metric(anchors)[0]
51
+ if new_bpr > bpr: # replace anchors
52
+ anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors)
53
+ m.anchors[:] = anchors.clone().view_as(m.anchors)
54
+ check_anchor_order(m) # must be in pixel-space (not grid-space)
55
+ m.anchors /= stride
56
+ s = f'{PREFIX}Done ✅ (optional: update model *.yaml to use these anchors in the future)'
57
+ else:
58
+ s = f'{PREFIX}Done ⚠️ (original anchors better than new anchors, proceeding with original anchors)'
59
+ LOGGER.info(s)
60
+
61
+
62
+ def kmean_anchors(dataset='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=1000, verbose=True):
63
+ """ Creates kmeans-evolved anchors from training dataset
64
+
65
+ Arguments:
66
+ dataset: path to data.yaml, or a loaded dataset
67
+ n: number of anchors
68
+ img_size: image size used for training
69
+ thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0
70
+ gen: generations to evolve anchors using genetic algorithm
71
+ verbose: print all results
72
+
73
+ Return:
74
+ k: kmeans evolved anchors
75
+
76
+ Usage:
77
+ from utils.autoanchor import *; _ = kmean_anchors()
78
+ """
79
+ from scipy.cluster.vq import kmeans
80
+
81
+ npr = np.random
82
+ thr = 1 / thr
83
+
84
+ def metric(k, wh): # compute metrics
85
+ r = wh[:, None] / k[None]
86
+ x = torch.min(r, 1 / r).min(2)[0] # ratio metric
87
+ # x = wh_iou(wh, torch.tensor(k)) # iou metric
88
+ return x, x.max(1)[0] # x, best_x
89
+
90
+ def anchor_fitness(k): # mutation fitness
91
+ _, best = metric(torch.tensor(k, dtype=torch.float32), wh)
92
+ return (best * (best > thr).float()).mean() # fitness
93
+
94
+ def print_results(k, verbose=True):
95
+ k = k[np.argsort(k.prod(1))] # sort small to large
96
+ x, best = metric(k, wh0)
97
+ bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr
98
+ s = f'{PREFIX}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr\n' \
99
+ f'{PREFIX}n={n}, img_size={img_size}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, ' \
100
+ f'past_thr={x[x > thr].mean():.3f}-mean: '
101
+ for x in k:
102
+ s += '%i,%i, ' % (round(x[0]), round(x[1]))
103
+ if verbose:
104
+ LOGGER.info(s[:-2])
105
+ return k
106
+
107
+ if isinstance(dataset, str): # *.yaml file
108
+ with open(dataset, errors='ignore') as f:
109
+ data_dict = yaml.safe_load(f) # model dict
110
+ from utils.dataloaders import LoadImagesAndLabels
111
+ dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True)
112
+
113
+ # Get label wh
114
+ shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True)
115
+ wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh
116
+
117
+ # Filter
118
+ i = (wh0 < 3.0).any(1).sum()
119
+ if i:
120
+ LOGGER.info(f'{PREFIX}WARNING ⚠️ Extremely small objects found: {i} of {len(wh0)} labels are <3 pixels in size')
121
+ wh = wh0[(wh0 >= 2.0).any(1)].astype(np.float32) # filter > 2 pixels
122
+ # wh = wh * (npr.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1
123
+
124
+ # Kmeans init
125
+ try:
126
+ LOGGER.info(f'{PREFIX}Running kmeans for {n} anchors on {len(wh)} points...')
127
+ assert n <= len(wh) # apply overdetermined constraint
128
+ s = wh.std(0) # sigmas for whitening
129
+ k = kmeans(wh / s, n, iter=30)[0] * s # points
130
+ assert n == len(k) # kmeans may return fewer points than requested if wh is insufficient or too similar
131
+ except Exception:
132
+ LOGGER.warning(f'{PREFIX}WARNING ⚠️ switching strategies from kmeans to random init')
133
+ k = np.sort(npr.rand(n * 2)).reshape(n, 2) * img_size # random init
134
+ wh, wh0 = (torch.tensor(x, dtype=torch.float32) for x in (wh, wh0))
135
+ k = print_results(k, verbose=False)
136
+
137
+ # Plot
138
+ # k, d = [None] * 20, [None] * 20
139
+ # for i in tqdm(range(1, 21)):
140
+ # k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance
141
+ # fig, ax = plt.subplots(1, 2, figsize=(14, 7), tight_layout=True)
142
+ # ax = ax.ravel()
143
+ # ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.')
144
+ # fig, ax = plt.subplots(1, 2, figsize=(14, 7)) # plot wh
145
+ # ax[0].hist(wh[wh[:, 0]<100, 0],400)
146
+ # ax[1].hist(wh[wh[:, 1]<100, 1],400)
147
+ # fig.savefig('wh.png', dpi=200)
148
+
149
+ # Evolve
150
+ f, sh, mp, s = anchor_fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
151
+ pbar = tqdm(range(gen), bar_format=TQDM_BAR_FORMAT) # progress bar
152
+ for _ in pbar:
153
+ v = np.ones(sh)
154
+ while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
155
+ v = ((npr.random(sh) < mp) * random.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)
156
+ kg = (k.copy() * v).clip(min=2.0)
157
+ fg = anchor_fitness(kg)
158
+ if fg > f:
159
+ f, k = fg, kg.copy()
160
+ pbar.desc = f'{PREFIX}Evolving anchors with Genetic Algorithm: fitness = {f:.4f}'
161
+ if verbose:
162
+ print_results(k, verbose)
163
+
164
+ return print_results(k).astype(np.float32)
utils/autobatch.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from utils.general import LOGGER, colorstr
7
+ from utils.torch_utils import profile
8
+
9
+
10
+ def check_train_batch_size(model, imgsz=640, amp=True):
11
+ # Check YOLOv5 training batch size
12
+ with torch.cuda.amp.autocast(amp):
13
+ return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size
14
+
15
+
16
+ def autobatch(model, imgsz=640, fraction=0.8, batch_size=16):
17
+ # Automatically estimate best YOLOv5 batch size to use `fraction` of available CUDA memory
18
+ # Usage:
19
+ # import torch
20
+ # from utils.autobatch import autobatch
21
+ # model = torch.hub.load('ultralytics/yolov5', 'yolov5s', autoshape=False)
22
+ # print(autobatch(model))
23
+
24
+ # Check device
25
+ prefix = colorstr('AutoBatch: ')
26
+ LOGGER.info(f'{prefix}Computing optimal batch size for --imgsz {imgsz}')
27
+ device = next(model.parameters()).device # get model device
28
+ if device.type == 'cpu':
29
+ LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}')
30
+ return batch_size
31
+ if torch.backends.cudnn.benchmark:
32
+ LOGGER.info(f'{prefix} ⚠️ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}')
33
+ return batch_size
34
+
35
+ # Inspect CUDA memory
36
+ gb = 1 << 30 # bytes to GiB (1024 ** 3)
37
+ d = str(device).upper() # 'CUDA:0'
38
+ properties = torch.cuda.get_device_properties(device) # device properties
39
+ t = properties.total_memory / gb # GiB total
40
+ r = torch.cuda.memory_reserved(device) / gb # GiB reserved
41
+ a = torch.cuda.memory_allocated(device) / gb # GiB allocated
42
+ f = t - (r + a) # GiB free
43
+ LOGGER.info(f'{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free')
44
+
45
+ # Profile batch sizes
46
+ batch_sizes = [1, 2, 4, 8, 16]
47
+ try:
48
+ img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]
49
+ results = profile(img, model, n=3, device=device)
50
+ except Exception as e:
51
+ LOGGER.warning(f'{prefix}{e}')
52
+
53
+ # Fit a solution
54
+ y = [x[2] for x in results if x] # memory [2]
55
+ p = np.polyfit(batch_sizes[:len(y)], y, deg=1) # first degree polynomial fit
56
+ b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size)
57
+ if None in results: # some sizes failed
58
+ i = results.index(None) # first fail index
59
+ if b >= batch_sizes[i]: # y intercept above failure point
60
+ b = batch_sizes[max(i - 1, 0)] # select prior safe point
61
+ if b < 1 or b > 1024: # b outside of safe range
62
+ b = batch_size
63
+ LOGGER.warning(f'{prefix}WARNING ⚠️ CUDA anomaly detected, recommend restart environment and retry command.')
64
+
65
+ fraction = (np.polyval(p, b) + r + a) / t # actual fraction predicted
66
+ LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅')
67
+ return b
utils/callbacks.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+
3
+
4
+ class Callbacks:
5
+ """"
6
+ Handles all registered callbacks for YOLOv5 Hooks
7
+ """
8
+
9
+ def __init__(self):
10
+ # Define the available callbacks
11
+ self._callbacks = {
12
+ 'on_pretrain_routine_start': [],
13
+ 'on_pretrain_routine_end': [],
14
+ 'on_train_start': [],
15
+ 'on_train_epoch_start': [],
16
+ 'on_train_batch_start': [],
17
+ 'optimizer_step': [],
18
+ 'on_before_zero_grad': [],
19
+ 'on_train_batch_end': [],
20
+ 'on_train_epoch_end': [],
21
+ 'on_val_start': [],
22
+ 'on_val_batch_start': [],
23
+ 'on_val_image_end': [],
24
+ 'on_val_batch_end': [],
25
+ 'on_val_end': [],
26
+ 'on_fit_epoch_end': [], # fit = train + val
27
+ 'on_model_save': [],
28
+ 'on_train_end': [],
29
+ 'on_params_update': [],
30
+ 'teardown': [],}
31
+ self.stop_training = False # set True to interrupt training
32
+
33
+ def register_action(self, hook, name='', callback=None):
34
+ """
35
+ Register a new action to a callback hook
36
+
37
+ Args:
38
+ hook: The callback hook name to register the action to
39
+ name: The name of the action for later reference
40
+ callback: The callback to fire
41
+ """
42
+ assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
43
+ assert callable(callback), f"callback '{callback}' is not callable"
44
+ self._callbacks[hook].append({'name': name, 'callback': callback})
45
+
46
+ def get_registered_actions(self, hook=None):
47
+ """"
48
+ Returns all the registered actions by callback hook
49
+
50
+ Args:
51
+ hook: The name of the hook to check, defaults to all
52
+ """
53
+ return self._callbacks[hook] if hook else self._callbacks
54
+
55
+ def run(self, hook, *args, thread=False, **kwargs):
56
+ """
57
+ Loop through the registered actions and fire all callbacks on main thread
58
+
59
+ Args:
60
+ hook: The name of the hook to check, defaults to all
61
+ args: Arguments to receive from YOLOv5
62
+ thread: (boolean) Run callbacks in daemon thread
63
+ kwargs: Keyword Arguments to receive from YOLOv5
64
+ """
65
+
66
+ assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
67
+ for logger in self._callbacks[hook]:
68
+ if thread:
69
+ threading.Thread(target=logger['callback'], args=args, kwargs=kwargs, daemon=True).start()
70
+ else:
71
+ logger['callback'](*args, **kwargs)
utils/coco_utils.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+
3
+ from pycocotools.coco import COCO
4
+ from pycocotools import mask as maskUtils
5
+
6
+ # coco id: https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
7
+ all_instances_ids = [
8
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
9
+ 11, 13, 14, 15, 16, 17, 18, 19, 20,
10
+ 21, 22, 23, 24, 25, 27, 28,
11
+ 31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
12
+ 41, 42, 43, 44, 46, 47, 48, 49, 50,
13
+ 51, 52, 53, 54, 55, 56, 57, 58, 59, 60,
14
+ 61, 62, 63, 64, 65, 67, 70,
15
+ 72, 73, 74, 75, 76, 77, 78, 79, 80,
16
+ 81, 82, 84, 85, 86, 87, 88, 89, 90,
17
+ ]
18
+
19
+ all_stuff_ids = [
20
+ 92, 93, 94, 95, 96, 97, 98, 99, 100,
21
+ 101, 102, 103, 104, 105, 106, 107, 108, 109, 110,
22
+ 111, 112, 113, 114, 115, 116, 117, 118, 119, 120,
23
+ 121, 122, 123, 124, 125, 126, 127, 128, 129, 130,
24
+ 131, 132, 133, 134, 135, 136, 137, 138, 139, 140,
25
+ 141, 142, 143, 144, 145, 146, 147, 148, 149, 150,
26
+ 151, 152, 153, 154, 155, 156, 157, 158, 159, 160,
27
+ 161, 162, 163, 164, 165, 166, 167, 168, 169, 170,
28
+ 171, 172, 173, 174, 175, 176, 177, 178, 179, 180,
29
+ 181, 182,
30
+ # other
31
+ 183,
32
+ # unlabeled
33
+ 0,
34
+ ]
35
+
36
+ # panoptic id: https://github.com/cocodataset/panopticapi/blob/master/panoptic_coco_categories.json
37
+ panoptic_stuff_ids = [
38
+ 92, 93, 95, 100,
39
+ 107, 109,
40
+ 112, 118, 119,
41
+ 122, 125, 128, 130,
42
+ 133, 138,
43
+ 141, 144, 145, 147, 148, 149,
44
+ 151, 154, 155, 156, 159,
45
+ 161, 166, 168,
46
+ 171, 175, 176, 177, 178, 180,
47
+ 181, 184, 185, 186, 187, 188, 189, 190,
48
+ 191, 192, 193, 194, 195, 196, 197, 198, 199, 200,
49
+ # unlabeled
50
+ 0,
51
+ ]
52
+
53
+ def getCocoIds(name = 'semantic'):
54
+ if 'instances' == name:
55
+ return all_instances_ids
56
+ elif 'stuff' == name:
57
+ return all_stuff_ids
58
+ elif 'panoptic' == name:
59
+ return all_instances_ids + panoptic_stuff_ids
60
+ else: # semantic
61
+ return all_instances_ids + all_stuff_ids
62
+
63
+ def getMappingId(index, name = 'semantic'):
64
+ ids = getCocoIds(name = name)
65
+ return ids[index]
66
+
67
+ def getMappingIndex(id, name = 'semantic'):
68
+ ids = getCocoIds(name = name)
69
+ return ids.index(id)
70
+
71
+ # convert ann to rle encoded string
72
+ def annToRLE(ann, img_size):
73
+ h, w = img_size
74
+ segm = ann['segmentation']
75
+ if list == type(segm):
76
+ # polygon -- a single object might consist of multiple parts
77
+ # we merge all parts into one mask rle code
78
+ rles = maskUtils.frPyObjects(segm, h, w)
79
+ rle = maskUtils.merge(rles)
80
+ elif list == type(segm['counts']):
81
+ # uncompressed RLE
82
+ rle = maskUtils.frPyObjects(segm, h, w)
83
+ else:
84
+ # rle
85
+ rle = ann['segmentation']
86
+ return rle
87
+
88
+ # decode ann to mask martix
89
+ def annToMask(ann, img_size):
90
+ rle = annToRLE(ann, img_size)
91
+ m = maskUtils.decode(rle)
92
+ return m
93
+
94
+ # convert mask to polygans
95
+ def convert_to_polys(mask):
96
+ # opencv 3.2
97
+ contours, hierarchy = cv2.findContours((mask).astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
98
+
99
+ # before opencv 3.2
100
+ # contours, hierarchy = cv2.findContours((mask).astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
101
+
102
+ segmentation = []
103
+ for contour in contours:
104
+ contour = contour.flatten().tolist()
105
+ if 4 < len(contour):
106
+ segmentation.append(contour)
107
+
108
+ return segmentation
utils/dataloaders.py ADDED
@@ -0,0 +1,1217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import glob
3
+ import hashlib
4
+ import json
5
+ import math
6
+ import os
7
+ import random
8
+ import shutil
9
+ import time
10
+ from itertools import repeat
11
+ from multiprocessing.pool import Pool, ThreadPool
12
+ from pathlib import Path
13
+ from threading import Thread
14
+ from urllib.parse import urlparse
15
+
16
+ import numpy as np
17
+ import psutil
18
+ import torch
19
+ import torch.nn.functional as F
20
+ import torchvision
21
+ import yaml
22
+ from PIL import ExifTags, Image, ImageOps
23
+ from torch.utils.data import DataLoader, Dataset, dataloader, distributed
24
+ from tqdm import tqdm
25
+
26
+ from utils.augmentations import (Albumentations, augment_hsv, classify_albumentations, classify_transforms, copy_paste,
27
+ letterbox, mixup, random_perspective)
28
+ from utils.general import (DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, check_dataset, check_requirements,
29
+ check_yaml, clean_str, cv2, is_colab, is_kaggle, segments2boxes, unzip_file, xyn2xy,
30
+ xywh2xyxy, xywhn2xyxy, xyxy2xywhn)
31
+ from utils.torch_utils import torch_distributed_zero_first
32
+
33
+ # Parameters
34
+ HELP_URL = 'See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
35
+ IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm' # include image suffixes
36
+ VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes
37
+ LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
38
+ RANK = int(os.getenv('RANK', -1))
39
+ PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders
40
+
41
+ # Get orientation exif tag
42
+ for orientation in ExifTags.TAGS.keys():
43
+ if ExifTags.TAGS[orientation] == 'Orientation':
44
+ break
45
+
46
+
47
+ def get_hash(paths):
48
+ # Returns a single hash value of a list of paths (files or dirs)
49
+ size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
50
+ h = hashlib.md5(str(size).encode()) # hash sizes
51
+ h.update(''.join(paths).encode()) # hash paths
52
+ return h.hexdigest() # return hash
53
+
54
+
55
+ def exif_size(img):
56
+ # Returns exif-corrected PIL size
57
+ s = img.size # (width, height)
58
+ with contextlib.suppress(Exception):
59
+ rotation = dict(img._getexif().items())[orientation]
60
+ if rotation in [6, 8]: # rotation 270 or 90
61
+ s = (s[1], s[0])
62
+ return s
63
+
64
+
65
+ def exif_transpose(image):
66
+ """
67
+ Transpose a PIL image accordingly if it has an EXIF Orientation tag.
68
+ Inplace version of https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py exif_transpose()
69
+
70
+ :param image: The image to transpose.
71
+ :return: An image.
72
+ """
73
+ exif = image.getexif()
74
+ orientation = exif.get(0x0112, 1) # default 1
75
+ if orientation > 1:
76
+ method = {
77
+ 2: Image.FLIP_LEFT_RIGHT,
78
+ 3: Image.ROTATE_180,
79
+ 4: Image.FLIP_TOP_BOTTOM,
80
+ 5: Image.TRANSPOSE,
81
+ 6: Image.ROTATE_270,
82
+ 7: Image.TRANSVERSE,
83
+ 8: Image.ROTATE_90}.get(orientation)
84
+ if method is not None:
85
+ image = image.transpose(method)
86
+ del exif[0x0112]
87
+ image.info["exif"] = exif.tobytes()
88
+ return image
89
+
90
+
91
+ def seed_worker(worker_id):
92
+ # Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader
93
+ worker_seed = torch.initial_seed() % 2 ** 32
94
+ np.random.seed(worker_seed)
95
+ random.seed(worker_seed)
96
+
97
+
98
+ def create_dataloader(path,
99
+ imgsz,
100
+ batch_size,
101
+ stride,
102
+ single_cls=False,
103
+ hyp=None,
104
+ augment=False,
105
+ cache=False,
106
+ pad=0.0,
107
+ rect=False,
108
+ rank=-1,
109
+ workers=8,
110
+ image_weights=False,
111
+ close_mosaic=False,
112
+ quad=False,
113
+ min_items=0,
114
+ prefix='',
115
+ shuffle=False):
116
+ if rect and shuffle:
117
+ LOGGER.warning('WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False')
118
+ shuffle = False
119
+ with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
120
+ dataset = LoadImagesAndLabels(
121
+ path,
122
+ imgsz,
123
+ batch_size,
124
+ augment=augment, # augmentation
125
+ hyp=hyp, # hyperparameters
126
+ rect=rect, # rectangular batches
127
+ cache_images=cache,
128
+ single_cls=single_cls,
129
+ stride=int(stride),
130
+ pad=pad,
131
+ image_weights=image_weights,
132
+ min_items=min_items,
133
+ prefix=prefix)
134
+
135
+ batch_size = min(batch_size, len(dataset))
136
+ nd = torch.cuda.device_count() # number of CUDA devices
137
+ nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
138
+ sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
139
+ #loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
140
+ loader = DataLoader if image_weights or close_mosaic else InfiniteDataLoader
141
+ generator = torch.Generator()
142
+ generator.manual_seed(6148914691236517205 + RANK)
143
+ return loader(dataset,
144
+ batch_size=batch_size,
145
+ shuffle=shuffle and sampler is None,
146
+ num_workers=nw,
147
+ sampler=sampler,
148
+ pin_memory=PIN_MEMORY,
149
+ collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn,
150
+ worker_init_fn=seed_worker,
151
+ generator=generator), dataset
152
+
153
+
154
+ class InfiniteDataLoader(dataloader.DataLoader):
155
+ """ Dataloader that reuses workers
156
+
157
+ Uses same syntax as vanilla DataLoader
158
+ """
159
+
160
+ def __init__(self, *args, **kwargs):
161
+ super().__init__(*args, **kwargs)
162
+ object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
163
+ self.iterator = super().__iter__()
164
+
165
+ def __len__(self):
166
+ return len(self.batch_sampler.sampler)
167
+
168
+ def __iter__(self):
169
+ for _ in range(len(self)):
170
+ yield next(self.iterator)
171
+
172
+
173
+ class _RepeatSampler:
174
+ """ Sampler that repeats forever
175
+
176
+ Args:
177
+ sampler (Sampler)
178
+ """
179
+
180
+ def __init__(self, sampler):
181
+ self.sampler = sampler
182
+
183
+ def __iter__(self):
184
+ while True:
185
+ yield from iter(self.sampler)
186
+
187
+
188
+ class LoadScreenshots:
189
+ # YOLOv5 screenshot dataloader, i.e. `python detect.py --source "screen 0 100 100 512 256"`
190
+ def __init__(self, source, img_size=640, stride=32, auto=True, transforms=None):
191
+ # source = [screen_number left top width height] (pixels)
192
+ check_requirements('mss')
193
+ import mss
194
+
195
+ source, *params = source.split()
196
+ self.screen, left, top, width, height = 0, None, None, None, None # default to full screen 0
197
+ if len(params) == 1:
198
+ self.screen = int(params[0])
199
+ elif len(params) == 4:
200
+ left, top, width, height = (int(x) for x in params)
201
+ elif len(params) == 5:
202
+ self.screen, left, top, width, height = (int(x) for x in params)
203
+ self.img_size = img_size
204
+ self.stride = stride
205
+ self.transforms = transforms
206
+ self.auto = auto
207
+ self.mode = 'stream'
208
+ self.frame = 0
209
+ self.sct = mss.mss()
210
+
211
+ # Parse monitor shape
212
+ monitor = self.sct.monitors[self.screen]
213
+ self.top = monitor["top"] if top is None else (monitor["top"] + top)
214
+ self.left = monitor["left"] if left is None else (monitor["left"] + left)
215
+ self.width = width or monitor["width"]
216
+ self.height = height or monitor["height"]
217
+ self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height}
218
+
219
+ def __iter__(self):
220
+ return self
221
+
222
+ def __next__(self):
223
+ # mss screen capture: get raw pixels from the screen as np array
224
+ im0 = np.array(self.sct.grab(self.monitor))[:, :, :3] # [:, :, :3] BGRA to BGR
225
+ s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: "
226
+
227
+ if self.transforms:
228
+ im = self.transforms(im0) # transforms
229
+ else:
230
+ im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize
231
+ im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
232
+ im = np.ascontiguousarray(im) # contiguous
233
+ self.frame += 1
234
+ return str(self.screen), im, im0, None, s # screen, img, original img, im0s, s
235
+
236
+
237
+ class LoadImages:
238
+ # YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
239
+ def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
240
+ files = []
241
+ for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
242
+ p = str(Path(p).resolve())
243
+ if '*' in p:
244
+ files.extend(sorted(glob.glob(p, recursive=True))) # glob
245
+ elif os.path.isdir(p):
246
+ files.extend(sorted(glob.glob(os.path.join(p, '*.*')))) # dir
247
+ elif os.path.isfile(p):
248
+ files.append(p) # files
249
+ else:
250
+ raise FileNotFoundError(f'{p} does not exist')
251
+
252
+ images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
253
+ videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
254
+ ni, nv = len(images), len(videos)
255
+
256
+ self.img_size = img_size
257
+ self.stride = stride
258
+ self.files = images + videos
259
+ self.nf = ni + nv # number of files
260
+ self.video_flag = [False] * ni + [True] * nv
261
+ self.mode = 'image'
262
+ self.auto = auto
263
+ self.transforms = transforms # optional
264
+ self.vid_stride = vid_stride # video frame-rate stride
265
+ if any(videos):
266
+ self._new_video(videos[0]) # new video
267
+ else:
268
+ self.cap = None
269
+ assert self.nf > 0, f'No images or videos found in {p}. ' \
270
+ f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}'
271
+
272
+ def __iter__(self):
273
+ self.count = 0
274
+ return self
275
+
276
+ def __next__(self):
277
+ if self.count == self.nf:
278
+ raise StopIteration
279
+ path = self.files[self.count]
280
+
281
+ if self.video_flag[self.count]:
282
+ # Read video
283
+ self.mode = 'video'
284
+ for _ in range(self.vid_stride):
285
+ self.cap.grab()
286
+ ret_val, im0 = self.cap.retrieve()
287
+ while not ret_val:
288
+ self.count += 1
289
+ self.cap.release()
290
+ if self.count == self.nf: # last video
291
+ raise StopIteration
292
+ path = self.files[self.count]
293
+ self._new_video(path)
294
+ ret_val, im0 = self.cap.read()
295
+
296
+ self.frame += 1
297
+ # im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False
298
+ s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
299
+
300
+ else:
301
+ # Read image
302
+ self.count += 1
303
+ im0 = cv2.imread(path) # BGR
304
+ assert im0 is not None, f'Image Not Found {path}'
305
+ s = f'image {self.count}/{self.nf} {path}: '
306
+
307
+ if self.transforms:
308
+ im = self.transforms(im0) # transforms
309
+ else:
310
+ im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize
311
+ im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
312
+ im = np.ascontiguousarray(im) # contiguous
313
+
314
+ return path, im, im0, self.cap, s
315
+
316
+ def _new_video(self, path):
317
+ # Create a new video capture object
318
+ self.frame = 0
319
+ self.cap = cv2.VideoCapture(path)
320
+ self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
321
+ self.orientation = int(self.cap.get(cv2.CAP_PROP_ORIENTATION_META)) # rotation degrees
322
+ # self.cap.set(cv2.CAP_PROP_ORIENTATION_AUTO, 0) # disable https://github.com/ultralytics/yolov5/issues/8493
323
+
324
+ def _cv2_rotate(self, im):
325
+ # Rotate a cv2 video manually
326
+ if self.orientation == 0:
327
+ return cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE)
328
+ elif self.orientation == 180:
329
+ return cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE)
330
+ elif self.orientation == 90:
331
+ return cv2.rotate(im, cv2.ROTATE_180)
332
+ return im
333
+
334
+ def __len__(self):
335
+ return self.nf # number of files
336
+
337
+
338
+ class LoadStreams:
339
+ # YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
340
+ def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
341
+ torch.backends.cudnn.benchmark = True # faster for fixed-size inference
342
+ self.mode = 'stream'
343
+ self.img_size = img_size
344
+ self.stride = stride
345
+ self.vid_stride = vid_stride # video frame-rate stride
346
+ sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources]
347
+ n = len(sources)
348
+ self.sources = [clean_str(x) for x in sources] # clean source names for later
349
+ self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
350
+ for i, s in enumerate(sources): # index, source
351
+ # Start thread to read frames from video stream
352
+ st = f'{i + 1}/{n}: {s}... '
353
+ if urlparse(s).hostname in ('www.youtube.com', 'youtube.com', 'youtu.be'): # if source is YouTube video
354
+ # YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/Zgi9g1ksQHc'
355
+ check_requirements(('pafy', 'youtube_dl==2020.12.2'))
356
+ import pafy
357
+ s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
358
+ s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
359
+ if s == 0:
360
+ assert not is_colab(), '--source 0 webcam unsupported on Colab. Rerun command in a local environment.'
361
+ assert not is_kaggle(), '--source 0 webcam unsupported on Kaggle. Rerun command in a local environment.'
362
+ cap = cv2.VideoCapture(s)
363
+ assert cap.isOpened(), f'{st}Failed to open {s}'
364
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
365
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
366
+ fps = cap.get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
367
+ self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf') # infinite stream fallback
368
+ self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback
369
+
370
+ _, self.imgs[i] = cap.read() # guarantee first frame
371
+ self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
372
+ LOGGER.info(f"{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)")
373
+ self.threads[i].start()
374
+ LOGGER.info('') # newline
375
+
376
+ # check for common shapes
377
+ s = np.stack([letterbox(x, img_size, stride=stride, auto=auto)[0].shape for x in self.imgs])
378
+ self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
379
+ self.auto = auto and self.rect
380
+ self.transforms = transforms # optional
381
+ if not self.rect:
382
+ LOGGER.warning('WARNING ⚠️ Stream shapes differ. For optimal performance supply similarly-shaped streams.')
383
+
384
+ def update(self, i, cap, stream):
385
+ # Read stream `i` frames in daemon thread
386
+ n, f = 0, self.frames[i] # frame number, frame array
387
+ while cap.isOpened() and n < f:
388
+ n += 1
389
+ cap.grab() # .read() = .grab() followed by .retrieve()
390
+ if n % self.vid_stride == 0:
391
+ success, im = cap.retrieve()
392
+ if success:
393
+ self.imgs[i] = im
394
+ else:
395
+ LOGGER.warning('WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.')
396
+ self.imgs[i] = np.zeros_like(self.imgs[i])
397
+ cap.open(stream) # re-open stream if signal was lost
398
+ time.sleep(0.0) # wait time
399
+
400
+ def __iter__(self):
401
+ self.count = -1
402
+ return self
403
+
404
+ def __next__(self):
405
+ self.count += 1
406
+ if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
407
+ cv2.destroyAllWindows()
408
+ raise StopIteration
409
+
410
+ im0 = self.imgs.copy()
411
+ if self.transforms:
412
+ im = np.stack([self.transforms(x) for x in im0]) # transforms
413
+ else:
414
+ im = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0] for x in im0]) # resize
415
+ im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
416
+ im = np.ascontiguousarray(im) # contiguous
417
+
418
+ return self.sources, im, im0, None, ''
419
+
420
+ def __len__(self):
421
+ return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
422
+
423
+
424
+ def img2label_paths(img_paths):
425
+ # Define label paths as a function of image paths
426
+ sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}labels{os.sep}' # /images/, /labels/ substrings
427
+ return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
428
+
429
+
430
+ class LoadImagesAndLabels(Dataset):
431
+ # YOLOv5 train_loader/val_loader, loads images and labels for training and validation
432
+ cache_version = 0.6 # dataset labels *.cache version
433
+ rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]
434
+
435
+ def __init__(self,
436
+ path,
437
+ img_size=640,
438
+ batch_size=16,
439
+ augment=False,
440
+ hyp=None,
441
+ rect=False,
442
+ image_weights=False,
443
+ cache_images=False,
444
+ single_cls=False,
445
+ stride=32,
446
+ pad=0.0,
447
+ min_items=0,
448
+ prefix=''):
449
+ self.img_size = img_size
450
+ self.augment = augment
451
+ self.hyp = hyp
452
+ self.image_weights = image_weights
453
+ self.rect = False if image_weights else rect
454
+ self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
455
+ self.mosaic_border = [-img_size // 2, -img_size // 2]
456
+ self.stride = stride
457
+ self.path = path
458
+ self.albumentations = Albumentations(size=img_size) if augment else None
459
+
460
+ try:
461
+ f = [] # image files
462
+ for p in path if isinstance(path, list) else [path]:
463
+ p = Path(p) # os-agnostic
464
+ if p.is_dir(): # dir
465
+ f += glob.glob(str(p / '**' / '*.*'), recursive=True)
466
+ # f = list(p.rglob('*.*')) # pathlib
467
+ elif p.is_file(): # file
468
+ with open(p) as t:
469
+ t = t.read().strip().splitlines()
470
+ parent = str(p.parent) + os.sep
471
+ f += [x.replace('./', parent, 1) if x.startswith('./') else x for x in t] # to global path
472
+ # f += [p.parent / x.lstrip(os.sep) for x in t] # to global path (pathlib)
473
+ else:
474
+ raise FileNotFoundError(f'{prefix}{p} does not exist')
475
+ self.im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
476
+ # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
477
+ assert self.im_files, f'{prefix}No images found'
478
+ except Exception as e:
479
+ raise Exception(f'{prefix}Error loading data from {path}: {e}\n{HELP_URL}') from e
480
+
481
+ # Check cache
482
+ self.label_files = img2label_paths(self.im_files) # labels
483
+ cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')
484
+ try:
485
+ cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
486
+ assert cache['version'] == self.cache_version # matches current version
487
+ assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical hash
488
+ except Exception:
489
+ cache, exists = self.cache_labels(cache_path, prefix), False # run cache ops
490
+
491
+ # Display cache
492
+ nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
493
+ if exists and LOCAL_RANK in {-1, 0}:
494
+ d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
495
+ tqdm(None, desc=prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results
496
+ if cache['msgs']:
497
+ LOGGER.info('\n'.join(cache['msgs'])) # display warnings
498
+ assert nf > 0 or not augment, f'{prefix}No labels found in {cache_path}, can not start training. {HELP_URL}'
499
+
500
+ # Read cache
501
+ [cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
502
+ labels, shapes, self.segments = zip(*cache.values())
503
+ nl = len(np.concatenate(labels, 0)) # number of labels
504
+ assert nl > 0 or not augment, f'{prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}'
505
+ self.labels = list(labels)
506
+ self.shapes = np.array(shapes)
507
+ self.im_files = list(cache.keys()) # update
508
+ self.label_files = img2label_paths(cache.keys()) # update
509
+
510
+ # Filter images
511
+ if min_items:
512
+ include = np.array([len(x) >= min_items for x in self.labels]).nonzero()[0].astype(int)
513
+ LOGGER.info(f'{prefix}{n - len(include)}/{n} images filtered from dataset')
514
+ self.im_files = [self.im_files[i] for i in include]
515
+ self.label_files = [self.label_files[i] for i in include]
516
+ self.labels = [self.labels[i] for i in include]
517
+ self.segments = [self.segments[i] for i in include]
518
+ self.shapes = self.shapes[include] # wh
519
+
520
+ # Create indices
521
+ n = len(self.shapes) # number of images
522
+ bi = np.floor(np.arange(n) / batch_size).astype(int) # batch index
523
+ nb = bi[-1] + 1 # number of batches
524
+ self.batch = bi # batch index of image
525
+ self.n = n
526
+ self.indices = range(n)
527
+
528
+ # Update labels
529
+ include_class = [] # filter labels to include only these classes (optional)
530
+ include_class_array = np.array(include_class).reshape(1, -1)
531
+ for i, (label, segment) in enumerate(zip(self.labels, self.segments)):
532
+ if include_class:
533
+ j = (label[:, 0:1] == include_class_array).any(1)
534
+ self.labels[i] = label[j]
535
+ if segment:
536
+ self.segments[i] = segment[j]
537
+ if single_cls: # single-class training, merge all classes into 0
538
+ self.labels[i][:, 0] = 0
539
+
540
+ # Rectangular Training
541
+ if self.rect:
542
+ # Sort by aspect ratio
543
+ s = self.shapes # wh
544
+ ar = s[:, 1] / s[:, 0] # aspect ratio
545
+ irect = ar.argsort()
546
+ self.im_files = [self.im_files[i] for i in irect]
547
+ self.label_files = [self.label_files[i] for i in irect]
548
+ self.labels = [self.labels[i] for i in irect]
549
+ self.segments = [self.segments[i] for i in irect]
550
+ self.shapes = s[irect] # wh
551
+ ar = ar[irect]
552
+
553
+ # Set training image shapes
554
+ shapes = [[1, 1]] * nb
555
+ for i in range(nb):
556
+ ari = ar[bi == i]
557
+ mini, maxi = ari.min(), ari.max()
558
+ if maxi < 1:
559
+ shapes[i] = [maxi, 1]
560
+ elif mini > 1:
561
+ shapes[i] = [1, 1 / mini]
562
+
563
+ self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(int) * stride
564
+
565
+ # Cache images into RAM/disk for faster training
566
+ if cache_images == 'ram' and not self.check_cache_ram(prefix=prefix):
567
+ cache_images = False
568
+ self.ims = [None] * n
569
+ self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
570
+ if cache_images:
571
+ b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
572
+ self.im_hw0, self.im_hw = [None] * n, [None] * n
573
+ fcn = self.cache_images_to_disk if cache_images == 'disk' else self.load_image
574
+ results = ThreadPool(NUM_THREADS).imap(fcn, range(n))
575
+ pbar = tqdm(enumerate(results), total=n, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
576
+ for i, x in pbar:
577
+ if cache_images == 'disk':
578
+ b += self.npy_files[i].stat().st_size
579
+ else: # 'ram'
580
+ self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
581
+ b += self.ims[i].nbytes
582
+ pbar.desc = f'{prefix}Caching images ({b / gb:.1f}GB {cache_images})'
583
+ pbar.close()
584
+
585
+ def check_cache_ram(self, safety_margin=0.1, prefix=''):
586
+ # Check image caching requirements vs available memory
587
+ b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
588
+ n = min(self.n, 30) # extrapolate from 30 random images
589
+ for _ in range(n):
590
+ im = cv2.imread(random.choice(self.im_files)) # sample image
591
+ ratio = self.img_size / max(im.shape[0], im.shape[1]) # max(h, w) # ratio
592
+ b += im.nbytes * ratio ** 2
593
+ mem_required = b * self.n / n # GB required to cache dataset into RAM
594
+ mem = psutil.virtual_memory()
595
+ cache = mem_required * (1 + safety_margin) < mem.available # to cache or not to cache, that is the question
596
+ if not cache:
597
+ LOGGER.info(f"{prefix}{mem_required / gb:.1f}GB RAM required, "
598
+ f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, "
599
+ f"{'caching images ✅' if cache else 'not caching images ⚠️'}")
600
+ return cache
601
+
602
+ def cache_labels(self, path=Path('./labels.cache'), prefix=''):
603
+ # Cache dataset labels, check images and read shapes
604
+ x = {} # dict
605
+ nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
606
+ desc = f"{prefix}Scanning {path.parent / path.stem}..."
607
+ with Pool(NUM_THREADS) as pool:
608
+ pbar = tqdm(pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(prefix))),
609
+ desc=desc,
610
+ total=len(self.im_files),
611
+ bar_format=TQDM_BAR_FORMAT)
612
+ for im_file, lb, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
613
+ nm += nm_f
614
+ nf += nf_f
615
+ ne += ne_f
616
+ nc += nc_f
617
+ if im_file:
618
+ x[im_file] = [lb, shape, segments]
619
+ if msg:
620
+ msgs.append(msg)
621
+ pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
622
+
623
+ pbar.close()
624
+ if msgs:
625
+ LOGGER.info('\n'.join(msgs))
626
+ if nf == 0:
627
+ LOGGER.warning(f'{prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}')
628
+ x['hash'] = get_hash(self.label_files + self.im_files)
629
+ x['results'] = nf, nm, ne, nc, len(self.im_files)
630
+ x['msgs'] = msgs # warnings
631
+ x['version'] = self.cache_version # cache version
632
+ try:
633
+ np.save(path, x) # save cache for next time
634
+ path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
635
+ LOGGER.info(f'{prefix}New cache created: {path}')
636
+ except Exception as e:
637
+ LOGGER.warning(f'{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable: {e}') # not writeable
638
+ return x
639
+
640
+ def __len__(self):
641
+ return len(self.im_files)
642
+
643
+ # def __iter__(self):
644
+ # self.count = -1
645
+ # print('ran dataset iter')
646
+ # #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
647
+ # return self
648
+
649
+ def __getitem__(self, index):
650
+ index = self.indices[index] # linear, shuffled, or image_weights
651
+
652
+ hyp = self.hyp
653
+ mosaic = self.mosaic and random.random() < hyp['mosaic']
654
+ if mosaic:
655
+ # Load mosaic
656
+ img, labels = self.load_mosaic(index)
657
+ shapes = None
658
+
659
+ # MixUp augmentation
660
+ if random.random() < hyp['mixup']:
661
+ img, labels = mixup(img, labels, *self.load_mosaic(random.randint(0, self.n - 1)))
662
+
663
+ else:
664
+ # Load image
665
+ img, (h0, w0), (h, w) = self.load_image(index)
666
+
667
+ # Letterbox
668
+ shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
669
+ img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
670
+ shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
671
+
672
+ labels = self.labels[index].copy()
673
+ if labels.size: # normalized xywh to pixel xyxy format
674
+ labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
675
+
676
+ if self.augment:
677
+ img, labels = random_perspective(img,
678
+ labels,
679
+ degrees=hyp['degrees'],
680
+ translate=hyp['translate'],
681
+ scale=hyp['scale'],
682
+ shear=hyp['shear'],
683
+ perspective=hyp['perspective'])
684
+
685
+ nl = len(labels) # number of labels
686
+ if nl:
687
+ labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0], clip=True, eps=1E-3)
688
+
689
+ if self.augment:
690
+ # Albumentations
691
+ img, labels = self.albumentations(img, labels)
692
+ nl = len(labels) # update after albumentations
693
+
694
+ # HSV color-space
695
+ augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
696
+
697
+ # Flip up-down
698
+ if random.random() < hyp['flipud']:
699
+ img = np.flipud(img)
700
+ if nl:
701
+ labels[:, 2] = 1 - labels[:, 2]
702
+
703
+ # Flip left-right
704
+ if random.random() < hyp['fliplr']:
705
+ img = np.fliplr(img)
706
+ if nl:
707
+ labels[:, 1] = 1 - labels[:, 1]
708
+
709
+ # Cutouts
710
+ # labels = cutout(img, labels, p=0.5)
711
+ # nl = len(labels) # update after cutout
712
+
713
+ labels_out = torch.zeros((nl, 6))
714
+ if nl:
715
+ labels_out[:, 1:] = torch.from_numpy(labels)
716
+
717
+ # Convert
718
+ img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
719
+ img = np.ascontiguousarray(img)
720
+
721
+ return torch.from_numpy(img), labels_out, self.im_files[index], shapes
722
+
723
+ def load_image(self, i):
724
+ # Loads 1 image from dataset index 'i', returns (im, original hw, resized hw)
725
+ im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i],
726
+ if im is None: # not cached in RAM
727
+ if fn.exists(): # load npy
728
+ im = np.load(fn)
729
+ else: # read image
730
+ im = cv2.imread(f) # BGR
731
+ assert im is not None, f'Image Not Found {f}'
732
+ h0, w0 = im.shape[:2] # orig hw
733
+ r = self.img_size / max(h0, w0) # ratio
734
+ if r != 1: # if sizes are not equal
735
+ interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
736
+ im = cv2.resize(im, (int(w0 * r), int(h0 * r)), interpolation=interp)
737
+ return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
738
+ return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized
739
+
740
+ def cache_images_to_disk(self, i):
741
+ # Saves an image as an *.npy file for faster loading
742
+ f = self.npy_files[i]
743
+ if not f.exists():
744
+ np.save(f.as_posix(), cv2.imread(self.im_files[i]))
745
+
746
+ def load_mosaic(self, index):
747
+ # YOLOv5 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic
748
+ labels4, segments4 = [], []
749
+ s = self.img_size
750
+ yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border) # mosaic center x, y
751
+ indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices
752
+ random.shuffle(indices)
753
+ for i, index in enumerate(indices):
754
+ # Load image
755
+ img, _, (h, w) = self.load_image(index)
756
+
757
+ # place img in img4
758
+ if i == 0: # top left
759
+ img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
760
+ x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
761
+ x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
762
+ elif i == 1: # top right
763
+ x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
764
+ x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
765
+ elif i == 2: # bottom left
766
+ x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
767
+ x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
768
+ elif i == 3: # bottom right
769
+ x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
770
+ x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
771
+
772
+ img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
773
+ padw = x1a - x1b
774
+ padh = y1a - y1b
775
+
776
+ # Labels
777
+ labels, segments = self.labels[index].copy(), self.segments[index].copy()
778
+ if labels.size:
779
+ labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format
780
+ segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
781
+ labels4.append(labels)
782
+ segments4.extend(segments)
783
+
784
+ # Concat/clip labels
785
+ labels4 = np.concatenate(labels4, 0)
786
+ for x in (labels4[:, 1:], *segments4):
787
+ np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
788
+ # img4, labels4 = replicate(img4, labels4) # replicate
789
+
790
+ # Augment
791
+ img4, labels4, segments4 = copy_paste(img4, labels4, segments4, p=self.hyp['copy_paste'])
792
+ img4, labels4 = random_perspective(img4,
793
+ labels4,
794
+ segments4,
795
+ degrees=self.hyp['degrees'],
796
+ translate=self.hyp['translate'],
797
+ scale=self.hyp['scale'],
798
+ shear=self.hyp['shear'],
799
+ perspective=self.hyp['perspective'],
800
+ border=self.mosaic_border) # border to remove
801
+
802
+ return img4, labels4
803
+
804
+ def load_mosaic9(self, index):
805
+ # YOLOv5 9-mosaic loader. Loads 1 image + 8 random images into a 9-image mosaic
806
+ labels9, segments9 = [], []
807
+ s = self.img_size
808
+ indices = [index] + random.choices(self.indices, k=8) # 8 additional image indices
809
+ random.shuffle(indices)
810
+ hp, wp = -1, -1 # height, width previous
811
+ for i, index in enumerate(indices):
812
+ # Load image
813
+ img, _, (h, w) = self.load_image(index)
814
+
815
+ # place img in img9
816
+ if i == 0: # center
817
+ img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
818
+ h0, w0 = h, w
819
+ c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
820
+ elif i == 1: # top
821
+ c = s, s - h, s + w, s
822
+ elif i == 2: # top right
823
+ c = s + wp, s - h, s + wp + w, s
824
+ elif i == 3: # right
825
+ c = s + w0, s, s + w0 + w, s + h
826
+ elif i == 4: # bottom right
827
+ c = s + w0, s + hp, s + w0 + w, s + hp + h
828
+ elif i == 5: # bottom
829
+ c = s + w0 - w, s + h0, s + w0, s + h0 + h
830
+ elif i == 6: # bottom left
831
+ c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
832
+ elif i == 7: # left
833
+ c = s - w, s + h0 - h, s, s + h0
834
+ elif i == 8: # top left
835
+ c = s - w, s + h0 - hp - h, s, s + h0 - hp
836
+
837
+ padx, pady = c[:2]
838
+ x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
839
+
840
+ # Labels
841
+ labels, segments = self.labels[index].copy(), self.segments[index].copy()
842
+ if labels.size:
843
+ labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady) # normalized xywh to pixel xyxy format
844
+ segments = [xyn2xy(x, w, h, padx, pady) for x in segments]
845
+ labels9.append(labels)
846
+ segments9.extend(segments)
847
+
848
+ # Image
849
+ img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:] # img9[ymin:ymax, xmin:xmax]
850
+ hp, wp = h, w # height, width previous
851
+
852
+ # Offset
853
+ yc, xc = (int(random.uniform(0, s)) for _ in self.mosaic_border) # mosaic center x, y
854
+ img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s]
855
+
856
+ # Concat/clip labels
857
+ labels9 = np.concatenate(labels9, 0)
858
+ labels9[:, [1, 3]] -= xc
859
+ labels9[:, [2, 4]] -= yc
860
+ c = np.array([xc, yc]) # centers
861
+ segments9 = [x - c for x in segments9]
862
+
863
+ for x in (labels9[:, 1:], *segments9):
864
+ np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
865
+ # img9, labels9 = replicate(img9, labels9) # replicate
866
+
867
+ # Augment
868
+ img9, labels9, segments9 = copy_paste(img9, labels9, segments9, p=self.hyp['copy_paste'])
869
+ img9, labels9 = random_perspective(img9,
870
+ labels9,
871
+ segments9,
872
+ degrees=self.hyp['degrees'],
873
+ translate=self.hyp['translate'],
874
+ scale=self.hyp['scale'],
875
+ shear=self.hyp['shear'],
876
+ perspective=self.hyp['perspective'],
877
+ border=self.mosaic_border) # border to remove
878
+
879
+ return img9, labels9
880
+
881
+ @staticmethod
882
+ def collate_fn(batch):
883
+ im, label, path, shapes = zip(*batch) # transposed
884
+ for i, lb in enumerate(label):
885
+ lb[:, 0] = i # add target image index for build_targets()
886
+ return torch.stack(im, 0), torch.cat(label, 0), path, shapes
887
+
888
+ @staticmethod
889
+ def collate_fn4(batch):
890
+ im, label, path, shapes = zip(*batch) # transposed
891
+ n = len(shapes) // 4
892
+ im4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
893
+
894
+ ho = torch.tensor([[0.0, 0, 0, 1, 0, 0]])
895
+ wo = torch.tensor([[0.0, 0, 1, 0, 0, 0]])
896
+ s = torch.tensor([[1, 1, 0.5, 0.5, 0.5, 0.5]]) # scale
897
+ for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
898
+ i *= 4
899
+ if random.random() < 0.5:
900
+ im1 = F.interpolate(im[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear',
901
+ align_corners=False)[0].type(im[i].type())
902
+ lb = label[i]
903
+ else:
904
+ im1 = torch.cat((torch.cat((im[i], im[i + 1]), 1), torch.cat((im[i + 2], im[i + 3]), 1)), 2)
905
+ lb = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
906
+ im4.append(im1)
907
+ label4.append(lb)
908
+
909
+ for i, lb in enumerate(label4):
910
+ lb[:, 0] = i # add target image index for build_targets()
911
+
912
+ return torch.stack(im4, 0), torch.cat(label4, 0), path4, shapes4
913
+
914
+
915
+ # Ancillary functions --------------------------------------------------------------------------------------------------
916
+ def flatten_recursive(path=DATASETS_DIR / 'coco128'):
917
+ # Flatten a recursive directory by bringing all files to top level
918
+ new_path = Path(f'{str(path)}_flat')
919
+ if os.path.exists(new_path):
920
+ shutil.rmtree(new_path) # delete output folder
921
+ os.makedirs(new_path) # make new output folder
922
+ for file in tqdm(glob.glob(f'{str(Path(path))}/**/*.*', recursive=True)):
923
+ shutil.copyfile(file, new_path / Path(file).name)
924
+
925
+
926
+ def extract_boxes(path=DATASETS_DIR / 'coco128'): # from utils.dataloaders import *; extract_boxes()
927
+ # Convert detection dataset into classification dataset, with one directory per class
928
+ path = Path(path) # images dir
929
+ shutil.rmtree(path / 'classification') if (path / 'classification').is_dir() else None # remove existing
930
+ files = list(path.rglob('*.*'))
931
+ n = len(files) # number of files
932
+ for im_file in tqdm(files, total=n):
933
+ if im_file.suffix[1:] in IMG_FORMATS:
934
+ # image
935
+ im = cv2.imread(str(im_file))[..., ::-1] # BGR to RGB
936
+ h, w = im.shape[:2]
937
+
938
+ # labels
939
+ lb_file = Path(img2label_paths([str(im_file)])[0])
940
+ if Path(lb_file).exists():
941
+ with open(lb_file) as f:
942
+ lb = np.array([x.split() for x in f.read().strip().splitlines()], dtype=np.float32) # labels
943
+
944
+ for j, x in enumerate(lb):
945
+ c = int(x[0]) # class
946
+ f = (path / 'classifier') / f'{c}' / f'{path.stem}_{im_file.stem}_{j}.jpg' # new filename
947
+ if not f.parent.is_dir():
948
+ f.parent.mkdir(parents=True)
949
+
950
+ b = x[1:] * [w, h, w, h] # box
951
+ # b[2:] = b[2:].max() # rectangle to square
952
+ b[2:] = b[2:] * 1.2 + 3 # pad
953
+ b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(int)
954
+
955
+ b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
956
+ b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
957
+ assert cv2.imwrite(str(f), im[b[1]:b[3], b[0]:b[2]]), f'box failure in {f}'
958
+
959
+
960
+ def autosplit(path=DATASETS_DIR / 'coco128/images', weights=(0.9, 0.1, 0.0), annotated_only=False):
961
+ """ Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files
962
+ Usage: from utils.dataloaders import *; autosplit()
963
+ Arguments
964
+ path: Path to images directory
965
+ weights: Train, val, test weights (list, tuple)
966
+ annotated_only: Only use images with an annotated txt file
967
+ """
968
+ path = Path(path) # images dir
969
+ files = sorted(x for x in path.rglob('*.*') if x.suffix[1:].lower() in IMG_FORMATS) # image files only
970
+ n = len(files) # number of files
971
+ random.seed(0) # for reproducibility
972
+ indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
973
+
974
+ txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
975
+ for x in txt:
976
+ if (path.parent / x).exists():
977
+ (path.parent / x).unlink() # remove existing
978
+
979
+ print(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
980
+ for i, img in tqdm(zip(indices, files), total=n):
981
+ if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
982
+ with open(path.parent / txt[i], 'a') as f:
983
+ f.write(f'./{img.relative_to(path.parent).as_posix()}' + '\n') # add image to txt file
984
+
985
+
986
+ def verify_image_label(args):
987
+ # Verify one image-label pair
988
+ im_file, lb_file, prefix = args
989
+ nm, nf, ne, nc, msg, segments = 0, 0, 0, 0, '', [] # number (missing, found, empty, corrupt), message, segments
990
+ try:
991
+ # verify images
992
+ im = Image.open(im_file)
993
+ im.verify() # PIL verify
994
+ shape = exif_size(im) # image size
995
+ assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
996
+ assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
997
+ if im.format.lower() in ('jpg', 'jpeg'):
998
+ with open(im_file, 'rb') as f:
999
+ f.seek(-2, 2)
1000
+ if f.read() != b'\xff\xd9': # corrupt JPEG
1001
+ ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
1002
+ msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved'
1003
+
1004
+ # verify labels
1005
+ if os.path.isfile(lb_file):
1006
+ nf = 1 # label found
1007
+ with open(lb_file) as f:
1008
+ lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
1009
+ if any(len(x) > 6 for x in lb): # is segment
1010
+ classes = np.array([x[0] for x in lb], dtype=np.float32)
1011
+ segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...)
1012
+ lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
1013
+ lb = np.array(lb, dtype=np.float32)
1014
+ nl = len(lb)
1015
+ if nl:
1016
+ assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected'
1017
+ assert (lb >= 0).all(), f'negative label values {lb[lb < 0]}'
1018
+ assert (lb[:, 1:] <= 1).all(), f'non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}'
1019
+ _, i = np.unique(lb, axis=0, return_index=True)
1020
+ if len(i) < nl: # duplicate row check
1021
+ lb = lb[i] # remove duplicates
1022
+ if segments:
1023
+ segments = [segments[x] for x in i]
1024
+ msg = f'{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed'
1025
+ else:
1026
+ ne = 1 # label empty
1027
+ lb = np.zeros((0, 5), dtype=np.float32)
1028
+ else:
1029
+ nm = 1 # label missing
1030
+ lb = np.zeros((0, 5), dtype=np.float32)
1031
+ return im_file, lb, shape, segments, nm, nf, ne, nc, msg
1032
+ except Exception as e:
1033
+ nc = 1
1034
+ msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}'
1035
+ return [None, None, None, None, nm, nf, ne, nc, msg]
1036
+
1037
+
1038
+ class HUBDatasetStats():
1039
+ """ Class for generating HUB dataset JSON and `-hub` dataset directory
1040
+
1041
+ Arguments
1042
+ path: Path to data.yaml or data.zip (with data.yaml inside data.zip)
1043
+ autodownload: Attempt to download dataset if not found locally
1044
+
1045
+ Usage
1046
+ from utils.dataloaders import HUBDatasetStats
1047
+ stats = HUBDatasetStats('coco128.yaml', autodownload=True) # usage 1
1048
+ stats = HUBDatasetStats('path/to/coco128.zip') # usage 2
1049
+ stats.get_json(save=False)
1050
+ stats.process_images()
1051
+ """
1052
+
1053
+ def __init__(self, path='coco128.yaml', autodownload=False):
1054
+ # Initialize class
1055
+ zipped, data_dir, yaml_path = self._unzip(Path(path))
1056
+ try:
1057
+ with open(check_yaml(yaml_path), errors='ignore') as f:
1058
+ data = yaml.safe_load(f) # data dict
1059
+ if zipped:
1060
+ data['path'] = data_dir
1061
+ except Exception as e:
1062
+ raise Exception("error/HUB/dataset_stats/yaml_load") from e
1063
+
1064
+ check_dataset(data, autodownload) # download dataset if missing
1065
+ self.hub_dir = Path(data['path'] + '-hub')
1066
+ self.im_dir = self.hub_dir / 'images'
1067
+ self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images
1068
+ self.stats = {'nc': data['nc'], 'names': list(data['names'].values())} # statistics dictionary
1069
+ self.data = data
1070
+
1071
+ @staticmethod
1072
+ def _find_yaml(dir):
1073
+ # Return data.yaml file
1074
+ files = list(dir.glob('*.yaml')) or list(dir.rglob('*.yaml')) # try root level first and then recursive
1075
+ assert files, f'No *.yaml file found in {dir}'
1076
+ if len(files) > 1:
1077
+ files = [f for f in files if f.stem == dir.stem] # prefer *.yaml files that match dir name
1078
+ assert files, f'Multiple *.yaml files found in {dir}, only 1 *.yaml file allowed'
1079
+ assert len(files) == 1, f'Multiple *.yaml files found: {files}, only 1 *.yaml file allowed in {dir}'
1080
+ return files[0]
1081
+
1082
+ def _unzip(self, path):
1083
+ # Unzip data.zip
1084
+ if not str(path).endswith('.zip'): # path is data.yaml
1085
+ return False, None, path
1086
+ assert Path(path).is_file(), f'Error unzipping {path}, file not found'
1087
+ unzip_file(path, path=path.parent)
1088
+ dir = path.with_suffix('') # dataset directory == zip name
1089
+ assert dir.is_dir(), f'Error unzipping {path}, {dir} not found. path/to/abc.zip MUST unzip to path/to/abc/'
1090
+ return True, str(dir), self._find_yaml(dir) # zipped, data_dir, yaml_path
1091
+
1092
+ def _hub_ops(self, f, max_dim=1920):
1093
+ # HUB ops for 1 image 'f': resize and save at reduced quality in /dataset-hub for web/app viewing
1094
+ f_new = self.im_dir / Path(f).name # dataset-hub image filename
1095
+ try: # use PIL
1096
+ im = Image.open(f)
1097
+ r = max_dim / max(im.height, im.width) # ratio
1098
+ if r < 1.0: # image too large
1099
+ im = im.resize((int(im.width * r), int(im.height * r)))
1100
+ im.save(f_new, 'JPEG', quality=50, optimize=True) # save
1101
+ except Exception as e: # use OpenCV
1102
+ LOGGER.info(f'WARNING ⚠️ HUB ops PIL failure {f}: {e}')
1103
+ im = cv2.imread(f)
1104
+ im_height, im_width = im.shape[:2]
1105
+ r = max_dim / max(im_height, im_width) # ratio
1106
+ if r < 1.0: # image too large
1107
+ im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
1108
+ cv2.imwrite(str(f_new), im)
1109
+
1110
+ def get_json(self, save=False, verbose=False):
1111
+ # Return dataset JSON for Ultralytics HUB
1112
+ def _round(labels):
1113
+ # Update labels to integer class and 6 decimal place floats
1114
+ return [[int(c), *(round(x, 4) for x in points)] for c, *points in labels]
1115
+
1116
+ for split in 'train', 'val', 'test':
1117
+ if self.data.get(split) is None:
1118
+ self.stats[split] = None # i.e. no test set
1119
+ continue
1120
+ dataset = LoadImagesAndLabels(self.data[split]) # load dataset
1121
+ x = np.array([
1122
+ np.bincount(label[:, 0].astype(int), minlength=self.data['nc'])
1123
+ for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics')]) # shape(128x80)
1124
+ self.stats[split] = {
1125
+ 'instance_stats': {
1126
+ 'total': int(x.sum()),
1127
+ 'per_class': x.sum(0).tolist()},
1128
+ 'image_stats': {
1129
+ 'total': dataset.n,
1130
+ 'unlabelled': int(np.all(x == 0, 1).sum()),
1131
+ 'per_class': (x > 0).sum(0).tolist()},
1132
+ 'labels': [{
1133
+ str(Path(k).name): _round(v.tolist())} for k, v in zip(dataset.im_files, dataset.labels)]}
1134
+
1135
+ # Save, print and return
1136
+ if save:
1137
+ stats_path = self.hub_dir / 'stats.json'
1138
+ print(f'Saving {stats_path.resolve()}...')
1139
+ with open(stats_path, 'w') as f:
1140
+ json.dump(self.stats, f) # save stats.json
1141
+ if verbose:
1142
+ print(json.dumps(self.stats, indent=2, sort_keys=False))
1143
+ return self.stats
1144
+
1145
+ def process_images(self):
1146
+ # Compress images for Ultralytics HUB
1147
+ for split in 'train', 'val', 'test':
1148
+ if self.data.get(split) is None:
1149
+ continue
1150
+ dataset = LoadImagesAndLabels(self.data[split]) # load dataset
1151
+ desc = f'{split} images'
1152
+ for _ in tqdm(ThreadPool(NUM_THREADS).imap(self._hub_ops, dataset.im_files), total=dataset.n, desc=desc):
1153
+ pass
1154
+ print(f'Done. All images saved to {self.im_dir}')
1155
+ return self.im_dir
1156
+
1157
+
1158
+ # Classification dataloaders -------------------------------------------------------------------------------------------
1159
+ class ClassificationDataset(torchvision.datasets.ImageFolder):
1160
+ """
1161
+ YOLOv5 Classification Dataset.
1162
+ Arguments
1163
+ root: Dataset path
1164
+ transform: torchvision transforms, used by default
1165
+ album_transform: Albumentations transforms, used if installed
1166
+ """
1167
+
1168
+ def __init__(self, root, augment, imgsz, cache=False):
1169
+ super().__init__(root=root)
1170
+ self.torch_transforms = classify_transforms(imgsz)
1171
+ self.album_transforms = classify_albumentations(augment, imgsz) if augment else None
1172
+ self.cache_ram = cache is True or cache == 'ram'
1173
+ self.cache_disk = cache == 'disk'
1174
+ self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
1175
+
1176
+ def __getitem__(self, i):
1177
+ f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
1178
+ if self.cache_ram and im is None:
1179
+ im = self.samples[i][3] = cv2.imread(f)
1180
+ elif self.cache_disk:
1181
+ if not fn.exists(): # load npy
1182
+ np.save(fn.as_posix(), cv2.imread(f))
1183
+ im = np.load(fn)
1184
+ else: # read image
1185
+ im = cv2.imread(f) # BGR
1186
+ if self.album_transforms:
1187
+ sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"]
1188
+ else:
1189
+ sample = self.torch_transforms(im)
1190
+ return sample, j
1191
+
1192
+
1193
+ def create_classification_dataloader(path,
1194
+ imgsz=224,
1195
+ batch_size=16,
1196
+ augment=True,
1197
+ cache=False,
1198
+ rank=-1,
1199
+ workers=8,
1200
+ shuffle=True):
1201
+ # Returns Dataloader object to be used with YOLOv5 Classifier
1202
+ with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
1203
+ dataset = ClassificationDataset(root=path, imgsz=imgsz, augment=augment, cache=cache)
1204
+ batch_size = min(batch_size, len(dataset))
1205
+ nd = torch.cuda.device_count()
1206
+ nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers])
1207
+ sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
1208
+ generator = torch.Generator()
1209
+ generator.manual_seed(6148914691236517205 + RANK)
1210
+ return InfiniteDataLoader(dataset,
1211
+ batch_size=batch_size,
1212
+ shuffle=shuffle and sampler is None,
1213
+ num_workers=nw,
1214
+ sampler=sampler,
1215
+ pin_memory=PIN_MEMORY,
1216
+ worker_init_fn=seed_worker,
1217
+ generator=generator) # or DataLoader(persistent_workers=True)
utils/downloads.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import subprocess
4
+ import urllib
5
+ from pathlib import Path
6
+
7
+ import requests
8
+ import torch
9
+
10
+
11
+ def is_url(url, check=True):
12
+ # Check if string is URL and check if URL exists
13
+ try:
14
+ url = str(url)
15
+ result = urllib.parse.urlparse(url)
16
+ assert all([result.scheme, result.netloc]) # check if is url
17
+ return (urllib.request.urlopen(url).getcode() == 200) if check else True # check if exists online
18
+ except (AssertionError, urllib.request.HTTPError):
19
+ return False
20
+
21
+
22
+ def gsutil_getsize(url=''):
23
+ # gs://bucket/file size https://cloud.google.com/storage/docs/gsutil/commands/du
24
+ s = subprocess.check_output(f'gsutil du {url}', shell=True).decode('utf-8')
25
+ return eval(s.split(' ')[0]) if len(s) else 0 # bytes
26
+
27
+
28
+ def url_getsize(url='https://ultralytics.com/images/bus.jpg'):
29
+ # Return downloadable file size in bytes
30
+ response = requests.head(url, allow_redirects=True)
31
+ return int(response.headers.get('content-length', -1))
32
+
33
+
34
+ def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
35
+ # Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
36
+ from utils.general import LOGGER
37
+
38
+ file = Path(file)
39
+ assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}"
40
+ try: # url1
41
+ LOGGER.info(f'Downloading {url} to {file}...')
42
+ torch.hub.download_url_to_file(url, str(file), progress=LOGGER.level <= logging.INFO)
43
+ assert file.exists() and file.stat().st_size > min_bytes, assert_msg # check
44
+ except Exception as e: # url2
45
+ if file.exists():
46
+ file.unlink() # remove partial downloads
47
+ LOGGER.info(f'ERROR: {e}\nRe-attempting {url2 or url} to {file}...')
48
+ os.system(f"curl -# -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
49
+ finally:
50
+ if not file.exists() or file.stat().st_size < min_bytes: # check
51
+ if file.exists():
52
+ file.unlink() # remove partial downloads
53
+ LOGGER.info(f"ERROR: {assert_msg}\n{error_msg}")
54
+ LOGGER.info('')
55
+
56
+
57
+ def attempt_download(file, repo='ultralytics/yolov5', release='v7.0'):
58
+ # Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v7.0', etc.
59
+ from utils.general import LOGGER
60
+
61
+ def github_assets(repository, version='latest'):
62
+ # Return GitHub repo tag (i.e. 'v7.0') and assets (i.e. ['yolov5s.pt', 'yolov5m.pt', ...])
63
+ if version != 'latest':
64
+ version = f'tags/{version}' # i.e. tags/v7.0
65
+ response = requests.get(f'https://api.github.com/repos/{repository}/releases/{version}').json() # github api
66
+ return response['tag_name'], [x['name'] for x in response['assets']] # tag, assets
67
+
68
+ file = Path(str(file).strip().replace("'", ''))
69
+ if not file.exists():
70
+ # URL specified
71
+ name = Path(urllib.parse.unquote(str(file))).name # decode '%2F' to '/' etc.
72
+ if str(file).startswith(('http:/', 'https:/')): # download
73
+ url = str(file).replace(':/', '://') # Pathlib turns :// -> :/
74
+ file = name.split('?')[0] # parse authentication https://url.com/file.txt?auth...
75
+ if Path(file).is_file():
76
+ LOGGER.info(f'Found {url} locally at {file}') # file already exists
77
+ else:
78
+ safe_download(file=file, url=url, min_bytes=1E5)
79
+ return file
80
+
81
+ # GitHub assets
82
+ assets = [f'yolov5{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '6', '-cls', '-seg')] # default
83
+ try:
84
+ tag, assets = github_assets(repo, release)
85
+ except Exception:
86
+ try:
87
+ tag, assets = github_assets(repo) # latest release
88
+ except Exception:
89
+ try:
90
+ tag = subprocess.check_output('git tag', shell=True, stderr=subprocess.STDOUT).decode().split()[-1]
91
+ except Exception:
92
+ tag = release
93
+
94
+ file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required)
95
+ if name in assets:
96
+ url3 = 'https://drive.google.com/drive/folders/1EFQTEUeXWSFww0luse2jB9M1QNZQGwNl' # backup gdrive mirror
97
+ safe_download(
98
+ file,
99
+ url=f'https://github.com/{repo}/releases/download/{tag}/{name}',
100
+ min_bytes=1E5,
101
+ error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/{tag} or {url3}')
102
+
103
+ return str(file)
utils/general.py ADDED
@@ -0,0 +1,1227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import glob
3
+ import inspect
4
+ import logging
5
+ import logging.config
6
+ import math
7
+ import os
8
+ import platform
9
+ import random
10
+ import re
11
+ import signal
12
+ import sys
13
+ import time
14
+ import urllib
15
+ from copy import deepcopy
16
+ from datetime import datetime
17
+ from itertools import repeat
18
+ from multiprocessing.pool import ThreadPool
19
+ from pathlib import Path
20
+ from subprocess import check_output
21
+ from tarfile import is_tarfile
22
+ from typing import Optional
23
+ from zipfile import ZipFile, is_zipfile
24
+
25
+ import cv2
26
+ import IPython
27
+ import numpy as np
28
+ import pandas as pd
29
+ import pkg_resources as pkg
30
+ import torch
31
+ import torchvision
32
+ import yaml
33
+
34
+ from utils import TryExcept, emojis
35
+ from utils.downloads import gsutil_getsize
36
+ from utils.metrics import box_iou, fitness
37
+
38
+ FILE = Path(__file__).resolve()
39
+ ROOT = FILE.parents[1] # YOLO root directory
40
+ RANK = int(os.getenv('RANK', -1))
41
+
42
+ # Settings
43
+ NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
44
+ DATASETS_DIR = Path(os.getenv('YOLOv5_DATASETS_DIR', ROOT.parent / 'datasets')) # global datasets directory
45
+ AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
46
+ VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true' # global verbose mode
47
+ TQDM_BAR_FORMAT = '{l_bar}{bar:10}| {n_fmt}/{total_fmt} {elapsed}' # tqdm bar format
48
+ FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf
49
+
50
+ torch.set_printoptions(linewidth=320, precision=5, profile='long')
51
+ np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
52
+ pd.options.display.max_columns = 10
53
+ cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
54
+ os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
55
+ os.environ['OMP_NUM_THREADS'] = '1' if platform.system() == 'darwin' else str(NUM_THREADS) # OpenMP (PyTorch and SciPy)
56
+
57
+
58
+ def is_ascii(s=''):
59
+ # Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
60
+ s = str(s) # convert list, tuple, None, etc. to str
61
+ return len(s.encode().decode('ascii', 'ignore')) == len(s)
62
+
63
+
64
+ def is_chinese(s='人工智能'):
65
+ # Is string composed of any Chinese characters?
66
+ return bool(re.search('[\u4e00-\u9fff]', str(s)))
67
+
68
+
69
+ def is_colab():
70
+ # Is environment a Google Colab instance?
71
+ return 'google.colab' in sys.modules
72
+
73
+
74
+ def is_notebook():
75
+ # Is environment a Jupyter notebook? Verified on Colab, Jupyterlab, Kaggle, Paperspace
76
+ ipython_type = str(type(IPython.get_ipython()))
77
+ return 'colab' in ipython_type or 'zmqshell' in ipython_type
78
+
79
+
80
+ def is_kaggle():
81
+ # Is environment a Kaggle Notebook?
82
+ return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
83
+
84
+
85
+ def is_docker() -> bool:
86
+ """Check if the process runs inside a docker container."""
87
+ if Path("/.dockerenv").exists():
88
+ return True
89
+ try: # check if docker is in control groups
90
+ with open("/proc/self/cgroup") as file:
91
+ return any("docker" in line for line in file)
92
+ except OSError:
93
+ return False
94
+
95
+
96
+ def is_writeable(dir, test=False):
97
+ # Return True if directory has write permissions, test opening a file with write permissions if test=True
98
+ if not test:
99
+ return os.access(dir, os.W_OK) # possible issues on Windows
100
+ file = Path(dir) / 'tmp.txt'
101
+ try:
102
+ with open(file, 'w'): # open file with write permissions
103
+ pass
104
+ file.unlink() # remove file
105
+ return True
106
+ except OSError:
107
+ return False
108
+
109
+
110
+ LOGGING_NAME = "yolov5"
111
+
112
+
113
+ def set_logging(name=LOGGING_NAME, verbose=True):
114
+ # sets up logging for the given name
115
+ rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
116
+ level = logging.INFO if verbose and rank in {-1, 0} else logging.ERROR
117
+ logging.config.dictConfig({
118
+ "version": 1,
119
+ "disable_existing_loggers": False,
120
+ "formatters": {
121
+ name: {
122
+ "format": "%(message)s"}},
123
+ "handlers": {
124
+ name: {
125
+ "class": "logging.StreamHandler",
126
+ "formatter": name,
127
+ "level": level,}},
128
+ "loggers": {
129
+ name: {
130
+ "level": level,
131
+ "handlers": [name],
132
+ "propagate": False,}}})
133
+
134
+
135
+ set_logging(LOGGING_NAME) # run before defining LOGGER
136
+ LOGGER = logging.getLogger(LOGGING_NAME) # define globally (used in train.py, val.py, detect.py, etc.)
137
+ if platform.system() == 'Windows':
138
+ for fn in LOGGER.info, LOGGER.warning:
139
+ setattr(LOGGER, fn.__name__, lambda x: fn(emojis(x))) # emoji safe logging
140
+
141
+
142
+ def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
143
+ # Return path of user configuration directory. Prefer environment variable if exists. Make dir if required.
144
+ env = os.getenv(env_var)
145
+ if env:
146
+ path = Path(env) # use environment variable
147
+ else:
148
+ cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs
149
+ path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir
150
+ path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable
151
+ path.mkdir(exist_ok=True) # make if required
152
+ return path
153
+
154
+
155
+ CONFIG_DIR = user_config_dir() # Ultralytics settings dir
156
+
157
+
158
+ class Profile(contextlib.ContextDecorator):
159
+ # YOLO Profile class. Usage: @Profile() decorator or 'with Profile():' context manager
160
+ def __init__(self, t=0.0):
161
+ self.t = t
162
+ self.cuda = torch.cuda.is_available()
163
+
164
+ def __enter__(self):
165
+ self.start = self.time()
166
+ return self
167
+
168
+ def __exit__(self, type, value, traceback):
169
+ self.dt = self.time() - self.start # delta-time
170
+ self.t += self.dt # accumulate dt
171
+
172
+ def time(self):
173
+ if self.cuda:
174
+ torch.cuda.synchronize()
175
+ return time.time()
176
+
177
+
178
+ class Timeout(contextlib.ContextDecorator):
179
+ # YOLO Timeout class. Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
180
+ def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
181
+ self.seconds = int(seconds)
182
+ self.timeout_message = timeout_msg
183
+ self.suppress = bool(suppress_timeout_errors)
184
+
185
+ def _timeout_handler(self, signum, frame):
186
+ raise TimeoutError(self.timeout_message)
187
+
188
+ def __enter__(self):
189
+ if platform.system() != 'Windows': # not supported on Windows
190
+ signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
191
+ signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
192
+
193
+ def __exit__(self, exc_type, exc_val, exc_tb):
194
+ if platform.system() != 'Windows':
195
+ signal.alarm(0) # Cancel SIGALRM if it's scheduled
196
+ if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
197
+ return True
198
+
199
+
200
+ class WorkingDirectory(contextlib.ContextDecorator):
201
+ # Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager
202
+ def __init__(self, new_dir):
203
+ self.dir = new_dir # new dir
204
+ self.cwd = Path.cwd().resolve() # current dir
205
+
206
+ def __enter__(self):
207
+ os.chdir(self.dir)
208
+
209
+ def __exit__(self, exc_type, exc_val, exc_tb):
210
+ os.chdir(self.cwd)
211
+
212
+
213
+ def methods(instance):
214
+ # Get class/instance methods
215
+ return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
216
+
217
+
218
+ def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
219
+ # Print function arguments (optional args dict)
220
+ x = inspect.currentframe().f_back # previous frame
221
+ file, _, func, _, _ = inspect.getframeinfo(x)
222
+ if args is None: # get args automatically
223
+ args, _, _, frm = inspect.getargvalues(x)
224
+ args = {k: v for k, v in frm.items() if k in args}
225
+ try:
226
+ file = Path(file).resolve().relative_to(ROOT).with_suffix('')
227
+ except ValueError:
228
+ file = Path(file).stem
229
+ s = (f'{file}: ' if show_file else '') + (f'{func}: ' if show_func else '')
230
+ LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))
231
+
232
+
233
+ def init_seeds(seed=0, deterministic=False):
234
+ # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
235
+ random.seed(seed)
236
+ np.random.seed(seed)
237
+ torch.manual_seed(seed)
238
+ torch.cuda.manual_seed(seed)
239
+ torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe
240
+ # torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287
241
+ if deterministic and check_version(torch.__version__, '1.12.0'): # https://github.com/ultralytics/yolov5/pull/8213
242
+ torch.use_deterministic_algorithms(True)
243
+ torch.backends.cudnn.deterministic = True
244
+ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
245
+ os.environ['PYTHONHASHSEED'] = str(seed)
246
+
247
+
248
+ def intersect_dicts(da, db, exclude=()):
249
+ # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
250
+ return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
251
+
252
+
253
+ def get_default_args(func):
254
+ # Get func() default arguments
255
+ signature = inspect.signature(func)
256
+ return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
257
+
258
+
259
+ def get_latest_run(search_dir='.'):
260
+ # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
261
+ last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
262
+ return max(last_list, key=os.path.getctime) if last_list else ''
263
+
264
+
265
+ def file_age(path=__file__):
266
+ # Return days since last file update
267
+ dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
268
+ return dt.days # + dt.seconds / 86400 # fractional days
269
+
270
+
271
+ def file_date(path=__file__):
272
+ # Return human-readable file modification date, i.e. '2021-3-26'
273
+ t = datetime.fromtimestamp(Path(path).stat().st_mtime)
274
+ return f'{t.year}-{t.month}-{t.day}'
275
+
276
+
277
+ def file_size(path):
278
+ # Return file/dir size (MB)
279
+ mb = 1 << 20 # bytes to MiB (1024 ** 2)
280
+ path = Path(path)
281
+ if path.is_file():
282
+ return path.stat().st_size / mb
283
+ elif path.is_dir():
284
+ return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb
285
+ else:
286
+ return 0.0
287
+
288
+
289
+ def check_online():
290
+ # Check internet connectivity
291
+ import socket
292
+
293
+ def run_once():
294
+ # Check once
295
+ try:
296
+ socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
297
+ return True
298
+ except OSError:
299
+ return False
300
+
301
+ return run_once() or run_once() # check twice to increase robustness to intermittent connectivity issues
302
+
303
+
304
+ def git_describe(path=ROOT): # path must be a directory
305
+ # Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
306
+ try:
307
+ assert (Path(path) / '.git').is_dir()
308
+ return check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
309
+ except Exception:
310
+ return ''
311
+
312
+
313
+ @TryExcept()
314
+ @WorkingDirectory(ROOT)
315
+ def check_git_status(repo='WongKinYiu/yolov9', branch='main'):
316
+ # YOLO status check, recommend 'git pull' if code is out of date
317
+ url = f'https://github.com/{repo}'
318
+ msg = f', for updates see {url}'
319
+ s = colorstr('github: ') # string
320
+ assert Path('.git').exists(), s + 'skipping check (not a git repository)' + msg
321
+ assert check_online(), s + 'skipping check (offline)' + msg
322
+
323
+ splits = re.split(pattern=r'\s', string=check_output('git remote -v', shell=True).decode())
324
+ matches = [repo in s for s in splits]
325
+ if any(matches):
326
+ remote = splits[matches.index(True) - 1]
327
+ else:
328
+ remote = 'ultralytics'
329
+ check_output(f'git remote add {remote} {url}', shell=True)
330
+ check_output(f'git fetch {remote}', shell=True, timeout=5) # git fetch
331
+ local_branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
332
+ n = int(check_output(f'git rev-list {local_branch}..{remote}/{branch} --count', shell=True)) # commits behind
333
+ if n > 0:
334
+ pull = 'git pull' if remote == 'origin' else f'git pull {remote} {branch}'
335
+ s += f"⚠️ YOLO is out of date by {n} commit{'s' * (n > 1)}. Use `{pull}` or `git clone {url}` to update."
336
+ else:
337
+ s += f'up to date with {url} ✅'
338
+ LOGGER.info(s)
339
+
340
+
341
+ @WorkingDirectory(ROOT)
342
+ def check_git_info(path='.'):
343
+ # YOLO git info check, return {remote, branch, commit}
344
+ check_requirements('gitpython')
345
+ import git
346
+ try:
347
+ repo = git.Repo(path)
348
+ remote = repo.remotes.origin.url.replace('.git', '') # i.e. 'https://github.com/WongKinYiu/yolov9'
349
+ commit = repo.head.commit.hexsha # i.e. '3134699c73af83aac2a481435550b968d5792c0d'
350
+ try:
351
+ branch = repo.active_branch.name # i.e. 'main'
352
+ except TypeError: # not on any branch
353
+ branch = None # i.e. 'detached HEAD' state
354
+ return {'remote': remote, 'branch': branch, 'commit': commit}
355
+ except git.exc.InvalidGitRepositoryError: # path is not a git dir
356
+ return {'remote': None, 'branch': None, 'commit': None}
357
+
358
+
359
+ def check_python(minimum='3.7.0'):
360
+ # Check current python version vs. required python version
361
+ check_version(platform.python_version(), minimum, name='Python ', hard=True)
362
+
363
+
364
+ def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False, verbose=False):
365
+ # Check version vs. required version
366
+ current, minimum = (pkg.parse_version(x) for x in (current, minimum))
367
+ result = (current == minimum) if pinned else (current >= minimum) # bool
368
+ s = f'WARNING ⚠️ {name}{minimum} is required by YOLO, but {name}{current} is currently installed' # string
369
+ if hard:
370
+ assert result, emojis(s) # assert min requirements met
371
+ if verbose and not result:
372
+ LOGGER.warning(s)
373
+ return result
374
+
375
+
376
+ @TryExcept()
377
+ def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=''):
378
+ # Check installed dependencies meet YOLO requirements (pass *.txt file or list of packages or single package str)
379
+ prefix = colorstr('red', 'bold', 'requirements:')
380
+ check_python() # check python version
381
+ if isinstance(requirements, Path): # requirements.txt file
382
+ file = requirements.resolve()
383
+ assert file.exists(), f"{prefix} {file} not found, check failed."
384
+ with file.open() as f:
385
+ requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude]
386
+ elif isinstance(requirements, str):
387
+ requirements = [requirements]
388
+
389
+ s = ''
390
+ n = 0
391
+ for r in requirements:
392
+ try:
393
+ pkg.require(r)
394
+ except (pkg.VersionConflict, pkg.DistributionNotFound): # exception if requirements not met
395
+ s += f'"{r}" '
396
+ n += 1
397
+
398
+ if s and install and AUTOINSTALL: # check environment variable
399
+ LOGGER.info(f"{prefix} YOLO requirement{'s' * (n > 1)} {s}not found, attempting AutoUpdate...")
400
+ try:
401
+ # assert check_online(), "AutoUpdate skipped (offline)"
402
+ LOGGER.info(check_output(f'pip install {s} {cmds}', shell=True).decode())
403
+ source = file if 'file' in locals() else requirements
404
+ s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
405
+ f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
406
+ LOGGER.info(s)
407
+ except Exception as e:
408
+ LOGGER.warning(f'{prefix} ❌ {e}')
409
+
410
+
411
+ def check_img_size(imgsz, s=32, floor=0):
412
+ # Verify image size is a multiple of stride s in each dimension
413
+ if isinstance(imgsz, int): # integer i.e. img_size=640
414
+ new_size = max(make_divisible(imgsz, int(s)), floor)
415
+ else: # list i.e. img_size=[640, 480]
416
+ imgsz = list(imgsz) # convert to list if tuple
417
+ new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
418
+ if new_size != imgsz:
419
+ LOGGER.warning(f'WARNING ⚠️ --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
420
+ return new_size
421
+
422
+
423
+ def check_imshow(warn=False):
424
+ # Check if environment supports image displays
425
+ try:
426
+ assert not is_notebook()
427
+ assert not is_docker()
428
+ cv2.imshow('test', np.zeros((1, 1, 3)))
429
+ cv2.waitKey(1)
430
+ cv2.destroyAllWindows()
431
+ cv2.waitKey(1)
432
+ return True
433
+ except Exception as e:
434
+ if warn:
435
+ LOGGER.warning(f'WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}')
436
+ return False
437
+
438
+
439
+ def check_suffix(file='yolo.pt', suffix=('.pt',), msg=''):
440
+ # Check file(s) for acceptable suffix
441
+ if file and suffix:
442
+ if isinstance(suffix, str):
443
+ suffix = [suffix]
444
+ for f in file if isinstance(file, (list, tuple)) else [file]:
445
+ s = Path(f).suffix.lower() # file suffix
446
+ if len(s):
447
+ assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}"
448
+
449
+
450
+ def check_yaml(file, suffix=('.yaml', '.yml')):
451
+ # Search/download YAML file (if necessary) and return path, checking suffix
452
+ return check_file(file, suffix)
453
+
454
+
455
+ def check_file(file, suffix=''):
456
+ # Search/download file (if necessary) and return path
457
+ check_suffix(file, suffix) # optional
458
+ file = str(file) # convert to str()
459
+ if os.path.isfile(file) or not file: # exists
460
+ return file
461
+ elif file.startswith(('http:/', 'https:/')): # download
462
+ url = file # warning: Pathlib turns :// -> :/
463
+ file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
464
+ if os.path.isfile(file):
465
+ LOGGER.info(f'Found {url} locally at {file}') # file already exists
466
+ else:
467
+ LOGGER.info(f'Downloading {url} to {file}...')
468
+ torch.hub.download_url_to_file(url, file)
469
+ assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
470
+ return file
471
+ elif file.startswith('clearml://'): # ClearML Dataset ID
472
+ assert 'clearml' in sys.modules, "ClearML is not installed, so cannot use ClearML dataset. Try running 'pip install clearml'."
473
+ return file
474
+ else: # search
475
+ files = []
476
+ for d in 'data', 'models', 'utils': # search directories
477
+ files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
478
+ assert len(files), f'File not found: {file}' # assert file was found
479
+ assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
480
+ return files[0] # return file
481
+
482
+
483
+ def check_font(font=FONT, progress=False):
484
+ # Download font to CONFIG_DIR if necessary
485
+ font = Path(font)
486
+ file = CONFIG_DIR / font.name
487
+ if not font.exists() and not file.exists():
488
+ url = f'https://ultralytics.com/assets/{font.name}'
489
+ LOGGER.info(f'Downloading {url} to {file}...')
490
+ torch.hub.download_url_to_file(url, str(file), progress=progress)
491
+
492
+
493
+ def check_dataset(data, autodownload=True):
494
+ # Download, check and/or unzip dataset if not found locally
495
+
496
+ # Download (optional)
497
+ extract_dir = ''
498
+ if isinstance(data, (str, Path)) and (is_zipfile(data) or is_tarfile(data)):
499
+ download(data, dir=f'{DATASETS_DIR}/{Path(data).stem}', unzip=True, delete=False, curl=False, threads=1)
500
+ data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml'))
501
+ extract_dir, autodownload = data.parent, False
502
+
503
+ # Read yaml (optional)
504
+ if isinstance(data, (str, Path)):
505
+ data = yaml_load(data) # dictionary
506
+
507
+ # Checks
508
+ for k in 'train', 'val', 'names':
509
+ assert k in data, emojis(f"data.yaml '{k}:' field missing ❌")
510
+ if isinstance(data['names'], (list, tuple)): # old array format
511
+ data['names'] = dict(enumerate(data['names'])) # convert to dict
512
+ assert all(isinstance(k, int) for k in data['names'].keys()), 'data.yaml names keys must be integers, i.e. 2: car'
513
+ data['nc'] = len(data['names'])
514
+
515
+ # Resolve paths
516
+ path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.'
517
+ if not path.is_absolute():
518
+ path = (ROOT / path).resolve()
519
+ data['path'] = path # download scripts
520
+ for k in 'train', 'val', 'test':
521
+ if data.get(k): # prepend path
522
+ if isinstance(data[k], str):
523
+ x = (path / data[k]).resolve()
524
+ if not x.exists() and data[k].startswith('../'):
525
+ x = (path / data[k][3:]).resolve()
526
+ data[k] = str(x)
527
+ else:
528
+ data[k] = [str((path / x).resolve()) for x in data[k]]
529
+
530
+ # Parse yaml
531
+ train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
532
+ if val:
533
+ val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
534
+ if not all(x.exists() for x in val):
535
+ LOGGER.info('\nDataset not found ⚠️, missing paths %s' % [str(x) for x in val if not x.exists()])
536
+ if not s or not autodownload:
537
+ raise Exception('Dataset not found ❌')
538
+ t = time.time()
539
+ if s.startswith('http') and s.endswith('.zip'): # URL
540
+ f = Path(s).name # filename
541
+ LOGGER.info(f'Downloading {s} to {f}...')
542
+ torch.hub.download_url_to_file(s, f)
543
+ Path(DATASETS_DIR).mkdir(parents=True, exist_ok=True) # create root
544
+ unzip_file(f, path=DATASETS_DIR) # unzip
545
+ Path(f).unlink() # remove zip
546
+ r = None # success
547
+ elif s.startswith('bash '): # bash script
548
+ LOGGER.info(f'Running {s} ...')
549
+ r = os.system(s)
550
+ else: # python script
551
+ r = exec(s, {'yaml': data}) # return None
552
+ dt = f'({round(time.time() - t, 1)}s)'
553
+ s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌"
554
+ LOGGER.info(f"Dataset download {s}")
555
+ check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf', progress=True) # download fonts
556
+ return data # dictionary
557
+
558
+
559
+ def check_amp(model):
560
+ # Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation
561
+ from models.common import AutoShape, DetectMultiBackend
562
+
563
+ def amp_allclose(model, im):
564
+ # All close FP32 vs AMP results
565
+ m = AutoShape(model, verbose=False) # model
566
+ a = m(im).xywhn[0] # FP32 inference
567
+ m.amp = True
568
+ b = m(im).xywhn[0] # AMP inference
569
+ return a.shape == b.shape and torch.allclose(a, b, atol=0.1) # close to 10% absolute tolerance
570
+
571
+ prefix = colorstr('AMP: ')
572
+ device = next(model.parameters()).device # get model device
573
+ if device.type in ('cpu', 'mps'):
574
+ return False # AMP only used on CUDA devices
575
+ f = ROOT / 'data' / 'images' / 'bus.jpg' # image to check
576
+ im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if check_online() else np.ones((640, 640, 3))
577
+ try:
578
+ #assert amp_allclose(deepcopy(model), im) or amp_allclose(DetectMultiBackend('yolo.pt', device), im)
579
+ LOGGER.info(f'{prefix}checks passed ✅')
580
+ return True
581
+ except Exception:
582
+ help_url = 'https://github.com/ultralytics/yolov5/issues/7908'
583
+ LOGGER.warning(f'{prefix}checks failed ❌, disabling Automatic Mixed Precision. See {help_url}')
584
+ return False
585
+
586
+
587
+ def yaml_load(file='data.yaml'):
588
+ # Single-line safe yaml loading
589
+ with open(file, errors='ignore') as f:
590
+ return yaml.safe_load(f)
591
+
592
+
593
+ def yaml_save(file='data.yaml', data={}):
594
+ # Single-line safe yaml saving
595
+ with open(file, 'w') as f:
596
+ yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
597
+
598
+
599
+ def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX')):
600
+ # Unzip a *.zip file to path/, excluding files containing strings in exclude list
601
+ if path is None:
602
+ path = Path(file).parent # default path
603
+ with ZipFile(file) as zipObj:
604
+ for f in zipObj.namelist(): # list all archived filenames in the zip
605
+ if all(x not in f for x in exclude):
606
+ zipObj.extract(f, path=path)
607
+
608
+
609
+ def url2file(url):
610
+ # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
611
+ url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
612
+ return Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
613
+
614
+
615
+ def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3):
616
+ # Multithreaded file download and unzip function, used in data.yaml for autodownload
617
+ def download_one(url, dir):
618
+ # Download 1 file
619
+ success = True
620
+ if os.path.isfile(url):
621
+ f = Path(url) # filename
622
+ else: # does not exist
623
+ f = dir / Path(url).name
624
+ LOGGER.info(f'Downloading {url} to {f}...')
625
+ for i in range(retry + 1):
626
+ if curl:
627
+ s = 'sS' if threads > 1 else '' # silent
628
+ r = os.system(
629
+ f'curl -# -{s}L "{url}" -o "{f}" --retry 9 -C -') # curl download with retry, continue
630
+ success = r == 0
631
+ else:
632
+ torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download
633
+ success = f.is_file()
634
+ if success:
635
+ break
636
+ elif i < retry:
637
+ LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
638
+ else:
639
+ LOGGER.warning(f'❌ Failed to download {url}...')
640
+
641
+ if unzip and success and (f.suffix == '.gz' or is_zipfile(f) or is_tarfile(f)):
642
+ LOGGER.info(f'Unzipping {f}...')
643
+ if is_zipfile(f):
644
+ unzip_file(f, dir) # unzip
645
+ elif is_tarfile(f):
646
+ os.system(f'tar xf {f} --directory {f.parent}') # unzip
647
+ elif f.suffix == '.gz':
648
+ os.system(f'tar xfz {f} --directory {f.parent}') # unzip
649
+ if delete:
650
+ f.unlink() # remove zip
651
+
652
+ dir = Path(dir)
653
+ dir.mkdir(parents=True, exist_ok=True) # make directory
654
+ if threads > 1:
655
+ pool = ThreadPool(threads)
656
+ pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multithreaded
657
+ pool.close()
658
+ pool.join()
659
+ else:
660
+ for u in [url] if isinstance(url, (str, Path)) else url:
661
+ download_one(u, dir)
662
+
663
+
664
+ def make_divisible(x, divisor):
665
+ # Returns nearest x divisible by divisor
666
+ if isinstance(divisor, torch.Tensor):
667
+ divisor = int(divisor.max()) # to int
668
+ return math.ceil(x / divisor) * divisor
669
+
670
+
671
+ def clean_str(s):
672
+ # Cleans a string by replacing special characters with underscore _
673
+ return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
674
+
675
+
676
+ def one_cycle(y1=0.0, y2=1.0, steps=100):
677
+ # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
678
+ return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
679
+
680
+
681
+ def one_flat_cycle(y1=0.0, y2=1.0, steps=100):
682
+ # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
683
+ #return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
684
+ return lambda x: ((1 - math.cos((x - (steps // 2)) * math.pi / (steps // 2))) / 2) * (y2 - y1) + y1 if (x > (steps // 2)) else y1
685
+
686
+
687
+ def colorstr(*input):
688
+ # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
689
+ *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
690
+ colors = {
691
+ 'black': '\033[30m', # basic colors
692
+ 'red': '\033[31m',
693
+ 'green': '\033[32m',
694
+ 'yellow': '\033[33m',
695
+ 'blue': '\033[34m',
696
+ 'magenta': '\033[35m',
697
+ 'cyan': '\033[36m',
698
+ 'white': '\033[37m',
699
+ 'bright_black': '\033[90m', # bright colors
700
+ 'bright_red': '\033[91m',
701
+ 'bright_green': '\033[92m',
702
+ 'bright_yellow': '\033[93m',
703
+ 'bright_blue': '\033[94m',
704
+ 'bright_magenta': '\033[95m',
705
+ 'bright_cyan': '\033[96m',
706
+ 'bright_white': '\033[97m',
707
+ 'end': '\033[0m', # misc
708
+ 'bold': '\033[1m',
709
+ 'underline': '\033[4m'}
710
+ return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
711
+
712
+
713
+ def labels_to_class_weights(labels, nc=80):
714
+ # Get class weights (inverse frequency) from training labels
715
+ if labels[0] is None: # no labels loaded
716
+ return torch.Tensor()
717
+
718
+ labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
719
+ classes = labels[:, 0].astype(int) # labels = [class xywh]
720
+ weights = np.bincount(classes, minlength=nc) # occurrences per class
721
+
722
+ # Prepend gridpoint count (for uCE training)
723
+ # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
724
+ # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
725
+
726
+ weights[weights == 0] = 1 # replace empty bins with 1
727
+ weights = 1 / weights # number of targets per class
728
+ weights /= weights.sum() # normalize
729
+ return torch.from_numpy(weights).float()
730
+
731
+
732
+ def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
733
+ # Produces image weights based on class_weights and image contents
734
+ # Usage: index = random.choices(range(n), weights=image_weights, k=1) # weighted image sample
735
+ class_counts = np.array([np.bincount(x[:, 0].astype(int), minlength=nc) for x in labels])
736
+ return (class_weights.reshape(1, nc) * class_counts).sum(1)
737
+
738
+
739
+ def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
740
+ # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
741
+ # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
742
+ # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
743
+ # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
744
+ # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
745
+ return [
746
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
747
+ 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
748
+ 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
749
+
750
+
751
+ def xyxy2xywh(x):
752
+ # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
753
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
754
+ y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
755
+ y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
756
+ y[..., 2] = x[..., 2] - x[..., 0] # width
757
+ y[..., 3] = x[..., 3] - x[..., 1] # height
758
+ return y
759
+
760
+
761
+ def xywh2xyxy(x):
762
+ # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
763
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
764
+ y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
765
+ y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
766
+ y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x
767
+ y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y
768
+ return y
769
+
770
+
771
+ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
772
+ # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
773
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
774
+ y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
775
+ y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
776
+ y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x
777
+ y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh # bottom right y
778
+ return y
779
+
780
+
781
+ def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
782
+ # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
783
+ if clip:
784
+ clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip
785
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
786
+ y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
787
+ y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
788
+ y[..., 2] = (x[..., 2] - x[..., 0]) / w # width
789
+ y[..., 3] = (x[..., 3] - x[..., 1]) / h # height
790
+ return y
791
+
792
+
793
+ def xyn2xy(x, w=640, h=640, padw=0, padh=0):
794
+ # Convert normalized segments into pixel segments, shape (n,2)
795
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
796
+ y[..., 0] = w * x[..., 0] + padw # top left x
797
+ y[..., 1] = h * x[..., 1] + padh # top left y
798
+ return y
799
+
800
+
801
+ def segment2box(segment, width=640, height=640):
802
+ # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
803
+ x, y = segment.T # segment xy
804
+ inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
805
+ x, y, = x[inside], y[inside]
806
+ return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
807
+
808
+
809
+ def segments2boxes(segments):
810
+ # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
811
+ boxes = []
812
+ for s in segments:
813
+ x, y = s.T # segment xy
814
+ boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
815
+ return xyxy2xywh(np.array(boxes)) # cls, xywh
816
+
817
+
818
+ def resample_segments(segments, n=1000):
819
+ # Up-sample an (n,2) segment
820
+ for i, s in enumerate(segments):
821
+ s = np.concatenate((s, s[0:1, :]), axis=0)
822
+ x = np.linspace(0, len(s) - 1, n)
823
+ xp = np.arange(len(s))
824
+ segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
825
+ return segments
826
+
827
+
828
+ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
829
+ # Rescale boxes (xyxy) from img1_shape to img0_shape
830
+ if ratio_pad is None: # calculate from img0_shape
831
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
832
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
833
+ else:
834
+ gain = ratio_pad[0][0]
835
+ pad = ratio_pad[1]
836
+
837
+ boxes[:, [0, 2]] -= pad[0] # x padding
838
+ boxes[:, [1, 3]] -= pad[1] # y padding
839
+ boxes[:, :4] /= gain
840
+ clip_boxes(boxes, img0_shape)
841
+ return boxes
842
+
843
+
844
+ def scale_segments(img1_shape, segments, img0_shape, ratio_pad=None, normalize=False):
845
+ # Rescale coords (xyxy) from img1_shape to img0_shape
846
+ if ratio_pad is None: # calculate from img0_shape
847
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
848
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
849
+ else:
850
+ gain = ratio_pad[0][0]
851
+ pad = ratio_pad[1]
852
+
853
+ segments[:, 0] -= pad[0] # x padding
854
+ segments[:, 1] -= pad[1] # y padding
855
+ segments /= gain
856
+ clip_segments(segments, img0_shape)
857
+ if normalize:
858
+ segments[:, 0] /= img0_shape[1] # width
859
+ segments[:, 1] /= img0_shape[0] # height
860
+ return segments
861
+
862
+
863
+ def clip_boxes(boxes, shape):
864
+ # Clip boxes (xyxy) to image shape (height, width)
865
+ if isinstance(boxes, torch.Tensor): # faster individually
866
+ boxes[:, 0].clamp_(0, shape[1]) # x1
867
+ boxes[:, 1].clamp_(0, shape[0]) # y1
868
+ boxes[:, 2].clamp_(0, shape[1]) # x2
869
+ boxes[:, 3].clamp_(0, shape[0]) # y2
870
+ else: # np.array (faster grouped)
871
+ boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
872
+ boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
873
+
874
+
875
+ def clip_segments(segments, shape):
876
+ # Clip segments (xy1,xy2,...) to image shape (height, width)
877
+ if isinstance(segments, torch.Tensor): # faster individually
878
+ segments[:, 0].clamp_(0, shape[1]) # x
879
+ segments[:, 1].clamp_(0, shape[0]) # y
880
+ else: # np.array (faster grouped)
881
+ segments[:, 0] = segments[:, 0].clip(0, shape[1]) # x
882
+ segments[:, 1] = segments[:, 1].clip(0, shape[0]) # y
883
+
884
+ def box_iou_for_nms(box1, box2, GIoU=False, DIoU=False, CIoU=False, SIoU=False, EIou=False, eps=1e-7):
885
+ # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)
886
+
887
+ b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
888
+ b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
889
+ w1, h1 = b1_x2 - b1_x1, (b1_y2 - b1_y1).clamp(eps)
890
+ w2, h2 = b2_x2 - b2_x1, (b2_y2 - b2_y1).clamp(eps)
891
+
892
+ # Intersection area
893
+ inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * \
894
+ (b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp(0)
895
+
896
+ # Union Area
897
+ union = w1 * h1 + w2 * h2 - inter + eps
898
+
899
+ # IoU
900
+ iou = inter / union
901
+ if CIoU or DIoU or GIoU or EIou:
902
+ cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) width
903
+ ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height
904
+ if CIoU or DIoU or EIou: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
905
+ c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
906
+ rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2
907
+ if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
908
+ v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
909
+ with torch.no_grad():
910
+ alpha = v / (v - iou + (1 + eps))
911
+ return iou - (rho2 / c2 + v * alpha) # CIoU
912
+ elif EIou:
913
+ rho_w2 = ((b2_x2 - b2_x1) - (b1_x2 - b1_x1)) ** 2
914
+ rho_h2 = ((b2_y2 - b2_y1) - (b1_y2 - b1_y1)) ** 2
915
+ cw2 = cw ** 2 + eps
916
+ ch2 = ch ** 2 + eps
917
+ return iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2)
918
+ return iou - rho2 / c2 # DIoU
919
+ c_area = cw * ch + eps # convex area
920
+ return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf
921
+ elif SIoU:
922
+ cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) width
923
+ ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height
924
+ # SIoU Loss https://arxiv.org/pdf/2205.12740.pdf
925
+ s_cw = (b2_x1 + b2_x2 - b1_x1 - b1_x2) * 0.5 + eps
926
+ s_ch = (b2_y1 + b2_y2 - b1_y1 - b1_y2) * 0.5 + eps
927
+ sigma = torch.pow(s_cw ** 2 + s_ch ** 2, 0.5)
928
+ sin_alpha_1 = torch.abs(s_cw) / sigma
929
+ sin_alpha_2 = torch.abs(s_ch) / sigma
930
+ threshold = pow(2, 0.5) / 2
931
+ sin_alpha = torch.where(sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1)
932
+ angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2)
933
+ rho_x = (s_cw / cw) ** 2
934
+ rho_y = (s_ch / ch) ** 2
935
+ gamma = angle_cost - 2
936
+ distance_cost = 2 - torch.exp(gamma * rho_x) - torch.exp(gamma * rho_y)
937
+ omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2)
938
+ omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2)
939
+ shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4)
940
+ return iou - 0.5 * (distance_cost + shape_cost)
941
+ return iou # IoU
942
+
943
+
944
+ def soft_nms(bboxes, scores, iou_thresh=0.5,sigma=0.5,score_threshold=0.25):
945
+ order = scores.argsort(descending=True).to(bboxes.device)
946
+ keep = []
947
+
948
+ while order.numel() > 1:
949
+ if order.numel() == 1:
950
+ keep.append(order[0])
951
+ break
952
+ else:
953
+ i = order[0]
954
+ keep.append(i)
955
+
956
+ iou = box_iou_for_nms(bboxes[i], bboxes[order[1:]]).squeeze()
957
+
958
+ idx = (iou > iou_thresh).nonzero().squeeze()
959
+ if idx.numel() > 0:
960
+ iou = iou[idx]
961
+ newScores = torch.exp(-torch.pow(iou,2)/sigma)
962
+ scores[order[idx+1]] *= newScores
963
+
964
+ newOrder = (scores[order[1:]] > score_threshold).nonzero().squeeze()
965
+ if newOrder.numel() == 0:
966
+ break
967
+ else:
968
+ maxScoreIndex = torch.argmax(scores[order[newOrder+1]])
969
+ if maxScoreIndex != 0:
970
+ newOrder[[0,maxScoreIndex],] = newOrder[[maxScoreIndex,0],]
971
+ order = order[newOrder+1]
972
+
973
+ return torch.LongTensor(keep)
974
+
975
+ def non_max_suppression(
976
+ prediction,
977
+ conf_thres=0.25,
978
+ iou_thres=0.45,
979
+ classes=None,
980
+ agnostic=False,
981
+ multi_label=False,
982
+ labels=(),
983
+ max_det=300,
984
+ nm=0, # number of masks
985
+ ):
986
+ """Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
987
+
988
+ Returns:
989
+ list of detections, on (n,6) tensor per image [xyxy, conf, cls]
990
+ """
991
+
992
+ if isinstance(prediction, (list, tuple)): # YOLO model in validation model, output = (inference_out, loss_out)
993
+ prediction = prediction[0] # select only inference output
994
+
995
+
996
+ device = prediction.device
997
+ mps = 'mps' in device.type # Apple MPS
998
+ if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
999
+ prediction = prediction.cpu()
1000
+ bs = prediction.shape[0] # batch size
1001
+ nc = prediction.shape[1] - nm - 4 # number of classes
1002
+ mi = 4 + nc # mask start index
1003
+ xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates
1004
+
1005
+ # Checks
1006
+ assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
1007
+ assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
1008
+
1009
+ # Settings
1010
+ # min_wh = 2 # (pixels) minimum box width and height
1011
+ max_wh = 7680 # (pixels) maximum box width and height
1012
+ max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
1013
+ time_limit = 2.5 + 0.05 * bs # seconds to quit after
1014
+ redundant = True # require redundant detections
1015
+ multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
1016
+ merge = False # use merge-NMS
1017
+
1018
+ t = time.time()
1019
+ output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
1020
+ for xi, x in enumerate(prediction): # image index, image inference
1021
+ # Apply constraints
1022
+ # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
1023
+ x = x.T[xc[xi]] # confidence
1024
+
1025
+ # Cat apriori labels if autolabelling
1026
+ if labels and len(labels[xi]):
1027
+ lb = labels[xi]
1028
+ v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
1029
+ v[:, :4] = lb[:, 1:5] # box
1030
+ v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
1031
+ x = torch.cat((x, v), 0)
1032
+
1033
+ # If none remain process next image
1034
+ if not x.shape[0]:
1035
+ continue
1036
+
1037
+ # Detections matrix nx6 (xyxy, conf, cls)
1038
+ box, cls, mask = x.split((4, nc, nm), 1)
1039
+ box = xywh2xyxy(box) # center_x, center_y, width, height) to (x1, y1, x2, y2)
1040
+ if multi_label:
1041
+ i, j = (cls > conf_thres).nonzero(as_tuple=False).T
1042
+ x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
1043
+ else: # best class only
1044
+ conf, j = cls.max(1, keepdim=True)
1045
+ x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
1046
+
1047
+ # Filter by class
1048
+ if classes is not None:
1049
+ x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
1050
+
1051
+ # Apply finite constraint
1052
+ # if not torch.isfinite(x).all():
1053
+ # x = x[torch.isfinite(x).all(1)]
1054
+
1055
+ # Check shape
1056
+ n = x.shape[0] # number of boxes
1057
+ if not n: # no boxes
1058
+ continue
1059
+ elif n > max_nms: # excess boxes
1060
+ x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
1061
+ else:
1062
+ x = x[x[:, 4].argsort(descending=True)] # sort by confidence
1063
+
1064
+ # Batched NMS
1065
+ c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
1066
+ boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
1067
+ i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
1068
+ # i = soft_nms(boxes, scores, iou_thres)
1069
+ if i.shape[0] > max_det: # limit detections
1070
+ i = i[:max_det]
1071
+ if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
1072
+ # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
1073
+ iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
1074
+ weights = iou * scores[None] # box weights
1075
+ x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
1076
+ if redundant:
1077
+ i = i[iou.sum(1) > 1] # require redundancy
1078
+
1079
+ output[xi] = x[i]
1080
+ if mps:
1081
+ output[xi] = output[xi].to(device)
1082
+ if (time.time() - t) > time_limit:
1083
+ LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
1084
+ break # time limit exceeded
1085
+
1086
+ return output
1087
+
1088
+
1089
+ def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
1090
+ # Strip optimizer from 'f' to finalize training, optionally save as 's'
1091
+ x = torch.load(f, map_location=torch.device('cpu'))
1092
+ if x.get('ema'):
1093
+ x['model'] = x['ema'] # replace model with ema
1094
+ for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys
1095
+ x[k] = None
1096
+ x['epoch'] = -1
1097
+ x['model'].half() # to FP16
1098
+ for p in x['model'].parameters():
1099
+ p.requires_grad = False
1100
+ torch.save(x, s or f)
1101
+ mb = os.path.getsize(s or f) / 1E6 # filesize
1102
+ LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
1103
+
1104
+
1105
+ def print_mutation(keys, results, hyp, save_dir, bucket, prefix=colorstr('evolve: ')):
1106
+ evolve_csv = save_dir / 'evolve.csv'
1107
+ evolve_yaml = save_dir / 'hyp_evolve.yaml'
1108
+ keys = tuple(keys) + tuple(hyp.keys()) # [results + hyps]
1109
+ keys = tuple(x.strip() for x in keys)
1110
+ vals = results + tuple(hyp.values())
1111
+ n = len(keys)
1112
+
1113
+ # Download (optional)
1114
+ if bucket:
1115
+ url = f'gs://{bucket}/evolve.csv'
1116
+ if gsutil_getsize(url) > (evolve_csv.stat().st_size if evolve_csv.exists() else 0):
1117
+ os.system(f'gsutil cp {url} {save_dir}') # download evolve.csv if larger than local
1118
+
1119
+ # Log to evolve.csv
1120
+ s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header
1121
+ with open(evolve_csv, 'a') as f:
1122
+ f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')
1123
+
1124
+ # Save yaml
1125
+ with open(evolve_yaml, 'w') as f:
1126
+ data = pd.read_csv(evolve_csv)
1127
+ data = data.rename(columns=lambda x: x.strip()) # strip keys
1128
+ i = np.argmax(fitness(data.values[:, :4])) #
1129
+ generations = len(data)
1130
+ f.write('# YOLO Hyperparameter Evolution Results\n' + f'# Best generation: {i}\n' +
1131
+ f'# Last generation: {generations - 1}\n' + '# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) +
1132
+ '\n' + '# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n')
1133
+ yaml.safe_dump(data.loc[i][7:].to_dict(), f, sort_keys=False)
1134
+
1135
+ # Print to screen
1136
+ LOGGER.info(prefix + f'{generations} generations finished, current result:\n' + prefix +
1137
+ ', '.join(f'{x.strip():>20s}' for x in keys) + '\n' + prefix + ', '.join(f'{x:20.5g}'
1138
+ for x in vals) + '\n\n')
1139
+
1140
+ if bucket:
1141
+ os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload
1142
+
1143
+
1144
+ def apply_classifier(x, model, img, im0):
1145
+ # Apply a second stage classifier to YOLO outputs
1146
+ # Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval()
1147
+ im0 = [im0] if isinstance(im0, np.ndarray) else im0
1148
+ for i, d in enumerate(x): # per image
1149
+ if d is not None and len(d):
1150
+ d = d.clone()
1151
+
1152
+ # Reshape and pad cutouts
1153
+ b = xyxy2xywh(d[:, :4]) # boxes
1154
+ b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
1155
+ b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
1156
+ d[:, :4] = xywh2xyxy(b).long()
1157
+
1158
+ # Rescale boxes from img_size to im0 size
1159
+ scale_boxes(img.shape[2:], d[:, :4], im0[i].shape)
1160
+
1161
+ # Classes
1162
+ pred_cls1 = d[:, 5].long()
1163
+ ims = []
1164
+ for a in d:
1165
+ cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
1166
+ im = cv2.resize(cutout, (224, 224)) # BGR
1167
+
1168
+ im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
1169
+ im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
1170
+ im /= 255 # 0 - 255 to 0.0 - 1.0
1171
+ ims.append(im)
1172
+
1173
+ pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
1174
+ x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
1175
+
1176
+ return x
1177
+
1178
+
1179
+ def increment_path(path, exist_ok=False, sep='', mkdir=False):
1180
+ # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
1181
+ path = Path(path) # os-agnostic
1182
+ if path.exists() and not exist_ok:
1183
+ path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
1184
+
1185
+ # Method 1
1186
+ for n in range(2, 9999):
1187
+ p = f'{path}{sep}{n}{suffix}' # increment path
1188
+ if not os.path.exists(p): #
1189
+ break
1190
+ path = Path(p)
1191
+
1192
+ # Method 2 (deprecated)
1193
+ # dirs = glob.glob(f"{path}{sep}*") # similar paths
1194
+ # matches = [re.search(rf"{path.stem}{sep}(\d+)", d) for d in dirs]
1195
+ # i = [int(m.groups()[0]) for m in matches if m] # indices
1196
+ # n = max(i) + 1 if i else 2 # increment number
1197
+ # path = Path(f"{path}{sep}{n}{suffix}") # increment path
1198
+
1199
+ if mkdir:
1200
+ path.mkdir(parents=True, exist_ok=True) # make directory
1201
+
1202
+ return path
1203
+
1204
+
1205
+ # OpenCV Chinese-friendly functions ------------------------------------------------------------------------------------
1206
+ imshow_ = cv2.imshow # copy to avoid recursion errors
1207
+
1208
+
1209
+ def imread(path, flags=cv2.IMREAD_COLOR):
1210
+ return cv2.imdecode(np.fromfile(path, np.uint8), flags)
1211
+
1212
+
1213
+ def imwrite(path, im):
1214
+ try:
1215
+ cv2.imencode(Path(path).suffix, im)[1].tofile(path)
1216
+ return True
1217
+ except Exception:
1218
+ return False
1219
+
1220
+
1221
+ def imshow(path, im):
1222
+ imshow_(path.encode('unicode_escape').decode(), im)
1223
+
1224
+
1225
+ cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow # redefine
1226
+
1227
+ # Variables ------------------------------------------------------------------------------------------------------------
utils/lion.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch implementation of the Lion optimizer."""
2
+ import torch
3
+ from torch.optim.optimizer import Optimizer
4
+
5
+
6
+ class Lion(Optimizer):
7
+ r"""Implements Lion algorithm."""
8
+
9
+ def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0):
10
+ """Initialize the hyperparameters.
11
+ Args:
12
+ params (iterable): iterable of parameters to optimize or dicts defining
13
+ parameter groups
14
+ lr (float, optional): learning rate (default: 1e-4)
15
+ betas (Tuple[float, float], optional): coefficients used for computing
16
+ running averages of gradient and its square (default: (0.9, 0.99))
17
+ weight_decay (float, optional): weight decay coefficient (default: 0)
18
+ """
19
+
20
+ if not 0.0 <= lr:
21
+ raise ValueError('Invalid learning rate: {}'.format(lr))
22
+ if not 0.0 <= betas[0] < 1.0:
23
+ raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0]))
24
+ if not 0.0 <= betas[1] < 1.0:
25
+ raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1]))
26
+ defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
27
+ super().__init__(params, defaults)
28
+
29
+ @torch.no_grad()
30
+ def step(self, closure=None):
31
+ """Performs a single optimization step.
32
+ Args:
33
+ closure (callable, optional): A closure that reevaluates the model
34
+ and returns the loss.
35
+ Returns:
36
+ the loss.
37
+ """
38
+ loss = None
39
+ if closure is not None:
40
+ with torch.enable_grad():
41
+ loss = closure()
42
+
43
+ for group in self.param_groups:
44
+ for p in group['params']:
45
+ if p.grad is None:
46
+ continue
47
+
48
+ # Perform stepweight decay
49
+ p.data.mul_(1 - group['lr'] * group['weight_decay'])
50
+
51
+ grad = p.grad
52
+ state = self.state[p]
53
+ # State initialization
54
+ if len(state) == 0:
55
+ # Exponential moving average of gradient values
56
+ state['exp_avg'] = torch.zeros_like(p)
57
+
58
+ exp_avg = state['exp_avg']
59
+ beta1, beta2 = group['betas']
60
+
61
+ # Weight update
62
+ update = exp_avg * beta1 + grad * (1 - beta1)
63
+ p.add_(torch.sign(update), alpha=-group['lr'])
64
+ # Decay the momentum running average coefficient
65
+ exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)
66
+
67
+ return loss
utils/loggers/__init__.py ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ from pathlib import Path
4
+
5
+ import pkg_resources as pkg
6
+ import torch
7
+ from torch.utils.tensorboard import SummaryWriter
8
+
9
+ from utils.general import LOGGER, colorstr, cv2
10
+ from utils.loggers.clearml.clearml_utils import ClearmlLogger
11
+ from utils.loggers.wandb.wandb_utils import WandbLogger
12
+ from utils.plots import plot_images, plot_labels, plot_results
13
+ from utils.torch_utils import de_parallel
14
+
15
+ LOGGERS = ('csv', 'tb', 'wandb', 'clearml', 'comet') # *.csv, TensorBoard, Weights & Biases, ClearML
16
+ RANK = int(os.getenv('RANK', -1))
17
+
18
+ try:
19
+ import wandb
20
+
21
+ assert hasattr(wandb, '__version__') # verify package import not local dir
22
+ if pkg.parse_version(wandb.__version__) >= pkg.parse_version('0.12.2') and RANK in {0, -1}:
23
+ try:
24
+ wandb_login_success = wandb.login(timeout=30)
25
+ except wandb.errors.UsageError: # known non-TTY terminal issue
26
+ wandb_login_success = False
27
+ if not wandb_login_success:
28
+ wandb = None
29
+ except (ImportError, AssertionError):
30
+ wandb = None
31
+
32
+ try:
33
+ import clearml
34
+
35
+ assert hasattr(clearml, '__version__') # verify package import not local dir
36
+ except (ImportError, AssertionError):
37
+ clearml = None
38
+
39
+ try:
40
+ if RANK not in [0, -1]:
41
+ comet_ml = None
42
+ else:
43
+ import comet_ml
44
+
45
+ assert hasattr(comet_ml, '__version__') # verify package import not local dir
46
+ from utils.loggers.comet import CometLogger
47
+
48
+ except (ModuleNotFoundError, ImportError, AssertionError):
49
+ comet_ml = None
50
+
51
+
52
+ class Loggers():
53
+ # YOLO Loggers class
54
+ def __init__(self, save_dir=None, weights=None, opt=None, hyp=None, logger=None, include=LOGGERS):
55
+ self.save_dir = save_dir
56
+ self.weights = weights
57
+ self.opt = opt
58
+ self.hyp = hyp
59
+ self.plots = not opt.noplots # plot results
60
+ self.logger = logger # for printing results to console
61
+ self.include = include
62
+ self.keys = [
63
+ 'train/box_loss',
64
+ 'train/cls_loss',
65
+ 'train/dfl_loss', # train loss
66
+ 'metrics/precision',
67
+ 'metrics/recall',
68
+ 'metrics/mAP_0.5',
69
+ 'metrics/mAP_0.5:0.95', # metrics
70
+ 'val/box_loss',
71
+ 'val/cls_loss',
72
+ 'val/dfl_loss', # val loss
73
+ 'x/lr0',
74
+ 'x/lr1',
75
+ 'x/lr2'] # params
76
+ self.best_keys = ['best/epoch', 'best/precision', 'best/recall', 'best/mAP_0.5', 'best/mAP_0.5:0.95']
77
+ for k in LOGGERS:
78
+ setattr(self, k, None) # init empty logger dictionary
79
+ self.csv = True # always log to csv
80
+
81
+ # Messages
82
+ # if not wandb:
83
+ # prefix = colorstr('Weights & Biases: ')
84
+ # s = f"{prefix}run 'pip install wandb' to automatically track and visualize YOLO 🚀 runs in Weights & Biases"
85
+ # self.logger.info(s)
86
+ if not clearml:
87
+ prefix = colorstr('ClearML: ')
88
+ s = f"{prefix}run 'pip install clearml' to automatically track, visualize and remotely train YOLO 🚀 in ClearML"
89
+ self.logger.info(s)
90
+ if not comet_ml:
91
+ prefix = colorstr('Comet: ')
92
+ s = f"{prefix}run 'pip install comet_ml' to automatically track and visualize YOLO 🚀 runs in Comet"
93
+ self.logger.info(s)
94
+ # TensorBoard
95
+ s = self.save_dir
96
+ if 'tb' in self.include and not self.opt.evolve:
97
+ prefix = colorstr('TensorBoard: ')
98
+ self.logger.info(f"{prefix}Start with 'tensorboard --logdir {s.parent}', view at http://localhost:6006/")
99
+ self.tb = SummaryWriter(str(s))
100
+
101
+ # W&B
102
+ if wandb and 'wandb' in self.include:
103
+ wandb_artifact_resume = isinstance(self.opt.resume, str) and self.opt.resume.startswith('wandb-artifact://')
104
+ run_id = torch.load(self.weights).get('wandb_id') if self.opt.resume and not wandb_artifact_resume else None
105
+ self.opt.hyp = self.hyp # add hyperparameters
106
+ self.wandb = WandbLogger(self.opt, run_id)
107
+ # temp warn. because nested artifacts not supported after 0.12.10
108
+ # if pkg.parse_version(wandb.__version__) >= pkg.parse_version('0.12.11'):
109
+ # s = "YOLO temporarily requires wandb version 0.12.10 or below. Some features may not work as expected."
110
+ # self.logger.warning(s)
111
+ else:
112
+ self.wandb = None
113
+
114
+ # ClearML
115
+ if clearml and 'clearml' in self.include:
116
+ self.clearml = ClearmlLogger(self.opt, self.hyp)
117
+ else:
118
+ self.clearml = None
119
+
120
+ # Comet
121
+ if comet_ml and 'comet' in self.include:
122
+ if isinstance(self.opt.resume, str) and self.opt.resume.startswith("comet://"):
123
+ run_id = self.opt.resume.split("/")[-1]
124
+ self.comet_logger = CometLogger(self.opt, self.hyp, run_id=run_id)
125
+
126
+ else:
127
+ self.comet_logger = CometLogger(self.opt, self.hyp)
128
+
129
+ else:
130
+ self.comet_logger = None
131
+
132
+ @property
133
+ def remote_dataset(self):
134
+ # Get data_dict if custom dataset artifact link is provided
135
+ data_dict = None
136
+ if self.clearml:
137
+ data_dict = self.clearml.data_dict
138
+ if self.wandb:
139
+ data_dict = self.wandb.data_dict
140
+ if self.comet_logger:
141
+ data_dict = self.comet_logger.data_dict
142
+
143
+ return data_dict
144
+
145
+ def on_train_start(self):
146
+ if self.comet_logger:
147
+ self.comet_logger.on_train_start()
148
+
149
+ def on_pretrain_routine_start(self):
150
+ if self.comet_logger:
151
+ self.comet_logger.on_pretrain_routine_start()
152
+
153
+ def on_pretrain_routine_end(self, labels, names):
154
+ # Callback runs on pre-train routine end
155
+ if self.plots:
156
+ plot_labels(labels, names, self.save_dir)
157
+ paths = self.save_dir.glob('*labels*.jpg') # training labels
158
+ if self.wandb:
159
+ self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]})
160
+ # if self.clearml:
161
+ # pass # ClearML saves these images automatically using hooks
162
+ if self.comet_logger:
163
+ self.comet_logger.on_pretrain_routine_end(paths)
164
+
165
+ def on_train_batch_end(self, model, ni, imgs, targets, paths, vals):
166
+ log_dict = dict(zip(self.keys[0:3], vals))
167
+ # Callback runs on train batch end
168
+ # ni: number integrated batches (since train start)
169
+ if self.plots:
170
+ if ni < 3:
171
+ f = self.save_dir / f'train_batch{ni}.jpg' # filename
172
+ plot_images(imgs, targets, paths, f)
173
+ if ni == 0 and self.tb and not self.opt.sync_bn:
174
+ log_tensorboard_graph(self.tb, model, imgsz=(self.opt.imgsz, self.opt.imgsz))
175
+ if ni == 10 and (self.wandb or self.clearml):
176
+ files = sorted(self.save_dir.glob('train*.jpg'))
177
+ if self.wandb:
178
+ self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]})
179
+ if self.clearml:
180
+ self.clearml.log_debug_samples(files, title='Mosaics')
181
+
182
+ if self.comet_logger:
183
+ self.comet_logger.on_train_batch_end(log_dict, step=ni)
184
+
185
+ def on_train_epoch_end(self, epoch):
186
+ # Callback runs on train epoch end
187
+ if self.wandb:
188
+ self.wandb.current_epoch = epoch + 1
189
+
190
+ if self.comet_logger:
191
+ self.comet_logger.on_train_epoch_end(epoch)
192
+
193
+ def on_val_start(self):
194
+ if self.comet_logger:
195
+ self.comet_logger.on_val_start()
196
+
197
+ def on_val_image_end(self, pred, predn, path, names, im):
198
+ # Callback runs on val image end
199
+ if self.wandb:
200
+ self.wandb.val_one_image(pred, predn, path, names, im)
201
+ if self.clearml:
202
+ self.clearml.log_image_with_boxes(path, pred, names, im)
203
+
204
+ def on_val_batch_end(self, batch_i, im, targets, paths, shapes, out):
205
+ if self.comet_logger:
206
+ self.comet_logger.on_val_batch_end(batch_i, im, targets, paths, shapes, out)
207
+
208
+ def on_val_end(self, nt, tp, fp, p, r, f1, ap, ap50, ap_class, confusion_matrix):
209
+ # Callback runs on val end
210
+ if self.wandb or self.clearml:
211
+ files = sorted(self.save_dir.glob('val*.jpg'))
212
+ if self.wandb:
213
+ self.wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in files]})
214
+ if self.clearml:
215
+ self.clearml.log_debug_samples(files, title='Validation')
216
+
217
+ if self.comet_logger:
218
+ self.comet_logger.on_val_end(nt, tp, fp, p, r, f1, ap, ap50, ap_class, confusion_matrix)
219
+
220
+ def on_fit_epoch_end(self, vals, epoch, best_fitness, fi):
221
+ # Callback runs at the end of each fit (train+val) epoch
222
+ x = dict(zip(self.keys, vals))
223
+ if self.csv:
224
+ file = self.save_dir / 'results.csv'
225
+ n = len(x) + 1 # number of cols
226
+ s = '' if file.exists() else (('%20s,' * n % tuple(['epoch'] + self.keys)).rstrip(',') + '\n') # add header
227
+ with open(file, 'a') as f:
228
+ f.write(s + ('%20.5g,' * n % tuple([epoch] + vals)).rstrip(',') + '\n')
229
+
230
+ if self.tb:
231
+ for k, v in x.items():
232
+ self.tb.add_scalar(k, v, epoch)
233
+ elif self.clearml: # log to ClearML if TensorBoard not used
234
+ for k, v in x.items():
235
+ title, series = k.split('/')
236
+ self.clearml.task.get_logger().report_scalar(title, series, v, epoch)
237
+
238
+ if self.wandb:
239
+ if best_fitness == fi:
240
+ best_results = [epoch] + vals[3:7]
241
+ for i, name in enumerate(self.best_keys):
242
+ self.wandb.wandb_run.summary[name] = best_results[i] # log best results in the summary
243
+ self.wandb.log(x)
244
+ self.wandb.end_epoch(best_result=best_fitness == fi)
245
+
246
+ if self.clearml:
247
+ self.clearml.current_epoch_logged_images = set() # reset epoch image limit
248
+ self.clearml.current_epoch += 1
249
+
250
+ if self.comet_logger:
251
+ self.comet_logger.on_fit_epoch_end(x, epoch=epoch)
252
+
253
+ def on_model_save(self, last, epoch, final_epoch, best_fitness, fi):
254
+ # Callback runs on model save event
255
+ if (epoch + 1) % self.opt.save_period == 0 and not final_epoch and self.opt.save_period != -1:
256
+ if self.wandb:
257
+ self.wandb.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi)
258
+ if self.clearml:
259
+ self.clearml.task.update_output_model(model_path=str(last),
260
+ model_name='Latest Model',
261
+ auto_delete_file=False)
262
+
263
+ if self.comet_logger:
264
+ self.comet_logger.on_model_save(last, epoch, final_epoch, best_fitness, fi)
265
+
266
+ def on_train_end(self, last, best, epoch, results):
267
+ # Callback runs on training end, i.e. saving best model
268
+ if self.plots:
269
+ plot_results(file=self.save_dir / 'results.csv') # save results.png
270
+ files = ['results.png', 'confusion_matrix.png', *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
271
+ files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter
272
+ self.logger.info(f"Results saved to {colorstr('bold', self.save_dir)}")
273
+
274
+ if self.tb and not self.clearml: # These images are already captured by ClearML by now, we don't want doubles
275
+ for f in files:
276
+ self.tb.add_image(f.stem, cv2.imread(str(f))[..., ::-1], epoch, dataformats='HWC')
277
+
278
+ if self.wandb:
279
+ self.wandb.log(dict(zip(self.keys[3:10], results)))
280
+ self.wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]})
281
+ # Calling wandb.log. TODO: Refactor this into WandbLogger.log_model
282
+ if not self.opt.evolve:
283
+ wandb.log_artifact(str(best if best.exists() else last),
284
+ type='model',
285
+ name=f'run_{self.wandb.wandb_run.id}_model',
286
+ aliases=['latest', 'best', 'stripped'])
287
+ self.wandb.finish_run()
288
+
289
+ if self.clearml and not self.opt.evolve:
290
+ self.clearml.task.update_output_model(model_path=str(best if best.exists() else last),
291
+ name='Best Model',
292
+ auto_delete_file=False)
293
+
294
+ if self.comet_logger:
295
+ final_results = dict(zip(self.keys[3:10], results))
296
+ self.comet_logger.on_train_end(files, self.save_dir, last, best, epoch, final_results)
297
+
298
+ def on_params_update(self, params: dict):
299
+ # Update hyperparams or configs of the experiment
300
+ if self.wandb:
301
+ self.wandb.wandb_run.config.update(params, allow_val_change=True)
302
+ if self.comet_logger:
303
+ self.comet_logger.on_params_update(params)
304
+
305
+
306
+ class CustomedLoggers():
307
+ # YOLO Loggers class
308
+ def __init__(self, save_dir=None, weights=None, opt=None, hyp=None, logger=None, include=LOGGERS):
309
+ self.save_dir = save_dir
310
+ self.weights = weights
311
+ self.opt = opt
312
+ self.hyp = hyp
313
+ self.plots = not opt.noplots # plot results
314
+ self.logger = logger # for printing results to console
315
+ self.include = include
316
+ self.keys = [
317
+ 'train/box_loss',
318
+ 'train/cls_loss',
319
+ 'train/dfl_loss', # train loss
320
+ 'metrics/precision',
321
+ 'metrics/recall',
322
+ 'metrics/mAP_0.5',
323
+ 'metrics/mAP_0.5:0.95', # metrics
324
+ 'val/box_loss',
325
+ 'val/cls_loss',
326
+ 'val/dfl_loss', # val loss
327
+ 'x/lr0',
328
+ 'x/lr1',
329
+ 'x/lr2'] # params
330
+ self.best_keys = ['best/epoch', 'best/precision', 'best/recall', 'best/mAP_0.5', 'best/mAP_0.5:0.95']
331
+ for k in LOGGERS:
332
+ setattr(self, k, None) # init empty logger dictionary
333
+ self.csv = True # always log to csv
334
+
335
+ # Messages
336
+ # if not wandb:
337
+ # prefix = colorstr('Weights & Biases: ')
338
+ # s = f"{prefix}run 'pip install wandb' to automatically track and visualize YOLO 🚀 runs in Weights & Biases"
339
+ # self.logger.info(s)
340
+ if not clearml:
341
+ prefix = colorstr('ClearML: ')
342
+ s = f"{prefix}run 'pip install clearml' to automatically track, visualize and remotely train YOLO 🚀 in ClearML"
343
+ self.logger.info(s)
344
+ if not comet_ml:
345
+ prefix = colorstr('Comet: ')
346
+ s = f"{prefix}run 'pip install comet_ml' to automatically track and visualize YOLO 🚀 runs in Comet"
347
+ self.logger.info(s)
348
+ # TensorBoard
349
+ s = self.save_dir
350
+ if 'tb' in self.include and not self.opt.evolve:
351
+ prefix = colorstr('TensorBoard: ')
352
+ self.logger.info(f"{prefix}Start with 'tensorboard --logdir {s.parent}', view at http://localhost:6006/")
353
+ self.tb = SummaryWriter(str(s))
354
+
355
+ # W&B
356
+ if wandb and 'wandb' in self.include:
357
+ wandb_artifact_resume = isinstance(self.opt.resume, str) and self.opt.resume.startswith('wandb-artifact://')
358
+ run_id = torch.load(self.weights).get('wandb_id') if self.opt.resume and not wandb_artifact_resume else None
359
+ self.opt.hyp = self.hyp # add hyperparameters
360
+ self.wandb = WandbLogger(self.opt, run_id)
361
+ # temp warn. because nested artifacts not supported after 0.12.10
362
+ # if pkg.parse_version(wandb.__version__) >= pkg.parse_version('0.12.11'):
363
+ # s = "YOLO temporarily requires wandb version 0.12.10 or below. Some features may not work as expected."
364
+ # self.logger.warning(s)
365
+ else:
366
+ self.wandb = None
367
+
368
+ # ClearML
369
+ if clearml and 'clearml' in self.include:
370
+ self.clearml = ClearmlLogger(self.opt, self.hyp)
371
+ else:
372
+ self.clearml = None
373
+
374
+ # Comet
375
+ if comet_ml and 'comet' in self.include:
376
+ if isinstance(self.opt.resume, str) and self.opt.resume.startswith("comet://"):
377
+ run_id = self.opt.resume.split("/")[-1]
378
+ self.comet_logger = CometLogger(self.opt, self.hyp, run_id=run_id)
379
+
380
+ else:
381
+ self.comet_logger = CometLogger(self.opt, self.hyp)
382
+
383
+ else:
384
+ self.comet_logger = None
385
+
386
+ @property
387
+ def remote_dataset(self):
388
+ # Get data_dict if custom dataset artifact link is provided
389
+ data_dict = None
390
+ if self.clearml:
391
+ data_dict = self.clearml.data_dict
392
+ if self.wandb:
393
+ data_dict = self.wandb.data_dict
394
+ if self.comet_logger:
395
+ data_dict = self.comet_logger.data_dict
396
+
397
+ return data_dict
398
+
399
+ def on_train_start(self):
400
+ if self.comet_logger:
401
+ self.comet_logger.on_train_start()
402
+
403
+ def on_pretrain_routine_start(self):
404
+ if self.comet_logger:
405
+ self.comet_logger.on_pretrain_routine_start()
406
+
407
+ def on_pretrain_routine_end(self, labels, names):
408
+ # Callback runs on pre-train routine end
409
+ if self.plots:
410
+ plot_labels(labels, names, self.save_dir)
411
+ paths = self.save_dir.glob('*labels*.jpg') # training labels
412
+ if self.wandb:
413
+ self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]})
414
+ # if self.clearml:
415
+ # pass # ClearML saves these images automatically using hooks
416
+ if self.comet_logger:
417
+ self.comet_logger.on_pretrain_routine_end(paths)
418
+
419
+ def on_train_batch_end(self, model, ni, imgs, targets, paths, vals):
420
+ log_dict = dict(zip(self.keys[0:3], vals))
421
+ # Callback runs on train batch end
422
+ # ni: number integrated batches (since train start)
423
+ if self.plots:
424
+ if ni < 3:
425
+ f = self.save_dir / f'train_batch{ni}.jpg' # filename
426
+ plot_images(imgs, targets, paths, f)
427
+ if ni == 0 and self.tb and not self.opt.sync_bn:
428
+ log_tensorboard_graph(self.tb, model, imgsz=(self.opt.imgsz, self.opt.imgsz))
429
+ if ni == 10 and (self.wandb or self.clearml):
430
+ files = sorted(self.save_dir.glob('train*.jpg'))
431
+ if self.wandb:
432
+ self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]})
433
+ if self.clearml:
434
+ self.clearml.log_debug_samples(files, title='Mosaics')
435
+
436
+ if self.comet_logger:
437
+ self.comet_logger.on_train_batch_end(log_dict, step=ni)
438
+
439
+ def on_train_epoch_end(self, epoch):
440
+ # Callback runs on train epoch end
441
+ if self.wandb:
442
+ self.wandb.current_epoch = epoch + 1
443
+
444
+ if self.comet_logger:
445
+ self.comet_logger.on_train_epoch_end(epoch)
446
+
447
+ def on_val_start(self):
448
+ if self.comet_logger:
449
+ self.comet_logger.on_val_start()
450
+
451
+ def on_val_image_end(self, pred, predn, path, names, im):
452
+ # Callback runs on val image end
453
+ if self.wandb:
454
+ self.wandb.val_one_image(pred, predn, path, names, im)
455
+ if self.clearml:
456
+ self.clearml.log_image_with_boxes(path, pred, names, im)
457
+
458
+ def on_val_batch_end(self, batch_i, im, targets, paths, shapes, out):
459
+ if self.comet_logger:
460
+ self.comet_logger.on_val_batch_end(batch_i, im, targets, paths, shapes, out)
461
+
462
+ def on_val_end(self, nt, tp, fp, p, r, f1, ap, ap50, ap_class, confusion_matrix):
463
+ # Callback runs on val end
464
+ if self.wandb or self.clearml:
465
+ files = sorted(self.save_dir.glob('val*.jpg'))
466
+ if self.wandb:
467
+ self.wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in files]})
468
+ if self.clearml:
469
+ self.clearml.log_debug_samples(files, title='Validation')
470
+
471
+ if self.comet_logger:
472
+ self.comet_logger.on_val_end(nt, tp, fp, p, r, f1, ap, ap50, ap_class, confusion_matrix)
473
+
474
+ def on_fit_epoch_end(self, vals, epoch, best_fitness, fi):
475
+ # Callback runs at the end of each fit (train+val) epoch
476
+ x = dict(zip(self.keys, vals))
477
+ if self.csv:
478
+ file = self.save_dir / 'results.csv'
479
+ n = len(x) + 1 # number of cols
480
+ s = '' if file.exists() else (('%20s,' * n % tuple(['epoch'] + self.keys)).rstrip(',') + '\n') # add header
481
+ with open(file, 'a') as f:
482
+ f.write(s + ('%20.5g,' * (n + 1) % tuple([epoch] + vals)).rstrip(',') + '\n')
483
+
484
+ if self.tb:
485
+ for k, v in x.items():
486
+ self.tb.add_scalar(k, v, epoch)
487
+ elif self.clearml: # log to ClearML if TensorBoard not used
488
+ for k, v in x.items():
489
+ title, series = k.split('/')
490
+ self.clearml.task.get_logger().report_scalar(title, series, v, epoch)
491
+
492
+ if self.wandb:
493
+ if best_fitness == fi:
494
+ best_results = [epoch] + vals[3:7]
495
+ for i, name in enumerate(self.best_keys):
496
+ self.wandb.wandb_run.summary[name] = best_results[i] # log best results in the summary
497
+ self.wandb.log(x)
498
+ self.wandb.end_epoch(best_result=best_fitness == fi)
499
+
500
+ if self.clearml:
501
+ self.clearml.current_epoch_logged_images = set() # reset epoch image limit
502
+ self.clearml.current_epoch += 1
503
+
504
+ if self.comet_logger:
505
+ self.comet_logger.on_fit_epoch_end(x, epoch=epoch)
506
+
507
+ def on_model_save(self, last, epoch, final_epoch, best_fitness, fi):
508
+ # Callback runs on model save event
509
+ if (epoch + 1) % self.opt.save_period == 0 and not final_epoch and self.opt.save_period != -1:
510
+ if self.wandb:
511
+ self.wandb.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi)
512
+ if self.clearml:
513
+ self.clearml.task.update_output_model(model_path=str(last),
514
+ model_name='Latest Model',
515
+ auto_delete_file=False)
516
+
517
+ if self.comet_logger:
518
+ self.comet_logger.on_model_save(last, epoch, final_epoch, best_fitness, fi)
519
+
520
+ def on_train_end(self, last, best, epoch, results):
521
+ # Callback runs on training end, i.e. saving best model
522
+ if self.plots:
523
+ plot_results(file=self.save_dir / 'results.csv') # save results.png
524
+ files = ['results.png', 'confusion_matrix.png', *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
525
+ files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter
526
+ self.logger.info(f"Results saved to {colorstr('bold', self.save_dir)}")
527
+
528
+ if self.tb and not self.clearml: # These images are already captured by ClearML by now, we don't want doubles
529
+ for f in files:
530
+ self.tb.add_image(f.stem, cv2.imread(str(f))[..., ::-1], epoch, dataformats='HWC')
531
+
532
+ if self.wandb:
533
+ self.wandb.log(dict(zip(self.keys[3:10], results)))
534
+ self.wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]})
535
+ # Calling wandb.log. TODO: Refactor this into WandbLogger.log_model
536
+ if not self.opt.evolve:
537
+ wandb.log_artifact(str(best if best.exists() else last),
538
+ type='model',
539
+ name=f'run_{self.wandb.wandb_run.id}_model',
540
+ aliases=['latest', 'best', 'stripped'])
541
+ self.wandb.finish_run()
542
+
543
+ if self.clearml and not self.opt.evolve:
544
+ self.clearml.task.update_output_model(model_path=str(best if best.exists() else last),
545
+ name='Best Model',
546
+ auto_delete_file=False)
547
+
548
+ if self.comet_logger:
549
+ final_results = dict(zip(self.keys[3:10], results))
550
+ self.comet_logger.on_train_end(files, self.save_dir, last, best, epoch, final_results)
551
+
552
+ def on_params_update(self, params: dict):
553
+ # Update hyperparams or configs of the experiment
554
+ if self.wandb:
555
+ self.wandb.wandb_run.config.update(params, allow_val_change=True)
556
+ if self.comet_logger:
557
+ self.comet_logger.on_params_update(params)
558
+
559
+
560
+ class GenericLogger:
561
+ """
562
+ YOLO General purpose logger for non-task specific logging
563
+ Usage: from utils.loggers import GenericLogger; logger = GenericLogger(...)
564
+ Arguments
565
+ opt: Run arguments
566
+ console_logger: Console logger
567
+ include: loggers to include
568
+ """
569
+
570
+ def __init__(self, opt, console_logger, include=('tb', 'wandb')):
571
+ # init default loggers
572
+ self.save_dir = Path(opt.save_dir)
573
+ self.include = include
574
+ self.console_logger = console_logger
575
+ self.csv = self.save_dir / 'results.csv' # CSV logger
576
+ if 'tb' in self.include:
577
+ prefix = colorstr('TensorBoard: ')
578
+ self.console_logger.info(
579
+ f"{prefix}Start with 'tensorboard --logdir {self.save_dir.parent}', view at http://localhost:6006/")
580
+ self.tb = SummaryWriter(str(self.save_dir))
581
+
582
+ if wandb and 'wandb' in self.include:
583
+ self.wandb = wandb.init(project=web_project_name(str(opt.project)),
584
+ name=None if opt.name == "exp" else opt.name,
585
+ config=opt)
586
+ else:
587
+ self.wandb = None
588
+
589
+ def log_metrics(self, metrics, epoch):
590
+ # Log metrics dictionary to all loggers
591
+ if self.csv:
592
+ keys, vals = list(metrics.keys()), list(metrics.values())
593
+ n = len(metrics) + 1 # number of cols
594
+ s = '' if self.csv.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # header
595
+ with open(self.csv, 'a') as f:
596
+ f.write(s + ('%23.5g,' * n % tuple([epoch] + vals)).rstrip(',') + '\n')
597
+
598
+ if self.tb:
599
+ for k, v in metrics.items():
600
+ self.tb.add_scalar(k, v, epoch)
601
+
602
+ if self.wandb:
603
+ self.wandb.log(metrics, step=epoch)
604
+
605
+ def log_images(self, files, name='Images', epoch=0):
606
+ # Log images to all loggers
607
+ files = [Path(f) for f in (files if isinstance(files, (tuple, list)) else [files])] # to Path
608
+ files = [f for f in files if f.exists()] # filter by exists
609
+
610
+ if self.tb:
611
+ for f in files:
612
+ self.tb.add_image(f.stem, cv2.imread(str(f))[..., ::-1], epoch, dataformats='HWC')
613
+
614
+ if self.wandb:
615
+ self.wandb.log({name: [wandb.Image(str(f), caption=f.name) for f in files]}, step=epoch)
616
+
617
+ def log_graph(self, model, imgsz=(640, 640)):
618
+ # Log model graph to all loggers
619
+ if self.tb:
620
+ log_tensorboard_graph(self.tb, model, imgsz)
621
+
622
+ def log_model(self, model_path, epoch=0, metadata={}):
623
+ # Log model to all loggers
624
+ if self.wandb:
625
+ art = wandb.Artifact(name=f"run_{wandb.run.id}_model", type="model", metadata=metadata)
626
+ art.add_file(str(model_path))
627
+ wandb.log_artifact(art)
628
+
629
+ def update_params(self, params):
630
+ # Update the paramters logged
631
+ if self.wandb:
632
+ wandb.run.config.update(params, allow_val_change=True)
633
+
634
+
635
+ def log_tensorboard_graph(tb, model, imgsz=(640, 640)):
636
+ # Log model graph to TensorBoard
637
+ try:
638
+ p = next(model.parameters()) # for device, type
639
+ imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz # expand
640
+ im = torch.zeros((1, 3, *imgsz)).to(p.device).type_as(p) # input image (WARNING: must be zeros, not empty)
641
+ with warnings.catch_warnings():
642
+ warnings.simplefilter('ignore') # suppress jit trace warning
643
+ tb.add_graph(torch.jit.trace(de_parallel(model), im, strict=False), [])
644
+ except Exception as e:
645
+ LOGGER.warning(f'WARNING ⚠️ TensorBoard graph visualization failure {e}')
646
+
647
+
648
+ def web_project_name(project):
649
+ # Convert local project name to web project name
650
+ if not project.startswith('runs/train'):
651
+ return project
652
+ suffix = '-Classify' if project.endswith('-cls') else '-Segment' if project.endswith('-seg') else ''
653
+ return f'YOLO{suffix}'
utils/loggers/clearml/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # init
utils/loggers/clearml/clearml_utils.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Main Logger class for ClearML experiment tracking."""
2
+ import glob
3
+ import re
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import yaml
8
+
9
+ from utils.plots import Annotator, colors
10
+
11
+ try:
12
+ import clearml
13
+ from clearml import Dataset, Task
14
+
15
+ assert hasattr(clearml, '__version__') # verify package import not local dir
16
+ except (ImportError, AssertionError):
17
+ clearml = None
18
+
19
+
20
+ def construct_dataset(clearml_info_string):
21
+ """Load in a clearml dataset and fill the internal data_dict with its contents.
22
+ """
23
+ dataset_id = clearml_info_string.replace('clearml://', '')
24
+ dataset = Dataset.get(dataset_id=dataset_id)
25
+ dataset_root_path = Path(dataset.get_local_copy())
26
+
27
+ # We'll search for the yaml file definition in the dataset
28
+ yaml_filenames = list(glob.glob(str(dataset_root_path / "*.yaml")) + glob.glob(str(dataset_root_path / "*.yml")))
29
+ if len(yaml_filenames) > 1:
30
+ raise ValueError('More than one yaml file was found in the dataset root, cannot determine which one contains '
31
+ 'the dataset definition this way.')
32
+ elif len(yaml_filenames) == 0:
33
+ raise ValueError('No yaml definition found in dataset root path, check that there is a correct yaml file '
34
+ 'inside the dataset root path.')
35
+ with open(yaml_filenames[0]) as f:
36
+ dataset_definition = yaml.safe_load(f)
37
+
38
+ assert set(dataset_definition.keys()).issuperset(
39
+ {'train', 'test', 'val', 'nc', 'names'}
40
+ ), "The right keys were not found in the yaml file, make sure it at least has the following keys: ('train', 'test', 'val', 'nc', 'names')"
41
+
42
+ data_dict = dict()
43
+ data_dict['train'] = str(
44
+ (dataset_root_path / dataset_definition['train']).resolve()) if dataset_definition['train'] else None
45
+ data_dict['test'] = str(
46
+ (dataset_root_path / dataset_definition['test']).resolve()) if dataset_definition['test'] else None
47
+ data_dict['val'] = str(
48
+ (dataset_root_path / dataset_definition['val']).resolve()) if dataset_definition['val'] else None
49
+ data_dict['nc'] = dataset_definition['nc']
50
+ data_dict['names'] = dataset_definition['names']
51
+
52
+ return data_dict
53
+
54
+
55
+ class ClearmlLogger:
56
+ """Log training runs, datasets, models, and predictions to ClearML.
57
+
58
+ This logger sends information to ClearML at app.clear.ml or to your own hosted server. By default,
59
+ this information includes hyperparameters, system configuration and metrics, model metrics, code information and
60
+ basic data metrics and analyses.
61
+
62
+ By providing additional command line arguments to train.py, datasets,
63
+ models and predictions can also be logged.
64
+ """
65
+
66
+ def __init__(self, opt, hyp):
67
+ """
68
+ - Initialize ClearML Task, this object will capture the experiment
69
+ - Upload dataset version to ClearML Data if opt.upload_dataset is True
70
+
71
+ arguments:
72
+ opt (namespace) -- Commandline arguments for this run
73
+ hyp (dict) -- Hyperparameters for this run
74
+
75
+ """
76
+ self.current_epoch = 0
77
+ # Keep tracked of amount of logged images to enforce a limit
78
+ self.current_epoch_logged_images = set()
79
+ # Maximum number of images to log to clearML per epoch
80
+ self.max_imgs_to_log_per_epoch = 16
81
+ # Get the interval of epochs when bounding box images should be logged
82
+ self.bbox_interval = opt.bbox_interval
83
+ self.clearml = clearml
84
+ self.task = None
85
+ self.data_dict = None
86
+ if self.clearml:
87
+ self.task = Task.init(
88
+ project_name=opt.project if opt.project != 'runs/train' else 'YOLOv5',
89
+ task_name=opt.name if opt.name != 'exp' else 'Training',
90
+ tags=['YOLOv5'],
91
+ output_uri=True,
92
+ auto_connect_frameworks={'pytorch': False}
93
+ # We disconnect pytorch auto-detection, because we added manual model save points in the code
94
+ )
95
+ # ClearML's hooks will already grab all general parameters
96
+ # Only the hyperparameters coming from the yaml config file
97
+ # will have to be added manually!
98
+ self.task.connect(hyp, name='Hyperparameters')
99
+
100
+ # Get ClearML Dataset Version if requested
101
+ if opt.data.startswith('clearml://'):
102
+ # data_dict should have the following keys:
103
+ # names, nc (number of classes), test, train, val (all three relative paths to ../datasets)
104
+ self.data_dict = construct_dataset(opt.data)
105
+ # Set data to data_dict because wandb will crash without this information and opt is the best way
106
+ # to give it to them
107
+ opt.data = self.data_dict
108
+
109
+ def log_debug_samples(self, files, title='Debug Samples'):
110
+ """
111
+ Log files (images) as debug samples in the ClearML task.
112
+
113
+ arguments:
114
+ files (List(PosixPath)) a list of file paths in PosixPath format
115
+ title (str) A title that groups together images with the same values
116
+ """
117
+ for f in files:
118
+ if f.exists():
119
+ it = re.search(r'_batch(\d+)', f.name)
120
+ iteration = int(it.groups()[0]) if it else 0
121
+ self.task.get_logger().report_image(title=title,
122
+ series=f.name.replace(it.group(), ''),
123
+ local_path=str(f),
124
+ iteration=iteration)
125
+
126
+ def log_image_with_boxes(self, image_path, boxes, class_names, image, conf_threshold=0.25):
127
+ """
128
+ Draw the bounding boxes on a single image and report the result as a ClearML debug sample.
129
+
130
+ arguments:
131
+ image_path (PosixPath) the path the original image file
132
+ boxes (list): list of scaled predictions in the format - [xmin, ymin, xmax, ymax, confidence, class]
133
+ class_names (dict): dict containing mapping of class int to class name
134
+ image (Tensor): A torch tensor containing the actual image data
135
+ """
136
+ if len(self.current_epoch_logged_images) < self.max_imgs_to_log_per_epoch and self.current_epoch >= 0:
137
+ # Log every bbox_interval times and deduplicate for any intermittend extra eval runs
138
+ if self.current_epoch % self.bbox_interval == 0 and image_path not in self.current_epoch_logged_images:
139
+ im = np.ascontiguousarray(np.moveaxis(image.mul(255).clamp(0, 255).byte().cpu().numpy(), 0, 2))
140
+ annotator = Annotator(im=im, pil=True)
141
+ for i, (conf, class_nr, box) in enumerate(zip(boxes[:, 4], boxes[:, 5], boxes[:, :4])):
142
+ color = colors(i)
143
+
144
+ class_name = class_names[int(class_nr)]
145
+ confidence_percentage = round(float(conf) * 100, 2)
146
+ label = f"{class_name}: {confidence_percentage}%"
147
+
148
+ if conf > conf_threshold:
149
+ annotator.rectangle(box.cpu().numpy(), outline=color)
150
+ annotator.box_label(box.cpu().numpy(), label=label, color=color)
151
+
152
+ annotated_image = annotator.result()
153
+ self.task.get_logger().report_image(title='Bounding Boxes',
154
+ series=image_path.name,
155
+ iteration=self.current_epoch,
156
+ image=annotated_image)
157
+ self.current_epoch_logged_images.add(image_path)
utils/loggers/clearml/hpo.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from clearml import Task
2
+ # Connecting ClearML with the current process,
3
+ # from here on everything is logged automatically
4
+ from clearml.automation import HyperParameterOptimizer, UniformParameterRange
5
+ from clearml.automation.optuna import OptimizerOptuna
6
+
7
+ task = Task.init(project_name='Hyper-Parameter Optimization',
8
+ task_name='YOLOv5',
9
+ task_type=Task.TaskTypes.optimizer,
10
+ reuse_last_task_id=False)
11
+
12
+ # Example use case:
13
+ optimizer = HyperParameterOptimizer(
14
+ # This is the experiment we want to optimize
15
+ base_task_id='<your_template_task_id>',
16
+ # here we define the hyper-parameters to optimize
17
+ # Notice: The parameter name should exactly match what you see in the UI: <section_name>/<parameter>
18
+ # For Example, here we see in the base experiment a section Named: "General"
19
+ # under it a parameter named "batch_size", this becomes "General/batch_size"
20
+ # If you have `argparse` for example, then arguments will appear under the "Args" section,
21
+ # and you should instead pass "Args/batch_size"
22
+ hyper_parameters=[
23
+ UniformParameterRange('Hyperparameters/lr0', min_value=1e-5, max_value=1e-1),
24
+ UniformParameterRange('Hyperparameters/lrf', min_value=0.01, max_value=1.0),
25
+ UniformParameterRange('Hyperparameters/momentum', min_value=0.6, max_value=0.98),
26
+ UniformParameterRange('Hyperparameters/weight_decay', min_value=0.0, max_value=0.001),
27
+ UniformParameterRange('Hyperparameters/warmup_epochs', min_value=0.0, max_value=5.0),
28
+ UniformParameterRange('Hyperparameters/warmup_momentum', min_value=0.0, max_value=0.95),
29
+ UniformParameterRange('Hyperparameters/warmup_bias_lr', min_value=0.0, max_value=0.2),
30
+ UniformParameterRange('Hyperparameters/box', min_value=0.02, max_value=0.2),
31
+ UniformParameterRange('Hyperparameters/cls', min_value=0.2, max_value=4.0),
32
+ UniformParameterRange('Hyperparameters/cls_pw', min_value=0.5, max_value=2.0),
33
+ UniformParameterRange('Hyperparameters/obj', min_value=0.2, max_value=4.0),
34
+ UniformParameterRange('Hyperparameters/obj_pw', min_value=0.5, max_value=2.0),
35
+ UniformParameterRange('Hyperparameters/iou_t', min_value=0.1, max_value=0.7),
36
+ UniformParameterRange('Hyperparameters/anchor_t', min_value=2.0, max_value=8.0),
37
+ UniformParameterRange('Hyperparameters/fl_gamma', min_value=0.0, max_value=4.0),
38
+ UniformParameterRange('Hyperparameters/hsv_h', min_value=0.0, max_value=0.1),
39
+ UniformParameterRange('Hyperparameters/hsv_s', min_value=0.0, max_value=0.9),
40
+ UniformParameterRange('Hyperparameters/hsv_v', min_value=0.0, max_value=0.9),
41
+ UniformParameterRange('Hyperparameters/degrees', min_value=0.0, max_value=45.0),
42
+ UniformParameterRange('Hyperparameters/translate', min_value=0.0, max_value=0.9),
43
+ UniformParameterRange('Hyperparameters/scale', min_value=0.0, max_value=0.9),
44
+ UniformParameterRange('Hyperparameters/shear', min_value=0.0, max_value=10.0),
45
+ UniformParameterRange('Hyperparameters/perspective', min_value=0.0, max_value=0.001),
46
+ UniformParameterRange('Hyperparameters/flipud', min_value=0.0, max_value=1.0),
47
+ UniformParameterRange('Hyperparameters/fliplr', min_value=0.0, max_value=1.0),
48
+ UniformParameterRange('Hyperparameters/mosaic', min_value=0.0, max_value=1.0),
49
+ UniformParameterRange('Hyperparameters/mixup', min_value=0.0, max_value=1.0),
50
+ UniformParameterRange('Hyperparameters/copy_paste', min_value=0.0, max_value=1.0)],
51
+ # this is the objective metric we want to maximize/minimize
52
+ objective_metric_title='metrics',
53
+ objective_metric_series='mAP_0.5',
54
+ # now we decide if we want to maximize it or minimize it (accuracy we maximize)
55
+ objective_metric_sign='max',
56
+ # let us limit the number of concurrent experiments,
57
+ # this in turn will make sure we do dont bombard the scheduler with experiments.
58
+ # if we have an auto-scaler connected, this, by proxy, will limit the number of machine
59
+ max_number_of_concurrent_tasks=1,
60
+ # this is the optimizer class (actually doing the optimization)
61
+ # Currently, we can choose from GridSearch, RandomSearch or OptimizerBOHB (Bayesian optimization Hyper-Band)
62
+ optimizer_class=OptimizerOptuna,
63
+ # If specified only the top K performing Tasks will be kept, the others will be automatically archived
64
+ save_top_k_tasks_only=5, # 5,
65
+ compute_time_limit=None,
66
+ total_max_jobs=20,
67
+ min_iteration_per_job=None,
68
+ max_iteration_per_job=None,
69
+ )
70
+
71
+ # report every 10 seconds, this is way too often, but we are testing here
72
+ optimizer.set_report_period(10 / 60)
73
+ # You can also use the line below instead to run all the optimizer tasks locally, without using queues or agent
74
+ # an_optimizer.start_locally(job_complete_callback=job_complete_callback)
75
+ # set the time limit for the optimization process (2 hours)
76
+ optimizer.set_time_limit(in_minutes=120.0)
77
+ # Start the optimization process in the local environment
78
+ optimizer.start_locally()
79
+ # wait until process is done (notice we are controlling the optimization process in the background)
80
+ optimizer.wait()
81
+ # make sure background optimization stopped
82
+ optimizer.stop()
83
+
84
+ print('We are done, good bye')
utils/loggers/comet/__init__.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import json
3
+ import logging
4
+ import os
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ FILE = Path(__file__).resolve()
11
+ ROOT = FILE.parents[3] # YOLOv5 root directory
12
+ if str(ROOT) not in sys.path:
13
+ sys.path.append(str(ROOT)) # add ROOT to PATH
14
+
15
+ try:
16
+ import comet_ml
17
+
18
+ # Project Configuration
19
+ config = comet_ml.config.get_config()
20
+ COMET_PROJECT_NAME = config.get_string(os.getenv("COMET_PROJECT_NAME"), "comet.project_name", default="yolov5")
21
+ except (ModuleNotFoundError, ImportError):
22
+ comet_ml = None
23
+ COMET_PROJECT_NAME = None
24
+
25
+ import PIL
26
+ import torch
27
+ import torchvision.transforms as T
28
+ import yaml
29
+
30
+ from utils.dataloaders import img2label_paths
31
+ from utils.general import check_dataset, scale_boxes, xywh2xyxy
32
+ from utils.metrics import box_iou
33
+
34
+ COMET_PREFIX = "comet://"
35
+
36
+ COMET_MODE = os.getenv("COMET_MODE", "online")
37
+
38
+ # Model Saving Settings
39
+ COMET_MODEL_NAME = os.getenv("COMET_MODEL_NAME", "yolov5")
40
+
41
+ # Dataset Artifact Settings
42
+ COMET_UPLOAD_DATASET = os.getenv("COMET_UPLOAD_DATASET", "false").lower() == "true"
43
+
44
+ # Evaluation Settings
45
+ COMET_LOG_CONFUSION_MATRIX = os.getenv("COMET_LOG_CONFUSION_MATRIX", "true").lower() == "true"
46
+ COMET_LOG_PREDICTIONS = os.getenv("COMET_LOG_PREDICTIONS", "true").lower() == "true"
47
+ COMET_MAX_IMAGE_UPLOADS = int(os.getenv("COMET_MAX_IMAGE_UPLOADS", 100))
48
+
49
+ # Confusion Matrix Settings
50
+ CONF_THRES = float(os.getenv("CONF_THRES", 0.001))
51
+ IOU_THRES = float(os.getenv("IOU_THRES", 0.6))
52
+
53
+ # Batch Logging Settings
54
+ COMET_LOG_BATCH_METRICS = os.getenv("COMET_LOG_BATCH_METRICS", "false").lower() == "true"
55
+ COMET_BATCH_LOGGING_INTERVAL = os.getenv("COMET_BATCH_LOGGING_INTERVAL", 1)
56
+ COMET_PREDICTION_LOGGING_INTERVAL = os.getenv("COMET_PREDICTION_LOGGING_INTERVAL", 1)
57
+ COMET_LOG_PER_CLASS_METRICS = os.getenv("COMET_LOG_PER_CLASS_METRICS", "false").lower() == "true"
58
+
59
+ RANK = int(os.getenv("RANK", -1))
60
+
61
+ to_pil = T.ToPILImage()
62
+
63
+
64
+ class CometLogger:
65
+ """Log metrics, parameters, source code, models and much more
66
+ with Comet
67
+ """
68
+
69
+ def __init__(self, opt, hyp, run_id=None, job_type="Training", **experiment_kwargs) -> None:
70
+ self.job_type = job_type
71
+ self.opt = opt
72
+ self.hyp = hyp
73
+
74
+ # Comet Flags
75
+ self.comet_mode = COMET_MODE
76
+
77
+ self.save_model = opt.save_period > -1
78
+ self.model_name = COMET_MODEL_NAME
79
+
80
+ # Batch Logging Settings
81
+ self.log_batch_metrics = COMET_LOG_BATCH_METRICS
82
+ self.comet_log_batch_interval = COMET_BATCH_LOGGING_INTERVAL
83
+
84
+ # Dataset Artifact Settings
85
+ self.upload_dataset = self.opt.upload_dataset if self.opt.upload_dataset else COMET_UPLOAD_DATASET
86
+ self.resume = self.opt.resume
87
+
88
+ # Default parameters to pass to Experiment objects
89
+ self.default_experiment_kwargs = {
90
+ "log_code": False,
91
+ "log_env_gpu": True,
92
+ "log_env_cpu": True,
93
+ "project_name": COMET_PROJECT_NAME,}
94
+ self.default_experiment_kwargs.update(experiment_kwargs)
95
+ self.experiment = self._get_experiment(self.comet_mode, run_id)
96
+
97
+ self.data_dict = self.check_dataset(self.opt.data)
98
+ self.class_names = self.data_dict["names"]
99
+ self.num_classes = self.data_dict["nc"]
100
+
101
+ self.logged_images_count = 0
102
+ self.max_images = COMET_MAX_IMAGE_UPLOADS
103
+
104
+ if run_id is None:
105
+ self.experiment.log_other("Created from", "YOLOv5")
106
+ if not isinstance(self.experiment, comet_ml.OfflineExperiment):
107
+ workspace, project_name, experiment_id = self.experiment.url.split("/")[-3:]
108
+ self.experiment.log_other(
109
+ "Run Path",
110
+ f"{workspace}/{project_name}/{experiment_id}",
111
+ )
112
+ self.log_parameters(vars(opt))
113
+ self.log_parameters(self.opt.hyp)
114
+ self.log_asset_data(
115
+ self.opt.hyp,
116
+ name="hyperparameters.json",
117
+ metadata={"type": "hyp-config-file"},
118
+ )
119
+ self.log_asset(
120
+ f"{self.opt.save_dir}/opt.yaml",
121
+ metadata={"type": "opt-config-file"},
122
+ )
123
+
124
+ self.comet_log_confusion_matrix = COMET_LOG_CONFUSION_MATRIX
125
+
126
+ if hasattr(self.opt, "conf_thres"):
127
+ self.conf_thres = self.opt.conf_thres
128
+ else:
129
+ self.conf_thres = CONF_THRES
130
+ if hasattr(self.opt, "iou_thres"):
131
+ self.iou_thres = self.opt.iou_thres
132
+ else:
133
+ self.iou_thres = IOU_THRES
134
+
135
+ self.log_parameters({"val_iou_threshold": self.iou_thres, "val_conf_threshold": self.conf_thres})
136
+
137
+ self.comet_log_predictions = COMET_LOG_PREDICTIONS
138
+ if self.opt.bbox_interval == -1:
139
+ self.comet_log_prediction_interval = 1 if self.opt.epochs < 10 else self.opt.epochs // 10
140
+ else:
141
+ self.comet_log_prediction_interval = self.opt.bbox_interval
142
+
143
+ if self.comet_log_predictions:
144
+ self.metadata_dict = {}
145
+ self.logged_image_names = []
146
+
147
+ self.comet_log_per_class_metrics = COMET_LOG_PER_CLASS_METRICS
148
+
149
+ self.experiment.log_others({
150
+ "comet_mode": COMET_MODE,
151
+ "comet_max_image_uploads": COMET_MAX_IMAGE_UPLOADS,
152
+ "comet_log_per_class_metrics": COMET_LOG_PER_CLASS_METRICS,
153
+ "comet_log_batch_metrics": COMET_LOG_BATCH_METRICS,
154
+ "comet_log_confusion_matrix": COMET_LOG_CONFUSION_MATRIX,
155
+ "comet_model_name": COMET_MODEL_NAME,})
156
+
157
+ # Check if running the Experiment with the Comet Optimizer
158
+ if hasattr(self.opt, "comet_optimizer_id"):
159
+ self.experiment.log_other("optimizer_id", self.opt.comet_optimizer_id)
160
+ self.experiment.log_other("optimizer_objective", self.opt.comet_optimizer_objective)
161
+ self.experiment.log_other("optimizer_metric", self.opt.comet_optimizer_metric)
162
+ self.experiment.log_other("optimizer_parameters", json.dumps(self.hyp))
163
+
164
+ def _get_experiment(self, mode, experiment_id=None):
165
+ if mode == "offline":
166
+ if experiment_id is not None:
167
+ return comet_ml.ExistingOfflineExperiment(
168
+ previous_experiment=experiment_id,
169
+ **self.default_experiment_kwargs,
170
+ )
171
+
172
+ return comet_ml.OfflineExperiment(**self.default_experiment_kwargs,)
173
+
174
+ else:
175
+ try:
176
+ if experiment_id is not None:
177
+ return comet_ml.ExistingExperiment(
178
+ previous_experiment=experiment_id,
179
+ **self.default_experiment_kwargs,
180
+ )
181
+
182
+ return comet_ml.Experiment(**self.default_experiment_kwargs)
183
+
184
+ except ValueError:
185
+ logger.warning("COMET WARNING: "
186
+ "Comet credentials have not been set. "
187
+ "Comet will default to offline logging. "
188
+ "Please set your credentials to enable online logging.")
189
+ return self._get_experiment("offline", experiment_id)
190
+
191
+ return
192
+
193
+ def log_metrics(self, log_dict, **kwargs):
194
+ self.experiment.log_metrics(log_dict, **kwargs)
195
+
196
+ def log_parameters(self, log_dict, **kwargs):
197
+ self.experiment.log_parameters(log_dict, **kwargs)
198
+
199
+ def log_asset(self, asset_path, **kwargs):
200
+ self.experiment.log_asset(asset_path, **kwargs)
201
+
202
+ def log_asset_data(self, asset, **kwargs):
203
+ self.experiment.log_asset_data(asset, **kwargs)
204
+
205
+ def log_image(self, img, **kwargs):
206
+ self.experiment.log_image(img, **kwargs)
207
+
208
+ def log_model(self, path, opt, epoch, fitness_score, best_model=False):
209
+ if not self.save_model:
210
+ return
211
+
212
+ model_metadata = {
213
+ "fitness_score": fitness_score[-1],
214
+ "epochs_trained": epoch + 1,
215
+ "save_period": opt.save_period,
216
+ "total_epochs": opt.epochs,}
217
+
218
+ model_files = glob.glob(f"{path}/*.pt")
219
+ for model_path in model_files:
220
+ name = Path(model_path).name
221
+
222
+ self.experiment.log_model(
223
+ self.model_name,
224
+ file_or_folder=model_path,
225
+ file_name=name,
226
+ metadata=model_metadata,
227
+ overwrite=True,
228
+ )
229
+
230
+ def check_dataset(self, data_file):
231
+ with open(data_file) as f:
232
+ data_config = yaml.safe_load(f)
233
+
234
+ if data_config['path'].startswith(COMET_PREFIX):
235
+ path = data_config['path'].replace(COMET_PREFIX, "")
236
+ data_dict = self.download_dataset_artifact(path)
237
+
238
+ return data_dict
239
+
240
+ self.log_asset(self.opt.data, metadata={"type": "data-config-file"})
241
+
242
+ return check_dataset(data_file)
243
+
244
+ def log_predictions(self, image, labelsn, path, shape, predn):
245
+ if self.logged_images_count >= self.max_images:
246
+ return
247
+ detections = predn[predn[:, 4] > self.conf_thres]
248
+ iou = box_iou(labelsn[:, 1:], detections[:, :4])
249
+ mask, _ = torch.where(iou > self.iou_thres)
250
+ if len(mask) == 0:
251
+ return
252
+
253
+ filtered_detections = detections[mask]
254
+ filtered_labels = labelsn[mask]
255
+
256
+ image_id = path.split("/")[-1].split(".")[0]
257
+ image_name = f"{image_id}_curr_epoch_{self.experiment.curr_epoch}"
258
+ if image_name not in self.logged_image_names:
259
+ native_scale_image = PIL.Image.open(path)
260
+ self.log_image(native_scale_image, name=image_name)
261
+ self.logged_image_names.append(image_name)
262
+
263
+ metadata = []
264
+ for cls, *xyxy in filtered_labels.tolist():
265
+ metadata.append({
266
+ "label": f"{self.class_names[int(cls)]}-gt",
267
+ "score": 100,
268
+ "box": {
269
+ "x": xyxy[0],
270
+ "y": xyxy[1],
271
+ "x2": xyxy[2],
272
+ "y2": xyxy[3]},})
273
+ for *xyxy, conf, cls in filtered_detections.tolist():
274
+ metadata.append({
275
+ "label": f"{self.class_names[int(cls)]}",
276
+ "score": conf * 100,
277
+ "box": {
278
+ "x": xyxy[0],
279
+ "y": xyxy[1],
280
+ "x2": xyxy[2],
281
+ "y2": xyxy[3]},})
282
+
283
+ self.metadata_dict[image_name] = metadata
284
+ self.logged_images_count += 1
285
+
286
+ return
287
+
288
+ def preprocess_prediction(self, image, labels, shape, pred):
289
+ nl, _ = labels.shape[0], pred.shape[0]
290
+
291
+ # Predictions
292
+ if self.opt.single_cls:
293
+ pred[:, 5] = 0
294
+
295
+ predn = pred.clone()
296
+ scale_boxes(image.shape[1:], predn[:, :4], shape[0], shape[1])
297
+
298
+ labelsn = None
299
+ if nl:
300
+ tbox = xywh2xyxy(labels[:, 1:5]) # target boxes
301
+ scale_boxes(image.shape[1:], tbox, shape[0], shape[1]) # native-space labels
302
+ labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels
303
+ scale_boxes(image.shape[1:], predn[:, :4], shape[0], shape[1]) # native-space pred
304
+
305
+ return predn, labelsn
306
+
307
+ def add_assets_to_artifact(self, artifact, path, asset_path, split):
308
+ img_paths = sorted(glob.glob(f"{asset_path}/*"))
309
+ label_paths = img2label_paths(img_paths)
310
+
311
+ for image_file, label_file in zip(img_paths, label_paths):
312
+ image_logical_path, label_logical_path = map(lambda x: os.path.relpath(x, path), [image_file, label_file])
313
+
314
+ try:
315
+ artifact.add(image_file, logical_path=image_logical_path, metadata={"split": split})
316
+ artifact.add(label_file, logical_path=label_logical_path, metadata={"split": split})
317
+ except ValueError as e:
318
+ logger.error('COMET ERROR: Error adding file to Artifact. Skipping file.')
319
+ logger.error(f"COMET ERROR: {e}")
320
+ continue
321
+
322
+ return artifact
323
+
324
+ def upload_dataset_artifact(self):
325
+ dataset_name = self.data_dict.get("dataset_name", "yolov5-dataset")
326
+ path = str((ROOT / Path(self.data_dict["path"])).resolve())
327
+
328
+ metadata = self.data_dict.copy()
329
+ for key in ["train", "val", "test"]:
330
+ split_path = metadata.get(key)
331
+ if split_path is not None:
332
+ metadata[key] = split_path.replace(path, "")
333
+
334
+ artifact = comet_ml.Artifact(name=dataset_name, artifact_type="dataset", metadata=metadata)
335
+ for key in metadata.keys():
336
+ if key in ["train", "val", "test"]:
337
+ if isinstance(self.upload_dataset, str) and (key != self.upload_dataset):
338
+ continue
339
+
340
+ asset_path = self.data_dict.get(key)
341
+ if asset_path is not None:
342
+ artifact = self.add_assets_to_artifact(artifact, path, asset_path, key)
343
+
344
+ self.experiment.log_artifact(artifact)
345
+
346
+ return
347
+
348
+ def download_dataset_artifact(self, artifact_path):
349
+ logged_artifact = self.experiment.get_artifact(artifact_path)
350
+ artifact_save_dir = str(Path(self.opt.save_dir) / logged_artifact.name)
351
+ logged_artifact.download(artifact_save_dir)
352
+
353
+ metadata = logged_artifact.metadata
354
+ data_dict = metadata.copy()
355
+ data_dict["path"] = artifact_save_dir
356
+
357
+ metadata_names = metadata.get("names")
358
+ if type(metadata_names) == dict:
359
+ data_dict["names"] = {int(k): v for k, v in metadata.get("names").items()}
360
+ elif type(metadata_names) == list:
361
+ data_dict["names"] = {int(k): v for k, v in zip(range(len(metadata_names)), metadata_names)}
362
+ else:
363
+ raise "Invalid 'names' field in dataset yaml file. Please use a list or dictionary"
364
+
365
+ data_dict = self.update_data_paths(data_dict)
366
+ return data_dict
367
+
368
+ def update_data_paths(self, data_dict):
369
+ path = data_dict.get("path", "")
370
+
371
+ for split in ["train", "val", "test"]:
372
+ if data_dict.get(split):
373
+ split_path = data_dict.get(split)
374
+ data_dict[split] = (f"{path}/{split_path}" if isinstance(split, str) else [
375
+ f"{path}/{x}" for x in split_path])
376
+
377
+ return data_dict
378
+
379
+ def on_pretrain_routine_end(self, paths):
380
+ if self.opt.resume:
381
+ return
382
+
383
+ for path in paths:
384
+ self.log_asset(str(path))
385
+
386
+ if self.upload_dataset:
387
+ if not self.resume:
388
+ self.upload_dataset_artifact()
389
+
390
+ return
391
+
392
+ def on_train_start(self):
393
+ self.log_parameters(self.hyp)
394
+
395
+ def on_train_epoch_start(self):
396
+ return
397
+
398
+ def on_train_epoch_end(self, epoch):
399
+ self.experiment.curr_epoch = epoch
400
+
401
+ return
402
+
403
+ def on_train_batch_start(self):
404
+ return
405
+
406
+ def on_train_batch_end(self, log_dict, step):
407
+ self.experiment.curr_step = step
408
+ if self.log_batch_metrics and (step % self.comet_log_batch_interval == 0):
409
+ self.log_metrics(log_dict, step=step)
410
+
411
+ return
412
+
413
+ def on_train_end(self, files, save_dir, last, best, epoch, results):
414
+ if self.comet_log_predictions:
415
+ curr_epoch = self.experiment.curr_epoch
416
+ self.experiment.log_asset_data(self.metadata_dict, "image-metadata.json", epoch=curr_epoch)
417
+
418
+ for f in files:
419
+ self.log_asset(f, metadata={"epoch": epoch})
420
+ self.log_asset(f"{save_dir}/results.csv", metadata={"epoch": epoch})
421
+
422
+ if not self.opt.evolve:
423
+ model_path = str(best if best.exists() else last)
424
+ name = Path(model_path).name
425
+ if self.save_model:
426
+ self.experiment.log_model(
427
+ self.model_name,
428
+ file_or_folder=model_path,
429
+ file_name=name,
430
+ overwrite=True,
431
+ )
432
+
433
+ # Check if running Experiment with Comet Optimizer
434
+ if hasattr(self.opt, 'comet_optimizer_id'):
435
+ metric = results.get(self.opt.comet_optimizer_metric)
436
+ self.experiment.log_other('optimizer_metric_value', metric)
437
+
438
+ self.finish_run()
439
+
440
+ def on_val_start(self):
441
+ return
442
+
443
+ def on_val_batch_start(self):
444
+ return
445
+
446
+ def on_val_batch_end(self, batch_i, images, targets, paths, shapes, outputs):
447
+ if not (self.comet_log_predictions and ((batch_i + 1) % self.comet_log_prediction_interval == 0)):
448
+ return
449
+
450
+ for si, pred in enumerate(outputs):
451
+ if len(pred) == 0:
452
+ continue
453
+
454
+ image = images[si]
455
+ labels = targets[targets[:, 0] == si, 1:]
456
+ shape = shapes[si]
457
+ path = paths[si]
458
+ predn, labelsn = self.preprocess_prediction(image, labels, shape, pred)
459
+ if labelsn is not None:
460
+ self.log_predictions(image, labelsn, path, shape, predn)
461
+
462
+ return
463
+
464
+ def on_val_end(self, nt, tp, fp, p, r, f1, ap, ap50, ap_class, confusion_matrix):
465
+ if self.comet_log_per_class_metrics:
466
+ if self.num_classes > 1:
467
+ for i, c in enumerate(ap_class):
468
+ class_name = self.class_names[c]
469
+ self.experiment.log_metrics(
470
+ {
471
+ 'mAP@.5': ap50[i],
472
+ 'mAP@.5:.95': ap[i],
473
+ 'precision': p[i],
474
+ 'recall': r[i],
475
+ 'f1': f1[i],
476
+ 'true_positives': tp[i],
477
+ 'false_positives': fp[i],
478
+ 'support': nt[c]},
479
+ prefix=class_name)
480
+
481
+ if self.comet_log_confusion_matrix:
482
+ epoch = self.experiment.curr_epoch
483
+ class_names = list(self.class_names.values())
484
+ class_names.append("background")
485
+ num_classes = len(class_names)
486
+
487
+ self.experiment.log_confusion_matrix(
488
+ matrix=confusion_matrix.matrix,
489
+ max_categories=num_classes,
490
+ labels=class_names,
491
+ epoch=epoch,
492
+ column_label='Actual Category',
493
+ row_label='Predicted Category',
494
+ file_name=f"confusion-matrix-epoch-{epoch}.json",
495
+ )
496
+
497
+ def on_fit_epoch_end(self, result, epoch):
498
+ self.log_metrics(result, epoch=epoch)
499
+
500
+ def on_model_save(self, last, epoch, final_epoch, best_fitness, fi):
501
+ if ((epoch + 1) % self.opt.save_period == 0 and not final_epoch) and self.opt.save_period != -1:
502
+ self.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi)
503
+
504
+ def on_params_update(self, params):
505
+ self.log_parameters(params)
506
+
507
+ def finish_run(self):
508
+ self.experiment.end()
utils/loggers/comet/comet_utils.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from urllib.parse import urlparse
4
+
5
+ try:
6
+ import comet_ml
7
+ except (ModuleNotFoundError, ImportError):
8
+ comet_ml = None
9
+
10
+ import yaml
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ COMET_PREFIX = "comet://"
15
+ COMET_MODEL_NAME = os.getenv("COMET_MODEL_NAME", "yolov5")
16
+ COMET_DEFAULT_CHECKPOINT_FILENAME = os.getenv("COMET_DEFAULT_CHECKPOINT_FILENAME", "last.pt")
17
+
18
+
19
+ def download_model_checkpoint(opt, experiment):
20
+ model_dir = f"{opt.project}/{experiment.name}"
21
+ os.makedirs(model_dir, exist_ok=True)
22
+
23
+ model_name = COMET_MODEL_NAME
24
+ model_asset_list = experiment.get_model_asset_list(model_name)
25
+
26
+ if len(model_asset_list) == 0:
27
+ logger.error(f"COMET ERROR: No checkpoints found for model name : {model_name}")
28
+ return
29
+
30
+ model_asset_list = sorted(
31
+ model_asset_list,
32
+ key=lambda x: x["step"],
33
+ reverse=True,
34
+ )
35
+ logged_checkpoint_map = {asset["fileName"]: asset["assetId"] for asset in model_asset_list}
36
+
37
+ resource_url = urlparse(opt.weights)
38
+ checkpoint_filename = resource_url.query
39
+
40
+ if checkpoint_filename:
41
+ asset_id = logged_checkpoint_map.get(checkpoint_filename)
42
+ else:
43
+ asset_id = logged_checkpoint_map.get(COMET_DEFAULT_CHECKPOINT_FILENAME)
44
+ checkpoint_filename = COMET_DEFAULT_CHECKPOINT_FILENAME
45
+
46
+ if asset_id is None:
47
+ logger.error(f"COMET ERROR: Checkpoint {checkpoint_filename} not found in the given Experiment")
48
+ return
49
+
50
+ try:
51
+ logger.info(f"COMET INFO: Downloading checkpoint {checkpoint_filename}")
52
+ asset_filename = checkpoint_filename
53
+
54
+ model_binary = experiment.get_asset(asset_id, return_type="binary", stream=False)
55
+ model_download_path = f"{model_dir}/{asset_filename}"
56
+ with open(model_download_path, "wb") as f:
57
+ f.write(model_binary)
58
+
59
+ opt.weights = model_download_path
60
+
61
+ except Exception as e:
62
+ logger.warning("COMET WARNING: Unable to download checkpoint from Comet")
63
+ logger.exception(e)
64
+
65
+
66
+ def set_opt_parameters(opt, experiment):
67
+ """Update the opts Namespace with parameters
68
+ from Comet's ExistingExperiment when resuming a run
69
+
70
+ Args:
71
+ opt (argparse.Namespace): Namespace of command line options
72
+ experiment (comet_ml.APIExperiment): Comet API Experiment object
73
+ """
74
+ asset_list = experiment.get_asset_list()
75
+ resume_string = opt.resume
76
+
77
+ for asset in asset_list:
78
+ if asset["fileName"] == "opt.yaml":
79
+ asset_id = asset["assetId"]
80
+ asset_binary = experiment.get_asset(asset_id, return_type="binary", stream=False)
81
+ opt_dict = yaml.safe_load(asset_binary)
82
+ for key, value in opt_dict.items():
83
+ setattr(opt, key, value)
84
+ opt.resume = resume_string
85
+
86
+ # Save hyperparameters to YAML file
87
+ # Necessary to pass checks in training script
88
+ save_dir = f"{opt.project}/{experiment.name}"
89
+ os.makedirs(save_dir, exist_ok=True)
90
+
91
+ hyp_yaml_path = f"{save_dir}/hyp.yaml"
92
+ with open(hyp_yaml_path, "w") as f:
93
+ yaml.dump(opt.hyp, f)
94
+ opt.hyp = hyp_yaml_path
95
+
96
+
97
+ def check_comet_weights(opt):
98
+ """Downloads model weights from Comet and updates the
99
+ weights path to point to saved weights location
100
+
101
+ Args:
102
+ opt (argparse.Namespace): Command Line arguments passed
103
+ to YOLOv5 training script
104
+
105
+ Returns:
106
+ None/bool: Return True if weights are successfully downloaded
107
+ else return None
108
+ """
109
+ if comet_ml is None:
110
+ return
111
+
112
+ if isinstance(opt.weights, str):
113
+ if opt.weights.startswith(COMET_PREFIX):
114
+ api = comet_ml.API()
115
+ resource = urlparse(opt.weights)
116
+ experiment_path = f"{resource.netloc}{resource.path}"
117
+ experiment = api.get(experiment_path)
118
+ download_model_checkpoint(opt, experiment)
119
+ return True
120
+
121
+ return None
122
+
123
+
124
+ def check_comet_resume(opt):
125
+ """Restores run parameters to its original state based on the model checkpoint
126
+ and logged Experiment parameters.
127
+
128
+ Args:
129
+ opt (argparse.Namespace): Command Line arguments passed
130
+ to YOLOv5 training script
131
+
132
+ Returns:
133
+ None/bool: Return True if the run is restored successfully
134
+ else return None
135
+ """
136
+ if comet_ml is None:
137
+ return
138
+
139
+ if isinstance(opt.resume, str):
140
+ if opt.resume.startswith(COMET_PREFIX):
141
+ api = comet_ml.API()
142
+ resource = urlparse(opt.resume)
143
+ experiment_path = f"{resource.netloc}{resource.path}"
144
+ experiment = api.get(experiment_path)
145
+ set_opt_parameters(opt, experiment)
146
+ download_model_checkpoint(opt, experiment)
147
+
148
+ return True
149
+
150
+ return None
utils/loggers/comet/hpo.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import logging
4
+ import os
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ import comet_ml
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ FILE = Path(__file__).resolve()
13
+ ROOT = FILE.parents[3] # YOLOv5 root directory
14
+ if str(ROOT) not in sys.path:
15
+ sys.path.append(str(ROOT)) # add ROOT to PATH
16
+
17
+ from train import train
18
+ from utils.callbacks import Callbacks
19
+ from utils.general import increment_path
20
+ from utils.torch_utils import select_device
21
+
22
+ # Project Configuration
23
+ config = comet_ml.config.get_config()
24
+ COMET_PROJECT_NAME = config.get_string(os.getenv("COMET_PROJECT_NAME"), "comet.project_name", default="yolov5")
25
+
26
+
27
+ def get_args(known=False):
28
+ parser = argparse.ArgumentParser()
29
+ parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s.pt', help='initial weights path')
30
+ parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
31
+ parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
32
+ parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch-low.yaml', help='hyperparameters path')
33
+ parser.add_argument('--epochs', type=int, default=300, help='total training epochs')
34
+ parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs, -1 for autobatch')
35
+ parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
36
+ parser.add_argument('--rect', action='store_true', help='rectangular training')
37
+ parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
38
+ parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
39
+ parser.add_argument('--noval', action='store_true', help='only validate final epoch')
40
+ parser.add_argument('--noautoanchor', action='store_true', help='disable AutoAnchor')
41
+ parser.add_argument('--noplots', action='store_true', help='save no plot files')
42
+ parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
43
+ parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
44
+ parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"')
45
+ parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
46
+ parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
47
+ parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
48
+ parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
49
+ parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'AdamW'], default='SGD', help='optimizer')
50
+ parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
51
+ parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
52
+ parser.add_argument('--project', default=ROOT / 'runs/train', help='save to project/name')
53
+ parser.add_argument('--name', default='exp', help='save to project/name')
54
+ parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
55
+ parser.add_argument('--quad', action='store_true', help='quad dataloader')
56
+ parser.add_argument('--cos-lr', action='store_true', help='cosine LR scheduler')
57
+ parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
58
+ parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')
59
+ parser.add_argument('--freeze', nargs='+', type=int, default=[0], help='Freeze layers: backbone=10, first3=0 1 2')
60
+ parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)')
61
+ parser.add_argument('--seed', type=int, default=0, help='Global training seed')
62
+ parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify')
63
+
64
+ # Weights & Biases arguments
65
+ parser.add_argument('--entity', default=None, help='W&B: Entity')
66
+ parser.add_argument('--upload_dataset', nargs='?', const=True, default=False, help='W&B: Upload data, "val" option')
67
+ parser.add_argument('--bbox_interval', type=int, default=-1, help='W&B: Set bounding-box image logging interval')
68
+ parser.add_argument('--artifact_alias', type=str, default='latest', help='W&B: Version of dataset artifact to use')
69
+
70
+ # Comet Arguments
71
+ parser.add_argument("--comet_optimizer_config", type=str, help="Comet: Path to a Comet Optimizer Config File.")
72
+ parser.add_argument("--comet_optimizer_id", type=str, help="Comet: ID of the Comet Optimizer sweep.")
73
+ parser.add_argument("--comet_optimizer_objective", type=str, help="Comet: Set to 'minimize' or 'maximize'.")
74
+ parser.add_argument("--comet_optimizer_metric", type=str, help="Comet: Metric to Optimize.")
75
+ parser.add_argument("--comet_optimizer_workers",
76
+ type=int,
77
+ default=1,
78
+ help="Comet: Number of Parallel Workers to use with the Comet Optimizer.")
79
+
80
+ return parser.parse_known_args()[0] if known else parser.parse_args()
81
+
82
+
83
+ def run(parameters, opt):
84
+ hyp_dict = {k: v for k, v in parameters.items() if k not in ["epochs", "batch_size"]}
85
+
86
+ opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok or opt.evolve))
87
+ opt.batch_size = parameters.get("batch_size")
88
+ opt.epochs = parameters.get("epochs")
89
+
90
+ device = select_device(opt.device, batch_size=opt.batch_size)
91
+ train(hyp_dict, opt, device, callbacks=Callbacks())
92
+
93
+
94
+ if __name__ == "__main__":
95
+ opt = get_args(known=True)
96
+
97
+ opt.weights = str(opt.weights)
98
+ opt.cfg = str(opt.cfg)
99
+ opt.data = str(opt.data)
100
+ opt.project = str(opt.project)
101
+
102
+ optimizer_id = os.getenv("COMET_OPTIMIZER_ID")
103
+ if optimizer_id is None:
104
+ with open(opt.comet_optimizer_config) as f:
105
+ optimizer_config = json.load(f)
106
+ optimizer = comet_ml.Optimizer(optimizer_config)
107
+ else:
108
+ optimizer = comet_ml.Optimizer(optimizer_id)
109
+
110
+ opt.comet_optimizer_id = optimizer.id
111
+ status = optimizer.status()
112
+
113
+ opt.comet_optimizer_objective = status["spec"]["objective"]
114
+ opt.comet_optimizer_metric = status["spec"]["metric"]
115
+
116
+ logger.info("COMET INFO: Starting Hyperparameter Sweep")
117
+ for parameter in optimizer.get_parameters():
118
+ run(parameter["parameters"], opt)
utils/loggers/comet/optimizer_config.json ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "algorithm": "random",
3
+ "parameters": {
4
+ "anchor_t": {
5
+ "type": "discrete",
6
+ "values": [
7
+ 2,
8
+ 8
9
+ ]
10
+ },
11
+ "batch_size": {
12
+ "type": "discrete",
13
+ "values": [
14
+ 16,
15
+ 32,
16
+ 64
17
+ ]
18
+ },
19
+ "box": {
20
+ "type": "discrete",
21
+ "values": [
22
+ 0.02,
23
+ 0.2
24
+ ]
25
+ },
26
+ "cls": {
27
+ "type": "discrete",
28
+ "values": [
29
+ 0.2
30
+ ]
31
+ },
32
+ "cls_pw": {
33
+ "type": "discrete",
34
+ "values": [
35
+ 0.5
36
+ ]
37
+ },
38
+ "copy_paste": {
39
+ "type": "discrete",
40
+ "values": [
41
+ 1
42
+ ]
43
+ },
44
+ "degrees": {
45
+ "type": "discrete",
46
+ "values": [
47
+ 0,
48
+ 45
49
+ ]
50
+ },
51
+ "epochs": {
52
+ "type": "discrete",
53
+ "values": [
54
+ 5
55
+ ]
56
+ },
57
+ "fl_gamma": {
58
+ "type": "discrete",
59
+ "values": [
60
+ 0
61
+ ]
62
+ },
63
+ "fliplr": {
64
+ "type": "discrete",
65
+ "values": [
66
+ 0
67
+ ]
68
+ },
69
+ "flipud": {
70
+ "type": "discrete",
71
+ "values": [
72
+ 0
73
+ ]
74
+ },
75
+ "hsv_h": {
76
+ "type": "discrete",
77
+ "values": [
78
+ 0
79
+ ]
80
+ },
81
+ "hsv_s": {
82
+ "type": "discrete",
83
+ "values": [
84
+ 0
85
+ ]
86
+ },
87
+ "hsv_v": {
88
+ "type": "discrete",
89
+ "values": [
90
+ 0
91
+ ]
92
+ },
93
+ "iou_t": {
94
+ "type": "discrete",
95
+ "values": [
96
+ 0.7
97
+ ]
98
+ },
99
+ "lr0": {
100
+ "type": "discrete",
101
+ "values": [
102
+ 1e-05,
103
+ 0.1
104
+ ]
105
+ },
106
+ "lrf": {
107
+ "type": "discrete",
108
+ "values": [
109
+ 0.01,
110
+ 1
111
+ ]
112
+ },
113
+ "mixup": {
114
+ "type": "discrete",
115
+ "values": [
116
+ 1
117
+ ]
118
+ },
119
+ "momentum": {
120
+ "type": "discrete",
121
+ "values": [
122
+ 0.6
123
+ ]
124
+ },
125
+ "mosaic": {
126
+ "type": "discrete",
127
+ "values": [
128
+ 0
129
+ ]
130
+ },
131
+ "obj": {
132
+ "type": "discrete",
133
+ "values": [
134
+ 0.2
135
+ ]
136
+ },
137
+ "obj_pw": {
138
+ "type": "discrete",
139
+ "values": [
140
+ 0.5
141
+ ]
142
+ },
143
+ "optimizer": {
144
+ "type": "categorical",
145
+ "values": [
146
+ "SGD",
147
+ "Adam",
148
+ "AdamW"
149
+ ]
150
+ },
151
+ "perspective": {
152
+ "type": "discrete",
153
+ "values": [
154
+ 0
155
+ ]
156
+ },
157
+ "scale": {
158
+ "type": "discrete",
159
+ "values": [
160
+ 0
161
+ ]
162
+ },
163
+ "shear": {
164
+ "type": "discrete",
165
+ "values": [
166
+ 0
167
+ ]
168
+ },
169
+ "translate": {
170
+ "type": "discrete",
171
+ "values": [
172
+ 0
173
+ ]
174
+ },
175
+ "warmup_bias_lr": {
176
+ "type": "discrete",
177
+ "values": [
178
+ 0,
179
+ 0.2
180
+ ]
181
+ },
182
+ "warmup_epochs": {
183
+ "type": "discrete",
184
+ "values": [
185
+ 5
186
+ ]
187
+ },
188
+ "warmup_momentum": {
189
+ "type": "discrete",
190
+ "values": [
191
+ 0,
192
+ 0.95
193
+ ]
194
+ },
195
+ "weight_decay": {
196
+ "type": "discrete",
197
+ "values": [
198
+ 0,
199
+ 0.001
200
+ ]
201
+ }
202
+ },
203
+ "spec": {
204
+ "maxCombo": 0,
205
+ "metric": "metrics/mAP_0.5",
206
+ "objective": "maximize"
207
+ },
208
+ "trials": 1
209
+ }
utils/loggers/wandb/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # init
utils/loggers/wandb/log_dataset.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from wandb_utils import WandbLogger
4
+
5
+ from utils.general import LOGGER
6
+
7
+ WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'
8
+
9
+
10
+ def create_dataset_artifact(opt):
11
+ logger = WandbLogger(opt, None, job_type='Dataset Creation') # TODO: return value unused
12
+ if not logger.wandb:
13
+ LOGGER.info("install wandb using `pip install wandb` to log the dataset")
14
+
15
+
16
+ if __name__ == '__main__':
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
19
+ parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
20
+ parser.add_argument('--project', type=str, default='YOLOv5', help='name of W&B Project')
21
+ parser.add_argument('--entity', default=None, help='W&B entity')
22
+ parser.add_argument('--name', type=str, default='log dataset', help='name of W&B run')
23
+
24
+ opt = parser.parse_args()
25
+ opt.resume = False # Explicitly disallow resume check for dataset upload job
26
+
27
+ create_dataset_artifact(opt)
utils/loggers/wandb/sweep.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+
4
+ import wandb
5
+
6
+ FILE = Path(__file__).resolve()
7
+ ROOT = FILE.parents[3] # YOLOv5 root directory
8
+ if str(ROOT) not in sys.path:
9
+ sys.path.append(str(ROOT)) # add ROOT to PATH
10
+
11
+ from train import parse_opt, train
12
+ from utils.callbacks import Callbacks
13
+ from utils.general import increment_path
14
+ from utils.torch_utils import select_device
15
+
16
+
17
+ def sweep():
18
+ wandb.init()
19
+ # Get hyp dict from sweep agent. Copy because train() modifies parameters which confused wandb.
20
+ hyp_dict = vars(wandb.config).get("_items").copy()
21
+
22
+ # Workaround: get necessary opt args
23
+ opt = parse_opt(known=True)
24
+ opt.batch_size = hyp_dict.get("batch_size")
25
+ opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok or opt.evolve))
26
+ opt.epochs = hyp_dict.get("epochs")
27
+ opt.nosave = True
28
+ opt.data = hyp_dict.get("data")
29
+ opt.weights = str(opt.weights)
30
+ opt.cfg = str(opt.cfg)
31
+ opt.data = str(opt.data)
32
+ opt.hyp = str(opt.hyp)
33
+ opt.project = str(opt.project)
34
+ device = select_device(opt.device, batch_size=opt.batch_size)
35
+
36
+ # train
37
+ train(hyp_dict, opt, device, callbacks=Callbacks())
38
+
39
+
40
+ if __name__ == "__main__":
41
+ sweep()
utils/loggers/wandb/sweep.yaml ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hyperparameters for training
2
+ # To set range-
3
+ # Provide min and max values as:
4
+ # parameter:
5
+ #
6
+ # min: scalar
7
+ # max: scalar
8
+ # OR
9
+ #
10
+ # Set a specific list of search space-
11
+ # parameter:
12
+ # values: [scalar1, scalar2, scalar3...]
13
+ #
14
+ # You can use grid, bayesian and hyperopt search strategy
15
+ # For more info on configuring sweeps visit - https://docs.wandb.ai/guides/sweeps/configuration
16
+
17
+ program: utils/loggers/wandb/sweep.py
18
+ method: random
19
+ metric:
20
+ name: metrics/mAP_0.5
21
+ goal: maximize
22
+
23
+ parameters:
24
+ # hyperparameters: set either min, max range or values list
25
+ data:
26
+ value: "data/coco128.yaml"
27
+ batch_size:
28
+ values: [64]
29
+ epochs:
30
+ values: [10]
31
+
32
+ lr0:
33
+ distribution: uniform
34
+ min: 1e-5
35
+ max: 1e-1
36
+ lrf:
37
+ distribution: uniform
38
+ min: 0.01
39
+ max: 1.0
40
+ momentum:
41
+ distribution: uniform
42
+ min: 0.6
43
+ max: 0.98
44
+ weight_decay:
45
+ distribution: uniform
46
+ min: 0.0
47
+ max: 0.001
48
+ warmup_epochs:
49
+ distribution: uniform
50
+ min: 0.0
51
+ max: 5.0
52
+ warmup_momentum:
53
+ distribution: uniform
54
+ min: 0.0
55
+ max: 0.95
56
+ warmup_bias_lr:
57
+ distribution: uniform
58
+ min: 0.0
59
+ max: 0.2
60
+ box:
61
+ distribution: uniform
62
+ min: 0.02
63
+ max: 0.2
64
+ cls:
65
+ distribution: uniform
66
+ min: 0.2
67
+ max: 4.0
68
+ cls_pw:
69
+ distribution: uniform
70
+ min: 0.5
71
+ max: 2.0
72
+ obj:
73
+ distribution: uniform
74
+ min: 0.2
75
+ max: 4.0
76
+ obj_pw:
77
+ distribution: uniform
78
+ min: 0.5
79
+ max: 2.0
80
+ iou_t:
81
+ distribution: uniform
82
+ min: 0.1
83
+ max: 0.7
84
+ anchor_t:
85
+ distribution: uniform
86
+ min: 2.0
87
+ max: 8.0
88
+ fl_gamma:
89
+ distribution: uniform
90
+ min: 0.0
91
+ max: 4.0
92
+ hsv_h:
93
+ distribution: uniform
94
+ min: 0.0
95
+ max: 0.1
96
+ hsv_s:
97
+ distribution: uniform
98
+ min: 0.0
99
+ max: 0.9
100
+ hsv_v:
101
+ distribution: uniform
102
+ min: 0.0
103
+ max: 0.9
104
+ degrees:
105
+ distribution: uniform
106
+ min: 0.0
107
+ max: 45.0
108
+ translate:
109
+ distribution: uniform
110
+ min: 0.0
111
+ max: 0.9
112
+ scale:
113
+ distribution: uniform
114
+ min: 0.0
115
+ max: 0.9
116
+ shear:
117
+ distribution: uniform
118
+ min: 0.0
119
+ max: 10.0
120
+ perspective:
121
+ distribution: uniform
122
+ min: 0.0
123
+ max: 0.001
124
+ flipud:
125
+ distribution: uniform
126
+ min: 0.0
127
+ max: 1.0
128
+ fliplr:
129
+ distribution: uniform
130
+ min: 0.0
131
+ max: 1.0
132
+ mosaic:
133
+ distribution: uniform
134
+ min: 0.0
135
+ max: 1.0
136
+ mixup:
137
+ distribution: uniform
138
+ min: 0.0
139
+ max: 1.0
140
+ copy_paste:
141
+ distribution: uniform
142
+ min: 0.0
143
+ max: 1.0
utils/loggers/wandb/wandb_utils.py ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utilities and tools for tracking runs with Weights & Biases."""
2
+
3
+ import logging
4
+ import os
5
+ import sys
6
+ from contextlib import contextmanager
7
+ from pathlib import Path
8
+ from typing import Dict
9
+
10
+ import yaml
11
+ from tqdm import tqdm
12
+
13
+ FILE = Path(__file__).resolve()
14
+ ROOT = FILE.parents[3] # YOLOv5 root directory
15
+ if str(ROOT) not in sys.path:
16
+ sys.path.append(str(ROOT)) # add ROOT to PATH
17
+
18
+ from utils.dataloaders import LoadImagesAndLabels, img2label_paths
19
+ from utils.general import LOGGER, check_dataset, check_file
20
+
21
+ try:
22
+ import wandb
23
+
24
+ assert hasattr(wandb, '__version__') # verify package import not local dir
25
+ except (ImportError, AssertionError):
26
+ wandb = None
27
+
28
+ RANK = int(os.getenv('RANK', -1))
29
+ WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'
30
+
31
+
32
+ def remove_prefix(from_string, prefix=WANDB_ARTIFACT_PREFIX):
33
+ return from_string[len(prefix):]
34
+
35
+
36
+ def check_wandb_config_file(data_config_file):
37
+ wandb_config = '_wandb.'.join(data_config_file.rsplit('.', 1)) # updated data.yaml path
38
+ if Path(wandb_config).is_file():
39
+ return wandb_config
40
+ return data_config_file
41
+
42
+
43
+ def check_wandb_dataset(data_file):
44
+ is_trainset_wandb_artifact = False
45
+ is_valset_wandb_artifact = False
46
+ if isinstance(data_file, dict):
47
+ # In that case another dataset manager has already processed it and we don't have to
48
+ return data_file
49
+ if check_file(data_file) and data_file.endswith('.yaml'):
50
+ with open(data_file, errors='ignore') as f:
51
+ data_dict = yaml.safe_load(f)
52
+ is_trainset_wandb_artifact = isinstance(data_dict['train'],
53
+ str) and data_dict['train'].startswith(WANDB_ARTIFACT_PREFIX)
54
+ is_valset_wandb_artifact = isinstance(data_dict['val'],
55
+ str) and data_dict['val'].startswith(WANDB_ARTIFACT_PREFIX)
56
+ if is_trainset_wandb_artifact or is_valset_wandb_artifact:
57
+ return data_dict
58
+ else:
59
+ return check_dataset(data_file)
60
+
61
+
62
+ def get_run_info(run_path):
63
+ run_path = Path(remove_prefix(run_path, WANDB_ARTIFACT_PREFIX))
64
+ run_id = run_path.stem
65
+ project = run_path.parent.stem
66
+ entity = run_path.parent.parent.stem
67
+ model_artifact_name = 'run_' + run_id + '_model'
68
+ return entity, project, run_id, model_artifact_name
69
+
70
+
71
+ def check_wandb_resume(opt):
72
+ process_wandb_config_ddp_mode(opt) if RANK not in [-1, 0] else None
73
+ if isinstance(opt.resume, str):
74
+ if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
75
+ if RANK not in [-1, 0]: # For resuming DDP runs
76
+ entity, project, run_id, model_artifact_name = get_run_info(opt.resume)
77
+ api = wandb.Api()
78
+ artifact = api.artifact(entity + '/' + project + '/' + model_artifact_name + ':latest')
79
+ modeldir = artifact.download()
80
+ opt.weights = str(Path(modeldir) / "last.pt")
81
+ return True
82
+ return None
83
+
84
+
85
+ def process_wandb_config_ddp_mode(opt):
86
+ with open(check_file(opt.data), errors='ignore') as f:
87
+ data_dict = yaml.safe_load(f) # data dict
88
+ train_dir, val_dir = None, None
89
+ if isinstance(data_dict['train'], str) and data_dict['train'].startswith(WANDB_ARTIFACT_PREFIX):
90
+ api = wandb.Api()
91
+ train_artifact = api.artifact(remove_prefix(data_dict['train']) + ':' + opt.artifact_alias)
92
+ train_dir = train_artifact.download()
93
+ train_path = Path(train_dir) / 'data/images/'
94
+ data_dict['train'] = str(train_path)
95
+
96
+ if isinstance(data_dict['val'], str) and data_dict['val'].startswith(WANDB_ARTIFACT_PREFIX):
97
+ api = wandb.Api()
98
+ val_artifact = api.artifact(remove_prefix(data_dict['val']) + ':' + opt.artifact_alias)
99
+ val_dir = val_artifact.download()
100
+ val_path = Path(val_dir) / 'data/images/'
101
+ data_dict['val'] = str(val_path)
102
+ if train_dir or val_dir:
103
+ ddp_data_path = str(Path(val_dir) / 'wandb_local_data.yaml')
104
+ with open(ddp_data_path, 'w') as f:
105
+ yaml.safe_dump(data_dict, f)
106
+ opt.data = ddp_data_path
107
+
108
+
109
+ class WandbLogger():
110
+ """Log training runs, datasets, models, and predictions to Weights & Biases.
111
+
112
+ This logger sends information to W&B at wandb.ai. By default, this information
113
+ includes hyperparameters, system configuration and metrics, model metrics,
114
+ and basic data metrics and analyses.
115
+
116
+ By providing additional command line arguments to train.py, datasets,
117
+ models and predictions can also be logged.
118
+
119
+ For more on how this logger is used, see the Weights & Biases documentation:
120
+ https://docs.wandb.com/guides/integrations/yolov5
121
+ """
122
+
123
+ def __init__(self, opt, run_id=None, job_type='Training'):
124
+ """
125
+ - Initialize WandbLogger instance
126
+ - Upload dataset if opt.upload_dataset is True
127
+ - Setup training processes if job_type is 'Training'
128
+
129
+ arguments:
130
+ opt (namespace) -- Commandline arguments for this run
131
+ run_id (str) -- Run ID of W&B run to be resumed
132
+ job_type (str) -- To set the job_type for this run
133
+
134
+ """
135
+ # Temporary-fix
136
+ if opt.upload_dataset:
137
+ opt.upload_dataset = False
138
+ # LOGGER.info("Uploading Dataset functionality is not being supported temporarily due to a bug.")
139
+
140
+ # Pre-training routine --
141
+ self.job_type = job_type
142
+ self.wandb, self.wandb_run = wandb, None if not wandb else wandb.run
143
+ self.val_artifact, self.train_artifact = None, None
144
+ self.train_artifact_path, self.val_artifact_path = None, None
145
+ self.result_artifact = None
146
+ self.val_table, self.result_table = None, None
147
+ self.bbox_media_panel_images = []
148
+ self.val_table_path_map = None
149
+ self.max_imgs_to_log = 16
150
+ self.wandb_artifact_data_dict = None
151
+ self.data_dict = None
152
+ # It's more elegant to stick to 1 wandb.init call,
153
+ # but useful config data is overwritten in the WandbLogger's wandb.init call
154
+ if isinstance(opt.resume, str): # checks resume from artifact
155
+ if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
156
+ entity, project, run_id, model_artifact_name = get_run_info(opt.resume)
157
+ model_artifact_name = WANDB_ARTIFACT_PREFIX + model_artifact_name
158
+ assert wandb, 'install wandb to resume wandb runs'
159
+ # Resume wandb-artifact:// runs here| workaround for not overwriting wandb.config
160
+ self.wandb_run = wandb.init(id=run_id,
161
+ project=project,
162
+ entity=entity,
163
+ resume='allow',
164
+ allow_val_change=True)
165
+ opt.resume = model_artifact_name
166
+ elif self.wandb:
167
+ self.wandb_run = wandb.init(config=opt,
168
+ resume="allow",
169
+ project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
170
+ entity=opt.entity,
171
+ name=opt.name if opt.name != 'exp' else None,
172
+ job_type=job_type,
173
+ id=run_id,
174
+ allow_val_change=True) if not wandb.run else wandb.run
175
+ if self.wandb_run:
176
+ if self.job_type == 'Training':
177
+ if opt.upload_dataset:
178
+ if not opt.resume:
179
+ self.wandb_artifact_data_dict = self.check_and_upload_dataset(opt)
180
+
181
+ if isinstance(opt.data, dict):
182
+ # This means another dataset manager has already processed the dataset info (e.g. ClearML)
183
+ # and they will have stored the already processed dict in opt.data
184
+ self.data_dict = opt.data
185
+ elif opt.resume:
186
+ # resume from artifact
187
+ if isinstance(opt.resume, str) and opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
188
+ self.data_dict = dict(self.wandb_run.config.data_dict)
189
+ else: # local resume
190
+ self.data_dict = check_wandb_dataset(opt.data)
191
+ else:
192
+ self.data_dict = check_wandb_dataset(opt.data)
193
+ self.wandb_artifact_data_dict = self.wandb_artifact_data_dict or self.data_dict
194
+
195
+ # write data_dict to config. useful for resuming from artifacts. Do this only when not resuming.
196
+ self.wandb_run.config.update({'data_dict': self.wandb_artifact_data_dict}, allow_val_change=True)
197
+ self.setup_training(opt)
198
+
199
+ if self.job_type == 'Dataset Creation':
200
+ self.wandb_run.config.update({"upload_dataset": True})
201
+ self.data_dict = self.check_and_upload_dataset(opt)
202
+
203
+ def check_and_upload_dataset(self, opt):
204
+ """
205
+ Check if the dataset format is compatible and upload it as W&B artifact
206
+
207
+ arguments:
208
+ opt (namespace)-- Commandline arguments for current run
209
+
210
+ returns:
211
+ Updated dataset info dictionary where local dataset paths are replaced by WAND_ARFACT_PREFIX links.
212
+ """
213
+ assert wandb, 'Install wandb to upload dataset'
214
+ config_path = self.log_dataset_artifact(opt.data, opt.single_cls,
215
+ 'YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem)
216
+ with open(config_path, errors='ignore') as f:
217
+ wandb_data_dict = yaml.safe_load(f)
218
+ return wandb_data_dict
219
+
220
+ def setup_training(self, opt):
221
+ """
222
+ Setup the necessary processes for training YOLO models:
223
+ - Attempt to download model checkpoint and dataset artifacts if opt.resume stats with WANDB_ARTIFACT_PREFIX
224
+ - Update data_dict, to contain info of previous run if resumed and the paths of dataset artifact if downloaded
225
+ - Setup log_dict, initialize bbox_interval
226
+
227
+ arguments:
228
+ opt (namespace) -- commandline arguments for this run
229
+
230
+ """
231
+ self.log_dict, self.current_epoch = {}, 0
232
+ self.bbox_interval = opt.bbox_interval
233
+ if isinstance(opt.resume, str):
234
+ modeldir, _ = self.download_model_artifact(opt)
235
+ if modeldir:
236
+ self.weights = Path(modeldir) / "last.pt"
237
+ config = self.wandb_run.config
238
+ opt.weights, opt.save_period, opt.batch_size, opt.bbox_interval, opt.epochs, opt.hyp, opt.imgsz = str(
239
+ self.weights), config.save_period, config.batch_size, config.bbox_interval, config.epochs,\
240
+ config.hyp, config.imgsz
241
+ data_dict = self.data_dict
242
+ if self.val_artifact is None: # If --upload_dataset is set, use the existing artifact, don't download
243
+ self.train_artifact_path, self.train_artifact = self.download_dataset_artifact(
244
+ data_dict.get('train'), opt.artifact_alias)
245
+ self.val_artifact_path, self.val_artifact = self.download_dataset_artifact(
246
+ data_dict.get('val'), opt.artifact_alias)
247
+
248
+ if self.train_artifact_path is not None:
249
+ train_path = Path(self.train_artifact_path) / 'data/images/'
250
+ data_dict['train'] = str(train_path)
251
+ if self.val_artifact_path is not None:
252
+ val_path = Path(self.val_artifact_path) / 'data/images/'
253
+ data_dict['val'] = str(val_path)
254
+
255
+ if self.val_artifact is not None:
256
+ self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
257
+ columns = ["epoch", "id", "ground truth", "prediction"]
258
+ columns.extend(self.data_dict['names'])
259
+ self.result_table = wandb.Table(columns)
260
+ self.val_table = self.val_artifact.get("val")
261
+ if self.val_table_path_map is None:
262
+ self.map_val_table_path()
263
+ if opt.bbox_interval == -1:
264
+ self.bbox_interval = opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else 1
265
+ if opt.evolve or opt.noplots:
266
+ self.bbox_interval = opt.bbox_interval = opt.epochs + 1 # disable bbox_interval
267
+ train_from_artifact = self.train_artifact_path is not None and self.val_artifact_path is not None
268
+ # Update the the data_dict to point to local artifacts dir
269
+ if train_from_artifact:
270
+ self.data_dict = data_dict
271
+
272
+ def download_dataset_artifact(self, path, alias):
273
+ """
274
+ download the model checkpoint artifact if the path starts with WANDB_ARTIFACT_PREFIX
275
+
276
+ arguments:
277
+ path -- path of the dataset to be used for training
278
+ alias (str)-- alias of the artifact to be download/used for training
279
+
280
+ returns:
281
+ (str, wandb.Artifact) -- path of the downladed dataset and it's corresponding artifact object if dataset
282
+ is found otherwise returns (None, None)
283
+ """
284
+ if isinstance(path, str) and path.startswith(WANDB_ARTIFACT_PREFIX):
285
+ artifact_path = Path(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias)
286
+ dataset_artifact = wandb.use_artifact(artifact_path.as_posix().replace("\\", "/"))
287
+ assert dataset_artifact is not None, "'Error: W&B dataset artifact doesn\'t exist'"
288
+ datadir = dataset_artifact.download()
289
+ return datadir, dataset_artifact
290
+ return None, None
291
+
292
+ def download_model_artifact(self, opt):
293
+ """
294
+ download the model checkpoint artifact if the resume path starts with WANDB_ARTIFACT_PREFIX
295
+
296
+ arguments:
297
+ opt (namespace) -- Commandline arguments for this run
298
+ """
299
+ if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
300
+ model_artifact = wandb.use_artifact(remove_prefix(opt.resume, WANDB_ARTIFACT_PREFIX) + ":latest")
301
+ assert model_artifact is not None, 'Error: W&B model artifact doesn\'t exist'
302
+ modeldir = model_artifact.download()
303
+ # epochs_trained = model_artifact.metadata.get('epochs_trained')
304
+ total_epochs = model_artifact.metadata.get('total_epochs')
305
+ is_finished = total_epochs is None
306
+ assert not is_finished, 'training is finished, can only resume incomplete runs.'
307
+ return modeldir, model_artifact
308
+ return None, None
309
+
310
+ def log_model(self, path, opt, epoch, fitness_score, best_model=False):
311
+ """
312
+ Log the model checkpoint as W&B artifact
313
+
314
+ arguments:
315
+ path (Path) -- Path of directory containing the checkpoints
316
+ opt (namespace) -- Command line arguments for this run
317
+ epoch (int) -- Current epoch number
318
+ fitness_score (float) -- fitness score for current epoch
319
+ best_model (boolean) -- Boolean representing if the current checkpoint is the best yet.
320
+ """
321
+ model_artifact = wandb.Artifact('run_' + wandb.run.id + '_model',
322
+ type='model',
323
+ metadata={
324
+ 'original_url': str(path),
325
+ 'epochs_trained': epoch + 1,
326
+ 'save period': opt.save_period,
327
+ 'project': opt.project,
328
+ 'total_epochs': opt.epochs,
329
+ 'fitness_score': fitness_score})
330
+ model_artifact.add_file(str(path / 'last.pt'), name='last.pt')
331
+ wandb.log_artifact(model_artifact,
332
+ aliases=['latest', 'last', 'epoch ' + str(self.current_epoch), 'best' if best_model else ''])
333
+ LOGGER.info(f"Saving model artifact on epoch {epoch + 1}")
334
+
335
+ def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=False):
336
+ """
337
+ Log the dataset as W&B artifact and return the new data file with W&B links
338
+
339
+ arguments:
340
+ data_file (str) -- the .yaml file with information about the dataset like - path, classes etc.
341
+ single_class (boolean) -- train multi-class data as single-class
342
+ project (str) -- project name. Used to construct the artifact path
343
+ overwrite_config (boolean) -- overwrites the data.yaml file if set to true otherwise creates a new
344
+ file with _wandb postfix. Eg -> data_wandb.yaml
345
+
346
+ returns:
347
+ the new .yaml file with artifact links. it can be used to start training directly from artifacts
348
+ """
349
+ upload_dataset = self.wandb_run.config.upload_dataset
350
+ log_val_only = isinstance(upload_dataset, str) and upload_dataset == 'val'
351
+ self.data_dict = check_dataset(data_file) # parse and check
352
+ data = dict(self.data_dict)
353
+ nc, names = (1, ['item']) if single_cls else (int(data['nc']), data['names'])
354
+ names = {k: v for k, v in enumerate(names)} # to index dictionary
355
+
356
+ # log train set
357
+ if not log_val_only:
358
+ self.train_artifact = self.create_dataset_table(LoadImagesAndLabels(data['train'], rect=True, batch_size=1),
359
+ names,
360
+ name='train') if data.get('train') else None
361
+ if data.get('train'):
362
+ data['train'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'train')
363
+
364
+ self.val_artifact = self.create_dataset_table(
365
+ LoadImagesAndLabels(data['val'], rect=True, batch_size=1), names, name='val') if data.get('val') else None
366
+ if data.get('val'):
367
+ data['val'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'val')
368
+
369
+ path = Path(data_file)
370
+ # create a _wandb.yaml file with artifacts links if both train and test set are logged
371
+ if not log_val_only:
372
+ path = (path.stem if overwrite_config else path.stem + '_wandb') + '.yaml' # updated data.yaml path
373
+ path = ROOT / 'data' / path
374
+ data.pop('download', None)
375
+ data.pop('path', None)
376
+ with open(path, 'w') as f:
377
+ yaml.safe_dump(data, f)
378
+ LOGGER.info(f"Created dataset config file {path}")
379
+
380
+ if self.job_type == 'Training': # builds correct artifact pipeline graph
381
+ if not log_val_only:
382
+ self.wandb_run.log_artifact(
383
+ self.train_artifact) # calling use_artifact downloads the dataset. NOT NEEDED!
384
+ self.wandb_run.use_artifact(self.val_artifact)
385
+ self.val_artifact.wait()
386
+ self.val_table = self.val_artifact.get('val')
387
+ self.map_val_table_path()
388
+ else:
389
+ self.wandb_run.log_artifact(self.train_artifact)
390
+ self.wandb_run.log_artifact(self.val_artifact)
391
+ return path
392
+
393
+ def map_val_table_path(self):
394
+ """
395
+ Map the validation dataset Table like name of file -> it's id in the W&B Table.
396
+ Useful for - referencing artifacts for evaluation.
397
+ """
398
+ self.val_table_path_map = {}
399
+ LOGGER.info("Mapping dataset")
400
+ for i, data in enumerate(tqdm(self.val_table.data)):
401
+ self.val_table_path_map[data[3]] = data[0]
402
+
403
+ def create_dataset_table(self, dataset: LoadImagesAndLabels, class_to_id: Dict[int, str], name: str = 'dataset'):
404
+ """
405
+ Create and return W&B artifact containing W&B Table of the dataset.
406
+
407
+ arguments:
408
+ dataset -- instance of LoadImagesAndLabels class used to iterate over the data to build Table
409
+ class_to_id -- hash map that maps class ids to labels
410
+ name -- name of the artifact
411
+
412
+ returns:
413
+ dataset artifact to be logged or used
414
+ """
415
+ # TODO: Explore multiprocessing to slpit this loop parallely| This is essential for speeding up the the logging
416
+ artifact = wandb.Artifact(name=name, type="dataset")
417
+ img_files = tqdm([dataset.path]) if isinstance(dataset.path, str) and Path(dataset.path).is_dir() else None
418
+ img_files = tqdm(dataset.im_files) if not img_files else img_files
419
+ for img_file in img_files:
420
+ if Path(img_file).is_dir():
421
+ artifact.add_dir(img_file, name='data/images')
422
+ labels_path = 'labels'.join(dataset.path.rsplit('images', 1))
423
+ artifact.add_dir(labels_path, name='data/labels')
424
+ else:
425
+ artifact.add_file(img_file, name='data/images/' + Path(img_file).name)
426
+ label_file = Path(img2label_paths([img_file])[0])
427
+ artifact.add_file(str(label_file), name='data/labels/' +
428
+ label_file.name) if label_file.exists() else None
429
+ table = wandb.Table(columns=["id", "train_image", "Classes", "name"])
430
+ class_set = wandb.Classes([{'id': id, 'name': name} for id, name in class_to_id.items()])
431
+ for si, (img, labels, paths, shapes) in enumerate(tqdm(dataset)):
432
+ box_data, img_classes = [], {}
433
+ for cls, *xywh in labels[:, 1:].tolist():
434
+ cls = int(cls)
435
+ box_data.append({
436
+ "position": {
437
+ "middle": [xywh[0], xywh[1]],
438
+ "width": xywh[2],
439
+ "height": xywh[3]},
440
+ "class_id": cls,
441
+ "box_caption": "%s" % (class_to_id[cls])})
442
+ img_classes[cls] = class_to_id[cls]
443
+ boxes = {"ground_truth": {"box_data": box_data, "class_labels": class_to_id}} # inference-space
444
+ table.add_data(si, wandb.Image(paths, classes=class_set, boxes=boxes), list(img_classes.values()),
445
+ Path(paths).name)
446
+ artifact.add(table, name)
447
+ return artifact
448
+
449
+ def log_training_progress(self, predn, path, names):
450
+ """
451
+ Build evaluation Table. Uses reference from validation dataset table.
452
+
453
+ arguments:
454
+ predn (list): list of predictions in the native space in the format - [xmin, ymin, xmax, ymax, confidence, class]
455
+ path (str): local path of the current evaluation image
456
+ names (dict(int, str)): hash map that maps class ids to labels
457
+ """
458
+ class_set = wandb.Classes([{'id': id, 'name': name} for id, name in names.items()])
459
+ box_data = []
460
+ avg_conf_per_class = [0] * len(self.data_dict['names'])
461
+ pred_class_count = {}
462
+ for *xyxy, conf, cls in predn.tolist():
463
+ if conf >= 0.25:
464
+ cls = int(cls)
465
+ box_data.append({
466
+ "position": {
467
+ "minX": xyxy[0],
468
+ "minY": xyxy[1],
469
+ "maxX": xyxy[2],
470
+ "maxY": xyxy[3]},
471
+ "class_id": cls,
472
+ "box_caption": f"{names[cls]} {conf:.3f}",
473
+ "scores": {
474
+ "class_score": conf},
475
+ "domain": "pixel"})
476
+ avg_conf_per_class[cls] += conf
477
+
478
+ if cls in pred_class_count:
479
+ pred_class_count[cls] += 1
480
+ else:
481
+ pred_class_count[cls] = 1
482
+
483
+ for pred_class in pred_class_count.keys():
484
+ avg_conf_per_class[pred_class] = avg_conf_per_class[pred_class] / pred_class_count[pred_class]
485
+
486
+ boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space
487
+ id = self.val_table_path_map[Path(path).name]
488
+ self.result_table.add_data(self.current_epoch, id, self.val_table.data[id][1],
489
+ wandb.Image(self.val_table.data[id][1], boxes=boxes, classes=class_set),
490
+ *avg_conf_per_class)
491
+
492
+ def val_one_image(self, pred, predn, path, names, im):
493
+ """
494
+ Log validation data for one image. updates the result Table if validation dataset is uploaded and log bbox media panel
495
+
496
+ arguments:
497
+ pred (list): list of scaled predictions in the format - [xmin, ymin, xmax, ymax, confidence, class]
498
+ predn (list): list of predictions in the native space - [xmin, ymin, xmax, ymax, confidence, class]
499
+ path (str): local path of the current evaluation image
500
+ """
501
+ if self.val_table and self.result_table: # Log Table if Val dataset is uploaded as artifact
502
+ self.log_training_progress(predn, path, names)
503
+
504
+ if len(self.bbox_media_panel_images) < self.max_imgs_to_log and self.current_epoch > 0:
505
+ if self.current_epoch % self.bbox_interval == 0:
506
+ box_data = [{
507
+ "position": {
508
+ "minX": xyxy[0],
509
+ "minY": xyxy[1],
510
+ "maxX": xyxy[2],
511
+ "maxY": xyxy[3]},
512
+ "class_id": int(cls),
513
+ "box_caption": f"{names[int(cls)]} {conf:.3f}",
514
+ "scores": {
515
+ "class_score": conf},
516
+ "domain": "pixel"} for *xyxy, conf, cls in pred.tolist()]
517
+ boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space
518
+ self.bbox_media_panel_images.append(wandb.Image(im, boxes=boxes, caption=path.name))
519
+
520
+ def log(self, log_dict):
521
+ """
522
+ save the metrics to the logging dictionary
523
+
524
+ arguments:
525
+ log_dict (Dict) -- metrics/media to be logged in current step
526
+ """
527
+ if self.wandb_run:
528
+ for key, value in log_dict.items():
529
+ self.log_dict[key] = value
530
+
531
+ def end_epoch(self, best_result=False):
532
+ """
533
+ commit the log_dict, model artifacts and Tables to W&B and flush the log_dict.
534
+
535
+ arguments:
536
+ best_result (boolean): Boolean representing if the result of this evaluation is best or not
537
+ """
538
+ if self.wandb_run:
539
+ with all_logging_disabled():
540
+ if self.bbox_media_panel_images:
541
+ self.log_dict["BoundingBoxDebugger"] = self.bbox_media_panel_images
542
+ try:
543
+ wandb.log(self.log_dict)
544
+ except BaseException as e:
545
+ LOGGER.info(
546
+ f"An error occurred in wandb logger. The training will proceed without interruption. More info\n{e}"
547
+ )
548
+ self.wandb_run.finish()
549
+ self.wandb_run = None
550
+
551
+ self.log_dict = {}
552
+ self.bbox_media_panel_images = []
553
+ if self.result_artifact:
554
+ self.result_artifact.add(self.result_table, 'result')
555
+ wandb.log_artifact(self.result_artifact,
556
+ aliases=[
557
+ 'latest', 'last', 'epoch ' + str(self.current_epoch),
558
+ ('best' if best_result else '')])
559
+
560
+ wandb.log({"evaluation": self.result_table})
561
+ columns = ["epoch", "id", "ground truth", "prediction"]
562
+ columns.extend(self.data_dict['names'])
563
+ self.result_table = wandb.Table(columns)
564
+ self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
565
+
566
+ def finish_run(self):
567
+ """
568
+ Log metrics if any and finish the current W&B run
569
+ """
570
+ if self.wandb_run:
571
+ if self.log_dict:
572
+ with all_logging_disabled():
573
+ wandb.log(self.log_dict)
574
+ wandb.run.finish()
575
+
576
+
577
+ @contextmanager
578
+ def all_logging_disabled(highest_level=logging.CRITICAL):
579
+ """ source - https://gist.github.com/simon-weber/7853144
580
+ A context manager that will prevent any logging messages triggered during the body from being processed.
581
+ :param highest_level: the maximum logging level in use.
582
+ This would only need to be changed if a custom level greater than CRITICAL is defined.
583
+ """
584
+ previous_level = logging.root.manager.disable
585
+ logging.disable(highest_level)
586
+ try:
587
+ yield
588
+ finally:
589
+ logging.disable(previous_level)
utils/metrics.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from pathlib import Path
4
+
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import torch
8
+
9
+ from utils import TryExcept, threaded
10
+
11
+
12
+ def fitness(x):
13
+ # Model fitness as a weighted combination of metrics
14
+ w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
15
+ return (x[:, :4] * w).sum(1)
16
+
17
+
18
+ def smooth(y, f=0.05):
19
+ # Box filter of fraction f
20
+ nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd)
21
+ p = np.ones(nf // 2) # ones padding
22
+ yp = np.concatenate((p * y[0], y, p * y[-1]), 0) # y padded
23
+ return np.convolve(yp, np.ones(nf) / nf, mode='valid') # y-smoothed
24
+
25
+
26
+ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=(), eps=1e-16, prefix=""):
27
+ """ Compute the average precision, given the recall and precision curves.
28
+ Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
29
+ # Arguments
30
+ tp: True positives (nparray, nx1 or nx10).
31
+ conf: Objectness value from 0-1 (nparray).
32
+ pred_cls: Predicted object classes (nparray).
33
+ target_cls: True object classes (nparray).
34
+ plot: Plot precision-recall curve at mAP@0.5
35
+ save_dir: Plot save directory
36
+ # Returns
37
+ The average precision as computed in py-faster-rcnn.
38
+ """
39
+
40
+ # Sort by objectness
41
+ i = np.argsort(-conf)
42
+ tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
43
+
44
+ # Find unique classes
45
+ unique_classes, nt = np.unique(target_cls, return_counts=True)
46
+ nc = unique_classes.shape[0] # number of classes, number of detections
47
+
48
+ # Create Precision-Recall curve and compute AP for each class
49
+ px, py = np.linspace(0, 1, 1000), [] # for plotting
50
+ ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))
51
+ for ci, c in enumerate(unique_classes):
52
+ i = pred_cls == c
53
+ n_l = nt[ci] # number of labels
54
+ n_p = i.sum() # number of predictions
55
+ if n_p == 0 or n_l == 0:
56
+ continue
57
+
58
+ # Accumulate FPs and TPs
59
+ fpc = (1 - tp[i]).cumsum(0)
60
+ tpc = tp[i].cumsum(0)
61
+
62
+ # Recall
63
+ recall = tpc / (n_l + eps) # recall curve
64
+ r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases
65
+
66
+ # Precision
67
+ precision = tpc / (tpc + fpc) # precision curve
68
+ p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score
69
+
70
+ # AP from recall-precision curve
71
+ for j in range(tp.shape[1]):
72
+ ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
73
+ if plot and j == 0:
74
+ py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5
75
+
76
+ # Compute F1 (harmonic mean of precision and recall)
77
+ f1 = 2 * p * r / (p + r + eps)
78
+ names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
79
+ names = dict(enumerate(names)) # to dict
80
+ if plot:
81
+ plot_pr_curve(px, py, ap, Path(save_dir) / f'{prefix}PR_curve.png', names)
82
+ plot_mc_curve(px, f1, Path(save_dir) / f'{prefix}F1_curve.png', names, ylabel='F1')
83
+ plot_mc_curve(px, p, Path(save_dir) / f'{prefix}P_curve.png', names, ylabel='Precision')
84
+ plot_mc_curve(px, r, Path(save_dir) / f'{prefix}R_curve.png', names, ylabel='Recall')
85
+
86
+ i = smooth(f1.mean(0), 0.1).argmax() # max F1 index
87
+ p, r, f1 = p[:, i], r[:, i], f1[:, i]
88
+ tp = (r * nt).round() # true positives
89
+ fp = (tp / (p + eps) - tp).round() # false positives
90
+ return tp, fp, p, r, f1, ap, unique_classes.astype(int)
91
+
92
+
93
+ def compute_ap(recall, precision):
94
+ """ Compute the average precision, given the recall and precision curves
95
+ # Arguments
96
+ recall: The recall curve (list)
97
+ precision: The precision curve (list)
98
+ # Returns
99
+ Average precision, precision curve, recall curve
100
+ """
101
+
102
+ # Append sentinel values to beginning and end
103
+ mrec = np.concatenate(([0.0], recall, [1.0]))
104
+ mpre = np.concatenate(([1.0], precision, [0.0]))
105
+
106
+ # Compute the precision envelope
107
+ mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
108
+
109
+ # Integrate area under curve
110
+ method = 'interp' # methods: 'continuous', 'interp'
111
+ if method == 'interp':
112
+ x = np.linspace(0, 1, 101) # 101-point interp (COCO)
113
+ ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
114
+ else: # 'continuous'
115
+ i = np.where(mrec[1:] != mrec[:-1])[0] # points where x axis (recall) changes
116
+ ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
117
+
118
+ return ap, mpre, mrec
119
+
120
+
121
+ class ConfusionMatrix:
122
+ # Updated version of https://github.com/kaanakan/object_detection_confusion_matrix
123
+ def __init__(self, nc, conf=0.25, iou_thres=0.45):
124
+ self.matrix = np.zeros((nc + 1, nc + 1))
125
+ self.nc = nc # number of classes
126
+ self.conf = conf
127
+ self.iou_thres = iou_thres
128
+
129
+ def process_batch(self, detections, labels):
130
+ """
131
+ Return intersection-over-union (Jaccard index) of boxes.
132
+ Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
133
+ Arguments:
134
+ detections (Array[N, 6]), x1, y1, x2, y2, conf, class
135
+ labels (Array[M, 5]), class, x1, y1, x2, y2
136
+ Returns:
137
+ None, updates confusion matrix accordingly
138
+ """
139
+ if detections is None:
140
+ gt_classes = labels.int()
141
+ for gc in gt_classes:
142
+ self.matrix[self.nc, gc] += 1 # background FN
143
+ return
144
+
145
+ detections = detections[detections[:, 4] > self.conf]
146
+ gt_classes = labels[:, 0].int()
147
+ detection_classes = detections[:, 5].int()
148
+ iou = box_iou(labels[:, 1:], detections[:, :4])
149
+
150
+ x = torch.where(iou > self.iou_thres)
151
+ if x[0].shape[0]:
152
+ matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
153
+ if x[0].shape[0] > 1:
154
+ matches = matches[matches[:, 2].argsort()[::-1]]
155
+ matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
156
+ matches = matches[matches[:, 2].argsort()[::-1]]
157
+ matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
158
+ else:
159
+ matches = np.zeros((0, 3))
160
+
161
+ n = matches.shape[0] > 0
162
+ m0, m1, _ = matches.transpose().astype(int)
163
+ for i, gc in enumerate(gt_classes):
164
+ j = m0 == i
165
+ if n and sum(j) == 1:
166
+ self.matrix[detection_classes[m1[j]], gc] += 1 # correct
167
+ else:
168
+ self.matrix[self.nc, gc] += 1 # true background
169
+
170
+ if n:
171
+ for i, dc in enumerate(detection_classes):
172
+ if not any(m1 == i):
173
+ self.matrix[dc, self.nc] += 1 # predicted background
174
+
175
+ def matrix(self):
176
+ return self.matrix
177
+
178
+ def tp_fp(self):
179
+ tp = self.matrix.diagonal() # true positives
180
+ fp = self.matrix.sum(1) - tp # false positives
181
+ # fn = self.matrix.sum(0) - tp # false negatives (missed detections)
182
+ return tp[:-1], fp[:-1] # remove background class
183
+
184
+ @TryExcept('WARNING ⚠️ ConfusionMatrix plot failure')
185
+ def plot(self, normalize=True, save_dir='', names=()):
186
+ import seaborn as sn
187
+
188
+ array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
189
+ array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
190
+
191
+ fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
192
+ nc, nn = self.nc, len(names) # number of classes, names
193
+ sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
194
+ labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
195
+ ticklabels = (names + ['background']) if labels else "auto"
196
+ with warnings.catch_warnings():
197
+ warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
198
+ sn.heatmap(array,
199
+ ax=ax,
200
+ annot=nc < 30,
201
+ annot_kws={
202
+ "size": 8},
203
+ cmap='Blues',
204
+ fmt='.2f',
205
+ square=True,
206
+ vmin=0.0,
207
+ xticklabels=ticklabels,
208
+ yticklabels=ticklabels).set_facecolor((1, 1, 1))
209
+ ax.set_ylabel('True')
210
+ ax.set_ylabel('Predicted')
211
+ ax.set_title('Confusion Matrix')
212
+ fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
213
+ plt.close(fig)
214
+
215
+ def print(self):
216
+ for i in range(self.nc + 1):
217
+ print(' '.join(map(str, self.matrix[i])))
218
+
219
+
220
+ class WIoU_Scale:
221
+ ''' monotonous: {
222
+ None: origin v1
223
+ True: monotonic FM v2
224
+ False: non-monotonic FM v3
225
+ }
226
+ momentum: The momentum of running mean'''
227
+
228
+ iou_mean = 1.
229
+ monotonous = False
230
+ _momentum = 1 - 0.5 ** (1 / 7000)
231
+ _is_train = True
232
+
233
+ def __init__(self, iou):
234
+ self.iou = iou
235
+ self._update(self)
236
+
237
+ @classmethod
238
+ def _update(cls, self):
239
+ if cls._is_train: cls.iou_mean = (1 - cls._momentum) * cls.iou_mean + \
240
+ cls._momentum * self.iou.detach().mean().item()
241
+
242
+ @classmethod
243
+ def _scaled_loss(cls, self, gamma=1.9, delta=3):
244
+ if isinstance(self.monotonous, bool):
245
+ if self.monotonous:
246
+ return (self.iou.detach() / self.iou_mean).sqrt()
247
+ else:
248
+ beta = self.iou.detach() / self.iou_mean
249
+ alpha = delta * torch.pow(gamma, beta - delta)
250
+ return beta / alpha
251
+ return 1
252
+
253
+
254
+ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, MDPIoU=False, feat_h=640, feat_w=640, eps=1e-7):
255
+ # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)
256
+
257
+ # Get the coordinates of bounding boxes
258
+ if xywh: # transform from xywh to xyxy
259
+ (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
260
+ w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
261
+ b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
262
+ b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
263
+ else: # x1, y1, x2, y2 = box1
264
+ b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
265
+ b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
266
+ w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
267
+ w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
268
+
269
+ # Intersection area
270
+ inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
271
+ (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
272
+
273
+ # Union Area
274
+ union = w1 * h1 + w2 * h2 - inter + eps
275
+
276
+ # IoU
277
+ iou = inter / union
278
+ if CIoU or DIoU or GIoU:
279
+ cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
280
+ ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
281
+ if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
282
+ c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
283
+ rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2
284
+ if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
285
+ v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
286
+ with torch.no_grad():
287
+ alpha = v / (v - iou + (1 + eps))
288
+ return iou - (rho2 / c2 + v * alpha) # CIoU
289
+ return iou - rho2 / c2 # DIoU
290
+ c_area = cw * ch + eps # convex area
291
+ return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf
292
+ elif MDPIoU:
293
+ d1 = (b2_x1 - b1_x1) ** 2 + (b2_y1 - b1_y1) ** 2
294
+ d2 = (b2_x2 - b1_x2) ** 2 + (b2_y2 - b1_y2) ** 2
295
+ mpdiou_hw_pow = feat_h ** 2 + feat_w ** 2
296
+ return iou - d1 / mpdiou_hw_pow - d2 / mpdiou_hw_pow # MPDIoU
297
+ return iou # IoU
298
+
299
+
300
+ def box_iou(box1, box2, eps=1e-7):
301
+ # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
302
+ """
303
+ Return intersection-over-union (Jaccard index) of boxes.
304
+ Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
305
+ Arguments:
306
+ box1 (Tensor[N, 4])
307
+ box2 (Tensor[M, 4])
308
+ Returns:
309
+ iou (Tensor[N, M]): the NxM matrix containing the pairwise
310
+ IoU values for every element in boxes1 and boxes2
311
+ """
312
+
313
+ # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
314
+ (a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2)
315
+ inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2)
316
+
317
+ # IoU = inter / (area1 + area2 - inter)
318
+ return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)
319
+
320
+
321
+ def bbox_ioa(box1, box2, eps=1e-7):
322
+ """Returns the intersection over box2 area given box1, box2. Boxes are x1y1x2y2
323
+ box1: np.array of shape(nx4)
324
+ box2: np.array of shape(mx4)
325
+ returns: np.array of shape(nxm)
326
+ """
327
+
328
+ # Get the coordinates of bounding boxes
329
+ b1_x1, b1_y1, b1_x2, b1_y2 = box1.T
330
+ b2_x1, b2_y1, b2_x2, b2_y2 = box2.T
331
+
332
+ # Intersection area
333
+ inter_area = (np.minimum(b1_x2[:, None], b2_x2) - np.maximum(b1_x1[:, None], b2_x1)).clip(0) * \
334
+ (np.minimum(b1_y2[:, None], b2_y2) - np.maximum(b1_y1[:, None], b2_y1)).clip(0)
335
+
336
+ # box2 area
337
+ box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + eps
338
+
339
+ # Intersection over box2 area
340
+ return inter_area / box2_area
341
+
342
+
343
+ def wh_iou(wh1, wh2, eps=1e-7):
344
+ # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
345
+ wh1 = wh1[:, None] # [N,1,2]
346
+ wh2 = wh2[None] # [1,M,2]
347
+ inter = torch.min(wh1, wh2).prod(2) # [N,M]
348
+ return inter / (wh1.prod(2) + wh2.prod(2) - inter + eps) # iou = inter / (area1 + area2 - inter)
349
+
350
+
351
+ # Plots ----------------------------------------------------------------------------------------------------------------
352
+
353
+
354
+ @threaded
355
+ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
356
+ # Precision-recall curve
357
+ fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
358
+ py = np.stack(py, axis=1)
359
+
360
+ if 0 < len(names) < 21: # display per-class legend if < 21 classes
361
+ for i, y in enumerate(py.T):
362
+ ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}') # plot(recall, precision)
363
+ else:
364
+ ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision)
365
+
366
+ ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean())
367
+ ax.set_xlabel('Recall')
368
+ ax.set_ylabel('Precision')
369
+ ax.set_xlim(0, 1)
370
+ ax.set_ylim(0, 1)
371
+ ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
372
+ ax.set_title('Precision-Recall Curve')
373
+ fig.savefig(save_dir, dpi=250)
374
+ plt.close(fig)
375
+
376
+
377
+ @threaded
378
+ def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'):
379
+ # Metric-confidence curve
380
+ fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
381
+
382
+ if 0 < len(names) < 21: # display per-class legend if < 21 classes
383
+ for i, y in enumerate(py):
384
+ ax.plot(px, y, linewidth=1, label=f'{names[i]}') # plot(confidence, metric)
385
+ else:
386
+ ax.plot(px, py.T, linewidth=1, color='grey') # plot(confidence, metric)
387
+
388
+ y = smooth(py.mean(0), 0.05)
389
+ ax.plot(px, y, linewidth=3, color='blue', label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}')
390
+ ax.set_xlabel(xlabel)
391
+ ax.set_ylabel(ylabel)
392
+ ax.set_xlim(0, 1)
393
+ ax.set_ylim(0, 1)
394
+ ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
395
+ ax.set_title(f'{ylabel}-Confidence Curve')
396
+ fig.savefig(save_dir, dpi=250)
397
+ plt.close(fig)
utils/panoptic/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # init
utils/panoptic/augmentations.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+
4
+ import cv2
5
+ import numpy as np
6
+
7
+ from ..augmentations import box_candidates
8
+ from ..general import resample_segments, segment2box
9
+ from ..metrics import bbox_ioa
10
+
11
+
12
+ def mixup(im, labels, segments, seg_cls, semantic_masks, im2, labels2, segments2, seg_cls2, semantic_masks2):
13
+ # Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf
14
+ r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
15
+ im = (im * r + im2 * (1 - r)).astype(np.uint8)
16
+ labels = np.concatenate((labels, labels2), 0)
17
+ segments = np.concatenate((segments, segments2), 0)
18
+ seg_cls = np.concatenate((seg_cls, seg_cls2), 0)
19
+ semantic_masks = np.concatenate((semantic_masks, semantic_masks2), 0)
20
+ return im, labels, segments, seg_cls, semantic_masks
21
+
22
+
23
+ def random_perspective(im,
24
+ targets=(),
25
+ segments=(),
26
+ semantic_masks = (),
27
+ degrees=10,
28
+ translate=.1,
29
+ scale=.1,
30
+ shear=10,
31
+ perspective=0.0,
32
+ border=(0, 0)):
33
+ # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10))
34
+ # targets = [cls, xyxy]
35
+
36
+ height = im.shape[0] + border[0] * 2 # shape(h,w,c)
37
+ width = im.shape[1] + border[1] * 2
38
+
39
+ # Center
40
+ C = np.eye(3)
41
+ C[0, 2] = -im.shape[1] / 2 # x translation (pixels)
42
+ C[1, 2] = -im.shape[0] / 2 # y translation (pixels)
43
+
44
+ # Perspective
45
+ P = np.eye(3)
46
+ P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y)
47
+ P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x)
48
+
49
+ # Rotation and Scale
50
+ R = np.eye(3)
51
+ a = random.uniform(-degrees, degrees)
52
+ # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
53
+ s = random.uniform(1 - scale, 1 + scale)
54
+ # s = 2 ** random.uniform(-scale, scale)
55
+ R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
56
+
57
+ # Shear
58
+ S = np.eye(3)
59
+ S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
60
+ S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
61
+
62
+ # Translation
63
+ T = np.eye(3)
64
+ T[0, 2] = (random.uniform(0.5 - translate, 0.5 + translate) * width) # x translation (pixels)
65
+ T[1, 2] = (random.uniform(0.5 - translate, 0.5 + translate) * height) # y translation (pixels)
66
+
67
+ # Combined rotation matrix
68
+ M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
69
+ if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
70
+ if perspective:
71
+ im = cv2.warpPerspective(im, M, dsize=(width, height), borderValue=(114, 114, 114))
72
+ else: # affine
73
+ im = cv2.warpAffine(im, M[:2], dsize=(width, height), borderValue=(114, 114, 114))
74
+
75
+ # Visualize
76
+ # import matplotlib.pyplot as plt
77
+ # ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel()
78
+ # ax[0].imshow(im[:, :, ::-1]) # base
79
+ # ax[1].imshow(im2[:, :, ::-1]) # warped
80
+
81
+ # Transform label coordinates
82
+ n = len(targets)
83
+ new_segments = []
84
+ new_semantic_masks = []
85
+ if n:
86
+ new = np.zeros((n, 4))
87
+ segments = resample_segments(segments) # upsample
88
+ for i, segment in enumerate(segments):
89
+ xy = np.ones((len(segment), 3))
90
+ xy[:, :2] = segment
91
+ xy = xy @ M.T # transform
92
+ xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]) # perspective rescale or affine
93
+
94
+ # clip
95
+ new[i] = segment2box(xy, width, height)
96
+ new_segments.append(xy)
97
+
98
+ semantic_masks = resample_segments(semantic_masks)
99
+ for i, semantic_mask in enumerate(semantic_masks):
100
+ #if i < n:
101
+ # xy = np.ones((len(segments[i]), 3))
102
+ # xy[:, :2] = segments[i]
103
+ # xy = xy @ M.T # transform
104
+ # xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]) # perspective rescale or affine
105
+
106
+ # new[i] = segment2box(xy, width, height)
107
+ # new_segments.append(xy)
108
+
109
+ xy_s = np.ones((len(semantic_mask), 3))
110
+ xy_s[:, :2] = semantic_mask
111
+ xy_s = xy_s @ M.T # transform
112
+ xy_s = (xy_s[:, :2] / xy_s[:, 2:3] if perspective else xy_s[:, :2]) # perspective rescale or affine
113
+
114
+ new_semantic_masks.append(xy_s)
115
+
116
+ # filter candidates
117
+ i = box_candidates(box1=targets[:, 1:5].T * s, box2=new.T, area_thr=0.01)
118
+ targets = targets[i]
119
+ targets[:, 1:5] = new[i]
120
+ new_segments = np.array(new_segments)[i]
121
+ new_semantic_masks = np.array(new_semantic_masks)
122
+
123
+ return im, targets, new_segments, new_semantic_masks
124
+
125
+
126
+ def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
127
+ # Resize and pad image while meeting stride-multiple constraints
128
+ shape = im.shape[:2] # current shape [height, width]
129
+ if isinstance(new_shape, int):
130
+ new_shape = (new_shape, new_shape)
131
+
132
+ # Scale ratio (new / old)
133
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
134
+ if not scaleup: # only scale down, do not scale up (for better val mAP)
135
+ r = min(r, 1.0)
136
+
137
+ # Compute padding
138
+ ratio = r, r # width, height ratios
139
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
140
+ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
141
+ if auto: # minimum rectangle
142
+ dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
143
+ elif scaleFill: # stretch
144
+ dw, dh = 0.0, 0.0
145
+ new_unpad = (new_shape[1], new_shape[0])
146
+ ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
147
+
148
+ dw /= 2 # divide padding into 2 sides
149
+ dh /= 2
150
+
151
+ if shape[::-1] != new_unpad: # resize
152
+ im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
153
+ top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
154
+ left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
155
+ im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
156
+ return im, ratio, (dw, dh)
157
+
158
+
159
+ def copy_paste(im, labels, segments, seg_cls, semantic_masks, p=0.5):
160
+ # Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)
161
+ n = len(segments)
162
+ if p and n:
163
+ h, w, _ = im.shape # height, width, channels
164
+ im_new = np.zeros(im.shape, np.uint8)
165
+
166
+ # calculate ioa first then select indexes randomly
167
+ boxes = np.stack([w - labels[:, 3], labels[:, 2], w - labels[:, 1], labels[:, 4]], axis=-1) # (n, 4)
168
+ ioa = bbox_ioa(boxes, labels[:, 1:5]) # intersection over area
169
+ indexes = np.nonzero((ioa < 0.30).all(1))[0] # (N, )
170
+ n = len(indexes)
171
+ for j in random.sample(list(indexes), k=round(p * n)):
172
+ l, box, s = labels[j], boxes[j], segments[j]
173
+ labels = np.concatenate((labels, [[l[0], *box]]), 0)
174
+ segments.append(np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1))
175
+ seg_cls.append(l[0].astype(int))
176
+ semantic_masks.append(np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1))
177
+ cv2.drawContours(im_new, [segments[j].astype(np.int32)], -1, (1, 1, 1), cv2.FILLED)
178
+
179
+ result = cv2.flip(im, 1) # augment segments (flip left-right)
180
+ i = cv2.flip(im_new, 1).astype(bool)
181
+ im[i] = result[i] # cv2.imwrite('debug.jpg', im) # debug
182
+
183
+ return im, labels, segments, seg_cls, semantic_masks
utils/panoptic/dataloaders.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import pickle
5
+ from pathlib import Path
6
+
7
+ from itertools import repeat
8
+ from multiprocessing.pool import Pool, ThreadPool
9
+
10
+ import cv2
11
+ import numpy as np
12
+ import torch
13
+ from torch.utils.data import DataLoader, distributed
14
+ from tqdm import tqdm
15
+
16
+ from ..augmentations import augment_hsv
17
+ from ..dataloaders import InfiniteDataLoader, LoadImagesAndLabels, seed_worker, get_hash, verify_image_label, HELP_URL, TQDM_BAR_FORMAT, LOCAL_RANK
18
+ from ..general import NUM_THREADS, LOGGER, xyn2xy, xywhn2xyxy, xyxy2xywhn
19
+ from ..torch_utils import torch_distributed_zero_first
20
+ from ..coco_utils import annToMask, getCocoIds
21
+ from .augmentations import mixup, random_perspective, copy_paste, letterbox
22
+
23
+ RANK = int(os.getenv('RANK', -1))
24
+
25
+
26
+ def create_dataloader(path,
27
+ imgsz,
28
+ batch_size,
29
+ stride,
30
+ single_cls=False,
31
+ hyp=None,
32
+ augment=False,
33
+ cache=False,
34
+ pad=0.0,
35
+ rect=False,
36
+ rank=-1,
37
+ workers=8,
38
+ image_weights=False,
39
+ close_mosaic=False,
40
+ quad=False,
41
+ prefix='',
42
+ shuffle=False,
43
+ mask_downsample_ratio=1,
44
+ overlap_mask=False):
45
+ if rect and shuffle:
46
+ LOGGER.warning('WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False')
47
+ shuffle = False
48
+ with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
49
+ dataset = LoadImagesAndLabelsAndMasks(
50
+ path,
51
+ imgsz,
52
+ batch_size,
53
+ augment=augment, # augmentation
54
+ hyp=hyp, # hyperparameters
55
+ rect=rect, # rectangular batches
56
+ cache_images=cache,
57
+ single_cls=single_cls,
58
+ stride=int(stride),
59
+ pad=pad,
60
+ image_weights=image_weights,
61
+ prefix=prefix,
62
+ downsample_ratio=mask_downsample_ratio,
63
+ overlap=overlap_mask)
64
+
65
+ batch_size = min(batch_size, len(dataset))
66
+ nd = torch.cuda.device_count() # number of CUDA devices
67
+ nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
68
+ sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
69
+ #loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
70
+ loader = DataLoader if image_weights or close_mosaic else InfiniteDataLoader
71
+ generator = torch.Generator()
72
+ generator.manual_seed(6148914691236517205 + RANK)
73
+ return loader(
74
+ dataset,
75
+ batch_size=batch_size,
76
+ shuffle=shuffle and sampler is None,
77
+ num_workers=nw,
78
+ sampler=sampler,
79
+ pin_memory=True,
80
+ collate_fn=LoadImagesAndLabelsAndMasks.collate_fn4 if quad else LoadImagesAndLabelsAndMasks.collate_fn,
81
+ worker_init_fn=seed_worker,
82
+ generator=generator,
83
+ ), dataset
84
+
85
+ def img2stuff_paths(img_paths):
86
+ # Define label paths as a function of image paths
87
+ sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}stuff{os.sep}' # /images/, /segmentations/ substrings
88
+ return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
89
+
90
+
91
+ class LoadImagesAndLabelsAndMasks(LoadImagesAndLabels): # for training/testing
92
+
93
+ def __init__(
94
+ self,
95
+ path,
96
+ img_size=640,
97
+ batch_size=16,
98
+ augment=False,
99
+ hyp=None,
100
+ rect=False,
101
+ image_weights=False,
102
+ cache_images=False,
103
+ single_cls=False,
104
+ stride=32,
105
+ pad=0,
106
+ min_items=0,
107
+ prefix="",
108
+ downsample_ratio=1,
109
+ overlap=False,
110
+ ):
111
+ super().__init__(
112
+ path,
113
+ img_size,
114
+ batch_size,
115
+ augment,
116
+ hyp,
117
+ rect,
118
+ image_weights,
119
+ cache_images,
120
+ single_cls,
121
+ stride,
122
+ pad,
123
+ min_items,
124
+ prefix)
125
+ self.downsample_ratio = downsample_ratio
126
+ self.overlap = overlap
127
+
128
+ # semantic segmentation
129
+ self.coco_ids = getCocoIds()
130
+
131
+ # Check cache
132
+ self.seg_files = img2stuff_paths(self.im_files) # labels
133
+ p = Path(path)
134
+ cache_path = (p.with_suffix('') if p.is_file() else Path(self.seg_files[0]).parent)
135
+ cache_path = Path(str(cache_path) + '_stuff').with_suffix('.cache')
136
+ try:
137
+ cache, exists = np.load(cache_path, allow_pickle = True).item(), True # load dict
138
+ #assert cache['version'] == self.cache_version # matches current version
139
+ #assert cache['hash'] == get_hash(self.seg_files + self.im_files) # identical hash
140
+ except Exception:
141
+ cache, exists = self.cache_seg_labels(cache_path, prefix), False # run cache ops
142
+
143
+ # Display cache
144
+ nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
145
+ if exists and LOCAL_RANK in {-1, 0}:
146
+ d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupt"
147
+ tqdm(None, desc = (prefix + d), total = n, initial = n, bar_format = TQDM_BAR_FORMAT) # display cache results
148
+ if cache['msgs']:
149
+ LOGGER.info('\n'.join(cache['msgs'])) # display warnings
150
+ assert (0 < nf) or (not augment), f'{prefix}No labels found in {cache_path}, can not start training. {HELP_URL}'
151
+
152
+ # Read cache
153
+ [cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
154
+ seg_labels, _, self.semantic_masks = zip(*cache.values())
155
+ nl = len(np.concatenate(seg_labels, 0)) # number of labels
156
+ assert nl > 0 or not augment, f'{prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}'
157
+
158
+ # Update labels
159
+ self.seg_cls = []
160
+ include_class = [] # filter labels to include only these classes (optional)
161
+ include_class_array = np.array(include_class).reshape(1, -1)
162
+ for i, (label, semantic_masks) in enumerate(zip(seg_labels, self.semantic_masks)):
163
+ self.seg_cls.append((label[:, 0].astype(int)).tolist())
164
+ if include_class:
165
+ j = (label[:, 0:1] == include_class_array).any(1)
166
+ if semantic_masks:
167
+ self.semantic_masks[i] = semantic_masks[j]
168
+ if single_cls: # single-class training, merge all classes into 0
169
+ if semantic_masks:
170
+ self.semantic_masks[i][:, 0] = 0
171
+
172
+ def __getitem__(self, index):
173
+ index = self.indices[index] # linear, shuffled, or image_weights
174
+
175
+ hyp = self.hyp
176
+ mosaic = self.mosaic and random.random() < hyp['mosaic']
177
+ masks = []
178
+ if mosaic:
179
+ # Load mosaic
180
+ img, labels, segments, seg_cls, semantic_masks = self.load_mosaic(index)
181
+ shapes = None
182
+
183
+ # MixUp augmentation
184
+ if random.random() < hyp["mixup"]:
185
+ img, labels, segments, seg_cls, semantic_masks = mixup(img, labels, segments, seg_cls, semantic_masks,
186
+ *self.load_mosaic(random.randint(0, self.n - 1)))
187
+
188
+ else:
189
+ # Load image
190
+ img, (h0, w0), (h, w) = self.load_image(index)
191
+
192
+ # Letterbox
193
+ shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
194
+ img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
195
+ shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
196
+
197
+ labels = self.labels[index].copy()
198
+ # [array, array, ....], array.shape=(num_points, 2), xyxyxyxy
199
+ segments = self.segments[index].copy()
200
+ if len(segments):
201
+ for i_s in range(len(segments)):
202
+ segments[i_s] = xyn2xy(
203
+ segments[i_s],
204
+ ratio[0] * w,
205
+ ratio[1] * h,
206
+ padw=pad[0],
207
+ padh=pad[1],
208
+ )
209
+
210
+ seg_cls = self.seg_cls[index].copy()
211
+ semantic_masks = self.semantic_masks[index].copy()
212
+ #semantic_masks = [xyn2xy(x, ratio[0] * w, ratio[1] * h, padw = pad[0], padh = pad[1]) for x in semantic_masks]
213
+ if len(semantic_masks):
214
+ for ss in range(len(semantic_masks)):
215
+ semantic_masks[ss] = xyn2xy(
216
+ semantic_masks[ss],
217
+ ratio[0] * w,
218
+ ratio[1] * h,
219
+ padw = pad[0],
220
+ padh = pad[1],
221
+ )
222
+
223
+ if labels.size: # normalized xywh to pixel xyxy format
224
+ labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
225
+
226
+ if self.augment:
227
+ img, labels, segments, semantic_masks = random_perspective(
228
+ img,
229
+ labels,
230
+ segments=segments,
231
+ semantic_masks = semantic_masks,
232
+ degrees=hyp["degrees"],
233
+ translate=hyp["translate"],
234
+ scale=hyp["scale"],
235
+ shear=hyp["shear"],
236
+ perspective=hyp["perspective"])
237
+
238
+ nl = len(labels) # number of labels
239
+ if nl:
240
+ labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0], clip=True, eps=1e-3)
241
+ if self.overlap:
242
+ masks, sorted_idx = polygons2masks_overlap(img.shape[:2],
243
+ segments,
244
+ downsample_ratio=self.downsample_ratio)
245
+ masks = masks[None] # (640, 640) -> (1, 640, 640)
246
+ labels = labels[sorted_idx]
247
+ else:
248
+ masks = polygons2masks(img.shape[:2], segments, color=1, downsample_ratio=self.downsample_ratio)
249
+
250
+ masks = (torch.from_numpy(masks) if len(masks) else torch.zeros(1 if self.overlap else nl, img.shape[0] //
251
+ self.downsample_ratio, img.shape[1] //
252
+ self.downsample_ratio))
253
+ semantic_masks = polygons2masks(img.shape[:2], semantic_masks, color = 1, downsample_ratio=self.downsample_ratio)
254
+ #semantic_masks = polygons2masks(img.shape[:2], semantic_masks, color = 1, downsample_ratio=1)
255
+ semantic_masks = torch.from_numpy(semantic_masks)
256
+ # TODO: albumentations support
257
+ if self.augment:
258
+ # Albumentations
259
+ # there are some augmentation that won't change boxes and masks,
260
+ # so just be it for now.
261
+ img, labels = self.albumentations(img, labels)
262
+ nl = len(labels) # update after albumentations
263
+ ns = len(semantic_masks)
264
+
265
+ # HSV color-space
266
+ augment_hsv(img, hgain=hyp["hsv_h"], sgain=hyp["hsv_s"], vgain=hyp["hsv_v"])
267
+
268
+ # Flip up-down
269
+ if random.random() < hyp["flipud"]:
270
+ img = np.flipud(img)
271
+ if nl:
272
+ labels[:, 2] = 1 - labels[:, 2]
273
+ masks = torch.flip(masks, dims=[1])
274
+ if ns:
275
+ semantic_masks = torch.flip(semantic_masks, dims = [1])
276
+
277
+ # Flip left-right
278
+ if random.random() < hyp["fliplr"]:
279
+ img = np.fliplr(img)
280
+ if nl:
281
+ labels[:, 1] = 1 - labels[:, 1]
282
+ masks = torch.flip(masks, dims=[2])
283
+ if ns:
284
+ semantic_masks = torch.flip(semantic_masks, dims = [2])
285
+
286
+ # Cutouts # labels = cutout(img, labels, p=0.5)
287
+
288
+ labels_out = torch.zeros((nl, 6))
289
+ if nl:
290
+ labels_out[:, 1:] = torch.from_numpy(labels)
291
+
292
+ # Combine semantic masks
293
+ semantic_seg_masks = torch.zeros((len(self.coco_ids), img.shape[0] // self.downsample_ratio,
294
+ img.shape[1] // self.downsample_ratio), dtype = torch.uint8)
295
+ #semantic_seg_masks = torch.zeros((len(self.coco_ids), img.shape[0], img.shape[1]), dtype = torch.uint8)
296
+ for cls_id, semantic_mask in zip(seg_cls, semantic_masks):
297
+ semantic_seg_masks[cls_id] = (semantic_seg_masks[cls_id].logical_or(semantic_mask)).int()
298
+
299
+
300
+ # Convert
301
+ img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
302
+ img = np.ascontiguousarray(img)
303
+
304
+ return (torch.from_numpy(img), labels_out, self.im_files[index], shapes, masks, semantic_seg_masks)
305
+
306
+ def load_mosaic(self, index):
307
+ # YOLO 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic
308
+ labels4, segments4, seg_cls, semantic_masks4 = [], [], [], []
309
+ s = self.img_size
310
+ yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border) # mosaic center x, y
311
+
312
+ # 3 additional image indices
313
+ indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices
314
+ for i, index in enumerate(indices):
315
+ # Load image
316
+ img, _, (h, w) = self.load_image(index)
317
+
318
+ # place img in img4
319
+ if i == 0: # top left
320
+ img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
321
+ x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
322
+ x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
323
+ elif i == 1: # top right
324
+ x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
325
+ x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
326
+ elif i == 2: # bottom left
327
+ x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
328
+ x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
329
+ elif i == 3: # bottom right
330
+ x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
331
+ x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
332
+
333
+ img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
334
+ padw = x1a - x1b
335
+ padh = y1a - y1b
336
+
337
+ labels, segments, semantic_masks = self.labels[index].copy(), self.segments[index].copy(), self.semantic_masks[index].copy()
338
+
339
+ if labels.size:
340
+ labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format
341
+ segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
342
+ semantic_masks = [xyn2xy(x, w, h, padw, padh) for x in semantic_masks]
343
+ labels4.append(labels)
344
+ segments4.extend(segments)
345
+ seg_cls.extend(self.seg_cls[index].copy())
346
+ semantic_masks4.extend(semantic_masks)
347
+
348
+ # Concat/clip labels
349
+ labels4 = np.concatenate(labels4, 0)
350
+ for i in range(len(semantic_masks4)):
351
+ if i < len(segments4):
352
+ np.clip(labels4[:, 1:][i], 0, 2 * s, out = labels4[:, 1:][i])
353
+ np.clip(segments4[i], 0, 2 * s, out = segments4[i])
354
+ np.clip(semantic_masks4[i], 0, 2 * s, out = semantic_masks4[i])
355
+ # img4, labels4 = replicate(img4, labels4) # replicate
356
+
357
+ # 3 additional image indices
358
+ # Augment
359
+ img4, labels4, segments4, seg_cls, semantic_masks4 = copy_paste(img4, labels4, segments4, seg_cls, semantic_masks4, p=self.hyp["copy_paste"])
360
+ img4, labels4, segments4, semantic_masks4 = random_perspective(img4,
361
+ labels4,
362
+ segments4,
363
+ semantic_masks4,
364
+ degrees=self.hyp["degrees"],
365
+ translate=self.hyp["translate"],
366
+ scale=self.hyp["scale"],
367
+ shear=self.hyp["shear"],
368
+ perspective=self.hyp["perspective"],
369
+ border=self.mosaic_border) # border to remove
370
+
371
+ return img4, labels4, segments4, seg_cls, semantic_masks4
372
+
373
+ def cache_seg_labels(self, path = Path('./labels_stuff.cache'), prefix = ''):
374
+ # Cache dataset labels, check images and read shapes
375
+ x = {} # dict
376
+ nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
377
+ desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels..."
378
+ with Pool(NUM_THREADS) as pool:
379
+ pbar = tqdm(pool.imap(verify_image_label, zip(self.im_files, self.seg_files, repeat(prefix))),
380
+ desc = desc,
381
+ total = len(self.im_files),
382
+ bar_format = TQDM_BAR_FORMAT)
383
+ for im_file, lb, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
384
+ nm += nm_f
385
+ nf += nf_f
386
+ ne += ne_f
387
+ nc += nc_f
388
+ if im_file:
389
+ x[im_file] = [lb, shape, segments]
390
+ if msg:
391
+ msgs.append(msg)
392
+ pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupt"
393
+
394
+ pbar.close()
395
+ if msgs:
396
+ LOGGER.info('\n'.join(msgs))
397
+ if nf == 0:
398
+ LOGGER.warning(f'{prefix}WARNING: No labels found in {path}. {HELP_URL}')
399
+ x['hash'] = get_hash(self.seg_files + self.im_files)
400
+ x['results'] = nf, nm, ne, nc, len(self.im_files)
401
+ x['msgs'] = msgs # warnings
402
+ x['version'] = self.cache_version # cache version
403
+ try:
404
+ np.save(path, x) # save cache for next time
405
+ path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
406
+ LOGGER.info(f'{prefix}New cache created: {path}')
407
+ except Exception as e:
408
+ LOGGER.warning(f'{prefix}WARNING: Cache directory {path.parent} is not writeable: {e}') # not writeable
409
+ return x
410
+
411
+ @staticmethod
412
+ def collate_fn(batch):
413
+ img, label, path, shapes, masks, semantic_masks = zip(*batch) # transposed
414
+ batched_masks = torch.cat(masks, 0)
415
+ for i, l in enumerate(label):
416
+ l[:, 0] = i # add target image index for build_targets()
417
+ return torch.stack(img, 0), torch.cat(label, 0), path, shapes, batched_masks, torch.stack(semantic_masks, 0)
418
+
419
+
420
+
421
+ def polygon2mask(img_size, polygons, color=1, downsample_ratio=1):
422
+ """
423
+ Args:
424
+ img_size (tuple): The image size.
425
+ polygons (np.ndarray): [N, M], N is the number of polygons,
426
+ M is the number of points(Be divided by 2).
427
+ """
428
+ mask = np.zeros(img_size, dtype=np.uint8)
429
+ polygons = np.asarray(polygons)
430
+ polygons = polygons.astype(np.int32)
431
+ shape = polygons.shape
432
+ polygons = polygons.reshape(shape[0], -1, 2)
433
+ cv2.fillPoly(mask, polygons, color=color)
434
+ nh, nw = (img_size[0] // downsample_ratio, img_size[1] // downsample_ratio)
435
+ # NOTE: fillPoly firstly then resize is trying the keep the same way
436
+ # of loss calculation when mask-ratio=1.
437
+ mask = cv2.resize(mask, (nw, nh))
438
+ return mask
439
+
440
+
441
+ def polygons2masks(img_size, polygons, color, downsample_ratio=1):
442
+ """
443
+ Args:
444
+ img_size (tuple): The image size.
445
+ polygons (list[np.ndarray]): each polygon is [N, M],
446
+ N is the number of polygons,
447
+ M is the number of points(Be divided by 2).
448
+ """
449
+ masks = []
450
+ for si in range(len(polygons)):
451
+ mask = polygon2mask(img_size, [polygons[si].reshape(-1)], color, downsample_ratio)
452
+ masks.append(mask)
453
+ return np.array(masks)
454
+
455
+
456
+ def polygons2masks_overlap(img_size, segments, downsample_ratio=1):
457
+ """Return a (640, 640) overlap mask."""
458
+ masks = np.zeros((img_size[0] // downsample_ratio, img_size[1] // downsample_ratio),
459
+ dtype=np.int32 if len(segments) > 255 else np.uint8)
460
+ areas = []
461
+ ms = []
462
+ for si in range(len(segments)):
463
+ mask = polygon2mask(
464
+ img_size,
465
+ [segments[si].reshape(-1)],
466
+ downsample_ratio=downsample_ratio,
467
+ color=1,
468
+ )
469
+ ms.append(mask)
470
+ areas.append(mask.sum())
471
+ areas = np.asarray(areas)
472
+ index = np.argsort(-areas)
473
+ ms = np.array(ms)[index]
474
+ for i in range(len(segments)):
475
+ mask = ms[i] * (i + 1)
476
+ masks = masks + mask
477
+ masks = np.clip(masks, a_min=0, a_max=i + 1)
478
+ return masks, index
utils/panoptic/general.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ def crop_mask(masks, boxes):
8
+ """
9
+ "Crop" predicted masks by zeroing out everything not in the predicted bbox.
10
+ Vectorized by Chong (thanks Chong).
11
+
12
+ Args:
13
+ - masks should be a size [h, w, n] tensor of masks
14
+ - boxes should be a size [n, 4] tensor of bbox coords in relative point form
15
+ """
16
+
17
+ n, h, w = masks.shape
18
+ x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(1,1,n)
19
+ r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,w,1)
20
+ c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(h,1,1)
21
+
22
+ return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
23
+
24
+
25
+ def process_mask_upsample(protos, masks_in, bboxes, shape):
26
+ """
27
+ Crop after upsample.
28
+ proto_out: [mask_dim, mask_h, mask_w]
29
+ out_masks: [n, mask_dim], n is number of masks after nms
30
+ bboxes: [n, 4], n is number of masks after nms
31
+ shape:input_image_size, (h, w)
32
+
33
+ return: h, w, n
34
+ """
35
+
36
+ c, mh, mw = protos.shape # CHW
37
+ masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
38
+ masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
39
+ masks = crop_mask(masks, bboxes) # CHW
40
+ return masks.gt_(0.5)
41
+
42
+
43
+ def process_mask(protos, masks_in, bboxes, shape, upsample=False):
44
+ """
45
+ Crop before upsample.
46
+ proto_out: [mask_dim, mask_h, mask_w]
47
+ out_masks: [n, mask_dim], n is number of masks after nms
48
+ bboxes: [n, 4], n is number of masks after nms
49
+ shape:input_image_size, (h, w)
50
+
51
+ return: h, w, n
52
+ """
53
+
54
+ c, mh, mw = protos.shape # CHW
55
+ ih, iw = shape
56
+ masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw) # CHW
57
+
58
+ downsampled_bboxes = bboxes.clone()
59
+ downsampled_bboxes[:, 0] *= mw / iw
60
+ downsampled_bboxes[:, 2] *= mw / iw
61
+ downsampled_bboxes[:, 3] *= mh / ih
62
+ downsampled_bboxes[:, 1] *= mh / ih
63
+
64
+ masks = crop_mask(masks, downsampled_bboxes) # CHW
65
+ if upsample:
66
+ masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
67
+ return masks.gt_(0.5)
68
+
69
+
70
+ def scale_image(im1_shape, masks, im0_shape, ratio_pad=None):
71
+ """
72
+ img1_shape: model input shape, [h, w]
73
+ img0_shape: origin pic shape, [h, w, 3]
74
+ masks: [h, w, num]
75
+ """
76
+ # Rescale coordinates (xyxy) from im1_shape to im0_shape
77
+ if ratio_pad is None: # calculate from im0_shape
78
+ gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new
79
+ pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding
80
+ else:
81
+ pad = ratio_pad[1]
82
+ top, left = int(pad[1]), int(pad[0]) # y, x
83
+ bottom, right = int(im1_shape[0] - pad[1]), int(im1_shape[1] - pad[0])
84
+
85
+ if len(masks.shape) < 2:
86
+ raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}')
87
+ masks = masks[top:bottom, left:right]
88
+ # masks = masks.permute(2, 0, 1).contiguous()
89
+ # masks = F.interpolate(masks[None], im0_shape[:2], mode='bilinear', align_corners=False)[0]
90
+ # masks = masks.permute(1, 2, 0).contiguous()
91
+ masks = cv2.resize(masks, (im0_shape[1], im0_shape[0]))
92
+
93
+ if len(masks.shape) == 2:
94
+ masks = masks[:, :, None]
95
+ return masks
96
+
97
+
98
+ def mask_iou(mask1, mask2, eps=1e-7):
99
+ """
100
+ mask1: [N, n] m1 means number of predicted objects
101
+ mask2: [M, n] m2 means number of gt objects
102
+ Note: n means image_w x image_h
103
+
104
+ return: masks iou, [N, M]
105
+ """
106
+ intersection = torch.matmul(mask1, mask2.t()).clamp(0)
107
+ union = (mask1.sum(1)[:, None] + mask2.sum(1)[None]) - intersection # (area1 + area2) - intersection
108
+ return intersection / (union + eps)
109
+
110
+
111
+ def masks_iou(mask1, mask2, eps=1e-7):
112
+ """
113
+ mask1: [N, n] m1 means number of predicted objects
114
+ mask2: [N, n] m2 means number of gt objects
115
+ Note: n means image_w x image_h
116
+
117
+ return: masks iou, (N, )
118
+ """
119
+ intersection = (mask1 * mask2).sum(1).clamp(0) # (N, )
120
+ union = (mask1.sum(1) + mask2.sum(1))[None] - intersection # (area1 + area2) - intersection
121
+ return intersection / (union + eps)
122
+
123
+
124
+ def masks2segments(masks, strategy='largest'):
125
+ # Convert masks(n,160,160) into segments(n,xy)
126
+ segments = []
127
+ for x in masks.int().cpu().numpy().astype('uint8'):
128
+ c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
129
+ if c:
130
+ if strategy == 'concat': # concatenate all segments
131
+ c = np.concatenate([x.reshape(-1, 2) for x in c])
132
+ elif strategy == 'largest': # select largest segment
133
+ c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
134
+ else:
135
+ c = np.zeros((0, 2)) # no segments found
136
+ segments.append(c.astype('float32'))
137
+ return segments
utils/panoptic/loss.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from ..general import xywh2xyxy
6
+ from ..loss import FocalLoss, smooth_BCE
7
+ from ..metrics import bbox_iou
8
+ from ..torch_utils import de_parallel
9
+ from .general import crop_mask
10
+
11
+
12
+ class ComputeLoss:
13
+ # Compute losses
14
+ def __init__(self, model, autobalance=False, overlap=False):
15
+ self.sort_obj_iou = False
16
+ self.overlap = overlap
17
+ device = next(model.parameters()).device # get model device
18
+ h = model.hyp # hyperparameters
19
+ self.device = device
20
+
21
+ # Define criteria
22
+ BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device))
23
+ BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device))
24
+
25
+ # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
26
+ self.cp, self.cn = smooth_BCE(eps=h.get('label_smoothing', 0.0)) # positive, negative BCE targets
27
+
28
+ # Focal loss
29
+ g = h['fl_gamma'] # focal loss gamma
30
+ if g > 0:
31
+ BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
32
+
33
+ m = de_parallel(model).model[-1] # Detect() module
34
+ self.balance = {3: [4.0, 1.0, 0.4]}.get(m.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7
35
+ self.ssi = list(m.stride).index(16) if autobalance else 0 # stride 16 index
36
+ self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance
37
+ self.na = m.na # number of anchors
38
+ self.nc = m.nc # number of classes
39
+ self.nl = m.nl # number of layers
40
+ self.nm = m.nm # number of masks
41
+ self.anchors = m.anchors
42
+ self.device = device
43
+
44
+ def __call__(self, preds, targets, masks): # predictions, targets, model
45
+ p, proto = preds
46
+ bs, nm, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
47
+ lcls = torch.zeros(1, device=self.device)
48
+ lbox = torch.zeros(1, device=self.device)
49
+ lobj = torch.zeros(1, device=self.device)
50
+ lseg = torch.zeros(1, device=self.device)
51
+ tcls, tbox, indices, anchors, tidxs, xywhn = self.build_targets(p, targets) # targets
52
+
53
+ # Losses
54
+ for i, pi in enumerate(p): # layer index, layer predictions
55
+ b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
56
+ tobj = torch.zeros(pi.shape[:4], dtype=pi.dtype, device=self.device) # target obj
57
+
58
+ n = b.shape[0] # number of targets
59
+ if n:
60
+ pxy, pwh, _, pcls, pmask = pi[b, a, gj, gi].split((2, 2, 1, self.nc, nm), 1) # subset of predictions
61
+
62
+ # Box regression
63
+ pxy = pxy.sigmoid() * 2 - 0.5
64
+ pwh = (pwh.sigmoid() * 2) ** 2 * anchors[i]
65
+ pbox = torch.cat((pxy, pwh), 1) # predicted box
66
+ iou = bbox_iou(pbox, tbox[i], CIoU=True).squeeze() # iou(prediction, target)
67
+ lbox += (1.0 - iou).mean() # iou loss
68
+
69
+ # Objectness
70
+ iou = iou.detach().clamp(0).type(tobj.dtype)
71
+ if self.sort_obj_iou:
72
+ j = iou.argsort()
73
+ b, a, gj, gi, iou = b[j], a[j], gj[j], gi[j], iou[j]
74
+ if self.gr < 1:
75
+ iou = (1.0 - self.gr) + self.gr * iou
76
+ tobj[b, a, gj, gi] = iou # iou ratio
77
+
78
+ # Classification
79
+ if self.nc > 1: # cls loss (only if multiple classes)
80
+ t = torch.full_like(pcls, self.cn, device=self.device) # targets
81
+ t[range(n), tcls[i]] = self.cp
82
+ lcls += self.BCEcls(pcls, t) # BCE
83
+
84
+ # Mask regression
85
+ if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
86
+ masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
87
+ marea = xywhn[i][:, 2:].prod(1) # mask width, height normalized
88
+ mxyxy = xywh2xyxy(xywhn[i] * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=self.device))
89
+ for bi in b.unique():
90
+ j = b == bi # matching index
91
+ if self.overlap:
92
+ mask_gti = torch.where(masks[bi][None] == tidxs[i][j].view(-1, 1, 1), 1.0, 0.0)
93
+ else:
94
+ mask_gti = masks[tidxs[i]][j]
95
+ lseg += self.single_mask_loss(mask_gti, pmask[j], proto[bi], mxyxy[j], marea[j])
96
+
97
+ obji = self.BCEobj(pi[..., 4], tobj)
98
+ lobj += obji * self.balance[i] # obj loss
99
+ if self.autobalance:
100
+ self.balance[i] = self.balance[i] * 0.9999 + 0.0001 / obji.detach().item()
101
+
102
+ if self.autobalance:
103
+ self.balance = [x / self.balance[self.ssi] for x in self.balance]
104
+ lbox *= self.hyp["box"]
105
+ lobj *= self.hyp["obj"]
106
+ lcls *= self.hyp["cls"]
107
+ lseg *= self.hyp["box"] / bs
108
+
109
+ loss = lbox + lobj + lcls + lseg
110
+ return loss * bs, torch.cat((lbox, lseg, lobj, lcls)).detach()
111
+
112
+ def single_mask_loss(self, gt_mask, pred, proto, xyxy, area):
113
+ # Mask loss for one image
114
+ pred_mask = (pred @ proto.view(self.nm, -1)).view(-1, *proto.shape[1:]) # (n,32) @ (32,80,80) -> (n,80,80)
115
+ loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none")
116
+ return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean()
117
+
118
+ def build_targets(self, p, targets):
119
+ # Build targets for compute_loss(), input targets(image,class,x,y,w,h)
120
+ na, nt = self.na, targets.shape[0] # number of anchors, targets
121
+ tcls, tbox, indices, anch, tidxs, xywhn = [], [], [], [], [], []
122
+ gain = torch.ones(8, device=self.device) # normalized to gridspace gain
123
+ ai = torch.arange(na, device=self.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
124
+ if self.overlap:
125
+ batch = p[0].shape[0]
126
+ ti = []
127
+ for i in range(batch):
128
+ num = (targets[:, 0] == i).sum() # find number of targets of each image
129
+ ti.append(torch.arange(num, device=self.device).float().view(1, num).repeat(na, 1) + 1) # (na, num)
130
+ ti = torch.cat(ti, 1) # (na, nt)
131
+ else:
132
+ ti = torch.arange(nt, device=self.device).float().view(1, nt).repeat(na, 1)
133
+ targets = torch.cat((targets.repeat(na, 1, 1), ai[..., None], ti[..., None]), 2) # append anchor indices
134
+
135
+ g = 0.5 # bias
136
+ off = torch.tensor(
137
+ [
138
+ [0, 0],
139
+ [1, 0],
140
+ [0, 1],
141
+ [-1, 0],
142
+ [0, -1], # j,k,l,m
143
+ # [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm
144
+ ],
145
+ device=self.device).float() * g # offsets
146
+
147
+ for i in range(self.nl):
148
+ anchors, shape = self.anchors[i], p[i].shape
149
+ gain[2:6] = torch.tensor(shape)[[3, 2, 3, 2]] # xyxy gain
150
+
151
+ # Match targets to anchors
152
+ t = targets * gain # shape(3,n,7)
153
+ if nt:
154
+ # Matches
155
+ r = t[..., 4:6] / anchors[:, None] # wh ratio
156
+ j = torch.max(r, 1 / r).max(2)[0] < self.hyp['anchor_t'] # compare
157
+ # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
158
+ t = t[j] # filter
159
+
160
+ # Offsets
161
+ gxy = t[:, 2:4] # grid xy
162
+ gxi = gain[[2, 3]] - gxy # inverse
163
+ j, k = ((gxy % 1 < g) & (gxy > 1)).T
164
+ l, m = ((gxi % 1 < g) & (gxi > 1)).T
165
+ j = torch.stack((torch.ones_like(j), j, k, l, m))
166
+ t = t.repeat((5, 1, 1))[j]
167
+ offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
168
+ else:
169
+ t = targets[0]
170
+ offsets = 0
171
+
172
+ # Define
173
+ bc, gxy, gwh, at = t.chunk(4, 1) # (image, class), grid xy, grid wh, anchors
174
+ (a, tidx), (b, c) = at.long().T, bc.long().T # anchors, image, class
175
+ gij = (gxy - offsets).long()
176
+ gi, gj = gij.T # grid indices
177
+
178
+ # Append
179
+ indices.append((b, a, gj.clamp_(0, shape[2] - 1), gi.clamp_(0, shape[3] - 1))) # image, anchor, grid
180
+ tbox.append(torch.cat((gxy - gij, gwh), 1)) # box
181
+ anch.append(anchors[a]) # anchors
182
+ tcls.append(c) # class
183
+ tidxs.append(tidx)
184
+ xywhn.append(torch.cat((gxy, gwh), 1) / gain[2:6]) # xywh normalized
185
+
186
+ return tcls, tbox, indices, anch, tidxs, xywhn
utils/panoptic/loss_tal.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from torchvision.ops import sigmoid_focal_loss
8
+
9
+ from utils.general import xywh2xyxy, xyxy2xywh
10
+ from utils.metrics import bbox_iou
11
+ from utils.panoptic.tal.anchor_generator import dist2bbox, make_anchors, bbox2dist
12
+ from utils.panoptic.tal.assigner import TaskAlignedAssigner
13
+ from utils.torch_utils import de_parallel
14
+ from utils.panoptic.general import crop_mask
15
+
16
+
17
+ def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
18
+ # return positive, negative label smoothing BCE targets
19
+ return 1.0 - 0.5 * eps, 0.5 * eps
20
+
21
+
22
+ class VarifocalLoss(nn.Module):
23
+ # Varifocal loss by Zhang et al. https://arxiv.org/abs/2008.13367
24
+ def __init__(self):
25
+ super().__init__()
26
+
27
+ def forward(self, pred_score, gt_score, label, alpha=0.75, gamma=2.0):
28
+ weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
29
+ with torch.cuda.amp.autocast(enabled=False):
30
+ loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(),
31
+ reduction="none") * weight).sum()
32
+ return loss
33
+
34
+
35
+ class FocalLoss(nn.Module):
36
+ # Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
37
+ def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
38
+ super().__init__()
39
+ self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
40
+ self.gamma = gamma
41
+ self.alpha = alpha
42
+ self.reduction = loss_fcn.reduction
43
+ self.loss_fcn.reduction = "none" # required to apply FL to each element
44
+
45
+ def forward(self, pred, true):
46
+ loss = self.loss_fcn(pred, true)
47
+ # p_t = torch.exp(-loss)
48
+ # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
49
+
50
+ # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
51
+ pred_prob = torch.sigmoid(pred) # prob from logits
52
+ p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
53
+ alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
54
+ modulating_factor = (1.0 - p_t) ** self.gamma
55
+ loss *= alpha_factor * modulating_factor
56
+
57
+ if self.reduction == "mean":
58
+ return loss.mean()
59
+ elif self.reduction == "sum":
60
+ return loss.sum()
61
+ else: # 'none'
62
+ return loss
63
+
64
+
65
+ class BboxLoss(nn.Module):
66
+ def __init__(self, reg_max, use_dfl=False):
67
+ super().__init__()
68
+ self.reg_max = reg_max
69
+ self.use_dfl = use_dfl
70
+
71
+ def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
72
+ # iou loss
73
+ bbox_mask = fg_mask.unsqueeze(-1).repeat([1, 1, 4]) # (b, h*w, 4)
74
+ pred_bboxes_pos = torch.masked_select(pred_bboxes, bbox_mask).view(-1, 4)
75
+ target_bboxes_pos = torch.masked_select(target_bboxes, bbox_mask).view(-1, 4)
76
+ bbox_weight = torch.masked_select(target_scores.sum(-1), fg_mask).unsqueeze(-1)
77
+
78
+ iou = bbox_iou(pred_bboxes_pos, target_bboxes_pos, xywh=False, CIoU=True)
79
+ loss_iou = 1.0 - iou
80
+
81
+ #### wiou
82
+ #iou = bbox_iou(pred_bboxes_pos, target_bboxes_pos, xywh=False, WIoU=True, scale=True)
83
+ #if type(iou) is tuple:
84
+ # if len(iou) == 2:
85
+ # loss_iou = (iou[1].detach() * (1 - iou[0]))
86
+ # iou = iou[0]
87
+ # else:
88
+ # loss_iou = (iou[0] * iou[1])
89
+ # iou = iou[-1]
90
+ #else:
91
+ # loss_iou = (1.0 - iou) # iou loss
92
+
93
+ loss_iou *= bbox_weight
94
+ loss_iou = loss_iou.sum() / target_scores_sum
95
+ # loss_iou = loss_iou.mean()
96
+
97
+ # dfl loss
98
+ if self.use_dfl:
99
+ dist_mask = fg_mask.unsqueeze(-1).repeat([1, 1, (self.reg_max + 1) * 4])
100
+ pred_dist_pos = torch.masked_select(pred_dist, dist_mask).view(-1, 4, self.reg_max + 1)
101
+ target_ltrb = bbox2dist(anchor_points, target_bboxes, self.reg_max)
102
+ target_ltrb_pos = torch.masked_select(target_ltrb, bbox_mask).view(-1, 4)
103
+ loss_dfl = self._df_loss(pred_dist_pos, target_ltrb_pos) * bbox_weight
104
+ loss_dfl = loss_dfl.sum() / target_scores_sum
105
+ else:
106
+ loss_dfl = torch.tensor(0.0).to(pred_dist.device)
107
+
108
+ return loss_iou, loss_dfl, iou
109
+
110
+ def _df_loss(self, pred_dist, target):
111
+ target_left = target.to(torch.long)
112
+ target_right = target_left + 1
113
+ weight_left = target_right.to(torch.float) - target
114
+ weight_right = 1 - weight_left
115
+ loss_left = F.cross_entropy(pred_dist.view(-1, self.reg_max + 1), target_left.view(-1), reduction="none").view(
116
+ target_left.shape) * weight_left
117
+ loss_right = F.cross_entropy(pred_dist.view(-1, self.reg_max + 1), target_right.view(-1),
118
+ reduction="none").view(target_left.shape) * weight_right
119
+ return (loss_left + loss_right).mean(-1, keepdim=True)
120
+
121
+
122
+ class ComputeLoss:
123
+ # Compute losses
124
+ def __init__(self, model, use_dfl=True, overlap=True):
125
+ device = next(model.parameters()).device # get model device
126
+ h = model.hyp # hyperparameters
127
+
128
+ # Define criteria
129
+ BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h["cls_pw"]], device=device), reduction='none')
130
+
131
+ # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
132
+ self.cp, self.cn = smooth_BCE(eps=h.get("label_smoothing", 0.0)) # positive, negative BCE targets
133
+
134
+ # Focal loss
135
+ g = h["fl_gamma"] # focal loss gamma
136
+ if g > 0:
137
+ BCEcls = FocalLoss(BCEcls, g)
138
+
139
+ m = de_parallel(model).model[-1] # Detect() module
140
+ self.balance = {3: [4.0, 1.0, 0.4]}.get(m.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7
141
+ self.BCEcls = BCEcls
142
+ self.hyp = h
143
+ self.stride = m.stride # model strides
144
+ self.nc = m.nc # number of classes
145
+ self.nl = m.nl # number of layers
146
+ self.no = m.no
147
+ self.nm = m.nm
148
+ self.overlap = overlap
149
+ self.reg_max = m.reg_max
150
+ self.device = device
151
+
152
+ self.assigner = TaskAlignedAssigner(topk=int(os.getenv('YOLOM', 10)),
153
+ num_classes=self.nc,
154
+ alpha=float(os.getenv('YOLOA', 0.5)),
155
+ beta=float(os.getenv('YOLOB', 6.0)))
156
+ self.bbox_loss = BboxLoss(m.reg_max - 1, use_dfl=use_dfl).to(device)
157
+ self.proj = torch.arange(m.reg_max).float().to(device) # / 120.0
158
+ self.use_dfl = use_dfl
159
+
160
+ def preprocess(self, targets, batch_size, scale_tensor):
161
+ if targets.shape[0] == 0:
162
+ out = torch.zeros(batch_size, 0, 5, device=self.device)
163
+ else:
164
+ i = targets[:, 0] # image index
165
+ _, counts = i.unique(return_counts=True)
166
+ out = torch.zeros(batch_size, counts.max(), 5, device=self.device)
167
+ for j in range(batch_size):
168
+ matches = i == j
169
+ n = matches.sum()
170
+ if n:
171
+ out[j, :n] = targets[matches, 1:]
172
+ out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))
173
+ return out
174
+
175
+ def bbox_decode(self, anchor_points, pred_dist):
176
+ if self.use_dfl:
177
+ b, a, c = pred_dist.shape # batch, anchors, channels
178
+ pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
179
+ # pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype))
180
+ # pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
181
+ return dist2bbox(pred_dist, anchor_points, xywh=False)
182
+
183
+ def __call__(self, p, targets, masks, semasks, img=None, epoch=0):
184
+ loss = torch.zeros(6, device=self.device) # box, cls, dfl
185
+ feats, pred_masks, proto, psemasks = p if len(p) == 4 else p[1]
186
+ batch_size, _, mask_h, mask_w = proto.shape
187
+ pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
188
+ (self.reg_max * 4, self.nc), 1)
189
+ pred_scores = pred_scores.permute(0, 2, 1).contiguous()
190
+ pred_distri = pred_distri.permute(0, 2, 1).contiguous()
191
+ pred_masks = pred_masks.permute(0, 2, 1).contiguous()
192
+
193
+ dtype = pred_scores.dtype
194
+ batch_size, grid_size = pred_scores.shape[:2]
195
+ imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
196
+ anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
197
+
198
+ # targets
199
+ try:
200
+ batch_idx = targets[:, 0].view(-1, 1)
201
+ targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
202
+ gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
203
+ mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
204
+ except RuntimeError as e:
205
+ raise TypeError('ERROR.') from e
206
+
207
+
208
+ # pboxes
209
+ pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
210
+
211
+ target_labels, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
212
+ pred_scores.detach().sigmoid(),
213
+ (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
214
+ anchor_points * stride_tensor,
215
+ gt_labels,
216
+ gt_bboxes,
217
+ mask_gt)
218
+
219
+ target_scores_sum = target_scores.sum()
220
+
221
+ # cls loss
222
+ # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
223
+ loss[2] = self.BCEcls(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
224
+
225
+ # bbox loss
226
+ if fg_mask.sum():
227
+ loss[0], loss[3], _ = self.bbox_loss(pred_distri,
228
+ pred_bboxes,
229
+ anchor_points,
230
+ target_bboxes / stride_tensor,
231
+ target_scores,
232
+ target_scores_sum,
233
+ fg_mask)
234
+
235
+ # masks loss
236
+ if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
237
+ masks = F.interpolate(masks[None], (mask_h, mask_w), mode='nearest')[0]
238
+
239
+ for i in range(batch_size):
240
+ if fg_mask[i].sum():
241
+ mask_idx = target_gt_idx[i][fg_mask[i]]
242
+ if self.overlap:
243
+ gt_mask = torch.where(masks[[i]] == (mask_idx + 1).view(-1, 1, 1), 1.0, 0.0)
244
+ else:
245
+ gt_mask = masks[batch_idx.view(-1) == i][mask_idx]
246
+ xyxyn = target_bboxes[i][fg_mask[i]] / imgsz[[1, 0, 1, 0]]
247
+ marea = xyxy2xywh(xyxyn)[:, 2:].prod(1)
248
+ mxyxy = xyxyn * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=self.device)
249
+ loss[1] += self.single_mask_loss(gt_mask, pred_masks[i][fg_mask[i]], proto[i], mxyxy,
250
+ marea) # seg loss
251
+ # Semantic Segmentation
252
+ # focal loss
253
+ pt = torch.flatten(psemasks, start_dim = 2).permute(0, 2, 1)
254
+ gt = torch.flatten(semasks, start_dim = 2).permute(0, 2, 1)
255
+
256
+ bs, _, _ = gt.shape
257
+ #torch.clamp(torch.sigmoid(logits), min=eps, max= 1 - eps)
258
+ #total_loss = (sigmoid_focal_loss(pt.float(), gt.float(), alpha = .25, gamma = 2., reduction = 'mean')) / 2.
259
+ #total_loss = (sigmoid_focal_loss(pt.clamp(-16., 16.), gt, alpha = .25, gamma = 2., reduction = 'mean')) / 2.
260
+ total_loss = (sigmoid_focal_loss(pt, gt, alpha = .25, gamma = 2., reduction = 'mean')) / 2.
261
+ loss[4] += total_loss * 20.
262
+
263
+ # dice loss
264
+ pt = torch.flatten(psemasks.softmax(dim = 1))
265
+ gt = torch.flatten(semasks)
266
+
267
+ inter_mask = torch.sum(torch.mul(pt, gt))
268
+ union_mask = torch.sum(torch.add(pt, gt))
269
+ dice_coef = (2. * inter_mask + 1.) / (union_mask + 1.)
270
+ loss[5] += (1. - dice_coef) / 2.
271
+
272
+ loss[0] *= 7.5 # box gain
273
+ loss[1] *= 2.5 / batch_size
274
+ loss[2] *= 0.5 # cls gain
275
+ loss[3] *= 1.5 # dfl gain
276
+ loss[4] *= 2.5 #/ batch_size
277
+ loss[5] *= 2.5 #/ batch_size
278
+
279
+ return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
280
+
281
+ def single_mask_loss(self, gt_mask, pred, proto, xyxy, area):
282
+ # Mask loss for one image
283
+ pred_mask = (pred @ proto.view(self.nm, -1)).view(-1, *proto.shape[1:]) # (n, 32) @ (32,80,80) -> (n,80,80)
284
+ loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction='none')
285
+ return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean()
utils/panoptic/metrics.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ from ..metrics import ap_per_class
5
+
6
+
7
+ def fitness(x):
8
+ # Model fitness as a weighted combination of metrics
9
+ w = [0.0, 0.0, 0.1, 0.9, 0.0, 0.0, 0.1, 0.9, 0.1, 0.9]
10
+ return (x[:, :len(w)] * w).sum(1)
11
+
12
+
13
+ def ap_per_class_box_and_mask(
14
+ tp_m,
15
+ tp_b,
16
+ conf,
17
+ pred_cls,
18
+ target_cls,
19
+ plot=False,
20
+ save_dir=".",
21
+ names=(),
22
+ ):
23
+ """
24
+ Args:
25
+ tp_b: tp of boxes.
26
+ tp_m: tp of masks.
27
+ other arguments see `func: ap_per_class`.
28
+ """
29
+ results_boxes = ap_per_class(tp_b,
30
+ conf,
31
+ pred_cls,
32
+ target_cls,
33
+ plot=plot,
34
+ save_dir=save_dir,
35
+ names=names,
36
+ prefix="Box")[2:]
37
+ results_masks = ap_per_class(tp_m,
38
+ conf,
39
+ pred_cls,
40
+ target_cls,
41
+ plot=plot,
42
+ save_dir=save_dir,
43
+ names=names,
44
+ prefix="Mask")[2:]
45
+
46
+ results = {
47
+ "boxes": {
48
+ "p": results_boxes[0],
49
+ "r": results_boxes[1],
50
+ "ap": results_boxes[3],
51
+ "f1": results_boxes[2],
52
+ "ap_class": results_boxes[4]},
53
+ "masks": {
54
+ "p": results_masks[0],
55
+ "r": results_masks[1],
56
+ "ap": results_masks[3],
57
+ "f1": results_masks[2],
58
+ "ap_class": results_masks[4]}}
59
+ return results
60
+
61
+
62
+ class Metric:
63
+
64
+ def __init__(self) -> None:
65
+ self.p = [] # (nc, )
66
+ self.r = [] # (nc, )
67
+ self.f1 = [] # (nc, )
68
+ self.all_ap = [] # (nc, 10)
69
+ self.ap_class_index = [] # (nc, )
70
+
71
+ @property
72
+ def ap50(self):
73
+ """AP@0.5 of all classes.
74
+ Return:
75
+ (nc, ) or [].
76
+ """
77
+ return self.all_ap[:, 0] if len(self.all_ap) else []
78
+
79
+ @property
80
+ def ap(self):
81
+ """AP@0.5:0.95
82
+ Return:
83
+ (nc, ) or [].
84
+ """
85
+ return self.all_ap.mean(1) if len(self.all_ap) else []
86
+
87
+ @property
88
+ def mp(self):
89
+ """mean precision of all classes.
90
+ Return:
91
+ float.
92
+ """
93
+ return self.p.mean() if len(self.p) else 0.0
94
+
95
+ @property
96
+ def mr(self):
97
+ """mean recall of all classes.
98
+ Return:
99
+ float.
100
+ """
101
+ return self.r.mean() if len(self.r) else 0.0
102
+
103
+ @property
104
+ def map50(self):
105
+ """Mean AP@0.5 of all classes.
106
+ Return:
107
+ float.
108
+ """
109
+ return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0
110
+
111
+ @property
112
+ def map(self):
113
+ """Mean AP@0.5:0.95 of all classes.
114
+ Return:
115
+ float.
116
+ """
117
+ return self.all_ap.mean() if len(self.all_ap) else 0.0
118
+
119
+ def mean_results(self):
120
+ """Mean of results, return mp, mr, map50, map"""
121
+ return (self.mp, self.mr, self.map50, self.map)
122
+
123
+ def class_result(self, i):
124
+ """class-aware result, return p[i], r[i], ap50[i], ap[i]"""
125
+ return (self.p[i], self.r[i], self.ap50[i], self.ap[i])
126
+
127
+ def get_maps(self, nc):
128
+ maps = np.zeros(nc) + self.map
129
+ for i, c in enumerate(self.ap_class_index):
130
+ maps[c] = self.ap[i]
131
+ return maps
132
+
133
+ def update(self, results):
134
+ """
135
+ Args:
136
+ results: tuple(p, r, ap, f1, ap_class)
137
+ """
138
+ p, r, all_ap, f1, ap_class_index = results
139
+ self.p = p
140
+ self.r = r
141
+ self.all_ap = all_ap
142
+ self.f1 = f1
143
+ self.ap_class_index = ap_class_index
144
+
145
+
146
+ class Metrics:
147
+ """Metric for boxes and masks."""
148
+
149
+ def __init__(self) -> None:
150
+ self.metric_box = Metric()
151
+ self.metric_mask = Metric()
152
+
153
+ def update(self, results):
154
+ """
155
+ Args:
156
+ results: Dict{'boxes': Dict{}, 'masks': Dict{}}
157
+ """
158
+ self.metric_box.update(list(results["boxes"].values()))
159
+ self.metric_mask.update(list(results["masks"].values()))
160
+
161
+ def mean_results(self):
162
+ return self.metric_box.mean_results() + self.metric_mask.mean_results()
163
+
164
+ def class_result(self, i):
165
+ return self.metric_box.class_result(i) + self.metric_mask.class_result(i)
166
+
167
+ def get_maps(self, nc):
168
+ return self.metric_box.get_maps(nc) + self.metric_mask.get_maps(nc)
169
+
170
+ @property
171
+ def ap_class_index(self):
172
+ # boxes and masks have the same ap_class_index
173
+ return self.metric_box.ap_class_index
174
+
175
+
176
+ class Semantic_Metrics:
177
+ def __init__(self, nc, device):
178
+ self.nc = nc # number of classes
179
+ self.device = device
180
+ self.iou = []
181
+ self.c_bit_counts = torch.zeros(nc, dtype = torch.long).to(device)
182
+ self.c_intersection_counts = torch.zeros(nc, dtype = torch.long).to(device)
183
+ self.c_union_counts = torch.zeros(nc, dtype = torch.long).to(device)
184
+
185
+ def update(self, pred_masks, target_masks):
186
+ nb, nc, h, w = pred_masks.shape
187
+ device = pred_masks.device
188
+
189
+ for b in range(nb):
190
+ onehot_mask = pred_masks[b].to(device)
191
+ # convert predict mask to one hot
192
+ semantic_mask = torch.flatten(onehot_mask, start_dim = 1).permute(1, 0) # class x h x w -> (h x w) x class
193
+ max_idx = semantic_mask.argmax(1)
194
+ output_masks = (torch.zeros(semantic_mask.shape).to(self.device)).scatter(1, max_idx.unsqueeze(1), 1.0) # one hot: (h x w) x class
195
+ output_masks = torch.reshape(output_masks.permute(1, 0), (nc, h, w)) # (h x w) x class -> class x h x w
196
+ onehot_mask = output_masks.int()
197
+
198
+ for c in range(self.nc):
199
+ pred_mask = onehot_mask[c].to(device)
200
+ target_mask = target_masks[b, c].to(device)
201
+
202
+ # calculate IoU
203
+ intersection = (torch.logical_and(pred_mask, target_mask).sum()).item()
204
+ union = (torch.logical_or(pred_mask, target_mask).sum()).item()
205
+ iou = 0. if (0 == union) else (intersection / union)
206
+
207
+ # record class pixel counts, intersection counts, union counts
208
+ self.c_bit_counts[c] += target_mask.int().sum()
209
+ self.c_intersection_counts[c] += intersection
210
+ self.c_union_counts[c] += union
211
+
212
+ self.iou.append(iou)
213
+
214
+ def results(self):
215
+ # Mean IoU
216
+ miou = 0. if (0 == len(self.iou)) else np.sum(self.iou) / (len(self.iou) * self.nc)
217
+
218
+ # Frequency Weighted IoU
219
+ c_iou = self.c_intersection_counts / (self.c_union_counts + 1) # add smooth
220
+ # c_bit_counts = self.c_bit_counts.astype(int)
221
+ total_c_bit_counts = self.c_bit_counts.sum()
222
+ freq_ious = torch.zeros(1, dtype = torch.long).to(self.device) if (0 == total_c_bit_counts) else (self.c_bit_counts / total_c_bit_counts) * c_iou
223
+ fwiou = (freq_ious.sum()).item()
224
+
225
+ return (miou, fwiou)
226
+
227
+ def reset(self):
228
+ self.iou = []
229
+ self.c_bit_counts = torch.zeros(self.nc, dtype = torch.long).to(self.device)
230
+ self.c_intersection_counts = torch.zeros(self.nc, dtype = torch.long).to(self.device)
231
+ self.c_union_counts = torch.zeros(self.nc, dtype = torch.long).to(self.device)
232
+
233
+
234
+ KEYS = [
235
+ "train/box_loss",
236
+ "train/seg_loss", # train loss
237
+ "train/cls_loss",
238
+ "train/dfl_loss",
239
+ "train/fcl_loss",
240
+ "train/dic_loss",
241
+ "metrics/precision(B)",
242
+ "metrics/recall(B)",
243
+ "metrics/mAP_0.5(B)",
244
+ "metrics/mAP_0.5:0.95(B)", # metrics
245
+ "metrics/precision(M)",
246
+ "metrics/recall(M)",
247
+ "metrics/mAP_0.5(M)",
248
+ "metrics/mAP_0.5:0.95(M)", # metrics
249
+ "metrics/MIOUS(S)",
250
+ "metrics/FWIOUS(S)", # metrics
251
+ "val/box_loss",
252
+ "val/seg_loss", # val loss
253
+ "val/cls_loss",
254
+ "val/dfl_loss",
255
+ "val/fcl_loss",
256
+ "val/dic_loss",
257
+ "x/lr0",
258
+ "x/lr1",
259
+ "x/lr2",]
260
+
261
+ BEST_KEYS = [
262
+ "best/epoch",
263
+ "best/precision(B)",
264
+ "best/recall(B)",
265
+ "best/mAP_0.5(B)",
266
+ "best/mAP_0.5:0.95(B)",
267
+ "best/precision(M)",
268
+ "best/recall(M)",
269
+ "best/mAP_0.5(M)",
270
+ "best/mAP_0.5:0.95(M)",
271
+ "best/MIOUS(S)",
272
+ "best/FWIOUS(S)",]
utils/panoptic/plots.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import math
3
+ from pathlib import Path
4
+
5
+ import cv2
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import pandas as pd
9
+ import torch
10
+ from torchvision.utils import draw_segmentation_masks, save_image
11
+
12
+ from .. import threaded
13
+ from ..general import xywh2xyxy
14
+ from ..plots import Annotator, colors
15
+
16
+
17
+ @threaded
18
+ def plot_images_and_masks(images, targets, masks, semasks, paths=None, fname='images.jpg', names=None):
19
+
20
+ try:
21
+ if images.shape[-2:] != semasks.shape[-2:]:
22
+ m = torch.nn.Upsample(scale_factor=4, mode='nearest')
23
+ semasks = m(semasks)
24
+
25
+ for idx in range(images.shape[0]):
26
+ output_img = draw_segmentation_masks(
27
+ image = images[idx, :, :, :].cpu().to(dtype = torch.uint8),
28
+ masks = semasks[idx, :, :, :].cpu().to(dtype = torch.bool),
29
+ alpha = 1)
30
+ cv2.imwrite(
31
+ '{}_{}.jpg'.format(fname, idx),
32
+ torch.permute(output_img, (1, 2, 0)).numpy()
33
+ )
34
+ except:
35
+ pass
36
+
37
+ # Plot image grid with labels
38
+ if isinstance(images, torch.Tensor):
39
+ images = images.cpu().float().numpy()
40
+ if isinstance(targets, torch.Tensor):
41
+ targets = targets.cpu().numpy()
42
+ if isinstance(masks, torch.Tensor):
43
+ masks = masks.cpu().numpy().astype(int)
44
+ if isinstance(semasks, torch.Tensor):
45
+ semasks = semasks.cpu().numpy().astype(int)
46
+
47
+ max_size = 1920 # max image size
48
+ max_subplots = 16 # max image subplots, i.e. 4x4
49
+ bs, _, h, w = images.shape # batch size, _, height, width
50
+ bs = min(bs, max_subplots) # limit plot images
51
+ ns = np.ceil(bs ** 0.5) # number of subplots (square)
52
+ if np.max(images[0]) <= 1:
53
+ images *= 255 # de-normalise (optional)
54
+
55
+ # Build Image
56
+ mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
57
+ for i, im in enumerate(images):
58
+ if i == max_subplots: # if last batch has fewer images than we expect
59
+ break
60
+ x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
61
+ im = im.transpose(1, 2, 0)
62
+ mosaic[y:y + h, x:x + w, :] = im
63
+
64
+ # Resize (optional)
65
+ scale = max_size / ns / max(h, w)
66
+ if scale < 1:
67
+ h = math.ceil(scale * h)
68
+ w = math.ceil(scale * w)
69
+ mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
70
+
71
+ # Annotate
72
+ fs = int((h + w) * ns * 0.01) # font size
73
+ annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)
74
+ for i in range(i + 1):
75
+ x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
76
+ annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
77
+ if paths:
78
+ annotator.text((x + 5, y + 5 + h), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
79
+ if len(targets) > 0:
80
+ idx = targets[:, 0] == i
81
+ ti = targets[idx] # image targets
82
+
83
+ boxes = xywh2xyxy(ti[:, 2:6]).T
84
+ classes = ti[:, 1].astype('int')
85
+ labels = ti.shape[1] == 6 # labels if no conf column
86
+ conf = None if labels else ti[:, 6] # check for confidence presence (label vs pred)
87
+
88
+ if boxes.shape[1]:
89
+ if boxes.max() <= 1.01: # if normalized with tolerance 0.01
90
+ boxes[[0, 2]] *= w # scale to pixels
91
+ boxes[[1, 3]] *= h
92
+ elif scale < 1: # absolute coords need scale if image scales
93
+ boxes *= scale
94
+ boxes[[0, 2]] += x
95
+ boxes[[1, 3]] += y
96
+ for j, box in enumerate(boxes.T.tolist()):
97
+ cls = classes[j]
98
+ color = colors(cls)
99
+ cls = names[cls] if names else cls
100
+ if labels or conf[j] > 0.25: # 0.25 conf thresh
101
+ label = f'{cls}' if labels else f'{cls} {conf[j]:.1f}'
102
+ annotator.box_label(box, label, color=color)
103
+
104
+ # Plot masks
105
+ if len(masks):
106
+ if masks.max() > 1.0: # mean that masks are overlap
107
+ image_masks = masks[[i]] # (1, 640, 640)
108
+ nl = len(ti)
109
+ index = np.arange(nl).reshape(nl, 1, 1) + 1
110
+ image_masks = np.repeat(image_masks, nl, axis=0)
111
+ image_masks = np.where(image_masks == index, 1.0, 0.0)
112
+ else:
113
+ image_masks = masks[idx]
114
+
115
+ im = np.asarray(annotator.im).copy()
116
+ for j, box in enumerate(boxes.T.tolist()):
117
+ if labels or conf[j] > 0.25: # 0.25 conf thresh
118
+ color = colors(classes[j])
119
+ mh, mw = image_masks[j].shape
120
+ if mh != h or mw != w:
121
+ mask = image_masks[j].astype(np.uint8)
122
+ mask = cv2.resize(mask, (w, h))
123
+ mask = mask.astype(bool)
124
+ else:
125
+ mask = image_masks[j].astype(bool)
126
+ with contextlib.suppress(Exception):
127
+ im[y:y + h, x:x + w, :][mask] = im[y:y + h, x:x + w, :][mask] * 0.4 + np.array(color) * 0.6
128
+ annotator.fromarray(im)
129
+ annotator.im.save(fname) # save
130
+
131
+
132
+ def plot_results_with_masks(file="path/to/results.csv", dir="", best=True):
133
+ # Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')
134
+ save_dir = Path(file).parent if file else Path(dir)
135
+ fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True)
136
+ ax = ax.ravel()
137
+ files = list(save_dir.glob("results*.csv"))
138
+ assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot."
139
+ for f in files:
140
+ try:
141
+ data = pd.read_csv(f)
142
+ index = np.argmax(0.9 * data.values[:, 8] + 0.1 * data.values[:, 7] + 0.9 * data.values[:, 12] +
143
+ 0.1 * data.values[:, 11])
144
+ s = [x.strip() for x in data.columns]
145
+ x = data.values[:, 0]
146
+ for i, j in enumerate([1, 2, 3, 4, 5, 6, 9, 10, 13, 14, 15, 16, 7, 8, 11, 12]):
147
+ y = data.values[:, j]
148
+ # y[y == 0] = np.nan # don't show zero values
149
+ ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=2)
150
+ if best:
151
+ # best
152
+ ax[i].scatter(index, y[index], color="r", label=f"best:{index}", marker="*", linewidth=3)
153
+ ax[i].set_title(s[j] + f"\n{round(y[index], 5)}")
154
+ else:
155
+ # last
156
+ ax[i].scatter(x[-1], y[-1], color="r", label="last", marker="*", linewidth=3)
157
+ ax[i].set_title(s[j] + f"\n{round(y[-1], 5)}")
158
+ # if j in [8, 9, 10]: # share train and val loss y axes
159
+ # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
160
+ except Exception as e:
161
+ print(f"Warning: Plotting error for {f}: {e}")
162
+ ax[1].legend()
163
+ fig.savefig(save_dir / "results.png", dpi=200)
164
+ plt.close()
utils/panoptic/tal/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # init