watchtowerss wybertwang commited on
Commit
ee87a3a
·
0 Parent(s):

Duplicate from TencentARC/Caption-Anything

Browse files

Co-authored-by: wybertwang <wybertwang@users.noreply.huggingface.co>

Files changed (49) hide show
  1. .gitattributes +42 -0
  2. .gitignore +143 -0
  3. LICENSE +28 -0
  4. README.md +15 -0
  5. app.py +599 -0
  6. assets/UI.png +3 -0
  7. assets/caption_anything_logo.png +0 -0
  8. assets/demo1.jpg +3 -0
  9. assets/demo1.png +3 -0
  10. assets/demo1.svg +0 -0
  11. assets/demo2.png +0 -0
  12. assets/demo2.svg +0 -0
  13. assets/qingming.gif +3 -0
  14. assets/times_with_simsun.ttf +3 -0
  15. assets/title.png +0 -0
  16. assets/title.svg +1 -0
  17. caption_anything/__init__.py +0 -0
  18. caption_anything/captioner/README.md +13 -0
  19. caption_anything/captioner/__init__.py +15 -0
  20. caption_anything/captioner/base_captioner.py +200 -0
  21. caption_anything/captioner/blip.py +72 -0
  22. caption_anything/captioner/blip2.py +71 -0
  23. caption_anything/captioner/git.py +67 -0
  24. caption_anything/captioner/modeling_blip.py +1476 -0
  25. caption_anything/captioner/modeling_git.py +1587 -0
  26. caption_anything/captioner/vit_pixel_masks_utils.py +17 -0
  27. caption_anything/model.py +294 -0
  28. caption_anything/segmenter/__init__.py +14 -0
  29. caption_anything/segmenter/base_segmenter.py +184 -0
  30. caption_anything/segmenter/readme.md +68 -0
  31. caption_anything/text_refiner/README.md +8 -0
  32. caption_anything/text_refiner/__init__.py +6 -0
  33. caption_anything/text_refiner/text_refiner.py +86 -0
  34. caption_anything/utils/chatbot.py +225 -0
  35. caption_anything/utils/densecap_painter.py +64 -0
  36. caption_anything/utils/image_editing_utils.py +127 -0
  37. caption_anything/utils/parser.py +35 -0
  38. caption_anything/utils/utils.py +496 -0
  39. requirements.txt +21 -0
  40. sam_vit_h_4b8939.pth +3 -0
  41. test_images/img0.png +0 -0
  42. test_images/img1.jpg +0 -0
  43. test_images/img12.jpg +0 -0
  44. test_images/img14.jpg +0 -0
  45. test_images/img2.jpg +0 -0
  46. test_images/img35.webp +0 -0
  47. test_images/img36.webp +0 -0
  48. test_images/img5.jpg +0 -0
  49. test_images/qingming3.jpeg +3 -0
.gitattributes ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ test_img/img18.jpg filter=lfs diff=lfs merge=lfs -text
36
+ test_img/img22.jpg filter=lfs diff=lfs merge=lfs -text
37
+ times_with_simsun.ttf filter=lfs diff=lfs merge=lfs -text
38
+ test_images/qingming3.jpeg filter=lfs diff=lfs merge=lfs -text
39
+ assets/demo1.jpg filter=lfs diff=lfs merge=lfs -text
40
+ assets/demo1.png filter=lfs diff=lfs merge=lfs -text
41
+ assets/qingming.gif filter=lfs diff=lfs merge=lfs -text
42
+ assets/UI.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ result/
2
+ model_cache/
3
+ *.pth
4
+ teng_grad_start.sh
5
+ *.jpg
6
+ *.jpeg
7
+ *.png
8
+ *.svg
9
+ *.gif
10
+ *.tiff
11
+ *.webp
12
+
13
+
14
+ # Byte-compiled / optimized / DLL files
15
+ __pycache__/
16
+ *.py[cod]
17
+ *$py.class
18
+ result/
19
+
20
+ # C extensions
21
+ *.so
22
+
23
+ # Distribution / packaging
24
+ .Python
25
+ build/
26
+ develop-eggs/
27
+ dist/
28
+ downloads/
29
+ eggs/
30
+ .eggs/
31
+ lib/
32
+ lib64/
33
+ parts/
34
+ sdist/
35
+ var/
36
+ wheels/
37
+ pip-wheel-metadata/
38
+ share/python-wheels/
39
+ *.egg-info/
40
+ .installed.cfg
41
+ *.egg
42
+ MANIFEST
43
+
44
+ # PyInstaller
45
+ # Usually these files are written by a python script from a template
46
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
47
+ *.manifest
48
+ *.spec
49
+
50
+ # Installer logs
51
+ pip-log.txt
52
+ pip-delete-this-directory.txt
53
+
54
+ # Unit test / coverage reports
55
+ htmlcov/
56
+ .tox/
57
+ .nox/
58
+ .coverage
59
+ .coverage.*
60
+ .cache
61
+ nosetests.xml
62
+ coverage.xml
63
+ *.cover
64
+ *.py,cover
65
+ .hypothesis/
66
+ .pytest_cache/
67
+
68
+ # Translations
69
+ *.mo
70
+ *.pot
71
+
72
+ # Django stuff:
73
+ *.log
74
+ local_settings.py
75
+ db.sqlite3
76
+ db.sqlite3-journal
77
+
78
+ # Flask stuff:
79
+ instance/
80
+ .webassets-cache
81
+
82
+ # Scrapy stuff:
83
+ .scrapy
84
+
85
+ # Sphinx documentation
86
+ docs/_build/
87
+
88
+ # PyBuilder
89
+ target/
90
+
91
+ # Jupyter Notebook
92
+ .ipynb_checkpoints
93
+
94
+ # IPython
95
+ profile_default/
96
+ ipython_config.py
97
+
98
+ # pyenv
99
+ .python-version
100
+
101
+ # pipenv
102
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
103
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
104
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
105
+ # install all needed dependencies.
106
+ #Pipfile.lock
107
+
108
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
109
+ __pypackages__/
110
+
111
+ # Celery stuff
112
+ celerybeat-schedule
113
+ celerybeat.pid
114
+
115
+ # SageMath parsed files
116
+ *.sage.py
117
+
118
+ # Environments
119
+ .env
120
+ .venv
121
+ env/
122
+ venv/
123
+ ENV/
124
+ env.bak/
125
+ venv.bak/
126
+
127
+ # Spyder project settings
128
+ .spyderproject
129
+ .spyproject
130
+
131
+ # Rope project settings
132
+ .ropeproject
133
+
134
+ # mkdocs documentation
135
+ /site
136
+
137
+ # mypy
138
+ .mypy_cache/
139
+ .dmypy.json
140
+ dmypy.json
141
+
142
+ # Pyre type checker
143
+ .pyre/
LICENSE ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2023, Teng Wang
4
+
5
+ Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation
13
+ and/or other materials provided with the distribution.
14
+
15
+ 3. Neither the name of the copyright holder nor the names of its
16
+ contributors may be used to endorse or promote products derived from
17
+ this software without specific prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Caption Anything
3
+ emoji: 📚
4
+ colorFrom: green
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 3.26.0
8
+ python_version: 3.8.9
9
+ app_file: app.py
10
+ pinned: false
11
+ license: apache-2.0
12
+ duplicated_from: TencentARC/Caption-Anything
13
+ ---
14
+
15
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,599 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import gradio as gr
4
+ import numpy as np
5
+ from gradio import processing_utils
6
+
7
+ from packaging import version
8
+ from PIL import Image, ImageDraw
9
+ import functools
10
+
11
+ from caption_anything.model import CaptionAnything
12
+ from caption_anything.utils.image_editing_utils import create_bubble_frame
13
+ from caption_anything.utils.utils import mask_painter, seg_model_map, prepare_segmenter, image_resize
14
+ from caption_anything.utils.parser import parse_augment
15
+ from caption_anything.captioner import build_captioner
16
+ from caption_anything.text_refiner import build_text_refiner
17
+ from caption_anything.segmenter import build_segmenter
18
+ from caption_anything.utils.chatbot import ConversationBot, build_chatbot_tools, get_new_image_name
19
+ from segment_anything import sam_model_registry
20
+ import easyocr
21
+
22
+ args = parse_augment()
23
+ args.segmenter = "huge"
24
+ args.segmenter_checkpoint = "sam_vit_h_4b8939.pth"
25
+ args.clip_filter = True
26
+ if args.segmenter_checkpoint is None:
27
+ _, segmenter_checkpoint = prepare_segmenter(args.segmenter)
28
+ else:
29
+ segmenter_checkpoint = args.segmenter_checkpoint
30
+
31
+ shared_captioner = build_captioner(args.captioner, args.device, args)
32
+ shared_sam_model = sam_model_registry[seg_model_map[args.segmenter]](checkpoint=segmenter_checkpoint).to(args.device)
33
+ ocr_lang = ["ch_tra", "en"]
34
+ shared_ocr_reader = easyocr.Reader(ocr_lang)
35
+ tools_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.chat_tools_dict.split(',')}
36
+ shared_chatbot_tools = build_chatbot_tools(tools_dict)
37
+
38
+
39
+ class ImageSketcher(gr.Image):
40
+ """
41
+ Fix the bug of gradio.Image that cannot upload with tool == 'sketch'.
42
+ """
43
+
44
+ is_template = True # Magic to make this work with gradio.Block, don't remove unless you know what you're doing.
45
+
46
+ def __init__(self, **kwargs):
47
+ super().__init__(tool="sketch", **kwargs)
48
+
49
+ def preprocess(self, x):
50
+ if self.tool == 'sketch' and self.source in ["upload", "webcam"]:
51
+ assert isinstance(x, dict)
52
+ if x['mask'] is None:
53
+ decode_image = processing_utils.decode_base64_to_image(x['image'])
54
+ width, height = decode_image.size
55
+ mask = np.zeros((height, width, 4), dtype=np.uint8)
56
+ mask[..., -1] = 255
57
+ mask = self.postprocess(mask)
58
+ x['mask'] = mask
59
+ return super().preprocess(x)
60
+
61
+
62
+ def build_caption_anything_with_models(args, api_key="", captioner=None, sam_model=None, ocr_reader=None, text_refiner=None,
63
+ session_id=None):
64
+ segmenter = build_segmenter(args.segmenter, args.device, args, model=sam_model)
65
+ captioner = captioner
66
+ if session_id is not None:
67
+ print('Init caption anything for session {}'.format(session_id))
68
+ return CaptionAnything(args, api_key, captioner=captioner, segmenter=segmenter, ocr_reader=ocr_reader, text_refiner=text_refiner)
69
+
70
+
71
+ def init_openai_api_key(api_key=""):
72
+ text_refiner = None
73
+ visual_chatgpt = None
74
+ if api_key and len(api_key) > 30:
75
+ try:
76
+ text_refiner = build_text_refiner(args.text_refiner, args.device, args, api_key)
77
+ assert len(text_refiner.llm('hi')) > 0 # test
78
+ visual_chatgpt = ConversationBot(shared_chatbot_tools, api_key)
79
+ except:
80
+ text_refiner = None
81
+ visual_chatgpt = None
82
+ openai_available = text_refiner is not None
83
+ if openai_available:
84
+ return [gr.update(visible=True)]*6 + [gr.update(visible=False)]*2 + [text_refiner, visual_chatgpt, None]
85
+ else:
86
+ return [gr.update(visible=False)]*6 + [gr.update(visible=True)]*2 + [text_refiner, visual_chatgpt, 'Your OpenAI API Key is not available']
87
+
88
+ def init_wo_openai_api_key():
89
+ return [gr.update(visible=False)]*4 + [gr.update(visible=True)]*2 + [gr.update(visible=False)]*2 + [None, None, None]
90
+
91
+ def get_click_prompt(chat_input, click_state, click_mode):
92
+ inputs = json.loads(chat_input)
93
+ if click_mode == 'Continuous':
94
+ points = click_state[0]
95
+ labels = click_state[1]
96
+ for input in inputs:
97
+ points.append(input[:2])
98
+ labels.append(input[2])
99
+ elif click_mode == 'Single':
100
+ points = []
101
+ labels = []
102
+ for input in inputs:
103
+ points.append(input[:2])
104
+ labels.append(input[2])
105
+ click_state[0] = points
106
+ click_state[1] = labels
107
+ else:
108
+ raise NotImplementedError
109
+
110
+ prompt = {
111
+ "prompt_type": ["click"],
112
+ "input_point": click_state[0],
113
+ "input_label": click_state[1],
114
+ "multimask_output": "True",
115
+ }
116
+ return prompt
117
+
118
+
119
+ def update_click_state(click_state, caption, click_mode):
120
+ if click_mode == 'Continuous':
121
+ click_state[2].append(caption)
122
+ elif click_mode == 'Single':
123
+ click_state[2] = [caption]
124
+ else:
125
+ raise NotImplementedError
126
+
127
+ def chat_input_callback(*args):
128
+ visual_chatgpt, chat_input, click_state, state, aux_state = args
129
+ if visual_chatgpt is not None:
130
+ return visual_chatgpt.run_text(chat_input, state, aux_state)
131
+ else:
132
+ response = "Text refiner is not initilzed, please input openai api key."
133
+ state = state + [(chat_input, response)]
134
+ return state, state
135
+
136
+
137
+
138
+ def upload_callback(image_input, state, visual_chatgpt=None):
139
+
140
+ if isinstance(image_input, dict): # if upload from sketcher_input, input contains image and mask
141
+ image_input, mask = image_input['image'], image_input['mask']
142
+
143
+ click_state = [[], [], []]
144
+ image_input = image_resize(image_input, res=1024)
145
+
146
+ model = build_caption_anything_with_models(
147
+ args,
148
+ api_key="",
149
+ captioner=shared_captioner,
150
+ sam_model=shared_sam_model,
151
+ ocr_reader=shared_ocr_reader,
152
+ session_id=iface.app_id
153
+ )
154
+ model.segmenter.set_image(image_input)
155
+ image_embedding = model.image_embedding
156
+ original_size = model.original_size
157
+ input_size = model.input_size
158
+
159
+ if visual_chatgpt is not None:
160
+ print('upload_callback: add caption to chatGPT memory')
161
+ new_image_path = get_new_image_name('chat_image', func_name='upload')
162
+ image_input.save(new_image_path)
163
+ visual_chatgpt.current_image = new_image_path
164
+ img_caption = model.captioner.inference(image_input, filter=False, args={'text_prompt':''})['caption']
165
+ Human_prompt = f'\nHuman: The description of the image with path {new_image_path} is: {img_caption}. This information helps you to understand this image, but you should use tools to finish following tasks, rather than directly imagine from my description. If you understand, say \"Received\". \n'
166
+ AI_prompt = "Received."
167
+ visual_chatgpt.global_prompt = Human_prompt + 'AI: ' + AI_prompt
168
+ visual_chatgpt.agent.memory.buffer = visual_chatgpt.agent.memory.buffer + visual_chatgpt.global_prompt
169
+ state = [(None, 'Received new image, resize it to width {} and height {}: '.format(image_input.size[0], image_input.size[1]))]
170
+
171
+ return state, state, image_input, click_state, image_input, image_input, image_embedding, \
172
+ original_size, input_size
173
+
174
+
175
+ def inference_click(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality,
176
+ length, image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt,
177
+ evt: gr.SelectData):
178
+ click_index = evt.index
179
+
180
+ if point_prompt == 'Positive':
181
+ coordinate = "[[{}, {}, 1]]".format(str(click_index[0]), str(click_index[1]))
182
+ else:
183
+ coordinate = "[[{}, {}, 0]]".format(str(click_index[0]), str(click_index[1]))
184
+
185
+ prompt = get_click_prompt(coordinate, click_state, click_mode)
186
+ input_points = prompt['input_point']
187
+ input_labels = prompt['input_label']
188
+
189
+ controls = {'length': length,
190
+ 'sentiment': sentiment,
191
+ 'factuality': factuality,
192
+ 'language': language}
193
+
194
+ model = build_caption_anything_with_models(
195
+ args,
196
+ api_key="",
197
+ captioner=shared_captioner,
198
+ sam_model=shared_sam_model,
199
+ ocr_reader=shared_ocr_reader,
200
+ text_refiner=text_refiner,
201
+ session_id=iface.app_id
202
+ )
203
+
204
+ model.setup(image_embedding, original_size, input_size, is_image_set=True)
205
+
206
+ enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
207
+ out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki, verbose=True, args={'clip_filter': False})[0]
208
+
209
+ state = state + [("Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]), None)]
210
+ state = state + [(None, "raw_caption: {}".format(out['generated_captions']['raw_caption']))]
211
+ update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode)
212
+ text = out['generated_captions']['raw_caption']
213
+ input_mask = np.array(out['mask'].convert('P'))
214
+ image_input = mask_painter(np.array(image_input), input_mask)
215
+ origin_image_input = image_input
216
+ image_input = create_bubble_frame(image_input, text, (click_index[0], click_index[1]), input_mask,
217
+ input_points=input_points, input_labels=input_labels)
218
+ x, y = input_points[-1]
219
+
220
+ if visual_chatgpt is not None:
221
+ print('inference_click: add caption to chatGPT memory')
222
+ new_crop_save_path = get_new_image_name('chat_image', func_name='crop')
223
+ Image.open(out["crop_save_path"]).save(new_crop_save_path)
224
+ point_prompt = f'You should primarly use tools on the selected regional image (description: {text}, path: {new_crop_save_path}), which is a part of the whole image (path: {visual_chatgpt.current_image}). If human mentioned some objects not in the selected region, you can use tools on the whole image.'
225
+ visual_chatgpt.point_prompt = point_prompt
226
+
227
+ yield state, state, click_state, image_input
228
+ if not args.disable_gpt and model.text_refiner:
229
+ refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'],
230
+ enable_wiki=enable_wiki)
231
+ # new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
232
+ new_cap = refined_caption['caption']
233
+ if refined_caption['wiki']:
234
+ state = state + [(None, "Wiki: {}".format(refined_caption['wiki']))]
235
+ state = state + [(None, f"caption: {new_cap}")]
236
+ refined_image_input = create_bubble_frame(origin_image_input, new_cap, (click_index[0], click_index[1]),
237
+ input_mask,
238
+ input_points=input_points, input_labels=input_labels)
239
+ yield state, state, click_state, refined_image_input
240
+
241
+
242
+ def get_sketch_prompt(mask: Image.Image):
243
+ """
244
+ Get the prompt for the sketcher.
245
+ TODO: This is a temporary solution. We should cluster the sketch and get the bounding box of each cluster.
246
+ """
247
+
248
+ mask = np.asarray(mask)[..., 0]
249
+
250
+ # Get the bounding box of the sketch
251
+ y, x = np.where(mask != 0)
252
+ x1, y1 = np.min(x), np.min(y)
253
+ x2, y2 = np.max(x), np.max(y)
254
+
255
+ prompt = {
256
+ 'prompt_type': ['box'],
257
+ 'input_boxes': [
258
+ [x1, y1, x2, y2]
259
+ ]
260
+ }
261
+
262
+ return prompt
263
+
264
+
265
+ def inference_traject(sketcher_image, enable_wiki, language, sentiment, factuality, length, image_embedding, state,
266
+ original_size, input_size, text_refiner):
267
+ image_input, mask = sketcher_image['image'], sketcher_image['mask']
268
+
269
+ prompt = get_sketch_prompt(mask)
270
+ boxes = prompt['input_boxes']
271
+
272
+ controls = {'length': length,
273
+ 'sentiment': sentiment,
274
+ 'factuality': factuality,
275
+ 'language': language}
276
+
277
+ model = build_caption_anything_with_models(
278
+ args,
279
+ api_key="",
280
+ captioner=shared_captioner,
281
+ sam_model=shared_sam_model,
282
+ ocr_reader=shared_ocr_reader,
283
+ text_refiner=text_refiner,
284
+ session_id=iface.app_id
285
+ )
286
+
287
+ model.setup(image_embedding, original_size, input_size, is_image_set=True)
288
+
289
+ enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
290
+ out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki)[0]
291
+
292
+ # Update components and states
293
+ state.append((f'Box: {boxes}', None))
294
+ state.append((None, f'raw_caption: {out["generated_captions"]["raw_caption"]}'))
295
+ text = out['generated_captions']['raw_caption']
296
+ input_mask = np.array(out['mask'].convert('P'))
297
+ image_input = mask_painter(np.array(image_input), input_mask)
298
+
299
+ origin_image_input = image_input
300
+
301
+ fake_click_index = (int((boxes[0][0] + boxes[0][2]) / 2), int((boxes[0][1] + boxes[0][3]) / 2))
302
+ image_input = create_bubble_frame(image_input, text, fake_click_index, input_mask)
303
+
304
+ yield state, state, image_input
305
+
306
+ if not args.disable_gpt and model.text_refiner:
307
+ refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'],
308
+ enable_wiki=enable_wiki)
309
+
310
+ new_cap = refined_caption['caption']
311
+ if refined_caption['wiki']:
312
+ state = state + [(None, "Wiki: {}".format(refined_caption['wiki']))]
313
+ state = state + [(None, f"caption: {new_cap}")]
314
+ refined_image_input = create_bubble_frame(origin_image_input, new_cap, fake_click_index, input_mask)
315
+
316
+ yield state, state, refined_image_input
317
+
318
+ def clear_chat_memory(visual_chatgpt, keep_global=False):
319
+ if visual_chatgpt is not None:
320
+ visual_chatgpt.memory.clear()
321
+ visual_chatgpt.point_prompt = ""
322
+ if keep_global:
323
+ visual_chatgpt.agent.memory.buffer = visual_chatgpt.global_prompt
324
+ else:
325
+ visual_chatgpt.current_image = None
326
+ visual_chatgpt.global_prompt = ""
327
+
328
+ def cap_everything(image_input, visual_chatgpt, text_refiner):
329
+
330
+ model = build_caption_anything_with_models(
331
+ args,
332
+ api_key="",
333
+ captioner=shared_captioner,
334
+ sam_model=shared_sam_model,
335
+ ocr_reader=shared_ocr_reader,
336
+ text_refiner=text_refiner,
337
+ session_id=iface.app_id
338
+ )
339
+ paragraph = model.inference_cap_everything(image_input, verbose=True)
340
+ # state = state + [(None, f"Caption Everything: {paragraph}")]
341
+ Human_prompt = f'\nThe description of the image with path {visual_chatgpt.current_image} is:\n{paragraph}\nThis information helps you to understand this image, but you should use tools to finish following tasks, rather than directly imagine from my description. If you understand, say \"Received\". \n'
342
+ AI_prompt = "Received."
343
+ visual_chatgpt.global_prompt = Human_prompt + 'AI: ' + AI_prompt
344
+ visual_chatgpt.agent.memory.buffer = visual_chatgpt.agent.memory.buffer + visual_chatgpt.global_prompt
345
+ return paragraph
346
+
347
+
348
+ def get_style():
349
+ current_version = version.parse(gr.__version__)
350
+ if current_version <= version.parse('3.24.1'):
351
+ style = '''
352
+ #image_sketcher{min-height:500px}
353
+ #image_sketcher [data-testid="image"], #image_sketcher [data-testid="image"] > div{min-height: 500px}
354
+ #image_upload{min-height:500px}
355
+ #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 500px}
356
+ '''
357
+ elif current_version <= version.parse('3.27'):
358
+ style = '''
359
+ #image_sketcher{min-height:500px}
360
+ #image_upload{min-height:500px}
361
+ '''
362
+ else:
363
+ style = None
364
+
365
+ return style
366
+
367
+
368
+ def create_ui():
369
+ title = """<p><h1 align="center">Caption-Anything</h1></p>
370
+ """
371
+ description = """<p>Gradio demo for Caption Anything, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. Code: <a href="https://github.com/ttengwang/Caption-Anything">https://github.com/ttengwang/Caption-Anything</a> <a href="https://huggingface.co/spaces/TencentARC/Caption-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>"""
372
+
373
+ examples = [
374
+ ["test_images/img35.webp"],
375
+ ["test_images/img2.jpg"],
376
+ ["test_images/img5.jpg"],
377
+ ["test_images/img12.jpg"],
378
+ ["test_images/img14.jpg"],
379
+ ["test_images/qingming3.jpeg"],
380
+ ["test_images/img1.jpg"],
381
+ ]
382
+
383
+ with gr.Blocks(
384
+ css=get_style()
385
+ ) as iface:
386
+ state = gr.State([])
387
+ click_state = gr.State([[], [], []])
388
+ # chat_state = gr.State([])
389
+ origin_image = gr.State(None)
390
+ image_embedding = gr.State(None)
391
+ text_refiner = gr.State(None)
392
+ visual_chatgpt = gr.State(None)
393
+ original_size = gr.State(None)
394
+ input_size = gr.State(None)
395
+ # img_caption = gr.State(None)
396
+ aux_state = gr.State([])
397
+
398
+ gr.Markdown(title)
399
+ gr.Markdown(description)
400
+
401
+ with gr.Row():
402
+ with gr.Column(scale=1.0):
403
+ with gr.Column(visible=False) as modules_not_need_gpt:
404
+ with gr.Tab("Click"):
405
+ image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
406
+ example_image = gr.Image(type="pil", interactive=False, visible=False)
407
+ with gr.Row(scale=1.0):
408
+ with gr.Row(scale=0.4):
409
+ point_prompt = gr.Radio(
410
+ choices=["Positive", "Negative"],
411
+ value="Positive",
412
+ label="Point Prompt",
413
+ interactive=True)
414
+ click_mode = gr.Radio(
415
+ choices=["Continuous", "Single"],
416
+ value="Continuous",
417
+ label="Clicking Mode",
418
+ interactive=True)
419
+ with gr.Row(scale=0.4):
420
+ clear_button_click = gr.Button(value="Clear Clicks", interactive=True)
421
+ clear_button_image = gr.Button(value="Clear Image", interactive=True)
422
+ with gr.Tab("Trajectory (beta)"):
423
+ sketcher_input = ImageSketcher(type="pil", interactive=True, brush_radius=20,
424
+ elem_id="image_sketcher")
425
+ with gr.Row():
426
+ submit_button_sketcher = gr.Button(value="Submit", interactive=True)
427
+
428
+ with gr.Column(visible=False) as modules_need_gpt1:
429
+ with gr.Row(scale=1.0):
430
+ language = gr.Dropdown(
431
+ ['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"],
432
+ value="English", label="Language", interactive=True)
433
+ sentiment = gr.Radio(
434
+ choices=["Positive", "Natural", "Negative"],
435
+ value="Natural",
436
+ label="Sentiment",
437
+ interactive=True,
438
+ )
439
+ with gr.Row(scale=1.0):
440
+ factuality = gr.Radio(
441
+ choices=["Factual", "Imagination"],
442
+ value="Factual",
443
+ label="Factuality",
444
+ interactive=True,
445
+ )
446
+ length = gr.Slider(
447
+ minimum=10,
448
+ maximum=80,
449
+ value=10,
450
+ step=1,
451
+ interactive=True,
452
+ label="Generated Caption Length",
453
+ )
454
+ enable_wiki = gr.Radio(
455
+ choices=["Yes", "No"],
456
+ value="No",
457
+ label="Enable Wiki",
458
+ interactive=True)
459
+ # with gr.Column(visible=True) as modules_not_need_gpt3:
460
+ gr.Examples(
461
+ examples=examples,
462
+ inputs=[example_image],
463
+ )
464
+ with gr.Column(scale=0.5):
465
+ with gr.Column(visible=True) as module_key_input:
466
+ openai_api_key = gr.Textbox(
467
+ placeholder="Input openAI API key",
468
+ show_label=False,
469
+ label="OpenAI API Key",
470
+ lines=1,
471
+ type="password")
472
+ with gr.Row(scale=0.5):
473
+ enable_chatGPT_button = gr.Button(value="Run with ChatGPT", interactive=True, variant='primary')
474
+ disable_chatGPT_button = gr.Button(value="Run without ChatGPT (Faster)", interactive=True,
475
+ variant='primary')
476
+ with gr.Column(visible=False) as module_notification_box:
477
+ notification_box = gr.Textbox(lines=1, label="Notification", max_lines=5, show_label=False)
478
+ with gr.Column(visible=False) as modules_need_gpt2:
479
+ paragraph_output = gr.Textbox(lines=7, label="Describe Everything", max_lines=7)
480
+ with gr.Column(visible=False) as modules_need_gpt0:
481
+ cap_everything_button = gr.Button(value="Caption Everything in a Paragraph", interactive=True)
482
+ with gr.Column(visible=False) as modules_not_need_gpt2:
483
+ chatbot = gr.Chatbot(label="Chatbox", ).style(height=550, scale=0.5)
484
+ with gr.Column(visible=False) as modules_need_gpt3:
485
+ chat_input = gr.Textbox(show_label=False, placeholder="Enter text and press Enter").style(
486
+ container=False)
487
+ with gr.Row():
488
+ clear_button_text = gr.Button(value="Clear Text", interactive=True)
489
+ submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
490
+
491
+ openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key],
492
+ outputs=[modules_need_gpt0, modules_need_gpt1, modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt,
493
+ modules_not_need_gpt2, module_key_input, module_notification_box, text_refiner, visual_chatgpt, notification_box])
494
+ enable_chatGPT_button.click(init_openai_api_key, inputs=[openai_api_key],
495
+ outputs=[modules_need_gpt0, modules_need_gpt1, modules_need_gpt2, modules_need_gpt3,
496
+ modules_not_need_gpt,
497
+ modules_not_need_gpt2, module_key_input, module_notification_box, text_refiner, visual_chatgpt, notification_box])
498
+ disable_chatGPT_button.click(init_wo_openai_api_key,
499
+ outputs=[modules_need_gpt0, modules_need_gpt1, modules_need_gpt2, modules_need_gpt3,
500
+ modules_not_need_gpt,
501
+ modules_not_need_gpt2, module_key_input, module_notification_box, text_refiner, visual_chatgpt, notification_box])
502
+
503
+ enable_chatGPT_button.click(
504
+ lambda: (None, [], [], [[], [], []], "", "", ""),
505
+ [],
506
+ [image_input, chatbot, state, click_state, paragraph_output, origin_image],
507
+ queue=False,
508
+ show_progress=False
509
+ )
510
+ openai_api_key.submit(
511
+ lambda: (None, [], [], [[], [], []], "", "", ""),
512
+ [],
513
+ [image_input, chatbot, state, click_state, paragraph_output, origin_image],
514
+ queue=False,
515
+ show_progress=False
516
+ )
517
+
518
+ cap_everything_button.click(cap_everything, [origin_image, visual_chatgpt, text_refiner], [paragraph_output])
519
+
520
+ clear_button_click.click(
521
+ lambda x: ([[], [], []], x),
522
+ [origin_image],
523
+ [click_state, image_input],
524
+ queue=False,
525
+ show_progress=False
526
+ )
527
+ clear_button_click.click(functools.partial(clear_chat_memory, keep_global=True), inputs=[visual_chatgpt])
528
+ clear_button_image.click(
529
+ lambda: (None, [], [], [[], [], []], "", "", ""),
530
+ [],
531
+ [image_input, chatbot, state, click_state, paragraph_output, origin_image],
532
+ queue=False,
533
+ show_progress=False
534
+ )
535
+ clear_button_image.click(clear_chat_memory, inputs=[visual_chatgpt])
536
+ clear_button_text.click(
537
+ lambda: ([], [], [[], [], [], []]),
538
+ [],
539
+ [chatbot, state, click_state],
540
+ queue=False,
541
+ show_progress=False
542
+ )
543
+ clear_button_text.click(clear_chat_memory, inputs=[visual_chatgpt])
544
+
545
+ image_input.clear(
546
+ lambda: (None, [], [], [[], [], []], "", "", ""),
547
+ [],
548
+ [image_input, chatbot, state, click_state, paragraph_output, origin_image],
549
+ queue=False,
550
+ show_progress=False
551
+ )
552
+
553
+ image_input.clear(clear_chat_memory, inputs=[visual_chatgpt])
554
+
555
+
556
+ image_input.upload(upload_callback, [image_input, state, visual_chatgpt],
557
+ [chatbot, state, origin_image, click_state, image_input, sketcher_input,
558
+ image_embedding, original_size, input_size])
559
+ sketcher_input.upload(upload_callback, [sketcher_input, state, visual_chatgpt],
560
+ [chatbot, state, origin_image, click_state, image_input, sketcher_input,
561
+ image_embedding, original_size, input_size])
562
+ chat_input.submit(chat_input_callback, [visual_chatgpt, chat_input, click_state, state, aux_state],
563
+ [chatbot, state, aux_state])
564
+ chat_input.submit(lambda: "", None, chat_input)
565
+ submit_button_text.click(chat_input_callback, [visual_chatgpt, chat_input, click_state, state, aux_state],
566
+ [chatbot, state, aux_state])
567
+ submit_button_text.click(lambda: "", None, chat_input)
568
+ example_image.change(upload_callback, [example_image, state, visual_chatgpt],
569
+ [chatbot, state, origin_image, click_state, image_input, sketcher_input,
570
+ image_embedding, original_size, input_size])
571
+ example_image.change(clear_chat_memory, inputs=[visual_chatgpt])
572
+ # select coordinate
573
+ image_input.select(
574
+ inference_click,
575
+ inputs=[
576
+ origin_image, point_prompt, click_mode, enable_wiki, language, sentiment, factuality, length,
577
+ image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt
578
+ ],
579
+ outputs=[chatbot, state, click_state, image_input],
580
+ show_progress=False, queue=True
581
+ )
582
+
583
+ submit_button_sketcher.click(
584
+ inference_traject,
585
+ inputs=[
586
+ sketcher_input, enable_wiki, language, sentiment, factuality, length, image_embedding, state,
587
+ original_size, input_size, text_refiner
588
+ ],
589
+ outputs=[chatbot, state, sketcher_input],
590
+ show_progress=False, queue=True
591
+ )
592
+
593
+ return iface
594
+
595
+
596
+ if __name__ == '__main__':
597
+ iface = create_ui()
598
+ iface.queue(concurrency_count=5, api_open=False, max_size=10)
599
+ iface.launch(server_name="0.0.0.0", enable_queue=True)
assets/UI.png ADDED

Git LFS Details

  • SHA256: bce7f8b8b11832a98d85ecf7755274df5860d9b5eb35738dabbb2e585d70ddd4
  • Pointer size: 132 Bytes
  • Size of remote file: 2.64 MB
assets/caption_anything_logo.png ADDED
assets/demo1.jpg ADDED

Git LFS Details

  • SHA256: 7a3bf5f8e4e8a79824f06916cdd41c94c23c5159abf3ecd5045732f27dd358f2
  • Pointer size: 132 Bytes
  • Size of remote file: 1.87 MB
assets/demo1.png ADDED

Git LFS Details

  • SHA256: 2bd22e897705a8cebb3f1fc2ddf857eeeb1736b7b627cf8c24ed84c17728a4cc
  • Pointer size: 132 Bytes
  • Size of remote file: 1.79 MB
assets/demo1.svg ADDED
assets/demo2.png ADDED
assets/demo2.svg ADDED
assets/qingming.gif ADDED

Git LFS Details

  • SHA256: dc052aad5ab86a9a0ac1483853f2370686add2a4b0a5088be86598bec01b533e
  • Pointer size: 132 Bytes
  • Size of remote file: 4.64 MB
assets/times_with_simsun.ttf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b15a12dd4bba4a48885c279a1d16590b652773f02137a7e62ede3411970c59f
3
+ size 11066612
assets/title.png ADDED
assets/title.svg ADDED
caption_anything/__init__.py ADDED
File without changes
caption_anything/captioner/README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ To run BLIP/BLIP2, you should install transformers from source!
2
+ ```
3
+ !pip install git+https://github.com/huggingface/transformers.git
4
+ ```
5
+ To run filter module, you should install CLIP repo as a Python package as follow:
6
+ ```
7
+ !pip install ftfy regex tqdm
8
+ !pip install git+https://github.com/openai/CLIP.git
9
+ ```
10
+ To accelerate BLIP2 with int8, you should install accelerate
11
+ ```
12
+ !pip install accelerate bitsandbytes
13
+ ```
caption_anything/captioner/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .blip import BLIPCaptioner
2
+ from .blip2 import BLIP2Captioner
3
+ from .git import GITCaptioner
4
+ from .base_captioner import BaseCaptioner
5
+
6
+
7
+ def build_captioner(type, device, args=None):
8
+ if type == 'blip':
9
+ return BLIPCaptioner(device, enable_filter=args.clip_filter)
10
+ elif type == 'blip2':
11
+ return BLIP2Captioner(device, enable_filter=args.clip_filter)
12
+ elif type == 'git':
13
+ return GITCaptioner(device, enable_filter=args.clip_filter)
14
+ else:
15
+ raise NotImplementedError("")
caption_anything/captioner/base_captioner.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image, ImageDraw, ImageOps
3
+ from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
4
+ import json
5
+ import pdb
6
+ import cv2
7
+ import numpy as np
8
+ from typing import Any, Union, List
9
+ import time
10
+ import clip
11
+
12
+ from caption_anything.utils.utils import load_image
13
+
14
+
15
+ def boundary(inputs):
16
+ col = inputs.shape[1]
17
+ inputs = inputs.reshape(-1)
18
+ lens = len(inputs)
19
+ start = np.argmax(inputs)
20
+ end = lens - 1 - np.argmax(np.flip(inputs))
21
+ top = start // col
22
+ bottom = end // col
23
+ return top, bottom
24
+
25
+
26
+ def new_seg_to_box(seg_mask: Union[np.ndarray, Image.Image, str]):
27
+ if type(seg_mask) == str:
28
+ seg_mask = Image.open(seg_mask)
29
+ elif type(seg_mask) == np.ndarray:
30
+ seg_mask = Image.fromarray(seg_mask)
31
+ seg_mask = np.array(seg_mask) > 0
32
+ size = max(seg_mask.shape[0], seg_mask.shape[1])
33
+ top, bottom = boundary(seg_mask)
34
+ left, right = boundary(seg_mask.T)
35
+ return [left / size, top / size, right / size, bottom / size]
36
+
37
+
38
+ def seg_to_box(seg_mask: Union[np.ndarray, Image.Image, str]):
39
+ if type(seg_mask) == str:
40
+ seg_mask = cv2.imread(seg_mask, cv2.IMREAD_GRAYSCALE)
41
+ _, seg_mask = cv2.threshold(seg_mask, 127, 255, 0)
42
+ elif type(seg_mask) == np.ndarray:
43
+ assert seg_mask.ndim == 2 # only support single-channel segmentation mask
44
+ seg_mask = seg_mask.astype('uint8')
45
+ if seg_mask.dtype == 'bool':
46
+ seg_mask = seg_mask * 255
47
+ contours, hierarchy = cv2.findContours(seg_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
48
+ contours = np.concatenate(contours, axis=0)
49
+ rect = cv2.minAreaRect(contours)
50
+ box = cv2.boxPoints(rect)
51
+ if rect[-1] >= 45:
52
+ newstart = box.argmin(axis=0)[1] # leftmost
53
+ else:
54
+ newstart = box.argmax(axis=0)[0] # topmost
55
+ box = np.concatenate([box[newstart:], box[:newstart]], axis=0)
56
+ box = np.int0(box)
57
+ return box
58
+
59
+
60
+ def get_w_h(rect_points):
61
+ w = np.linalg.norm(rect_points[0] - rect_points[1], ord=2).astype('int')
62
+ h = np.linalg.norm(rect_points[0] - rect_points[3], ord=2).astype('int')
63
+ return w, h
64
+
65
+
66
+ def cut_box(img, rect_points):
67
+ w, h = get_w_h(rect_points)
68
+ dst_pts = np.array([[h, 0], [h, w], [0, w], [0, 0], ], dtype="float32")
69
+ transform = cv2.getPerspectiveTransform(rect_points.astype("float32"), dst_pts)
70
+ cropped_img = cv2.warpPerspective(img, transform, (h, w))
71
+ return cropped_img
72
+
73
+
74
+ class BaseCaptioner:
75
+ def __init__(self, device, enable_filter=False):
76
+ print(f"Initializing ImageCaptioning to {device}")
77
+ self.device = device
78
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
79
+ self.processor = None
80
+ self.model = None
81
+ self.enable_filter = enable_filter
82
+ if enable_filter:
83
+ self.filter, self.preprocess = clip.load('ViT-B/32', device)
84
+
85
+ @torch.no_grad()
86
+ def filter_caption(self, image: Union[np.ndarray, Image.Image, str], caption: str, reference_caption: List[str]=[]):
87
+ image = load_image(image, return_type='pil')
88
+ image = self.preprocess(image).unsqueeze(0).to(self.device) # (1, 3, 224, 224)
89
+ captions = [caption]
90
+ if len(reference_caption):
91
+ captions.extend(reference_caption)
92
+ text = clip.tokenize(captions).to(self.device) # (>1, 77)
93
+ image_features = self.filter.encode_image(image) # (1, 512)
94
+ text_features = self.filter.encode_text(text) # # (>1, 512)
95
+ image_features /= image_features.norm(dim=-1, keepdim=True)
96
+ text_features /= text_features.norm(dim=-1, keepdim=True)
97
+
98
+ if len(reference_caption):
99
+ similarity = torch.matmul(image_features, text_features.transpose(1, 0)) / 0.07
100
+ similarity = similarity.softmax(dim=1)[0, 0].item()
101
+ else:
102
+ similarity = torch.matmul(image_features, text_features.transpose(1, 0)).item()
103
+ print(f'Clip score of the caption is {similarity}')
104
+ return similarity
105
+
106
+ def inference(self, image: Union[np.ndarray, Image.Image, str], filter: bool = False):
107
+ raise NotImplementedError()
108
+
109
+ def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, filter: bool = False):
110
+ raise NotImplementedError()
111
+
112
+ def inference_box(self, image: Union[np.ndarray, Image.Image, str], box: Union[list, np.ndarray], filter=False, verbose=False, caption_args={}):
113
+ image = load_image(image, return_type="pil")
114
+
115
+ if np.array(box).size == 4:
116
+ # [x0, y0, x1, y1], where (x0, y0), (x1, y1) represent top-left and bottom-right corners
117
+ size = max(image.width, image.height)
118
+ x1, y1, x2, y2 = box
119
+ image_crop = np.array(image.crop((x1 * size, y1 * size, x2 * size, y2 * size)))
120
+ elif np.array(box).size == 8: # four corners of an irregular rectangle
121
+ image_crop = cut_box(np.array(image), box)
122
+
123
+ crop_save_path = None
124
+ if verbose:
125
+ crop_save_path = f'result/crop_{time.time()}.png'
126
+ Image.fromarray(image_crop).save(crop_save_path)
127
+ print(f'croped image saved in {crop_save_path}')
128
+ caption = self.inference(image_crop, filter, caption_args)
129
+ caption.update({'crop_save_path': crop_save_path})
130
+ return caption
131
+
132
+ def inference_seg(self,
133
+ image: Union[np.ndarray, str],
134
+ seg_mask: Union[np.ndarray, Image.Image, str] = None,
135
+ crop_mode="w_bg",
136
+ filter=False,
137
+ disable_regular_box=False,
138
+ verbose=False,
139
+ caption_args={}):
140
+ if seg_mask is None:
141
+ seg_mask = np.ones(image.size).astype(bool)
142
+
143
+ image = load_image(image, return_type="pil")
144
+ seg_mask = load_image(seg_mask, return_type="pil")
145
+
146
+ seg_mask = seg_mask.resize(image.size)
147
+ seg_mask = np.array(seg_mask) > 0
148
+ if crop_mode == "wo_bg":
149
+ image = np.array(image) * seg_mask[:, :, np.newaxis] + (1 - seg_mask[:, :, np.newaxis]) * 255
150
+ image = np.uint8(image)
151
+ else:
152
+ image = np.array(image)
153
+
154
+ if disable_regular_box:
155
+ min_area_box = seg_to_box(seg_mask)
156
+ else:
157
+ min_area_box = new_seg_to_box(seg_mask)
158
+ return self.inference_box(image, min_area_box, filter, verbose, caption_args)
159
+
160
+ def generate_seg_cropped_image(self,
161
+ image: Union[np.ndarray, str],
162
+ seg_mask: Union[np.ndarray, Image.Image, str],
163
+ crop_mode="w_bg",
164
+ disable_regular_box=False):
165
+ image = load_image(image, return_type="pil")
166
+ seg_mask = load_image(seg_mask, return_type="pil")
167
+
168
+ seg_mask = seg_mask.resize(image.size)
169
+ seg_mask = np.array(seg_mask) > 0
170
+
171
+ if crop_mode == "wo_bg":
172
+ image = np.array(image) * seg_mask[:, :, np.newaxis] + (1 - seg_mask[:, :, np.newaxis]) * 255
173
+ else:
174
+ image = np.array(image)
175
+
176
+ if disable_regular_box:
177
+ box = seg_to_box(seg_mask)
178
+ else:
179
+ box = new_seg_to_box(seg_mask)
180
+
181
+ if np.array(box).size == 4:
182
+ # [x0, y0, x1, y1], where (x0, y0), (x1, y1) represent top-left and bottom-right corners
183
+ size = max(image.shape[0], image.shape[1])
184
+ x1, y1, x2, y2 = box
185
+ image_crop = np.array(image.crop((x1 * size, y1 * size, x2 * size, y2 * size)))
186
+ elif np.array(box).size == 8: # four corners of an irregular rectangle
187
+ image_crop = cut_box(np.array(image), box)
188
+ crop_save_path = f'result/crop_{time.time()}.png'
189
+ Image.fromarray(image_crop).save(crop_save_path)
190
+ print(f'croped image saved in {crop_save_path}')
191
+ return crop_save_path
192
+
193
+
194
+ if __name__ == '__main__':
195
+ model = BaseCaptioner(device='cuda:0')
196
+ image_path = 'test_images/img2.jpg'
197
+ seg_mask = np.zeros((15, 15))
198
+ seg_mask[5:10, 5:10] = 1
199
+ seg_mask = 'image/SAM/img10.jpg.raw_mask.png'
200
+ print(model.inference_seg(image_path, seg_mask))
caption_anything/captioner/blip.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from transformers import BlipProcessor
4
+
5
+ from caption_anything.utils.utils import load_image
6
+ from .modeling_blip import BlipForConditionalGeneration
7
+ import numpy as np
8
+ from typing import Union
9
+ from .base_captioner import BaseCaptioner
10
+ import torchvision.transforms.functional as F
11
+
12
+
13
+ class BLIPCaptioner(BaseCaptioner):
14
+ def __init__(self, device, enable_filter=False):
15
+ super().__init__(device, enable_filter)
16
+ self.device = device
17
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
18
+ self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
19
+ self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large",
20
+ torch_dtype=self.torch_dtype).to(self.device)
21
+
22
+ @torch.no_grad()
23
+ def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False, args={}):
24
+ image = load_image(image, return_type="pil")
25
+ inputs = self.processor(image, return_tensors="pt").to(self.device, self.torch_dtype)
26
+ out = self.model.generate(**inputs, max_new_tokens=50)
27
+ captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
28
+
29
+ result = {}
30
+ if self.enable_filter and filter:
31
+ clip_score = self.filter_caption(image, captions)
32
+ result['clip_score'] = clip_score
33
+ result.update({'caption':captions})
34
+ print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
35
+ return {'caption': captions}
36
+
37
+ @torch.no_grad()
38
+ def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, crop_mode="w_bg",
39
+ filter=False, disable_regular_box=False):
40
+ result = {}
41
+ crop_save_path = self.generate_seg_cropped_image(image=image, seg_mask=seg_mask, crop_mode=crop_mode,
42
+ disable_regular_box=disable_regular_box)
43
+ image = load_image(image, return_type="pil")
44
+ inputs = self.processor(image, return_tensors="pt")
45
+ pixel_values = inputs.pixel_values.to(self.device, self.torch_dtype)
46
+ _, _, H, W = pixel_values.shape
47
+ seg_mask = Image.fromarray(seg_mask.astype(float))
48
+ seg_mask = seg_mask.resize((H, W))
49
+ seg_mask = F.pil_to_tensor(seg_mask) > 0.5
50
+ seg_mask = seg_mask.float()
51
+ pixel_masks = seg_mask.unsqueeze(0).to(self.device)
52
+ out = self.model.generate(pixel_values=pixel_values, pixel_masks=pixel_masks, max_new_tokens=50)
53
+ captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
54
+ if self.enable_filter and filter:
55
+ clip_score = self.filter_caption(image, captions)
56
+ result['clip_score'] = clip_score
57
+ result.update({'caption':captions, 'crop_save_path':crop_save_path})
58
+ print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
59
+ return result
60
+
61
+
62
+ if __name__ == '__main__':
63
+ model = BLIPCaptioner(device='cuda:0')
64
+ # image_path = 'test_images/img2.jpg'
65
+ image_path = 'image/SAM/img10.jpg'
66
+ seg_mask = np.zeros((15, 15))
67
+ seg_mask[5:10, 5:10] = 1
68
+ seg_mask = 'test_images/img10.jpg.raw_mask.png'
69
+ image_path = 'test_images/img2.jpg'
70
+ seg_mask = 'test_images/img2.jpg.raw_mask.png'
71
+ print(f'process image {image_path}')
72
+ print(model.inference_with_reduced_tokens(image_path, seg_mask))
caption_anything/captioner/blip2.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import numpy as np
4
+ from typing import Union
5
+ from transformers import AutoProcessor, Blip2ForConditionalGeneration
6
+
7
+ from caption_anything.utils.utils import is_platform_win, load_image
8
+ from .base_captioner import BaseCaptioner
9
+ import time
10
+
11
+ class BLIP2Captioner(BaseCaptioner):
12
+ def __init__(self, device, dialogue: bool = False, enable_filter: bool = False):
13
+ super().__init__(device, enable_filter)
14
+ self.device = device
15
+ self.dialogue = dialogue
16
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
17
+ self.processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
18
+ if is_platform_win():
19
+ self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", device_map="sequential", torch_dtype=self.torch_dtype)
20
+ else:
21
+ self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", device_map='sequential', load_in_8bit=True)
22
+
23
+ @torch.no_grad()
24
+ def inference(self,
25
+ image: Union[np.ndarray, Image.Image, str],
26
+ filter=False,
27
+ args={}):
28
+ args['return_ppl'] = args.get('return_ppl', False)
29
+ args['text_prompt'] = args.get('text_prompt', 'Question: what does the image show? Answer:')
30
+ args['reference_caption'] = args.get('reference_caption', [])
31
+
32
+ image = load_image(image, return_type="pil")
33
+ result = {}
34
+ if not self.dialogue:
35
+ inputs = self.processor(image, text = args['text_prompt'], return_tensors="pt").to(self.device, self.torch_dtype)
36
+ out = self.model.generate(**inputs, return_dict_in_generate=True, output_scores=True, max_new_tokens=50)
37
+ caption = self.processor.decode(out.sequences[0], skip_special_tokens=True).strip()
38
+ if self.enable_filter and filter:
39
+ print('reference caption: {}, caption: {}'.format(args['reference_caption'], caption))
40
+ clip_score = self.filter_caption(image, caption, args['reference_caption'])
41
+ result['clip_score'] = clip_score
42
+ if args['return_ppl']:
43
+ ppl_score = torch.stack(out.scores, dim=1).softmax(dim=2).log().max(dim=2)[0].sum(dim=1)[0]
44
+ result['ppl_score'] = ppl_score.item()
45
+ print(f"\nProcessed ImageCaptioning by BLIP2Captioner, Output Text: {caption}")
46
+ result['caption'] = caption
47
+ return result
48
+ else:
49
+ context = []
50
+ template = "Question: {} Answer: {}."
51
+ while(True):
52
+ input_texts = input()
53
+ if input_texts == 'end':
54
+ break
55
+ prompt = " ".join([template.format(context[i][0], context[i][1]) for i in range(len(context))]) + " Question: " + input_texts + " Answer:"
56
+ inputs = self.processor(image, text = prompt, return_tensors="pt").to(self.device, self.torch_dtype)
57
+ out = self.model.generate(**inputs, max_new_tokens=50)
58
+ captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
59
+ context.append((input_texts, captions))
60
+ result['caption'] = captions
61
+ return result
62
+
63
+ if __name__ == '__main__':
64
+
65
+ dialogue = False
66
+ model = BLIP2Captioner(device='cuda:4', dialogue = dialogue, cache_dir = '/nvme-ssd/fjj/Caption-Anything/model_cache')
67
+ image_path = 'test_images/img2.jpg'
68
+ seg_mask = np.zeros((224,224))
69
+ seg_mask[50:200, 50:200] = 1
70
+ print(f'process image {image_path}')
71
+ print(model.inference_seg(image_path, seg_mask))
caption_anything/captioner/git.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GitProcessor, AutoProcessor
2
+
3
+ from caption_anything.utils.utils import load_image
4
+ from .modeling_git import GitForCausalLM
5
+ from PIL import Image
6
+ import torch
7
+ from .base_captioner import BaseCaptioner
8
+ import numpy as np
9
+ from typing import Union
10
+ import torchvision.transforms.functional as F
11
+
12
+
13
+ class GITCaptioner(BaseCaptioner):
14
+ def __init__(self, device, enable_filter=False):
15
+ super().__init__(device, enable_filter)
16
+ self.device = device
17
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
18
+ self.processor = AutoProcessor.from_pretrained("microsoft/git-large")
19
+ self.model = GitForCausalLM.from_pretrained("microsoft/git-large", torch_dtype=self.torch_dtype).to(self.device)
20
+
21
+ @torch.no_grad()
22
+ def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False, args={}):
23
+ image = load_image(image, return_type="pil")
24
+ pixel_values = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device, self.torch_dtype)
25
+ generated_ids = self.model.generate(pixel_values=pixel_values, max_new_tokens=50)
26
+ captions = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
27
+
28
+ result = {}
29
+ if self.enable_filter and filter:
30
+ clip_score = self.filter_caption(image, captions)
31
+ result['clip_score'] = clip_score
32
+ result.update({'caption':captions})
33
+ print(f"\nProcessed ImageCaptioning by GITCaptioner, Output Text: {captions}")
34
+ return {'caption': captions}
35
+
36
+ @torch.no_grad()
37
+ def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, crop_mode="w_bg",
38
+ filter=False, disable_regular_box=False):
39
+ result = {}
40
+ crop_save_path = self.generate_seg_cropped_image(image=image, seg_mask=seg_mask, crop_mode=crop_mode,
41
+ disable_regular_box=disable_regular_box)
42
+ image = load_image(image, return_type="pil")
43
+ inputs = self.processor(images=image, return_tensors="pt")
44
+ pixel_values = inputs.pixel_values.to(self.device, self.torch_dtype)
45
+ _, _, H, W = pixel_values.shape
46
+ seg_mask = Image.fromarray(seg_mask.astype(float))
47
+ seg_mask = seg_mask.resize((H, W))
48
+ seg_mask = F.pil_to_tensor(seg_mask) > 0.5
49
+ seg_mask = seg_mask.float()
50
+ pixel_masks = seg_mask.unsqueeze(0).to(self.device)
51
+ out = self.model.generate(pixel_values=pixel_values, pixel_masks=pixel_masks, max_new_tokens=50)
52
+ captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
53
+ if self.enable_filter and filter:
54
+ clip_score = self.filter_caption(image, captions)
55
+ result['clip_score'] = clip_score
56
+ print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
57
+ result.update({'caption':captions, 'crop_save_path':crop_save_path})
58
+ return result
59
+
60
+
61
+ if __name__ == '__main__':
62
+ model = GITCaptioner(device='cuda:2', enable_filter=False)
63
+ image_path = 'test_images/img2.jpg'
64
+ seg_mask = np.zeros((224, 224))
65
+ seg_mask[50:200, 50:200] = 1
66
+ print(f'process image {image_path}')
67
+ print(model.inference_with_reduced_tokens(image_path, seg_mask))
caption_anything/captioner/modeling_blip.py ADDED
@@ -0,0 +1,1476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The Salesforce Team Authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch BLIP model."""
16
+
17
+ from dataclasses import dataclass
18
+ from typing import Any, Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.utils.checkpoint
22
+ from torch import nn
23
+ from torch.nn.functional import normalize
24
+
25
+ from transformers.activations import ACT2FN
26
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
27
+ from transformers.modeling_utils import PreTrainedModel
28
+ from transformers.utils import (
29
+ ModelOutput,
30
+ add_start_docstrings,
31
+ add_start_docstrings_to_model_forward,
32
+ logging,
33
+ replace_return_docstrings,
34
+ )
35
+ from transformers.models.blip.configuration_blip import BlipConfig, BlipTextConfig, BlipVisionConfig
36
+ from transformers.models.blip.modeling_blip_text import BlipTextLMHeadModel, BlipTextModel
37
+ from .vit_pixel_masks_utils import ViTPatchMaskGenerator
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+ _CHECKPOINT_FOR_DOC = "Salesforce/blip-vqa-base"
42
+
43
+ BLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
44
+ "Salesforce/blip-vqa-base",
45
+ "Salesforce/blip-vqa-capfit-large",
46
+ "Salesforce/blip-image-captioning-base",
47
+ "Salesforce/blip-image-captioning-large",
48
+ "Salesforce/blip-itm-base-coco",
49
+ "Salesforce/blip-itm-large-coco",
50
+ "Salesforce/blip-itm-base-flikr",
51
+ "Salesforce/blip-itm-large-flikr",
52
+ # See all BLIP models at https://huggingface.co/models?filter=blip
53
+ ]
54
+
55
+
56
+ # Copied from transformers.models.clip.modeling_clip.contrastive_loss
57
+ def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
58
+ return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
59
+
60
+
61
+ # Copied from transformers.models.clip.modeling_clip.clip_loss with clip->blip
62
+ def blip_loss(similarity: torch.Tensor) -> torch.Tensor:
63
+ caption_loss = contrastive_loss(similarity)
64
+ image_loss = contrastive_loss(similarity.t())
65
+ return (caption_loss + image_loss) / 2.0
66
+
67
+
68
+ @dataclass
69
+ class BlipForConditionalGenerationModelOutput(ModelOutput):
70
+ """
71
+ Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the
72
+ last hidden states. This class also adds the loss term from the text decoder.
73
+
74
+ Args:
75
+ loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
76
+ Languge modeling loss from the text decoder.
77
+ decoder_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*):
78
+ Prediction scores of the language modeling head of the text decoder model.
79
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*):
80
+ The image embeddings obtained after applying the Vision Transformer model to the input image.
81
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
82
+ Sequence of hidden-states at the output of the last layer of the model.
83
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
84
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
85
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
86
+
87
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
88
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed):
89
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
90
+ sequence_length)`.
91
+
92
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
93
+ heads.
94
+ """
95
+
96
+ loss: Optional[Tuple[torch.FloatTensor]] = None
97
+ decoder_logits: Optional[Tuple[torch.FloatTensor]] = None
98
+ image_embeds: Optional[torch.FloatTensor] = None
99
+ last_hidden_state: torch.FloatTensor = None
100
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
101
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
102
+
103
+
104
+ @dataclass
105
+ class BlipTextVisionModelOutput(ModelOutput):
106
+ """
107
+ Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the
108
+ last hidden states. This class also adds the loss term from the text decoder.
109
+
110
+ Args:
111
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
112
+ Languge modeling loss from the text decoder.
113
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
114
+ The image embeddings obtained by applying the projection layer to the pooler_output.
115
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
116
+ Sequence of hidden-states at the output of the last layer of the model.
117
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
118
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
119
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
120
+
121
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
122
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
123
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
124
+ sequence_length)`.
125
+
126
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
127
+ heads.
128
+ """
129
+
130
+ loss: Optional[torch.FloatTensor] = None
131
+ image_embeds: Optional[torch.FloatTensor] = None
132
+ last_hidden_state: torch.FloatTensor = None
133
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
134
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
135
+
136
+
137
+ @dataclass
138
+ class BlipImageTextMatchingModelOutput(ModelOutput):
139
+ """
140
+ Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the
141
+ last hidden states. This class also adds the loss term from the text decoder as well as the image-text similarity
142
+ scores.
143
+
144
+ Args:
145
+ itm_score (`torch.FloatTensor`):
146
+ The image-text similarity scores.
147
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
148
+ Languge modeling loss from the text decoder.
149
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
150
+ The image embeddings obtained by applying the projection layer to the pooler_output.
151
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
152
+ Sequence of hidden-states at the output of the last layer of the model.
153
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
154
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
155
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
156
+
157
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
158
+ vision_pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*):
159
+ Last layer hidden-state of the vision of the vision-only branch of the model.
160
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
161
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
162
+ sequence_length)`.
163
+
164
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
165
+ heads.
166
+ question_embeds (`torch.FloatTensor`):
167
+ The question embeddings obtained by the text projection layer.
168
+ """
169
+
170
+ itm_score: Optional[torch.FloatTensor] = None
171
+ loss: Optional[torch.FloatTensor] = None
172
+ image_embeds: Optional[torch.FloatTensor] = None
173
+ last_hidden_state: torch.FloatTensor = None
174
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
175
+ vision_pooler_output: Optional[torch.FloatTensor] = None
176
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
177
+ question_embeds: Optional[Tuple[torch.FloatTensor]] = None
178
+
179
+
180
+ @dataclass
181
+ class BlipOutput(ModelOutput):
182
+ """
183
+ Args:
184
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
185
+ Contrastive loss for image-text similarity.
186
+ logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
187
+ The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
188
+ similarity scores.
189
+ logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
190
+ The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
191
+ similarity scores.
192
+ text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
193
+ The text embeddings obtained by applying the projection layer to the pooled output of [`BlipTextModel`].
194
+ image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
195
+ The image embeddings obtained by applying the projection layer to the pooled output of [`BlipVisionModel`].
196
+ text_model_output(`BaseModelOutputWithPooling`):
197
+ The output of the [`BlipTextModel`].
198
+ vision_model_output(`BaseModelOutputWithPooling`):
199
+ The output of the [`BlipVisionModel`].
200
+ """
201
+
202
+ loss: Optional[torch.FloatTensor] = None
203
+ logits_per_image: torch.FloatTensor = None
204
+ logits_per_text: torch.FloatTensor = None
205
+ text_embeds: torch.FloatTensor = None
206
+ image_embeds: torch.FloatTensor = None
207
+ text_model_output: BaseModelOutputWithPooling = None
208
+ vision_model_output: BaseModelOutputWithPooling = None
209
+
210
+ def to_tuple(self) -> Tuple[Any]:
211
+ return tuple(
212
+ self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
213
+ for k in self.keys()
214
+ )
215
+
216
+
217
+ class BlipVisionEmbeddings(nn.Module):
218
+ def __init__(self, config: BlipVisionConfig):
219
+ super().__init__()
220
+ self.config = config
221
+ self.embed_dim = config.hidden_size
222
+ self.image_size = config.image_size
223
+ self.patch_size = config.patch_size
224
+
225
+ self.class_embedding = nn.Parameter(
226
+ torch.randn(1, 1, self.embed_dim),
227
+ )
228
+
229
+ self.patch_embedding = nn.Conv2d(
230
+ in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
231
+ )
232
+
233
+ self.num_patches = (self.image_size // self.patch_size) ** 2
234
+ self.num_positions = self.num_patches + 1
235
+
236
+ self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
237
+
238
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
239
+ batch_size = pixel_values.shape[0]
240
+ target_dtype = self.patch_embedding.weight.dtype
241
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
242
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
243
+
244
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
245
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
246
+ embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype)
247
+ return embeddings
248
+
249
+
250
+ # Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Blip
251
+ class BlipTextEmbeddings(nn.Module):
252
+ def __init__(self, config: BlipTextConfig):
253
+ super().__init__()
254
+ embed_dim = config.hidden_size
255
+
256
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
257
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
258
+
259
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
260
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
261
+
262
+ def forward(
263
+ self,
264
+ input_ids: Optional[torch.LongTensor] = None,
265
+ position_ids: Optional[torch.LongTensor] = None,
266
+ inputs_embeds: Optional[torch.FloatTensor] = None,
267
+ ) -> torch.Tensor:
268
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
269
+
270
+ if position_ids is None:
271
+ position_ids = self.position_ids[:, :seq_length]
272
+
273
+ if inputs_embeds is None:
274
+ inputs_embeds = self.token_embedding(input_ids)
275
+
276
+ position_embeddings = self.position_embedding(position_ids)
277
+ embeddings = inputs_embeds + position_embeddings
278
+
279
+ return embeddings
280
+
281
+
282
+ class BlipAttention(nn.Module):
283
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
284
+
285
+ def __init__(self, config):
286
+ super().__init__()
287
+ self.config = config
288
+ self.embed_dim = config.hidden_size
289
+ self.num_heads = config.num_attention_heads
290
+ self.head_dim = self.embed_dim // self.num_heads
291
+ if self.head_dim * self.num_heads != self.embed_dim:
292
+ raise ValueError(
293
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
294
+ f" {self.num_heads})."
295
+ )
296
+ self.scale = self.head_dim**-0.5
297
+ self.dropout = nn.Dropout(config.attention_dropout)
298
+
299
+ self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim)
300
+
301
+ self.projection = nn.Linear(self.embed_dim, self.embed_dim)
302
+
303
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
304
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
305
+
306
+ def forward(
307
+ self,
308
+ hidden_states: torch.Tensor,
309
+ head_mask: Optional[torch.Tensor] = None,
310
+ output_attentions: Optional[bool] = False,
311
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
312
+ """Input shape: Batch x Time x Channel"""
313
+
314
+ bsz, tgt_len, embed_dim = hidden_states.size()
315
+
316
+ mixed_qkv = self.qkv(hidden_states)
317
+ mixed_qkv = (
318
+ self.qkv(hidden_states)
319
+ .reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads)
320
+ .permute(2, 0, 3, 1, 4)
321
+ )
322
+ query_states, key_states, value_states = (
323
+ mixed_qkv[0],
324
+ mixed_qkv[1],
325
+ mixed_qkv[2],
326
+ )
327
+
328
+ # Take the dot product between "query" and "key" to get the raw attention scores.
329
+ attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
330
+
331
+ attention_scores = attention_scores * self.scale
332
+
333
+ # Normalize the attention scores to probabilities.
334
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
335
+
336
+ # This is actually dropping out entire tokens to attend to, which might
337
+ # seem a bit unusual, but is taken from the original Transformer paper.
338
+ attention_probs = self.dropout(attention_probs)
339
+
340
+ # Mask heads if we want to
341
+ if head_mask is not None:
342
+ attention_probs = attention_probs * head_mask
343
+
344
+ context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)
345
+
346
+ new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,)
347
+ context_layer = context_layer.reshape(new_context_layer_shape)
348
+
349
+ output = self.projection(context_layer)
350
+
351
+ outputs = (output, attention_probs) if output_attentions else (output, None)
352
+
353
+ return outputs
354
+
355
+
356
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Blip
357
+ class BlipMLP(nn.Module):
358
+ def __init__(self, config):
359
+ super().__init__()
360
+ self.config = config
361
+ self.activation_fn = ACT2FN[config.hidden_act]
362
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
363
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
364
+
365
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
366
+ hidden_states = self.fc1(hidden_states)
367
+ hidden_states = self.activation_fn(hidden_states)
368
+ hidden_states = self.fc2(hidden_states)
369
+ return hidden_states
370
+
371
+
372
+ class BlipEncoderLayer(nn.Module):
373
+ def __init__(self, config: BlipConfig):
374
+ super().__init__()
375
+ self.embed_dim = config.hidden_size
376
+ self.self_attn = BlipAttention(config)
377
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
378
+ self.mlp = BlipMLP(config)
379
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
380
+
381
+ def forward(
382
+ self,
383
+ hidden_states: torch.Tensor,
384
+ attention_mask: torch.Tensor,
385
+ output_attentions: Optional[bool] = False,
386
+ ) -> Tuple[torch.FloatTensor]:
387
+ """
388
+ Args:
389
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
390
+ attention_mask (`torch.FloatTensor`): attention mask of size
391
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
392
+ `(config.encoder_attention_heads,)`.
393
+ output_attentions (`bool`, *optional*):
394
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
395
+ returned tensors for more detail.
396
+ """
397
+ residual = hidden_states
398
+
399
+ hidden_states = self.layer_norm1(hidden_states)
400
+ hidden_states, attn_weights = self.self_attn(
401
+ hidden_states=hidden_states,
402
+ head_mask=attention_mask,
403
+ output_attentions=output_attentions,
404
+ )
405
+ hidden_states = hidden_states + residual
406
+ residual = hidden_states
407
+ hidden_states = self.layer_norm2(hidden_states)
408
+ hidden_states = self.mlp(hidden_states)
409
+
410
+ hidden_states = hidden_states + residual
411
+
412
+ outputs = (hidden_states,)
413
+
414
+ if output_attentions:
415
+ outputs += (attn_weights,)
416
+
417
+ return outputs
418
+
419
+
420
+ class BlipPreTrainedModel(PreTrainedModel):
421
+ """
422
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
423
+ models.
424
+ """
425
+
426
+ config_class = BlipConfig
427
+ base_model_prefix = "blip"
428
+ supports_gradient_checkpointing = True
429
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
430
+
431
+ def _init_weights(self, module):
432
+ """Initialize the weights"""
433
+ factor = self.config.initializer_range
434
+ if isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear):
435
+ module.weight.data.normal_(mean=0.0, std=factor)
436
+ if hasattr(module, "bias") and module.bias is not None:
437
+ module.bias.data.zero_()
438
+
439
+ if isinstance(module, BlipVisionEmbeddings):
440
+ if hasattr(self.config, "vision_config"):
441
+ factor = self.config.vision_config.initializer_range
442
+ nn.init.trunc_normal_(
443
+ module.position_embedding,
444
+ mean=0.0,
445
+ std=factor,
446
+ )
447
+
448
+ nn.init.trunc_normal_(
449
+ module.class_embedding,
450
+ mean=0.0,
451
+ std=factor,
452
+ )
453
+
454
+ elif isinstance(module, nn.LayerNorm):
455
+ module.bias.data.zero_()
456
+ module.weight.data.fill_(1.0)
457
+ elif isinstance(module, nn.Linear) and module.bias is not None:
458
+ module.bias.data.zero_()
459
+
460
+ def _set_gradient_checkpointing(self, module, value=False):
461
+ if isinstance(module, BlipEncoder):
462
+ module.gradient_checkpointing = value
463
+
464
+
465
+ BLIP_START_DOCSTRING = r"""
466
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
467
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
468
+ etc.)
469
+
470
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
471
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
472
+ and behavior.
473
+
474
+ Parameters:
475
+ config ([`BlipConfig`]): Model configuration class with all the parameters of the model.
476
+ Initializing with a config file does not load the weights associated with the model, only the
477
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
478
+ """
479
+
480
+ BLIP_TEXT_INPUTS_DOCSTRING = r"""
481
+ Args:
482
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
483
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
484
+ it.
485
+
486
+ Indices can be obtained using [`AutoProcessor`]. See [`BlipProcessor.__call__`] for details.
487
+
488
+ [What are input IDs?](../glossary#input-ids)
489
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
490
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
491
+
492
+ - 1 for tokens that are **not masked**,
493
+ - 0 for tokens that are **masked**.
494
+
495
+ [What are attention masks?](../glossary#attention-mask)
496
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
497
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
498
+ config.max_position_embeddings - 1]`.
499
+
500
+ [What are position IDs?](../glossary#position-ids)
501
+ output_attentions (`bool`, *optional*):
502
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
503
+ tensors for more detail.
504
+ output_hidden_states (`bool`, *optional*):
505
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
506
+ more detail.
507
+ return_dict (`bool`, *optional*):
508
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
509
+ """
510
+
511
+ BLIP_VISION_INPUTS_DOCSTRING = r"""
512
+ Args:
513
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
514
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
515
+ [`BlipImageProcessor`]. See [`BlipImageProcessor.__call__`] for details.
516
+ output_attentions (`bool`, *optional*):
517
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
518
+ tensors for more detail.
519
+ output_hidden_states (`bool`, *optional*):
520
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
521
+ more detail.
522
+ return_dict (`bool`, *optional*):
523
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
524
+ """
525
+
526
+ BLIP_INPUTS_DOCSTRING = r"""
527
+ Args:
528
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
529
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
530
+ it.
531
+
532
+ Indices can be obtained using [`AutoProcessor`]. See [`BlipProcessor.__call__`] for details.
533
+
534
+ [What are input IDs?](../glossary#input-ids)
535
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
536
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
537
+
538
+ - 1 for tokens that are **not masked**,
539
+ - 0 for tokens that are **masked**.
540
+
541
+ [What are attention masks?](../glossary#attention-mask)
542
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
543
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
544
+ config.max_position_embeddings - 1]`.
545
+
546
+ [What are position IDs?](../glossary#position-ids)
547
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
548
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
549
+ [`BlipImageProcessor`]. See [`BlipImageProcessor.__call__`] for details.
550
+ return_loss (`bool`, *optional*):
551
+ Whether or not to return the contrastive loss.
552
+ output_attentions (`bool`, *optional*):
553
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
554
+ tensors for more detail.
555
+ output_hidden_states (`bool`, *optional*):
556
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
557
+ more detail.
558
+ return_dict (`bool`, *optional*):
559
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
560
+ """
561
+
562
+
563
+ class BlipEncoder(nn.Module):
564
+ """
565
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
566
+ [`BlipEncoderLayer`].
567
+
568
+ Args:
569
+ config (`BlipConfig`):
570
+ The corresponding vision configuration for the `BlipEncoder`.
571
+ """
572
+
573
+ def __init__(self, config: BlipConfig):
574
+ super().__init__()
575
+ self.config = config
576
+ self.layers = nn.ModuleList([BlipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
577
+ self.gradient_checkpointing = False
578
+
579
+ def forward(
580
+ self,
581
+ inputs_embeds,
582
+ attention_mask: Optional[torch.LongTensor] = None,
583
+ output_attentions: Optional[bool] = None,
584
+ output_hidden_states: Optional[bool] = None,
585
+ return_dict: Optional[bool] = None,
586
+ ) -> Union[Tuple, BaseModelOutput]:
587
+ r"""
588
+ Args:
589
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
590
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
591
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
592
+ than the model's internal embedding lookup matrix.
593
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
594
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
595
+
596
+ - 1 for tokens that are **not masked**,
597
+ - 0 for tokens that are **masked**.
598
+
599
+ [What are attention masks?](../glossary#attention-mask)
600
+ output_attentions (`bool`, *optional*):
601
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
602
+ returned tensors for more detail.
603
+ output_hidden_states (`bool`, *optional*):
604
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
605
+ for more detail.
606
+ return_dict (`bool`, *optional*):
607
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
608
+ """
609
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
610
+ output_hidden_states = (
611
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
612
+ )
613
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
614
+
615
+ encoder_states = () if output_hidden_states else None
616
+ all_attentions = () if output_attentions else None
617
+
618
+ hidden_states = inputs_embeds
619
+ for idx, encoder_layer in enumerate(self.layers):
620
+ if output_hidden_states:
621
+ encoder_states = encoder_states + (hidden_states,)
622
+ if self.gradient_checkpointing and self.training:
623
+
624
+ def create_custom_forward(module):
625
+ def custom_forward(*inputs):
626
+ return module(*inputs, output_attentions)
627
+
628
+ return custom_forward
629
+
630
+ layer_outputs = torch.utils.checkpoint.checkpoint(
631
+ create_custom_forward(encoder_layer),
632
+ hidden_states,
633
+ attention_mask,
634
+ )
635
+ else:
636
+ layer_outputs = encoder_layer(
637
+ hidden_states,
638
+ attention_mask,
639
+ output_attentions=output_attentions,
640
+ )
641
+
642
+ hidden_states = layer_outputs[0]
643
+
644
+ if output_attentions:
645
+ all_attentions = all_attentions + (layer_outputs[1],)
646
+
647
+ if output_hidden_states:
648
+ encoder_states = encoder_states + (hidden_states,)
649
+
650
+ if not return_dict:
651
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
652
+ return BaseModelOutput(
653
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
654
+ )
655
+
656
+
657
+ class BlipVisionModel(BlipPreTrainedModel):
658
+ main_input_name = "pixel_values"
659
+ config_class = BlipVisionConfig
660
+
661
+ def __init__(self, config: BlipVisionConfig):
662
+ super().__init__(config)
663
+ self.config = config
664
+ embed_dim = config.hidden_size
665
+ self.embeddings = BlipVisionEmbeddings(config)
666
+ self.patch_mask_generator = ViTPatchMaskGenerator(config.patch_size)
667
+ self.encoder = BlipEncoder(config)
668
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
669
+
670
+ self.post_init()
671
+
672
+ @add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
673
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=BlipVisionConfig)
674
+ def forward(
675
+ self,
676
+ pixel_values: Optional[torch.FloatTensor] = None,
677
+ pixel_masks: Optional[torch.LongTensor] = None,
678
+ output_attentions: Optional[bool] = None,
679
+ output_hidden_states: Optional[bool] = None,
680
+ return_dict: Optional[bool] = None,
681
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
682
+ r"""
683
+ Returns:
684
+
685
+ """
686
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
687
+ output_hidden_states = (
688
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
689
+ )
690
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
691
+
692
+ if pixel_values is None:
693
+ raise ValueError("You have to specify pixel_values")
694
+
695
+ hidden_states = self.embeddings(pixel_values)
696
+ B, N, D = hidden_states.shape
697
+ # print('Before mask:', hidden_states.shape)
698
+ if pixel_masks is not None:
699
+ assert pixel_masks.shape[0] == 1
700
+ patch_masks = self.patch_mask_generator(pixel_masks)
701
+ # print(patch_masks.shape)
702
+ patch_masks = patch_masks.unsqueeze(-1).expand_as(hidden_states)
703
+ hidden_states = hidden_states.masked_select(patch_masks).view(B, -1, D)
704
+ # print('After mask:', hidden_states.shape)
705
+
706
+ encoder_outputs = self.encoder(
707
+ inputs_embeds=hidden_states,
708
+ output_attentions=output_attentions,
709
+ output_hidden_states=output_hidden_states,
710
+ return_dict=return_dict,
711
+ )
712
+
713
+ last_hidden_state = encoder_outputs[0]
714
+ last_hidden_state = self.post_layernorm(last_hidden_state)
715
+
716
+ pooled_output = last_hidden_state[:, 0, :]
717
+ pooled_output = self.post_layernorm(pooled_output)
718
+
719
+ if not return_dict:
720
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
721
+
722
+ return BaseModelOutputWithPooling(
723
+ last_hidden_state=last_hidden_state,
724
+ pooler_output=pooled_output,
725
+ hidden_states=encoder_outputs.hidden_states,
726
+ attentions=encoder_outputs.attentions,
727
+ )
728
+
729
+ def get_input_embeddings(self):
730
+ return self.embeddings
731
+
732
+
733
+ @add_start_docstrings(BLIP_START_DOCSTRING)
734
+ class BlipModel(BlipPreTrainedModel):
735
+ config_class = BlipConfig
736
+
737
+ def __init__(self, config: BlipConfig):
738
+ super().__init__(config)
739
+
740
+ if not isinstance(config.text_config, BlipTextConfig):
741
+ raise ValueError(
742
+ "config.text_config is expected to be of type BlipTextConfig but is of type"
743
+ f" {type(config.text_config)}."
744
+ )
745
+
746
+ if not isinstance(config.vision_config, BlipVisionConfig):
747
+ raise ValueError(
748
+ "config.vision_config is expected to be of type BlipVisionConfig but is of type"
749
+ f" {type(config.vision_config)}."
750
+ )
751
+
752
+ text_config = config.text_config
753
+ vision_config = config.vision_config
754
+
755
+ self.projection_dim = config.projection_dim
756
+ self.text_embed_dim = text_config.hidden_size
757
+ self.vision_embed_dim = vision_config.hidden_size
758
+
759
+ self.text_model = BlipTextModel(text_config)
760
+ self.vision_model = BlipVisionModel(vision_config)
761
+
762
+ self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
763
+ self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
764
+ self.logit_scale = nn.Parameter(torch.ones([]) * self.config.logit_scale_init_value)
765
+
766
+ # Initialize weights and apply final processing
767
+ self.post_init()
768
+
769
+ @add_start_docstrings_to_model_forward(BLIP_TEXT_INPUTS_DOCSTRING)
770
+ def get_text_features(
771
+ self,
772
+ input_ids: Optional[torch.Tensor] = None,
773
+ attention_mask: Optional[torch.Tensor] = None,
774
+ position_ids: Optional[torch.Tensor] = None,
775
+ return_dict: Optional[bool] = None,
776
+ ) -> torch.FloatTensor:
777
+ r"""
778
+ Returns:
779
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
780
+ applying the projection layer to the pooled output of [`BlipTextModel`].
781
+
782
+ Examples:
783
+
784
+ ```python
785
+ >>> from transformers import AutoProcessor, BlipModel
786
+
787
+ >>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base")
788
+ >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
789
+
790
+ >>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
791
+ >>> text_features = model.get_text_features(**inputs)
792
+ ```"""
793
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
794
+
795
+ text_outputs = self.text_model(
796
+ input_ids=input_ids,
797
+ attention_mask=attention_mask,
798
+ position_ids=position_ids,
799
+ return_dict=return_dict,
800
+ )
801
+
802
+ pooled_output = text_outputs[1]
803
+ text_features = self.text_projection(pooled_output)
804
+
805
+ return text_features
806
+
807
+ @add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
808
+ def get_image_features(
809
+ self,
810
+ pixel_values: Optional[torch.FloatTensor] = None,
811
+ return_dict: Optional[bool] = None,
812
+ ) -> torch.FloatTensor:
813
+ r"""
814
+ Returns:
815
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
816
+ applying the projection layer to the pooled output of [`BlipVisionModel`].
817
+
818
+ Examples:
819
+
820
+ ```python
821
+ >>> from PIL import Image
822
+ >>> import requests
823
+ >>> from transformers import AutoProcessor, BlipModel
824
+
825
+ >>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base")
826
+ >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
827
+
828
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
829
+ >>> image = Image.open(requests.get(url, stream=True).raw)
830
+
831
+ >>> inputs = processor(images=image, return_tensors="pt")
832
+
833
+ >>> image_features = model.get_image_features(**inputs)
834
+ ```"""
835
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
836
+
837
+ vision_outputs = self.vision_model(
838
+ pixel_values=pixel_values,
839
+ return_dict=return_dict,
840
+ )
841
+
842
+ pooled_output = vision_outputs[1] # pooled_output
843
+ image_features = self.visual_projection(pooled_output)
844
+
845
+ return image_features
846
+
847
+ @add_start_docstrings_to_model_forward(BLIP_INPUTS_DOCSTRING)
848
+ @replace_return_docstrings(output_type=BlipOutput, config_class=BlipConfig)
849
+ def forward(
850
+ self,
851
+ input_ids: Optional[torch.LongTensor] = None,
852
+ pixel_values: Optional[torch.FloatTensor] = None,
853
+ pixel_masks: Optional[torch.FloatTensor] = None,
854
+ attention_mask: Optional[torch.Tensor] = None,
855
+ position_ids: Optional[torch.LongTensor] = None,
856
+ return_loss: Optional[bool] = None,
857
+ output_attentions: Optional[bool] = None,
858
+ output_hidden_states: Optional[bool] = None,
859
+ return_dict: Optional[bool] = None,
860
+ ) -> Union[Tuple, BlipOutput]:
861
+ r"""
862
+ Returns:
863
+
864
+ Examples:
865
+
866
+ ```python
867
+ >>> from PIL import Image
868
+ >>> import requests
869
+ >>> from transformers import AutoProcessor, BlipModel
870
+
871
+ >>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base")
872
+ >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
873
+
874
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
875
+ >>> image = Image.open(requests.get(url, stream=True).raw)
876
+
877
+ >>> inputs = processor(
878
+ ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
879
+ ... )
880
+
881
+ >>> outputs = model(**inputs)
882
+ >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
883
+ >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
884
+ ```"""
885
+ # Use BLIP model's config for some fields (if specified) instead of those of vision & text components.
886
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
887
+ output_hidden_states = (
888
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
889
+ )
890
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
891
+
892
+ vision_outputs = self.vision_model(
893
+ pixel_values=pixel_values,
894
+ pixel_masks=pixel_masks,
895
+ output_attentions=output_attentions,
896
+ output_hidden_states=output_hidden_states,
897
+ return_dict=return_dict,
898
+ )
899
+
900
+ text_outputs = self.text_model(
901
+ input_ids=input_ids,
902
+ attention_mask=attention_mask,
903
+ position_ids=position_ids,
904
+ output_attentions=output_attentions,
905
+ output_hidden_states=output_hidden_states,
906
+ return_dict=return_dict,
907
+ )
908
+
909
+ image_embeds = vision_outputs[1]
910
+ image_embeds = self.visual_projection(image_embeds)
911
+
912
+ text_embeds = text_outputs[1]
913
+ text_embeds = self.text_projection(text_embeds)
914
+
915
+ # normalized features
916
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
917
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
918
+
919
+ # cosine similarity as logits
920
+ logit_scale = self.logit_scale.exp()
921
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
922
+ logits_per_image = logits_per_text.t()
923
+
924
+ loss = None
925
+ if return_loss:
926
+ loss = blip_loss(logits_per_text)
927
+
928
+ if not return_dict:
929
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
930
+ return ((loss,) + output) if loss is not None else output
931
+
932
+ return BlipOutput(
933
+ loss=loss,
934
+ logits_per_image=logits_per_image,
935
+ logits_per_text=logits_per_text,
936
+ text_embeds=text_embeds,
937
+ image_embeds=image_embeds,
938
+ text_model_output=text_outputs,
939
+ vision_model_output=vision_outputs,
940
+ )
941
+
942
+
943
+ @add_start_docstrings(
944
+ """
945
+ BLIP Model for image captioning. The model consists of a vision encoder and a text decoder. One can optionally pass
946
+ `input_ids` to the model, which serve as a text prompt, to make the text decoder continue the prompt. Otherwise,
947
+ the decoder starts generating text from the [BOS] (beginning-of-sequence) token. will start generating the caption
948
+ from the text input. If no text input is provided, the decoder will start with the [BOS] token only.
949
+ """,
950
+ BLIP_START_DOCSTRING,
951
+ )
952
+ class BlipForConditionalGeneration(BlipPreTrainedModel):
953
+ config_class = BlipConfig
954
+ _keys_to_ignore_on_load_missing = [r"text_decoder.cls.predictions.decoder.bias"]
955
+ main_input_name = "pixel_values"
956
+
957
+ def __init__(self, config: BlipConfig):
958
+ super().__init__(config)
959
+
960
+ self.vision_model = BlipVisionModel(config.vision_config)
961
+
962
+ self.text_decoder = BlipTextLMHeadModel(config.text_config)
963
+
964
+ self.decoder_input_ids = config.text_config.bos_token_id
965
+ self.decoder_pad_token_id = config.text_config.pad_token_id
966
+
967
+ # Initialize weights and apply final processing
968
+ self.post_init()
969
+
970
+ def get_input_embeddings(self) -> nn.Module:
971
+ return self.vision_model.embeddings.patch_embedding
972
+
973
+ @add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
974
+ @replace_return_docstrings(output_type=BlipForConditionalGenerationModelOutput, config_class=BlipVisionConfig)
975
+ def forward(
976
+ self,
977
+ pixel_values: torch.FloatTensor,
978
+ input_ids: Optional[torch.LongTensor] = None,
979
+ attention_mask: Optional[torch.LongTensor] = None,
980
+ output_attentions: Optional[bool] = None,
981
+ output_hidden_states: Optional[bool] = None,
982
+ labels: Optional[torch.LongTensor] = None,
983
+ return_dict: Optional[bool] = None,
984
+ ) -> Union[Tuple, BlipForConditionalGenerationModelOutput]:
985
+ r"""
986
+ Returns:
987
+
988
+ Examples:
989
+
990
+ ```python
991
+ >>> from PIL import Image
992
+ >>> import requests
993
+ >>> from transformers import AutoProcessor, BlipForConditionalGeneration
994
+
995
+ >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
996
+ >>> model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
997
+
998
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
999
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1000
+ >>> text = "A picture of"
1001
+
1002
+ >>> inputs = processor(images=image, text=text, return_tensors="pt")
1003
+
1004
+ >>> outputs = model(**inputs)
1005
+ ```"""
1006
+
1007
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1008
+
1009
+ vision_outputs = self.vision_model(
1010
+ pixel_values=pixel_values,
1011
+ output_attentions=output_attentions,
1012
+ output_hidden_states=output_hidden_states,
1013
+ return_dict=return_dict,
1014
+ )
1015
+
1016
+ image_embeds = vision_outputs[0]
1017
+
1018
+ outputs = self.text_decoder(
1019
+ input_ids=input_ids,
1020
+ attention_mask=attention_mask,
1021
+ encoder_hidden_states=image_embeds,
1022
+ labels=labels,
1023
+ return_dict=return_dict,
1024
+ reduction="mean",
1025
+ )
1026
+
1027
+ if not return_dict:
1028
+ outputs = (outputs[0], outputs[1], image_embeds, vision_outputs[0]) + vision_outputs[2:]
1029
+ return tuple(output for output in outputs if output is not None)
1030
+
1031
+ return BlipForConditionalGenerationModelOutput(
1032
+ loss=outputs.loss,
1033
+ decoder_logits=outputs.logits,
1034
+ image_embeds=image_embeds,
1035
+ last_hidden_state=vision_outputs.last_hidden_state,
1036
+ hidden_states=vision_outputs.hidden_states,
1037
+ attentions=vision_outputs.attentions,
1038
+ )
1039
+
1040
+ @torch.no_grad()
1041
+ def generate(
1042
+ self,
1043
+ pixel_values: torch.FloatTensor,
1044
+ pixel_masks: torch.Tensor = None,
1045
+ input_ids: Optional[torch.LongTensor] = None,
1046
+ attention_mask: Optional[torch.LongTensor] = None,
1047
+ **generate_kwargs,
1048
+ ) -> torch.LongTensor:
1049
+ r"""
1050
+ Overrides *generate* function to be able to use the model as a conditional generator
1051
+
1052
+ Parameters:
1053
+ pixel_values (*torch.FloatTensor* of shape *(batch_size, image_width, image_height)*:
1054
+ Input image to be processed
1055
+ input_ids (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
1056
+ The sequence used as a prompt for the generation.
1057
+ attention_mask (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
1058
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1059
+
1060
+
1061
+ Examples:
1062
+ ```python
1063
+ >>> from PIL import Image
1064
+ >>> import requests
1065
+ >>> from transformers import AutoProcessor, BlipForConditionalGeneration
1066
+
1067
+ >>> model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
1068
+ >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
1069
+
1070
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1071
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1072
+
1073
+ >>> inputs = processor(images=image, return_tensors="pt")
1074
+
1075
+ >>> outputs = model.generate(**inputs)
1076
+ >>> print(processor.decode(outputs[0], skip_special_tokens=True))
1077
+ two cats are laying on a couch
1078
+ ```
1079
+ """
1080
+
1081
+ batch_size = pixel_values.shape[0]
1082
+ vision_outputs = self.vision_model(
1083
+ pixel_values=pixel_values,
1084
+ pixel_masks=pixel_masks,
1085
+ )
1086
+
1087
+ image_embeds = vision_outputs[0]
1088
+
1089
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image_embeds.device)
1090
+
1091
+ if isinstance(input_ids, list):
1092
+ input_ids = torch.LongTensor(input_ids)
1093
+ elif input_ids is None:
1094
+ input_ids = (
1095
+ torch.LongTensor([[self.decoder_input_ids, self.config.text_config.eos_token_id]])
1096
+ .repeat(batch_size, 1)
1097
+ .to(image_embeds.device)
1098
+ )
1099
+
1100
+ input_ids[:, 0] = self.config.text_config.bos_token_id
1101
+ attention_mask = attention_mask[:, :-1] if attention_mask is not None else None
1102
+
1103
+ outputs = self.text_decoder.generate(
1104
+ input_ids=input_ids[:, :-1],
1105
+ eos_token_id=self.config.text_config.sep_token_id,
1106
+ pad_token_id=self.config.text_config.pad_token_id,
1107
+ attention_mask=attention_mask,
1108
+ encoder_hidden_states=image_embeds,
1109
+ encoder_attention_mask=image_attention_mask,
1110
+ **generate_kwargs,
1111
+ )
1112
+
1113
+ return outputs
1114
+
1115
+
1116
+ @add_start_docstrings(
1117
+ """
1118
+ BLIP Model for visual question answering. The model consists of a vision encoder, a text encoder as well as a text
1119
+ decoder. The vision encoder will encode the input image, the text encoder will encode the input question together
1120
+ with the encoding of the image, and the text decoder will output the answer to the question.
1121
+ """,
1122
+ BLIP_START_DOCSTRING,
1123
+ )
1124
+ class BlipForQuestionAnswering(BlipPreTrainedModel):
1125
+ config_class = BlipConfig
1126
+ _keys_to_ignore_on_load_missing = [r"text_decoder.cls.predictions.decoder.bias"]
1127
+
1128
+ def __init__(self, config: BlipConfig):
1129
+ super().__init__(config)
1130
+
1131
+ self.vision_model = BlipVisionModel(config.vision_config)
1132
+
1133
+ self.text_encoder = BlipTextModel(config.text_config, add_pooling_layer=False)
1134
+
1135
+ self.text_decoder = BlipTextLMHeadModel(config.text_config)
1136
+
1137
+ self.decoder_pad_token_id = config.text_config.pad_token_id
1138
+ self.decoder_start_token_id = config.text_config.bos_token_id
1139
+
1140
+ # Initialize weights and apply final processing
1141
+ self.post_init()
1142
+
1143
+ def get_input_embeddings(self) -> nn.Module:
1144
+ return self.vision_model.embeddings.patch_embedding
1145
+
1146
+ # Adapted from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right
1147
+ def _shift_right(self, input_ids):
1148
+ pad_token_id = self.decoder_pad_token_id
1149
+
1150
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
1151
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
1152
+ shifted_input_ids[..., 0] = self.decoder_start_token_id
1153
+
1154
+ # replace possible -100 values in labels by `pad_token_id`
1155
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
1156
+
1157
+ return shifted_input_ids
1158
+
1159
+ @add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
1160
+ @replace_return_docstrings(output_type=BlipTextVisionModelOutput, config_class=BlipVisionConfig)
1161
+ def forward(
1162
+ self,
1163
+ input_ids: torch.LongTensor,
1164
+ pixel_values: torch.FloatTensor,
1165
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1166
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1167
+ attention_mask: Optional[torch.LongTensor] = None,
1168
+ output_attentions: Optional[bool] = None,
1169
+ output_hidden_states: Optional[bool] = None,
1170
+ labels: Optional[torch.LongTensor] = None,
1171
+ return_dict: Optional[bool] = None,
1172
+ ) -> Union[Tuple, BlipTextVisionModelOutput]:
1173
+ r"""
1174
+ Returns:
1175
+
1176
+ Examples:
1177
+
1178
+ ```python
1179
+ >>> from PIL import Image
1180
+ >>> import requests
1181
+ >>> from transformers import AutoProcessor, BlipForQuestionAnswering
1182
+
1183
+ >>> model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
1184
+ >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
1185
+
1186
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1187
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1188
+
1189
+ >>> # training
1190
+ >>> text = "How many cats are in the picture?"
1191
+ >>> label = "2"
1192
+ >>> inputs = processor(images=image, text=text, return_tensors="pt")
1193
+ >>> labels = processor(text=label, return_tensors="pt").input_ids
1194
+
1195
+ >>> inputs["labels"] = labels
1196
+ >>> outputs = model(**inputs)
1197
+ >>> loss = outputs.loss
1198
+ >>> loss.backward()
1199
+
1200
+ >>> # inference
1201
+ >>> text = "How many cats are in the picture?"
1202
+ >>> inputs = processor(images=image, text=text, return_tensors="pt")
1203
+ >>> outputs = model.generate(**inputs)
1204
+ >>> print(processor.decode(outputs[0], skip_special_tokens=True))
1205
+ 2
1206
+ ```"""
1207
+ if labels is None and decoder_input_ids is None:
1208
+ raise ValueError(
1209
+ "Either `decoder_input_ids` or `labels` should be passed when calling `forward` with"
1210
+ " `BlipForQuestionAnswering`. if you are training the model make sure that `labels` is passed, if you"
1211
+ " are using the model for inference make sure that `decoder_input_ids` is passed or call `generate`"
1212
+ )
1213
+
1214
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1215
+
1216
+ vision_outputs = self.vision_model(
1217
+ pixel_values=pixel_values,
1218
+ output_attentions=output_attentions,
1219
+ output_hidden_states=output_hidden_states,
1220
+ return_dict=return_dict,
1221
+ )
1222
+
1223
+ image_embeds = vision_outputs[0]
1224
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long)
1225
+
1226
+ question_embeds = self.text_encoder(
1227
+ input_ids=input_ids,
1228
+ attention_mask=attention_mask,
1229
+ encoder_hidden_states=image_embeds,
1230
+ encoder_attention_mask=image_attention_mask,
1231
+ return_dict=return_dict,
1232
+ )
1233
+
1234
+ question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state
1235
+
1236
+ if labels is not None and decoder_input_ids is None:
1237
+ # get decoder inputs from shifting lm labels to the right - this is used in training mode
1238
+ decoder_input_ids = self._shift_right(labels)
1239
+ # replace possible -100 values in labels by `pad_token_id`
1240
+ labels = labels.masked_fill(labels == self.decoder_pad_token_id, -100)
1241
+
1242
+ answer_output = self.text_decoder(
1243
+ input_ids=decoder_input_ids,
1244
+ attention_mask=decoder_attention_mask,
1245
+ encoder_hidden_states=question_embeds,
1246
+ encoder_attention_mask=attention_mask,
1247
+ labels=labels,
1248
+ return_dict=return_dict,
1249
+ reduction="mean",
1250
+ )
1251
+
1252
+ if labels is not None:
1253
+ decoder_loss = answer_output.loss.mean() if return_dict else answer_output[0].mean()
1254
+ else:
1255
+ decoder_loss = None
1256
+
1257
+ if not return_dict:
1258
+ outputs = (decoder_loss, image_embeds, vision_outputs[0]) + vision_outputs[2:]
1259
+ return tuple(output for output in outputs if output is not None)
1260
+
1261
+ return BlipTextVisionModelOutput(
1262
+ loss=decoder_loss,
1263
+ image_embeds=image_embeds,
1264
+ last_hidden_state=vision_outputs.last_hidden_state,
1265
+ hidden_states=vision_outputs.hidden_states,
1266
+ attentions=vision_outputs.attentions,
1267
+ )
1268
+
1269
+ @torch.no_grad()
1270
+ def generate(
1271
+ self,
1272
+ input_ids: torch.LongTensor,
1273
+ pixel_values: torch.FloatTensor,
1274
+ pixel_masks: torch.Tensor = None,
1275
+ attention_mask: Optional[torch.LongTensor] = None,
1276
+ **generate_kwargs,
1277
+ ) -> torch.LongTensor:
1278
+ r"""
1279
+ Overrides *generate* function to be able to use the model as a conditional generator
1280
+
1281
+ Parameters:
1282
+ input_ids (*torch.LongTensor* of shape *(batch_size, sequence_length)*):
1283
+ The sequence used as a prompt for the generation.
1284
+ pixel_values (*torch.FloatTensor* of shape *(batch_size, image_width, image_height)*:
1285
+ Input image to be processed
1286
+ attention_mask (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
1287
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`. `1` for
1288
+ tokens that are NOT MASKED, `0` for MASKED tokens.
1289
+ **generate_kwargs:
1290
+ Additional arguments passed to the *generate* function of the decoder
1291
+
1292
+
1293
+ Examples:
1294
+ ```python
1295
+ >>> from PIL import Image
1296
+ >>> import requests
1297
+ >>> from transformers import AutoProcessor, BlipForQuestionAnswering
1298
+
1299
+ >>> model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
1300
+ >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
1301
+
1302
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1303
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1304
+ >>> text = "How many cats are in the picture?"
1305
+
1306
+ >>> inputs = processor(images=image, text=text, return_tensors="pt")
1307
+
1308
+ >>> outputs = model.generate(**inputs)
1309
+ >>> print(processor.decode(outputs[0], skip_special_tokens=True))
1310
+ 2
1311
+ ```
1312
+ """
1313
+ vision_outputs = self.vision_model(
1314
+ pixel_values=pixel_values,
1315
+ pixel_masks=pixel_masks
1316
+ )
1317
+
1318
+ image_embeds = vision_outputs[0]
1319
+
1320
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image_embeds.device)
1321
+
1322
+ if isinstance(input_ids, list):
1323
+ input_ids = torch.LongTensor(input_ids)
1324
+
1325
+ question_outputs = self.text_encoder(
1326
+ input_ids=input_ids,
1327
+ attention_mask=attention_mask,
1328
+ encoder_hidden_states=image_embeds,
1329
+ encoder_attention_mask=image_attention_mask,
1330
+ return_dict=False,
1331
+ )
1332
+
1333
+ question_embeds = question_outputs[0]
1334
+
1335
+ question_attention_mask = torch.ones(question_embeds.size()[:-1], dtype=torch.long).to(question_embeds.device)
1336
+
1337
+ bos_ids = torch.full(
1338
+ (question_embeds.size(0), 1), fill_value=self.decoder_start_token_id, device=question_embeds.device
1339
+ )
1340
+
1341
+ outputs = self.text_decoder.generate(
1342
+ input_ids=bos_ids,
1343
+ eos_token_id=self.config.text_config.sep_token_id,
1344
+ pad_token_id=self.config.text_config.pad_token_id,
1345
+ encoder_hidden_states=question_embeds,
1346
+ encoder_attention_mask=question_attention_mask,
1347
+ **generate_kwargs,
1348
+ )
1349
+
1350
+ return outputs
1351
+
1352
+
1353
+ @add_start_docstrings(
1354
+ """
1355
+ BLIP Model with a vision and text projector, and a classification head on top. The model is used in the context of
1356
+ image-text retrieval. Given an image and a text, the model returns the probability of the text being relevant to
1357
+ the image.
1358
+ """,
1359
+ BLIP_START_DOCSTRING,
1360
+ )
1361
+ class BlipForImageTextRetrieval(BlipPreTrainedModel):
1362
+ config_class = BlipConfig
1363
+
1364
+ def __init__(self, config: BlipConfig):
1365
+ super().__init__(config)
1366
+
1367
+ self.vision_model = BlipVisionModel(config.vision_config)
1368
+
1369
+ self.text_encoder = BlipTextModel(config.text_config, add_pooling_layer=False)
1370
+
1371
+ # vision projection layer
1372
+ self.vision_proj = nn.Linear(config.vision_config.hidden_size, config.image_text_hidden_size)
1373
+
1374
+ # text projection layer
1375
+ self.text_proj = nn.Linear(config.text_config.hidden_size, config.image_text_hidden_size)
1376
+
1377
+ # image text matching head
1378
+ self.itm_head = nn.Linear(config.text_config.hidden_size, 2)
1379
+
1380
+ self.decoder_pad_token_id = (
1381
+ config.text_config.pad_token_id
1382
+ if not hasattr(config, "decoder_pad_token_id")
1383
+ else config.decoder_pad_token_id
1384
+ )
1385
+ self.decoder_start_token_id = (
1386
+ config.text_config.bos_token_id
1387
+ if not hasattr(config, "decoder_start_token_id")
1388
+ else config.decoder_start_token_id
1389
+ )
1390
+
1391
+ # Initialize weights and apply final processing
1392
+ self.post_init()
1393
+
1394
+ def get_input_embeddings(self) -> nn.Module:
1395
+ return self.vision_model.embeddings.patch_embedding
1396
+
1397
+ @add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
1398
+ @replace_return_docstrings(output_type=BlipTextVisionModelOutput, config_class=BlipVisionConfig)
1399
+ def forward(
1400
+ self,
1401
+ input_ids: torch.LongTensor,
1402
+ pixel_values: torch.FloatTensor,
1403
+ use_itm_head: Optional[bool] = True,
1404
+ attention_mask: Optional[torch.LongTensor] = None,
1405
+ output_attentions: Optional[bool] = None,
1406
+ output_hidden_states: Optional[bool] = None,
1407
+ return_dict: Optional[bool] = None,
1408
+ ) -> Union[Tuple, BlipTextVisionModelOutput]:
1409
+ r"""
1410
+ Returns:
1411
+
1412
+ Examples:
1413
+
1414
+ ```python
1415
+ >>> from PIL import Image
1416
+ >>> import requests
1417
+ >>> from transformers import AutoProcessor, BlipForImageTextRetrieval
1418
+
1419
+ >>> model = BlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco")
1420
+ >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-itm-base-coco")
1421
+
1422
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1423
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1424
+ >>> text = "an image of a cat"
1425
+
1426
+ >>> inputs = processor(images=image, text=text, return_tensors="pt")
1427
+ >>> outputs = model(**inputs)
1428
+ ```
1429
+ """
1430
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1431
+
1432
+ vision_outputs = self.vision_model(
1433
+ pixel_values=pixel_values,
1434
+ output_attentions=output_attentions,
1435
+ output_hidden_states=output_hidden_states,
1436
+ return_dict=return_dict,
1437
+ )
1438
+
1439
+ image_embeds = vision_outputs[0]
1440
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long)
1441
+
1442
+ if use_itm_head:
1443
+ question_embeds = self.text_encoder(
1444
+ input_ids=input_ids,
1445
+ attention_mask=attention_mask,
1446
+ encoder_hidden_states=image_embeds,
1447
+ encoder_attention_mask=image_atts,
1448
+ return_dict=return_dict,
1449
+ )
1450
+ question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state
1451
+
1452
+ output = self.itm_head(question_embeds[:, 0, :])
1453
+ else:
1454
+ question_embeds = self.text_encoder(
1455
+ input_ids=input_ids,
1456
+ attention_mask=attention_mask,
1457
+ return_dict=return_dict,
1458
+ )
1459
+ question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state
1460
+
1461
+ image_feat = normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1)
1462
+ text_feat = normalize(self.text_proj(question_embeds[:, 0, :]), dim=-1)
1463
+
1464
+ output = image_feat @ text_feat.t()
1465
+
1466
+ if not return_dict:
1467
+ outputs = (output, vision_outputs[0]) + vision_outputs[2:] + (question_embeds,)
1468
+ return tuple(output for output in outputs if output is not None)
1469
+
1470
+ return BlipImageTextMatchingModelOutput(
1471
+ itm_score=output,
1472
+ last_hidden_state=vision_outputs.last_hidden_state,
1473
+ hidden_states=vision_outputs.hidden_states,
1474
+ attentions=vision_outputs.attentions,
1475
+ question_embeds=question_embeds,
1476
+ )
caption_anything/captioner/modeling_git.py ADDED
@@ -0,0 +1,1587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 Microsoft Research and The HuggingFace Inc. team.
3
+ # All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch GIT model."""
17
+
18
+
19
+ import math
20
+ from dataclasses import dataclass
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ from torch.nn import CrossEntropyLoss
27
+
28
+ from transformers.activations import ACT2FN
29
+ from transformers.file_utils import ModelOutput
30
+ from transformers.modeling_outputs import (
31
+ BaseModelOutput,
32
+ BaseModelOutputWithPast,
33
+ BaseModelOutputWithPooling,
34
+ CausalLMOutputWithPast,
35
+ )
36
+ from transformers.modeling_utils import PreTrainedModel
37
+ from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
38
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
39
+ from transformers.models.git.configuration_git import GitConfig, GitVisionConfig
40
+ from .vit_pixel_masks_utils import ViTPatchMaskGenerator
41
+
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+ _CHECKPOINT_FOR_DOC = "microsoft/git-base"
46
+ _CONFIG_FOR_DOC = "GitConfig"
47
+
48
+ GIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
49
+ "microsoft/git-base",
50
+ # See all GIT models at https://huggingface.co/models?filter=git
51
+ ]
52
+
53
+
54
+ @dataclass
55
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Git
56
+ class GitVisionModelOutput(ModelOutput):
57
+ """
58
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
59
+
60
+ Args:
61
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
62
+ The image embeddings obtained by applying the projection layer to the pooler_output.
63
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
64
+ Sequence of hidden-states at the output of the last layer of the model.
65
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
66
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
67
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
68
+
69
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
70
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
71
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
72
+ sequence_length)`.
73
+
74
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
75
+ heads.
76
+ """
77
+
78
+ image_embeds: Optional[torch.FloatTensor] = None
79
+ last_hidden_state: torch.FloatTensor = None
80
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
81
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
82
+
83
+
84
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
85
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
86
+ """
87
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
88
+ """
89
+ bsz, src_len = mask.size()
90
+ tgt_len = tgt_len if tgt_len is not None else src_len
91
+
92
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
93
+
94
+ inverted_mask = 1.0 - expanded_mask
95
+
96
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
97
+
98
+
99
+ class GitEmbeddings(nn.Module):
100
+ """Construct the embeddings from word and position embeddings."""
101
+
102
+ def __init__(self, config):
103
+ super().__init__()
104
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
105
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
106
+
107
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
108
+ # any TensorFlow checkpoint file
109
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
110
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
111
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
112
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
113
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
114
+
115
+ def forward(
116
+ self,
117
+ input_ids: Optional[torch.LongTensor] = None,
118
+ position_ids: Optional[torch.LongTensor] = None,
119
+ inputs_embeds: Optional[torch.FloatTensor] = None,
120
+ past_key_values_length: int = 0,
121
+ ) -> torch.Tensor:
122
+ if input_ids is not None:
123
+ input_shape = input_ids.size()
124
+ else:
125
+ input_shape = inputs_embeds.size()[:-1]
126
+
127
+ seq_length = input_shape[1]
128
+
129
+ if position_ids is None:
130
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
131
+
132
+ if inputs_embeds is None:
133
+ embeddings = self.word_embeddings(input_ids)
134
+ else:
135
+ embeddings = inputs_embeds
136
+
137
+ if self.position_embedding_type == "absolute":
138
+ position_embeddings = self.position_embeddings(position_ids)
139
+ embeddings += position_embeddings
140
+ embeddings = self.LayerNorm(embeddings)
141
+ embeddings = self.dropout(embeddings)
142
+ return embeddings
143
+
144
+
145
+ class GitSelfAttention(nn.Module):
146
+ def __init__(self, config, position_embedding_type=None):
147
+ super().__init__()
148
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
149
+ raise ValueError(
150
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
151
+ f"heads ({config.num_attention_heads})"
152
+ )
153
+
154
+ self.num_attention_heads = config.num_attention_heads
155
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
156
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
157
+ self.image_patch_tokens = int((config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1)
158
+ if config.num_image_with_embedding is not None:
159
+ self.image_patch_tokens *= config.num_image_with_embedding
160
+
161
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
162
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
163
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
164
+
165
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
166
+ self.position_embedding_type = position_embedding_type or getattr(
167
+ config, "position_embedding_type", "absolute"
168
+ )
169
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
170
+ self.max_position_embeddings = config.max_position_embeddings
171
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
172
+
173
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
174
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
175
+ x = x.view(new_x_shape)
176
+ return x.permute(0, 2, 1, 3)
177
+
178
+ def forward(
179
+ self,
180
+ hidden_states: torch.Tensor,
181
+ attention_mask: Optional[torch.FloatTensor] = None,
182
+ head_mask: Optional[torch.FloatTensor] = None,
183
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
184
+ output_attentions: Optional[bool] = False,
185
+ pixel_values_present: Optional[bool] = False,
186
+ image_token_num: Optional[int] = None
187
+ ) -> Tuple[torch.Tensor]:
188
+ mixed_query_layer = self.query(hidden_states)
189
+ if image_token_num is not None:
190
+ cutoff = image_token_num
191
+ else:
192
+ cutoff = self.image_patch_tokens if pixel_values_present else 0
193
+ if past_key_value is not None:
194
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
195
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
196
+ key_layer = torch.cat([key_layer[:, :, :cutoff, :], past_key_value[0], key_layer[:, :, -1:, :]], dim=2)
197
+ value_layer = torch.cat(
198
+ [value_layer[:, :, :cutoff, :], past_key_value[1], value_layer[:, :, -1:, :]], dim=2
199
+ )
200
+ else:
201
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
202
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
203
+
204
+ query_layer = self.transpose_for_scores(mixed_query_layer)
205
+
206
+ use_cache = past_key_value is not None
207
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
208
+ # Further calls to cross_attention layer can then reuse all cross-attention
209
+ # key/value_states (first "if" case)
210
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
211
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
212
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
213
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
214
+ # NOTE: like in other caches, we store the text component. In GIT it means we discard the image component.
215
+ past_key_value = (
216
+ key_layer[:, :, cutoff:, :],
217
+ value_layer[:, :, cutoff:, :],
218
+ )
219
+
220
+ # Take the dot product between "query" and "key" to get the raw attention scores.
221
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
222
+
223
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
224
+ query_length, key_length = query_layer.shape[2], key_layer.shape[2]
225
+ if use_cache:
226
+ position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
227
+ -1, 1
228
+ )
229
+ else:
230
+ position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
231
+ position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
232
+ distance = position_ids_l - position_ids_r
233
+
234
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
235
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
236
+
237
+ if self.position_embedding_type == "relative_key":
238
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
239
+ attention_scores = attention_scores + relative_position_scores
240
+ elif self.position_embedding_type == "relative_key_query":
241
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
242
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
243
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
244
+
245
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
246
+ if attention_mask is not None:
247
+ # Apply the attention mask is (precomputed for all layers in GitModel forward() function)
248
+ attention_scores = attention_scores + attention_mask
249
+
250
+ # Normalize the attention scores to probabilities.
251
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
252
+
253
+ # This is actually dropping out entire tokens to attend to, which might
254
+ # seem a bit unusual, but is taken from the original Transformer paper.
255
+ attention_probs = self.dropout(attention_probs)
256
+
257
+ # Mask heads if we want to
258
+ if head_mask is not None:
259
+ attention_probs = attention_probs * head_mask
260
+
261
+ context_layer = torch.matmul(attention_probs, value_layer)
262
+
263
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
264
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
265
+ context_layer = context_layer.view(new_context_layer_shape)
266
+
267
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
268
+
269
+ outputs = outputs + (past_key_value,)
270
+ return outputs
271
+
272
+
273
+ # Copied from transformers.models.bert.modeling_bert.BertSelfOutput
274
+ class GitSelfOutput(nn.Module):
275
+ def __init__(self, config):
276
+ super().__init__()
277
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
278
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
279
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
280
+
281
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
282
+ hidden_states = self.dense(hidden_states)
283
+ hidden_states = self.dropout(hidden_states)
284
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
285
+ return hidden_states
286
+
287
+
288
+ class GitAttention(nn.Module):
289
+ # Copied from transformers.models.bert.modeling_bert.BertAttention.__init__ with Bert->Git
290
+ def __init__(self, config, position_embedding_type=None):
291
+ super().__init__()
292
+ self.self = GitSelfAttention(config, position_embedding_type=position_embedding_type)
293
+ self.output = GitSelfOutput(config)
294
+ self.pruned_heads = set()
295
+
296
+ # Copied from transformers.models.bert.modeling_bert.BertAttention.prune_heads
297
+ def prune_heads(self, heads):
298
+ if len(heads) == 0:
299
+ return
300
+ heads, index = find_pruneable_heads_and_indices(
301
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
302
+ )
303
+
304
+ # Prune linear layers
305
+ self.self.query = prune_linear_layer(self.self.query, index)
306
+ self.self.key = prune_linear_layer(self.self.key, index)
307
+ self.self.value = prune_linear_layer(self.self.value, index)
308
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
309
+
310
+ # Update hyper params and store pruned heads
311
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
312
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
313
+ self.pruned_heads = self.pruned_heads.union(heads)
314
+
315
+ def forward(
316
+ self,
317
+ hidden_states: torch.Tensor,
318
+ attention_mask: Optional[torch.FloatTensor] = None,
319
+ head_mask: Optional[torch.FloatTensor] = None,
320
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
321
+ output_attentions: Optional[bool] = False,
322
+ pixel_values_present: Optional[bool] = False,
323
+ image_token_num: Optional[int] = None
324
+ ) -> Tuple[torch.Tensor]:
325
+ self_outputs = self.self(
326
+ hidden_states,
327
+ attention_mask,
328
+ head_mask,
329
+ past_key_value,
330
+ output_attentions,
331
+ pixel_values_present,
332
+ image_token_num
333
+ )
334
+ attention_output = self.output(self_outputs[0], hidden_states)
335
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
336
+ return outputs
337
+
338
+
339
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate
340
+ class GitIntermediate(nn.Module):
341
+ def __init__(self, config):
342
+ super().__init__()
343
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
344
+ if isinstance(config.hidden_act, str):
345
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
346
+ else:
347
+ self.intermediate_act_fn = config.hidden_act
348
+
349
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
350
+ hidden_states = self.dense(hidden_states)
351
+ hidden_states = self.intermediate_act_fn(hidden_states)
352
+ return hidden_states
353
+
354
+
355
+ # Copied from transformers.models.bert.modeling_bert.BertOutput
356
+ class GitOutput(nn.Module):
357
+ def __init__(self, config):
358
+ super().__init__()
359
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
360
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
361
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
362
+
363
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
364
+ hidden_states = self.dense(hidden_states)
365
+ hidden_states = self.dropout(hidden_states)
366
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
367
+ return hidden_states
368
+
369
+
370
+ class GitLayer(nn.Module):
371
+ def __init__(self, config):
372
+ super().__init__()
373
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
374
+ self.seq_len_dim = 1
375
+ self.attention = GitAttention(config)
376
+ self.intermediate = GitIntermediate(config)
377
+ self.output = GitOutput(config)
378
+
379
+ def forward(
380
+ self,
381
+ hidden_states: torch.Tensor,
382
+ attention_mask: Optional[torch.FloatTensor] = None,
383
+ head_mask: Optional[torch.FloatTensor] = None,
384
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
385
+ output_attentions: Optional[bool] = False,
386
+ pixel_values_present: Optional[bool] = False,
387
+ image_token_num: Optional[bool] = None,
388
+ ) -> Tuple[torch.Tensor]:
389
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
390
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
391
+ self_attention_outputs = self.attention(
392
+ hidden_states,
393
+ attention_mask,
394
+ head_mask,
395
+ output_attentions=output_attentions,
396
+ past_key_value=self_attn_past_key_value,
397
+ pixel_values_present=pixel_values_present,
398
+ image_token_num=image_token_num
399
+ )
400
+ attention_output = self_attention_outputs[0]
401
+
402
+ # if decoder, the last output is tuple of self-attn cache
403
+ outputs = self_attention_outputs[1:-1]
404
+ present_key_value = self_attention_outputs[-1]
405
+
406
+ layer_output = apply_chunking_to_forward(
407
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
408
+ )
409
+ outputs = (layer_output,) + outputs
410
+
411
+ # if decoder, return the attn key/values as the last output
412
+ outputs = outputs + (present_key_value,)
413
+
414
+ return outputs
415
+
416
+ def feed_forward_chunk(self, attention_output):
417
+ intermediate_output = self.intermediate(attention_output)
418
+ layer_output = self.output(intermediate_output, attention_output)
419
+ return layer_output
420
+
421
+
422
+ class GitEncoder(nn.Module):
423
+ # Copied from transformers.models.bert.modeling_bert.BertEncoder.__init__ with Bert->Git
424
+ def __init__(self, config):
425
+ super().__init__()
426
+ self.config = config
427
+ self.layer = nn.ModuleList([GitLayer(config) for _ in range(config.num_hidden_layers)])
428
+ self.gradient_checkpointing = False
429
+
430
+ def forward(
431
+ self,
432
+ hidden_states: torch.Tensor,
433
+ attention_mask: Optional[torch.FloatTensor] = None,
434
+ head_mask: Optional[torch.FloatTensor] = None,
435
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
436
+ use_cache: Optional[bool] = None,
437
+ output_attentions: Optional[bool] = False,
438
+ output_hidden_states: Optional[bool] = False,
439
+ pixel_values_present: Optional[bool] = False,
440
+ image_token_num: Optional[int] = None,
441
+ return_dict: Optional[bool] = True,
442
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]:
443
+ if self.gradient_checkpointing and self.training:
444
+ if use_cache:
445
+ logger.warning_once(
446
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
447
+ )
448
+ use_cache = False
449
+
450
+ all_hidden_states = () if output_hidden_states else None
451
+ all_self_attentions = () if output_attentions else None
452
+
453
+ next_decoder_cache = () if use_cache else None
454
+ for i, layer_module in enumerate(self.layer):
455
+ if output_hidden_states:
456
+ all_hidden_states = all_hidden_states + (hidden_states,)
457
+
458
+ layer_head_mask = head_mask[i] if head_mask is not None else None
459
+ past_key_value = past_key_values[i] if past_key_values is not None else None
460
+
461
+ if self.gradient_checkpointing and self.training:
462
+
463
+ def create_custom_forward(module):
464
+ def custom_forward(*inputs):
465
+ return module(*inputs, past_key_value, output_attentions)
466
+
467
+ return custom_forward
468
+
469
+ layer_outputs = torch.utils.checkpoint.checkpoint(
470
+ create_custom_forward(layer_module),
471
+ hidden_states,
472
+ attention_mask,
473
+ layer_head_mask,
474
+ )
475
+ else:
476
+ layer_outputs = layer_module(
477
+ hidden_states,
478
+ attention_mask,
479
+ layer_head_mask,
480
+ past_key_value,
481
+ output_attentions,
482
+ pixel_values_present,
483
+ image_token_num,
484
+
485
+ )
486
+
487
+ hidden_states = layer_outputs[0]
488
+ if use_cache:
489
+ next_decoder_cache += (layer_outputs[-1],)
490
+ if output_attentions:
491
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
492
+
493
+ if output_hidden_states:
494
+ all_hidden_states = all_hidden_states + (hidden_states,)
495
+
496
+ if not return_dict:
497
+ return tuple(
498
+ v
499
+ for v in [
500
+ hidden_states,
501
+ next_decoder_cache,
502
+ all_hidden_states,
503
+ all_self_attentions,
504
+ ]
505
+ if v is not None
506
+ )
507
+ return BaseModelOutputWithPast(
508
+ last_hidden_state=hidden_states,
509
+ past_key_values=next_decoder_cache,
510
+ hidden_states=all_hidden_states,
511
+ attentions=all_self_attentions,
512
+ )
513
+
514
+
515
+ class GitPreTrainedModel(PreTrainedModel):
516
+ """
517
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
518
+ models.
519
+ """
520
+
521
+ config_class = GitConfig
522
+ base_model_prefix = "git"
523
+ supports_gradient_checkpointing = True
524
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
525
+
526
+ def _init_weights(self, module):
527
+ """Initialize the weights"""
528
+ if isinstance(module, GitVisionEmbeddings):
529
+ nn.init.normal_(module.class_embedding, mean=0.0, std=self.config.initializer_range)
530
+ nn.init.normal_(module.patch_embedding.weight, std=self.config.initializer_range)
531
+ nn.init.normal_(module.position_embedding.weight, std=self.config.initializer_range)
532
+ if isinstance(module, nn.Linear):
533
+ # Slightly different from the TF version which uses truncated_normal for initialization
534
+ # cf https://github.com/pytorch/pytorch/pull/5617
535
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
536
+ if module.bias is not None:
537
+ module.bias.data.zero_()
538
+ elif isinstance(module, nn.Embedding):
539
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
540
+ if module.padding_idx is not None:
541
+ module.weight.data[module.padding_idx].zero_()
542
+ elif isinstance(module, nn.LayerNorm):
543
+ module.bias.data.zero_()
544
+ module.weight.data.fill_(1.0)
545
+
546
+ def _set_gradient_checkpointing(self, module, value=False):
547
+ if isinstance(module, (GitEncoder, GitVisionEncoder)):
548
+ module.gradient_checkpointing = value
549
+
550
+
551
+ GIT_START_DOCSTRING = r"""
552
+
553
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
554
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
555
+ etc.)
556
+
557
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
558
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
559
+ and behavior.
560
+
561
+ Parameters:
562
+ config ([`GitConfig`]): Model configuration class with all the parameters of the model.
563
+ Initializing with a config file does not load the weights associated with the model, only the
564
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
565
+ """
566
+
567
+ GIT_INPUTS_DOCSTRING = r"""
568
+ Args:
569
+ input_ids (`torch.LongTensor` of shape `({0})`):
570
+ Indices of input sequence tokens in the vocabulary.
571
+
572
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
573
+ [`PreTrainedTokenizer.__call__`] for details.
574
+
575
+ [What are input IDs?](../glossary#input-ids)
576
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
577
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
578
+
579
+ - 1 for tokens that are **not masked**,
580
+ - 0 for tokens that are **masked**.
581
+
582
+ [What are attention masks?](../glossary#attention-mask)
583
+
584
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
585
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
586
+ config.max_position_embeddings - 1]`.
587
+
588
+ [What are position IDs?](../glossary#position-ids)
589
+
590
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
591
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
592
+ [`CLIPImageProcessor.__call__`] for details.
593
+
594
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
595
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
596
+
597
+ - 1 indicates the head is **not masked**,
598
+ - 0 indicates the head is **masked**.
599
+
600
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
601
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
602
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
603
+ model's internal embedding lookup matrix.
604
+ output_attentions (`bool`, *optional*):
605
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
606
+ tensors for more detail.
607
+ output_hidden_states (`bool`, *optional*):
608
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
609
+ more detail.
610
+ return_dict (`bool`, *optional*):
611
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
612
+ """
613
+
614
+
615
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Git
616
+ class GitVisionEmbeddings(nn.Module):
617
+ def __init__(self, config: GitVisionConfig):
618
+ super().__init__()
619
+ self.config = config
620
+ self.embed_dim = config.hidden_size
621
+ self.image_size = config.image_size
622
+ self.patch_size = config.patch_size
623
+
624
+ self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
625
+
626
+ self.patch_embedding = nn.Conv2d(
627
+ in_channels=config.num_channels,
628
+ out_channels=self.embed_dim,
629
+ kernel_size=self.patch_size,
630
+ stride=self.patch_size,
631
+ bias=False,
632
+ )
633
+
634
+ self.num_patches = (self.image_size // self.patch_size) ** 2
635
+ self.num_positions = self.num_patches + 1
636
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
637
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
638
+
639
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
640
+ batch_size = pixel_values.shape[0]
641
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
642
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
643
+
644
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
645
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
646
+ embeddings = embeddings + self.position_embedding(self.position_ids)
647
+ return embeddings
648
+
649
+
650
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP
651
+ class GitVisionMLP(nn.Module):
652
+ def __init__(self, config):
653
+ super().__init__()
654
+ self.config = config
655
+ self.activation_fn = ACT2FN[config.hidden_act]
656
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
657
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
658
+
659
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
660
+ hidden_states = self.fc1(hidden_states)
661
+ hidden_states = self.activation_fn(hidden_states)
662
+ hidden_states = self.fc2(hidden_states)
663
+ return hidden_states
664
+
665
+
666
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention
667
+ class GitVisionAttention(nn.Module):
668
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
669
+
670
+ def __init__(self, config):
671
+ super().__init__()
672
+ self.config = config
673
+ self.embed_dim = config.hidden_size
674
+ self.num_heads = config.num_attention_heads
675
+ self.head_dim = self.embed_dim // self.num_heads
676
+ if self.head_dim * self.num_heads != self.embed_dim:
677
+ raise ValueError(
678
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
679
+ f" {self.num_heads})."
680
+ )
681
+ self.scale = self.head_dim**-0.5
682
+ self.dropout = config.attention_dropout
683
+
684
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
685
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
686
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
687
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
688
+
689
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
690
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
691
+
692
+ def forward(
693
+ self,
694
+ hidden_states: torch.Tensor,
695
+ attention_mask: Optional[torch.Tensor] = None,
696
+ causal_attention_mask: Optional[torch.Tensor] = None,
697
+ output_attentions: Optional[bool] = False,
698
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
699
+ """Input shape: Batch x Time x Channel"""
700
+
701
+ bsz, tgt_len, embed_dim = hidden_states.size()
702
+
703
+ # get query proj
704
+ query_states = self.q_proj(hidden_states) * self.scale
705
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
706
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
707
+
708
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
709
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
710
+ key_states = key_states.view(*proj_shape)
711
+ value_states = value_states.view(*proj_shape)
712
+
713
+ src_len = key_states.size(1)
714
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
715
+
716
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
717
+ raise ValueError(
718
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
719
+ f" {attn_weights.size()}"
720
+ )
721
+
722
+ # apply the causal_attention_mask first
723
+ if causal_attention_mask is not None:
724
+ if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
725
+ raise ValueError(
726
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
727
+ f" {causal_attention_mask.size()}"
728
+ )
729
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
730
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
731
+
732
+ if attention_mask is not None:
733
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
734
+ raise ValueError(
735
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
736
+ )
737
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
738
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
739
+
740
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
741
+
742
+ if output_attentions:
743
+ # this operation is a bit akward, but it's required to
744
+ # make sure that attn_weights keeps its gradient.
745
+ # In order to do so, attn_weights have to reshaped
746
+ # twice and have to be reused in the following
747
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
748
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
749
+ else:
750
+ attn_weights_reshaped = None
751
+
752
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
753
+
754
+ attn_output = torch.bmm(attn_probs, value_states)
755
+
756
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
757
+ raise ValueError(
758
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
759
+ f" {attn_output.size()}"
760
+ )
761
+
762
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
763
+ attn_output = attn_output.transpose(1, 2)
764
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
765
+
766
+ attn_output = self.out_proj(attn_output)
767
+
768
+ return attn_output, attn_weights_reshaped
769
+
770
+
771
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->GitVision
772
+ class GitVisionEncoderLayer(nn.Module):
773
+ def __init__(self, config: GitVisionConfig):
774
+ super().__init__()
775
+ self.embed_dim = config.hidden_size
776
+ self.self_attn = GitVisionAttention(config)
777
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
778
+ self.mlp = GitVisionMLP(config)
779
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
780
+
781
+ def forward(
782
+ self,
783
+ hidden_states: torch.Tensor,
784
+ attention_mask: torch.Tensor,
785
+ causal_attention_mask: torch.Tensor,
786
+ output_attentions: Optional[bool] = False,
787
+ ) -> Tuple[torch.FloatTensor]:
788
+ """
789
+ Args:
790
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
791
+ attention_mask (`torch.FloatTensor`): attention mask of size
792
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
793
+ `(config.encoder_attention_heads,)`.
794
+ output_attentions (`bool`, *optional*):
795
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
796
+ returned tensors for more detail.
797
+ """
798
+ residual = hidden_states
799
+
800
+ hidden_states = self.layer_norm1(hidden_states)
801
+ hidden_states, attn_weights = self.self_attn(
802
+ hidden_states=hidden_states,
803
+ attention_mask=attention_mask,
804
+ causal_attention_mask=causal_attention_mask,
805
+ output_attentions=output_attentions,
806
+ )
807
+ hidden_states = residual + hidden_states
808
+
809
+ residual = hidden_states
810
+ hidden_states = self.layer_norm2(hidden_states)
811
+ hidden_states = self.mlp(hidden_states)
812
+ hidden_states = residual + hidden_states
813
+
814
+ outputs = (hidden_states,)
815
+
816
+ if output_attentions:
817
+ outputs += (attn_weights,)
818
+
819
+ return outputs
820
+
821
+
822
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->GitVision, CLIPConfig
823
+ class GitVisionEncoder(nn.Module):
824
+ """
825
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
826
+ [`GitVisionEncoderLayer`].
827
+
828
+ Args:
829
+ config: GitVisionConfig
830
+ """
831
+
832
+ def __init__(self, config: GitVisionConfig):
833
+ super().__init__()
834
+ self.config = config
835
+ self.layers = nn.ModuleList([GitVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
836
+ self.gradient_checkpointing = False
837
+
838
+ def forward(
839
+ self,
840
+ inputs_embeds,
841
+ attention_mask: Optional[torch.Tensor] = None,
842
+ causal_attention_mask: Optional[torch.Tensor] = None,
843
+ output_attentions: Optional[bool] = None,
844
+ output_hidden_states: Optional[bool] = None,
845
+ return_dict: Optional[bool] = None,
846
+ ) -> Union[Tuple, BaseModelOutput]:
847
+ r"""
848
+ Args:
849
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
850
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
851
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
852
+ than the model's internal embedding lookup matrix.
853
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
854
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
855
+
856
+ - 1 for tokens that are **not masked**,
857
+ - 0 for tokens that are **masked**.
858
+
859
+ [What are attention masks?](../glossary#attention-mask)
860
+ causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
861
+ Causal mask for the text model. Mask values selected in `[0, 1]`:
862
+
863
+ - 1 for tokens that are **not masked**,
864
+ - 0 for tokens that are **masked**.
865
+
866
+ [What are attention masks?](../glossary#attention-mask)
867
+ output_attentions (`bool`, *optional*):
868
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
869
+ returned tensors for more detail.
870
+ output_hidden_states (`bool`, *optional*):
871
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
872
+ for more detail.
873
+ return_dict (`bool`, *optional*):
874
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
875
+ """
876
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
877
+ output_hidden_states = (
878
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
879
+ )
880
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
881
+
882
+ encoder_states = () if output_hidden_states else None
883
+ all_attentions = () if output_attentions else None
884
+
885
+ hidden_states = inputs_embeds
886
+ for idx, encoder_layer in enumerate(self.layers):
887
+ if output_hidden_states:
888
+ encoder_states = encoder_states + (hidden_states,)
889
+ if self.gradient_checkpointing and self.training:
890
+
891
+ def create_custom_forward(module):
892
+ def custom_forward(*inputs):
893
+ return module(*inputs, output_attentions)
894
+
895
+ return custom_forward
896
+
897
+ layer_outputs = torch.utils.checkpoint.checkpoint(
898
+ create_custom_forward(encoder_layer),
899
+ hidden_states,
900
+ attention_mask,
901
+ causal_attention_mask,
902
+ )
903
+ else:
904
+ layer_outputs = encoder_layer(
905
+ hidden_states,
906
+ attention_mask,
907
+ causal_attention_mask,
908
+ output_attentions=output_attentions,
909
+ )
910
+
911
+ hidden_states = layer_outputs[0]
912
+
913
+ if output_attentions:
914
+ all_attentions = all_attentions + (layer_outputs[1],)
915
+
916
+ if output_hidden_states:
917
+ encoder_states = encoder_states + (hidden_states,)
918
+
919
+ if not return_dict:
920
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
921
+ return BaseModelOutput(
922
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
923
+ )
924
+
925
+
926
+ GIT_VISION_INPUTS_DOCSTRING = r"""
927
+ Args:
928
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
929
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
930
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
931
+ output_attentions (`bool`, *optional*):
932
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
933
+ tensors for more detail.
934
+ output_hidden_states (`bool`, *optional*):
935
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
936
+ more detail.
937
+ return_dict (`bool`, *optional*):
938
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
939
+ """
940
+
941
+
942
+ class GitVisionTransformer(nn.Module):
943
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer.__init__ with CLIPEncoder->GitVisionEncoder, CLIP->Git
944
+ def __init__(self, config: GitVisionConfig):
945
+ super().__init__()
946
+ self.config = config
947
+ embed_dim = config.hidden_size
948
+
949
+ self.embeddings = GitVisionEmbeddings(config)
950
+ self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
951
+ self.patch_mask_generator = ViTPatchMaskGenerator(config.patch_size)
952
+ self.encoder = GitVisionEncoder(config)
953
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
954
+
955
+ @add_start_docstrings_to_model_forward(GIT_VISION_INPUTS_DOCSTRING)
956
+ @replace_return_docstrings(output_type=BaseModelOutput, config_class=GitVisionConfig)
957
+ def forward(
958
+ self,
959
+ pixel_values: Optional[torch.FloatTensor] = None,
960
+ pixel_masks: Optional[torch.Tensor] = None,
961
+ output_attentions: Optional[bool] = None,
962
+ output_hidden_states: Optional[bool] = None,
963
+ return_dict: Optional[bool] = None,
964
+ ) -> Union[Tuple, BaseModelOutput]:
965
+ r"""
966
+ Returns:
967
+
968
+ """
969
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
970
+ output_hidden_states = (
971
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
972
+ )
973
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
974
+
975
+ if pixel_values is None:
976
+ raise ValueError("You have to specify pixel_values")
977
+
978
+ hidden_states = self.embeddings(pixel_values)
979
+ B, N, D = hidden_states.shape
980
+ # print('Before mask:', hidden_states.shape)
981
+ if pixel_masks is not None:
982
+ assert pixel_masks.shape[0] == 1
983
+ patch_masks = self.patch_mask_generator(pixel_masks)
984
+ # print(patch_masks.shape)
985
+ patch_masks = patch_masks.unsqueeze(-1).expand_as(hidden_states)
986
+ hidden_states = hidden_states.masked_select(patch_masks).view(B, -1, D)
987
+ # print('After mask:', hidden_states.shape)
988
+ hidden_states = self.pre_layrnorm(hidden_states)
989
+
990
+ encoder_outputs = self.encoder(
991
+ inputs_embeds=hidden_states,
992
+ output_attentions=output_attentions,
993
+ output_hidden_states=output_hidden_states,
994
+ return_dict=return_dict,
995
+ )
996
+
997
+ last_hidden_state = encoder_outputs[0]
998
+
999
+ last_hidden_state = self.post_layernorm(last_hidden_state)
1000
+
1001
+ if not return_dict:
1002
+ return (last_hidden_state,) + encoder_outputs[1:]
1003
+
1004
+ return BaseModelOutput(
1005
+ last_hidden_state=last_hidden_state,
1006
+ hidden_states=encoder_outputs.hidden_states,
1007
+ attentions=encoder_outputs.attentions,
1008
+ )
1009
+
1010
+
1011
+ @add_start_docstrings(
1012
+ """The vision model from CLIP, used in GIT, without any head or projection on top.""",
1013
+ GIT_START_DOCSTRING,
1014
+ )
1015
+ class GitVisionModel(GitPreTrainedModel):
1016
+ config_class = GitVisionConfig
1017
+ main_input_name = "pixel_values"
1018
+
1019
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionModel.__init__ with CLIP->Git
1020
+ def __init__(self, config: GitVisionConfig):
1021
+ super().__init__(config)
1022
+ self.vision_model = GitVisionTransformer(config)
1023
+ # Initialize weights and apply final processing
1024
+ self.post_init()
1025
+
1026
+ def get_input_embeddings(self) -> nn.Module:
1027
+ return self.vision_model.embeddings.patch_embedding
1028
+
1029
+ @add_start_docstrings_to_model_forward(GIT_VISION_INPUTS_DOCSTRING)
1030
+ @replace_return_docstrings(output_type=BaseModelOutput, config_class=GitVisionConfig)
1031
+ def forward(
1032
+ self,
1033
+ pixel_values: Optional[torch.FloatTensor] = None,
1034
+ pixel_masks: Optional[torch.Tensor] = None,
1035
+ output_attentions: Optional[bool] = None,
1036
+ output_hidden_states: Optional[bool] = None,
1037
+ return_dict: Optional[bool] = None,
1038
+ ) -> Union[Tuple, BaseModelOutput]:
1039
+ r"""
1040
+ Returns:
1041
+
1042
+ Examples:
1043
+
1044
+ ```python
1045
+ >>> from PIL import Image
1046
+ >>> import requests
1047
+ >>> from transformers import AutoProcessor, GitVisionModel
1048
+
1049
+ >>> processor = AutoProcessor.from_pretrained("microsoft/git-base")
1050
+ >>> model = GitVisionModel.from_pretrained("microsoft/git-base")
1051
+
1052
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1053
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1054
+
1055
+ >>> inputs = processor(images=image, return_tensors="pt")
1056
+
1057
+ >>> outputs = model(**inputs)
1058
+ >>> last_hidden_state = outputs.last_hidden_state
1059
+ ```"""
1060
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1061
+
1062
+ return self.vision_model(
1063
+ pixel_values=pixel_values,
1064
+ pixel_masks=pixel_masks,
1065
+ output_attentions=output_attentions,
1066
+ output_hidden_states=output_hidden_states,
1067
+ return_dict=return_dict,
1068
+ )
1069
+
1070
+
1071
+ class GitProjection(nn.Module):
1072
+ def __init__(self, config: GitConfig):
1073
+ super().__init__()
1074
+ self.config = config
1075
+ self.visual_projection = nn.Sequential(
1076
+ nn.Linear(config.vision_config.hidden_size, config.hidden_size),
1077
+ nn.LayerNorm(config.hidden_size, eps=config.vision_config.layer_norm_eps),
1078
+ )
1079
+
1080
+ def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
1081
+ return self.visual_projection(embeddings)
1082
+
1083
+
1084
+ @add_start_docstrings(
1085
+ "The bare GIT Model transformer consisting of a CLIP image encoder and text decoder outputting raw hidden-states"
1086
+ " without any specific head on top.",
1087
+ GIT_START_DOCSTRING,
1088
+ )
1089
+ class GitModel(GitPreTrainedModel):
1090
+ def __init__(self, config):
1091
+ super().__init__(config)
1092
+ self.config = config
1093
+
1094
+ self.embeddings = GitEmbeddings(config)
1095
+ self.image_encoder = GitVisionModel(config.vision_config)
1096
+ self.encoder = GitEncoder(config)
1097
+
1098
+ self.visual_projection = GitProjection(config)
1099
+
1100
+ if config.num_image_with_embedding is not None:
1101
+ self.img_temperal_embedding = nn.ParameterList(
1102
+ nn.Parameter(torch.zeros(1, 1, config.vision_config.hidden_size))
1103
+ for _ in range(config.num_image_with_embedding)
1104
+ )
1105
+
1106
+ # Initialize weights and apply final processing
1107
+ self.post_init()
1108
+
1109
+ def get_input_embeddings(self):
1110
+ return self.embeddings.word_embeddings
1111
+
1112
+ def set_input_embeddings(self, value):
1113
+ self.embeddings.word_embeddings = value
1114
+
1115
+ def _prune_heads(self, heads_to_prune):
1116
+ """
1117
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1118
+ class PreTrainedModel
1119
+ """
1120
+ for layer, heads in heads_to_prune.items():
1121
+ self.encoder.layer[layer].attention.prune_heads(heads)
1122
+
1123
+ def _generate_future_mask(self, size: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
1124
+ # Default mask is for forward direction. Flip for backward direction.
1125
+ mask = torch.triu(torch.ones(size, size, device=device, dtype=dtype), diagonal=1)
1126
+ mask = mask.masked_fill(mask == 1, float("-inf"))
1127
+ return mask
1128
+
1129
+ def create_attention_mask(self, tgt, memory, tgt_mask, past_key_values_length, memory_key_padding_mask=None):
1130
+ num_tgt = tgt.shape[1]
1131
+ num_memory = memory.shape[1]
1132
+ device = tgt.device
1133
+ dtype = tgt.dtype
1134
+ top_left = torch.zeros((num_memory, num_memory), device=device, dtype=dtype)
1135
+ top_right = torch.full(
1136
+ (num_memory, num_tgt + past_key_values_length),
1137
+ float("-inf"),
1138
+ device=tgt.device,
1139
+ dtype=dtype,
1140
+ )
1141
+ bottom_left = torch.zeros(
1142
+ (num_tgt, num_memory),
1143
+ dtype=dtype,
1144
+ device=tgt_mask.device,
1145
+ )
1146
+
1147
+ if past_key_values_length > 0:
1148
+ tgt_mask = torch.zeros(
1149
+ (tgt_mask.shape[0], tgt_mask.shape[0] + past_key_values_length),
1150
+ dtype=dtype,
1151
+ device=tgt_mask.device,
1152
+ )
1153
+
1154
+ left = torch.cat((top_left, bottom_left), dim=0)
1155
+ right = torch.cat((top_right, tgt_mask.to(dtype)), dim=0)
1156
+
1157
+ full_attention_mask = torch.cat((left, right), dim=1)[None, :]
1158
+
1159
+ if memory_key_padding_mask is None:
1160
+ memory_key_padding_mask = torch.full((memory.shape[0], memory.shape[1]), fill_value=False, device=device)
1161
+ # if it is False, it means valid. That is, it is not a padding
1162
+ if memory_key_padding_mask.dtype != torch.bool:
1163
+ raise ValueError("Memory key padding mask must be a boolean tensor.")
1164
+ zero_negative_infinity = torch.zeros_like(memory_key_padding_mask, dtype=tgt.dtype)
1165
+ zero_negative_infinity[memory_key_padding_mask] = float("-inf")
1166
+ full_attention_mask = full_attention_mask.expand(
1167
+ (memory_key_padding_mask.shape[0], num_memory + num_tgt, num_memory + past_key_values_length + num_tgt)
1168
+ )
1169
+ full_attention_mask = full_attention_mask.clone()
1170
+ origin_left = full_attention_mask[:, :, :num_memory]
1171
+ update = zero_negative_infinity[:, None, :]
1172
+ full_attention_mask[:, :, :num_memory] = origin_left + update
1173
+
1174
+ # add axis for multi-head
1175
+ full_attention_mask = full_attention_mask[:, None, :, :]
1176
+
1177
+ return full_attention_mask
1178
+
1179
+ @add_start_docstrings_to_model_forward(GIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1180
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
1181
+ def forward(
1182
+ self,
1183
+ input_ids: Optional[torch.Tensor] = None,
1184
+ attention_mask: Optional[torch.Tensor] = None,
1185
+ position_ids: Optional[torch.Tensor] = None,
1186
+ pixel_values: Optional[torch.Tensor] = None,
1187
+ pixel_masks: Optional[torch.Tensor] = None,
1188
+ head_mask: Optional[torch.Tensor] = None,
1189
+ inputs_embeds: Optional[torch.Tensor] = None,
1190
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1191
+ use_cache: Optional[bool] = None,
1192
+ output_attentions: Optional[bool] = None,
1193
+ output_hidden_states: Optional[bool] = None,
1194
+ return_dict: Optional[bool] = None,
1195
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPooling]:
1196
+ r"""
1197
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1198
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1199
+
1200
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1201
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1202
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1203
+ use_cache (`bool`, *optional*):
1204
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1205
+ `past_key_values`).
1206
+
1207
+ Returns:
1208
+
1209
+ Examples:
1210
+
1211
+ ```python
1212
+ >>> from transformers import AutoProcessor, AutoModel
1213
+ >>> import requests
1214
+ >>> from PIL import Image
1215
+
1216
+ >>> processor = AutoProcessor.from_pretrained("microsoft/git-base")
1217
+ >>> model = AutoModel.from_pretrained("microsoft/git-base")
1218
+
1219
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1220
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1221
+
1222
+ >>> text = "this is an image of two cats"
1223
+
1224
+ >>> inputs = processor(text, images=image, return_tensors="pt")
1225
+
1226
+ >>> outputs = model(**inputs)
1227
+ >>> last_hidden_state = outputs.last_hidden_state
1228
+ ```"""
1229
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1230
+ output_hidden_states = (
1231
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1232
+ )
1233
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1234
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1235
+
1236
+ if input_ids is not None and inputs_embeds is not None:
1237
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1238
+ elif input_ids is not None:
1239
+ input_shape = input_ids.size()
1240
+ elif inputs_embeds is not None:
1241
+ input_shape = inputs_embeds.size()[:-1]
1242
+ else:
1243
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1244
+
1245
+ seq_length = input_shape[1]
1246
+
1247
+ # past_key_values_length
1248
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
1249
+
1250
+ # Prepare head mask if needed
1251
+ # 1.0 in head_mask indicate we keep the head
1252
+ # attention_probs has shape bsz x n_heads x N x N
1253
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1254
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1255
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1256
+
1257
+ projected_visual_features = None
1258
+ if pixel_values is not None:
1259
+ if pixel_values.ndim == 4:
1260
+ # here we assume pixel_values is of shape (batch_size, num_channels, height, width)
1261
+ visual_features = self.image_encoder(pixel_values=pixel_values, pixel_masks=pixel_masks).last_hidden_state
1262
+
1263
+ elif pixel_values.ndim == 5:
1264
+ # here we assume pixel_values is of shape (batch_size, num_frames, num_channels, height, width)
1265
+ visual_features = []
1266
+ for frame_idx in range(pixel_values.shape[1]):
1267
+ visual_features_frame = self.image_encoder(pixel_values[:, frame_idx, :, :]).last_hidden_state
1268
+ visual_features_frame += self.img_temperal_embedding[frame_idx]
1269
+ visual_features.append(visual_features_frame)
1270
+
1271
+ # finally, concatenate all features along sequence dimension
1272
+ visual_features = torch.cat(visual_features, dim=1)
1273
+
1274
+ else:
1275
+ raise ValueError("pixel_values must be of rank 4 or 5")
1276
+
1277
+ projected_visual_features = self.visual_projection(visual_features)
1278
+ image_token_num = projected_visual_features.shape[1]
1279
+ embedding_output = self.embeddings(
1280
+ input_ids=input_ids,
1281
+ position_ids=position_ids,
1282
+ inputs_embeds=inputs_embeds,
1283
+ past_key_values_length=past_key_values_length,
1284
+ )
1285
+
1286
+ if projected_visual_features is None:
1287
+ projected_visual_features = torch.zeros(
1288
+ (embedding_output.shape[0], 0, embedding_output.shape[2]),
1289
+ dtype=embedding_output.dtype,
1290
+ device=embedding_output.device,
1291
+ )
1292
+
1293
+ # Repeat visual features to match embedding batch size.
1294
+ projected_visual_features = projected_visual_features.repeat(
1295
+ embedding_output.size(0) // projected_visual_features.size(0), 1, 1
1296
+ )
1297
+
1298
+ # concatenate patch token and text token embeddings
1299
+ hidden_states = torch.cat((projected_visual_features, embedding_output), dim=1)
1300
+
1301
+ # By default, an additive causal mask is created
1302
+ # for masking the future (one direction).
1303
+ tgt_mask = self._generate_future_mask(seq_length, embedding_output.dtype, embedding_output.device)
1304
+
1305
+ # Create an attention mask of shape (batch_size, 1, tgt_seq_len, src_seq_len)
1306
+ combined_attention_mask = self.create_attention_mask(
1307
+ tgt=embedding_output,
1308
+ memory=projected_visual_features,
1309
+ tgt_mask=tgt_mask,
1310
+ past_key_values_length=past_key_values_length,
1311
+ )
1312
+
1313
+ if attention_mask is not None:
1314
+ # if the user provides an attention mask, we add it to the default one
1315
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1316
+ expanded_attn_mask = _expand_mask(attention_mask, embedding_output.dtype, tgt_len=input_shape[-1]).to(
1317
+ embedding_output.device
1318
+ )
1319
+ if past_key_values_length > 0:
1320
+ expanded_attn_mask = expanded_attn_mask[:, :, -past_key_values_length:, :]
1321
+ else:
1322
+ combined_attention_mask[:, :, -input_shape[1] :, -input_shape[1] :] += expanded_attn_mask
1323
+
1324
+ encoder_outputs = self.encoder(
1325
+ hidden_states,
1326
+ attention_mask=combined_attention_mask,
1327
+ head_mask=head_mask,
1328
+ past_key_values=past_key_values,
1329
+ use_cache=use_cache,
1330
+ output_attentions=output_attentions,
1331
+ output_hidden_states=output_hidden_states,
1332
+ return_dict=return_dict,
1333
+ pixel_values_present=pixel_values is not None,
1334
+ image_token_num=image_token_num
1335
+ )
1336
+ sequence_output = encoder_outputs[0]
1337
+
1338
+ if not return_dict:
1339
+ return (sequence_output,) + encoder_outputs[1:]
1340
+
1341
+ return BaseModelOutputWithPast(
1342
+ last_hidden_state=sequence_output,
1343
+ past_key_values=encoder_outputs.past_key_values,
1344
+ hidden_states=encoder_outputs.hidden_states,
1345
+ attentions=encoder_outputs.attentions,
1346
+ )
1347
+
1348
+
1349
+ @add_start_docstrings(
1350
+ """GIT Model with a `language modeling` head on top for autoregressive language modeling.""", GIT_START_DOCSTRING
1351
+ )
1352
+ class GitForCausalLM(GitPreTrainedModel):
1353
+ def __init__(self, config):
1354
+ super().__init__(config)
1355
+
1356
+ self.git = GitModel(config)
1357
+ self.output = nn.Linear(config.hidden_size, config.vocab_size)
1358
+
1359
+ # Initialize weights and apply final processing
1360
+ self.post_init()
1361
+
1362
+ def get_output_embeddings(self):
1363
+ return self.output
1364
+
1365
+ def set_output_embeddings(self, new_embeddings):
1366
+ self.output = new_embeddings
1367
+
1368
+ @add_start_docstrings_to_model_forward(GIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1369
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1370
+ def forward(
1371
+ self,
1372
+ input_ids: Optional[torch.Tensor] = None,
1373
+ attention_mask: Optional[torch.Tensor] = None,
1374
+ position_ids: Optional[torch.Tensor] = None,
1375
+ pixel_values: Optional[torch.Tensor] = None,
1376
+ pixel_masks: Optional[torch.Tensor] = None,
1377
+ head_mask: Optional[torch.Tensor] = None,
1378
+ inputs_embeds: Optional[torch.Tensor] = None,
1379
+ labels: Optional[torch.Tensor] = None,
1380
+ past_key_values: Optional[List[torch.Tensor]] = None,
1381
+ use_cache: Optional[bool] = None,
1382
+ output_attentions: Optional[bool] = None,
1383
+ output_hidden_states: Optional[bool] = None,
1384
+ return_dict: Optional[bool] = None,
1385
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithPast]:
1386
+ r"""
1387
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1388
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1389
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
1390
+ ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
1391
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1392
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1393
+
1394
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1395
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1396
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1397
+ use_cache (`bool`, *optional*):
1398
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1399
+ `past_key_values`).
1400
+
1401
+ Returns:
1402
+
1403
+ Examples:
1404
+
1405
+ Image captioning example:
1406
+
1407
+ ```python
1408
+ >>> from transformers import AutoProcessor, AutoModelForCausalLM
1409
+ >>> import requests
1410
+ >>> from PIL import Image
1411
+
1412
+ >>> processor = AutoProcessor.from_pretrained("microsoft/git-base-coco")
1413
+ >>> model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco")
1414
+
1415
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1416
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1417
+
1418
+ >>> pixel_values = processor(images=image, return_tensors="pt").pixel_values
1419
+
1420
+ >>> generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
1421
+ >>> generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
1422
+ >>> print(generated_caption)
1423
+ two cats sleeping on a pink blanket next to remotes.
1424
+ ```
1425
+
1426
+ Visual question answering (VQA) example:
1427
+
1428
+ ```python
1429
+ >>> from transformers import AutoProcessor, AutoModelForCausalLM
1430
+ >>> from huggingface_hub import hf_hub_download
1431
+ >>> from PIL import Image
1432
+
1433
+ >>> processor = AutoProcessor.from_pretrained("microsoft/git-base-textvqa")
1434
+ >>> model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-textvqa")
1435
+
1436
+ >>> file_path = hf_hub_download(repo_id="nielsr/textvqa-sample", filename="bus.png", repo_type="dataset")
1437
+ >>> image = Image.open(file_path).convert("RGB")
1438
+
1439
+ >>> pixel_values = processor(images=image, return_tensors="pt").pixel_values
1440
+
1441
+ >>> question = "what does the front of the bus say at the top?"
1442
+
1443
+ >>> input_ids = processor(text=question, add_special_tokens=False).input_ids
1444
+ >>> input_ids = [processor.tokenizer.cls_token_id] + input_ids
1445
+ >>> input_ids = torch.tensor(input_ids).unsqueeze(0)
1446
+
1447
+ >>> generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50)
1448
+ >>> print(processor.batch_decode(generated_ids, skip_special_tokens=True))
1449
+ ['what does the front of the bus say at the top? special']
1450
+ ```
1451
+
1452
+ Video captioning example:
1453
+
1454
+ ```python
1455
+ >>> import av
1456
+ >>> import numpy as np
1457
+ >>> from PIL import Image
1458
+ >>> from huggingface_hub import hf_hub_download
1459
+ >>> from transformers import AutoProcessor, AutoModelForCausalLM
1460
+
1461
+ >>> processor = AutoProcessor.from_pretrained("microsoft/git-base-vatex")
1462
+ >>> model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-vatex")
1463
+
1464
+ >>> # set seed for reproducability
1465
+ >>> np.random.seed(45)
1466
+
1467
+
1468
+ >>> def read_video_pyav(container, indices):
1469
+ ... '''
1470
+ ... Decode the video with PyAV decoder.
1471
+ ... Args:
1472
+ ... container (`av.container.input.InputContainer`): PyAV container.
1473
+ ... indices (`List[int]`): List of frame indices to decode.
1474
+ ... Returns:
1475
+ ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
1476
+ ... '''
1477
+ ... frames = []
1478
+ ... container.seek(0)
1479
+ ... start_index = indices[0]
1480
+ ... end_index = indices[-1]
1481
+ ... for i, frame in enumerate(container.decode(video=0)):
1482
+ ... if i > end_index:
1483
+ ... break
1484
+ ... if i >= start_index and i in indices:
1485
+ ... frames.append(frame)
1486
+ ... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
1487
+
1488
+
1489
+ >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
1490
+ ... converted_len = int(clip_len * frame_sample_rate)
1491
+ ... end_idx = np.random.randint(converted_len, seg_len)
1492
+ ... start_idx = end_idx - converted_len
1493
+ ... indices = np.linspace(start_idx, end_idx, num=clip_len)
1494
+ ... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
1495
+ ... return indices
1496
+
1497
+
1498
+ >>> # load video
1499
+ >>> file_path = hf_hub_download(
1500
+ ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
1501
+ ... )
1502
+ >>> container = av.open(file_path)
1503
+
1504
+ >>> # sample frames
1505
+ >>> num_frames = model.config.num_image_with_embedding
1506
+ >>> indices = sample_frame_indices(
1507
+ ... clip_len=num_frames, frame_sample_rate=4, seg_len=container.streams.video[0].frames
1508
+ ... )
1509
+ >>> frames = read_video_pyav(container, indices)
1510
+
1511
+ >>> pixel_values = processor(images=list(frames), return_tensors="pt").pixel_values
1512
+
1513
+ >>> generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
1514
+
1515
+ >>> print("Generated caption:", processor.batch_decode(generated_ids, skip_special_tokens=True))
1516
+ Generated caption: ['a woman is sitting at a table and she is talking about the food she is holding.']
1517
+ ```
1518
+ """
1519
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1520
+ if labels is not None:
1521
+ use_cache = False
1522
+
1523
+ outputs = self.git(
1524
+ input_ids,
1525
+ attention_mask=attention_mask,
1526
+ position_ids=position_ids,
1527
+ pixel_values=pixel_values,
1528
+ pixel_masks=pixel_masks,
1529
+ head_mask=head_mask,
1530
+ inputs_embeds=inputs_embeds,
1531
+ past_key_values=past_key_values,
1532
+ use_cache=use_cache,
1533
+ output_attentions=output_attentions,
1534
+ output_hidden_states=output_hidden_states,
1535
+ return_dict=return_dict,
1536
+ )
1537
+
1538
+ sequence_output = outputs[0]
1539
+ logits = self.output(sequence_output)
1540
+
1541
+ loss = None
1542
+ if labels is not None:
1543
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1544
+ num_image_tokens = self.git.encoder.layer[0].attention.self.image_patch_tokens
1545
+ shifted_logits = logits[:, num_image_tokens:-1, :].contiguous()
1546
+ labels = labels[:, 1:].contiguous()
1547
+ loss_fct = CrossEntropyLoss()
1548
+ loss = loss_fct(shifted_logits.view(-1, self.config.vocab_size), labels.view(-1))
1549
+
1550
+ if not return_dict:
1551
+ output = (logits,) + outputs[1:]
1552
+ return ((loss,) + output) if loss is not None else output
1553
+
1554
+ return CausalLMOutputWithPast(
1555
+ loss=loss,
1556
+ logits=logits,
1557
+ past_key_values=outputs.past_key_values,
1558
+ hidden_states=outputs.hidden_states,
1559
+ attentions=outputs.attentions,
1560
+ )
1561
+
1562
+ def prepare_inputs_for_generation(
1563
+ self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
1564
+ ):
1565
+ # cut decoder_input_ids if past_key_values is used
1566
+ if past_key_values is not None:
1567
+ input_ids = input_ids[:, -1:]
1568
+
1569
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1570
+ input_shape = input_ids.shape
1571
+ if attention_mask is None:
1572
+ attention_mask = input_ids.new_ones(input_shape)
1573
+
1574
+ return {
1575
+ "input_ids": input_ids,
1576
+ "attention_mask": attention_mask,
1577
+ "pixel_values": kwargs.get("pixel_values", None),
1578
+ "pixel_masks": kwargs.get("pixel_masks", None),
1579
+ "past_key_values": past_key_values,
1580
+ "use_cache": use_cache,
1581
+ }
1582
+
1583
+ def _reorder_cache(self, past_key_values, beam_idx):
1584
+ reordered_past = ()
1585
+ for layer_past in past_key_values:
1586
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1587
+ return reordered_past
caption_anything/captioner/vit_pixel_masks_utils.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ class ViTPatchMaskGenerator(nn.Module):
7
+ def __init__(self, patch_size) -> None:
8
+ super(ViTPatchMaskGenerator, self).__init__()
9
+ self.patch_size = patch_size
10
+ self.pool = nn.MaxPool2d(kernel_size=patch_size, stride=patch_size)
11
+
12
+ def forward(self, pixel_masks):
13
+ patch_mask = self.pool(pixel_masks)
14
+ patch_mask = patch_mask.bool().flatten(1)
15
+ cls_token_mask = patch_mask.new_ones([patch_mask.shape[0], 1]).bool()
16
+ patch_mask = torch.cat([cls_token_mask, patch_mask], dim=-1)
17
+ return patch_mask
caption_anything/model.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import pdb
4
+ import time
5
+ from PIL import Image
6
+ import cv2
7
+ import numpy as np
8
+ from PIL import Image
9
+ import easyocr
10
+ import copy
11
+ import time
12
+ from caption_anything.captioner import build_captioner, BaseCaptioner
13
+ from caption_anything.segmenter import build_segmenter, build_segmenter_densecap
14
+ from caption_anything.text_refiner import build_text_refiner
15
+ from caption_anything.utils.utils import prepare_segmenter, seg_model_map, load_image, get_image_shape
16
+ from caption_anything.utils.utils import mask_painter_foreground_all, mask_painter, xywh_to_x1y1x2y2, image_resize
17
+ from caption_anything.utils.densecap_painter import draw_bbox
18
+
19
+ class CaptionAnything:
20
+ def __init__(self, args, api_key="", captioner=None, segmenter=None, ocr_reader=None, text_refiner=None):
21
+ self.args = args
22
+ self.captioner = build_captioner(args.captioner, args.device, args) if captioner is None else captioner
23
+ self.segmenter = build_segmenter(args.segmenter, args.device, args) if segmenter is None else segmenter
24
+ self.segmenter_densecap = build_segmenter_densecap(args.segmenter, args.device, args, model=self.segmenter.model)
25
+ self.ocr_lang = ["ch_tra", "en"]
26
+ self.ocr_reader = ocr_reader if ocr_reader is not None else easyocr.Reader(self.ocr_lang)
27
+
28
+
29
+ self.text_refiner = None
30
+ if not args.disable_gpt:
31
+ if text_refiner is not None:
32
+ self.text_refiner = text_refiner
33
+ elif api_key != "":
34
+ self.init_refiner(api_key)
35
+ self.require_caption_prompt = args.captioner == 'blip2'
36
+
37
+ @property
38
+ def image_embedding(self):
39
+ return self.segmenter.image_embedding
40
+
41
+ @image_embedding.setter
42
+ def image_embedding(self, image_embedding):
43
+ self.segmenter.image_embedding = image_embedding
44
+
45
+ @property
46
+ def original_size(self):
47
+ return self.segmenter.predictor.original_size
48
+
49
+ @original_size.setter
50
+ def original_size(self, original_size):
51
+ self.segmenter.predictor.original_size = original_size
52
+
53
+ @property
54
+ def input_size(self):
55
+ return self.segmenter.predictor.input_size
56
+
57
+ @input_size.setter
58
+ def input_size(self, input_size):
59
+ self.segmenter.predictor.input_size = input_size
60
+
61
+ def setup(self, image_embedding, original_size, input_size, is_image_set):
62
+ self.image_embedding = image_embedding
63
+ self.original_size = original_size
64
+ self.input_size = input_size
65
+ self.segmenter.predictor.is_image_set = is_image_set
66
+
67
+ def init_refiner(self, api_key):
68
+ try:
69
+ self.text_refiner = build_text_refiner(self.args.text_refiner, self.args.device, self.args, api_key)
70
+ self.text_refiner.llm('hi') # test
71
+ except:
72
+ self.text_refiner = None
73
+ print('OpenAI GPT is not available')
74
+
75
+ def inference(self, image, prompt, controls, disable_gpt=False, enable_wiki=False, verbose=False, is_densecap=False, args={}):
76
+ # segment with prompt
77
+ print("CA prompt: ", prompt, "CA controls", controls)
78
+ is_seg_everything = 'everything' in prompt['prompt_type']
79
+
80
+ args['seg_crop_mode'] = args.get('seg_crop_mode', self.args.seg_crop_mode)
81
+ args['clip_filter'] = args.get('clip_filter', self.args.clip_filter)
82
+ args['disable_regular_box'] = args.get('disable_regular_box', self.args.disable_regular_box)
83
+ args['context_captions'] = args.get('context_captions', self.args.context_captions)
84
+ args['enable_reduce_tokens'] = args.get('enable_reduce_tokens', self.args.enable_reduce_tokens)
85
+ args['enable_morphologyex'] = args.get('enable_morphologyex', self.args.enable_morphologyex)
86
+ args['topN'] = args.get('topN', 10) if is_seg_everything else 1
87
+ args['min_mask_area'] = args.get('min_mask_area', 0)
88
+
89
+ if not is_densecap:
90
+ seg_results = self.segmenter.inference(image, prompt)
91
+ else:
92
+ seg_results = self.segmenter_densecap.inference(image, prompt)
93
+
94
+ seg_masks, seg_bbox, seg_area = seg_results if is_seg_everything else (seg_results, None, None)
95
+
96
+ if args['topN'] > 1: # sort by area
97
+ samples = list(zip(*[seg_masks, seg_bbox, seg_area]))
98
+ # top_samples = sorted(samples, key=lambda x: x[2], reverse=True)
99
+ # seg_masks, seg_bbox, seg_area = list(zip(*top_samples))
100
+ samples = list(filter(lambda x: x[2] > args['min_mask_area'], samples))
101
+ samples = samples[:args['topN']]
102
+ seg_masks, seg_bbox, seg_area = list(zip(*samples))
103
+
104
+ out_list = []
105
+ for i, seg_mask in enumerate(seg_masks):
106
+ if args['enable_morphologyex']:
107
+ seg_mask = 255 * seg_mask.astype(np.uint8)
108
+ seg_mask = np.stack([seg_mask, seg_mask, seg_mask], axis=-1)
109
+ seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_OPEN, kernel=np.ones((6, 6), np.uint8))
110
+ seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_CLOSE, kernel=np.ones((6, 6), np.uint8))
111
+ seg_mask = seg_mask[:, :, 0] > 0
112
+
113
+ seg_mask_img = Image.fromarray(seg_mask.astype('int') * 255.)
114
+ mask_save_path = None
115
+
116
+ if verbose:
117
+ mask_save_path = f'result/mask_{time.time()}.png'
118
+ if not os.path.exists(os.path.dirname(mask_save_path)):
119
+ os.makedirs(os.path.dirname(mask_save_path))
120
+
121
+ if seg_mask_img.mode != 'RGB':
122
+ seg_mask_img = seg_mask_img.convert('RGB')
123
+ seg_mask_img.save(mask_save_path)
124
+ print('seg_mask path: ', mask_save_path)
125
+ print("seg_mask.shape: ", seg_mask.shape)
126
+
127
+
128
+ # captioning with mask
129
+ if args['enable_reduce_tokens']:
130
+ result = self.captioner.inference_with_reduced_tokens(image, seg_mask,
131
+ crop_mode=args['seg_crop_mode'],
132
+ filter=args['clip_filter'],
133
+ disable_regular_box=args['disable_regular_box'],
134
+ verbose=verbose,
135
+ caption_args=args)
136
+ else:
137
+ result = self.captioner.inference_seg(image, seg_mask,
138
+ crop_mode=args['seg_crop_mode'],
139
+ filter=args['clip_filter'],
140
+ disable_regular_box=args['disable_regular_box'],
141
+ verbose=verbose,
142
+ caption_args=args)
143
+ caption = result.get('caption', None)
144
+ crop_save_path = result.get('crop_save_path', None)
145
+
146
+ # refining with TextRefiner
147
+ context_captions = []
148
+ if args['context_captions']:
149
+ context_captions.append(self.captioner.inference(image)['caption'])
150
+ if not disable_gpt and self.text_refiner is not None:
151
+ refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions,
152
+ enable_wiki=enable_wiki)
153
+ else:
154
+ refined_caption = {'raw_caption': caption}
155
+ out = {'generated_captions': refined_caption,
156
+ 'crop_save_path': crop_save_path,
157
+ 'mask_save_path': mask_save_path,
158
+ 'mask': seg_mask_img,
159
+ 'bbox': seg_bbox[i] if seg_bbox is not None else None,
160
+ 'area': seg_area[i] if seg_area is not None else None,
161
+ 'context_captions': context_captions,
162
+ 'ppl_score': result.get('ppl_score', -100.),
163
+ 'clip_score': result.get('clip_score', 0.)
164
+ }
165
+ out_list.append(out)
166
+ return out_list
167
+
168
+ def parse_dense_caption(self, image, topN=10, reference_caption=[], verbose=False):
169
+ width, height = get_image_shape(image)
170
+ prompt = {'prompt_type': ['everything']}
171
+ densecap_args = {
172
+ 'return_ppl': True,
173
+ 'clip_filter': True,
174
+ 'reference_caption': reference_caption,
175
+ 'text_prompt': "", # 'Question: what does the image show? Answer:'
176
+ 'seg_crop_mode': 'w_bg',
177
+ # 'text_prompt': "",
178
+ # 'seg_crop_mode': 'wo_bg',
179
+ 'disable_regular_box': False,
180
+ 'topN': topN,
181
+ 'min_ppl_score': -1.8,
182
+ 'min_clip_score': 0.30,
183
+ 'min_mask_area': 2500,
184
+ }
185
+
186
+ dense_captions = self.inference(image, prompt,
187
+ controls=None,
188
+ disable_gpt=True,
189
+ verbose=verbose,
190
+ is_densecap=True,
191
+ args=densecap_args)
192
+ print('Process Dense Captioning: \n', dense_captions)
193
+ dense_captions = list(filter(lambda x: x['ppl_score'] / (1+len(x['generated_captions']['raw_caption'].split())) >= densecap_args['min_ppl_score'], dense_captions))
194
+ dense_captions = list(filter(lambda x: x['clip_score'] >= densecap_args['min_clip_score'], dense_captions))
195
+ dense_cap_prompt = []
196
+ for cap in dense_captions:
197
+ x, y, w, h = cap['bbox']
198
+ cx, cy = x + w/2, (y + h/2)
199
+ dense_cap_prompt.append("({}: X:{:.0f}, Y:{:.0f}, Width:{:.0f}, Height:{:.0f})".format(cap['generated_captions']['raw_caption'], cx, cy, w, h))
200
+
201
+ if verbose:
202
+ all_masks = [np.array(item['mask'].convert('P')) for item in dense_captions]
203
+ new_image = mask_painter_foreground_all(np.array(image), all_masks, background_alpha=0.4)
204
+ save_path = 'result/dense_caption_mask.png'
205
+ Image.fromarray(new_image).save(save_path)
206
+ print(f'Dense captioning mask saved in {save_path}')
207
+
208
+ vis_path = 'result/dense_caption_vis_{}.png'.format(time.time())
209
+ dense_cap_painter_input = [{'bbox': xywh_to_x1y1x2y2(cap['bbox']),
210
+ 'caption': cap['generated_captions']['raw_caption']} for cap in dense_captions]
211
+ draw_bbox(load_image(image, return_type='numpy'), vis_path, dense_cap_painter_input, show_caption=True)
212
+ print(f'Dense Captioning visualization saved in {vis_path}')
213
+ return ','.join(dense_cap_prompt)
214
+
215
+ def parse_ocr(self, image, thres=0.2):
216
+ width, height = get_image_shape(image)
217
+ image = load_image(image, return_type='numpy')
218
+ bounds = self.ocr_reader.readtext(image)
219
+ bounds = [bound for bound in bounds if bound[2] > thres]
220
+ print('Process OCR Text:\n', bounds)
221
+
222
+ ocr_prompt = []
223
+ for box, text, conf in bounds:
224
+ p0, p1, p2, p3 = box
225
+ ocr_prompt.append('(\"{}\": X:{:.0f}, Y:{:.0f})'.format(text, (p0[0]+p1[0]+p2[0]+p3[0])/4, (p0[1]+p1[1]+p2[1]+p3[1])/4))
226
+ ocr_prompt = '\n'.join(ocr_prompt)
227
+
228
+ # ocr_prompt = self.text_refiner.llm(f'The image have some scene texts with their locations: {ocr_prompt}. Please group these individual words into one or several phrase based on their relative positions (only give me your answer, do not show explanination)').strip()
229
+
230
+ # ocr_prefix1 = f'The image have some scene texts with their locations: {ocr_prompt}. Please group these individual words into one or several phrase based on their relative positions (only give me your answer, do not show explanination)'
231
+ # ocr_prefix2 = f'Please group these individual words into 1-3 phrases, given scene texts with their locations: {ocr_prompt}. You return is one or several strings and infer their locations. (only give me your answer like (“man working”, X: value, Y: value), do not show explanination)'
232
+ # ocr_prefix4 = f'summarize the individual scene text words detected by OCR tools into a fluent sentence based on their positions and distances. You should strictly describe all of the given scene text words. Do not miss any given word. Do not create non-exist words. Do not appear numeric positions. The individual words are given:\n{ocr_prompt}\n'
233
+ # ocr_prefix3 = f'combine the individual scene text words detected by OCR tools into one/several fluent phrases/sentences based on their positions and distances. You should strictly copy or correct all of the given scene text words. Do not miss any given word. Do not create non-exist words. The response is several strings seperate with their location (X, Y), each of which represents a phrase. The individual words are given:\n{ocr_prompt}\n'
234
+ # response = self.text_refiner.llm(ocr_prefix3).strip() if len(ocr_prompt) else ""
235
+ return ocr_prompt
236
+
237
+ def inference_cap_everything(self, image, verbose=False):
238
+ image = load_image(image, return_type='pil')
239
+ image = image_resize(image, res=1024)
240
+ width, height = get_image_shape(image)
241
+ other_args = {'text_prompt': ""} if self.require_caption_prompt else {}
242
+ img_caption = self.captioner.inference(image, filter=False, args=other_args)['caption']
243
+ dense_caption_prompt = self.parse_dense_caption(image, topN=10, verbose=verbose, reference_caption=[])
244
+ scene_text_prompt = self.parse_ocr(image, thres=0.2)
245
+ # scene_text_prompt = "N/A"
246
+
247
+ # the summarize_prompt is modified from https://github.com/JialianW/GRiT and https://github.com/showlab/Image2Paragraph
248
+ summarize_prompt = "Imagine you are a blind but intelligent image captioner. You should generate a descriptive, coherent and human-like paragraph based on the given information (a,b,c,d) instead of imagination:\na) Image Resolution: {image_size}\nb) Image Caption:{image_caption}\nc) Dense Caption: {dense_caption}\nd) Scene Text: {scene_text}\nThere are some rules for your response: Show objects with their attributes (e.g. position, color, size, shape, texture).\nPrimarily describe common objects with large size.\nProvide context of the image.\nShow relative position between objects.\nLess than 6 sentences.\nDo not appear number.\nDo not describe any individual letter.\nDo not show the image resolution.\nIngore the white background."
249
+ prompt = summarize_prompt.format(**{
250
+ "image_size": "width {} height {}".format(width, height),
251
+ "image_caption":img_caption,
252
+ "dense_caption": dense_caption_prompt,
253
+ "scene_text": scene_text_prompt})
254
+ print(f'caption everything prompt: {prompt}')
255
+ response = self.text_refiner.llm(prompt).strip()
256
+ # chinese_response = self.text_refiner.llm('Translate it into Chinese: {}'.format(response)).strip()
257
+ return response
258
+
259
+ if __name__ == "__main__":
260
+ from caption_anything.utils.parser import parse_augment
261
+ args = parse_augment()
262
+ image_path = 'result/wt/memes/87226084.jpg'
263
+ image = Image.open(image_path)
264
+ prompts = [
265
+ {
266
+ "prompt_type": ["click"],
267
+ "input_point": [[500, 300], [200, 500]],
268
+ "input_label": [1, 0],
269
+ "multimask_output": "True",
270
+ },
271
+ # {
272
+ # "prompt_type": ["click"],
273
+ # "input_point": [[300, 800]],
274
+ # "input_label": [1],
275
+ # "multimask_output": "True",
276
+ # }
277
+ ]
278
+ controls = {
279
+ "length": "30",
280
+ "sentiment": "positive",
281
+ # "imagination": "True",
282
+ "imagination": "False",
283
+ "language": "English",
284
+ }
285
+
286
+ model = CaptionAnything(args, os.environ['OPENAI_API_KEY'])
287
+ img_dir = 'test_images/memes'
288
+ for image_file in os.listdir(img_dir):
289
+ image_path = os.path.join(img_dir, image_file)
290
+ print('image_path:', image_path)
291
+ paragraph = model.inference_cap_everything(image_path, verbose=True)
292
+ print('Caption Everything:\n', paragraph)
293
+ ocr = model.parse_ocr(image_path)
294
+ print('OCR', ocr)
caption_anything/segmenter/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_segmenter import BaseSegmenter
2
+ from caption_anything.utils.utils import seg_model_map
3
+ import copy
4
+
5
+ def build_segmenter(model_name, device, args, model=None):
6
+ return BaseSegmenter(device, args.segmenter_checkpoint, model_name, reuse_feature=not args.disable_reuse_features, model=model, args=args)
7
+
8
+ def build_segmenter_densecap(model_name, device, args, model=None):
9
+ args_for_densecap = copy.deepcopy(args)
10
+ args_for_densecap.pred_iou_thresh = 0.88
11
+ args_for_densecap.min_mask_region_area = 400
12
+ args_for_densecap.stability_score_thresh = 0.95
13
+ args_for_densecap.box_nms_thresh = 0.3
14
+ return BaseSegmenter(device, args.segmenter_checkpoint, model_name, reuse_feature=not args.disable_reuse_features, model=model, args=args)
caption_anything/segmenter/base_segmenter.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import cv2
4
+ from PIL import Image, ImageDraw, ImageOps
5
+ import numpy as np
6
+ from typing import Union
7
+ from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
8
+ from caption_anything.utils.utils import prepare_segmenter, seg_model_map, load_image
9
+ import matplotlib.pyplot as plt
10
+ import PIL
11
+
12
+
13
+ class BaseSegmenter:
14
+ def __init__(self, device, checkpoint, model_name='huge', reuse_feature=True, model=None, args=None):
15
+ print(f"Initializing BaseSegmenter to {device}")
16
+ self.device = device
17
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
18
+ self.processor = None
19
+ if model is None:
20
+ if checkpoint is None:
21
+ _, checkpoint = prepare_segmenter(model_name)
22
+ self.model = sam_model_registry[seg_model_map[model_name]](checkpoint=checkpoint)
23
+ self.checkpoint = checkpoint
24
+ self.model.to(device=self.device)
25
+ else:
26
+ self.model = model
27
+ self.reuse_feature = reuse_feature
28
+ self.predictor = SamPredictor(self.model)
29
+
30
+ sam_generator_keys = ['pred_iou_thresh', 'min_mask_region_area', 'stability_score_thresh', 'box_nms_thresh']
31
+ generator_args = {k:v for k,v in vars(args).items() if k in sam_generator_keys}
32
+ self.mask_generator = SamAutomaticMaskGenerator(model=self.model, **generator_args)
33
+ self.image_embedding = None
34
+ self.image = None
35
+
36
+ @torch.no_grad()
37
+ def set_image(self, image: Union[np.ndarray, Image.Image, str]):
38
+ image = load_image(image, return_type='numpy')
39
+ self.image = image
40
+ if self.reuse_feature:
41
+ self.predictor.set_image(image)
42
+ self.image_embedding = self.predictor.get_image_embedding()
43
+ print(self.image_embedding.shape)
44
+
45
+ @torch.no_grad()
46
+ def inference(self, image: Union[np.ndarray, Image.Image, str], control: dict):
47
+ """
48
+ SAM inference of image according to control.
49
+ Args:
50
+ image: str or PIL.Image or np.ndarray
51
+ control: dict to control SAM.
52
+ prompt_type:
53
+ 1. {control['prompt_type'] = ['everything']} to segment everything in the image.
54
+ 2. {control['prompt_type'] = ['click', 'box']} to segment according to click and box.
55
+ 3. {control['prompt_type'] = ['click'] to segment according to click.
56
+ 4. {control['prompt_type'] = ['box'] to segment according to box.
57
+ input_point: list of [x, y] coordinates of click.
58
+ input_label: List of labels for points accordingly, 0 for negative, 1 for positive.
59
+ input_box: List of [x1, y1, x2, y2] coordinates of box.
60
+ multimask_output:
61
+ If true, the model will return three masks.
62
+ For ambiguous input prompts (such as a single click), this will often
63
+ produce better masks than a single prediction. If only a single
64
+ mask is needed, the model's predicted quality score can be used
65
+ to select the best mask. For non-ambiguous prompts, such as multiple
66
+ input prompts, multimask_output=False can give better results.
67
+ Returns:
68
+ masks: np.ndarray of shape [num_masks, height, width]
69
+
70
+ """
71
+ image = load_image(image, return_type='numpy')
72
+ if 'everything' in control['prompt_type']:
73
+ masks = self.mask_generator.generate(image)
74
+ new_masks = np.concatenate([mask["segmentation"][np.newaxis, :] for mask in masks])
75
+ bbox = np.array([mask["bbox"] for mask in masks])
76
+ area = np.array([mask["area"] for mask in masks])
77
+ return new_masks, bbox, area
78
+ else:
79
+ if not self.reuse_feature or self.image_embedding is None:
80
+ self.set_image(image)
81
+ self.predictor.set_image(self.image)
82
+ else:
83
+ assert self.image_embedding is not None
84
+ self.predictor.features = self.image_embedding
85
+
86
+ if 'mutimask_output' in control:
87
+ masks, scores, logits = self.predictor.predict(
88
+ point_coords=np.array(control['input_point']),
89
+ point_labels=np.array(control['input_label']),
90
+ multimask_output=True,
91
+ )
92
+ elif 'input_boxes' in control:
93
+ transformed_boxes = self.predictor.transform.apply_boxes_torch(
94
+ torch.tensor(control["input_boxes"], device=self.predictor.device),
95
+ image.shape[1::-1] # Reverse shape because numpy is (W, H) and function need (H, W)
96
+ )
97
+ masks, _, _ = self.predictor.predict_torch(
98
+ point_coords=None,
99
+ point_labels=None,
100
+ boxes=transformed_boxes,
101
+ multimask_output=False,
102
+ )
103
+ masks = masks.squeeze(1).cpu().numpy()
104
+
105
+ else:
106
+ input_point = np.array(control['input_point']) if 'click' in control['prompt_type'] else None
107
+ input_label = np.array(control['input_label']) if 'click' in control['prompt_type'] else None
108
+ input_box = np.array(control['input_box']) if 'box' in control['prompt_type'] else None
109
+
110
+ masks, scores, logits = self.predictor.predict(
111
+ point_coords=input_point,
112
+ point_labels=input_label,
113
+ box=input_box,
114
+ multimask_output=False,
115
+ )
116
+
117
+ if 0 in control['input_label']:
118
+ mask_input = logits[np.argmax(scores), :, :]
119
+ masks, scores, logits = self.predictor.predict(
120
+ point_coords=input_point,
121
+ point_labels=input_label,
122
+ box=input_box,
123
+ mask_input=mask_input[None, :, :],
124
+ multimask_output=False,
125
+ )
126
+
127
+ return masks
128
+
129
+
130
+ if __name__ == "__main__":
131
+ image_path = 'segmenter/images/truck.jpg'
132
+ prompts = [
133
+ # {
134
+ # "prompt_type":["click"],
135
+ # "input_point":[[500, 375]],
136
+ # "input_label":[1],
137
+ # "multimask_output":"True",
138
+ # },
139
+ {
140
+ "prompt_type": ["click"],
141
+ "input_point": [[1000, 600], [1325, 625]],
142
+ "input_label": [1, 0],
143
+ },
144
+ # {
145
+ # "prompt_type":["click", "box"],
146
+ # "input_box":[425, 600, 700, 875],
147
+ # "input_point":[[575, 750]],
148
+ # "input_label": [0]
149
+ # },
150
+ # {
151
+ # "prompt_type":["box"],
152
+ # "input_boxes": [
153
+ # [75, 275, 1725, 850],
154
+ # [425, 600, 700, 875],
155
+ # [1375, 550, 1650, 800],
156
+ # [1240, 675, 1400, 750],
157
+ # ]
158
+ # },
159
+ # {
160
+ # "prompt_type":["everything"]
161
+ # },
162
+ ]
163
+
164
+ init_time = time.time()
165
+ segmenter = BaseSegmenter(
166
+ device='cuda',
167
+ # checkpoint='sam_vit_h_4b8939.pth',
168
+ checkpoint='segmenter/sam_vit_h_4b8939.pth',
169
+ model_type='vit_h',
170
+ reuse_feature=True
171
+ )
172
+ print(f'init time: {time.time() - init_time}')
173
+
174
+ image_path = 'test_images/img2.jpg'
175
+ infer_time = time.time()
176
+ for i, prompt in enumerate(prompts):
177
+ print(f'{prompt["prompt_type"]} mode')
178
+ image = Image.open(image_path)
179
+ segmenter.set_image(np.array(image))
180
+ masks = segmenter.inference(np.array(image), prompt)
181
+ Image.fromarray(masks[0]).save('seg.png')
182
+ print(masks.shape)
183
+
184
+ print(f'infer time: {time.time() - infer_time}')
caption_anything/segmenter/readme.md ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Prepare SAM
2
+ ```
3
+ pip install git+https://github.com/facebookresearch/segment-anything.git
4
+ ```
5
+ or
6
+ ```
7
+ git clone git@github.com:facebookresearch/segment-anything.git
8
+ cd segment-anything; pip install -e .
9
+ ```
10
+
11
+ ```
12
+ pip install opencv-python pycocotools matplotlib onnxruntime onnx
13
+ ```
14
+ ### Download the checkpoint:
15
+
16
+ https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
17
+
18
+ ### Inference
19
+
20
+ The prompts are in json format:
21
+
22
+ ```
23
+ prompts = [
24
+ {
25
+ "prompt_type":["click"],
26
+ "input_point":[[500, 375]],
27
+ "input_label":[1],
28
+ "multimask_output":"True",
29
+ },
30
+ {
31
+ "prompt_type":["click"],
32
+ "input_point":[[500, 375], [1125, 625]],
33
+ "input_label":[1, 0],
34
+ },
35
+ {
36
+ "prompt_type":["click", "box"],
37
+ "input_box":[425, 600, 700, 875],
38
+ "input_point":[[575, 750]],
39
+ "input_label": [0]
40
+ },
41
+ {
42
+ "prompt_type":["box"],
43
+ "input_boxes": [
44
+ [75, 275, 1725, 850],
45
+ [425, 600, 700, 875],
46
+ [1375, 550, 1650, 800],
47
+ [1240, 675, 1400, 750],
48
+ ]
49
+ },
50
+ {
51
+ "prompt_type":["everything"]
52
+ },
53
+ ]
54
+ ```
55
+
56
+ In `base_segmenter.py`:
57
+ ```
58
+ segmenter = BaseSegmenter(
59
+ device='cuda',
60
+ checkpoint='sam_vit_h_4b8939.pth',
61
+ model_type='vit_h'
62
+ )
63
+
64
+ for i, prompt in enumerate(prompts):
65
+ masks = segmenter.inference(image_path, prompt)
66
+ ```
67
+
68
+ Outputs are masks (True and False numpy Matrix), shape: (num of masks, height, weight)
caption_anything/text_refiner/README.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Install
2
+ * python >= 3.8.1
3
+
4
+ ```bash
5
+ pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 torchaudio==0.10.1 -f https://download.pytorch.org/whl/cu111/torch_stable.html # CUDA version could be different
6
+ pip install openai pillow transformers
7
+ pip install langchain==0.0.101
8
+ ```
caption_anything/text_refiner/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .text_refiner import TextRefiner
2
+
3
+
4
+ def build_text_refiner(type, device, args=None, api_key=""):
5
+ if type == 'base':
6
+ return TextRefiner(device, api_key)
caption_anything/text_refiner/text_refiner.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.llms.openai import OpenAI
2
+ import torch
3
+ from PIL import Image, ImageDraw, ImageOps
4
+ from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
5
+ import pdb
6
+
7
+ class TextRefiner:
8
+ def __init__(self, device, api_key=""):
9
+ print(f"Initializing TextRefiner to {device}")
10
+ self.llm = OpenAI(model_name="gpt-3.5-turbo", temperature=0, openai_api_key=api_key)
11
+ self.prompt_tag = {
12
+ "imagination": {"True": "could",
13
+ "False": "could not"}
14
+ }
15
+ self.short_prompts = {
16
+ "length": "around {length} words",
17
+ "sentiment": "of {sentiment} sentiment",
18
+ "language": "in {language}",
19
+ }
20
+
21
+ self.long_prompts = {
22
+ "imagination": "The new sentence could extend the original description by using your imagination to create additional details, or think about what might have happened before or after the scene in the image, but should not conflict with the original sentence",
23
+ }
24
+
25
+ self.wiki_prompts = "I want you to act as a Wikipedia page. I will give you a sentence and you will parse the single main object in the sentence and provide a summary of that object in the format of a Wikipedia page. Your summary should be informative and factual, covering the most important aspects of the object. Start your summary with an introductory paragraph that gives an overview of the object. The overall length of the response should be around 100 words. You should not describe the parsing process and only provide the final summary. The sentence is \"{query}\"."
26
+
27
+ self.control_prompts = "As a text reviser, you will convert an image description into a new sentence or long paragraph. The new text is {prompts}. {long_prompts} The sentence is \"{query}\" (give me the revised sentence only)"
28
+
29
+ def parse(self, response):
30
+ out = response.strip()
31
+ return out
32
+
33
+ def parse2(self, response):
34
+ out = response.strip()
35
+ return out
36
+
37
+ def prepare_input(self, query, short_prompts, long_prompts):
38
+ input = self.control_prompts.format(**{'prompts': ', '.join(short_prompts), 'long_prompts': '. '.join(long_prompts), 'query': query})
39
+ print('prompt: ', input)
40
+ return input
41
+
42
+ def inference(self, query: str, controls: dict, context: list=[], enable_wiki=False):
43
+ """
44
+ query: the caption of the region of interest, generated by captioner
45
+ controls: a dict of control singals, e.g., {"length": 5, "sentiment": "positive"}
46
+ """
47
+ prompts = []
48
+ long_prompts = []
49
+ for control, value in controls.items():
50
+ # if control in self.prompt_tag:
51
+ # value = self.prompt_tag[control][value]
52
+ if control in self.short_prompts:
53
+ prompts.append(self.short_prompts[control].format(**{control: value}))
54
+ else:
55
+ if value in [True, "True", "true"]:
56
+ long_prompts.append(self.long_prompts[control])
57
+ input = self.prepare_input(query, prompts, long_prompts)
58
+ response = self.llm(input)
59
+ response = self.parse(response)
60
+
61
+ response_wiki = ""
62
+ if enable_wiki:
63
+ tmp_configs = {"query": query}
64
+ prompt_wiki = self.wiki_prompts.format(**tmp_configs)
65
+ response_wiki = self.llm(prompt_wiki)
66
+ response_wiki = self.parse2(response_wiki)
67
+ out = {
68
+ 'raw_caption': query,
69
+ 'caption': response,
70
+ 'wiki': response_wiki
71
+ }
72
+ print(out)
73
+ return out
74
+
75
+ if __name__ == "__main__":
76
+ model = TextRefiner(device='cpu')
77
+ controls = {
78
+ "length": "30",
79
+ "sentiment": "negative",
80
+ # "imagination": "True",
81
+ "imagination": "False",
82
+ "language": "English",
83
+ }
84
+ # model.inference(query='a dog is sitting on a brown bench', controls=controls)
85
+ model.inference(query='a cat is sleeping', controls=controls)
86
+
caption_anything/utils/chatbot.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft
2
+ # Modified from Visual ChatGPT Project https://github.com/microsoft/TaskMatrix/blob/main/visual_chatgpt.py
3
+
4
+ import os
5
+ import gradio as gr
6
+ import re
7
+ import uuid
8
+ from PIL import Image, ImageDraw, ImageOps
9
+ import numpy as np
10
+ import argparse
11
+ import inspect
12
+
13
+ from langchain.agents.initialize import initialize_agent
14
+ from langchain.agents.tools import Tool
15
+ from langchain.chains.conversation.memory import ConversationBufferMemory
16
+ from langchain.llms.openai import OpenAI
17
+ import torch
18
+ from PIL import Image, ImageDraw, ImageOps
19
+ from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
20
+
21
+ VISUAL_CHATGPT_PREFIX = """
22
+ I want you act as Caption Anything Chatbox (short as CATchat), which is designed to be able to assist with a wide range of text and visual related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. You are able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
23
+
24
+ As a language model, you can not directly read images, but can invoke VQA tool to indirectly understand pictures, by repeatly asking questions about the objects and scene of the image. You should carefully asking informative questions to maximize your information about this image content. Each image will have a file name formed as "chat_image/xxx.png", you are very strict to the file name and will never fabricate nonexistent files.
25
+
26
+ You have access to the following tools:"""
27
+
28
+
29
+ # TOOLS:
30
+ # ------
31
+
32
+ # Visual ChatGPT has access to the following tools:"""
33
+
34
+ VISUAL_CHATGPT_FORMAT_INSTRUCTIONS = """To use a tool, please use the following format:
35
+
36
+ "Thought: Do I need to use a tool? Yes
37
+ Action: the action to take, should be one of [{tool_names}], remember the action must to be one tool
38
+ Action Input: the input to the action
39
+ Observation: the result of the action"
40
+
41
+ When you have a response to say to the Human, or if you do not need to use a tool, you MUST use the format:
42
+
43
+ "Thought: Do I need to use a tool? No
44
+ {ai_prefix}: [your response here]"
45
+
46
+ """
47
+
48
+ VISUAL_CHATGPT_SUFFIX = """
49
+ Begin Chatting!
50
+
51
+ Previous conversation history:
52
+ {chat_history}
53
+
54
+ New input: {input}
55
+ As a language model, you must repeatly to use VQA tools to observe images. You response should be consistent with the outputs of the VQA tool instead of imagination. Do not repeat asking the same question.
56
+
57
+ Thought: Do I need to use a tool? {agent_scratchpad} (You are strictly to use the aforementioned "Thought/Action/Action Input/Observation" format as the answer.)"""
58
+
59
+ os.makedirs('chat_image', exist_ok=True)
60
+
61
+
62
+ def prompts(name, description):
63
+ def decorator(func):
64
+ func.name = name
65
+ func.description = description
66
+ return func
67
+ return decorator
68
+
69
+ def cut_dialogue_history(history_memory, keep_last_n_words=500):
70
+ if history_memory is None or len(history_memory) == 0:
71
+ return history_memory
72
+ tokens = history_memory.split()
73
+ n_tokens = len(tokens)
74
+ print(f"history_memory:{history_memory}, n_tokens: {n_tokens}")
75
+ if n_tokens < keep_last_n_words:
76
+ return history_memory
77
+ paragraphs = history_memory.split('\n')
78
+ last_n_tokens = n_tokens
79
+ while last_n_tokens >= keep_last_n_words:
80
+ last_n_tokens -= len(paragraphs[0].split(' '))
81
+ paragraphs = paragraphs[1:]
82
+ return '\n' + '\n'.join(paragraphs)
83
+
84
+ def get_new_image_name(folder='chat_image', func_name="update"):
85
+ this_new_uuid = str(uuid.uuid4())[:8]
86
+ new_file_name = f'{func_name}_{this_new_uuid}.png'
87
+ return os.path.join(folder, new_file_name)
88
+
89
+ class VisualQuestionAnswering:
90
+ def __init__(self, device):
91
+ print(f"Initializing VisualQuestionAnswering to {device}")
92
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
93
+ self.device = device
94
+ self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
95
+ self.model = BlipForQuestionAnswering.from_pretrained(
96
+ "Salesforce/blip-vqa-base", torch_dtype=self.torch_dtype).to(self.device)
97
+ # self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
98
+ # self.model = BlipForQuestionAnswering.from_pretrained(
99
+ # "Salesforce/blip-vqa-capfilt-large", torch_dtype=self.torch_dtype).to(self.device)
100
+
101
+ @prompts(name="Answer Question About The Image",
102
+ description="VQA tool is useful when you need an answer for a question based on an image. "
103
+ "like: what is the color of an object, how many cats in this figure, where is the child sitting, what does the cat doing, why is he laughing."
104
+ "The input to this tool should be a comma separated string of two, representing the image path and the question.")
105
+ def inference(self, inputs):
106
+ image_path, question = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
107
+ raw_image = Image.open(image_path).convert('RGB')
108
+ inputs = self.processor(raw_image, question, return_tensors="pt").to(self.device, self.torch_dtype)
109
+ out = self.model.generate(**inputs)
110
+ answer = self.processor.decode(out[0], skip_special_tokens=True)
111
+ print(f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input Question: {question}, "
112
+ f"Output Answer: {answer}")
113
+ return answer
114
+
115
+ def build_chatbot_tools(load_dict):
116
+ print(f"Initializing ChatBot, load_dict={load_dict}")
117
+ models = {}
118
+ # Load Basic Foundation Models
119
+ for class_name, device in load_dict.items():
120
+ models[class_name] = globals()[class_name](device=device)
121
+
122
+ # Load Template Foundation Models
123
+ for class_name, module in globals().items():
124
+ if getattr(module, 'template_model', False):
125
+ template_required_names = {k for k in inspect.signature(module.__init__).parameters.keys() if k!='self'}
126
+ loaded_names = set([type(e).__name__ for e in models.values()])
127
+ if template_required_names.issubset(loaded_names):
128
+ models[class_name] = globals()[class_name](
129
+ **{name: models[name] for name in template_required_names})
130
+
131
+ tools = []
132
+ for instance in models.values():
133
+ for e in dir(instance):
134
+ if e.startswith('inference'):
135
+ func = getattr(instance, e)
136
+ tools.append(Tool(name=func.name, description=func.description, func=func))
137
+ return tools
138
+
139
+ class ConversationBot:
140
+ def __init__(self, tools, api_key=""):
141
+ # load_dict = {'VisualQuestionAnswering':'cuda:0', 'ImageCaptioning':'cuda:1',...}
142
+ llm = OpenAI(model_name="gpt-3.5-turbo", temperature=0.7, openai_api_key=api_key)
143
+ self.llm = llm
144
+ self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')
145
+ self.tools = tools
146
+ self.current_image = None
147
+ self.point_prompt = ""
148
+ self.global_prompt = ""
149
+ self.agent = initialize_agent(
150
+ self.tools,
151
+ self.llm,
152
+ agent="conversational-react-description",
153
+ verbose=True,
154
+ memory=self.memory,
155
+ return_intermediate_steps=True,
156
+ agent_kwargs={'prefix': VISUAL_CHATGPT_PREFIX, 'format_instructions': VISUAL_CHATGPT_FORMAT_INSTRUCTIONS,
157
+ 'suffix': VISUAL_CHATGPT_SUFFIX}, )
158
+
159
+ def constructe_intermediate_steps(self, agent_res):
160
+ ans = []
161
+ for action, output in agent_res:
162
+ if hasattr(action, "tool_input"):
163
+ use_tool = "Yes"
164
+ act = (f"Thought: Do I need to use a tool? {use_tool}\nAction: {action.tool}\nAction Input: {action.tool_input}", f"Observation: {output}")
165
+ else:
166
+ use_tool = "No"
167
+ act = (f"Thought: Do I need to use a tool? {use_tool}", f"AI: {output}")
168
+ act= list(map(lambda x: x.replace('\n', '<br>'), act))
169
+ ans.append(act)
170
+ return ans
171
+
172
+ def run_text(self, text, state, aux_state):
173
+ self.agent.memory.buffer = cut_dialogue_history(self.agent.memory.buffer, keep_last_n_words=500)
174
+ if self.point_prompt != "":
175
+ Human_prompt = f'\nHuman: {self.point_prompt}\n'
176
+ AI_prompt = 'Ok'
177
+ self.agent.memory.buffer = self.agent.memory.buffer + Human_prompt + 'AI: ' + AI_prompt
178
+ self.point_prompt = ""
179
+ res = self.agent({"input": text})
180
+ res['output'] = res['output'].replace("\\", "/")
181
+ response = re.sub('(chat_image/\S*png)', lambda m: f'![](/file={m.group(0)})*{m.group(0)}*', res['output'])
182
+ state = state + [(text, response)]
183
+
184
+ aux_state = aux_state + [(f"User Input: {text}", None)]
185
+ aux_state = aux_state + self.constructe_intermediate_steps(res['intermediate_steps'])
186
+ print(f"\nProcessed run_text, Input text: {text}\nCurrent state: {state}\n"
187
+ f"Current Memory: {self.agent.memory.buffer}\n"
188
+ f"Aux state: {aux_state}\n"
189
+ )
190
+ return state, state, aux_state, aux_state
191
+
192
+
193
+ if __name__ == '__main__':
194
+ parser = argparse.ArgumentParser()
195
+ parser.add_argument('--load', type=str, default="VisualQuestionAnswering_cuda:0")
196
+ parser.add_argument('--port', type=int, default=1015)
197
+
198
+ args = parser.parse_args()
199
+ load_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.load.split(',')}
200
+ tools = build_chatbot_tools(load_dict)
201
+ bot = ConversationBot(tools)
202
+ with gr.Blocks(css="#chatbot .overflow-y-auto{height:500px}") as demo:
203
+ with gr.Row():
204
+ chatbot = gr.Chatbot(elem_id="chatbot", label="CATchat").style(height=1000,scale=0.5)
205
+ auxwindow = gr.Chatbot(elem_id="chatbot", label="Aux Window").style(height=1000,scale=0.5)
206
+ state = gr.State([])
207
+ aux_state = gr.State([])
208
+ with gr.Row():
209
+ with gr.Column(scale=0.7):
210
+ txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter, or upload an image").style(
211
+ container=False)
212
+ with gr.Column(scale=0.15, min_width=0):
213
+ clear = gr.Button("Clear")
214
+ with gr.Column(scale=0.15, min_width=0):
215
+ btn = gr.UploadButton("Upload", file_types=["image"])
216
+
217
+ txt.submit(bot.run_text, [txt, state, aux_state], [chatbot, state, aux_state, auxwindow])
218
+ txt.submit(lambda: "", None, txt)
219
+ btn.upload(bot.run_image, [btn, state, txt, aux_state], [chatbot, state, txt, aux_state, auxwindow])
220
+ clear.click(bot.memory.clear)
221
+ clear.click(lambda: [], None, chatbot)
222
+ clear.click(lambda: [], None, auxwindow)
223
+ clear.click(lambda: [], None, state)
224
+ clear.click(lambda: [], None, aux_state)
225
+ demo.launch(server_name="0.0.0.0", server_port=args.port, share=True)
caption_anything/utils/densecap_painter.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import json
3
+ import numpy as np
4
+ from typing import List
5
+ import random
6
+ from typing import Union
7
+
8
+ def draw_bbox(img: Union[np.ndarray, str], save_name: str, bbox: List[dict], show_caption: bool = False):
9
+ """
10
+ bbox: [{'image_id': str, 'bbox': [x1, y1, x2, y2], 'caption': str}, ...]
11
+ """
12
+ if isinstance(img, str):
13
+ img = cv2.imread(img)
14
+
15
+ RGB = [0, 50, 100, 150, 200, 250]
16
+ for box in bbox:
17
+ box['bbox'] = [int(_) for _ in box['bbox']]
18
+ x1, y1, x2, y2 = box['bbox']
19
+ caption = box['caption']
20
+ box_color = random.choices(RGB, k = 3)
21
+ (text_width, text_height), _ = cv2.getTextSize(caption, cv2.FONT_HERSHEY_SIMPLEX, fontScale = 0.5, thickness = 2)
22
+ cv2.rectangle(img, (x1, y1), (x2, y2), color = box_color, thickness = 2)
23
+ if show_caption:
24
+ cv2.putText(img, caption, (x1, y1 + text_height), cv2.FONT_HERSHEY_SIMPLEX, fontScale = 0.5, color = box_color, thickness = 2)
25
+
26
+ cv2.imwrite(save_name, img)
27
+ # cv2.imshow('visualise', img)
28
+ # cv2.waitKey(0)
29
+
30
+ def parse_bbox(anno, image_id: int = None):
31
+
32
+ with open(anno, 'r') as f:
33
+ predictions = json.load(f)
34
+
35
+ if image_id is None:
36
+ image_id = next(iter(predictions))
37
+
38
+ return predictions[image_id]
39
+
40
+ def gt_bbox(anno, img_name: int = None):
41
+
42
+ with open(anno, 'r') as f:
43
+ annotations = json.load(f)
44
+ annotations = annotations['annotations']
45
+
46
+ gt = []
47
+ img_name = int(img_name[:-4])
48
+ for annotation in annotations:
49
+ if annotation['image_id'] == 63:
50
+ x1, y1, w, h = annotation['bbox']
51
+ gt.append({'bbox': [x1, y1, x1 + w, y1 + h], 'caption': annotation['caption']})
52
+ return gt
53
+
54
+ if __name__ == '__main__':
55
+
56
+ img_name = '63.jpg'
57
+ show_caption = True
58
+ anno = 'vg_dense_captioning_blip2_top48_0.88_1000_0.96_debugTrue_predictions_shard_all.json'
59
+
60
+ img = cv2.imread(img_name)
61
+ examp_bbox = parse_bbox(anno)
62
+ ground_truth_bbox = gt_bbox('test.json', img_name)
63
+ draw_bbox(img, 'GT.jpg', ground_truth_bbox, show_caption)
64
+ draw_bbox(img, 'Pred.jpg', examp_bbox, show_caption)
caption_anything/utils/image_editing_utils.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageDraw, ImageFont
2
+ import copy
3
+ import numpy as np
4
+ import cv2
5
+
6
+
7
+ def wrap_text(text, font, max_width):
8
+ lines = []
9
+ words = text.split(' ')
10
+ current_line = ''
11
+
12
+ for word in words:
13
+ if font.getsize(current_line + word)[0] <= max_width:
14
+ current_line += word + ' '
15
+ else:
16
+ lines.append(current_line)
17
+ current_line = word + ' '
18
+
19
+ lines.append(current_line)
20
+ return lines
21
+
22
+
23
+ def create_bubble_frame(image, text, point, segmask, input_points=(), input_labels=(),
24
+ font_path='assets/times_with_simsun.ttf', font_size_ratio=0.033, point_size_ratio=0.01):
25
+ # Load the image
26
+ if input_points is None:
27
+ input_points = []
28
+ if input_labels is None:
29
+ input_labels = []
30
+
31
+ if type(image) == np.ndarray:
32
+ image = Image.fromarray(image)
33
+
34
+ image = copy.deepcopy(image)
35
+ width, height = image.size
36
+
37
+ # Calculate max_text_width and font_size based on image dimensions and total number of characters
38
+ total_chars = len(text)
39
+ max_text_width = int(0.4 * width)
40
+ font_size = int(height * font_size_ratio)
41
+ point_size = max(int(height * point_size_ratio), 1)
42
+
43
+ # Load the font
44
+ font = ImageFont.truetype(font_path, font_size)
45
+
46
+ # Wrap the text to fit within the max_text_width
47
+ lines = wrap_text(text, font, max_text_width)
48
+ text_width = max([font.getsize(line)[0] for line in lines])
49
+ _, text_height = font.getsize(lines[0])
50
+ text_height = text_height * len(lines)
51
+
52
+ # Define bubble frame dimensions
53
+ padding = 10
54
+ bubble_width = text_width + 2 * padding
55
+ bubble_height = text_height + 2 * padding
56
+
57
+ # Create a new image for the bubble frame
58
+ bubble = Image.new('RGBA', (bubble_width, bubble_height), (255, 248, 220, 0))
59
+
60
+ # Draw the bubble frame on the new image
61
+ draw = ImageDraw.Draw(bubble)
62
+ # draw.rectangle([(0, 0), (bubble_width - 1, bubble_height - 1)], fill=(255, 255, 255, 0), outline=(255, 255, 255, 0), width=2)
63
+ draw_rounded_rectangle(draw, (0, 0, bubble_width - 1, bubble_height - 1), point_size * 2,
64
+ fill=(255, 248, 220, 120), outline=None, width=2)
65
+ # Draw the wrapped text line by line
66
+ y_text = padding
67
+ for line in lines:
68
+ draw.text((padding, y_text), line, font=font, fill=(0, 0, 0, 255))
69
+ y_text += font.getsize(line)[1]
70
+
71
+ # Determine the point by the min area rect of mask
72
+ try:
73
+ ret, thresh = cv2.threshold(segmask, 127, 255, 0)
74
+ contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
75
+ largest_contour = max(contours, key=cv2.contourArea)
76
+ min_area_rect = cv2.minAreaRect(largest_contour)
77
+ box = cv2.boxPoints(min_area_rect)
78
+ sorted_points = box[np.argsort(box[:, 0])]
79
+ right_most_points = sorted_points[-2:]
80
+ right_down_most_point = right_most_points[np.argsort(right_most_points[:, 1])][-1]
81
+ x, y = int(right_down_most_point[0]), int(right_down_most_point[1])
82
+ except:
83
+ x, y = point
84
+ # Calculate the bubble frame position
85
+ if x + bubble_width > width:
86
+ x = width - bubble_width
87
+ if y + bubble_height > height:
88
+ y = height - bubble_height
89
+
90
+ # Paste the bubble frame onto the image
91
+ image.paste(bubble, (x, y), bubble)
92
+ draw = ImageDraw.Draw(image)
93
+ colors = [(0, 191, 255, 255), (255, 106, 106, 255)]
94
+ for p, label in zip(input_points, input_labels):
95
+ point_x, point_y = p[0], p[1]
96
+ left = point_x - point_size
97
+ top = point_y - point_size
98
+ right = point_x + point_size
99
+ bottom = point_y + point_size
100
+ draw.ellipse((left, top, right, bottom), fill=colors[label])
101
+ return image
102
+
103
+
104
+ def draw_rounded_rectangle(draw, xy, corner_radius, fill=None, outline=None, width=1):
105
+ x1, y1, x2, y2 = xy
106
+
107
+ draw.rectangle(
108
+ (x1, y1 + corner_radius, x2, y2 - corner_radius),
109
+ fill=fill,
110
+ outline=outline,
111
+ width=width
112
+ )
113
+ draw.rectangle(
114
+ (x1 + corner_radius, y1, x2 - corner_radius, y2),
115
+ fill=fill,
116
+ outline=outline,
117
+ width=width
118
+ )
119
+
120
+ draw.pieslice((x1, y1, x1 + corner_radius * 2, y1 + corner_radius * 2), 180, 270, fill=fill, outline=outline,
121
+ width=width)
122
+ draw.pieslice((x2 - corner_radius * 2, y1, x2, y1 + corner_radius * 2), 270, 360, fill=fill, outline=outline,
123
+ width=width)
124
+ draw.pieslice((x2 - corner_radius * 2, y2 - corner_radius * 2, x2, y2), 0, 90, fill=fill, outline=outline,
125
+ width=width)
126
+ draw.pieslice((x1, y2 - corner_radius * 2, x1 + corner_radius * 2, y2), 90, 180, fill=fill, outline=outline,
127
+ width=width)
caption_anything/utils/parser.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ def parse_augment():
4
+ parser = argparse.ArgumentParser()
5
+ parser.add_argument('--captioner', type=str, default="blip2")
6
+ parser.add_argument('--segmenter', type=str, default="huge")
7
+ parser.add_argument('--text_refiner', type=str, default="base")
8
+ parser.add_argument('--segmenter_checkpoint', type=str, default=None, help="SAM checkpoint path")
9
+ parser.add_argument('--seg_crop_mode', type=str, default="wo_bg", choices=['wo_bg', 'w_bg'],
10
+ help="whether to add or remove background of the image when captioning")
11
+ parser.add_argument('--clip_filter', action="store_true", help="use clip to filter bad captions")
12
+ parser.add_argument('--context_captions', action="store_true",
13
+ help="use surrounding captions to enhance current caption (TODO)")
14
+ parser.add_argument('--disable_regular_box', action="store_true", default=False,
15
+ help="crop image with a regular box")
16
+ parser.add_argument('--device', type=str, default="cuda:0")
17
+ parser.add_argument('--port', type=int, default=6086, help="only useful when running gradio applications")
18
+ parser.add_argument('--debug', action="store_true")
19
+ parser.add_argument('--gradio_share', action="store_true")
20
+ parser.add_argument('--disable_gpt', action="store_true")
21
+ parser.add_argument('--enable_reduce_tokens', action="store_true", default=False)
22
+ parser.add_argument('--disable_reuse_features', action="store_true", default=False)
23
+ parser.add_argument('--enable_morphologyex', action="store_true", default=False)
24
+ parser.add_argument('--chat_tools_dict', type=str, default='VisualQuestionAnswering_cuda:0', help='Visual ChatGPT tools, only useful when running gradio applications')
25
+
26
+ parser.add_argument('--pred_iou_thresh', type=float, default=0.88, help="sam post-precessing")
27
+ parser.add_argument('--min_mask_region_area', type=int, default=0, help="sam post-precessing")
28
+ parser.add_argument('--stability_score_thresh', type=float, default=0.95, help='sam post-processing')
29
+ parser.add_argument('--box_nms_thresh', type=float, default=0.7, help='sam post-processing')
30
+
31
+ args = parser.parse_args()
32
+
33
+ if args.debug:
34
+ print(args)
35
+ return args
caption_anything/utils/utils.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import sys
4
+
5
+ import cv2
6
+ import hashlib
7
+ import requests
8
+ import numpy as np
9
+
10
+ from typing import Union
11
+
12
+ from PIL import Image
13
+ from tqdm import tqdm
14
+
15
+
16
+ def load_image(image: Union[np.ndarray, Image.Image, str], return_type='numpy'):
17
+ """
18
+ Load image from path or PIL.Image or numpy.ndarray to required format.
19
+ """
20
+
21
+ # Check if image is already in return_type
22
+ if isinstance(image, Image.Image) and return_type == 'pil' or \
23
+ isinstance(image, np.ndarray) and return_type == 'numpy':
24
+ return image
25
+
26
+ # PIL.Image as intermediate format
27
+ if isinstance(image, str):
28
+ image = Image.open(image)
29
+ elif isinstance(image, np.ndarray):
30
+ image = Image.fromarray(image)
31
+
32
+ if image.mode == "RGBA":
33
+ image = image.convert("RGB")
34
+
35
+ if return_type == 'pil':
36
+ return image
37
+ elif return_type == 'numpy':
38
+ return np.asarray(image)
39
+ else:
40
+ raise NotImplementedError()
41
+
42
+
43
+ def image_resize(image: Image.Image, res=1024):
44
+ width, height = org_size = image.size
45
+ ratio = min(1.0 * res / max(width, height), 1.0)
46
+ if ratio < 1.0:
47
+ image = image.resize((int(width * ratio), int(height * ratio)))
48
+ print('Scaling image from {} to {}'.format(org_size, image.size))
49
+ return image
50
+
51
+ def xywh_to_x1y1x2y2(bbox):
52
+ x, y, w, h = bbox
53
+ return x,y,x+w,y+h
54
+
55
+
56
+ def x1y1x2y2_to_xywh(bbox):
57
+ x1, y1, x2, y2 = bbox
58
+ return x1,y1,x2-x1,y2-y1
59
+
60
+
61
+ def get_image_shape(image):
62
+ if isinstance(image, str):
63
+ return Image.open(image).size
64
+ elif isinstance(image, np.ndarray):
65
+ return image.shape
66
+ elif isinstance(image, Image.Image):
67
+ return image.size
68
+ else:
69
+ raise NotImplementedError
70
+
71
+ def is_platform_win():
72
+ return sys.platform == "win32"
73
+
74
+
75
+ def colormap(rgb=True):
76
+ color_list = np.array(
77
+ [
78
+ 0.000, 0.000, 0.000,
79
+ 1.000, 1.000, 1.000,
80
+ 1.000, 0.498, 0.313,
81
+ 0.392, 0.581, 0.929,
82
+ 0.000, 0.447, 0.741,
83
+ 0.850, 0.325, 0.098,
84
+ 0.929, 0.694, 0.125,
85
+ 0.494, 0.184, 0.556,
86
+ 0.466, 0.674, 0.188,
87
+ 0.301, 0.745, 0.933,
88
+ 0.635, 0.078, 0.184,
89
+ 0.300, 0.300, 0.300,
90
+ 0.600, 0.600, 0.600,
91
+ 1.000, 0.000, 0.000,
92
+ 1.000, 0.500, 0.000,
93
+ 0.749, 0.749, 0.000,
94
+ 0.000, 1.000, 0.000,
95
+ 0.000, 0.000, 1.000,
96
+ 0.667, 0.000, 1.000,
97
+ 0.333, 0.333, 0.000,
98
+ 0.333, 0.667, 0.000,
99
+ 0.333, 1.000, 0.000,
100
+ 0.667, 0.333, 0.000,
101
+ 0.667, 0.667, 0.000,
102
+ 0.667, 1.000, 0.000,
103
+ 1.000, 0.333, 0.000,
104
+ 1.000, 0.667, 0.000,
105
+ 1.000, 1.000, 0.000,
106
+ 0.000, 0.333, 0.500,
107
+ 0.000, 0.667, 0.500,
108
+ 0.000, 1.000, 0.500,
109
+ 0.333, 0.000, 0.500,
110
+ 0.333, 0.333, 0.500,
111
+ 0.333, 0.667, 0.500,
112
+ 0.333, 1.000, 0.500,
113
+ 0.667, 0.000, 0.500,
114
+ 0.667, 0.333, 0.500,
115
+ 0.667, 0.667, 0.500,
116
+ 0.667, 1.000, 0.500,
117
+ 1.000, 0.000, 0.500,
118
+ 1.000, 0.333, 0.500,
119
+ 1.000, 0.667, 0.500,
120
+ 1.000, 1.000, 0.500,
121
+ 0.000, 0.333, 1.000,
122
+ 0.000, 0.667, 1.000,
123
+ 0.000, 1.000, 1.000,
124
+ 0.333, 0.000, 1.000,
125
+ 0.333, 0.333, 1.000,
126
+ 0.333, 0.667, 1.000,
127
+ 0.333, 1.000, 1.000,
128
+ 0.667, 0.000, 1.000,
129
+ 0.667, 0.333, 1.000,
130
+ 0.667, 0.667, 1.000,
131
+ 0.667, 1.000, 1.000,
132
+ 1.000, 0.000, 1.000,
133
+ 1.000, 0.333, 1.000,
134
+ 1.000, 0.667, 1.000,
135
+ 0.167, 0.000, 0.000,
136
+ 0.333, 0.000, 0.000,
137
+ 0.500, 0.000, 0.000,
138
+ 0.667, 0.000, 0.000,
139
+ 0.833, 0.000, 0.000,
140
+ 1.000, 0.000, 0.000,
141
+ 0.000, 0.167, 0.000,
142
+ 0.000, 0.333, 0.000,
143
+ 0.000, 0.500, 0.000,
144
+ 0.000, 0.667, 0.000,
145
+ 0.000, 0.833, 0.000,
146
+ 0.000, 1.000, 0.000,
147
+ 0.000, 0.000, 0.167,
148
+ 0.000, 0.000, 0.333,
149
+ 0.000, 0.000, 0.500,
150
+ 0.000, 0.000, 0.667,
151
+ 0.000, 0.000, 0.833,
152
+ 0.000, 0.000, 1.000,
153
+ 0.143, 0.143, 0.143,
154
+ 0.286, 0.286, 0.286,
155
+ 0.429, 0.429, 0.429,
156
+ 0.571, 0.571, 0.571,
157
+ 0.714, 0.714, 0.714,
158
+ 0.857, 0.857, 0.857
159
+ ]
160
+ ).astype(np.float32)
161
+ color_list = color_list.reshape((-1, 3)) * 255
162
+ if not rgb:
163
+ color_list = color_list[:, ::-1]
164
+ return color_list
165
+
166
+
167
+ color_list = colormap()
168
+ color_list = color_list.astype('uint8').tolist()
169
+
170
+
171
+ def vis_add_mask(image, mask, color, alpha, kernel_size):
172
+ color = np.array(color)
173
+ mask = mask.astype('float').copy()
174
+ mask = (cv2.GaussianBlur(mask, (kernel_size, kernel_size), kernel_size) / 255.) * (alpha)
175
+ for i in range(3):
176
+ image[:, :, i] = image[:, :, i] * (1 - alpha + mask) + color[i] * (alpha - mask)
177
+ return image
178
+
179
+
180
+ def vis_add_mask_wo_blur(image, mask, color, alpha):
181
+ color = np.array(color)
182
+ mask = mask.astype('float').copy()
183
+ for i in range(3):
184
+ image[:, :, i] = image[:, :, i] * (1 - alpha + mask) + color[i] * (alpha - mask)
185
+ return image
186
+
187
+
188
+ def vis_add_mask_wo_gaussian(image, background_mask, contour_mask, background_color, contour_color, background_alpha,
189
+ contour_alpha):
190
+ background_color = np.array(background_color)
191
+ contour_color = np.array(contour_color)
192
+
193
+ # background_mask = 1 - background_mask
194
+ # contour_mask = 1 - contour_mask
195
+
196
+ for i in range(3):
197
+ image[:, :, i] = image[:, :, i] * (1 - background_alpha + background_mask * background_alpha) \
198
+ + background_color[i] * (background_alpha - background_mask * background_alpha)
199
+
200
+ image[:, :, i] = image[:, :, i] * (1 - contour_alpha + contour_mask * contour_alpha) \
201
+ + contour_color[i] * (contour_alpha - contour_mask * contour_alpha)
202
+
203
+ return image.astype('uint8')
204
+
205
+
206
+ def mask_painter(input_image, input_mask, background_alpha=0.7, background_blur_radius=7, contour_width=3,
207
+ contour_color=3, contour_alpha=1, background_color=0, paint_foreground=False):
208
+ """
209
+ add color mask to the background/foreground area
210
+ input_image: numpy array (w, h, C)
211
+ input_mask: numpy array (w, h)
212
+ background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
213
+ background_blur_radius: radius of background blur, must be odd number
214
+ contour_width: width of mask contour, must be odd number
215
+ contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
216
+ background_color: color index of the background (area with input_mask == False)
217
+ contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
218
+ paint_foreground: True for paint on foreground, False for background. Default: Flase
219
+
220
+ Output:
221
+ painted_image: numpy array
222
+ """
223
+ assert input_image.shape[:2] == input_mask.shape, 'different shape'
224
+ assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
225
+
226
+ # 0: background, 1: foreground
227
+ input_mask[input_mask > 0] = 255
228
+ if paint_foreground:
229
+ painted_image = vis_add_mask(input_image, 255 - input_mask, color_list[background_color], background_alpha,
230
+ background_blur_radius) # black for background
231
+ else:
232
+ # mask background
233
+ painted_image = vis_add_mask(input_image, input_mask, color_list[background_color], background_alpha,
234
+ background_blur_radius) # black for background
235
+ # mask contour
236
+ contour_mask = input_mask.copy()
237
+ contour_mask = cv2.Canny(contour_mask, 100, 200) # contour extraction
238
+ # widden contour
239
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (contour_width, contour_width))
240
+ contour_mask = cv2.dilate(contour_mask, kernel)
241
+ painted_image = vis_add_mask(painted_image, 255 - contour_mask, color_list[contour_color], contour_alpha,
242
+ contour_width)
243
+ return painted_image
244
+
245
+
246
+ def mask_painter_foreground_all(input_image, input_masks, background_alpha=0.7, background_blur_radius=7,
247
+ contour_width=3, contour_color=3, contour_alpha=1):
248
+ """
249
+ paint color mask on the all foreground area
250
+ input_image: numpy array with shape (w, h, C)
251
+ input_mask: list of masks, each mask is a numpy array with shape (w,h)
252
+ background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
253
+ background_blur_radius: radius of background blur, must be odd number
254
+ contour_width: width of mask contour, must be odd number
255
+ contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
256
+ background_color: color index of the background (area with input_mask == False)
257
+ contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
258
+
259
+ Output:
260
+ painted_image: numpy array
261
+ """
262
+
263
+ for i, input_mask in enumerate(input_masks):
264
+ input_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width,
265
+ contour_color, contour_alpha, background_color=i + 2, paint_foreground=True)
266
+ return input_image
267
+
268
+
269
+ def mask_generator_00(mask, background_radius, contour_radius):
270
+ # no background width when '00'
271
+ # distance map
272
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
273
+ dist_transform_back = cv2.distanceTransform(1 - mask, cv2.DIST_L2, 3)
274
+ dist_map = dist_transform_fore - dist_transform_back
275
+ # ...:::!!!:::...
276
+ contour_radius += 2
277
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
278
+ contour_mask = contour_mask / np.max(contour_mask)
279
+ contour_mask[contour_mask > 0.5] = 1.
280
+
281
+ return mask, contour_mask
282
+
283
+
284
+ def mask_generator_01(mask, background_radius, contour_radius):
285
+ # no background width when '00'
286
+ # distance map
287
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
288
+ dist_transform_back = cv2.distanceTransform(1 - mask, cv2.DIST_L2, 3)
289
+ dist_map = dist_transform_fore - dist_transform_back
290
+ # ...:::!!!:::...
291
+ contour_radius += 2
292
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
293
+ contour_mask = contour_mask / np.max(contour_mask)
294
+ return mask, contour_mask
295
+
296
+
297
+ def mask_generator_10(mask, background_radius, contour_radius):
298
+ # distance map
299
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
300
+ dist_transform_back = cv2.distanceTransform(1 - mask, cv2.DIST_L2, 3)
301
+ dist_map = dist_transform_fore - dist_transform_back
302
+ # .....:::::!!!!!
303
+ background_mask = np.clip(dist_map, -background_radius, background_radius)
304
+ background_mask = (background_mask - np.min(background_mask))
305
+ background_mask = background_mask / np.max(background_mask)
306
+ # ...:::!!!:::...
307
+ contour_radius += 2
308
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
309
+ contour_mask = contour_mask / np.max(contour_mask)
310
+ contour_mask[contour_mask > 0.5] = 1.
311
+ return background_mask, contour_mask
312
+
313
+
314
+ def mask_generator_11(mask, background_radius, contour_radius):
315
+ # distance map
316
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
317
+ dist_transform_back = cv2.distanceTransform(1 - mask, cv2.DIST_L2, 3)
318
+ dist_map = dist_transform_fore - dist_transform_back
319
+ # .....:::::!!!!!
320
+ background_mask = np.clip(dist_map, -background_radius, background_radius)
321
+ background_mask = (background_mask - np.min(background_mask))
322
+ background_mask = background_mask / np.max(background_mask)
323
+ # ...:::!!!:::...
324
+ contour_radius += 2
325
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
326
+ contour_mask = contour_mask / np.max(contour_mask)
327
+ return background_mask, contour_mask
328
+
329
+
330
+ def mask_painter_wo_gaussian(input_image, input_mask, background_alpha=0.5, background_blur_radius=7, contour_width=3,
331
+ contour_color=3, contour_alpha=1, mode='11'):
332
+ """
333
+ Input:
334
+ input_image: numpy array
335
+ input_mask: numpy array
336
+ background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
337
+ background_blur_radius: radius of background blur, must be odd number
338
+ contour_width: width of mask contour, must be odd number
339
+ contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
340
+ contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
341
+ mode: painting mode, '00', no blur, '01' only blur contour, '10' only blur background, '11' blur both
342
+
343
+ Output:
344
+ painted_image: numpy array
345
+ """
346
+ assert input_image.shape[:2] == input_mask.shape, 'different shape'
347
+ assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
348
+ assert mode in ['00', '01', '10', '11'], 'mode should be 00, 01, 10, or 11'
349
+
350
+ # downsample input image and mask
351
+ width, height = input_image.shape[0], input_image.shape[1]
352
+ res = 1024
353
+ ratio = min(1.0 * res / max(width, height), 1.0)
354
+ input_image = cv2.resize(input_image, (int(height * ratio), int(width * ratio)))
355
+ input_mask = cv2.resize(input_mask, (int(height * ratio), int(width * ratio)))
356
+
357
+ # 0: background, 1: foreground
358
+ msk = np.clip(input_mask, 0, 1)
359
+
360
+ # generate masks for background and contour pixels
361
+ background_radius = (background_blur_radius - 1) // 2
362
+ contour_radius = (contour_width - 1) // 2
363
+ generator_dict = {'00': mask_generator_00, '01': mask_generator_01, '10': mask_generator_10,
364
+ '11': mask_generator_11}
365
+ background_mask, contour_mask = generator_dict[mode](msk, background_radius, contour_radius)
366
+
367
+ # paint
368
+ painted_image = vis_add_mask_wo_gaussian \
369
+ (input_image, background_mask, contour_mask, color_list[0], color_list[contour_color], background_alpha,
370
+ contour_alpha) # black for background
371
+
372
+ return painted_image
373
+
374
+
375
+ seg_model_map = {
376
+ 'base': 'vit_b',
377
+ 'large': 'vit_l',
378
+ 'huge': 'vit_h'
379
+ }
380
+ ckpt_url_map = {
381
+ 'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
382
+ 'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
383
+ 'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'
384
+ }
385
+ expected_sha256_map = {
386
+ 'vit_b': 'ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912',
387
+ 'vit_l': '3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622',
388
+ 'vit_h': 'a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e'
389
+ }
390
+
391
+
392
+ def prepare_segmenter(segmenter="huge", download_root: str = None):
393
+ """
394
+ Prepare segmenter model and download checkpoint if necessary.
395
+
396
+ Returns: segmenter model name from 'vit_b', 'vit_l', 'vit_h'.
397
+
398
+ """
399
+
400
+ os.makedirs('result', exist_ok=True)
401
+ seg_model_name = seg_model_map[segmenter]
402
+ checkpoint_url = ckpt_url_map[seg_model_name]
403
+ folder = download_root or os.path.expanduser("~/.cache/SAM")
404
+ filename = os.path.basename(checkpoint_url)
405
+ segmenter_checkpoint = download_checkpoint(checkpoint_url, folder, filename, expected_sha256_map[seg_model_name])
406
+
407
+ return seg_model_name, segmenter_checkpoint
408
+
409
+
410
+ def download_checkpoint(url, folder, filename, expected_sha256):
411
+ os.makedirs(folder, exist_ok=True)
412
+ download_target = os.path.join(folder, filename)
413
+ if os.path.isfile(download_target):
414
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
415
+ return download_target
416
+
417
+ print(f'Download SAM checkpoint {url}, saving to {download_target} ...')
418
+ with requests.get(url, stream=True) as response, open(download_target, "wb") as output:
419
+ progress = tqdm(total=int(response.headers.get('content-length', 0)), unit='B', unit_scale=True)
420
+ for data in response.iter_content(chunk_size=1024):
421
+ size = output.write(data)
422
+ progress.update(size)
423
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
424
+ raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
425
+ return download_target
426
+
427
+
428
+ if __name__ == '__main__':
429
+
430
+ background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing
431
+ background_blur_radius = 31 # radius of background blur, must be odd number
432
+ contour_width = 11 # contour width, must be odd number
433
+ contour_color = 3 # id in color map, 0: black, 1: white, >1: others
434
+ contour_alpha = 1 # transparency of background, 0: no contour highlighted
435
+
436
+ # load input image and mask
437
+ input_image = np.array(Image.open('./test_images/painter_input_image.jpg').convert('RGB'))
438
+ input_mask = np.array(Image.open('./test_images/painter_input_mask.jpg').convert('P'))
439
+
440
+ # paint
441
+ overall_time_1 = 0
442
+ overall_time_2 = 0
443
+ overall_time_3 = 0
444
+ overall_time_4 = 0
445
+ overall_time_5 = 0
446
+
447
+ for i in range(50):
448
+ t2 = time.time()
449
+ painted_image_00 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius,
450
+ contour_width, contour_color, contour_alpha, mode='00')
451
+ e2 = time.time()
452
+
453
+ t3 = time.time()
454
+ painted_image_10 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius,
455
+ contour_width, contour_color, contour_alpha, mode='10')
456
+ e3 = time.time()
457
+
458
+ t1 = time.time()
459
+ painted_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width,
460
+ contour_color, contour_alpha)
461
+ e1 = time.time()
462
+
463
+ t4 = time.time()
464
+ painted_image_01 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius,
465
+ contour_width, contour_color, contour_alpha, mode='01')
466
+ e4 = time.time()
467
+
468
+ t5 = time.time()
469
+ painted_image_11 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius,
470
+ contour_width, contour_color, contour_alpha, mode='11')
471
+ e5 = time.time()
472
+
473
+ overall_time_1 += (e1 - t1)
474
+ overall_time_2 += (e2 - t2)
475
+ overall_time_3 += (e3 - t3)
476
+ overall_time_4 += (e4 - t4)
477
+ overall_time_5 += (e5 - t5)
478
+
479
+ print(f'average time w gaussian: {overall_time_1 / 50}')
480
+ print(f'average time w/o gaussian00: {overall_time_2 / 50}')
481
+ print(f'average time w/o gaussian10: {overall_time_3 / 50}')
482
+ print(f'average time w/o gaussian01: {overall_time_4 / 50}')
483
+ print(f'average time w/o gaussian11: {overall_time_5 / 50}')
484
+
485
+ # save
486
+ painted_image_00 = Image.fromarray(painted_image_00)
487
+ painted_image_00.save('./test_images/painter_output_image_00.png')
488
+
489
+ painted_image_10 = Image.fromarray(painted_image_10)
490
+ painted_image_10.save('./test_images/painter_output_image_10.png')
491
+
492
+ painted_image_01 = Image.fromarray(painted_image_01)
493
+ painted_image_01.save('./test_images/painter_output_image_01.png')
494
+
495
+ painted_image_11 = Image.fromarray(painted_image_11)
496
+ painted_image_11.save('./test_images/painter_output_image_11.png')
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ https://download.pytorch.org/whl/cu111/torch-1.10.1%2Bcu111-cp38-cp38-linux_x86_64.whl
2
+ https://download.pytorch.org/whl/cu111/torchvision-0.11.2%2Bcu111-cp38-cp38-linux_x86_64.whl
3
+ https://download.pytorch.org/whl/cu111/torchaudio-0.10.1%2Bcu111-cp38-cp38-linux_x86_64.whl
4
+ openai
5
+ pillow
6
+ langchain==0.0.101
7
+ transformers==4.28.1
8
+ ftfy
9
+ regex
10
+ tqdm
11
+ git+https://github.com/openai/CLIP.git
12
+ git+https://github.com/facebookresearch/segment-anything.git
13
+ opencv-python==4.5.5.64
14
+ pycocotools
15
+ matplotlib
16
+ onnxruntime
17
+ onnx
18
+ gradio==3.27.0
19
+ accelerate
20
+ bitsandbytes==0.34
21
+ easyocr
sam_vit_h_4b8939.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
3
+ size 2564550879
test_images/img0.png ADDED
test_images/img1.jpg ADDED
test_images/img12.jpg ADDED
test_images/img14.jpg ADDED
test_images/img2.jpg ADDED
test_images/img35.webp ADDED
test_images/img36.webp ADDED
test_images/img5.jpg ADDED
test_images/qingming3.jpeg ADDED

Git LFS Details

  • SHA256: 3fc255019acfe629f0838ec225028f32f38b71ebd01a2abcaa8e261eae48a521
  • Pointer size: 132 Bytes
  • Size of remote file: 1.17 MB