Spaces:
Runtime error
Runtime error
Commit
·
ee87a3a
0
Parent(s):
Duplicate from TencentARC/Caption-Anything
Browse filesCo-authored-by: wybertwang <wybertwang@users.noreply.huggingface.co>
- .gitattributes +42 -0
- .gitignore +143 -0
- LICENSE +28 -0
- README.md +15 -0
- app.py +599 -0
- assets/UI.png +3 -0
- assets/caption_anything_logo.png +0 -0
- assets/demo1.jpg +3 -0
- assets/demo1.png +3 -0
- assets/demo1.svg +0 -0
- assets/demo2.png +0 -0
- assets/demo2.svg +0 -0
- assets/qingming.gif +3 -0
- assets/times_with_simsun.ttf +3 -0
- assets/title.png +0 -0
- assets/title.svg +1 -0
- caption_anything/__init__.py +0 -0
- caption_anything/captioner/README.md +13 -0
- caption_anything/captioner/__init__.py +15 -0
- caption_anything/captioner/base_captioner.py +200 -0
- caption_anything/captioner/blip.py +72 -0
- caption_anything/captioner/blip2.py +71 -0
- caption_anything/captioner/git.py +67 -0
- caption_anything/captioner/modeling_blip.py +1476 -0
- caption_anything/captioner/modeling_git.py +1587 -0
- caption_anything/captioner/vit_pixel_masks_utils.py +17 -0
- caption_anything/model.py +294 -0
- caption_anything/segmenter/__init__.py +14 -0
- caption_anything/segmenter/base_segmenter.py +184 -0
- caption_anything/segmenter/readme.md +68 -0
- caption_anything/text_refiner/README.md +8 -0
- caption_anything/text_refiner/__init__.py +6 -0
- caption_anything/text_refiner/text_refiner.py +86 -0
- caption_anything/utils/chatbot.py +225 -0
- caption_anything/utils/densecap_painter.py +64 -0
- caption_anything/utils/image_editing_utils.py +127 -0
- caption_anything/utils/parser.py +35 -0
- caption_anything/utils/utils.py +496 -0
- requirements.txt +21 -0
- sam_vit_h_4b8939.pth +3 -0
- test_images/img0.png +0 -0
- test_images/img1.jpg +0 -0
- test_images/img12.jpg +0 -0
- test_images/img14.jpg +0 -0
- test_images/img2.jpg +0 -0
- test_images/img35.webp +0 -0
- test_images/img36.webp +0 -0
- test_images/img5.jpg +0 -0
- test_images/qingming3.jpeg +3 -0
.gitattributes
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
test_img/img18.jpg filter=lfs diff=lfs merge=lfs -text
|
36 |
+
test_img/img22.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
+
times_with_simsun.ttf filter=lfs diff=lfs merge=lfs -text
|
38 |
+
test_images/qingming3.jpeg filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/demo1.jpg filter=lfs diff=lfs merge=lfs -text
|
40 |
+
assets/demo1.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
assets/qingming.gif filter=lfs diff=lfs merge=lfs -text
|
42 |
+
assets/UI.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
result/
|
2 |
+
model_cache/
|
3 |
+
*.pth
|
4 |
+
teng_grad_start.sh
|
5 |
+
*.jpg
|
6 |
+
*.jpeg
|
7 |
+
*.png
|
8 |
+
*.svg
|
9 |
+
*.gif
|
10 |
+
*.tiff
|
11 |
+
*.webp
|
12 |
+
|
13 |
+
|
14 |
+
# Byte-compiled / optimized / DLL files
|
15 |
+
__pycache__/
|
16 |
+
*.py[cod]
|
17 |
+
*$py.class
|
18 |
+
result/
|
19 |
+
|
20 |
+
# C extensions
|
21 |
+
*.so
|
22 |
+
|
23 |
+
# Distribution / packaging
|
24 |
+
.Python
|
25 |
+
build/
|
26 |
+
develop-eggs/
|
27 |
+
dist/
|
28 |
+
downloads/
|
29 |
+
eggs/
|
30 |
+
.eggs/
|
31 |
+
lib/
|
32 |
+
lib64/
|
33 |
+
parts/
|
34 |
+
sdist/
|
35 |
+
var/
|
36 |
+
wheels/
|
37 |
+
pip-wheel-metadata/
|
38 |
+
share/python-wheels/
|
39 |
+
*.egg-info/
|
40 |
+
.installed.cfg
|
41 |
+
*.egg
|
42 |
+
MANIFEST
|
43 |
+
|
44 |
+
# PyInstaller
|
45 |
+
# Usually these files are written by a python script from a template
|
46 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
47 |
+
*.manifest
|
48 |
+
*.spec
|
49 |
+
|
50 |
+
# Installer logs
|
51 |
+
pip-log.txt
|
52 |
+
pip-delete-this-directory.txt
|
53 |
+
|
54 |
+
# Unit test / coverage reports
|
55 |
+
htmlcov/
|
56 |
+
.tox/
|
57 |
+
.nox/
|
58 |
+
.coverage
|
59 |
+
.coverage.*
|
60 |
+
.cache
|
61 |
+
nosetests.xml
|
62 |
+
coverage.xml
|
63 |
+
*.cover
|
64 |
+
*.py,cover
|
65 |
+
.hypothesis/
|
66 |
+
.pytest_cache/
|
67 |
+
|
68 |
+
# Translations
|
69 |
+
*.mo
|
70 |
+
*.pot
|
71 |
+
|
72 |
+
# Django stuff:
|
73 |
+
*.log
|
74 |
+
local_settings.py
|
75 |
+
db.sqlite3
|
76 |
+
db.sqlite3-journal
|
77 |
+
|
78 |
+
# Flask stuff:
|
79 |
+
instance/
|
80 |
+
.webassets-cache
|
81 |
+
|
82 |
+
# Scrapy stuff:
|
83 |
+
.scrapy
|
84 |
+
|
85 |
+
# Sphinx documentation
|
86 |
+
docs/_build/
|
87 |
+
|
88 |
+
# PyBuilder
|
89 |
+
target/
|
90 |
+
|
91 |
+
# Jupyter Notebook
|
92 |
+
.ipynb_checkpoints
|
93 |
+
|
94 |
+
# IPython
|
95 |
+
profile_default/
|
96 |
+
ipython_config.py
|
97 |
+
|
98 |
+
# pyenv
|
99 |
+
.python-version
|
100 |
+
|
101 |
+
# pipenv
|
102 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
103 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
104 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
105 |
+
# install all needed dependencies.
|
106 |
+
#Pipfile.lock
|
107 |
+
|
108 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
109 |
+
__pypackages__/
|
110 |
+
|
111 |
+
# Celery stuff
|
112 |
+
celerybeat-schedule
|
113 |
+
celerybeat.pid
|
114 |
+
|
115 |
+
# SageMath parsed files
|
116 |
+
*.sage.py
|
117 |
+
|
118 |
+
# Environments
|
119 |
+
.env
|
120 |
+
.venv
|
121 |
+
env/
|
122 |
+
venv/
|
123 |
+
ENV/
|
124 |
+
env.bak/
|
125 |
+
venv.bak/
|
126 |
+
|
127 |
+
# Spyder project settings
|
128 |
+
.spyderproject
|
129 |
+
.spyproject
|
130 |
+
|
131 |
+
# Rope project settings
|
132 |
+
.ropeproject
|
133 |
+
|
134 |
+
# mkdocs documentation
|
135 |
+
/site
|
136 |
+
|
137 |
+
# mypy
|
138 |
+
.mypy_cache/
|
139 |
+
.dmypy.json
|
140 |
+
dmypy.json
|
141 |
+
|
142 |
+
# Pyre type checker
|
143 |
+
.pyre/
|
LICENSE
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
BSD 3-Clause License
|
2 |
+
|
3 |
+
Copyright (c) 2023, Teng Wang
|
4 |
+
|
5 |
+
Redistribution and use in source and binary forms, with or without
|
6 |
+
modification, are permitted provided that the following conditions are met:
|
7 |
+
|
8 |
+
1. Redistributions of source code must retain the above copyright notice, this
|
9 |
+
list of conditions and the following disclaimer.
|
10 |
+
|
11 |
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
12 |
+
this list of conditions and the following disclaimer in the documentation
|
13 |
+
and/or other materials provided with the distribution.
|
14 |
+
|
15 |
+
3. Neither the name of the copyright holder nor the names of its
|
16 |
+
contributors may be used to endorse or promote products derived from
|
17 |
+
this software without specific prior written permission.
|
18 |
+
|
19 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
20 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
21 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
22 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
23 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
24 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
25 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
26 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
27 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
28 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
README.md
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Caption Anything
|
3 |
+
emoji: 📚
|
4 |
+
colorFrom: green
|
5 |
+
colorTo: green
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.26.0
|
8 |
+
python_version: 3.8.9
|
9 |
+
app_file: app.py
|
10 |
+
pinned: false
|
11 |
+
license: apache-2.0
|
12 |
+
duplicated_from: TencentARC/Caption-Anything
|
13 |
+
---
|
14 |
+
|
15 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,599 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import gradio as gr
|
4 |
+
import numpy as np
|
5 |
+
from gradio import processing_utils
|
6 |
+
|
7 |
+
from packaging import version
|
8 |
+
from PIL import Image, ImageDraw
|
9 |
+
import functools
|
10 |
+
|
11 |
+
from caption_anything.model import CaptionAnything
|
12 |
+
from caption_anything.utils.image_editing_utils import create_bubble_frame
|
13 |
+
from caption_anything.utils.utils import mask_painter, seg_model_map, prepare_segmenter, image_resize
|
14 |
+
from caption_anything.utils.parser import parse_augment
|
15 |
+
from caption_anything.captioner import build_captioner
|
16 |
+
from caption_anything.text_refiner import build_text_refiner
|
17 |
+
from caption_anything.segmenter import build_segmenter
|
18 |
+
from caption_anything.utils.chatbot import ConversationBot, build_chatbot_tools, get_new_image_name
|
19 |
+
from segment_anything import sam_model_registry
|
20 |
+
import easyocr
|
21 |
+
|
22 |
+
args = parse_augment()
|
23 |
+
args.segmenter = "huge"
|
24 |
+
args.segmenter_checkpoint = "sam_vit_h_4b8939.pth"
|
25 |
+
args.clip_filter = True
|
26 |
+
if args.segmenter_checkpoint is None:
|
27 |
+
_, segmenter_checkpoint = prepare_segmenter(args.segmenter)
|
28 |
+
else:
|
29 |
+
segmenter_checkpoint = args.segmenter_checkpoint
|
30 |
+
|
31 |
+
shared_captioner = build_captioner(args.captioner, args.device, args)
|
32 |
+
shared_sam_model = sam_model_registry[seg_model_map[args.segmenter]](checkpoint=segmenter_checkpoint).to(args.device)
|
33 |
+
ocr_lang = ["ch_tra", "en"]
|
34 |
+
shared_ocr_reader = easyocr.Reader(ocr_lang)
|
35 |
+
tools_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.chat_tools_dict.split(',')}
|
36 |
+
shared_chatbot_tools = build_chatbot_tools(tools_dict)
|
37 |
+
|
38 |
+
|
39 |
+
class ImageSketcher(gr.Image):
|
40 |
+
"""
|
41 |
+
Fix the bug of gradio.Image that cannot upload with tool == 'sketch'.
|
42 |
+
"""
|
43 |
+
|
44 |
+
is_template = True # Magic to make this work with gradio.Block, don't remove unless you know what you're doing.
|
45 |
+
|
46 |
+
def __init__(self, **kwargs):
|
47 |
+
super().__init__(tool="sketch", **kwargs)
|
48 |
+
|
49 |
+
def preprocess(self, x):
|
50 |
+
if self.tool == 'sketch' and self.source in ["upload", "webcam"]:
|
51 |
+
assert isinstance(x, dict)
|
52 |
+
if x['mask'] is None:
|
53 |
+
decode_image = processing_utils.decode_base64_to_image(x['image'])
|
54 |
+
width, height = decode_image.size
|
55 |
+
mask = np.zeros((height, width, 4), dtype=np.uint8)
|
56 |
+
mask[..., -1] = 255
|
57 |
+
mask = self.postprocess(mask)
|
58 |
+
x['mask'] = mask
|
59 |
+
return super().preprocess(x)
|
60 |
+
|
61 |
+
|
62 |
+
def build_caption_anything_with_models(args, api_key="", captioner=None, sam_model=None, ocr_reader=None, text_refiner=None,
|
63 |
+
session_id=None):
|
64 |
+
segmenter = build_segmenter(args.segmenter, args.device, args, model=sam_model)
|
65 |
+
captioner = captioner
|
66 |
+
if session_id is not None:
|
67 |
+
print('Init caption anything for session {}'.format(session_id))
|
68 |
+
return CaptionAnything(args, api_key, captioner=captioner, segmenter=segmenter, ocr_reader=ocr_reader, text_refiner=text_refiner)
|
69 |
+
|
70 |
+
|
71 |
+
def init_openai_api_key(api_key=""):
|
72 |
+
text_refiner = None
|
73 |
+
visual_chatgpt = None
|
74 |
+
if api_key and len(api_key) > 30:
|
75 |
+
try:
|
76 |
+
text_refiner = build_text_refiner(args.text_refiner, args.device, args, api_key)
|
77 |
+
assert len(text_refiner.llm('hi')) > 0 # test
|
78 |
+
visual_chatgpt = ConversationBot(shared_chatbot_tools, api_key)
|
79 |
+
except:
|
80 |
+
text_refiner = None
|
81 |
+
visual_chatgpt = None
|
82 |
+
openai_available = text_refiner is not None
|
83 |
+
if openai_available:
|
84 |
+
return [gr.update(visible=True)]*6 + [gr.update(visible=False)]*2 + [text_refiner, visual_chatgpt, None]
|
85 |
+
else:
|
86 |
+
return [gr.update(visible=False)]*6 + [gr.update(visible=True)]*2 + [text_refiner, visual_chatgpt, 'Your OpenAI API Key is not available']
|
87 |
+
|
88 |
+
def init_wo_openai_api_key():
|
89 |
+
return [gr.update(visible=False)]*4 + [gr.update(visible=True)]*2 + [gr.update(visible=False)]*2 + [None, None, None]
|
90 |
+
|
91 |
+
def get_click_prompt(chat_input, click_state, click_mode):
|
92 |
+
inputs = json.loads(chat_input)
|
93 |
+
if click_mode == 'Continuous':
|
94 |
+
points = click_state[0]
|
95 |
+
labels = click_state[1]
|
96 |
+
for input in inputs:
|
97 |
+
points.append(input[:2])
|
98 |
+
labels.append(input[2])
|
99 |
+
elif click_mode == 'Single':
|
100 |
+
points = []
|
101 |
+
labels = []
|
102 |
+
for input in inputs:
|
103 |
+
points.append(input[:2])
|
104 |
+
labels.append(input[2])
|
105 |
+
click_state[0] = points
|
106 |
+
click_state[1] = labels
|
107 |
+
else:
|
108 |
+
raise NotImplementedError
|
109 |
+
|
110 |
+
prompt = {
|
111 |
+
"prompt_type": ["click"],
|
112 |
+
"input_point": click_state[0],
|
113 |
+
"input_label": click_state[1],
|
114 |
+
"multimask_output": "True",
|
115 |
+
}
|
116 |
+
return prompt
|
117 |
+
|
118 |
+
|
119 |
+
def update_click_state(click_state, caption, click_mode):
|
120 |
+
if click_mode == 'Continuous':
|
121 |
+
click_state[2].append(caption)
|
122 |
+
elif click_mode == 'Single':
|
123 |
+
click_state[2] = [caption]
|
124 |
+
else:
|
125 |
+
raise NotImplementedError
|
126 |
+
|
127 |
+
def chat_input_callback(*args):
|
128 |
+
visual_chatgpt, chat_input, click_state, state, aux_state = args
|
129 |
+
if visual_chatgpt is not None:
|
130 |
+
return visual_chatgpt.run_text(chat_input, state, aux_state)
|
131 |
+
else:
|
132 |
+
response = "Text refiner is not initilzed, please input openai api key."
|
133 |
+
state = state + [(chat_input, response)]
|
134 |
+
return state, state
|
135 |
+
|
136 |
+
|
137 |
+
|
138 |
+
def upload_callback(image_input, state, visual_chatgpt=None):
|
139 |
+
|
140 |
+
if isinstance(image_input, dict): # if upload from sketcher_input, input contains image and mask
|
141 |
+
image_input, mask = image_input['image'], image_input['mask']
|
142 |
+
|
143 |
+
click_state = [[], [], []]
|
144 |
+
image_input = image_resize(image_input, res=1024)
|
145 |
+
|
146 |
+
model = build_caption_anything_with_models(
|
147 |
+
args,
|
148 |
+
api_key="",
|
149 |
+
captioner=shared_captioner,
|
150 |
+
sam_model=shared_sam_model,
|
151 |
+
ocr_reader=shared_ocr_reader,
|
152 |
+
session_id=iface.app_id
|
153 |
+
)
|
154 |
+
model.segmenter.set_image(image_input)
|
155 |
+
image_embedding = model.image_embedding
|
156 |
+
original_size = model.original_size
|
157 |
+
input_size = model.input_size
|
158 |
+
|
159 |
+
if visual_chatgpt is not None:
|
160 |
+
print('upload_callback: add caption to chatGPT memory')
|
161 |
+
new_image_path = get_new_image_name('chat_image', func_name='upload')
|
162 |
+
image_input.save(new_image_path)
|
163 |
+
visual_chatgpt.current_image = new_image_path
|
164 |
+
img_caption = model.captioner.inference(image_input, filter=False, args={'text_prompt':''})['caption']
|
165 |
+
Human_prompt = f'\nHuman: The description of the image with path {new_image_path} is: {img_caption}. This information helps you to understand this image, but you should use tools to finish following tasks, rather than directly imagine from my description. If you understand, say \"Received\". \n'
|
166 |
+
AI_prompt = "Received."
|
167 |
+
visual_chatgpt.global_prompt = Human_prompt + 'AI: ' + AI_prompt
|
168 |
+
visual_chatgpt.agent.memory.buffer = visual_chatgpt.agent.memory.buffer + visual_chatgpt.global_prompt
|
169 |
+
state = [(None, 'Received new image, resize it to width {} and height {}: '.format(image_input.size[0], image_input.size[1]))]
|
170 |
+
|
171 |
+
return state, state, image_input, click_state, image_input, image_input, image_embedding, \
|
172 |
+
original_size, input_size
|
173 |
+
|
174 |
+
|
175 |
+
def inference_click(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality,
|
176 |
+
length, image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt,
|
177 |
+
evt: gr.SelectData):
|
178 |
+
click_index = evt.index
|
179 |
+
|
180 |
+
if point_prompt == 'Positive':
|
181 |
+
coordinate = "[[{}, {}, 1]]".format(str(click_index[0]), str(click_index[1]))
|
182 |
+
else:
|
183 |
+
coordinate = "[[{}, {}, 0]]".format(str(click_index[0]), str(click_index[1]))
|
184 |
+
|
185 |
+
prompt = get_click_prompt(coordinate, click_state, click_mode)
|
186 |
+
input_points = prompt['input_point']
|
187 |
+
input_labels = prompt['input_label']
|
188 |
+
|
189 |
+
controls = {'length': length,
|
190 |
+
'sentiment': sentiment,
|
191 |
+
'factuality': factuality,
|
192 |
+
'language': language}
|
193 |
+
|
194 |
+
model = build_caption_anything_with_models(
|
195 |
+
args,
|
196 |
+
api_key="",
|
197 |
+
captioner=shared_captioner,
|
198 |
+
sam_model=shared_sam_model,
|
199 |
+
ocr_reader=shared_ocr_reader,
|
200 |
+
text_refiner=text_refiner,
|
201 |
+
session_id=iface.app_id
|
202 |
+
)
|
203 |
+
|
204 |
+
model.setup(image_embedding, original_size, input_size, is_image_set=True)
|
205 |
+
|
206 |
+
enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
|
207 |
+
out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki, verbose=True, args={'clip_filter': False})[0]
|
208 |
+
|
209 |
+
state = state + [("Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]), None)]
|
210 |
+
state = state + [(None, "raw_caption: {}".format(out['generated_captions']['raw_caption']))]
|
211 |
+
update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode)
|
212 |
+
text = out['generated_captions']['raw_caption']
|
213 |
+
input_mask = np.array(out['mask'].convert('P'))
|
214 |
+
image_input = mask_painter(np.array(image_input), input_mask)
|
215 |
+
origin_image_input = image_input
|
216 |
+
image_input = create_bubble_frame(image_input, text, (click_index[0], click_index[1]), input_mask,
|
217 |
+
input_points=input_points, input_labels=input_labels)
|
218 |
+
x, y = input_points[-1]
|
219 |
+
|
220 |
+
if visual_chatgpt is not None:
|
221 |
+
print('inference_click: add caption to chatGPT memory')
|
222 |
+
new_crop_save_path = get_new_image_name('chat_image', func_name='crop')
|
223 |
+
Image.open(out["crop_save_path"]).save(new_crop_save_path)
|
224 |
+
point_prompt = f'You should primarly use tools on the selected regional image (description: {text}, path: {new_crop_save_path}), which is a part of the whole image (path: {visual_chatgpt.current_image}). If human mentioned some objects not in the selected region, you can use tools on the whole image.'
|
225 |
+
visual_chatgpt.point_prompt = point_prompt
|
226 |
+
|
227 |
+
yield state, state, click_state, image_input
|
228 |
+
if not args.disable_gpt and model.text_refiner:
|
229 |
+
refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'],
|
230 |
+
enable_wiki=enable_wiki)
|
231 |
+
# new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
|
232 |
+
new_cap = refined_caption['caption']
|
233 |
+
if refined_caption['wiki']:
|
234 |
+
state = state + [(None, "Wiki: {}".format(refined_caption['wiki']))]
|
235 |
+
state = state + [(None, f"caption: {new_cap}")]
|
236 |
+
refined_image_input = create_bubble_frame(origin_image_input, new_cap, (click_index[0], click_index[1]),
|
237 |
+
input_mask,
|
238 |
+
input_points=input_points, input_labels=input_labels)
|
239 |
+
yield state, state, click_state, refined_image_input
|
240 |
+
|
241 |
+
|
242 |
+
def get_sketch_prompt(mask: Image.Image):
|
243 |
+
"""
|
244 |
+
Get the prompt for the sketcher.
|
245 |
+
TODO: This is a temporary solution. We should cluster the sketch and get the bounding box of each cluster.
|
246 |
+
"""
|
247 |
+
|
248 |
+
mask = np.asarray(mask)[..., 0]
|
249 |
+
|
250 |
+
# Get the bounding box of the sketch
|
251 |
+
y, x = np.where(mask != 0)
|
252 |
+
x1, y1 = np.min(x), np.min(y)
|
253 |
+
x2, y2 = np.max(x), np.max(y)
|
254 |
+
|
255 |
+
prompt = {
|
256 |
+
'prompt_type': ['box'],
|
257 |
+
'input_boxes': [
|
258 |
+
[x1, y1, x2, y2]
|
259 |
+
]
|
260 |
+
}
|
261 |
+
|
262 |
+
return prompt
|
263 |
+
|
264 |
+
|
265 |
+
def inference_traject(sketcher_image, enable_wiki, language, sentiment, factuality, length, image_embedding, state,
|
266 |
+
original_size, input_size, text_refiner):
|
267 |
+
image_input, mask = sketcher_image['image'], sketcher_image['mask']
|
268 |
+
|
269 |
+
prompt = get_sketch_prompt(mask)
|
270 |
+
boxes = prompt['input_boxes']
|
271 |
+
|
272 |
+
controls = {'length': length,
|
273 |
+
'sentiment': sentiment,
|
274 |
+
'factuality': factuality,
|
275 |
+
'language': language}
|
276 |
+
|
277 |
+
model = build_caption_anything_with_models(
|
278 |
+
args,
|
279 |
+
api_key="",
|
280 |
+
captioner=shared_captioner,
|
281 |
+
sam_model=shared_sam_model,
|
282 |
+
ocr_reader=shared_ocr_reader,
|
283 |
+
text_refiner=text_refiner,
|
284 |
+
session_id=iface.app_id
|
285 |
+
)
|
286 |
+
|
287 |
+
model.setup(image_embedding, original_size, input_size, is_image_set=True)
|
288 |
+
|
289 |
+
enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
|
290 |
+
out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki)[0]
|
291 |
+
|
292 |
+
# Update components and states
|
293 |
+
state.append((f'Box: {boxes}', None))
|
294 |
+
state.append((None, f'raw_caption: {out["generated_captions"]["raw_caption"]}'))
|
295 |
+
text = out['generated_captions']['raw_caption']
|
296 |
+
input_mask = np.array(out['mask'].convert('P'))
|
297 |
+
image_input = mask_painter(np.array(image_input), input_mask)
|
298 |
+
|
299 |
+
origin_image_input = image_input
|
300 |
+
|
301 |
+
fake_click_index = (int((boxes[0][0] + boxes[0][2]) / 2), int((boxes[0][1] + boxes[0][3]) / 2))
|
302 |
+
image_input = create_bubble_frame(image_input, text, fake_click_index, input_mask)
|
303 |
+
|
304 |
+
yield state, state, image_input
|
305 |
+
|
306 |
+
if not args.disable_gpt and model.text_refiner:
|
307 |
+
refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'],
|
308 |
+
enable_wiki=enable_wiki)
|
309 |
+
|
310 |
+
new_cap = refined_caption['caption']
|
311 |
+
if refined_caption['wiki']:
|
312 |
+
state = state + [(None, "Wiki: {}".format(refined_caption['wiki']))]
|
313 |
+
state = state + [(None, f"caption: {new_cap}")]
|
314 |
+
refined_image_input = create_bubble_frame(origin_image_input, new_cap, fake_click_index, input_mask)
|
315 |
+
|
316 |
+
yield state, state, refined_image_input
|
317 |
+
|
318 |
+
def clear_chat_memory(visual_chatgpt, keep_global=False):
|
319 |
+
if visual_chatgpt is not None:
|
320 |
+
visual_chatgpt.memory.clear()
|
321 |
+
visual_chatgpt.point_prompt = ""
|
322 |
+
if keep_global:
|
323 |
+
visual_chatgpt.agent.memory.buffer = visual_chatgpt.global_prompt
|
324 |
+
else:
|
325 |
+
visual_chatgpt.current_image = None
|
326 |
+
visual_chatgpt.global_prompt = ""
|
327 |
+
|
328 |
+
def cap_everything(image_input, visual_chatgpt, text_refiner):
|
329 |
+
|
330 |
+
model = build_caption_anything_with_models(
|
331 |
+
args,
|
332 |
+
api_key="",
|
333 |
+
captioner=shared_captioner,
|
334 |
+
sam_model=shared_sam_model,
|
335 |
+
ocr_reader=shared_ocr_reader,
|
336 |
+
text_refiner=text_refiner,
|
337 |
+
session_id=iface.app_id
|
338 |
+
)
|
339 |
+
paragraph = model.inference_cap_everything(image_input, verbose=True)
|
340 |
+
# state = state + [(None, f"Caption Everything: {paragraph}")]
|
341 |
+
Human_prompt = f'\nThe description of the image with path {visual_chatgpt.current_image} is:\n{paragraph}\nThis information helps you to understand this image, but you should use tools to finish following tasks, rather than directly imagine from my description. If you understand, say \"Received\". \n'
|
342 |
+
AI_prompt = "Received."
|
343 |
+
visual_chatgpt.global_prompt = Human_prompt + 'AI: ' + AI_prompt
|
344 |
+
visual_chatgpt.agent.memory.buffer = visual_chatgpt.agent.memory.buffer + visual_chatgpt.global_prompt
|
345 |
+
return paragraph
|
346 |
+
|
347 |
+
|
348 |
+
def get_style():
|
349 |
+
current_version = version.parse(gr.__version__)
|
350 |
+
if current_version <= version.parse('3.24.1'):
|
351 |
+
style = '''
|
352 |
+
#image_sketcher{min-height:500px}
|
353 |
+
#image_sketcher [data-testid="image"], #image_sketcher [data-testid="image"] > div{min-height: 500px}
|
354 |
+
#image_upload{min-height:500px}
|
355 |
+
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 500px}
|
356 |
+
'''
|
357 |
+
elif current_version <= version.parse('3.27'):
|
358 |
+
style = '''
|
359 |
+
#image_sketcher{min-height:500px}
|
360 |
+
#image_upload{min-height:500px}
|
361 |
+
'''
|
362 |
+
else:
|
363 |
+
style = None
|
364 |
+
|
365 |
+
return style
|
366 |
+
|
367 |
+
|
368 |
+
def create_ui():
|
369 |
+
title = """<p><h1 align="center">Caption-Anything</h1></p>
|
370 |
+
"""
|
371 |
+
description = """<p>Gradio demo for Caption Anything, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. Code: <a href="https://github.com/ttengwang/Caption-Anything">https://github.com/ttengwang/Caption-Anything</a> <a href="https://huggingface.co/spaces/TencentARC/Caption-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>"""
|
372 |
+
|
373 |
+
examples = [
|
374 |
+
["test_images/img35.webp"],
|
375 |
+
["test_images/img2.jpg"],
|
376 |
+
["test_images/img5.jpg"],
|
377 |
+
["test_images/img12.jpg"],
|
378 |
+
["test_images/img14.jpg"],
|
379 |
+
["test_images/qingming3.jpeg"],
|
380 |
+
["test_images/img1.jpg"],
|
381 |
+
]
|
382 |
+
|
383 |
+
with gr.Blocks(
|
384 |
+
css=get_style()
|
385 |
+
) as iface:
|
386 |
+
state = gr.State([])
|
387 |
+
click_state = gr.State([[], [], []])
|
388 |
+
# chat_state = gr.State([])
|
389 |
+
origin_image = gr.State(None)
|
390 |
+
image_embedding = gr.State(None)
|
391 |
+
text_refiner = gr.State(None)
|
392 |
+
visual_chatgpt = gr.State(None)
|
393 |
+
original_size = gr.State(None)
|
394 |
+
input_size = gr.State(None)
|
395 |
+
# img_caption = gr.State(None)
|
396 |
+
aux_state = gr.State([])
|
397 |
+
|
398 |
+
gr.Markdown(title)
|
399 |
+
gr.Markdown(description)
|
400 |
+
|
401 |
+
with gr.Row():
|
402 |
+
with gr.Column(scale=1.0):
|
403 |
+
with gr.Column(visible=False) as modules_not_need_gpt:
|
404 |
+
with gr.Tab("Click"):
|
405 |
+
image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
|
406 |
+
example_image = gr.Image(type="pil", interactive=False, visible=False)
|
407 |
+
with gr.Row(scale=1.0):
|
408 |
+
with gr.Row(scale=0.4):
|
409 |
+
point_prompt = gr.Radio(
|
410 |
+
choices=["Positive", "Negative"],
|
411 |
+
value="Positive",
|
412 |
+
label="Point Prompt",
|
413 |
+
interactive=True)
|
414 |
+
click_mode = gr.Radio(
|
415 |
+
choices=["Continuous", "Single"],
|
416 |
+
value="Continuous",
|
417 |
+
label="Clicking Mode",
|
418 |
+
interactive=True)
|
419 |
+
with gr.Row(scale=0.4):
|
420 |
+
clear_button_click = gr.Button(value="Clear Clicks", interactive=True)
|
421 |
+
clear_button_image = gr.Button(value="Clear Image", interactive=True)
|
422 |
+
with gr.Tab("Trajectory (beta)"):
|
423 |
+
sketcher_input = ImageSketcher(type="pil", interactive=True, brush_radius=20,
|
424 |
+
elem_id="image_sketcher")
|
425 |
+
with gr.Row():
|
426 |
+
submit_button_sketcher = gr.Button(value="Submit", interactive=True)
|
427 |
+
|
428 |
+
with gr.Column(visible=False) as modules_need_gpt1:
|
429 |
+
with gr.Row(scale=1.0):
|
430 |
+
language = gr.Dropdown(
|
431 |
+
['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"],
|
432 |
+
value="English", label="Language", interactive=True)
|
433 |
+
sentiment = gr.Radio(
|
434 |
+
choices=["Positive", "Natural", "Negative"],
|
435 |
+
value="Natural",
|
436 |
+
label="Sentiment",
|
437 |
+
interactive=True,
|
438 |
+
)
|
439 |
+
with gr.Row(scale=1.0):
|
440 |
+
factuality = gr.Radio(
|
441 |
+
choices=["Factual", "Imagination"],
|
442 |
+
value="Factual",
|
443 |
+
label="Factuality",
|
444 |
+
interactive=True,
|
445 |
+
)
|
446 |
+
length = gr.Slider(
|
447 |
+
minimum=10,
|
448 |
+
maximum=80,
|
449 |
+
value=10,
|
450 |
+
step=1,
|
451 |
+
interactive=True,
|
452 |
+
label="Generated Caption Length",
|
453 |
+
)
|
454 |
+
enable_wiki = gr.Radio(
|
455 |
+
choices=["Yes", "No"],
|
456 |
+
value="No",
|
457 |
+
label="Enable Wiki",
|
458 |
+
interactive=True)
|
459 |
+
# with gr.Column(visible=True) as modules_not_need_gpt3:
|
460 |
+
gr.Examples(
|
461 |
+
examples=examples,
|
462 |
+
inputs=[example_image],
|
463 |
+
)
|
464 |
+
with gr.Column(scale=0.5):
|
465 |
+
with gr.Column(visible=True) as module_key_input:
|
466 |
+
openai_api_key = gr.Textbox(
|
467 |
+
placeholder="Input openAI API key",
|
468 |
+
show_label=False,
|
469 |
+
label="OpenAI API Key",
|
470 |
+
lines=1,
|
471 |
+
type="password")
|
472 |
+
with gr.Row(scale=0.5):
|
473 |
+
enable_chatGPT_button = gr.Button(value="Run with ChatGPT", interactive=True, variant='primary')
|
474 |
+
disable_chatGPT_button = gr.Button(value="Run without ChatGPT (Faster)", interactive=True,
|
475 |
+
variant='primary')
|
476 |
+
with gr.Column(visible=False) as module_notification_box:
|
477 |
+
notification_box = gr.Textbox(lines=1, label="Notification", max_lines=5, show_label=False)
|
478 |
+
with gr.Column(visible=False) as modules_need_gpt2:
|
479 |
+
paragraph_output = gr.Textbox(lines=7, label="Describe Everything", max_lines=7)
|
480 |
+
with gr.Column(visible=False) as modules_need_gpt0:
|
481 |
+
cap_everything_button = gr.Button(value="Caption Everything in a Paragraph", interactive=True)
|
482 |
+
with gr.Column(visible=False) as modules_not_need_gpt2:
|
483 |
+
chatbot = gr.Chatbot(label="Chatbox", ).style(height=550, scale=0.5)
|
484 |
+
with gr.Column(visible=False) as modules_need_gpt3:
|
485 |
+
chat_input = gr.Textbox(show_label=False, placeholder="Enter text and press Enter").style(
|
486 |
+
container=False)
|
487 |
+
with gr.Row():
|
488 |
+
clear_button_text = gr.Button(value="Clear Text", interactive=True)
|
489 |
+
submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
|
490 |
+
|
491 |
+
openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key],
|
492 |
+
outputs=[modules_need_gpt0, modules_need_gpt1, modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt,
|
493 |
+
modules_not_need_gpt2, module_key_input, module_notification_box, text_refiner, visual_chatgpt, notification_box])
|
494 |
+
enable_chatGPT_button.click(init_openai_api_key, inputs=[openai_api_key],
|
495 |
+
outputs=[modules_need_gpt0, modules_need_gpt1, modules_need_gpt2, modules_need_gpt3,
|
496 |
+
modules_not_need_gpt,
|
497 |
+
modules_not_need_gpt2, module_key_input, module_notification_box, text_refiner, visual_chatgpt, notification_box])
|
498 |
+
disable_chatGPT_button.click(init_wo_openai_api_key,
|
499 |
+
outputs=[modules_need_gpt0, modules_need_gpt1, modules_need_gpt2, modules_need_gpt3,
|
500 |
+
modules_not_need_gpt,
|
501 |
+
modules_not_need_gpt2, module_key_input, module_notification_box, text_refiner, visual_chatgpt, notification_box])
|
502 |
+
|
503 |
+
enable_chatGPT_button.click(
|
504 |
+
lambda: (None, [], [], [[], [], []], "", "", ""),
|
505 |
+
[],
|
506 |
+
[image_input, chatbot, state, click_state, paragraph_output, origin_image],
|
507 |
+
queue=False,
|
508 |
+
show_progress=False
|
509 |
+
)
|
510 |
+
openai_api_key.submit(
|
511 |
+
lambda: (None, [], [], [[], [], []], "", "", ""),
|
512 |
+
[],
|
513 |
+
[image_input, chatbot, state, click_state, paragraph_output, origin_image],
|
514 |
+
queue=False,
|
515 |
+
show_progress=False
|
516 |
+
)
|
517 |
+
|
518 |
+
cap_everything_button.click(cap_everything, [origin_image, visual_chatgpt, text_refiner], [paragraph_output])
|
519 |
+
|
520 |
+
clear_button_click.click(
|
521 |
+
lambda x: ([[], [], []], x),
|
522 |
+
[origin_image],
|
523 |
+
[click_state, image_input],
|
524 |
+
queue=False,
|
525 |
+
show_progress=False
|
526 |
+
)
|
527 |
+
clear_button_click.click(functools.partial(clear_chat_memory, keep_global=True), inputs=[visual_chatgpt])
|
528 |
+
clear_button_image.click(
|
529 |
+
lambda: (None, [], [], [[], [], []], "", "", ""),
|
530 |
+
[],
|
531 |
+
[image_input, chatbot, state, click_state, paragraph_output, origin_image],
|
532 |
+
queue=False,
|
533 |
+
show_progress=False
|
534 |
+
)
|
535 |
+
clear_button_image.click(clear_chat_memory, inputs=[visual_chatgpt])
|
536 |
+
clear_button_text.click(
|
537 |
+
lambda: ([], [], [[], [], [], []]),
|
538 |
+
[],
|
539 |
+
[chatbot, state, click_state],
|
540 |
+
queue=False,
|
541 |
+
show_progress=False
|
542 |
+
)
|
543 |
+
clear_button_text.click(clear_chat_memory, inputs=[visual_chatgpt])
|
544 |
+
|
545 |
+
image_input.clear(
|
546 |
+
lambda: (None, [], [], [[], [], []], "", "", ""),
|
547 |
+
[],
|
548 |
+
[image_input, chatbot, state, click_state, paragraph_output, origin_image],
|
549 |
+
queue=False,
|
550 |
+
show_progress=False
|
551 |
+
)
|
552 |
+
|
553 |
+
image_input.clear(clear_chat_memory, inputs=[visual_chatgpt])
|
554 |
+
|
555 |
+
|
556 |
+
image_input.upload(upload_callback, [image_input, state, visual_chatgpt],
|
557 |
+
[chatbot, state, origin_image, click_state, image_input, sketcher_input,
|
558 |
+
image_embedding, original_size, input_size])
|
559 |
+
sketcher_input.upload(upload_callback, [sketcher_input, state, visual_chatgpt],
|
560 |
+
[chatbot, state, origin_image, click_state, image_input, sketcher_input,
|
561 |
+
image_embedding, original_size, input_size])
|
562 |
+
chat_input.submit(chat_input_callback, [visual_chatgpt, chat_input, click_state, state, aux_state],
|
563 |
+
[chatbot, state, aux_state])
|
564 |
+
chat_input.submit(lambda: "", None, chat_input)
|
565 |
+
submit_button_text.click(chat_input_callback, [visual_chatgpt, chat_input, click_state, state, aux_state],
|
566 |
+
[chatbot, state, aux_state])
|
567 |
+
submit_button_text.click(lambda: "", None, chat_input)
|
568 |
+
example_image.change(upload_callback, [example_image, state, visual_chatgpt],
|
569 |
+
[chatbot, state, origin_image, click_state, image_input, sketcher_input,
|
570 |
+
image_embedding, original_size, input_size])
|
571 |
+
example_image.change(clear_chat_memory, inputs=[visual_chatgpt])
|
572 |
+
# select coordinate
|
573 |
+
image_input.select(
|
574 |
+
inference_click,
|
575 |
+
inputs=[
|
576 |
+
origin_image, point_prompt, click_mode, enable_wiki, language, sentiment, factuality, length,
|
577 |
+
image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt
|
578 |
+
],
|
579 |
+
outputs=[chatbot, state, click_state, image_input],
|
580 |
+
show_progress=False, queue=True
|
581 |
+
)
|
582 |
+
|
583 |
+
submit_button_sketcher.click(
|
584 |
+
inference_traject,
|
585 |
+
inputs=[
|
586 |
+
sketcher_input, enable_wiki, language, sentiment, factuality, length, image_embedding, state,
|
587 |
+
original_size, input_size, text_refiner
|
588 |
+
],
|
589 |
+
outputs=[chatbot, state, sketcher_input],
|
590 |
+
show_progress=False, queue=True
|
591 |
+
)
|
592 |
+
|
593 |
+
return iface
|
594 |
+
|
595 |
+
|
596 |
+
if __name__ == '__main__':
|
597 |
+
iface = create_ui()
|
598 |
+
iface.queue(concurrency_count=5, api_open=False, max_size=10)
|
599 |
+
iface.launch(server_name="0.0.0.0", enable_queue=True)
|
assets/UI.png
ADDED
![]() |
Git LFS Details
|
assets/caption_anything_logo.png
ADDED
![]() |
assets/demo1.jpg
ADDED
![]() |
Git LFS Details
|
assets/demo1.png
ADDED
![]() |
Git LFS Details
|
assets/demo1.svg
ADDED
|
assets/demo2.png
ADDED
![]() |
assets/demo2.svg
ADDED
|
assets/qingming.gif
ADDED
![]() |
Git LFS Details
|
assets/times_with_simsun.ttf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0b15a12dd4bba4a48885c279a1d16590b652773f02137a7e62ede3411970c59f
|
3 |
+
size 11066612
|
assets/title.png
ADDED
![]() |
assets/title.svg
ADDED
|
caption_anything/__init__.py
ADDED
File without changes
|
caption_anything/captioner/README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
To run BLIP/BLIP2, you should install transformers from source!
|
2 |
+
```
|
3 |
+
!pip install git+https://github.com/huggingface/transformers.git
|
4 |
+
```
|
5 |
+
To run filter module, you should install CLIP repo as a Python package as follow:
|
6 |
+
```
|
7 |
+
!pip install ftfy regex tqdm
|
8 |
+
!pip install git+https://github.com/openai/CLIP.git
|
9 |
+
```
|
10 |
+
To accelerate BLIP2 with int8, you should install accelerate
|
11 |
+
```
|
12 |
+
!pip install accelerate bitsandbytes
|
13 |
+
```
|
caption_anything/captioner/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .blip import BLIPCaptioner
|
2 |
+
from .blip2 import BLIP2Captioner
|
3 |
+
from .git import GITCaptioner
|
4 |
+
from .base_captioner import BaseCaptioner
|
5 |
+
|
6 |
+
|
7 |
+
def build_captioner(type, device, args=None):
|
8 |
+
if type == 'blip':
|
9 |
+
return BLIPCaptioner(device, enable_filter=args.clip_filter)
|
10 |
+
elif type == 'blip2':
|
11 |
+
return BLIP2Captioner(device, enable_filter=args.clip_filter)
|
12 |
+
elif type == 'git':
|
13 |
+
return GITCaptioner(device, enable_filter=args.clip_filter)
|
14 |
+
else:
|
15 |
+
raise NotImplementedError("")
|
caption_anything/captioner/base_captioner.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image, ImageDraw, ImageOps
|
3 |
+
from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
|
4 |
+
import json
|
5 |
+
import pdb
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
from typing import Any, Union, List
|
9 |
+
import time
|
10 |
+
import clip
|
11 |
+
|
12 |
+
from caption_anything.utils.utils import load_image
|
13 |
+
|
14 |
+
|
15 |
+
def boundary(inputs):
|
16 |
+
col = inputs.shape[1]
|
17 |
+
inputs = inputs.reshape(-1)
|
18 |
+
lens = len(inputs)
|
19 |
+
start = np.argmax(inputs)
|
20 |
+
end = lens - 1 - np.argmax(np.flip(inputs))
|
21 |
+
top = start // col
|
22 |
+
bottom = end // col
|
23 |
+
return top, bottom
|
24 |
+
|
25 |
+
|
26 |
+
def new_seg_to_box(seg_mask: Union[np.ndarray, Image.Image, str]):
|
27 |
+
if type(seg_mask) == str:
|
28 |
+
seg_mask = Image.open(seg_mask)
|
29 |
+
elif type(seg_mask) == np.ndarray:
|
30 |
+
seg_mask = Image.fromarray(seg_mask)
|
31 |
+
seg_mask = np.array(seg_mask) > 0
|
32 |
+
size = max(seg_mask.shape[0], seg_mask.shape[1])
|
33 |
+
top, bottom = boundary(seg_mask)
|
34 |
+
left, right = boundary(seg_mask.T)
|
35 |
+
return [left / size, top / size, right / size, bottom / size]
|
36 |
+
|
37 |
+
|
38 |
+
def seg_to_box(seg_mask: Union[np.ndarray, Image.Image, str]):
|
39 |
+
if type(seg_mask) == str:
|
40 |
+
seg_mask = cv2.imread(seg_mask, cv2.IMREAD_GRAYSCALE)
|
41 |
+
_, seg_mask = cv2.threshold(seg_mask, 127, 255, 0)
|
42 |
+
elif type(seg_mask) == np.ndarray:
|
43 |
+
assert seg_mask.ndim == 2 # only support single-channel segmentation mask
|
44 |
+
seg_mask = seg_mask.astype('uint8')
|
45 |
+
if seg_mask.dtype == 'bool':
|
46 |
+
seg_mask = seg_mask * 255
|
47 |
+
contours, hierarchy = cv2.findContours(seg_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
48 |
+
contours = np.concatenate(contours, axis=0)
|
49 |
+
rect = cv2.minAreaRect(contours)
|
50 |
+
box = cv2.boxPoints(rect)
|
51 |
+
if rect[-1] >= 45:
|
52 |
+
newstart = box.argmin(axis=0)[1] # leftmost
|
53 |
+
else:
|
54 |
+
newstart = box.argmax(axis=0)[0] # topmost
|
55 |
+
box = np.concatenate([box[newstart:], box[:newstart]], axis=0)
|
56 |
+
box = np.int0(box)
|
57 |
+
return box
|
58 |
+
|
59 |
+
|
60 |
+
def get_w_h(rect_points):
|
61 |
+
w = np.linalg.norm(rect_points[0] - rect_points[1], ord=2).astype('int')
|
62 |
+
h = np.linalg.norm(rect_points[0] - rect_points[3], ord=2).astype('int')
|
63 |
+
return w, h
|
64 |
+
|
65 |
+
|
66 |
+
def cut_box(img, rect_points):
|
67 |
+
w, h = get_w_h(rect_points)
|
68 |
+
dst_pts = np.array([[h, 0], [h, w], [0, w], [0, 0], ], dtype="float32")
|
69 |
+
transform = cv2.getPerspectiveTransform(rect_points.astype("float32"), dst_pts)
|
70 |
+
cropped_img = cv2.warpPerspective(img, transform, (h, w))
|
71 |
+
return cropped_img
|
72 |
+
|
73 |
+
|
74 |
+
class BaseCaptioner:
|
75 |
+
def __init__(self, device, enable_filter=False):
|
76 |
+
print(f"Initializing ImageCaptioning to {device}")
|
77 |
+
self.device = device
|
78 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
79 |
+
self.processor = None
|
80 |
+
self.model = None
|
81 |
+
self.enable_filter = enable_filter
|
82 |
+
if enable_filter:
|
83 |
+
self.filter, self.preprocess = clip.load('ViT-B/32', device)
|
84 |
+
|
85 |
+
@torch.no_grad()
|
86 |
+
def filter_caption(self, image: Union[np.ndarray, Image.Image, str], caption: str, reference_caption: List[str]=[]):
|
87 |
+
image = load_image(image, return_type='pil')
|
88 |
+
image = self.preprocess(image).unsqueeze(0).to(self.device) # (1, 3, 224, 224)
|
89 |
+
captions = [caption]
|
90 |
+
if len(reference_caption):
|
91 |
+
captions.extend(reference_caption)
|
92 |
+
text = clip.tokenize(captions).to(self.device) # (>1, 77)
|
93 |
+
image_features = self.filter.encode_image(image) # (1, 512)
|
94 |
+
text_features = self.filter.encode_text(text) # # (>1, 512)
|
95 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
96 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
97 |
+
|
98 |
+
if len(reference_caption):
|
99 |
+
similarity = torch.matmul(image_features, text_features.transpose(1, 0)) / 0.07
|
100 |
+
similarity = similarity.softmax(dim=1)[0, 0].item()
|
101 |
+
else:
|
102 |
+
similarity = torch.matmul(image_features, text_features.transpose(1, 0)).item()
|
103 |
+
print(f'Clip score of the caption is {similarity}')
|
104 |
+
return similarity
|
105 |
+
|
106 |
+
def inference(self, image: Union[np.ndarray, Image.Image, str], filter: bool = False):
|
107 |
+
raise NotImplementedError()
|
108 |
+
|
109 |
+
def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, filter: bool = False):
|
110 |
+
raise NotImplementedError()
|
111 |
+
|
112 |
+
def inference_box(self, image: Union[np.ndarray, Image.Image, str], box: Union[list, np.ndarray], filter=False, verbose=False, caption_args={}):
|
113 |
+
image = load_image(image, return_type="pil")
|
114 |
+
|
115 |
+
if np.array(box).size == 4:
|
116 |
+
# [x0, y0, x1, y1], where (x0, y0), (x1, y1) represent top-left and bottom-right corners
|
117 |
+
size = max(image.width, image.height)
|
118 |
+
x1, y1, x2, y2 = box
|
119 |
+
image_crop = np.array(image.crop((x1 * size, y1 * size, x2 * size, y2 * size)))
|
120 |
+
elif np.array(box).size == 8: # four corners of an irregular rectangle
|
121 |
+
image_crop = cut_box(np.array(image), box)
|
122 |
+
|
123 |
+
crop_save_path = None
|
124 |
+
if verbose:
|
125 |
+
crop_save_path = f'result/crop_{time.time()}.png'
|
126 |
+
Image.fromarray(image_crop).save(crop_save_path)
|
127 |
+
print(f'croped image saved in {crop_save_path}')
|
128 |
+
caption = self.inference(image_crop, filter, caption_args)
|
129 |
+
caption.update({'crop_save_path': crop_save_path})
|
130 |
+
return caption
|
131 |
+
|
132 |
+
def inference_seg(self,
|
133 |
+
image: Union[np.ndarray, str],
|
134 |
+
seg_mask: Union[np.ndarray, Image.Image, str] = None,
|
135 |
+
crop_mode="w_bg",
|
136 |
+
filter=False,
|
137 |
+
disable_regular_box=False,
|
138 |
+
verbose=False,
|
139 |
+
caption_args={}):
|
140 |
+
if seg_mask is None:
|
141 |
+
seg_mask = np.ones(image.size).astype(bool)
|
142 |
+
|
143 |
+
image = load_image(image, return_type="pil")
|
144 |
+
seg_mask = load_image(seg_mask, return_type="pil")
|
145 |
+
|
146 |
+
seg_mask = seg_mask.resize(image.size)
|
147 |
+
seg_mask = np.array(seg_mask) > 0
|
148 |
+
if crop_mode == "wo_bg":
|
149 |
+
image = np.array(image) * seg_mask[:, :, np.newaxis] + (1 - seg_mask[:, :, np.newaxis]) * 255
|
150 |
+
image = np.uint8(image)
|
151 |
+
else:
|
152 |
+
image = np.array(image)
|
153 |
+
|
154 |
+
if disable_regular_box:
|
155 |
+
min_area_box = seg_to_box(seg_mask)
|
156 |
+
else:
|
157 |
+
min_area_box = new_seg_to_box(seg_mask)
|
158 |
+
return self.inference_box(image, min_area_box, filter, verbose, caption_args)
|
159 |
+
|
160 |
+
def generate_seg_cropped_image(self,
|
161 |
+
image: Union[np.ndarray, str],
|
162 |
+
seg_mask: Union[np.ndarray, Image.Image, str],
|
163 |
+
crop_mode="w_bg",
|
164 |
+
disable_regular_box=False):
|
165 |
+
image = load_image(image, return_type="pil")
|
166 |
+
seg_mask = load_image(seg_mask, return_type="pil")
|
167 |
+
|
168 |
+
seg_mask = seg_mask.resize(image.size)
|
169 |
+
seg_mask = np.array(seg_mask) > 0
|
170 |
+
|
171 |
+
if crop_mode == "wo_bg":
|
172 |
+
image = np.array(image) * seg_mask[:, :, np.newaxis] + (1 - seg_mask[:, :, np.newaxis]) * 255
|
173 |
+
else:
|
174 |
+
image = np.array(image)
|
175 |
+
|
176 |
+
if disable_regular_box:
|
177 |
+
box = seg_to_box(seg_mask)
|
178 |
+
else:
|
179 |
+
box = new_seg_to_box(seg_mask)
|
180 |
+
|
181 |
+
if np.array(box).size == 4:
|
182 |
+
# [x0, y0, x1, y1], where (x0, y0), (x1, y1) represent top-left and bottom-right corners
|
183 |
+
size = max(image.shape[0], image.shape[1])
|
184 |
+
x1, y1, x2, y2 = box
|
185 |
+
image_crop = np.array(image.crop((x1 * size, y1 * size, x2 * size, y2 * size)))
|
186 |
+
elif np.array(box).size == 8: # four corners of an irregular rectangle
|
187 |
+
image_crop = cut_box(np.array(image), box)
|
188 |
+
crop_save_path = f'result/crop_{time.time()}.png'
|
189 |
+
Image.fromarray(image_crop).save(crop_save_path)
|
190 |
+
print(f'croped image saved in {crop_save_path}')
|
191 |
+
return crop_save_path
|
192 |
+
|
193 |
+
|
194 |
+
if __name__ == '__main__':
|
195 |
+
model = BaseCaptioner(device='cuda:0')
|
196 |
+
image_path = 'test_images/img2.jpg'
|
197 |
+
seg_mask = np.zeros((15, 15))
|
198 |
+
seg_mask[5:10, 5:10] = 1
|
199 |
+
seg_mask = 'image/SAM/img10.jpg.raw_mask.png'
|
200 |
+
print(model.inference_seg(image_path, seg_mask))
|
caption_anything/captioner/blip.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
from transformers import BlipProcessor
|
4 |
+
|
5 |
+
from caption_anything.utils.utils import load_image
|
6 |
+
from .modeling_blip import BlipForConditionalGeneration
|
7 |
+
import numpy as np
|
8 |
+
from typing import Union
|
9 |
+
from .base_captioner import BaseCaptioner
|
10 |
+
import torchvision.transforms.functional as F
|
11 |
+
|
12 |
+
|
13 |
+
class BLIPCaptioner(BaseCaptioner):
|
14 |
+
def __init__(self, device, enable_filter=False):
|
15 |
+
super().__init__(device, enable_filter)
|
16 |
+
self.device = device
|
17 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
18 |
+
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
|
19 |
+
self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large",
|
20 |
+
torch_dtype=self.torch_dtype).to(self.device)
|
21 |
+
|
22 |
+
@torch.no_grad()
|
23 |
+
def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False, args={}):
|
24 |
+
image = load_image(image, return_type="pil")
|
25 |
+
inputs = self.processor(image, return_tensors="pt").to(self.device, self.torch_dtype)
|
26 |
+
out = self.model.generate(**inputs, max_new_tokens=50)
|
27 |
+
captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
|
28 |
+
|
29 |
+
result = {}
|
30 |
+
if self.enable_filter and filter:
|
31 |
+
clip_score = self.filter_caption(image, captions)
|
32 |
+
result['clip_score'] = clip_score
|
33 |
+
result.update({'caption':captions})
|
34 |
+
print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
|
35 |
+
return {'caption': captions}
|
36 |
+
|
37 |
+
@torch.no_grad()
|
38 |
+
def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, crop_mode="w_bg",
|
39 |
+
filter=False, disable_regular_box=False):
|
40 |
+
result = {}
|
41 |
+
crop_save_path = self.generate_seg_cropped_image(image=image, seg_mask=seg_mask, crop_mode=crop_mode,
|
42 |
+
disable_regular_box=disable_regular_box)
|
43 |
+
image = load_image(image, return_type="pil")
|
44 |
+
inputs = self.processor(image, return_tensors="pt")
|
45 |
+
pixel_values = inputs.pixel_values.to(self.device, self.torch_dtype)
|
46 |
+
_, _, H, W = pixel_values.shape
|
47 |
+
seg_mask = Image.fromarray(seg_mask.astype(float))
|
48 |
+
seg_mask = seg_mask.resize((H, W))
|
49 |
+
seg_mask = F.pil_to_tensor(seg_mask) > 0.5
|
50 |
+
seg_mask = seg_mask.float()
|
51 |
+
pixel_masks = seg_mask.unsqueeze(0).to(self.device)
|
52 |
+
out = self.model.generate(pixel_values=pixel_values, pixel_masks=pixel_masks, max_new_tokens=50)
|
53 |
+
captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
|
54 |
+
if self.enable_filter and filter:
|
55 |
+
clip_score = self.filter_caption(image, captions)
|
56 |
+
result['clip_score'] = clip_score
|
57 |
+
result.update({'caption':captions, 'crop_save_path':crop_save_path})
|
58 |
+
print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
|
59 |
+
return result
|
60 |
+
|
61 |
+
|
62 |
+
if __name__ == '__main__':
|
63 |
+
model = BLIPCaptioner(device='cuda:0')
|
64 |
+
# image_path = 'test_images/img2.jpg'
|
65 |
+
image_path = 'image/SAM/img10.jpg'
|
66 |
+
seg_mask = np.zeros((15, 15))
|
67 |
+
seg_mask[5:10, 5:10] = 1
|
68 |
+
seg_mask = 'test_images/img10.jpg.raw_mask.png'
|
69 |
+
image_path = 'test_images/img2.jpg'
|
70 |
+
seg_mask = 'test_images/img2.jpg.raw_mask.png'
|
71 |
+
print(f'process image {image_path}')
|
72 |
+
print(model.inference_with_reduced_tokens(image_path, seg_mask))
|
caption_anything/captioner/blip2.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
from typing import Union
|
5 |
+
from transformers import AutoProcessor, Blip2ForConditionalGeneration
|
6 |
+
|
7 |
+
from caption_anything.utils.utils import is_platform_win, load_image
|
8 |
+
from .base_captioner import BaseCaptioner
|
9 |
+
import time
|
10 |
+
|
11 |
+
class BLIP2Captioner(BaseCaptioner):
|
12 |
+
def __init__(self, device, dialogue: bool = False, enable_filter: bool = False):
|
13 |
+
super().__init__(device, enable_filter)
|
14 |
+
self.device = device
|
15 |
+
self.dialogue = dialogue
|
16 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
17 |
+
self.processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
18 |
+
if is_platform_win():
|
19 |
+
self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", device_map="sequential", torch_dtype=self.torch_dtype)
|
20 |
+
else:
|
21 |
+
self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", device_map='sequential', load_in_8bit=True)
|
22 |
+
|
23 |
+
@torch.no_grad()
|
24 |
+
def inference(self,
|
25 |
+
image: Union[np.ndarray, Image.Image, str],
|
26 |
+
filter=False,
|
27 |
+
args={}):
|
28 |
+
args['return_ppl'] = args.get('return_ppl', False)
|
29 |
+
args['text_prompt'] = args.get('text_prompt', 'Question: what does the image show? Answer:')
|
30 |
+
args['reference_caption'] = args.get('reference_caption', [])
|
31 |
+
|
32 |
+
image = load_image(image, return_type="pil")
|
33 |
+
result = {}
|
34 |
+
if not self.dialogue:
|
35 |
+
inputs = self.processor(image, text = args['text_prompt'], return_tensors="pt").to(self.device, self.torch_dtype)
|
36 |
+
out = self.model.generate(**inputs, return_dict_in_generate=True, output_scores=True, max_new_tokens=50)
|
37 |
+
caption = self.processor.decode(out.sequences[0], skip_special_tokens=True).strip()
|
38 |
+
if self.enable_filter and filter:
|
39 |
+
print('reference caption: {}, caption: {}'.format(args['reference_caption'], caption))
|
40 |
+
clip_score = self.filter_caption(image, caption, args['reference_caption'])
|
41 |
+
result['clip_score'] = clip_score
|
42 |
+
if args['return_ppl']:
|
43 |
+
ppl_score = torch.stack(out.scores, dim=1).softmax(dim=2).log().max(dim=2)[0].sum(dim=1)[0]
|
44 |
+
result['ppl_score'] = ppl_score.item()
|
45 |
+
print(f"\nProcessed ImageCaptioning by BLIP2Captioner, Output Text: {caption}")
|
46 |
+
result['caption'] = caption
|
47 |
+
return result
|
48 |
+
else:
|
49 |
+
context = []
|
50 |
+
template = "Question: {} Answer: {}."
|
51 |
+
while(True):
|
52 |
+
input_texts = input()
|
53 |
+
if input_texts == 'end':
|
54 |
+
break
|
55 |
+
prompt = " ".join([template.format(context[i][0], context[i][1]) for i in range(len(context))]) + " Question: " + input_texts + " Answer:"
|
56 |
+
inputs = self.processor(image, text = prompt, return_tensors="pt").to(self.device, self.torch_dtype)
|
57 |
+
out = self.model.generate(**inputs, max_new_tokens=50)
|
58 |
+
captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
|
59 |
+
context.append((input_texts, captions))
|
60 |
+
result['caption'] = captions
|
61 |
+
return result
|
62 |
+
|
63 |
+
if __name__ == '__main__':
|
64 |
+
|
65 |
+
dialogue = False
|
66 |
+
model = BLIP2Captioner(device='cuda:4', dialogue = dialogue, cache_dir = '/nvme-ssd/fjj/Caption-Anything/model_cache')
|
67 |
+
image_path = 'test_images/img2.jpg'
|
68 |
+
seg_mask = np.zeros((224,224))
|
69 |
+
seg_mask[50:200, 50:200] = 1
|
70 |
+
print(f'process image {image_path}')
|
71 |
+
print(model.inference_seg(image_path, seg_mask))
|
caption_anything/captioner/git.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import GitProcessor, AutoProcessor
|
2 |
+
|
3 |
+
from caption_anything.utils.utils import load_image
|
4 |
+
from .modeling_git import GitForCausalLM
|
5 |
+
from PIL import Image
|
6 |
+
import torch
|
7 |
+
from .base_captioner import BaseCaptioner
|
8 |
+
import numpy as np
|
9 |
+
from typing import Union
|
10 |
+
import torchvision.transforms.functional as F
|
11 |
+
|
12 |
+
|
13 |
+
class GITCaptioner(BaseCaptioner):
|
14 |
+
def __init__(self, device, enable_filter=False):
|
15 |
+
super().__init__(device, enable_filter)
|
16 |
+
self.device = device
|
17 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
18 |
+
self.processor = AutoProcessor.from_pretrained("microsoft/git-large")
|
19 |
+
self.model = GitForCausalLM.from_pretrained("microsoft/git-large", torch_dtype=self.torch_dtype).to(self.device)
|
20 |
+
|
21 |
+
@torch.no_grad()
|
22 |
+
def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False, args={}):
|
23 |
+
image = load_image(image, return_type="pil")
|
24 |
+
pixel_values = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device, self.torch_dtype)
|
25 |
+
generated_ids = self.model.generate(pixel_values=pixel_values, max_new_tokens=50)
|
26 |
+
captions = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
27 |
+
|
28 |
+
result = {}
|
29 |
+
if self.enable_filter and filter:
|
30 |
+
clip_score = self.filter_caption(image, captions)
|
31 |
+
result['clip_score'] = clip_score
|
32 |
+
result.update({'caption':captions})
|
33 |
+
print(f"\nProcessed ImageCaptioning by GITCaptioner, Output Text: {captions}")
|
34 |
+
return {'caption': captions}
|
35 |
+
|
36 |
+
@torch.no_grad()
|
37 |
+
def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, crop_mode="w_bg",
|
38 |
+
filter=False, disable_regular_box=False):
|
39 |
+
result = {}
|
40 |
+
crop_save_path = self.generate_seg_cropped_image(image=image, seg_mask=seg_mask, crop_mode=crop_mode,
|
41 |
+
disable_regular_box=disable_regular_box)
|
42 |
+
image = load_image(image, return_type="pil")
|
43 |
+
inputs = self.processor(images=image, return_tensors="pt")
|
44 |
+
pixel_values = inputs.pixel_values.to(self.device, self.torch_dtype)
|
45 |
+
_, _, H, W = pixel_values.shape
|
46 |
+
seg_mask = Image.fromarray(seg_mask.astype(float))
|
47 |
+
seg_mask = seg_mask.resize((H, W))
|
48 |
+
seg_mask = F.pil_to_tensor(seg_mask) > 0.5
|
49 |
+
seg_mask = seg_mask.float()
|
50 |
+
pixel_masks = seg_mask.unsqueeze(0).to(self.device)
|
51 |
+
out = self.model.generate(pixel_values=pixel_values, pixel_masks=pixel_masks, max_new_tokens=50)
|
52 |
+
captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
|
53 |
+
if self.enable_filter and filter:
|
54 |
+
clip_score = self.filter_caption(image, captions)
|
55 |
+
result['clip_score'] = clip_score
|
56 |
+
print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
|
57 |
+
result.update({'caption':captions, 'crop_save_path':crop_save_path})
|
58 |
+
return result
|
59 |
+
|
60 |
+
|
61 |
+
if __name__ == '__main__':
|
62 |
+
model = GITCaptioner(device='cuda:2', enable_filter=False)
|
63 |
+
image_path = 'test_images/img2.jpg'
|
64 |
+
seg_mask = np.zeros((224, 224))
|
65 |
+
seg_mask[50:200, 50:200] = 1
|
66 |
+
print(f'process image {image_path}')
|
67 |
+
print(model.inference_with_reduced_tokens(image_path, seg_mask))
|
caption_anything/captioner/modeling_blip.py
ADDED
@@ -0,0 +1,1476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The Salesforce Team Authors and The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch BLIP model."""
|
16 |
+
|
17 |
+
from dataclasses import dataclass
|
18 |
+
from typing import Any, Optional, Tuple, Union
|
19 |
+
|
20 |
+
import torch
|
21 |
+
import torch.utils.checkpoint
|
22 |
+
from torch import nn
|
23 |
+
from torch.nn.functional import normalize
|
24 |
+
|
25 |
+
from transformers.activations import ACT2FN
|
26 |
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
27 |
+
from transformers.modeling_utils import PreTrainedModel
|
28 |
+
from transformers.utils import (
|
29 |
+
ModelOutput,
|
30 |
+
add_start_docstrings,
|
31 |
+
add_start_docstrings_to_model_forward,
|
32 |
+
logging,
|
33 |
+
replace_return_docstrings,
|
34 |
+
)
|
35 |
+
from transformers.models.blip.configuration_blip import BlipConfig, BlipTextConfig, BlipVisionConfig
|
36 |
+
from transformers.models.blip.modeling_blip_text import BlipTextLMHeadModel, BlipTextModel
|
37 |
+
from .vit_pixel_masks_utils import ViTPatchMaskGenerator
|
38 |
+
|
39 |
+
logger = logging.get_logger(__name__)
|
40 |
+
|
41 |
+
_CHECKPOINT_FOR_DOC = "Salesforce/blip-vqa-base"
|
42 |
+
|
43 |
+
BLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
44 |
+
"Salesforce/blip-vqa-base",
|
45 |
+
"Salesforce/blip-vqa-capfit-large",
|
46 |
+
"Salesforce/blip-image-captioning-base",
|
47 |
+
"Salesforce/blip-image-captioning-large",
|
48 |
+
"Salesforce/blip-itm-base-coco",
|
49 |
+
"Salesforce/blip-itm-large-coco",
|
50 |
+
"Salesforce/blip-itm-base-flikr",
|
51 |
+
"Salesforce/blip-itm-large-flikr",
|
52 |
+
# See all BLIP models at https://huggingface.co/models?filter=blip
|
53 |
+
]
|
54 |
+
|
55 |
+
|
56 |
+
# Copied from transformers.models.clip.modeling_clip.contrastive_loss
|
57 |
+
def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
|
58 |
+
return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
|
59 |
+
|
60 |
+
|
61 |
+
# Copied from transformers.models.clip.modeling_clip.clip_loss with clip->blip
|
62 |
+
def blip_loss(similarity: torch.Tensor) -> torch.Tensor:
|
63 |
+
caption_loss = contrastive_loss(similarity)
|
64 |
+
image_loss = contrastive_loss(similarity.t())
|
65 |
+
return (caption_loss + image_loss) / 2.0
|
66 |
+
|
67 |
+
|
68 |
+
@dataclass
|
69 |
+
class BlipForConditionalGenerationModelOutput(ModelOutput):
|
70 |
+
"""
|
71 |
+
Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the
|
72 |
+
last hidden states. This class also adds the loss term from the text decoder.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
|
76 |
+
Languge modeling loss from the text decoder.
|
77 |
+
decoder_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*):
|
78 |
+
Prediction scores of the language modeling head of the text decoder model.
|
79 |
+
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*):
|
80 |
+
The image embeddings obtained after applying the Vision Transformer model to the input image.
|
81 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
82 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
83 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
|
84 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
85 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
86 |
+
|
87 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
88 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed):
|
89 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
90 |
+
sequence_length)`.
|
91 |
+
|
92 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
93 |
+
heads.
|
94 |
+
"""
|
95 |
+
|
96 |
+
loss: Optional[Tuple[torch.FloatTensor]] = None
|
97 |
+
decoder_logits: Optional[Tuple[torch.FloatTensor]] = None
|
98 |
+
image_embeds: Optional[torch.FloatTensor] = None
|
99 |
+
last_hidden_state: torch.FloatTensor = None
|
100 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
101 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
102 |
+
|
103 |
+
|
104 |
+
@dataclass
|
105 |
+
class BlipTextVisionModelOutput(ModelOutput):
|
106 |
+
"""
|
107 |
+
Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the
|
108 |
+
last hidden states. This class also adds the loss term from the text decoder.
|
109 |
+
|
110 |
+
Args:
|
111 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
112 |
+
Languge modeling loss from the text decoder.
|
113 |
+
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
|
114 |
+
The image embeddings obtained by applying the projection layer to the pooler_output.
|
115 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
116 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
117 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
118 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
119 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
120 |
+
|
121 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
122 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
123 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
124 |
+
sequence_length)`.
|
125 |
+
|
126 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
127 |
+
heads.
|
128 |
+
"""
|
129 |
+
|
130 |
+
loss: Optional[torch.FloatTensor] = None
|
131 |
+
image_embeds: Optional[torch.FloatTensor] = None
|
132 |
+
last_hidden_state: torch.FloatTensor = None
|
133 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
134 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
135 |
+
|
136 |
+
|
137 |
+
@dataclass
|
138 |
+
class BlipImageTextMatchingModelOutput(ModelOutput):
|
139 |
+
"""
|
140 |
+
Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the
|
141 |
+
last hidden states. This class also adds the loss term from the text decoder as well as the image-text similarity
|
142 |
+
scores.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
itm_score (`torch.FloatTensor`):
|
146 |
+
The image-text similarity scores.
|
147 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
148 |
+
Languge modeling loss from the text decoder.
|
149 |
+
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
|
150 |
+
The image embeddings obtained by applying the projection layer to the pooler_output.
|
151 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
152 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
153 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
154 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
155 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
156 |
+
|
157 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
158 |
+
vision_pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*):
|
159 |
+
Last layer hidden-state of the vision of the vision-only branch of the model.
|
160 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
161 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
162 |
+
sequence_length)`.
|
163 |
+
|
164 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
165 |
+
heads.
|
166 |
+
question_embeds (`torch.FloatTensor`):
|
167 |
+
The question embeddings obtained by the text projection layer.
|
168 |
+
"""
|
169 |
+
|
170 |
+
itm_score: Optional[torch.FloatTensor] = None
|
171 |
+
loss: Optional[torch.FloatTensor] = None
|
172 |
+
image_embeds: Optional[torch.FloatTensor] = None
|
173 |
+
last_hidden_state: torch.FloatTensor = None
|
174 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
175 |
+
vision_pooler_output: Optional[torch.FloatTensor] = None
|
176 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
177 |
+
question_embeds: Optional[Tuple[torch.FloatTensor]] = None
|
178 |
+
|
179 |
+
|
180 |
+
@dataclass
|
181 |
+
class BlipOutput(ModelOutput):
|
182 |
+
"""
|
183 |
+
Args:
|
184 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
|
185 |
+
Contrastive loss for image-text similarity.
|
186 |
+
logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
|
187 |
+
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
|
188 |
+
similarity scores.
|
189 |
+
logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
|
190 |
+
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
|
191 |
+
similarity scores.
|
192 |
+
text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
193 |
+
The text embeddings obtained by applying the projection layer to the pooled output of [`BlipTextModel`].
|
194 |
+
image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
195 |
+
The image embeddings obtained by applying the projection layer to the pooled output of [`BlipVisionModel`].
|
196 |
+
text_model_output(`BaseModelOutputWithPooling`):
|
197 |
+
The output of the [`BlipTextModel`].
|
198 |
+
vision_model_output(`BaseModelOutputWithPooling`):
|
199 |
+
The output of the [`BlipVisionModel`].
|
200 |
+
"""
|
201 |
+
|
202 |
+
loss: Optional[torch.FloatTensor] = None
|
203 |
+
logits_per_image: torch.FloatTensor = None
|
204 |
+
logits_per_text: torch.FloatTensor = None
|
205 |
+
text_embeds: torch.FloatTensor = None
|
206 |
+
image_embeds: torch.FloatTensor = None
|
207 |
+
text_model_output: BaseModelOutputWithPooling = None
|
208 |
+
vision_model_output: BaseModelOutputWithPooling = None
|
209 |
+
|
210 |
+
def to_tuple(self) -> Tuple[Any]:
|
211 |
+
return tuple(
|
212 |
+
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
|
213 |
+
for k in self.keys()
|
214 |
+
)
|
215 |
+
|
216 |
+
|
217 |
+
class BlipVisionEmbeddings(nn.Module):
|
218 |
+
def __init__(self, config: BlipVisionConfig):
|
219 |
+
super().__init__()
|
220 |
+
self.config = config
|
221 |
+
self.embed_dim = config.hidden_size
|
222 |
+
self.image_size = config.image_size
|
223 |
+
self.patch_size = config.patch_size
|
224 |
+
|
225 |
+
self.class_embedding = nn.Parameter(
|
226 |
+
torch.randn(1, 1, self.embed_dim),
|
227 |
+
)
|
228 |
+
|
229 |
+
self.patch_embedding = nn.Conv2d(
|
230 |
+
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
|
231 |
+
)
|
232 |
+
|
233 |
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
234 |
+
self.num_positions = self.num_patches + 1
|
235 |
+
|
236 |
+
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
|
237 |
+
|
238 |
+
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
239 |
+
batch_size = pixel_values.shape[0]
|
240 |
+
target_dtype = self.patch_embedding.weight.dtype
|
241 |
+
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
|
242 |
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
243 |
+
|
244 |
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
|
245 |
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
246 |
+
embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype)
|
247 |
+
return embeddings
|
248 |
+
|
249 |
+
|
250 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Blip
|
251 |
+
class BlipTextEmbeddings(nn.Module):
|
252 |
+
def __init__(self, config: BlipTextConfig):
|
253 |
+
super().__init__()
|
254 |
+
embed_dim = config.hidden_size
|
255 |
+
|
256 |
+
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
|
257 |
+
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
|
258 |
+
|
259 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
260 |
+
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
261 |
+
|
262 |
+
def forward(
|
263 |
+
self,
|
264 |
+
input_ids: Optional[torch.LongTensor] = None,
|
265 |
+
position_ids: Optional[torch.LongTensor] = None,
|
266 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
267 |
+
) -> torch.Tensor:
|
268 |
+
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
269 |
+
|
270 |
+
if position_ids is None:
|
271 |
+
position_ids = self.position_ids[:, :seq_length]
|
272 |
+
|
273 |
+
if inputs_embeds is None:
|
274 |
+
inputs_embeds = self.token_embedding(input_ids)
|
275 |
+
|
276 |
+
position_embeddings = self.position_embedding(position_ids)
|
277 |
+
embeddings = inputs_embeds + position_embeddings
|
278 |
+
|
279 |
+
return embeddings
|
280 |
+
|
281 |
+
|
282 |
+
class BlipAttention(nn.Module):
|
283 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
284 |
+
|
285 |
+
def __init__(self, config):
|
286 |
+
super().__init__()
|
287 |
+
self.config = config
|
288 |
+
self.embed_dim = config.hidden_size
|
289 |
+
self.num_heads = config.num_attention_heads
|
290 |
+
self.head_dim = self.embed_dim // self.num_heads
|
291 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
292 |
+
raise ValueError(
|
293 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
294 |
+
f" {self.num_heads})."
|
295 |
+
)
|
296 |
+
self.scale = self.head_dim**-0.5
|
297 |
+
self.dropout = nn.Dropout(config.attention_dropout)
|
298 |
+
|
299 |
+
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim)
|
300 |
+
|
301 |
+
self.projection = nn.Linear(self.embed_dim, self.embed_dim)
|
302 |
+
|
303 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
304 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
305 |
+
|
306 |
+
def forward(
|
307 |
+
self,
|
308 |
+
hidden_states: torch.Tensor,
|
309 |
+
head_mask: Optional[torch.Tensor] = None,
|
310 |
+
output_attentions: Optional[bool] = False,
|
311 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
312 |
+
"""Input shape: Batch x Time x Channel"""
|
313 |
+
|
314 |
+
bsz, tgt_len, embed_dim = hidden_states.size()
|
315 |
+
|
316 |
+
mixed_qkv = self.qkv(hidden_states)
|
317 |
+
mixed_qkv = (
|
318 |
+
self.qkv(hidden_states)
|
319 |
+
.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads)
|
320 |
+
.permute(2, 0, 3, 1, 4)
|
321 |
+
)
|
322 |
+
query_states, key_states, value_states = (
|
323 |
+
mixed_qkv[0],
|
324 |
+
mixed_qkv[1],
|
325 |
+
mixed_qkv[2],
|
326 |
+
)
|
327 |
+
|
328 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
329 |
+
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
|
330 |
+
|
331 |
+
attention_scores = attention_scores * self.scale
|
332 |
+
|
333 |
+
# Normalize the attention scores to probabilities.
|
334 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
335 |
+
|
336 |
+
# This is actually dropping out entire tokens to attend to, which might
|
337 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
338 |
+
attention_probs = self.dropout(attention_probs)
|
339 |
+
|
340 |
+
# Mask heads if we want to
|
341 |
+
if head_mask is not None:
|
342 |
+
attention_probs = attention_probs * head_mask
|
343 |
+
|
344 |
+
context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)
|
345 |
+
|
346 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,)
|
347 |
+
context_layer = context_layer.reshape(new_context_layer_shape)
|
348 |
+
|
349 |
+
output = self.projection(context_layer)
|
350 |
+
|
351 |
+
outputs = (output, attention_probs) if output_attentions else (output, None)
|
352 |
+
|
353 |
+
return outputs
|
354 |
+
|
355 |
+
|
356 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Blip
|
357 |
+
class BlipMLP(nn.Module):
|
358 |
+
def __init__(self, config):
|
359 |
+
super().__init__()
|
360 |
+
self.config = config
|
361 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
362 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
363 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
364 |
+
|
365 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
366 |
+
hidden_states = self.fc1(hidden_states)
|
367 |
+
hidden_states = self.activation_fn(hidden_states)
|
368 |
+
hidden_states = self.fc2(hidden_states)
|
369 |
+
return hidden_states
|
370 |
+
|
371 |
+
|
372 |
+
class BlipEncoderLayer(nn.Module):
|
373 |
+
def __init__(self, config: BlipConfig):
|
374 |
+
super().__init__()
|
375 |
+
self.embed_dim = config.hidden_size
|
376 |
+
self.self_attn = BlipAttention(config)
|
377 |
+
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
378 |
+
self.mlp = BlipMLP(config)
|
379 |
+
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
380 |
+
|
381 |
+
def forward(
|
382 |
+
self,
|
383 |
+
hidden_states: torch.Tensor,
|
384 |
+
attention_mask: torch.Tensor,
|
385 |
+
output_attentions: Optional[bool] = False,
|
386 |
+
) -> Tuple[torch.FloatTensor]:
|
387 |
+
"""
|
388 |
+
Args:
|
389 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
390 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
391 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
392 |
+
`(config.encoder_attention_heads,)`.
|
393 |
+
output_attentions (`bool`, *optional*):
|
394 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
395 |
+
returned tensors for more detail.
|
396 |
+
"""
|
397 |
+
residual = hidden_states
|
398 |
+
|
399 |
+
hidden_states = self.layer_norm1(hidden_states)
|
400 |
+
hidden_states, attn_weights = self.self_attn(
|
401 |
+
hidden_states=hidden_states,
|
402 |
+
head_mask=attention_mask,
|
403 |
+
output_attentions=output_attentions,
|
404 |
+
)
|
405 |
+
hidden_states = hidden_states + residual
|
406 |
+
residual = hidden_states
|
407 |
+
hidden_states = self.layer_norm2(hidden_states)
|
408 |
+
hidden_states = self.mlp(hidden_states)
|
409 |
+
|
410 |
+
hidden_states = hidden_states + residual
|
411 |
+
|
412 |
+
outputs = (hidden_states,)
|
413 |
+
|
414 |
+
if output_attentions:
|
415 |
+
outputs += (attn_weights,)
|
416 |
+
|
417 |
+
return outputs
|
418 |
+
|
419 |
+
|
420 |
+
class BlipPreTrainedModel(PreTrainedModel):
|
421 |
+
"""
|
422 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
423 |
+
models.
|
424 |
+
"""
|
425 |
+
|
426 |
+
config_class = BlipConfig
|
427 |
+
base_model_prefix = "blip"
|
428 |
+
supports_gradient_checkpointing = True
|
429 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
430 |
+
|
431 |
+
def _init_weights(self, module):
|
432 |
+
"""Initialize the weights"""
|
433 |
+
factor = self.config.initializer_range
|
434 |
+
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear):
|
435 |
+
module.weight.data.normal_(mean=0.0, std=factor)
|
436 |
+
if hasattr(module, "bias") and module.bias is not None:
|
437 |
+
module.bias.data.zero_()
|
438 |
+
|
439 |
+
if isinstance(module, BlipVisionEmbeddings):
|
440 |
+
if hasattr(self.config, "vision_config"):
|
441 |
+
factor = self.config.vision_config.initializer_range
|
442 |
+
nn.init.trunc_normal_(
|
443 |
+
module.position_embedding,
|
444 |
+
mean=0.0,
|
445 |
+
std=factor,
|
446 |
+
)
|
447 |
+
|
448 |
+
nn.init.trunc_normal_(
|
449 |
+
module.class_embedding,
|
450 |
+
mean=0.0,
|
451 |
+
std=factor,
|
452 |
+
)
|
453 |
+
|
454 |
+
elif isinstance(module, nn.LayerNorm):
|
455 |
+
module.bias.data.zero_()
|
456 |
+
module.weight.data.fill_(1.0)
|
457 |
+
elif isinstance(module, nn.Linear) and module.bias is not None:
|
458 |
+
module.bias.data.zero_()
|
459 |
+
|
460 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
461 |
+
if isinstance(module, BlipEncoder):
|
462 |
+
module.gradient_checkpointing = value
|
463 |
+
|
464 |
+
|
465 |
+
BLIP_START_DOCSTRING = r"""
|
466 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
467 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
468 |
+
etc.)
|
469 |
+
|
470 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
471 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
472 |
+
and behavior.
|
473 |
+
|
474 |
+
Parameters:
|
475 |
+
config ([`BlipConfig`]): Model configuration class with all the parameters of the model.
|
476 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
477 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
478 |
+
"""
|
479 |
+
|
480 |
+
BLIP_TEXT_INPUTS_DOCSTRING = r"""
|
481 |
+
Args:
|
482 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
483 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
484 |
+
it.
|
485 |
+
|
486 |
+
Indices can be obtained using [`AutoProcessor`]. See [`BlipProcessor.__call__`] for details.
|
487 |
+
|
488 |
+
[What are input IDs?](../glossary#input-ids)
|
489 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
490 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
491 |
+
|
492 |
+
- 1 for tokens that are **not masked**,
|
493 |
+
- 0 for tokens that are **masked**.
|
494 |
+
|
495 |
+
[What are attention masks?](../glossary#attention-mask)
|
496 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
497 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
498 |
+
config.max_position_embeddings - 1]`.
|
499 |
+
|
500 |
+
[What are position IDs?](../glossary#position-ids)
|
501 |
+
output_attentions (`bool`, *optional*):
|
502 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
503 |
+
tensors for more detail.
|
504 |
+
output_hidden_states (`bool`, *optional*):
|
505 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
506 |
+
more detail.
|
507 |
+
return_dict (`bool`, *optional*):
|
508 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
509 |
+
"""
|
510 |
+
|
511 |
+
BLIP_VISION_INPUTS_DOCSTRING = r"""
|
512 |
+
Args:
|
513 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
514 |
+
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
515 |
+
[`BlipImageProcessor`]. See [`BlipImageProcessor.__call__`] for details.
|
516 |
+
output_attentions (`bool`, *optional*):
|
517 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
518 |
+
tensors for more detail.
|
519 |
+
output_hidden_states (`bool`, *optional*):
|
520 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
521 |
+
more detail.
|
522 |
+
return_dict (`bool`, *optional*):
|
523 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
524 |
+
"""
|
525 |
+
|
526 |
+
BLIP_INPUTS_DOCSTRING = r"""
|
527 |
+
Args:
|
528 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
529 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
530 |
+
it.
|
531 |
+
|
532 |
+
Indices can be obtained using [`AutoProcessor`]. See [`BlipProcessor.__call__`] for details.
|
533 |
+
|
534 |
+
[What are input IDs?](../glossary#input-ids)
|
535 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
536 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
537 |
+
|
538 |
+
- 1 for tokens that are **not masked**,
|
539 |
+
- 0 for tokens that are **masked**.
|
540 |
+
|
541 |
+
[What are attention masks?](../glossary#attention-mask)
|
542 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
543 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
544 |
+
config.max_position_embeddings - 1]`.
|
545 |
+
|
546 |
+
[What are position IDs?](../glossary#position-ids)
|
547 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
548 |
+
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
549 |
+
[`BlipImageProcessor`]. See [`BlipImageProcessor.__call__`] for details.
|
550 |
+
return_loss (`bool`, *optional*):
|
551 |
+
Whether or not to return the contrastive loss.
|
552 |
+
output_attentions (`bool`, *optional*):
|
553 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
554 |
+
tensors for more detail.
|
555 |
+
output_hidden_states (`bool`, *optional*):
|
556 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
557 |
+
more detail.
|
558 |
+
return_dict (`bool`, *optional*):
|
559 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
560 |
+
"""
|
561 |
+
|
562 |
+
|
563 |
+
class BlipEncoder(nn.Module):
|
564 |
+
"""
|
565 |
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
566 |
+
[`BlipEncoderLayer`].
|
567 |
+
|
568 |
+
Args:
|
569 |
+
config (`BlipConfig`):
|
570 |
+
The corresponding vision configuration for the `BlipEncoder`.
|
571 |
+
"""
|
572 |
+
|
573 |
+
def __init__(self, config: BlipConfig):
|
574 |
+
super().__init__()
|
575 |
+
self.config = config
|
576 |
+
self.layers = nn.ModuleList([BlipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
577 |
+
self.gradient_checkpointing = False
|
578 |
+
|
579 |
+
def forward(
|
580 |
+
self,
|
581 |
+
inputs_embeds,
|
582 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
583 |
+
output_attentions: Optional[bool] = None,
|
584 |
+
output_hidden_states: Optional[bool] = None,
|
585 |
+
return_dict: Optional[bool] = None,
|
586 |
+
) -> Union[Tuple, BaseModelOutput]:
|
587 |
+
r"""
|
588 |
+
Args:
|
589 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
590 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
591 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
592 |
+
than the model's internal embedding lookup matrix.
|
593 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
594 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
595 |
+
|
596 |
+
- 1 for tokens that are **not masked**,
|
597 |
+
- 0 for tokens that are **masked**.
|
598 |
+
|
599 |
+
[What are attention masks?](../glossary#attention-mask)
|
600 |
+
output_attentions (`bool`, *optional*):
|
601 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
602 |
+
returned tensors for more detail.
|
603 |
+
output_hidden_states (`bool`, *optional*):
|
604 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
605 |
+
for more detail.
|
606 |
+
return_dict (`bool`, *optional*):
|
607 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
608 |
+
"""
|
609 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
610 |
+
output_hidden_states = (
|
611 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
612 |
+
)
|
613 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
614 |
+
|
615 |
+
encoder_states = () if output_hidden_states else None
|
616 |
+
all_attentions = () if output_attentions else None
|
617 |
+
|
618 |
+
hidden_states = inputs_embeds
|
619 |
+
for idx, encoder_layer in enumerate(self.layers):
|
620 |
+
if output_hidden_states:
|
621 |
+
encoder_states = encoder_states + (hidden_states,)
|
622 |
+
if self.gradient_checkpointing and self.training:
|
623 |
+
|
624 |
+
def create_custom_forward(module):
|
625 |
+
def custom_forward(*inputs):
|
626 |
+
return module(*inputs, output_attentions)
|
627 |
+
|
628 |
+
return custom_forward
|
629 |
+
|
630 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
631 |
+
create_custom_forward(encoder_layer),
|
632 |
+
hidden_states,
|
633 |
+
attention_mask,
|
634 |
+
)
|
635 |
+
else:
|
636 |
+
layer_outputs = encoder_layer(
|
637 |
+
hidden_states,
|
638 |
+
attention_mask,
|
639 |
+
output_attentions=output_attentions,
|
640 |
+
)
|
641 |
+
|
642 |
+
hidden_states = layer_outputs[0]
|
643 |
+
|
644 |
+
if output_attentions:
|
645 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
646 |
+
|
647 |
+
if output_hidden_states:
|
648 |
+
encoder_states = encoder_states + (hidden_states,)
|
649 |
+
|
650 |
+
if not return_dict:
|
651 |
+
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
652 |
+
return BaseModelOutput(
|
653 |
+
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
654 |
+
)
|
655 |
+
|
656 |
+
|
657 |
+
class BlipVisionModel(BlipPreTrainedModel):
|
658 |
+
main_input_name = "pixel_values"
|
659 |
+
config_class = BlipVisionConfig
|
660 |
+
|
661 |
+
def __init__(self, config: BlipVisionConfig):
|
662 |
+
super().__init__(config)
|
663 |
+
self.config = config
|
664 |
+
embed_dim = config.hidden_size
|
665 |
+
self.embeddings = BlipVisionEmbeddings(config)
|
666 |
+
self.patch_mask_generator = ViTPatchMaskGenerator(config.patch_size)
|
667 |
+
self.encoder = BlipEncoder(config)
|
668 |
+
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
669 |
+
|
670 |
+
self.post_init()
|
671 |
+
|
672 |
+
@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
|
673 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=BlipVisionConfig)
|
674 |
+
def forward(
|
675 |
+
self,
|
676 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
677 |
+
pixel_masks: Optional[torch.LongTensor] = None,
|
678 |
+
output_attentions: Optional[bool] = None,
|
679 |
+
output_hidden_states: Optional[bool] = None,
|
680 |
+
return_dict: Optional[bool] = None,
|
681 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
682 |
+
r"""
|
683 |
+
Returns:
|
684 |
+
|
685 |
+
"""
|
686 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
687 |
+
output_hidden_states = (
|
688 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
689 |
+
)
|
690 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
691 |
+
|
692 |
+
if pixel_values is None:
|
693 |
+
raise ValueError("You have to specify pixel_values")
|
694 |
+
|
695 |
+
hidden_states = self.embeddings(pixel_values)
|
696 |
+
B, N, D = hidden_states.shape
|
697 |
+
# print('Before mask:', hidden_states.shape)
|
698 |
+
if pixel_masks is not None:
|
699 |
+
assert pixel_masks.shape[0] == 1
|
700 |
+
patch_masks = self.patch_mask_generator(pixel_masks)
|
701 |
+
# print(patch_masks.shape)
|
702 |
+
patch_masks = patch_masks.unsqueeze(-1).expand_as(hidden_states)
|
703 |
+
hidden_states = hidden_states.masked_select(patch_masks).view(B, -1, D)
|
704 |
+
# print('After mask:', hidden_states.shape)
|
705 |
+
|
706 |
+
encoder_outputs = self.encoder(
|
707 |
+
inputs_embeds=hidden_states,
|
708 |
+
output_attentions=output_attentions,
|
709 |
+
output_hidden_states=output_hidden_states,
|
710 |
+
return_dict=return_dict,
|
711 |
+
)
|
712 |
+
|
713 |
+
last_hidden_state = encoder_outputs[0]
|
714 |
+
last_hidden_state = self.post_layernorm(last_hidden_state)
|
715 |
+
|
716 |
+
pooled_output = last_hidden_state[:, 0, :]
|
717 |
+
pooled_output = self.post_layernorm(pooled_output)
|
718 |
+
|
719 |
+
if not return_dict:
|
720 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
721 |
+
|
722 |
+
return BaseModelOutputWithPooling(
|
723 |
+
last_hidden_state=last_hidden_state,
|
724 |
+
pooler_output=pooled_output,
|
725 |
+
hidden_states=encoder_outputs.hidden_states,
|
726 |
+
attentions=encoder_outputs.attentions,
|
727 |
+
)
|
728 |
+
|
729 |
+
def get_input_embeddings(self):
|
730 |
+
return self.embeddings
|
731 |
+
|
732 |
+
|
733 |
+
@add_start_docstrings(BLIP_START_DOCSTRING)
|
734 |
+
class BlipModel(BlipPreTrainedModel):
|
735 |
+
config_class = BlipConfig
|
736 |
+
|
737 |
+
def __init__(self, config: BlipConfig):
|
738 |
+
super().__init__(config)
|
739 |
+
|
740 |
+
if not isinstance(config.text_config, BlipTextConfig):
|
741 |
+
raise ValueError(
|
742 |
+
"config.text_config is expected to be of type BlipTextConfig but is of type"
|
743 |
+
f" {type(config.text_config)}."
|
744 |
+
)
|
745 |
+
|
746 |
+
if not isinstance(config.vision_config, BlipVisionConfig):
|
747 |
+
raise ValueError(
|
748 |
+
"config.vision_config is expected to be of type BlipVisionConfig but is of type"
|
749 |
+
f" {type(config.vision_config)}."
|
750 |
+
)
|
751 |
+
|
752 |
+
text_config = config.text_config
|
753 |
+
vision_config = config.vision_config
|
754 |
+
|
755 |
+
self.projection_dim = config.projection_dim
|
756 |
+
self.text_embed_dim = text_config.hidden_size
|
757 |
+
self.vision_embed_dim = vision_config.hidden_size
|
758 |
+
|
759 |
+
self.text_model = BlipTextModel(text_config)
|
760 |
+
self.vision_model = BlipVisionModel(vision_config)
|
761 |
+
|
762 |
+
self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
|
763 |
+
self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
|
764 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * self.config.logit_scale_init_value)
|
765 |
+
|
766 |
+
# Initialize weights and apply final processing
|
767 |
+
self.post_init()
|
768 |
+
|
769 |
+
@add_start_docstrings_to_model_forward(BLIP_TEXT_INPUTS_DOCSTRING)
|
770 |
+
def get_text_features(
|
771 |
+
self,
|
772 |
+
input_ids: Optional[torch.Tensor] = None,
|
773 |
+
attention_mask: Optional[torch.Tensor] = None,
|
774 |
+
position_ids: Optional[torch.Tensor] = None,
|
775 |
+
return_dict: Optional[bool] = None,
|
776 |
+
) -> torch.FloatTensor:
|
777 |
+
r"""
|
778 |
+
Returns:
|
779 |
+
text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
|
780 |
+
applying the projection layer to the pooled output of [`BlipTextModel`].
|
781 |
+
|
782 |
+
Examples:
|
783 |
+
|
784 |
+
```python
|
785 |
+
>>> from transformers import AutoProcessor, BlipModel
|
786 |
+
|
787 |
+
>>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base")
|
788 |
+
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
789 |
+
|
790 |
+
>>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
791 |
+
>>> text_features = model.get_text_features(**inputs)
|
792 |
+
```"""
|
793 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
794 |
+
|
795 |
+
text_outputs = self.text_model(
|
796 |
+
input_ids=input_ids,
|
797 |
+
attention_mask=attention_mask,
|
798 |
+
position_ids=position_ids,
|
799 |
+
return_dict=return_dict,
|
800 |
+
)
|
801 |
+
|
802 |
+
pooled_output = text_outputs[1]
|
803 |
+
text_features = self.text_projection(pooled_output)
|
804 |
+
|
805 |
+
return text_features
|
806 |
+
|
807 |
+
@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
|
808 |
+
def get_image_features(
|
809 |
+
self,
|
810 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
811 |
+
return_dict: Optional[bool] = None,
|
812 |
+
) -> torch.FloatTensor:
|
813 |
+
r"""
|
814 |
+
Returns:
|
815 |
+
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
|
816 |
+
applying the projection layer to the pooled output of [`BlipVisionModel`].
|
817 |
+
|
818 |
+
Examples:
|
819 |
+
|
820 |
+
```python
|
821 |
+
>>> from PIL import Image
|
822 |
+
>>> import requests
|
823 |
+
>>> from transformers import AutoProcessor, BlipModel
|
824 |
+
|
825 |
+
>>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base")
|
826 |
+
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
827 |
+
|
828 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
829 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
830 |
+
|
831 |
+
>>> inputs = processor(images=image, return_tensors="pt")
|
832 |
+
|
833 |
+
>>> image_features = model.get_image_features(**inputs)
|
834 |
+
```"""
|
835 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
836 |
+
|
837 |
+
vision_outputs = self.vision_model(
|
838 |
+
pixel_values=pixel_values,
|
839 |
+
return_dict=return_dict,
|
840 |
+
)
|
841 |
+
|
842 |
+
pooled_output = vision_outputs[1] # pooled_output
|
843 |
+
image_features = self.visual_projection(pooled_output)
|
844 |
+
|
845 |
+
return image_features
|
846 |
+
|
847 |
+
@add_start_docstrings_to_model_forward(BLIP_INPUTS_DOCSTRING)
|
848 |
+
@replace_return_docstrings(output_type=BlipOutput, config_class=BlipConfig)
|
849 |
+
def forward(
|
850 |
+
self,
|
851 |
+
input_ids: Optional[torch.LongTensor] = None,
|
852 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
853 |
+
pixel_masks: Optional[torch.FloatTensor] = None,
|
854 |
+
attention_mask: Optional[torch.Tensor] = None,
|
855 |
+
position_ids: Optional[torch.LongTensor] = None,
|
856 |
+
return_loss: Optional[bool] = None,
|
857 |
+
output_attentions: Optional[bool] = None,
|
858 |
+
output_hidden_states: Optional[bool] = None,
|
859 |
+
return_dict: Optional[bool] = None,
|
860 |
+
) -> Union[Tuple, BlipOutput]:
|
861 |
+
r"""
|
862 |
+
Returns:
|
863 |
+
|
864 |
+
Examples:
|
865 |
+
|
866 |
+
```python
|
867 |
+
>>> from PIL import Image
|
868 |
+
>>> import requests
|
869 |
+
>>> from transformers import AutoProcessor, BlipModel
|
870 |
+
|
871 |
+
>>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base")
|
872 |
+
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
873 |
+
|
874 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
875 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
876 |
+
|
877 |
+
>>> inputs = processor(
|
878 |
+
... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
|
879 |
+
... )
|
880 |
+
|
881 |
+
>>> outputs = model(**inputs)
|
882 |
+
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
883 |
+
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
884 |
+
```"""
|
885 |
+
# Use BLIP model's config for some fields (if specified) instead of those of vision & text components.
|
886 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
887 |
+
output_hidden_states = (
|
888 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
889 |
+
)
|
890 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
891 |
+
|
892 |
+
vision_outputs = self.vision_model(
|
893 |
+
pixel_values=pixel_values,
|
894 |
+
pixel_masks=pixel_masks,
|
895 |
+
output_attentions=output_attentions,
|
896 |
+
output_hidden_states=output_hidden_states,
|
897 |
+
return_dict=return_dict,
|
898 |
+
)
|
899 |
+
|
900 |
+
text_outputs = self.text_model(
|
901 |
+
input_ids=input_ids,
|
902 |
+
attention_mask=attention_mask,
|
903 |
+
position_ids=position_ids,
|
904 |
+
output_attentions=output_attentions,
|
905 |
+
output_hidden_states=output_hidden_states,
|
906 |
+
return_dict=return_dict,
|
907 |
+
)
|
908 |
+
|
909 |
+
image_embeds = vision_outputs[1]
|
910 |
+
image_embeds = self.visual_projection(image_embeds)
|
911 |
+
|
912 |
+
text_embeds = text_outputs[1]
|
913 |
+
text_embeds = self.text_projection(text_embeds)
|
914 |
+
|
915 |
+
# normalized features
|
916 |
+
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
917 |
+
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
|
918 |
+
|
919 |
+
# cosine similarity as logits
|
920 |
+
logit_scale = self.logit_scale.exp()
|
921 |
+
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
|
922 |
+
logits_per_image = logits_per_text.t()
|
923 |
+
|
924 |
+
loss = None
|
925 |
+
if return_loss:
|
926 |
+
loss = blip_loss(logits_per_text)
|
927 |
+
|
928 |
+
if not return_dict:
|
929 |
+
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
|
930 |
+
return ((loss,) + output) if loss is not None else output
|
931 |
+
|
932 |
+
return BlipOutput(
|
933 |
+
loss=loss,
|
934 |
+
logits_per_image=logits_per_image,
|
935 |
+
logits_per_text=logits_per_text,
|
936 |
+
text_embeds=text_embeds,
|
937 |
+
image_embeds=image_embeds,
|
938 |
+
text_model_output=text_outputs,
|
939 |
+
vision_model_output=vision_outputs,
|
940 |
+
)
|
941 |
+
|
942 |
+
|
943 |
+
@add_start_docstrings(
|
944 |
+
"""
|
945 |
+
BLIP Model for image captioning. The model consists of a vision encoder and a text decoder. One can optionally pass
|
946 |
+
`input_ids` to the model, which serve as a text prompt, to make the text decoder continue the prompt. Otherwise,
|
947 |
+
the decoder starts generating text from the [BOS] (beginning-of-sequence) token. will start generating the caption
|
948 |
+
from the text input. If no text input is provided, the decoder will start with the [BOS] token only.
|
949 |
+
""",
|
950 |
+
BLIP_START_DOCSTRING,
|
951 |
+
)
|
952 |
+
class BlipForConditionalGeneration(BlipPreTrainedModel):
|
953 |
+
config_class = BlipConfig
|
954 |
+
_keys_to_ignore_on_load_missing = [r"text_decoder.cls.predictions.decoder.bias"]
|
955 |
+
main_input_name = "pixel_values"
|
956 |
+
|
957 |
+
def __init__(self, config: BlipConfig):
|
958 |
+
super().__init__(config)
|
959 |
+
|
960 |
+
self.vision_model = BlipVisionModel(config.vision_config)
|
961 |
+
|
962 |
+
self.text_decoder = BlipTextLMHeadModel(config.text_config)
|
963 |
+
|
964 |
+
self.decoder_input_ids = config.text_config.bos_token_id
|
965 |
+
self.decoder_pad_token_id = config.text_config.pad_token_id
|
966 |
+
|
967 |
+
# Initialize weights and apply final processing
|
968 |
+
self.post_init()
|
969 |
+
|
970 |
+
def get_input_embeddings(self) -> nn.Module:
|
971 |
+
return self.vision_model.embeddings.patch_embedding
|
972 |
+
|
973 |
+
@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
|
974 |
+
@replace_return_docstrings(output_type=BlipForConditionalGenerationModelOutput, config_class=BlipVisionConfig)
|
975 |
+
def forward(
|
976 |
+
self,
|
977 |
+
pixel_values: torch.FloatTensor,
|
978 |
+
input_ids: Optional[torch.LongTensor] = None,
|
979 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
980 |
+
output_attentions: Optional[bool] = None,
|
981 |
+
output_hidden_states: Optional[bool] = None,
|
982 |
+
labels: Optional[torch.LongTensor] = None,
|
983 |
+
return_dict: Optional[bool] = None,
|
984 |
+
) -> Union[Tuple, BlipForConditionalGenerationModelOutput]:
|
985 |
+
r"""
|
986 |
+
Returns:
|
987 |
+
|
988 |
+
Examples:
|
989 |
+
|
990 |
+
```python
|
991 |
+
>>> from PIL import Image
|
992 |
+
>>> import requests
|
993 |
+
>>> from transformers import AutoProcessor, BlipForConditionalGeneration
|
994 |
+
|
995 |
+
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
996 |
+
>>> model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
|
997 |
+
|
998 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
999 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
1000 |
+
>>> text = "A picture of"
|
1001 |
+
|
1002 |
+
>>> inputs = processor(images=image, text=text, return_tensors="pt")
|
1003 |
+
|
1004 |
+
>>> outputs = model(**inputs)
|
1005 |
+
```"""
|
1006 |
+
|
1007 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1008 |
+
|
1009 |
+
vision_outputs = self.vision_model(
|
1010 |
+
pixel_values=pixel_values,
|
1011 |
+
output_attentions=output_attentions,
|
1012 |
+
output_hidden_states=output_hidden_states,
|
1013 |
+
return_dict=return_dict,
|
1014 |
+
)
|
1015 |
+
|
1016 |
+
image_embeds = vision_outputs[0]
|
1017 |
+
|
1018 |
+
outputs = self.text_decoder(
|
1019 |
+
input_ids=input_ids,
|
1020 |
+
attention_mask=attention_mask,
|
1021 |
+
encoder_hidden_states=image_embeds,
|
1022 |
+
labels=labels,
|
1023 |
+
return_dict=return_dict,
|
1024 |
+
reduction="mean",
|
1025 |
+
)
|
1026 |
+
|
1027 |
+
if not return_dict:
|
1028 |
+
outputs = (outputs[0], outputs[1], image_embeds, vision_outputs[0]) + vision_outputs[2:]
|
1029 |
+
return tuple(output for output in outputs if output is not None)
|
1030 |
+
|
1031 |
+
return BlipForConditionalGenerationModelOutput(
|
1032 |
+
loss=outputs.loss,
|
1033 |
+
decoder_logits=outputs.logits,
|
1034 |
+
image_embeds=image_embeds,
|
1035 |
+
last_hidden_state=vision_outputs.last_hidden_state,
|
1036 |
+
hidden_states=vision_outputs.hidden_states,
|
1037 |
+
attentions=vision_outputs.attentions,
|
1038 |
+
)
|
1039 |
+
|
1040 |
+
@torch.no_grad()
|
1041 |
+
def generate(
|
1042 |
+
self,
|
1043 |
+
pixel_values: torch.FloatTensor,
|
1044 |
+
pixel_masks: torch.Tensor = None,
|
1045 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1046 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
1047 |
+
**generate_kwargs,
|
1048 |
+
) -> torch.LongTensor:
|
1049 |
+
r"""
|
1050 |
+
Overrides *generate* function to be able to use the model as a conditional generator
|
1051 |
+
|
1052 |
+
Parameters:
|
1053 |
+
pixel_values (*torch.FloatTensor* of shape *(batch_size, image_width, image_height)*:
|
1054 |
+
Input image to be processed
|
1055 |
+
input_ids (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
|
1056 |
+
The sequence used as a prompt for the generation.
|
1057 |
+
attention_mask (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
|
1058 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
1059 |
+
|
1060 |
+
|
1061 |
+
Examples:
|
1062 |
+
```python
|
1063 |
+
>>> from PIL import Image
|
1064 |
+
>>> import requests
|
1065 |
+
>>> from transformers import AutoProcessor, BlipForConditionalGeneration
|
1066 |
+
|
1067 |
+
>>> model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
|
1068 |
+
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
1069 |
+
|
1070 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
1071 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
1072 |
+
|
1073 |
+
>>> inputs = processor(images=image, return_tensors="pt")
|
1074 |
+
|
1075 |
+
>>> outputs = model.generate(**inputs)
|
1076 |
+
>>> print(processor.decode(outputs[0], skip_special_tokens=True))
|
1077 |
+
two cats are laying on a couch
|
1078 |
+
```
|
1079 |
+
"""
|
1080 |
+
|
1081 |
+
batch_size = pixel_values.shape[0]
|
1082 |
+
vision_outputs = self.vision_model(
|
1083 |
+
pixel_values=pixel_values,
|
1084 |
+
pixel_masks=pixel_masks,
|
1085 |
+
)
|
1086 |
+
|
1087 |
+
image_embeds = vision_outputs[0]
|
1088 |
+
|
1089 |
+
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image_embeds.device)
|
1090 |
+
|
1091 |
+
if isinstance(input_ids, list):
|
1092 |
+
input_ids = torch.LongTensor(input_ids)
|
1093 |
+
elif input_ids is None:
|
1094 |
+
input_ids = (
|
1095 |
+
torch.LongTensor([[self.decoder_input_ids, self.config.text_config.eos_token_id]])
|
1096 |
+
.repeat(batch_size, 1)
|
1097 |
+
.to(image_embeds.device)
|
1098 |
+
)
|
1099 |
+
|
1100 |
+
input_ids[:, 0] = self.config.text_config.bos_token_id
|
1101 |
+
attention_mask = attention_mask[:, :-1] if attention_mask is not None else None
|
1102 |
+
|
1103 |
+
outputs = self.text_decoder.generate(
|
1104 |
+
input_ids=input_ids[:, :-1],
|
1105 |
+
eos_token_id=self.config.text_config.sep_token_id,
|
1106 |
+
pad_token_id=self.config.text_config.pad_token_id,
|
1107 |
+
attention_mask=attention_mask,
|
1108 |
+
encoder_hidden_states=image_embeds,
|
1109 |
+
encoder_attention_mask=image_attention_mask,
|
1110 |
+
**generate_kwargs,
|
1111 |
+
)
|
1112 |
+
|
1113 |
+
return outputs
|
1114 |
+
|
1115 |
+
|
1116 |
+
@add_start_docstrings(
|
1117 |
+
"""
|
1118 |
+
BLIP Model for visual question answering. The model consists of a vision encoder, a text encoder as well as a text
|
1119 |
+
decoder. The vision encoder will encode the input image, the text encoder will encode the input question together
|
1120 |
+
with the encoding of the image, and the text decoder will output the answer to the question.
|
1121 |
+
""",
|
1122 |
+
BLIP_START_DOCSTRING,
|
1123 |
+
)
|
1124 |
+
class BlipForQuestionAnswering(BlipPreTrainedModel):
|
1125 |
+
config_class = BlipConfig
|
1126 |
+
_keys_to_ignore_on_load_missing = [r"text_decoder.cls.predictions.decoder.bias"]
|
1127 |
+
|
1128 |
+
def __init__(self, config: BlipConfig):
|
1129 |
+
super().__init__(config)
|
1130 |
+
|
1131 |
+
self.vision_model = BlipVisionModel(config.vision_config)
|
1132 |
+
|
1133 |
+
self.text_encoder = BlipTextModel(config.text_config, add_pooling_layer=False)
|
1134 |
+
|
1135 |
+
self.text_decoder = BlipTextLMHeadModel(config.text_config)
|
1136 |
+
|
1137 |
+
self.decoder_pad_token_id = config.text_config.pad_token_id
|
1138 |
+
self.decoder_start_token_id = config.text_config.bos_token_id
|
1139 |
+
|
1140 |
+
# Initialize weights and apply final processing
|
1141 |
+
self.post_init()
|
1142 |
+
|
1143 |
+
def get_input_embeddings(self) -> nn.Module:
|
1144 |
+
return self.vision_model.embeddings.patch_embedding
|
1145 |
+
|
1146 |
+
# Adapted from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right
|
1147 |
+
def _shift_right(self, input_ids):
|
1148 |
+
pad_token_id = self.decoder_pad_token_id
|
1149 |
+
|
1150 |
+
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
1151 |
+
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
1152 |
+
shifted_input_ids[..., 0] = self.decoder_start_token_id
|
1153 |
+
|
1154 |
+
# replace possible -100 values in labels by `pad_token_id`
|
1155 |
+
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
1156 |
+
|
1157 |
+
return shifted_input_ids
|
1158 |
+
|
1159 |
+
@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
|
1160 |
+
@replace_return_docstrings(output_type=BlipTextVisionModelOutput, config_class=BlipVisionConfig)
|
1161 |
+
def forward(
|
1162 |
+
self,
|
1163 |
+
input_ids: torch.LongTensor,
|
1164 |
+
pixel_values: torch.FloatTensor,
|
1165 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
1166 |
+
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
1167 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
1168 |
+
output_attentions: Optional[bool] = None,
|
1169 |
+
output_hidden_states: Optional[bool] = None,
|
1170 |
+
labels: Optional[torch.LongTensor] = None,
|
1171 |
+
return_dict: Optional[bool] = None,
|
1172 |
+
) -> Union[Tuple, BlipTextVisionModelOutput]:
|
1173 |
+
r"""
|
1174 |
+
Returns:
|
1175 |
+
|
1176 |
+
Examples:
|
1177 |
+
|
1178 |
+
```python
|
1179 |
+
>>> from PIL import Image
|
1180 |
+
>>> import requests
|
1181 |
+
>>> from transformers import AutoProcessor, BlipForQuestionAnswering
|
1182 |
+
|
1183 |
+
>>> model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
|
1184 |
+
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
|
1185 |
+
|
1186 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
1187 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
1188 |
+
|
1189 |
+
>>> # training
|
1190 |
+
>>> text = "How many cats are in the picture?"
|
1191 |
+
>>> label = "2"
|
1192 |
+
>>> inputs = processor(images=image, text=text, return_tensors="pt")
|
1193 |
+
>>> labels = processor(text=label, return_tensors="pt").input_ids
|
1194 |
+
|
1195 |
+
>>> inputs["labels"] = labels
|
1196 |
+
>>> outputs = model(**inputs)
|
1197 |
+
>>> loss = outputs.loss
|
1198 |
+
>>> loss.backward()
|
1199 |
+
|
1200 |
+
>>> # inference
|
1201 |
+
>>> text = "How many cats are in the picture?"
|
1202 |
+
>>> inputs = processor(images=image, text=text, return_tensors="pt")
|
1203 |
+
>>> outputs = model.generate(**inputs)
|
1204 |
+
>>> print(processor.decode(outputs[0], skip_special_tokens=True))
|
1205 |
+
2
|
1206 |
+
```"""
|
1207 |
+
if labels is None and decoder_input_ids is None:
|
1208 |
+
raise ValueError(
|
1209 |
+
"Either `decoder_input_ids` or `labels` should be passed when calling `forward` with"
|
1210 |
+
" `BlipForQuestionAnswering`. if you are training the model make sure that `labels` is passed, if you"
|
1211 |
+
" are using the model for inference make sure that `decoder_input_ids` is passed or call `generate`"
|
1212 |
+
)
|
1213 |
+
|
1214 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1215 |
+
|
1216 |
+
vision_outputs = self.vision_model(
|
1217 |
+
pixel_values=pixel_values,
|
1218 |
+
output_attentions=output_attentions,
|
1219 |
+
output_hidden_states=output_hidden_states,
|
1220 |
+
return_dict=return_dict,
|
1221 |
+
)
|
1222 |
+
|
1223 |
+
image_embeds = vision_outputs[0]
|
1224 |
+
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long)
|
1225 |
+
|
1226 |
+
question_embeds = self.text_encoder(
|
1227 |
+
input_ids=input_ids,
|
1228 |
+
attention_mask=attention_mask,
|
1229 |
+
encoder_hidden_states=image_embeds,
|
1230 |
+
encoder_attention_mask=image_attention_mask,
|
1231 |
+
return_dict=return_dict,
|
1232 |
+
)
|
1233 |
+
|
1234 |
+
question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state
|
1235 |
+
|
1236 |
+
if labels is not None and decoder_input_ids is None:
|
1237 |
+
# get decoder inputs from shifting lm labels to the right - this is used in training mode
|
1238 |
+
decoder_input_ids = self._shift_right(labels)
|
1239 |
+
# replace possible -100 values in labels by `pad_token_id`
|
1240 |
+
labels = labels.masked_fill(labels == self.decoder_pad_token_id, -100)
|
1241 |
+
|
1242 |
+
answer_output = self.text_decoder(
|
1243 |
+
input_ids=decoder_input_ids,
|
1244 |
+
attention_mask=decoder_attention_mask,
|
1245 |
+
encoder_hidden_states=question_embeds,
|
1246 |
+
encoder_attention_mask=attention_mask,
|
1247 |
+
labels=labels,
|
1248 |
+
return_dict=return_dict,
|
1249 |
+
reduction="mean",
|
1250 |
+
)
|
1251 |
+
|
1252 |
+
if labels is not None:
|
1253 |
+
decoder_loss = answer_output.loss.mean() if return_dict else answer_output[0].mean()
|
1254 |
+
else:
|
1255 |
+
decoder_loss = None
|
1256 |
+
|
1257 |
+
if not return_dict:
|
1258 |
+
outputs = (decoder_loss, image_embeds, vision_outputs[0]) + vision_outputs[2:]
|
1259 |
+
return tuple(output for output in outputs if output is not None)
|
1260 |
+
|
1261 |
+
return BlipTextVisionModelOutput(
|
1262 |
+
loss=decoder_loss,
|
1263 |
+
image_embeds=image_embeds,
|
1264 |
+
last_hidden_state=vision_outputs.last_hidden_state,
|
1265 |
+
hidden_states=vision_outputs.hidden_states,
|
1266 |
+
attentions=vision_outputs.attentions,
|
1267 |
+
)
|
1268 |
+
|
1269 |
+
@torch.no_grad()
|
1270 |
+
def generate(
|
1271 |
+
self,
|
1272 |
+
input_ids: torch.LongTensor,
|
1273 |
+
pixel_values: torch.FloatTensor,
|
1274 |
+
pixel_masks: torch.Tensor = None,
|
1275 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
1276 |
+
**generate_kwargs,
|
1277 |
+
) -> torch.LongTensor:
|
1278 |
+
r"""
|
1279 |
+
Overrides *generate* function to be able to use the model as a conditional generator
|
1280 |
+
|
1281 |
+
Parameters:
|
1282 |
+
input_ids (*torch.LongTensor* of shape *(batch_size, sequence_length)*):
|
1283 |
+
The sequence used as a prompt for the generation.
|
1284 |
+
pixel_values (*torch.FloatTensor* of shape *(batch_size, image_width, image_height)*:
|
1285 |
+
Input image to be processed
|
1286 |
+
attention_mask (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
|
1287 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`. `1` for
|
1288 |
+
tokens that are NOT MASKED, `0` for MASKED tokens.
|
1289 |
+
**generate_kwargs:
|
1290 |
+
Additional arguments passed to the *generate* function of the decoder
|
1291 |
+
|
1292 |
+
|
1293 |
+
Examples:
|
1294 |
+
```python
|
1295 |
+
>>> from PIL import Image
|
1296 |
+
>>> import requests
|
1297 |
+
>>> from transformers import AutoProcessor, BlipForQuestionAnswering
|
1298 |
+
|
1299 |
+
>>> model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
|
1300 |
+
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
|
1301 |
+
|
1302 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
1303 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
1304 |
+
>>> text = "How many cats are in the picture?"
|
1305 |
+
|
1306 |
+
>>> inputs = processor(images=image, text=text, return_tensors="pt")
|
1307 |
+
|
1308 |
+
>>> outputs = model.generate(**inputs)
|
1309 |
+
>>> print(processor.decode(outputs[0], skip_special_tokens=True))
|
1310 |
+
2
|
1311 |
+
```
|
1312 |
+
"""
|
1313 |
+
vision_outputs = self.vision_model(
|
1314 |
+
pixel_values=pixel_values,
|
1315 |
+
pixel_masks=pixel_masks
|
1316 |
+
)
|
1317 |
+
|
1318 |
+
image_embeds = vision_outputs[0]
|
1319 |
+
|
1320 |
+
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image_embeds.device)
|
1321 |
+
|
1322 |
+
if isinstance(input_ids, list):
|
1323 |
+
input_ids = torch.LongTensor(input_ids)
|
1324 |
+
|
1325 |
+
question_outputs = self.text_encoder(
|
1326 |
+
input_ids=input_ids,
|
1327 |
+
attention_mask=attention_mask,
|
1328 |
+
encoder_hidden_states=image_embeds,
|
1329 |
+
encoder_attention_mask=image_attention_mask,
|
1330 |
+
return_dict=False,
|
1331 |
+
)
|
1332 |
+
|
1333 |
+
question_embeds = question_outputs[0]
|
1334 |
+
|
1335 |
+
question_attention_mask = torch.ones(question_embeds.size()[:-1], dtype=torch.long).to(question_embeds.device)
|
1336 |
+
|
1337 |
+
bos_ids = torch.full(
|
1338 |
+
(question_embeds.size(0), 1), fill_value=self.decoder_start_token_id, device=question_embeds.device
|
1339 |
+
)
|
1340 |
+
|
1341 |
+
outputs = self.text_decoder.generate(
|
1342 |
+
input_ids=bos_ids,
|
1343 |
+
eos_token_id=self.config.text_config.sep_token_id,
|
1344 |
+
pad_token_id=self.config.text_config.pad_token_id,
|
1345 |
+
encoder_hidden_states=question_embeds,
|
1346 |
+
encoder_attention_mask=question_attention_mask,
|
1347 |
+
**generate_kwargs,
|
1348 |
+
)
|
1349 |
+
|
1350 |
+
return outputs
|
1351 |
+
|
1352 |
+
|
1353 |
+
@add_start_docstrings(
|
1354 |
+
"""
|
1355 |
+
BLIP Model with a vision and text projector, and a classification head on top. The model is used in the context of
|
1356 |
+
image-text retrieval. Given an image and a text, the model returns the probability of the text being relevant to
|
1357 |
+
the image.
|
1358 |
+
""",
|
1359 |
+
BLIP_START_DOCSTRING,
|
1360 |
+
)
|
1361 |
+
class BlipForImageTextRetrieval(BlipPreTrainedModel):
|
1362 |
+
config_class = BlipConfig
|
1363 |
+
|
1364 |
+
def __init__(self, config: BlipConfig):
|
1365 |
+
super().__init__(config)
|
1366 |
+
|
1367 |
+
self.vision_model = BlipVisionModel(config.vision_config)
|
1368 |
+
|
1369 |
+
self.text_encoder = BlipTextModel(config.text_config, add_pooling_layer=False)
|
1370 |
+
|
1371 |
+
# vision projection layer
|
1372 |
+
self.vision_proj = nn.Linear(config.vision_config.hidden_size, config.image_text_hidden_size)
|
1373 |
+
|
1374 |
+
# text projection layer
|
1375 |
+
self.text_proj = nn.Linear(config.text_config.hidden_size, config.image_text_hidden_size)
|
1376 |
+
|
1377 |
+
# image text matching head
|
1378 |
+
self.itm_head = nn.Linear(config.text_config.hidden_size, 2)
|
1379 |
+
|
1380 |
+
self.decoder_pad_token_id = (
|
1381 |
+
config.text_config.pad_token_id
|
1382 |
+
if not hasattr(config, "decoder_pad_token_id")
|
1383 |
+
else config.decoder_pad_token_id
|
1384 |
+
)
|
1385 |
+
self.decoder_start_token_id = (
|
1386 |
+
config.text_config.bos_token_id
|
1387 |
+
if not hasattr(config, "decoder_start_token_id")
|
1388 |
+
else config.decoder_start_token_id
|
1389 |
+
)
|
1390 |
+
|
1391 |
+
# Initialize weights and apply final processing
|
1392 |
+
self.post_init()
|
1393 |
+
|
1394 |
+
def get_input_embeddings(self) -> nn.Module:
|
1395 |
+
return self.vision_model.embeddings.patch_embedding
|
1396 |
+
|
1397 |
+
@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
|
1398 |
+
@replace_return_docstrings(output_type=BlipTextVisionModelOutput, config_class=BlipVisionConfig)
|
1399 |
+
def forward(
|
1400 |
+
self,
|
1401 |
+
input_ids: torch.LongTensor,
|
1402 |
+
pixel_values: torch.FloatTensor,
|
1403 |
+
use_itm_head: Optional[bool] = True,
|
1404 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
1405 |
+
output_attentions: Optional[bool] = None,
|
1406 |
+
output_hidden_states: Optional[bool] = None,
|
1407 |
+
return_dict: Optional[bool] = None,
|
1408 |
+
) -> Union[Tuple, BlipTextVisionModelOutput]:
|
1409 |
+
r"""
|
1410 |
+
Returns:
|
1411 |
+
|
1412 |
+
Examples:
|
1413 |
+
|
1414 |
+
```python
|
1415 |
+
>>> from PIL import Image
|
1416 |
+
>>> import requests
|
1417 |
+
>>> from transformers import AutoProcessor, BlipForImageTextRetrieval
|
1418 |
+
|
1419 |
+
>>> model = BlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco")
|
1420 |
+
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip-itm-base-coco")
|
1421 |
+
|
1422 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
1423 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
1424 |
+
>>> text = "an image of a cat"
|
1425 |
+
|
1426 |
+
>>> inputs = processor(images=image, text=text, return_tensors="pt")
|
1427 |
+
>>> outputs = model(**inputs)
|
1428 |
+
```
|
1429 |
+
"""
|
1430 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1431 |
+
|
1432 |
+
vision_outputs = self.vision_model(
|
1433 |
+
pixel_values=pixel_values,
|
1434 |
+
output_attentions=output_attentions,
|
1435 |
+
output_hidden_states=output_hidden_states,
|
1436 |
+
return_dict=return_dict,
|
1437 |
+
)
|
1438 |
+
|
1439 |
+
image_embeds = vision_outputs[0]
|
1440 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long)
|
1441 |
+
|
1442 |
+
if use_itm_head:
|
1443 |
+
question_embeds = self.text_encoder(
|
1444 |
+
input_ids=input_ids,
|
1445 |
+
attention_mask=attention_mask,
|
1446 |
+
encoder_hidden_states=image_embeds,
|
1447 |
+
encoder_attention_mask=image_atts,
|
1448 |
+
return_dict=return_dict,
|
1449 |
+
)
|
1450 |
+
question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state
|
1451 |
+
|
1452 |
+
output = self.itm_head(question_embeds[:, 0, :])
|
1453 |
+
else:
|
1454 |
+
question_embeds = self.text_encoder(
|
1455 |
+
input_ids=input_ids,
|
1456 |
+
attention_mask=attention_mask,
|
1457 |
+
return_dict=return_dict,
|
1458 |
+
)
|
1459 |
+
question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state
|
1460 |
+
|
1461 |
+
image_feat = normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1)
|
1462 |
+
text_feat = normalize(self.text_proj(question_embeds[:, 0, :]), dim=-1)
|
1463 |
+
|
1464 |
+
output = image_feat @ text_feat.t()
|
1465 |
+
|
1466 |
+
if not return_dict:
|
1467 |
+
outputs = (output, vision_outputs[0]) + vision_outputs[2:] + (question_embeds,)
|
1468 |
+
return tuple(output for output in outputs if output is not None)
|
1469 |
+
|
1470 |
+
return BlipImageTextMatchingModelOutput(
|
1471 |
+
itm_score=output,
|
1472 |
+
last_hidden_state=vision_outputs.last_hidden_state,
|
1473 |
+
hidden_states=vision_outputs.hidden_states,
|
1474 |
+
attentions=vision_outputs.attentions,
|
1475 |
+
question_embeds=question_embeds,
|
1476 |
+
)
|
caption_anything/captioner/modeling_git.py
ADDED
@@ -0,0 +1,1587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 Microsoft Research and The HuggingFace Inc. team.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""PyTorch GIT model."""
|
17 |
+
|
18 |
+
|
19 |
+
import math
|
20 |
+
from dataclasses import dataclass
|
21 |
+
from typing import List, Optional, Tuple, Union
|
22 |
+
|
23 |
+
import torch
|
24 |
+
import torch.utils.checkpoint
|
25 |
+
from torch import nn
|
26 |
+
from torch.nn import CrossEntropyLoss
|
27 |
+
|
28 |
+
from transformers.activations import ACT2FN
|
29 |
+
from transformers.file_utils import ModelOutput
|
30 |
+
from transformers.modeling_outputs import (
|
31 |
+
BaseModelOutput,
|
32 |
+
BaseModelOutputWithPast,
|
33 |
+
BaseModelOutputWithPooling,
|
34 |
+
CausalLMOutputWithPast,
|
35 |
+
)
|
36 |
+
from transformers.modeling_utils import PreTrainedModel
|
37 |
+
from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
38 |
+
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
39 |
+
from transformers.models.git.configuration_git import GitConfig, GitVisionConfig
|
40 |
+
from .vit_pixel_masks_utils import ViTPatchMaskGenerator
|
41 |
+
|
42 |
+
|
43 |
+
logger = logging.get_logger(__name__)
|
44 |
+
|
45 |
+
_CHECKPOINT_FOR_DOC = "microsoft/git-base"
|
46 |
+
_CONFIG_FOR_DOC = "GitConfig"
|
47 |
+
|
48 |
+
GIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
49 |
+
"microsoft/git-base",
|
50 |
+
# See all GIT models at https://huggingface.co/models?filter=git
|
51 |
+
]
|
52 |
+
|
53 |
+
|
54 |
+
@dataclass
|
55 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Git
|
56 |
+
class GitVisionModelOutput(ModelOutput):
|
57 |
+
"""
|
58 |
+
Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
|
62 |
+
The image embeddings obtained by applying the projection layer to the pooler_output.
|
63 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
64 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
65 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
66 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
67 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
68 |
+
|
69 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
70 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
71 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
72 |
+
sequence_length)`.
|
73 |
+
|
74 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
75 |
+
heads.
|
76 |
+
"""
|
77 |
+
|
78 |
+
image_embeds: Optional[torch.FloatTensor] = None
|
79 |
+
last_hidden_state: torch.FloatTensor = None
|
80 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
81 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
82 |
+
|
83 |
+
|
84 |
+
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
85 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
86 |
+
"""
|
87 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
88 |
+
"""
|
89 |
+
bsz, src_len = mask.size()
|
90 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
91 |
+
|
92 |
+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
93 |
+
|
94 |
+
inverted_mask = 1.0 - expanded_mask
|
95 |
+
|
96 |
+
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
97 |
+
|
98 |
+
|
99 |
+
class GitEmbeddings(nn.Module):
|
100 |
+
"""Construct the embeddings from word and position embeddings."""
|
101 |
+
|
102 |
+
def __init__(self, config):
|
103 |
+
super().__init__()
|
104 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
105 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
106 |
+
|
107 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
108 |
+
# any TensorFlow checkpoint file
|
109 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
110 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
111 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
112 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
113 |
+
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
114 |
+
|
115 |
+
def forward(
|
116 |
+
self,
|
117 |
+
input_ids: Optional[torch.LongTensor] = None,
|
118 |
+
position_ids: Optional[torch.LongTensor] = None,
|
119 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
120 |
+
past_key_values_length: int = 0,
|
121 |
+
) -> torch.Tensor:
|
122 |
+
if input_ids is not None:
|
123 |
+
input_shape = input_ids.size()
|
124 |
+
else:
|
125 |
+
input_shape = inputs_embeds.size()[:-1]
|
126 |
+
|
127 |
+
seq_length = input_shape[1]
|
128 |
+
|
129 |
+
if position_ids is None:
|
130 |
+
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
131 |
+
|
132 |
+
if inputs_embeds is None:
|
133 |
+
embeddings = self.word_embeddings(input_ids)
|
134 |
+
else:
|
135 |
+
embeddings = inputs_embeds
|
136 |
+
|
137 |
+
if self.position_embedding_type == "absolute":
|
138 |
+
position_embeddings = self.position_embeddings(position_ids)
|
139 |
+
embeddings += position_embeddings
|
140 |
+
embeddings = self.LayerNorm(embeddings)
|
141 |
+
embeddings = self.dropout(embeddings)
|
142 |
+
return embeddings
|
143 |
+
|
144 |
+
|
145 |
+
class GitSelfAttention(nn.Module):
|
146 |
+
def __init__(self, config, position_embedding_type=None):
|
147 |
+
super().__init__()
|
148 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
149 |
+
raise ValueError(
|
150 |
+
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
151 |
+
f"heads ({config.num_attention_heads})"
|
152 |
+
)
|
153 |
+
|
154 |
+
self.num_attention_heads = config.num_attention_heads
|
155 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
156 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
157 |
+
self.image_patch_tokens = int((config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1)
|
158 |
+
if config.num_image_with_embedding is not None:
|
159 |
+
self.image_patch_tokens *= config.num_image_with_embedding
|
160 |
+
|
161 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
162 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
163 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
164 |
+
|
165 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
166 |
+
self.position_embedding_type = position_embedding_type or getattr(
|
167 |
+
config, "position_embedding_type", "absolute"
|
168 |
+
)
|
169 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
170 |
+
self.max_position_embeddings = config.max_position_embeddings
|
171 |
+
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
172 |
+
|
173 |
+
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
174 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
175 |
+
x = x.view(new_x_shape)
|
176 |
+
return x.permute(0, 2, 1, 3)
|
177 |
+
|
178 |
+
def forward(
|
179 |
+
self,
|
180 |
+
hidden_states: torch.Tensor,
|
181 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
182 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
183 |
+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
184 |
+
output_attentions: Optional[bool] = False,
|
185 |
+
pixel_values_present: Optional[bool] = False,
|
186 |
+
image_token_num: Optional[int] = None
|
187 |
+
) -> Tuple[torch.Tensor]:
|
188 |
+
mixed_query_layer = self.query(hidden_states)
|
189 |
+
if image_token_num is not None:
|
190 |
+
cutoff = image_token_num
|
191 |
+
else:
|
192 |
+
cutoff = self.image_patch_tokens if pixel_values_present else 0
|
193 |
+
if past_key_value is not None:
|
194 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
195 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
196 |
+
key_layer = torch.cat([key_layer[:, :, :cutoff, :], past_key_value[0], key_layer[:, :, -1:, :]], dim=2)
|
197 |
+
value_layer = torch.cat(
|
198 |
+
[value_layer[:, :, :cutoff, :], past_key_value[1], value_layer[:, :, -1:, :]], dim=2
|
199 |
+
)
|
200 |
+
else:
|
201 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
202 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
203 |
+
|
204 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
205 |
+
|
206 |
+
use_cache = past_key_value is not None
|
207 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
208 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
209 |
+
# key/value_states (first "if" case)
|
210 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
211 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
212 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
213 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
214 |
+
# NOTE: like in other caches, we store the text component. In GIT it means we discard the image component.
|
215 |
+
past_key_value = (
|
216 |
+
key_layer[:, :, cutoff:, :],
|
217 |
+
value_layer[:, :, cutoff:, :],
|
218 |
+
)
|
219 |
+
|
220 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
221 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
222 |
+
|
223 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
224 |
+
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
225 |
+
if use_cache:
|
226 |
+
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
|
227 |
+
-1, 1
|
228 |
+
)
|
229 |
+
else:
|
230 |
+
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
231 |
+
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
232 |
+
distance = position_ids_l - position_ids_r
|
233 |
+
|
234 |
+
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
235 |
+
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
236 |
+
|
237 |
+
if self.position_embedding_type == "relative_key":
|
238 |
+
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
239 |
+
attention_scores = attention_scores + relative_position_scores
|
240 |
+
elif self.position_embedding_type == "relative_key_query":
|
241 |
+
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
242 |
+
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
243 |
+
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
244 |
+
|
245 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
246 |
+
if attention_mask is not None:
|
247 |
+
# Apply the attention mask is (precomputed for all layers in GitModel forward() function)
|
248 |
+
attention_scores = attention_scores + attention_mask
|
249 |
+
|
250 |
+
# Normalize the attention scores to probabilities.
|
251 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
252 |
+
|
253 |
+
# This is actually dropping out entire tokens to attend to, which might
|
254 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
255 |
+
attention_probs = self.dropout(attention_probs)
|
256 |
+
|
257 |
+
# Mask heads if we want to
|
258 |
+
if head_mask is not None:
|
259 |
+
attention_probs = attention_probs * head_mask
|
260 |
+
|
261 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
262 |
+
|
263 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
264 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
265 |
+
context_layer = context_layer.view(new_context_layer_shape)
|
266 |
+
|
267 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
268 |
+
|
269 |
+
outputs = outputs + (past_key_value,)
|
270 |
+
return outputs
|
271 |
+
|
272 |
+
|
273 |
+
# Copied from transformers.models.bert.modeling_bert.BertSelfOutput
|
274 |
+
class GitSelfOutput(nn.Module):
|
275 |
+
def __init__(self, config):
|
276 |
+
super().__init__()
|
277 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
278 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
279 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
280 |
+
|
281 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
282 |
+
hidden_states = self.dense(hidden_states)
|
283 |
+
hidden_states = self.dropout(hidden_states)
|
284 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
285 |
+
return hidden_states
|
286 |
+
|
287 |
+
|
288 |
+
class GitAttention(nn.Module):
|
289 |
+
# Copied from transformers.models.bert.modeling_bert.BertAttention.__init__ with Bert->Git
|
290 |
+
def __init__(self, config, position_embedding_type=None):
|
291 |
+
super().__init__()
|
292 |
+
self.self = GitSelfAttention(config, position_embedding_type=position_embedding_type)
|
293 |
+
self.output = GitSelfOutput(config)
|
294 |
+
self.pruned_heads = set()
|
295 |
+
|
296 |
+
# Copied from transformers.models.bert.modeling_bert.BertAttention.prune_heads
|
297 |
+
def prune_heads(self, heads):
|
298 |
+
if len(heads) == 0:
|
299 |
+
return
|
300 |
+
heads, index = find_pruneable_heads_and_indices(
|
301 |
+
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
302 |
+
)
|
303 |
+
|
304 |
+
# Prune linear layers
|
305 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
306 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
307 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
308 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
309 |
+
|
310 |
+
# Update hyper params and store pruned heads
|
311 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
312 |
+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
313 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
314 |
+
|
315 |
+
def forward(
|
316 |
+
self,
|
317 |
+
hidden_states: torch.Tensor,
|
318 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
319 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
320 |
+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
321 |
+
output_attentions: Optional[bool] = False,
|
322 |
+
pixel_values_present: Optional[bool] = False,
|
323 |
+
image_token_num: Optional[int] = None
|
324 |
+
) -> Tuple[torch.Tensor]:
|
325 |
+
self_outputs = self.self(
|
326 |
+
hidden_states,
|
327 |
+
attention_mask,
|
328 |
+
head_mask,
|
329 |
+
past_key_value,
|
330 |
+
output_attentions,
|
331 |
+
pixel_values_present,
|
332 |
+
image_token_num
|
333 |
+
)
|
334 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
335 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
336 |
+
return outputs
|
337 |
+
|
338 |
+
|
339 |
+
# Copied from transformers.models.bert.modeling_bert.BertIntermediate
|
340 |
+
class GitIntermediate(nn.Module):
|
341 |
+
def __init__(self, config):
|
342 |
+
super().__init__()
|
343 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
344 |
+
if isinstance(config.hidden_act, str):
|
345 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
346 |
+
else:
|
347 |
+
self.intermediate_act_fn = config.hidden_act
|
348 |
+
|
349 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
350 |
+
hidden_states = self.dense(hidden_states)
|
351 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
352 |
+
return hidden_states
|
353 |
+
|
354 |
+
|
355 |
+
# Copied from transformers.models.bert.modeling_bert.BertOutput
|
356 |
+
class GitOutput(nn.Module):
|
357 |
+
def __init__(self, config):
|
358 |
+
super().__init__()
|
359 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
360 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
361 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
362 |
+
|
363 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
364 |
+
hidden_states = self.dense(hidden_states)
|
365 |
+
hidden_states = self.dropout(hidden_states)
|
366 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
367 |
+
return hidden_states
|
368 |
+
|
369 |
+
|
370 |
+
class GitLayer(nn.Module):
|
371 |
+
def __init__(self, config):
|
372 |
+
super().__init__()
|
373 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
374 |
+
self.seq_len_dim = 1
|
375 |
+
self.attention = GitAttention(config)
|
376 |
+
self.intermediate = GitIntermediate(config)
|
377 |
+
self.output = GitOutput(config)
|
378 |
+
|
379 |
+
def forward(
|
380 |
+
self,
|
381 |
+
hidden_states: torch.Tensor,
|
382 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
383 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
384 |
+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
385 |
+
output_attentions: Optional[bool] = False,
|
386 |
+
pixel_values_present: Optional[bool] = False,
|
387 |
+
image_token_num: Optional[bool] = None,
|
388 |
+
) -> Tuple[torch.Tensor]:
|
389 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
390 |
+
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
391 |
+
self_attention_outputs = self.attention(
|
392 |
+
hidden_states,
|
393 |
+
attention_mask,
|
394 |
+
head_mask,
|
395 |
+
output_attentions=output_attentions,
|
396 |
+
past_key_value=self_attn_past_key_value,
|
397 |
+
pixel_values_present=pixel_values_present,
|
398 |
+
image_token_num=image_token_num
|
399 |
+
)
|
400 |
+
attention_output = self_attention_outputs[0]
|
401 |
+
|
402 |
+
# if decoder, the last output is tuple of self-attn cache
|
403 |
+
outputs = self_attention_outputs[1:-1]
|
404 |
+
present_key_value = self_attention_outputs[-1]
|
405 |
+
|
406 |
+
layer_output = apply_chunking_to_forward(
|
407 |
+
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
408 |
+
)
|
409 |
+
outputs = (layer_output,) + outputs
|
410 |
+
|
411 |
+
# if decoder, return the attn key/values as the last output
|
412 |
+
outputs = outputs + (present_key_value,)
|
413 |
+
|
414 |
+
return outputs
|
415 |
+
|
416 |
+
def feed_forward_chunk(self, attention_output):
|
417 |
+
intermediate_output = self.intermediate(attention_output)
|
418 |
+
layer_output = self.output(intermediate_output, attention_output)
|
419 |
+
return layer_output
|
420 |
+
|
421 |
+
|
422 |
+
class GitEncoder(nn.Module):
|
423 |
+
# Copied from transformers.models.bert.modeling_bert.BertEncoder.__init__ with Bert->Git
|
424 |
+
def __init__(self, config):
|
425 |
+
super().__init__()
|
426 |
+
self.config = config
|
427 |
+
self.layer = nn.ModuleList([GitLayer(config) for _ in range(config.num_hidden_layers)])
|
428 |
+
self.gradient_checkpointing = False
|
429 |
+
|
430 |
+
def forward(
|
431 |
+
self,
|
432 |
+
hidden_states: torch.Tensor,
|
433 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
434 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
435 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
436 |
+
use_cache: Optional[bool] = None,
|
437 |
+
output_attentions: Optional[bool] = False,
|
438 |
+
output_hidden_states: Optional[bool] = False,
|
439 |
+
pixel_values_present: Optional[bool] = False,
|
440 |
+
image_token_num: Optional[int] = None,
|
441 |
+
return_dict: Optional[bool] = True,
|
442 |
+
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]:
|
443 |
+
if self.gradient_checkpointing and self.training:
|
444 |
+
if use_cache:
|
445 |
+
logger.warning_once(
|
446 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
447 |
+
)
|
448 |
+
use_cache = False
|
449 |
+
|
450 |
+
all_hidden_states = () if output_hidden_states else None
|
451 |
+
all_self_attentions = () if output_attentions else None
|
452 |
+
|
453 |
+
next_decoder_cache = () if use_cache else None
|
454 |
+
for i, layer_module in enumerate(self.layer):
|
455 |
+
if output_hidden_states:
|
456 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
457 |
+
|
458 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
459 |
+
past_key_value = past_key_values[i] if past_key_values is not None else None
|
460 |
+
|
461 |
+
if self.gradient_checkpointing and self.training:
|
462 |
+
|
463 |
+
def create_custom_forward(module):
|
464 |
+
def custom_forward(*inputs):
|
465 |
+
return module(*inputs, past_key_value, output_attentions)
|
466 |
+
|
467 |
+
return custom_forward
|
468 |
+
|
469 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
470 |
+
create_custom_forward(layer_module),
|
471 |
+
hidden_states,
|
472 |
+
attention_mask,
|
473 |
+
layer_head_mask,
|
474 |
+
)
|
475 |
+
else:
|
476 |
+
layer_outputs = layer_module(
|
477 |
+
hidden_states,
|
478 |
+
attention_mask,
|
479 |
+
layer_head_mask,
|
480 |
+
past_key_value,
|
481 |
+
output_attentions,
|
482 |
+
pixel_values_present,
|
483 |
+
image_token_num,
|
484 |
+
|
485 |
+
)
|
486 |
+
|
487 |
+
hidden_states = layer_outputs[0]
|
488 |
+
if use_cache:
|
489 |
+
next_decoder_cache += (layer_outputs[-1],)
|
490 |
+
if output_attentions:
|
491 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
492 |
+
|
493 |
+
if output_hidden_states:
|
494 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
495 |
+
|
496 |
+
if not return_dict:
|
497 |
+
return tuple(
|
498 |
+
v
|
499 |
+
for v in [
|
500 |
+
hidden_states,
|
501 |
+
next_decoder_cache,
|
502 |
+
all_hidden_states,
|
503 |
+
all_self_attentions,
|
504 |
+
]
|
505 |
+
if v is not None
|
506 |
+
)
|
507 |
+
return BaseModelOutputWithPast(
|
508 |
+
last_hidden_state=hidden_states,
|
509 |
+
past_key_values=next_decoder_cache,
|
510 |
+
hidden_states=all_hidden_states,
|
511 |
+
attentions=all_self_attentions,
|
512 |
+
)
|
513 |
+
|
514 |
+
|
515 |
+
class GitPreTrainedModel(PreTrainedModel):
|
516 |
+
"""
|
517 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
518 |
+
models.
|
519 |
+
"""
|
520 |
+
|
521 |
+
config_class = GitConfig
|
522 |
+
base_model_prefix = "git"
|
523 |
+
supports_gradient_checkpointing = True
|
524 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
525 |
+
|
526 |
+
def _init_weights(self, module):
|
527 |
+
"""Initialize the weights"""
|
528 |
+
if isinstance(module, GitVisionEmbeddings):
|
529 |
+
nn.init.normal_(module.class_embedding, mean=0.0, std=self.config.initializer_range)
|
530 |
+
nn.init.normal_(module.patch_embedding.weight, std=self.config.initializer_range)
|
531 |
+
nn.init.normal_(module.position_embedding.weight, std=self.config.initializer_range)
|
532 |
+
if isinstance(module, nn.Linear):
|
533 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
534 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
535 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
536 |
+
if module.bias is not None:
|
537 |
+
module.bias.data.zero_()
|
538 |
+
elif isinstance(module, nn.Embedding):
|
539 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
540 |
+
if module.padding_idx is not None:
|
541 |
+
module.weight.data[module.padding_idx].zero_()
|
542 |
+
elif isinstance(module, nn.LayerNorm):
|
543 |
+
module.bias.data.zero_()
|
544 |
+
module.weight.data.fill_(1.0)
|
545 |
+
|
546 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
547 |
+
if isinstance(module, (GitEncoder, GitVisionEncoder)):
|
548 |
+
module.gradient_checkpointing = value
|
549 |
+
|
550 |
+
|
551 |
+
GIT_START_DOCSTRING = r"""
|
552 |
+
|
553 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
554 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
555 |
+
etc.)
|
556 |
+
|
557 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
558 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
559 |
+
and behavior.
|
560 |
+
|
561 |
+
Parameters:
|
562 |
+
config ([`GitConfig`]): Model configuration class with all the parameters of the model.
|
563 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
564 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
565 |
+
"""
|
566 |
+
|
567 |
+
GIT_INPUTS_DOCSTRING = r"""
|
568 |
+
Args:
|
569 |
+
input_ids (`torch.LongTensor` of shape `({0})`):
|
570 |
+
Indices of input sequence tokens in the vocabulary.
|
571 |
+
|
572 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
573 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
574 |
+
|
575 |
+
[What are input IDs?](../glossary#input-ids)
|
576 |
+
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
|
577 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
578 |
+
|
579 |
+
- 1 for tokens that are **not masked**,
|
580 |
+
- 0 for tokens that are **masked**.
|
581 |
+
|
582 |
+
[What are attention masks?](../glossary#attention-mask)
|
583 |
+
|
584 |
+
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
585 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
586 |
+
config.max_position_embeddings - 1]`.
|
587 |
+
|
588 |
+
[What are position IDs?](../glossary#position-ids)
|
589 |
+
|
590 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
591 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
|
592 |
+
[`CLIPImageProcessor.__call__`] for details.
|
593 |
+
|
594 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
595 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
596 |
+
|
597 |
+
- 1 indicates the head is **not masked**,
|
598 |
+
- 0 indicates the head is **masked**.
|
599 |
+
|
600 |
+
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
|
601 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
602 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
603 |
+
model's internal embedding lookup matrix.
|
604 |
+
output_attentions (`bool`, *optional*):
|
605 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
606 |
+
tensors for more detail.
|
607 |
+
output_hidden_states (`bool`, *optional*):
|
608 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
609 |
+
more detail.
|
610 |
+
return_dict (`bool`, *optional*):
|
611 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
612 |
+
"""
|
613 |
+
|
614 |
+
|
615 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Git
|
616 |
+
class GitVisionEmbeddings(nn.Module):
|
617 |
+
def __init__(self, config: GitVisionConfig):
|
618 |
+
super().__init__()
|
619 |
+
self.config = config
|
620 |
+
self.embed_dim = config.hidden_size
|
621 |
+
self.image_size = config.image_size
|
622 |
+
self.patch_size = config.patch_size
|
623 |
+
|
624 |
+
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
|
625 |
+
|
626 |
+
self.patch_embedding = nn.Conv2d(
|
627 |
+
in_channels=config.num_channels,
|
628 |
+
out_channels=self.embed_dim,
|
629 |
+
kernel_size=self.patch_size,
|
630 |
+
stride=self.patch_size,
|
631 |
+
bias=False,
|
632 |
+
)
|
633 |
+
|
634 |
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
635 |
+
self.num_positions = self.num_patches + 1
|
636 |
+
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
637 |
+
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
|
638 |
+
|
639 |
+
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
640 |
+
batch_size = pixel_values.shape[0]
|
641 |
+
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
|
642 |
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
643 |
+
|
644 |
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
645 |
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
646 |
+
embeddings = embeddings + self.position_embedding(self.position_ids)
|
647 |
+
return embeddings
|
648 |
+
|
649 |
+
|
650 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPMLP
|
651 |
+
class GitVisionMLP(nn.Module):
|
652 |
+
def __init__(self, config):
|
653 |
+
super().__init__()
|
654 |
+
self.config = config
|
655 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
656 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
657 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
658 |
+
|
659 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
660 |
+
hidden_states = self.fc1(hidden_states)
|
661 |
+
hidden_states = self.activation_fn(hidden_states)
|
662 |
+
hidden_states = self.fc2(hidden_states)
|
663 |
+
return hidden_states
|
664 |
+
|
665 |
+
|
666 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPAttention
|
667 |
+
class GitVisionAttention(nn.Module):
|
668 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
669 |
+
|
670 |
+
def __init__(self, config):
|
671 |
+
super().__init__()
|
672 |
+
self.config = config
|
673 |
+
self.embed_dim = config.hidden_size
|
674 |
+
self.num_heads = config.num_attention_heads
|
675 |
+
self.head_dim = self.embed_dim // self.num_heads
|
676 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
677 |
+
raise ValueError(
|
678 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
679 |
+
f" {self.num_heads})."
|
680 |
+
)
|
681 |
+
self.scale = self.head_dim**-0.5
|
682 |
+
self.dropout = config.attention_dropout
|
683 |
+
|
684 |
+
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
685 |
+
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
686 |
+
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
687 |
+
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
688 |
+
|
689 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
690 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
691 |
+
|
692 |
+
def forward(
|
693 |
+
self,
|
694 |
+
hidden_states: torch.Tensor,
|
695 |
+
attention_mask: Optional[torch.Tensor] = None,
|
696 |
+
causal_attention_mask: Optional[torch.Tensor] = None,
|
697 |
+
output_attentions: Optional[bool] = False,
|
698 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
699 |
+
"""Input shape: Batch x Time x Channel"""
|
700 |
+
|
701 |
+
bsz, tgt_len, embed_dim = hidden_states.size()
|
702 |
+
|
703 |
+
# get query proj
|
704 |
+
query_states = self.q_proj(hidden_states) * self.scale
|
705 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
706 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
707 |
+
|
708 |
+
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
709 |
+
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
710 |
+
key_states = key_states.view(*proj_shape)
|
711 |
+
value_states = value_states.view(*proj_shape)
|
712 |
+
|
713 |
+
src_len = key_states.size(1)
|
714 |
+
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
715 |
+
|
716 |
+
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
717 |
+
raise ValueError(
|
718 |
+
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
719 |
+
f" {attn_weights.size()}"
|
720 |
+
)
|
721 |
+
|
722 |
+
# apply the causal_attention_mask first
|
723 |
+
if causal_attention_mask is not None:
|
724 |
+
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
725 |
+
raise ValueError(
|
726 |
+
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
727 |
+
f" {causal_attention_mask.size()}"
|
728 |
+
)
|
729 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
|
730 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
731 |
+
|
732 |
+
if attention_mask is not None:
|
733 |
+
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
734 |
+
raise ValueError(
|
735 |
+
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
736 |
+
)
|
737 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
738 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
739 |
+
|
740 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
741 |
+
|
742 |
+
if output_attentions:
|
743 |
+
# this operation is a bit akward, but it's required to
|
744 |
+
# make sure that attn_weights keeps its gradient.
|
745 |
+
# In order to do so, attn_weights have to reshaped
|
746 |
+
# twice and have to be reused in the following
|
747 |
+
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
748 |
+
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
749 |
+
else:
|
750 |
+
attn_weights_reshaped = None
|
751 |
+
|
752 |
+
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
753 |
+
|
754 |
+
attn_output = torch.bmm(attn_probs, value_states)
|
755 |
+
|
756 |
+
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
757 |
+
raise ValueError(
|
758 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
759 |
+
f" {attn_output.size()}"
|
760 |
+
)
|
761 |
+
|
762 |
+
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
763 |
+
attn_output = attn_output.transpose(1, 2)
|
764 |
+
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
765 |
+
|
766 |
+
attn_output = self.out_proj(attn_output)
|
767 |
+
|
768 |
+
return attn_output, attn_weights_reshaped
|
769 |
+
|
770 |
+
|
771 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->GitVision
|
772 |
+
class GitVisionEncoderLayer(nn.Module):
|
773 |
+
def __init__(self, config: GitVisionConfig):
|
774 |
+
super().__init__()
|
775 |
+
self.embed_dim = config.hidden_size
|
776 |
+
self.self_attn = GitVisionAttention(config)
|
777 |
+
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
778 |
+
self.mlp = GitVisionMLP(config)
|
779 |
+
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
780 |
+
|
781 |
+
def forward(
|
782 |
+
self,
|
783 |
+
hidden_states: torch.Tensor,
|
784 |
+
attention_mask: torch.Tensor,
|
785 |
+
causal_attention_mask: torch.Tensor,
|
786 |
+
output_attentions: Optional[bool] = False,
|
787 |
+
) -> Tuple[torch.FloatTensor]:
|
788 |
+
"""
|
789 |
+
Args:
|
790 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
791 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
792 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
793 |
+
`(config.encoder_attention_heads,)`.
|
794 |
+
output_attentions (`bool`, *optional*):
|
795 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
796 |
+
returned tensors for more detail.
|
797 |
+
"""
|
798 |
+
residual = hidden_states
|
799 |
+
|
800 |
+
hidden_states = self.layer_norm1(hidden_states)
|
801 |
+
hidden_states, attn_weights = self.self_attn(
|
802 |
+
hidden_states=hidden_states,
|
803 |
+
attention_mask=attention_mask,
|
804 |
+
causal_attention_mask=causal_attention_mask,
|
805 |
+
output_attentions=output_attentions,
|
806 |
+
)
|
807 |
+
hidden_states = residual + hidden_states
|
808 |
+
|
809 |
+
residual = hidden_states
|
810 |
+
hidden_states = self.layer_norm2(hidden_states)
|
811 |
+
hidden_states = self.mlp(hidden_states)
|
812 |
+
hidden_states = residual + hidden_states
|
813 |
+
|
814 |
+
outputs = (hidden_states,)
|
815 |
+
|
816 |
+
if output_attentions:
|
817 |
+
outputs += (attn_weights,)
|
818 |
+
|
819 |
+
return outputs
|
820 |
+
|
821 |
+
|
822 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->GitVision, CLIPConfig
|
823 |
+
class GitVisionEncoder(nn.Module):
|
824 |
+
"""
|
825 |
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
826 |
+
[`GitVisionEncoderLayer`].
|
827 |
+
|
828 |
+
Args:
|
829 |
+
config: GitVisionConfig
|
830 |
+
"""
|
831 |
+
|
832 |
+
def __init__(self, config: GitVisionConfig):
|
833 |
+
super().__init__()
|
834 |
+
self.config = config
|
835 |
+
self.layers = nn.ModuleList([GitVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
836 |
+
self.gradient_checkpointing = False
|
837 |
+
|
838 |
+
def forward(
|
839 |
+
self,
|
840 |
+
inputs_embeds,
|
841 |
+
attention_mask: Optional[torch.Tensor] = None,
|
842 |
+
causal_attention_mask: Optional[torch.Tensor] = None,
|
843 |
+
output_attentions: Optional[bool] = None,
|
844 |
+
output_hidden_states: Optional[bool] = None,
|
845 |
+
return_dict: Optional[bool] = None,
|
846 |
+
) -> Union[Tuple, BaseModelOutput]:
|
847 |
+
r"""
|
848 |
+
Args:
|
849 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
850 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
851 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
852 |
+
than the model's internal embedding lookup matrix.
|
853 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
854 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
855 |
+
|
856 |
+
- 1 for tokens that are **not masked**,
|
857 |
+
- 0 for tokens that are **masked**.
|
858 |
+
|
859 |
+
[What are attention masks?](../glossary#attention-mask)
|
860 |
+
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
861 |
+
Causal mask for the text model. Mask values selected in `[0, 1]`:
|
862 |
+
|
863 |
+
- 1 for tokens that are **not masked**,
|
864 |
+
- 0 for tokens that are **masked**.
|
865 |
+
|
866 |
+
[What are attention masks?](../glossary#attention-mask)
|
867 |
+
output_attentions (`bool`, *optional*):
|
868 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
869 |
+
returned tensors for more detail.
|
870 |
+
output_hidden_states (`bool`, *optional*):
|
871 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
872 |
+
for more detail.
|
873 |
+
return_dict (`bool`, *optional*):
|
874 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
875 |
+
"""
|
876 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
877 |
+
output_hidden_states = (
|
878 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
879 |
+
)
|
880 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
881 |
+
|
882 |
+
encoder_states = () if output_hidden_states else None
|
883 |
+
all_attentions = () if output_attentions else None
|
884 |
+
|
885 |
+
hidden_states = inputs_embeds
|
886 |
+
for idx, encoder_layer in enumerate(self.layers):
|
887 |
+
if output_hidden_states:
|
888 |
+
encoder_states = encoder_states + (hidden_states,)
|
889 |
+
if self.gradient_checkpointing and self.training:
|
890 |
+
|
891 |
+
def create_custom_forward(module):
|
892 |
+
def custom_forward(*inputs):
|
893 |
+
return module(*inputs, output_attentions)
|
894 |
+
|
895 |
+
return custom_forward
|
896 |
+
|
897 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
898 |
+
create_custom_forward(encoder_layer),
|
899 |
+
hidden_states,
|
900 |
+
attention_mask,
|
901 |
+
causal_attention_mask,
|
902 |
+
)
|
903 |
+
else:
|
904 |
+
layer_outputs = encoder_layer(
|
905 |
+
hidden_states,
|
906 |
+
attention_mask,
|
907 |
+
causal_attention_mask,
|
908 |
+
output_attentions=output_attentions,
|
909 |
+
)
|
910 |
+
|
911 |
+
hidden_states = layer_outputs[0]
|
912 |
+
|
913 |
+
if output_attentions:
|
914 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
915 |
+
|
916 |
+
if output_hidden_states:
|
917 |
+
encoder_states = encoder_states + (hidden_states,)
|
918 |
+
|
919 |
+
if not return_dict:
|
920 |
+
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
921 |
+
return BaseModelOutput(
|
922 |
+
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
923 |
+
)
|
924 |
+
|
925 |
+
|
926 |
+
GIT_VISION_INPUTS_DOCSTRING = r"""
|
927 |
+
Args:
|
928 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
929 |
+
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
930 |
+
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
931 |
+
output_attentions (`bool`, *optional*):
|
932 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
933 |
+
tensors for more detail.
|
934 |
+
output_hidden_states (`bool`, *optional*):
|
935 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
936 |
+
more detail.
|
937 |
+
return_dict (`bool`, *optional*):
|
938 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
939 |
+
"""
|
940 |
+
|
941 |
+
|
942 |
+
class GitVisionTransformer(nn.Module):
|
943 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer.__init__ with CLIPEncoder->GitVisionEncoder, CLIP->Git
|
944 |
+
def __init__(self, config: GitVisionConfig):
|
945 |
+
super().__init__()
|
946 |
+
self.config = config
|
947 |
+
embed_dim = config.hidden_size
|
948 |
+
|
949 |
+
self.embeddings = GitVisionEmbeddings(config)
|
950 |
+
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
951 |
+
self.patch_mask_generator = ViTPatchMaskGenerator(config.patch_size)
|
952 |
+
self.encoder = GitVisionEncoder(config)
|
953 |
+
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
954 |
+
|
955 |
+
@add_start_docstrings_to_model_forward(GIT_VISION_INPUTS_DOCSTRING)
|
956 |
+
@replace_return_docstrings(output_type=BaseModelOutput, config_class=GitVisionConfig)
|
957 |
+
def forward(
|
958 |
+
self,
|
959 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
960 |
+
pixel_masks: Optional[torch.Tensor] = None,
|
961 |
+
output_attentions: Optional[bool] = None,
|
962 |
+
output_hidden_states: Optional[bool] = None,
|
963 |
+
return_dict: Optional[bool] = None,
|
964 |
+
) -> Union[Tuple, BaseModelOutput]:
|
965 |
+
r"""
|
966 |
+
Returns:
|
967 |
+
|
968 |
+
"""
|
969 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
970 |
+
output_hidden_states = (
|
971 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
972 |
+
)
|
973 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
974 |
+
|
975 |
+
if pixel_values is None:
|
976 |
+
raise ValueError("You have to specify pixel_values")
|
977 |
+
|
978 |
+
hidden_states = self.embeddings(pixel_values)
|
979 |
+
B, N, D = hidden_states.shape
|
980 |
+
# print('Before mask:', hidden_states.shape)
|
981 |
+
if pixel_masks is not None:
|
982 |
+
assert pixel_masks.shape[0] == 1
|
983 |
+
patch_masks = self.patch_mask_generator(pixel_masks)
|
984 |
+
# print(patch_masks.shape)
|
985 |
+
patch_masks = patch_masks.unsqueeze(-1).expand_as(hidden_states)
|
986 |
+
hidden_states = hidden_states.masked_select(patch_masks).view(B, -1, D)
|
987 |
+
# print('After mask:', hidden_states.shape)
|
988 |
+
hidden_states = self.pre_layrnorm(hidden_states)
|
989 |
+
|
990 |
+
encoder_outputs = self.encoder(
|
991 |
+
inputs_embeds=hidden_states,
|
992 |
+
output_attentions=output_attentions,
|
993 |
+
output_hidden_states=output_hidden_states,
|
994 |
+
return_dict=return_dict,
|
995 |
+
)
|
996 |
+
|
997 |
+
last_hidden_state = encoder_outputs[0]
|
998 |
+
|
999 |
+
last_hidden_state = self.post_layernorm(last_hidden_state)
|
1000 |
+
|
1001 |
+
if not return_dict:
|
1002 |
+
return (last_hidden_state,) + encoder_outputs[1:]
|
1003 |
+
|
1004 |
+
return BaseModelOutput(
|
1005 |
+
last_hidden_state=last_hidden_state,
|
1006 |
+
hidden_states=encoder_outputs.hidden_states,
|
1007 |
+
attentions=encoder_outputs.attentions,
|
1008 |
+
)
|
1009 |
+
|
1010 |
+
|
1011 |
+
@add_start_docstrings(
|
1012 |
+
"""The vision model from CLIP, used in GIT, without any head or projection on top.""",
|
1013 |
+
GIT_START_DOCSTRING,
|
1014 |
+
)
|
1015 |
+
class GitVisionModel(GitPreTrainedModel):
|
1016 |
+
config_class = GitVisionConfig
|
1017 |
+
main_input_name = "pixel_values"
|
1018 |
+
|
1019 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPVisionModel.__init__ with CLIP->Git
|
1020 |
+
def __init__(self, config: GitVisionConfig):
|
1021 |
+
super().__init__(config)
|
1022 |
+
self.vision_model = GitVisionTransformer(config)
|
1023 |
+
# Initialize weights and apply final processing
|
1024 |
+
self.post_init()
|
1025 |
+
|
1026 |
+
def get_input_embeddings(self) -> nn.Module:
|
1027 |
+
return self.vision_model.embeddings.patch_embedding
|
1028 |
+
|
1029 |
+
@add_start_docstrings_to_model_forward(GIT_VISION_INPUTS_DOCSTRING)
|
1030 |
+
@replace_return_docstrings(output_type=BaseModelOutput, config_class=GitVisionConfig)
|
1031 |
+
def forward(
|
1032 |
+
self,
|
1033 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
1034 |
+
pixel_masks: Optional[torch.Tensor] = None,
|
1035 |
+
output_attentions: Optional[bool] = None,
|
1036 |
+
output_hidden_states: Optional[bool] = None,
|
1037 |
+
return_dict: Optional[bool] = None,
|
1038 |
+
) -> Union[Tuple, BaseModelOutput]:
|
1039 |
+
r"""
|
1040 |
+
Returns:
|
1041 |
+
|
1042 |
+
Examples:
|
1043 |
+
|
1044 |
+
```python
|
1045 |
+
>>> from PIL import Image
|
1046 |
+
>>> import requests
|
1047 |
+
>>> from transformers import AutoProcessor, GitVisionModel
|
1048 |
+
|
1049 |
+
>>> processor = AutoProcessor.from_pretrained("microsoft/git-base")
|
1050 |
+
>>> model = GitVisionModel.from_pretrained("microsoft/git-base")
|
1051 |
+
|
1052 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
1053 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
1054 |
+
|
1055 |
+
>>> inputs = processor(images=image, return_tensors="pt")
|
1056 |
+
|
1057 |
+
>>> outputs = model(**inputs)
|
1058 |
+
>>> last_hidden_state = outputs.last_hidden_state
|
1059 |
+
```"""
|
1060 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1061 |
+
|
1062 |
+
return self.vision_model(
|
1063 |
+
pixel_values=pixel_values,
|
1064 |
+
pixel_masks=pixel_masks,
|
1065 |
+
output_attentions=output_attentions,
|
1066 |
+
output_hidden_states=output_hidden_states,
|
1067 |
+
return_dict=return_dict,
|
1068 |
+
)
|
1069 |
+
|
1070 |
+
|
1071 |
+
class GitProjection(nn.Module):
|
1072 |
+
def __init__(self, config: GitConfig):
|
1073 |
+
super().__init__()
|
1074 |
+
self.config = config
|
1075 |
+
self.visual_projection = nn.Sequential(
|
1076 |
+
nn.Linear(config.vision_config.hidden_size, config.hidden_size),
|
1077 |
+
nn.LayerNorm(config.hidden_size, eps=config.vision_config.layer_norm_eps),
|
1078 |
+
)
|
1079 |
+
|
1080 |
+
def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
|
1081 |
+
return self.visual_projection(embeddings)
|
1082 |
+
|
1083 |
+
|
1084 |
+
@add_start_docstrings(
|
1085 |
+
"The bare GIT Model transformer consisting of a CLIP image encoder and text decoder outputting raw hidden-states"
|
1086 |
+
" without any specific head on top.",
|
1087 |
+
GIT_START_DOCSTRING,
|
1088 |
+
)
|
1089 |
+
class GitModel(GitPreTrainedModel):
|
1090 |
+
def __init__(self, config):
|
1091 |
+
super().__init__(config)
|
1092 |
+
self.config = config
|
1093 |
+
|
1094 |
+
self.embeddings = GitEmbeddings(config)
|
1095 |
+
self.image_encoder = GitVisionModel(config.vision_config)
|
1096 |
+
self.encoder = GitEncoder(config)
|
1097 |
+
|
1098 |
+
self.visual_projection = GitProjection(config)
|
1099 |
+
|
1100 |
+
if config.num_image_with_embedding is not None:
|
1101 |
+
self.img_temperal_embedding = nn.ParameterList(
|
1102 |
+
nn.Parameter(torch.zeros(1, 1, config.vision_config.hidden_size))
|
1103 |
+
for _ in range(config.num_image_with_embedding)
|
1104 |
+
)
|
1105 |
+
|
1106 |
+
# Initialize weights and apply final processing
|
1107 |
+
self.post_init()
|
1108 |
+
|
1109 |
+
def get_input_embeddings(self):
|
1110 |
+
return self.embeddings.word_embeddings
|
1111 |
+
|
1112 |
+
def set_input_embeddings(self, value):
|
1113 |
+
self.embeddings.word_embeddings = value
|
1114 |
+
|
1115 |
+
def _prune_heads(self, heads_to_prune):
|
1116 |
+
"""
|
1117 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
1118 |
+
class PreTrainedModel
|
1119 |
+
"""
|
1120 |
+
for layer, heads in heads_to_prune.items():
|
1121 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
1122 |
+
|
1123 |
+
def _generate_future_mask(self, size: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
|
1124 |
+
# Default mask is for forward direction. Flip for backward direction.
|
1125 |
+
mask = torch.triu(torch.ones(size, size, device=device, dtype=dtype), diagonal=1)
|
1126 |
+
mask = mask.masked_fill(mask == 1, float("-inf"))
|
1127 |
+
return mask
|
1128 |
+
|
1129 |
+
def create_attention_mask(self, tgt, memory, tgt_mask, past_key_values_length, memory_key_padding_mask=None):
|
1130 |
+
num_tgt = tgt.shape[1]
|
1131 |
+
num_memory = memory.shape[1]
|
1132 |
+
device = tgt.device
|
1133 |
+
dtype = tgt.dtype
|
1134 |
+
top_left = torch.zeros((num_memory, num_memory), device=device, dtype=dtype)
|
1135 |
+
top_right = torch.full(
|
1136 |
+
(num_memory, num_tgt + past_key_values_length),
|
1137 |
+
float("-inf"),
|
1138 |
+
device=tgt.device,
|
1139 |
+
dtype=dtype,
|
1140 |
+
)
|
1141 |
+
bottom_left = torch.zeros(
|
1142 |
+
(num_tgt, num_memory),
|
1143 |
+
dtype=dtype,
|
1144 |
+
device=tgt_mask.device,
|
1145 |
+
)
|
1146 |
+
|
1147 |
+
if past_key_values_length > 0:
|
1148 |
+
tgt_mask = torch.zeros(
|
1149 |
+
(tgt_mask.shape[0], tgt_mask.shape[0] + past_key_values_length),
|
1150 |
+
dtype=dtype,
|
1151 |
+
device=tgt_mask.device,
|
1152 |
+
)
|
1153 |
+
|
1154 |
+
left = torch.cat((top_left, bottom_left), dim=0)
|
1155 |
+
right = torch.cat((top_right, tgt_mask.to(dtype)), dim=0)
|
1156 |
+
|
1157 |
+
full_attention_mask = torch.cat((left, right), dim=1)[None, :]
|
1158 |
+
|
1159 |
+
if memory_key_padding_mask is None:
|
1160 |
+
memory_key_padding_mask = torch.full((memory.shape[0], memory.shape[1]), fill_value=False, device=device)
|
1161 |
+
# if it is False, it means valid. That is, it is not a padding
|
1162 |
+
if memory_key_padding_mask.dtype != torch.bool:
|
1163 |
+
raise ValueError("Memory key padding mask must be a boolean tensor.")
|
1164 |
+
zero_negative_infinity = torch.zeros_like(memory_key_padding_mask, dtype=tgt.dtype)
|
1165 |
+
zero_negative_infinity[memory_key_padding_mask] = float("-inf")
|
1166 |
+
full_attention_mask = full_attention_mask.expand(
|
1167 |
+
(memory_key_padding_mask.shape[0], num_memory + num_tgt, num_memory + past_key_values_length + num_tgt)
|
1168 |
+
)
|
1169 |
+
full_attention_mask = full_attention_mask.clone()
|
1170 |
+
origin_left = full_attention_mask[:, :, :num_memory]
|
1171 |
+
update = zero_negative_infinity[:, None, :]
|
1172 |
+
full_attention_mask[:, :, :num_memory] = origin_left + update
|
1173 |
+
|
1174 |
+
# add axis for multi-head
|
1175 |
+
full_attention_mask = full_attention_mask[:, None, :, :]
|
1176 |
+
|
1177 |
+
return full_attention_mask
|
1178 |
+
|
1179 |
+
@add_start_docstrings_to_model_forward(GIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
1180 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
|
1181 |
+
def forward(
|
1182 |
+
self,
|
1183 |
+
input_ids: Optional[torch.Tensor] = None,
|
1184 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1185 |
+
position_ids: Optional[torch.Tensor] = None,
|
1186 |
+
pixel_values: Optional[torch.Tensor] = None,
|
1187 |
+
pixel_masks: Optional[torch.Tensor] = None,
|
1188 |
+
head_mask: Optional[torch.Tensor] = None,
|
1189 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1190 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1191 |
+
use_cache: Optional[bool] = None,
|
1192 |
+
output_attentions: Optional[bool] = None,
|
1193 |
+
output_hidden_states: Optional[bool] = None,
|
1194 |
+
return_dict: Optional[bool] = None,
|
1195 |
+
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPooling]:
|
1196 |
+
r"""
|
1197 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
1198 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
1199 |
+
|
1200 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
1201 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
1202 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
1203 |
+
use_cache (`bool`, *optional*):
|
1204 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
1205 |
+
`past_key_values`).
|
1206 |
+
|
1207 |
+
Returns:
|
1208 |
+
|
1209 |
+
Examples:
|
1210 |
+
|
1211 |
+
```python
|
1212 |
+
>>> from transformers import AutoProcessor, AutoModel
|
1213 |
+
>>> import requests
|
1214 |
+
>>> from PIL import Image
|
1215 |
+
|
1216 |
+
>>> processor = AutoProcessor.from_pretrained("microsoft/git-base")
|
1217 |
+
>>> model = AutoModel.from_pretrained("microsoft/git-base")
|
1218 |
+
|
1219 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
1220 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
1221 |
+
|
1222 |
+
>>> text = "this is an image of two cats"
|
1223 |
+
|
1224 |
+
>>> inputs = processor(text, images=image, return_tensors="pt")
|
1225 |
+
|
1226 |
+
>>> outputs = model(**inputs)
|
1227 |
+
>>> last_hidden_state = outputs.last_hidden_state
|
1228 |
+
```"""
|
1229 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1230 |
+
output_hidden_states = (
|
1231 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1232 |
+
)
|
1233 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
1234 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1235 |
+
|
1236 |
+
if input_ids is not None and inputs_embeds is not None:
|
1237 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
1238 |
+
elif input_ids is not None:
|
1239 |
+
input_shape = input_ids.size()
|
1240 |
+
elif inputs_embeds is not None:
|
1241 |
+
input_shape = inputs_embeds.size()[:-1]
|
1242 |
+
else:
|
1243 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
1244 |
+
|
1245 |
+
seq_length = input_shape[1]
|
1246 |
+
|
1247 |
+
# past_key_values_length
|
1248 |
+
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
1249 |
+
|
1250 |
+
# Prepare head mask if needed
|
1251 |
+
# 1.0 in head_mask indicate we keep the head
|
1252 |
+
# attention_probs has shape bsz x n_heads x N x N
|
1253 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
1254 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
1255 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
1256 |
+
|
1257 |
+
projected_visual_features = None
|
1258 |
+
if pixel_values is not None:
|
1259 |
+
if pixel_values.ndim == 4:
|
1260 |
+
# here we assume pixel_values is of shape (batch_size, num_channels, height, width)
|
1261 |
+
visual_features = self.image_encoder(pixel_values=pixel_values, pixel_masks=pixel_masks).last_hidden_state
|
1262 |
+
|
1263 |
+
elif pixel_values.ndim == 5:
|
1264 |
+
# here we assume pixel_values is of shape (batch_size, num_frames, num_channels, height, width)
|
1265 |
+
visual_features = []
|
1266 |
+
for frame_idx in range(pixel_values.shape[1]):
|
1267 |
+
visual_features_frame = self.image_encoder(pixel_values[:, frame_idx, :, :]).last_hidden_state
|
1268 |
+
visual_features_frame += self.img_temperal_embedding[frame_idx]
|
1269 |
+
visual_features.append(visual_features_frame)
|
1270 |
+
|
1271 |
+
# finally, concatenate all features along sequence dimension
|
1272 |
+
visual_features = torch.cat(visual_features, dim=1)
|
1273 |
+
|
1274 |
+
else:
|
1275 |
+
raise ValueError("pixel_values must be of rank 4 or 5")
|
1276 |
+
|
1277 |
+
projected_visual_features = self.visual_projection(visual_features)
|
1278 |
+
image_token_num = projected_visual_features.shape[1]
|
1279 |
+
embedding_output = self.embeddings(
|
1280 |
+
input_ids=input_ids,
|
1281 |
+
position_ids=position_ids,
|
1282 |
+
inputs_embeds=inputs_embeds,
|
1283 |
+
past_key_values_length=past_key_values_length,
|
1284 |
+
)
|
1285 |
+
|
1286 |
+
if projected_visual_features is None:
|
1287 |
+
projected_visual_features = torch.zeros(
|
1288 |
+
(embedding_output.shape[0], 0, embedding_output.shape[2]),
|
1289 |
+
dtype=embedding_output.dtype,
|
1290 |
+
device=embedding_output.device,
|
1291 |
+
)
|
1292 |
+
|
1293 |
+
# Repeat visual features to match embedding batch size.
|
1294 |
+
projected_visual_features = projected_visual_features.repeat(
|
1295 |
+
embedding_output.size(0) // projected_visual_features.size(0), 1, 1
|
1296 |
+
)
|
1297 |
+
|
1298 |
+
# concatenate patch token and text token embeddings
|
1299 |
+
hidden_states = torch.cat((projected_visual_features, embedding_output), dim=1)
|
1300 |
+
|
1301 |
+
# By default, an additive causal mask is created
|
1302 |
+
# for masking the future (one direction).
|
1303 |
+
tgt_mask = self._generate_future_mask(seq_length, embedding_output.dtype, embedding_output.device)
|
1304 |
+
|
1305 |
+
# Create an attention mask of shape (batch_size, 1, tgt_seq_len, src_seq_len)
|
1306 |
+
combined_attention_mask = self.create_attention_mask(
|
1307 |
+
tgt=embedding_output,
|
1308 |
+
memory=projected_visual_features,
|
1309 |
+
tgt_mask=tgt_mask,
|
1310 |
+
past_key_values_length=past_key_values_length,
|
1311 |
+
)
|
1312 |
+
|
1313 |
+
if attention_mask is not None:
|
1314 |
+
# if the user provides an attention mask, we add it to the default one
|
1315 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
1316 |
+
expanded_attn_mask = _expand_mask(attention_mask, embedding_output.dtype, tgt_len=input_shape[-1]).to(
|
1317 |
+
embedding_output.device
|
1318 |
+
)
|
1319 |
+
if past_key_values_length > 0:
|
1320 |
+
expanded_attn_mask = expanded_attn_mask[:, :, -past_key_values_length:, :]
|
1321 |
+
else:
|
1322 |
+
combined_attention_mask[:, :, -input_shape[1] :, -input_shape[1] :] += expanded_attn_mask
|
1323 |
+
|
1324 |
+
encoder_outputs = self.encoder(
|
1325 |
+
hidden_states,
|
1326 |
+
attention_mask=combined_attention_mask,
|
1327 |
+
head_mask=head_mask,
|
1328 |
+
past_key_values=past_key_values,
|
1329 |
+
use_cache=use_cache,
|
1330 |
+
output_attentions=output_attentions,
|
1331 |
+
output_hidden_states=output_hidden_states,
|
1332 |
+
return_dict=return_dict,
|
1333 |
+
pixel_values_present=pixel_values is not None,
|
1334 |
+
image_token_num=image_token_num
|
1335 |
+
)
|
1336 |
+
sequence_output = encoder_outputs[0]
|
1337 |
+
|
1338 |
+
if not return_dict:
|
1339 |
+
return (sequence_output,) + encoder_outputs[1:]
|
1340 |
+
|
1341 |
+
return BaseModelOutputWithPast(
|
1342 |
+
last_hidden_state=sequence_output,
|
1343 |
+
past_key_values=encoder_outputs.past_key_values,
|
1344 |
+
hidden_states=encoder_outputs.hidden_states,
|
1345 |
+
attentions=encoder_outputs.attentions,
|
1346 |
+
)
|
1347 |
+
|
1348 |
+
|
1349 |
+
@add_start_docstrings(
|
1350 |
+
"""GIT Model with a `language modeling` head on top for autoregressive language modeling.""", GIT_START_DOCSTRING
|
1351 |
+
)
|
1352 |
+
class GitForCausalLM(GitPreTrainedModel):
|
1353 |
+
def __init__(self, config):
|
1354 |
+
super().__init__(config)
|
1355 |
+
|
1356 |
+
self.git = GitModel(config)
|
1357 |
+
self.output = nn.Linear(config.hidden_size, config.vocab_size)
|
1358 |
+
|
1359 |
+
# Initialize weights and apply final processing
|
1360 |
+
self.post_init()
|
1361 |
+
|
1362 |
+
def get_output_embeddings(self):
|
1363 |
+
return self.output
|
1364 |
+
|
1365 |
+
def set_output_embeddings(self, new_embeddings):
|
1366 |
+
self.output = new_embeddings
|
1367 |
+
|
1368 |
+
@add_start_docstrings_to_model_forward(GIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
1369 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1370 |
+
def forward(
|
1371 |
+
self,
|
1372 |
+
input_ids: Optional[torch.Tensor] = None,
|
1373 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1374 |
+
position_ids: Optional[torch.Tensor] = None,
|
1375 |
+
pixel_values: Optional[torch.Tensor] = None,
|
1376 |
+
pixel_masks: Optional[torch.Tensor] = None,
|
1377 |
+
head_mask: Optional[torch.Tensor] = None,
|
1378 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1379 |
+
labels: Optional[torch.Tensor] = None,
|
1380 |
+
past_key_values: Optional[List[torch.Tensor]] = None,
|
1381 |
+
use_cache: Optional[bool] = None,
|
1382 |
+
output_attentions: Optional[bool] = None,
|
1383 |
+
output_hidden_states: Optional[bool] = None,
|
1384 |
+
return_dict: Optional[bool] = None,
|
1385 |
+
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithPast]:
|
1386 |
+
r"""
|
1387 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1388 |
+
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
1389 |
+
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
|
1390 |
+
ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
|
1391 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
1392 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
1393 |
+
|
1394 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
1395 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
1396 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
1397 |
+
use_cache (`bool`, *optional*):
|
1398 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
1399 |
+
`past_key_values`).
|
1400 |
+
|
1401 |
+
Returns:
|
1402 |
+
|
1403 |
+
Examples:
|
1404 |
+
|
1405 |
+
Image captioning example:
|
1406 |
+
|
1407 |
+
```python
|
1408 |
+
>>> from transformers import AutoProcessor, AutoModelForCausalLM
|
1409 |
+
>>> import requests
|
1410 |
+
>>> from PIL import Image
|
1411 |
+
|
1412 |
+
>>> processor = AutoProcessor.from_pretrained("microsoft/git-base-coco")
|
1413 |
+
>>> model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco")
|
1414 |
+
|
1415 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
1416 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
1417 |
+
|
1418 |
+
>>> pixel_values = processor(images=image, return_tensors="pt").pixel_values
|
1419 |
+
|
1420 |
+
>>> generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
|
1421 |
+
>>> generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
1422 |
+
>>> print(generated_caption)
|
1423 |
+
two cats sleeping on a pink blanket next to remotes.
|
1424 |
+
```
|
1425 |
+
|
1426 |
+
Visual question answering (VQA) example:
|
1427 |
+
|
1428 |
+
```python
|
1429 |
+
>>> from transformers import AutoProcessor, AutoModelForCausalLM
|
1430 |
+
>>> from huggingface_hub import hf_hub_download
|
1431 |
+
>>> from PIL import Image
|
1432 |
+
|
1433 |
+
>>> processor = AutoProcessor.from_pretrained("microsoft/git-base-textvqa")
|
1434 |
+
>>> model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-textvqa")
|
1435 |
+
|
1436 |
+
>>> file_path = hf_hub_download(repo_id="nielsr/textvqa-sample", filename="bus.png", repo_type="dataset")
|
1437 |
+
>>> image = Image.open(file_path).convert("RGB")
|
1438 |
+
|
1439 |
+
>>> pixel_values = processor(images=image, return_tensors="pt").pixel_values
|
1440 |
+
|
1441 |
+
>>> question = "what does the front of the bus say at the top?"
|
1442 |
+
|
1443 |
+
>>> input_ids = processor(text=question, add_special_tokens=False).input_ids
|
1444 |
+
>>> input_ids = [processor.tokenizer.cls_token_id] + input_ids
|
1445 |
+
>>> input_ids = torch.tensor(input_ids).unsqueeze(0)
|
1446 |
+
|
1447 |
+
>>> generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50)
|
1448 |
+
>>> print(processor.batch_decode(generated_ids, skip_special_tokens=True))
|
1449 |
+
['what does the front of the bus say at the top? special']
|
1450 |
+
```
|
1451 |
+
|
1452 |
+
Video captioning example:
|
1453 |
+
|
1454 |
+
```python
|
1455 |
+
>>> import av
|
1456 |
+
>>> import numpy as np
|
1457 |
+
>>> from PIL import Image
|
1458 |
+
>>> from huggingface_hub import hf_hub_download
|
1459 |
+
>>> from transformers import AutoProcessor, AutoModelForCausalLM
|
1460 |
+
|
1461 |
+
>>> processor = AutoProcessor.from_pretrained("microsoft/git-base-vatex")
|
1462 |
+
>>> model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-vatex")
|
1463 |
+
|
1464 |
+
>>> # set seed for reproducability
|
1465 |
+
>>> np.random.seed(45)
|
1466 |
+
|
1467 |
+
|
1468 |
+
>>> def read_video_pyav(container, indices):
|
1469 |
+
... '''
|
1470 |
+
... Decode the video with PyAV decoder.
|
1471 |
+
... Args:
|
1472 |
+
... container (`av.container.input.InputContainer`): PyAV container.
|
1473 |
+
... indices (`List[int]`): List of frame indices to decode.
|
1474 |
+
... Returns:
|
1475 |
+
... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
|
1476 |
+
... '''
|
1477 |
+
... frames = []
|
1478 |
+
... container.seek(0)
|
1479 |
+
... start_index = indices[0]
|
1480 |
+
... end_index = indices[-1]
|
1481 |
+
... for i, frame in enumerate(container.decode(video=0)):
|
1482 |
+
... if i > end_index:
|
1483 |
+
... break
|
1484 |
+
... if i >= start_index and i in indices:
|
1485 |
+
... frames.append(frame)
|
1486 |
+
... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
|
1487 |
+
|
1488 |
+
|
1489 |
+
>>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
|
1490 |
+
... converted_len = int(clip_len * frame_sample_rate)
|
1491 |
+
... end_idx = np.random.randint(converted_len, seg_len)
|
1492 |
+
... start_idx = end_idx - converted_len
|
1493 |
+
... indices = np.linspace(start_idx, end_idx, num=clip_len)
|
1494 |
+
... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
|
1495 |
+
... return indices
|
1496 |
+
|
1497 |
+
|
1498 |
+
>>> # load video
|
1499 |
+
>>> file_path = hf_hub_download(
|
1500 |
+
... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
|
1501 |
+
... )
|
1502 |
+
>>> container = av.open(file_path)
|
1503 |
+
|
1504 |
+
>>> # sample frames
|
1505 |
+
>>> num_frames = model.config.num_image_with_embedding
|
1506 |
+
>>> indices = sample_frame_indices(
|
1507 |
+
... clip_len=num_frames, frame_sample_rate=4, seg_len=container.streams.video[0].frames
|
1508 |
+
... )
|
1509 |
+
>>> frames = read_video_pyav(container, indices)
|
1510 |
+
|
1511 |
+
>>> pixel_values = processor(images=list(frames), return_tensors="pt").pixel_values
|
1512 |
+
|
1513 |
+
>>> generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
|
1514 |
+
|
1515 |
+
>>> print("Generated caption:", processor.batch_decode(generated_ids, skip_special_tokens=True))
|
1516 |
+
Generated caption: ['a woman is sitting at a table and she is talking about the food she is holding.']
|
1517 |
+
```
|
1518 |
+
"""
|
1519 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1520 |
+
if labels is not None:
|
1521 |
+
use_cache = False
|
1522 |
+
|
1523 |
+
outputs = self.git(
|
1524 |
+
input_ids,
|
1525 |
+
attention_mask=attention_mask,
|
1526 |
+
position_ids=position_ids,
|
1527 |
+
pixel_values=pixel_values,
|
1528 |
+
pixel_masks=pixel_masks,
|
1529 |
+
head_mask=head_mask,
|
1530 |
+
inputs_embeds=inputs_embeds,
|
1531 |
+
past_key_values=past_key_values,
|
1532 |
+
use_cache=use_cache,
|
1533 |
+
output_attentions=output_attentions,
|
1534 |
+
output_hidden_states=output_hidden_states,
|
1535 |
+
return_dict=return_dict,
|
1536 |
+
)
|
1537 |
+
|
1538 |
+
sequence_output = outputs[0]
|
1539 |
+
logits = self.output(sequence_output)
|
1540 |
+
|
1541 |
+
loss = None
|
1542 |
+
if labels is not None:
|
1543 |
+
# we are doing next-token prediction; shift prediction scores and input ids by one
|
1544 |
+
num_image_tokens = self.git.encoder.layer[0].attention.self.image_patch_tokens
|
1545 |
+
shifted_logits = logits[:, num_image_tokens:-1, :].contiguous()
|
1546 |
+
labels = labels[:, 1:].contiguous()
|
1547 |
+
loss_fct = CrossEntropyLoss()
|
1548 |
+
loss = loss_fct(shifted_logits.view(-1, self.config.vocab_size), labels.view(-1))
|
1549 |
+
|
1550 |
+
if not return_dict:
|
1551 |
+
output = (logits,) + outputs[1:]
|
1552 |
+
return ((loss,) + output) if loss is not None else output
|
1553 |
+
|
1554 |
+
return CausalLMOutputWithPast(
|
1555 |
+
loss=loss,
|
1556 |
+
logits=logits,
|
1557 |
+
past_key_values=outputs.past_key_values,
|
1558 |
+
hidden_states=outputs.hidden_states,
|
1559 |
+
attentions=outputs.attentions,
|
1560 |
+
)
|
1561 |
+
|
1562 |
+
def prepare_inputs_for_generation(
|
1563 |
+
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
|
1564 |
+
):
|
1565 |
+
# cut decoder_input_ids if past_key_values is used
|
1566 |
+
if past_key_values is not None:
|
1567 |
+
input_ids = input_ids[:, -1:]
|
1568 |
+
|
1569 |
+
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
1570 |
+
input_shape = input_ids.shape
|
1571 |
+
if attention_mask is None:
|
1572 |
+
attention_mask = input_ids.new_ones(input_shape)
|
1573 |
+
|
1574 |
+
return {
|
1575 |
+
"input_ids": input_ids,
|
1576 |
+
"attention_mask": attention_mask,
|
1577 |
+
"pixel_values": kwargs.get("pixel_values", None),
|
1578 |
+
"pixel_masks": kwargs.get("pixel_masks", None),
|
1579 |
+
"past_key_values": past_key_values,
|
1580 |
+
"use_cache": use_cache,
|
1581 |
+
}
|
1582 |
+
|
1583 |
+
def _reorder_cache(self, past_key_values, beam_idx):
|
1584 |
+
reordered_past = ()
|
1585 |
+
for layer_past in past_key_values:
|
1586 |
+
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
1587 |
+
return reordered_past
|
caption_anything/captioner/vit_pixel_masks_utils.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
|
6 |
+
class ViTPatchMaskGenerator(nn.Module):
|
7 |
+
def __init__(self, patch_size) -> None:
|
8 |
+
super(ViTPatchMaskGenerator, self).__init__()
|
9 |
+
self.patch_size = patch_size
|
10 |
+
self.pool = nn.MaxPool2d(kernel_size=patch_size, stride=patch_size)
|
11 |
+
|
12 |
+
def forward(self, pixel_masks):
|
13 |
+
patch_mask = self.pool(pixel_masks)
|
14 |
+
patch_mask = patch_mask.bool().flatten(1)
|
15 |
+
cls_token_mask = patch_mask.new_ones([patch_mask.shape[0], 1]).bool()
|
16 |
+
patch_mask = torch.cat([cls_token_mask, patch_mask], dim=-1)
|
17 |
+
return patch_mask
|
caption_anything/model.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import pdb
|
4 |
+
import time
|
5 |
+
from PIL import Image
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image
|
9 |
+
import easyocr
|
10 |
+
import copy
|
11 |
+
import time
|
12 |
+
from caption_anything.captioner import build_captioner, BaseCaptioner
|
13 |
+
from caption_anything.segmenter import build_segmenter, build_segmenter_densecap
|
14 |
+
from caption_anything.text_refiner import build_text_refiner
|
15 |
+
from caption_anything.utils.utils import prepare_segmenter, seg_model_map, load_image, get_image_shape
|
16 |
+
from caption_anything.utils.utils import mask_painter_foreground_all, mask_painter, xywh_to_x1y1x2y2, image_resize
|
17 |
+
from caption_anything.utils.densecap_painter import draw_bbox
|
18 |
+
|
19 |
+
class CaptionAnything:
|
20 |
+
def __init__(self, args, api_key="", captioner=None, segmenter=None, ocr_reader=None, text_refiner=None):
|
21 |
+
self.args = args
|
22 |
+
self.captioner = build_captioner(args.captioner, args.device, args) if captioner is None else captioner
|
23 |
+
self.segmenter = build_segmenter(args.segmenter, args.device, args) if segmenter is None else segmenter
|
24 |
+
self.segmenter_densecap = build_segmenter_densecap(args.segmenter, args.device, args, model=self.segmenter.model)
|
25 |
+
self.ocr_lang = ["ch_tra", "en"]
|
26 |
+
self.ocr_reader = ocr_reader if ocr_reader is not None else easyocr.Reader(self.ocr_lang)
|
27 |
+
|
28 |
+
|
29 |
+
self.text_refiner = None
|
30 |
+
if not args.disable_gpt:
|
31 |
+
if text_refiner is not None:
|
32 |
+
self.text_refiner = text_refiner
|
33 |
+
elif api_key != "":
|
34 |
+
self.init_refiner(api_key)
|
35 |
+
self.require_caption_prompt = args.captioner == 'blip2'
|
36 |
+
|
37 |
+
@property
|
38 |
+
def image_embedding(self):
|
39 |
+
return self.segmenter.image_embedding
|
40 |
+
|
41 |
+
@image_embedding.setter
|
42 |
+
def image_embedding(self, image_embedding):
|
43 |
+
self.segmenter.image_embedding = image_embedding
|
44 |
+
|
45 |
+
@property
|
46 |
+
def original_size(self):
|
47 |
+
return self.segmenter.predictor.original_size
|
48 |
+
|
49 |
+
@original_size.setter
|
50 |
+
def original_size(self, original_size):
|
51 |
+
self.segmenter.predictor.original_size = original_size
|
52 |
+
|
53 |
+
@property
|
54 |
+
def input_size(self):
|
55 |
+
return self.segmenter.predictor.input_size
|
56 |
+
|
57 |
+
@input_size.setter
|
58 |
+
def input_size(self, input_size):
|
59 |
+
self.segmenter.predictor.input_size = input_size
|
60 |
+
|
61 |
+
def setup(self, image_embedding, original_size, input_size, is_image_set):
|
62 |
+
self.image_embedding = image_embedding
|
63 |
+
self.original_size = original_size
|
64 |
+
self.input_size = input_size
|
65 |
+
self.segmenter.predictor.is_image_set = is_image_set
|
66 |
+
|
67 |
+
def init_refiner(self, api_key):
|
68 |
+
try:
|
69 |
+
self.text_refiner = build_text_refiner(self.args.text_refiner, self.args.device, self.args, api_key)
|
70 |
+
self.text_refiner.llm('hi') # test
|
71 |
+
except:
|
72 |
+
self.text_refiner = None
|
73 |
+
print('OpenAI GPT is not available')
|
74 |
+
|
75 |
+
def inference(self, image, prompt, controls, disable_gpt=False, enable_wiki=False, verbose=False, is_densecap=False, args={}):
|
76 |
+
# segment with prompt
|
77 |
+
print("CA prompt: ", prompt, "CA controls", controls)
|
78 |
+
is_seg_everything = 'everything' in prompt['prompt_type']
|
79 |
+
|
80 |
+
args['seg_crop_mode'] = args.get('seg_crop_mode', self.args.seg_crop_mode)
|
81 |
+
args['clip_filter'] = args.get('clip_filter', self.args.clip_filter)
|
82 |
+
args['disable_regular_box'] = args.get('disable_regular_box', self.args.disable_regular_box)
|
83 |
+
args['context_captions'] = args.get('context_captions', self.args.context_captions)
|
84 |
+
args['enable_reduce_tokens'] = args.get('enable_reduce_tokens', self.args.enable_reduce_tokens)
|
85 |
+
args['enable_morphologyex'] = args.get('enable_morphologyex', self.args.enable_morphologyex)
|
86 |
+
args['topN'] = args.get('topN', 10) if is_seg_everything else 1
|
87 |
+
args['min_mask_area'] = args.get('min_mask_area', 0)
|
88 |
+
|
89 |
+
if not is_densecap:
|
90 |
+
seg_results = self.segmenter.inference(image, prompt)
|
91 |
+
else:
|
92 |
+
seg_results = self.segmenter_densecap.inference(image, prompt)
|
93 |
+
|
94 |
+
seg_masks, seg_bbox, seg_area = seg_results if is_seg_everything else (seg_results, None, None)
|
95 |
+
|
96 |
+
if args['topN'] > 1: # sort by area
|
97 |
+
samples = list(zip(*[seg_masks, seg_bbox, seg_area]))
|
98 |
+
# top_samples = sorted(samples, key=lambda x: x[2], reverse=True)
|
99 |
+
# seg_masks, seg_bbox, seg_area = list(zip(*top_samples))
|
100 |
+
samples = list(filter(lambda x: x[2] > args['min_mask_area'], samples))
|
101 |
+
samples = samples[:args['topN']]
|
102 |
+
seg_masks, seg_bbox, seg_area = list(zip(*samples))
|
103 |
+
|
104 |
+
out_list = []
|
105 |
+
for i, seg_mask in enumerate(seg_masks):
|
106 |
+
if args['enable_morphologyex']:
|
107 |
+
seg_mask = 255 * seg_mask.astype(np.uint8)
|
108 |
+
seg_mask = np.stack([seg_mask, seg_mask, seg_mask], axis=-1)
|
109 |
+
seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_OPEN, kernel=np.ones((6, 6), np.uint8))
|
110 |
+
seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_CLOSE, kernel=np.ones((6, 6), np.uint8))
|
111 |
+
seg_mask = seg_mask[:, :, 0] > 0
|
112 |
+
|
113 |
+
seg_mask_img = Image.fromarray(seg_mask.astype('int') * 255.)
|
114 |
+
mask_save_path = None
|
115 |
+
|
116 |
+
if verbose:
|
117 |
+
mask_save_path = f'result/mask_{time.time()}.png'
|
118 |
+
if not os.path.exists(os.path.dirname(mask_save_path)):
|
119 |
+
os.makedirs(os.path.dirname(mask_save_path))
|
120 |
+
|
121 |
+
if seg_mask_img.mode != 'RGB':
|
122 |
+
seg_mask_img = seg_mask_img.convert('RGB')
|
123 |
+
seg_mask_img.save(mask_save_path)
|
124 |
+
print('seg_mask path: ', mask_save_path)
|
125 |
+
print("seg_mask.shape: ", seg_mask.shape)
|
126 |
+
|
127 |
+
|
128 |
+
# captioning with mask
|
129 |
+
if args['enable_reduce_tokens']:
|
130 |
+
result = self.captioner.inference_with_reduced_tokens(image, seg_mask,
|
131 |
+
crop_mode=args['seg_crop_mode'],
|
132 |
+
filter=args['clip_filter'],
|
133 |
+
disable_regular_box=args['disable_regular_box'],
|
134 |
+
verbose=verbose,
|
135 |
+
caption_args=args)
|
136 |
+
else:
|
137 |
+
result = self.captioner.inference_seg(image, seg_mask,
|
138 |
+
crop_mode=args['seg_crop_mode'],
|
139 |
+
filter=args['clip_filter'],
|
140 |
+
disable_regular_box=args['disable_regular_box'],
|
141 |
+
verbose=verbose,
|
142 |
+
caption_args=args)
|
143 |
+
caption = result.get('caption', None)
|
144 |
+
crop_save_path = result.get('crop_save_path', None)
|
145 |
+
|
146 |
+
# refining with TextRefiner
|
147 |
+
context_captions = []
|
148 |
+
if args['context_captions']:
|
149 |
+
context_captions.append(self.captioner.inference(image)['caption'])
|
150 |
+
if not disable_gpt and self.text_refiner is not None:
|
151 |
+
refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions,
|
152 |
+
enable_wiki=enable_wiki)
|
153 |
+
else:
|
154 |
+
refined_caption = {'raw_caption': caption}
|
155 |
+
out = {'generated_captions': refined_caption,
|
156 |
+
'crop_save_path': crop_save_path,
|
157 |
+
'mask_save_path': mask_save_path,
|
158 |
+
'mask': seg_mask_img,
|
159 |
+
'bbox': seg_bbox[i] if seg_bbox is not None else None,
|
160 |
+
'area': seg_area[i] if seg_area is not None else None,
|
161 |
+
'context_captions': context_captions,
|
162 |
+
'ppl_score': result.get('ppl_score', -100.),
|
163 |
+
'clip_score': result.get('clip_score', 0.)
|
164 |
+
}
|
165 |
+
out_list.append(out)
|
166 |
+
return out_list
|
167 |
+
|
168 |
+
def parse_dense_caption(self, image, topN=10, reference_caption=[], verbose=False):
|
169 |
+
width, height = get_image_shape(image)
|
170 |
+
prompt = {'prompt_type': ['everything']}
|
171 |
+
densecap_args = {
|
172 |
+
'return_ppl': True,
|
173 |
+
'clip_filter': True,
|
174 |
+
'reference_caption': reference_caption,
|
175 |
+
'text_prompt': "", # 'Question: what does the image show? Answer:'
|
176 |
+
'seg_crop_mode': 'w_bg',
|
177 |
+
# 'text_prompt': "",
|
178 |
+
# 'seg_crop_mode': 'wo_bg',
|
179 |
+
'disable_regular_box': False,
|
180 |
+
'topN': topN,
|
181 |
+
'min_ppl_score': -1.8,
|
182 |
+
'min_clip_score': 0.30,
|
183 |
+
'min_mask_area': 2500,
|
184 |
+
}
|
185 |
+
|
186 |
+
dense_captions = self.inference(image, prompt,
|
187 |
+
controls=None,
|
188 |
+
disable_gpt=True,
|
189 |
+
verbose=verbose,
|
190 |
+
is_densecap=True,
|
191 |
+
args=densecap_args)
|
192 |
+
print('Process Dense Captioning: \n', dense_captions)
|
193 |
+
dense_captions = list(filter(lambda x: x['ppl_score'] / (1+len(x['generated_captions']['raw_caption'].split())) >= densecap_args['min_ppl_score'], dense_captions))
|
194 |
+
dense_captions = list(filter(lambda x: x['clip_score'] >= densecap_args['min_clip_score'], dense_captions))
|
195 |
+
dense_cap_prompt = []
|
196 |
+
for cap in dense_captions:
|
197 |
+
x, y, w, h = cap['bbox']
|
198 |
+
cx, cy = x + w/2, (y + h/2)
|
199 |
+
dense_cap_prompt.append("({}: X:{:.0f}, Y:{:.0f}, Width:{:.0f}, Height:{:.0f})".format(cap['generated_captions']['raw_caption'], cx, cy, w, h))
|
200 |
+
|
201 |
+
if verbose:
|
202 |
+
all_masks = [np.array(item['mask'].convert('P')) for item in dense_captions]
|
203 |
+
new_image = mask_painter_foreground_all(np.array(image), all_masks, background_alpha=0.4)
|
204 |
+
save_path = 'result/dense_caption_mask.png'
|
205 |
+
Image.fromarray(new_image).save(save_path)
|
206 |
+
print(f'Dense captioning mask saved in {save_path}')
|
207 |
+
|
208 |
+
vis_path = 'result/dense_caption_vis_{}.png'.format(time.time())
|
209 |
+
dense_cap_painter_input = [{'bbox': xywh_to_x1y1x2y2(cap['bbox']),
|
210 |
+
'caption': cap['generated_captions']['raw_caption']} for cap in dense_captions]
|
211 |
+
draw_bbox(load_image(image, return_type='numpy'), vis_path, dense_cap_painter_input, show_caption=True)
|
212 |
+
print(f'Dense Captioning visualization saved in {vis_path}')
|
213 |
+
return ','.join(dense_cap_prompt)
|
214 |
+
|
215 |
+
def parse_ocr(self, image, thres=0.2):
|
216 |
+
width, height = get_image_shape(image)
|
217 |
+
image = load_image(image, return_type='numpy')
|
218 |
+
bounds = self.ocr_reader.readtext(image)
|
219 |
+
bounds = [bound for bound in bounds if bound[2] > thres]
|
220 |
+
print('Process OCR Text:\n', bounds)
|
221 |
+
|
222 |
+
ocr_prompt = []
|
223 |
+
for box, text, conf in bounds:
|
224 |
+
p0, p1, p2, p3 = box
|
225 |
+
ocr_prompt.append('(\"{}\": X:{:.0f}, Y:{:.0f})'.format(text, (p0[0]+p1[0]+p2[0]+p3[0])/4, (p0[1]+p1[1]+p2[1]+p3[1])/4))
|
226 |
+
ocr_prompt = '\n'.join(ocr_prompt)
|
227 |
+
|
228 |
+
# ocr_prompt = self.text_refiner.llm(f'The image have some scene texts with their locations: {ocr_prompt}. Please group these individual words into one or several phrase based on their relative positions (only give me your answer, do not show explanination)').strip()
|
229 |
+
|
230 |
+
# ocr_prefix1 = f'The image have some scene texts with their locations: {ocr_prompt}. Please group these individual words into one or several phrase based on their relative positions (only give me your answer, do not show explanination)'
|
231 |
+
# ocr_prefix2 = f'Please group these individual words into 1-3 phrases, given scene texts with their locations: {ocr_prompt}. You return is one or several strings and infer their locations. (only give me your answer like (“man working”, X: value, Y: value), do not show explanination)'
|
232 |
+
# ocr_prefix4 = f'summarize the individual scene text words detected by OCR tools into a fluent sentence based on their positions and distances. You should strictly describe all of the given scene text words. Do not miss any given word. Do not create non-exist words. Do not appear numeric positions. The individual words are given:\n{ocr_prompt}\n'
|
233 |
+
# ocr_prefix3 = f'combine the individual scene text words detected by OCR tools into one/several fluent phrases/sentences based on their positions and distances. You should strictly copy or correct all of the given scene text words. Do not miss any given word. Do not create non-exist words. The response is several strings seperate with their location (X, Y), each of which represents a phrase. The individual words are given:\n{ocr_prompt}\n'
|
234 |
+
# response = self.text_refiner.llm(ocr_prefix3).strip() if len(ocr_prompt) else ""
|
235 |
+
return ocr_prompt
|
236 |
+
|
237 |
+
def inference_cap_everything(self, image, verbose=False):
|
238 |
+
image = load_image(image, return_type='pil')
|
239 |
+
image = image_resize(image, res=1024)
|
240 |
+
width, height = get_image_shape(image)
|
241 |
+
other_args = {'text_prompt': ""} if self.require_caption_prompt else {}
|
242 |
+
img_caption = self.captioner.inference(image, filter=False, args=other_args)['caption']
|
243 |
+
dense_caption_prompt = self.parse_dense_caption(image, topN=10, verbose=verbose, reference_caption=[])
|
244 |
+
scene_text_prompt = self.parse_ocr(image, thres=0.2)
|
245 |
+
# scene_text_prompt = "N/A"
|
246 |
+
|
247 |
+
# the summarize_prompt is modified from https://github.com/JialianW/GRiT and https://github.com/showlab/Image2Paragraph
|
248 |
+
summarize_prompt = "Imagine you are a blind but intelligent image captioner. You should generate a descriptive, coherent and human-like paragraph based on the given information (a,b,c,d) instead of imagination:\na) Image Resolution: {image_size}\nb) Image Caption:{image_caption}\nc) Dense Caption: {dense_caption}\nd) Scene Text: {scene_text}\nThere are some rules for your response: Show objects with their attributes (e.g. position, color, size, shape, texture).\nPrimarily describe common objects with large size.\nProvide context of the image.\nShow relative position between objects.\nLess than 6 sentences.\nDo not appear number.\nDo not describe any individual letter.\nDo not show the image resolution.\nIngore the white background."
|
249 |
+
prompt = summarize_prompt.format(**{
|
250 |
+
"image_size": "width {} height {}".format(width, height),
|
251 |
+
"image_caption":img_caption,
|
252 |
+
"dense_caption": dense_caption_prompt,
|
253 |
+
"scene_text": scene_text_prompt})
|
254 |
+
print(f'caption everything prompt: {prompt}')
|
255 |
+
response = self.text_refiner.llm(prompt).strip()
|
256 |
+
# chinese_response = self.text_refiner.llm('Translate it into Chinese: {}'.format(response)).strip()
|
257 |
+
return response
|
258 |
+
|
259 |
+
if __name__ == "__main__":
|
260 |
+
from caption_anything.utils.parser import parse_augment
|
261 |
+
args = parse_augment()
|
262 |
+
image_path = 'result/wt/memes/87226084.jpg'
|
263 |
+
image = Image.open(image_path)
|
264 |
+
prompts = [
|
265 |
+
{
|
266 |
+
"prompt_type": ["click"],
|
267 |
+
"input_point": [[500, 300], [200, 500]],
|
268 |
+
"input_label": [1, 0],
|
269 |
+
"multimask_output": "True",
|
270 |
+
},
|
271 |
+
# {
|
272 |
+
# "prompt_type": ["click"],
|
273 |
+
# "input_point": [[300, 800]],
|
274 |
+
# "input_label": [1],
|
275 |
+
# "multimask_output": "True",
|
276 |
+
# }
|
277 |
+
]
|
278 |
+
controls = {
|
279 |
+
"length": "30",
|
280 |
+
"sentiment": "positive",
|
281 |
+
# "imagination": "True",
|
282 |
+
"imagination": "False",
|
283 |
+
"language": "English",
|
284 |
+
}
|
285 |
+
|
286 |
+
model = CaptionAnything(args, os.environ['OPENAI_API_KEY'])
|
287 |
+
img_dir = 'test_images/memes'
|
288 |
+
for image_file in os.listdir(img_dir):
|
289 |
+
image_path = os.path.join(img_dir, image_file)
|
290 |
+
print('image_path:', image_path)
|
291 |
+
paragraph = model.inference_cap_everything(image_path, verbose=True)
|
292 |
+
print('Caption Everything:\n', paragraph)
|
293 |
+
ocr = model.parse_ocr(image_path)
|
294 |
+
print('OCR', ocr)
|
caption_anything/segmenter/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_segmenter import BaseSegmenter
|
2 |
+
from caption_anything.utils.utils import seg_model_map
|
3 |
+
import copy
|
4 |
+
|
5 |
+
def build_segmenter(model_name, device, args, model=None):
|
6 |
+
return BaseSegmenter(device, args.segmenter_checkpoint, model_name, reuse_feature=not args.disable_reuse_features, model=model, args=args)
|
7 |
+
|
8 |
+
def build_segmenter_densecap(model_name, device, args, model=None):
|
9 |
+
args_for_densecap = copy.deepcopy(args)
|
10 |
+
args_for_densecap.pred_iou_thresh = 0.88
|
11 |
+
args_for_densecap.min_mask_region_area = 400
|
12 |
+
args_for_densecap.stability_score_thresh = 0.95
|
13 |
+
args_for_densecap.box_nms_thresh = 0.3
|
14 |
+
return BaseSegmenter(device, args.segmenter_checkpoint, model_name, reuse_feature=not args.disable_reuse_features, model=model, args=args)
|
caption_anything/segmenter/base_segmenter.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import torch
|
3 |
+
import cv2
|
4 |
+
from PIL import Image, ImageDraw, ImageOps
|
5 |
+
import numpy as np
|
6 |
+
from typing import Union
|
7 |
+
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
|
8 |
+
from caption_anything.utils.utils import prepare_segmenter, seg_model_map, load_image
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import PIL
|
11 |
+
|
12 |
+
|
13 |
+
class BaseSegmenter:
|
14 |
+
def __init__(self, device, checkpoint, model_name='huge', reuse_feature=True, model=None, args=None):
|
15 |
+
print(f"Initializing BaseSegmenter to {device}")
|
16 |
+
self.device = device
|
17 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
18 |
+
self.processor = None
|
19 |
+
if model is None:
|
20 |
+
if checkpoint is None:
|
21 |
+
_, checkpoint = prepare_segmenter(model_name)
|
22 |
+
self.model = sam_model_registry[seg_model_map[model_name]](checkpoint=checkpoint)
|
23 |
+
self.checkpoint = checkpoint
|
24 |
+
self.model.to(device=self.device)
|
25 |
+
else:
|
26 |
+
self.model = model
|
27 |
+
self.reuse_feature = reuse_feature
|
28 |
+
self.predictor = SamPredictor(self.model)
|
29 |
+
|
30 |
+
sam_generator_keys = ['pred_iou_thresh', 'min_mask_region_area', 'stability_score_thresh', 'box_nms_thresh']
|
31 |
+
generator_args = {k:v for k,v in vars(args).items() if k in sam_generator_keys}
|
32 |
+
self.mask_generator = SamAutomaticMaskGenerator(model=self.model, **generator_args)
|
33 |
+
self.image_embedding = None
|
34 |
+
self.image = None
|
35 |
+
|
36 |
+
@torch.no_grad()
|
37 |
+
def set_image(self, image: Union[np.ndarray, Image.Image, str]):
|
38 |
+
image = load_image(image, return_type='numpy')
|
39 |
+
self.image = image
|
40 |
+
if self.reuse_feature:
|
41 |
+
self.predictor.set_image(image)
|
42 |
+
self.image_embedding = self.predictor.get_image_embedding()
|
43 |
+
print(self.image_embedding.shape)
|
44 |
+
|
45 |
+
@torch.no_grad()
|
46 |
+
def inference(self, image: Union[np.ndarray, Image.Image, str], control: dict):
|
47 |
+
"""
|
48 |
+
SAM inference of image according to control.
|
49 |
+
Args:
|
50 |
+
image: str or PIL.Image or np.ndarray
|
51 |
+
control: dict to control SAM.
|
52 |
+
prompt_type:
|
53 |
+
1. {control['prompt_type'] = ['everything']} to segment everything in the image.
|
54 |
+
2. {control['prompt_type'] = ['click', 'box']} to segment according to click and box.
|
55 |
+
3. {control['prompt_type'] = ['click'] to segment according to click.
|
56 |
+
4. {control['prompt_type'] = ['box'] to segment according to box.
|
57 |
+
input_point: list of [x, y] coordinates of click.
|
58 |
+
input_label: List of labels for points accordingly, 0 for negative, 1 for positive.
|
59 |
+
input_box: List of [x1, y1, x2, y2] coordinates of box.
|
60 |
+
multimask_output:
|
61 |
+
If true, the model will return three masks.
|
62 |
+
For ambiguous input prompts (such as a single click), this will often
|
63 |
+
produce better masks than a single prediction. If only a single
|
64 |
+
mask is needed, the model's predicted quality score can be used
|
65 |
+
to select the best mask. For non-ambiguous prompts, such as multiple
|
66 |
+
input prompts, multimask_output=False can give better results.
|
67 |
+
Returns:
|
68 |
+
masks: np.ndarray of shape [num_masks, height, width]
|
69 |
+
|
70 |
+
"""
|
71 |
+
image = load_image(image, return_type='numpy')
|
72 |
+
if 'everything' in control['prompt_type']:
|
73 |
+
masks = self.mask_generator.generate(image)
|
74 |
+
new_masks = np.concatenate([mask["segmentation"][np.newaxis, :] for mask in masks])
|
75 |
+
bbox = np.array([mask["bbox"] for mask in masks])
|
76 |
+
area = np.array([mask["area"] for mask in masks])
|
77 |
+
return new_masks, bbox, area
|
78 |
+
else:
|
79 |
+
if not self.reuse_feature or self.image_embedding is None:
|
80 |
+
self.set_image(image)
|
81 |
+
self.predictor.set_image(self.image)
|
82 |
+
else:
|
83 |
+
assert self.image_embedding is not None
|
84 |
+
self.predictor.features = self.image_embedding
|
85 |
+
|
86 |
+
if 'mutimask_output' in control:
|
87 |
+
masks, scores, logits = self.predictor.predict(
|
88 |
+
point_coords=np.array(control['input_point']),
|
89 |
+
point_labels=np.array(control['input_label']),
|
90 |
+
multimask_output=True,
|
91 |
+
)
|
92 |
+
elif 'input_boxes' in control:
|
93 |
+
transformed_boxes = self.predictor.transform.apply_boxes_torch(
|
94 |
+
torch.tensor(control["input_boxes"], device=self.predictor.device),
|
95 |
+
image.shape[1::-1] # Reverse shape because numpy is (W, H) and function need (H, W)
|
96 |
+
)
|
97 |
+
masks, _, _ = self.predictor.predict_torch(
|
98 |
+
point_coords=None,
|
99 |
+
point_labels=None,
|
100 |
+
boxes=transformed_boxes,
|
101 |
+
multimask_output=False,
|
102 |
+
)
|
103 |
+
masks = masks.squeeze(1).cpu().numpy()
|
104 |
+
|
105 |
+
else:
|
106 |
+
input_point = np.array(control['input_point']) if 'click' in control['prompt_type'] else None
|
107 |
+
input_label = np.array(control['input_label']) if 'click' in control['prompt_type'] else None
|
108 |
+
input_box = np.array(control['input_box']) if 'box' in control['prompt_type'] else None
|
109 |
+
|
110 |
+
masks, scores, logits = self.predictor.predict(
|
111 |
+
point_coords=input_point,
|
112 |
+
point_labels=input_label,
|
113 |
+
box=input_box,
|
114 |
+
multimask_output=False,
|
115 |
+
)
|
116 |
+
|
117 |
+
if 0 in control['input_label']:
|
118 |
+
mask_input = logits[np.argmax(scores), :, :]
|
119 |
+
masks, scores, logits = self.predictor.predict(
|
120 |
+
point_coords=input_point,
|
121 |
+
point_labels=input_label,
|
122 |
+
box=input_box,
|
123 |
+
mask_input=mask_input[None, :, :],
|
124 |
+
multimask_output=False,
|
125 |
+
)
|
126 |
+
|
127 |
+
return masks
|
128 |
+
|
129 |
+
|
130 |
+
if __name__ == "__main__":
|
131 |
+
image_path = 'segmenter/images/truck.jpg'
|
132 |
+
prompts = [
|
133 |
+
# {
|
134 |
+
# "prompt_type":["click"],
|
135 |
+
# "input_point":[[500, 375]],
|
136 |
+
# "input_label":[1],
|
137 |
+
# "multimask_output":"True",
|
138 |
+
# },
|
139 |
+
{
|
140 |
+
"prompt_type": ["click"],
|
141 |
+
"input_point": [[1000, 600], [1325, 625]],
|
142 |
+
"input_label": [1, 0],
|
143 |
+
},
|
144 |
+
# {
|
145 |
+
# "prompt_type":["click", "box"],
|
146 |
+
# "input_box":[425, 600, 700, 875],
|
147 |
+
# "input_point":[[575, 750]],
|
148 |
+
# "input_label": [0]
|
149 |
+
# },
|
150 |
+
# {
|
151 |
+
# "prompt_type":["box"],
|
152 |
+
# "input_boxes": [
|
153 |
+
# [75, 275, 1725, 850],
|
154 |
+
# [425, 600, 700, 875],
|
155 |
+
# [1375, 550, 1650, 800],
|
156 |
+
# [1240, 675, 1400, 750],
|
157 |
+
# ]
|
158 |
+
# },
|
159 |
+
# {
|
160 |
+
# "prompt_type":["everything"]
|
161 |
+
# },
|
162 |
+
]
|
163 |
+
|
164 |
+
init_time = time.time()
|
165 |
+
segmenter = BaseSegmenter(
|
166 |
+
device='cuda',
|
167 |
+
# checkpoint='sam_vit_h_4b8939.pth',
|
168 |
+
checkpoint='segmenter/sam_vit_h_4b8939.pth',
|
169 |
+
model_type='vit_h',
|
170 |
+
reuse_feature=True
|
171 |
+
)
|
172 |
+
print(f'init time: {time.time() - init_time}')
|
173 |
+
|
174 |
+
image_path = 'test_images/img2.jpg'
|
175 |
+
infer_time = time.time()
|
176 |
+
for i, prompt in enumerate(prompts):
|
177 |
+
print(f'{prompt["prompt_type"]} mode')
|
178 |
+
image = Image.open(image_path)
|
179 |
+
segmenter.set_image(np.array(image))
|
180 |
+
masks = segmenter.inference(np.array(image), prompt)
|
181 |
+
Image.fromarray(masks[0]).save('seg.png')
|
182 |
+
print(masks.shape)
|
183 |
+
|
184 |
+
print(f'infer time: {time.time() - infer_time}')
|
caption_anything/segmenter/readme.md
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Prepare SAM
|
2 |
+
```
|
3 |
+
pip install git+https://github.com/facebookresearch/segment-anything.git
|
4 |
+
```
|
5 |
+
or
|
6 |
+
```
|
7 |
+
git clone git@github.com:facebookresearch/segment-anything.git
|
8 |
+
cd segment-anything; pip install -e .
|
9 |
+
```
|
10 |
+
|
11 |
+
```
|
12 |
+
pip install opencv-python pycocotools matplotlib onnxruntime onnx
|
13 |
+
```
|
14 |
+
### Download the checkpoint:
|
15 |
+
|
16 |
+
https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
|
17 |
+
|
18 |
+
### Inference
|
19 |
+
|
20 |
+
The prompts are in json format:
|
21 |
+
|
22 |
+
```
|
23 |
+
prompts = [
|
24 |
+
{
|
25 |
+
"prompt_type":["click"],
|
26 |
+
"input_point":[[500, 375]],
|
27 |
+
"input_label":[1],
|
28 |
+
"multimask_output":"True",
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"prompt_type":["click"],
|
32 |
+
"input_point":[[500, 375], [1125, 625]],
|
33 |
+
"input_label":[1, 0],
|
34 |
+
},
|
35 |
+
{
|
36 |
+
"prompt_type":["click", "box"],
|
37 |
+
"input_box":[425, 600, 700, 875],
|
38 |
+
"input_point":[[575, 750]],
|
39 |
+
"input_label": [0]
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"prompt_type":["box"],
|
43 |
+
"input_boxes": [
|
44 |
+
[75, 275, 1725, 850],
|
45 |
+
[425, 600, 700, 875],
|
46 |
+
[1375, 550, 1650, 800],
|
47 |
+
[1240, 675, 1400, 750],
|
48 |
+
]
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"prompt_type":["everything"]
|
52 |
+
},
|
53 |
+
]
|
54 |
+
```
|
55 |
+
|
56 |
+
In `base_segmenter.py`:
|
57 |
+
```
|
58 |
+
segmenter = BaseSegmenter(
|
59 |
+
device='cuda',
|
60 |
+
checkpoint='sam_vit_h_4b8939.pth',
|
61 |
+
model_type='vit_h'
|
62 |
+
)
|
63 |
+
|
64 |
+
for i, prompt in enumerate(prompts):
|
65 |
+
masks = segmenter.inference(image_path, prompt)
|
66 |
+
```
|
67 |
+
|
68 |
+
Outputs are masks (True and False numpy Matrix), shape: (num of masks, height, weight)
|
caption_anything/text_refiner/README.md
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Install
|
2 |
+
* python >= 3.8.1
|
3 |
+
|
4 |
+
```bash
|
5 |
+
pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 torchaudio==0.10.1 -f https://download.pytorch.org/whl/cu111/torch_stable.html # CUDA version could be different
|
6 |
+
pip install openai pillow transformers
|
7 |
+
pip install langchain==0.0.101
|
8 |
+
```
|
caption_anything/text_refiner/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .text_refiner import TextRefiner
|
2 |
+
|
3 |
+
|
4 |
+
def build_text_refiner(type, device, args=None, api_key=""):
|
5 |
+
if type == 'base':
|
6 |
+
return TextRefiner(device, api_key)
|
caption_anything/text_refiner/text_refiner.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.llms.openai import OpenAI
|
2 |
+
import torch
|
3 |
+
from PIL import Image, ImageDraw, ImageOps
|
4 |
+
from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
|
5 |
+
import pdb
|
6 |
+
|
7 |
+
class TextRefiner:
|
8 |
+
def __init__(self, device, api_key=""):
|
9 |
+
print(f"Initializing TextRefiner to {device}")
|
10 |
+
self.llm = OpenAI(model_name="gpt-3.5-turbo", temperature=0, openai_api_key=api_key)
|
11 |
+
self.prompt_tag = {
|
12 |
+
"imagination": {"True": "could",
|
13 |
+
"False": "could not"}
|
14 |
+
}
|
15 |
+
self.short_prompts = {
|
16 |
+
"length": "around {length} words",
|
17 |
+
"sentiment": "of {sentiment} sentiment",
|
18 |
+
"language": "in {language}",
|
19 |
+
}
|
20 |
+
|
21 |
+
self.long_prompts = {
|
22 |
+
"imagination": "The new sentence could extend the original description by using your imagination to create additional details, or think about what might have happened before or after the scene in the image, but should not conflict with the original sentence",
|
23 |
+
}
|
24 |
+
|
25 |
+
self.wiki_prompts = "I want you to act as a Wikipedia page. I will give you a sentence and you will parse the single main object in the sentence and provide a summary of that object in the format of a Wikipedia page. Your summary should be informative and factual, covering the most important aspects of the object. Start your summary with an introductory paragraph that gives an overview of the object. The overall length of the response should be around 100 words. You should not describe the parsing process and only provide the final summary. The sentence is \"{query}\"."
|
26 |
+
|
27 |
+
self.control_prompts = "As a text reviser, you will convert an image description into a new sentence or long paragraph. The new text is {prompts}. {long_prompts} The sentence is \"{query}\" (give me the revised sentence only)"
|
28 |
+
|
29 |
+
def parse(self, response):
|
30 |
+
out = response.strip()
|
31 |
+
return out
|
32 |
+
|
33 |
+
def parse2(self, response):
|
34 |
+
out = response.strip()
|
35 |
+
return out
|
36 |
+
|
37 |
+
def prepare_input(self, query, short_prompts, long_prompts):
|
38 |
+
input = self.control_prompts.format(**{'prompts': ', '.join(short_prompts), 'long_prompts': '. '.join(long_prompts), 'query': query})
|
39 |
+
print('prompt: ', input)
|
40 |
+
return input
|
41 |
+
|
42 |
+
def inference(self, query: str, controls: dict, context: list=[], enable_wiki=False):
|
43 |
+
"""
|
44 |
+
query: the caption of the region of interest, generated by captioner
|
45 |
+
controls: a dict of control singals, e.g., {"length": 5, "sentiment": "positive"}
|
46 |
+
"""
|
47 |
+
prompts = []
|
48 |
+
long_prompts = []
|
49 |
+
for control, value in controls.items():
|
50 |
+
# if control in self.prompt_tag:
|
51 |
+
# value = self.prompt_tag[control][value]
|
52 |
+
if control in self.short_prompts:
|
53 |
+
prompts.append(self.short_prompts[control].format(**{control: value}))
|
54 |
+
else:
|
55 |
+
if value in [True, "True", "true"]:
|
56 |
+
long_prompts.append(self.long_prompts[control])
|
57 |
+
input = self.prepare_input(query, prompts, long_prompts)
|
58 |
+
response = self.llm(input)
|
59 |
+
response = self.parse(response)
|
60 |
+
|
61 |
+
response_wiki = ""
|
62 |
+
if enable_wiki:
|
63 |
+
tmp_configs = {"query": query}
|
64 |
+
prompt_wiki = self.wiki_prompts.format(**tmp_configs)
|
65 |
+
response_wiki = self.llm(prompt_wiki)
|
66 |
+
response_wiki = self.parse2(response_wiki)
|
67 |
+
out = {
|
68 |
+
'raw_caption': query,
|
69 |
+
'caption': response,
|
70 |
+
'wiki': response_wiki
|
71 |
+
}
|
72 |
+
print(out)
|
73 |
+
return out
|
74 |
+
|
75 |
+
if __name__ == "__main__":
|
76 |
+
model = TextRefiner(device='cpu')
|
77 |
+
controls = {
|
78 |
+
"length": "30",
|
79 |
+
"sentiment": "negative",
|
80 |
+
# "imagination": "True",
|
81 |
+
"imagination": "False",
|
82 |
+
"language": "English",
|
83 |
+
}
|
84 |
+
# model.inference(query='a dog is sitting on a brown bench', controls=controls)
|
85 |
+
model.inference(query='a cat is sleeping', controls=controls)
|
86 |
+
|
caption_anything/utils/chatbot.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft
|
2 |
+
# Modified from Visual ChatGPT Project https://github.com/microsoft/TaskMatrix/blob/main/visual_chatgpt.py
|
3 |
+
|
4 |
+
import os
|
5 |
+
import gradio as gr
|
6 |
+
import re
|
7 |
+
import uuid
|
8 |
+
from PIL import Image, ImageDraw, ImageOps
|
9 |
+
import numpy as np
|
10 |
+
import argparse
|
11 |
+
import inspect
|
12 |
+
|
13 |
+
from langchain.agents.initialize import initialize_agent
|
14 |
+
from langchain.agents.tools import Tool
|
15 |
+
from langchain.chains.conversation.memory import ConversationBufferMemory
|
16 |
+
from langchain.llms.openai import OpenAI
|
17 |
+
import torch
|
18 |
+
from PIL import Image, ImageDraw, ImageOps
|
19 |
+
from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
|
20 |
+
|
21 |
+
VISUAL_CHATGPT_PREFIX = """
|
22 |
+
I want you act as Caption Anything Chatbox (short as CATchat), which is designed to be able to assist with a wide range of text and visual related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. You are able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
|
23 |
+
|
24 |
+
As a language model, you can not directly read images, but can invoke VQA tool to indirectly understand pictures, by repeatly asking questions about the objects and scene of the image. You should carefully asking informative questions to maximize your information about this image content. Each image will have a file name formed as "chat_image/xxx.png", you are very strict to the file name and will never fabricate nonexistent files.
|
25 |
+
|
26 |
+
You have access to the following tools:"""
|
27 |
+
|
28 |
+
|
29 |
+
# TOOLS:
|
30 |
+
# ------
|
31 |
+
|
32 |
+
# Visual ChatGPT has access to the following tools:"""
|
33 |
+
|
34 |
+
VISUAL_CHATGPT_FORMAT_INSTRUCTIONS = """To use a tool, please use the following format:
|
35 |
+
|
36 |
+
"Thought: Do I need to use a tool? Yes
|
37 |
+
Action: the action to take, should be one of [{tool_names}], remember the action must to be one tool
|
38 |
+
Action Input: the input to the action
|
39 |
+
Observation: the result of the action"
|
40 |
+
|
41 |
+
When you have a response to say to the Human, or if you do not need to use a tool, you MUST use the format:
|
42 |
+
|
43 |
+
"Thought: Do I need to use a tool? No
|
44 |
+
{ai_prefix}: [your response here]"
|
45 |
+
|
46 |
+
"""
|
47 |
+
|
48 |
+
VISUAL_CHATGPT_SUFFIX = """
|
49 |
+
Begin Chatting!
|
50 |
+
|
51 |
+
Previous conversation history:
|
52 |
+
{chat_history}
|
53 |
+
|
54 |
+
New input: {input}
|
55 |
+
As a language model, you must repeatly to use VQA tools to observe images. You response should be consistent with the outputs of the VQA tool instead of imagination. Do not repeat asking the same question.
|
56 |
+
|
57 |
+
Thought: Do I need to use a tool? {agent_scratchpad} (You are strictly to use the aforementioned "Thought/Action/Action Input/Observation" format as the answer.)"""
|
58 |
+
|
59 |
+
os.makedirs('chat_image', exist_ok=True)
|
60 |
+
|
61 |
+
|
62 |
+
def prompts(name, description):
|
63 |
+
def decorator(func):
|
64 |
+
func.name = name
|
65 |
+
func.description = description
|
66 |
+
return func
|
67 |
+
return decorator
|
68 |
+
|
69 |
+
def cut_dialogue_history(history_memory, keep_last_n_words=500):
|
70 |
+
if history_memory is None or len(history_memory) == 0:
|
71 |
+
return history_memory
|
72 |
+
tokens = history_memory.split()
|
73 |
+
n_tokens = len(tokens)
|
74 |
+
print(f"history_memory:{history_memory}, n_tokens: {n_tokens}")
|
75 |
+
if n_tokens < keep_last_n_words:
|
76 |
+
return history_memory
|
77 |
+
paragraphs = history_memory.split('\n')
|
78 |
+
last_n_tokens = n_tokens
|
79 |
+
while last_n_tokens >= keep_last_n_words:
|
80 |
+
last_n_tokens -= len(paragraphs[0].split(' '))
|
81 |
+
paragraphs = paragraphs[1:]
|
82 |
+
return '\n' + '\n'.join(paragraphs)
|
83 |
+
|
84 |
+
def get_new_image_name(folder='chat_image', func_name="update"):
|
85 |
+
this_new_uuid = str(uuid.uuid4())[:8]
|
86 |
+
new_file_name = f'{func_name}_{this_new_uuid}.png'
|
87 |
+
return os.path.join(folder, new_file_name)
|
88 |
+
|
89 |
+
class VisualQuestionAnswering:
|
90 |
+
def __init__(self, device):
|
91 |
+
print(f"Initializing VisualQuestionAnswering to {device}")
|
92 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
93 |
+
self.device = device
|
94 |
+
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
|
95 |
+
self.model = BlipForQuestionAnswering.from_pretrained(
|
96 |
+
"Salesforce/blip-vqa-base", torch_dtype=self.torch_dtype).to(self.device)
|
97 |
+
# self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
|
98 |
+
# self.model = BlipForQuestionAnswering.from_pretrained(
|
99 |
+
# "Salesforce/blip-vqa-capfilt-large", torch_dtype=self.torch_dtype).to(self.device)
|
100 |
+
|
101 |
+
@prompts(name="Answer Question About The Image",
|
102 |
+
description="VQA tool is useful when you need an answer for a question based on an image. "
|
103 |
+
"like: what is the color of an object, how many cats in this figure, where is the child sitting, what does the cat doing, why is he laughing."
|
104 |
+
"The input to this tool should be a comma separated string of two, representing the image path and the question.")
|
105 |
+
def inference(self, inputs):
|
106 |
+
image_path, question = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
|
107 |
+
raw_image = Image.open(image_path).convert('RGB')
|
108 |
+
inputs = self.processor(raw_image, question, return_tensors="pt").to(self.device, self.torch_dtype)
|
109 |
+
out = self.model.generate(**inputs)
|
110 |
+
answer = self.processor.decode(out[0], skip_special_tokens=True)
|
111 |
+
print(f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input Question: {question}, "
|
112 |
+
f"Output Answer: {answer}")
|
113 |
+
return answer
|
114 |
+
|
115 |
+
def build_chatbot_tools(load_dict):
|
116 |
+
print(f"Initializing ChatBot, load_dict={load_dict}")
|
117 |
+
models = {}
|
118 |
+
# Load Basic Foundation Models
|
119 |
+
for class_name, device in load_dict.items():
|
120 |
+
models[class_name] = globals()[class_name](device=device)
|
121 |
+
|
122 |
+
# Load Template Foundation Models
|
123 |
+
for class_name, module in globals().items():
|
124 |
+
if getattr(module, 'template_model', False):
|
125 |
+
template_required_names = {k for k in inspect.signature(module.__init__).parameters.keys() if k!='self'}
|
126 |
+
loaded_names = set([type(e).__name__ for e in models.values()])
|
127 |
+
if template_required_names.issubset(loaded_names):
|
128 |
+
models[class_name] = globals()[class_name](
|
129 |
+
**{name: models[name] for name in template_required_names})
|
130 |
+
|
131 |
+
tools = []
|
132 |
+
for instance in models.values():
|
133 |
+
for e in dir(instance):
|
134 |
+
if e.startswith('inference'):
|
135 |
+
func = getattr(instance, e)
|
136 |
+
tools.append(Tool(name=func.name, description=func.description, func=func))
|
137 |
+
return tools
|
138 |
+
|
139 |
+
class ConversationBot:
|
140 |
+
def __init__(self, tools, api_key=""):
|
141 |
+
# load_dict = {'VisualQuestionAnswering':'cuda:0', 'ImageCaptioning':'cuda:1',...}
|
142 |
+
llm = OpenAI(model_name="gpt-3.5-turbo", temperature=0.7, openai_api_key=api_key)
|
143 |
+
self.llm = llm
|
144 |
+
self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')
|
145 |
+
self.tools = tools
|
146 |
+
self.current_image = None
|
147 |
+
self.point_prompt = ""
|
148 |
+
self.global_prompt = ""
|
149 |
+
self.agent = initialize_agent(
|
150 |
+
self.tools,
|
151 |
+
self.llm,
|
152 |
+
agent="conversational-react-description",
|
153 |
+
verbose=True,
|
154 |
+
memory=self.memory,
|
155 |
+
return_intermediate_steps=True,
|
156 |
+
agent_kwargs={'prefix': VISUAL_CHATGPT_PREFIX, 'format_instructions': VISUAL_CHATGPT_FORMAT_INSTRUCTIONS,
|
157 |
+
'suffix': VISUAL_CHATGPT_SUFFIX}, )
|
158 |
+
|
159 |
+
def constructe_intermediate_steps(self, agent_res):
|
160 |
+
ans = []
|
161 |
+
for action, output in agent_res:
|
162 |
+
if hasattr(action, "tool_input"):
|
163 |
+
use_tool = "Yes"
|
164 |
+
act = (f"Thought: Do I need to use a tool? {use_tool}\nAction: {action.tool}\nAction Input: {action.tool_input}", f"Observation: {output}")
|
165 |
+
else:
|
166 |
+
use_tool = "No"
|
167 |
+
act = (f"Thought: Do I need to use a tool? {use_tool}", f"AI: {output}")
|
168 |
+
act= list(map(lambda x: x.replace('\n', '<br>'), act))
|
169 |
+
ans.append(act)
|
170 |
+
return ans
|
171 |
+
|
172 |
+
def run_text(self, text, state, aux_state):
|
173 |
+
self.agent.memory.buffer = cut_dialogue_history(self.agent.memory.buffer, keep_last_n_words=500)
|
174 |
+
if self.point_prompt != "":
|
175 |
+
Human_prompt = f'\nHuman: {self.point_prompt}\n'
|
176 |
+
AI_prompt = 'Ok'
|
177 |
+
self.agent.memory.buffer = self.agent.memory.buffer + Human_prompt + 'AI: ' + AI_prompt
|
178 |
+
self.point_prompt = ""
|
179 |
+
res = self.agent({"input": text})
|
180 |
+
res['output'] = res['output'].replace("\\", "/")
|
181 |
+
response = re.sub('(chat_image/\S*png)', lambda m: f'})*{m.group(0)}*', res['output'])
|
182 |
+
state = state + [(text, response)]
|
183 |
+
|
184 |
+
aux_state = aux_state + [(f"User Input: {text}", None)]
|
185 |
+
aux_state = aux_state + self.constructe_intermediate_steps(res['intermediate_steps'])
|
186 |
+
print(f"\nProcessed run_text, Input text: {text}\nCurrent state: {state}\n"
|
187 |
+
f"Current Memory: {self.agent.memory.buffer}\n"
|
188 |
+
f"Aux state: {aux_state}\n"
|
189 |
+
)
|
190 |
+
return state, state, aux_state, aux_state
|
191 |
+
|
192 |
+
|
193 |
+
if __name__ == '__main__':
|
194 |
+
parser = argparse.ArgumentParser()
|
195 |
+
parser.add_argument('--load', type=str, default="VisualQuestionAnswering_cuda:0")
|
196 |
+
parser.add_argument('--port', type=int, default=1015)
|
197 |
+
|
198 |
+
args = parser.parse_args()
|
199 |
+
load_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.load.split(',')}
|
200 |
+
tools = build_chatbot_tools(load_dict)
|
201 |
+
bot = ConversationBot(tools)
|
202 |
+
with gr.Blocks(css="#chatbot .overflow-y-auto{height:500px}") as demo:
|
203 |
+
with gr.Row():
|
204 |
+
chatbot = gr.Chatbot(elem_id="chatbot", label="CATchat").style(height=1000,scale=0.5)
|
205 |
+
auxwindow = gr.Chatbot(elem_id="chatbot", label="Aux Window").style(height=1000,scale=0.5)
|
206 |
+
state = gr.State([])
|
207 |
+
aux_state = gr.State([])
|
208 |
+
with gr.Row():
|
209 |
+
with gr.Column(scale=0.7):
|
210 |
+
txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter, or upload an image").style(
|
211 |
+
container=False)
|
212 |
+
with gr.Column(scale=0.15, min_width=0):
|
213 |
+
clear = gr.Button("Clear")
|
214 |
+
with gr.Column(scale=0.15, min_width=0):
|
215 |
+
btn = gr.UploadButton("Upload", file_types=["image"])
|
216 |
+
|
217 |
+
txt.submit(bot.run_text, [txt, state, aux_state], [chatbot, state, aux_state, auxwindow])
|
218 |
+
txt.submit(lambda: "", None, txt)
|
219 |
+
btn.upload(bot.run_image, [btn, state, txt, aux_state], [chatbot, state, txt, aux_state, auxwindow])
|
220 |
+
clear.click(bot.memory.clear)
|
221 |
+
clear.click(lambda: [], None, chatbot)
|
222 |
+
clear.click(lambda: [], None, auxwindow)
|
223 |
+
clear.click(lambda: [], None, state)
|
224 |
+
clear.click(lambda: [], None, aux_state)
|
225 |
+
demo.launch(server_name="0.0.0.0", server_port=args.port, share=True)
|
caption_anything/utils/densecap_painter.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import json
|
3 |
+
import numpy as np
|
4 |
+
from typing import List
|
5 |
+
import random
|
6 |
+
from typing import Union
|
7 |
+
|
8 |
+
def draw_bbox(img: Union[np.ndarray, str], save_name: str, bbox: List[dict], show_caption: bool = False):
|
9 |
+
"""
|
10 |
+
bbox: [{'image_id': str, 'bbox': [x1, y1, x2, y2], 'caption': str}, ...]
|
11 |
+
"""
|
12 |
+
if isinstance(img, str):
|
13 |
+
img = cv2.imread(img)
|
14 |
+
|
15 |
+
RGB = [0, 50, 100, 150, 200, 250]
|
16 |
+
for box in bbox:
|
17 |
+
box['bbox'] = [int(_) for _ in box['bbox']]
|
18 |
+
x1, y1, x2, y2 = box['bbox']
|
19 |
+
caption = box['caption']
|
20 |
+
box_color = random.choices(RGB, k = 3)
|
21 |
+
(text_width, text_height), _ = cv2.getTextSize(caption, cv2.FONT_HERSHEY_SIMPLEX, fontScale = 0.5, thickness = 2)
|
22 |
+
cv2.rectangle(img, (x1, y1), (x2, y2), color = box_color, thickness = 2)
|
23 |
+
if show_caption:
|
24 |
+
cv2.putText(img, caption, (x1, y1 + text_height), cv2.FONT_HERSHEY_SIMPLEX, fontScale = 0.5, color = box_color, thickness = 2)
|
25 |
+
|
26 |
+
cv2.imwrite(save_name, img)
|
27 |
+
# cv2.imshow('visualise', img)
|
28 |
+
# cv2.waitKey(0)
|
29 |
+
|
30 |
+
def parse_bbox(anno, image_id: int = None):
|
31 |
+
|
32 |
+
with open(anno, 'r') as f:
|
33 |
+
predictions = json.load(f)
|
34 |
+
|
35 |
+
if image_id is None:
|
36 |
+
image_id = next(iter(predictions))
|
37 |
+
|
38 |
+
return predictions[image_id]
|
39 |
+
|
40 |
+
def gt_bbox(anno, img_name: int = None):
|
41 |
+
|
42 |
+
with open(anno, 'r') as f:
|
43 |
+
annotations = json.load(f)
|
44 |
+
annotations = annotations['annotations']
|
45 |
+
|
46 |
+
gt = []
|
47 |
+
img_name = int(img_name[:-4])
|
48 |
+
for annotation in annotations:
|
49 |
+
if annotation['image_id'] == 63:
|
50 |
+
x1, y1, w, h = annotation['bbox']
|
51 |
+
gt.append({'bbox': [x1, y1, x1 + w, y1 + h], 'caption': annotation['caption']})
|
52 |
+
return gt
|
53 |
+
|
54 |
+
if __name__ == '__main__':
|
55 |
+
|
56 |
+
img_name = '63.jpg'
|
57 |
+
show_caption = True
|
58 |
+
anno = 'vg_dense_captioning_blip2_top48_0.88_1000_0.96_debugTrue_predictions_shard_all.json'
|
59 |
+
|
60 |
+
img = cv2.imread(img_name)
|
61 |
+
examp_bbox = parse_bbox(anno)
|
62 |
+
ground_truth_bbox = gt_bbox('test.json', img_name)
|
63 |
+
draw_bbox(img, 'GT.jpg', ground_truth_bbox, show_caption)
|
64 |
+
draw_bbox(img, 'Pred.jpg', examp_bbox, show_caption)
|
caption_anything/utils/image_editing_utils.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image, ImageDraw, ImageFont
|
2 |
+
import copy
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
|
6 |
+
|
7 |
+
def wrap_text(text, font, max_width):
|
8 |
+
lines = []
|
9 |
+
words = text.split(' ')
|
10 |
+
current_line = ''
|
11 |
+
|
12 |
+
for word in words:
|
13 |
+
if font.getsize(current_line + word)[0] <= max_width:
|
14 |
+
current_line += word + ' '
|
15 |
+
else:
|
16 |
+
lines.append(current_line)
|
17 |
+
current_line = word + ' '
|
18 |
+
|
19 |
+
lines.append(current_line)
|
20 |
+
return lines
|
21 |
+
|
22 |
+
|
23 |
+
def create_bubble_frame(image, text, point, segmask, input_points=(), input_labels=(),
|
24 |
+
font_path='assets/times_with_simsun.ttf', font_size_ratio=0.033, point_size_ratio=0.01):
|
25 |
+
# Load the image
|
26 |
+
if input_points is None:
|
27 |
+
input_points = []
|
28 |
+
if input_labels is None:
|
29 |
+
input_labels = []
|
30 |
+
|
31 |
+
if type(image) == np.ndarray:
|
32 |
+
image = Image.fromarray(image)
|
33 |
+
|
34 |
+
image = copy.deepcopy(image)
|
35 |
+
width, height = image.size
|
36 |
+
|
37 |
+
# Calculate max_text_width and font_size based on image dimensions and total number of characters
|
38 |
+
total_chars = len(text)
|
39 |
+
max_text_width = int(0.4 * width)
|
40 |
+
font_size = int(height * font_size_ratio)
|
41 |
+
point_size = max(int(height * point_size_ratio), 1)
|
42 |
+
|
43 |
+
# Load the font
|
44 |
+
font = ImageFont.truetype(font_path, font_size)
|
45 |
+
|
46 |
+
# Wrap the text to fit within the max_text_width
|
47 |
+
lines = wrap_text(text, font, max_text_width)
|
48 |
+
text_width = max([font.getsize(line)[0] for line in lines])
|
49 |
+
_, text_height = font.getsize(lines[0])
|
50 |
+
text_height = text_height * len(lines)
|
51 |
+
|
52 |
+
# Define bubble frame dimensions
|
53 |
+
padding = 10
|
54 |
+
bubble_width = text_width + 2 * padding
|
55 |
+
bubble_height = text_height + 2 * padding
|
56 |
+
|
57 |
+
# Create a new image for the bubble frame
|
58 |
+
bubble = Image.new('RGBA', (bubble_width, bubble_height), (255, 248, 220, 0))
|
59 |
+
|
60 |
+
# Draw the bubble frame on the new image
|
61 |
+
draw = ImageDraw.Draw(bubble)
|
62 |
+
# draw.rectangle([(0, 0), (bubble_width - 1, bubble_height - 1)], fill=(255, 255, 255, 0), outline=(255, 255, 255, 0), width=2)
|
63 |
+
draw_rounded_rectangle(draw, (0, 0, bubble_width - 1, bubble_height - 1), point_size * 2,
|
64 |
+
fill=(255, 248, 220, 120), outline=None, width=2)
|
65 |
+
# Draw the wrapped text line by line
|
66 |
+
y_text = padding
|
67 |
+
for line in lines:
|
68 |
+
draw.text((padding, y_text), line, font=font, fill=(0, 0, 0, 255))
|
69 |
+
y_text += font.getsize(line)[1]
|
70 |
+
|
71 |
+
# Determine the point by the min area rect of mask
|
72 |
+
try:
|
73 |
+
ret, thresh = cv2.threshold(segmask, 127, 255, 0)
|
74 |
+
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
75 |
+
largest_contour = max(contours, key=cv2.contourArea)
|
76 |
+
min_area_rect = cv2.minAreaRect(largest_contour)
|
77 |
+
box = cv2.boxPoints(min_area_rect)
|
78 |
+
sorted_points = box[np.argsort(box[:, 0])]
|
79 |
+
right_most_points = sorted_points[-2:]
|
80 |
+
right_down_most_point = right_most_points[np.argsort(right_most_points[:, 1])][-1]
|
81 |
+
x, y = int(right_down_most_point[0]), int(right_down_most_point[1])
|
82 |
+
except:
|
83 |
+
x, y = point
|
84 |
+
# Calculate the bubble frame position
|
85 |
+
if x + bubble_width > width:
|
86 |
+
x = width - bubble_width
|
87 |
+
if y + bubble_height > height:
|
88 |
+
y = height - bubble_height
|
89 |
+
|
90 |
+
# Paste the bubble frame onto the image
|
91 |
+
image.paste(bubble, (x, y), bubble)
|
92 |
+
draw = ImageDraw.Draw(image)
|
93 |
+
colors = [(0, 191, 255, 255), (255, 106, 106, 255)]
|
94 |
+
for p, label in zip(input_points, input_labels):
|
95 |
+
point_x, point_y = p[0], p[1]
|
96 |
+
left = point_x - point_size
|
97 |
+
top = point_y - point_size
|
98 |
+
right = point_x + point_size
|
99 |
+
bottom = point_y + point_size
|
100 |
+
draw.ellipse((left, top, right, bottom), fill=colors[label])
|
101 |
+
return image
|
102 |
+
|
103 |
+
|
104 |
+
def draw_rounded_rectangle(draw, xy, corner_radius, fill=None, outline=None, width=1):
|
105 |
+
x1, y1, x2, y2 = xy
|
106 |
+
|
107 |
+
draw.rectangle(
|
108 |
+
(x1, y1 + corner_radius, x2, y2 - corner_radius),
|
109 |
+
fill=fill,
|
110 |
+
outline=outline,
|
111 |
+
width=width
|
112 |
+
)
|
113 |
+
draw.rectangle(
|
114 |
+
(x1 + corner_radius, y1, x2 - corner_radius, y2),
|
115 |
+
fill=fill,
|
116 |
+
outline=outline,
|
117 |
+
width=width
|
118 |
+
)
|
119 |
+
|
120 |
+
draw.pieslice((x1, y1, x1 + corner_radius * 2, y1 + corner_radius * 2), 180, 270, fill=fill, outline=outline,
|
121 |
+
width=width)
|
122 |
+
draw.pieslice((x2 - corner_radius * 2, y1, x2, y1 + corner_radius * 2), 270, 360, fill=fill, outline=outline,
|
123 |
+
width=width)
|
124 |
+
draw.pieslice((x2 - corner_radius * 2, y2 - corner_radius * 2, x2, y2), 0, 90, fill=fill, outline=outline,
|
125 |
+
width=width)
|
126 |
+
draw.pieslice((x1, y2 - corner_radius * 2, x1 + corner_radius * 2, y2), 90, 180, fill=fill, outline=outline,
|
127 |
+
width=width)
|
caption_anything/utils/parser.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
def parse_augment():
|
4 |
+
parser = argparse.ArgumentParser()
|
5 |
+
parser.add_argument('--captioner', type=str, default="blip2")
|
6 |
+
parser.add_argument('--segmenter', type=str, default="huge")
|
7 |
+
parser.add_argument('--text_refiner', type=str, default="base")
|
8 |
+
parser.add_argument('--segmenter_checkpoint', type=str, default=None, help="SAM checkpoint path")
|
9 |
+
parser.add_argument('--seg_crop_mode', type=str, default="wo_bg", choices=['wo_bg', 'w_bg'],
|
10 |
+
help="whether to add or remove background of the image when captioning")
|
11 |
+
parser.add_argument('--clip_filter', action="store_true", help="use clip to filter bad captions")
|
12 |
+
parser.add_argument('--context_captions', action="store_true",
|
13 |
+
help="use surrounding captions to enhance current caption (TODO)")
|
14 |
+
parser.add_argument('--disable_regular_box', action="store_true", default=False,
|
15 |
+
help="crop image with a regular box")
|
16 |
+
parser.add_argument('--device', type=str, default="cuda:0")
|
17 |
+
parser.add_argument('--port', type=int, default=6086, help="only useful when running gradio applications")
|
18 |
+
parser.add_argument('--debug', action="store_true")
|
19 |
+
parser.add_argument('--gradio_share', action="store_true")
|
20 |
+
parser.add_argument('--disable_gpt', action="store_true")
|
21 |
+
parser.add_argument('--enable_reduce_tokens', action="store_true", default=False)
|
22 |
+
parser.add_argument('--disable_reuse_features', action="store_true", default=False)
|
23 |
+
parser.add_argument('--enable_morphologyex', action="store_true", default=False)
|
24 |
+
parser.add_argument('--chat_tools_dict', type=str, default='VisualQuestionAnswering_cuda:0', help='Visual ChatGPT tools, only useful when running gradio applications')
|
25 |
+
|
26 |
+
parser.add_argument('--pred_iou_thresh', type=float, default=0.88, help="sam post-precessing")
|
27 |
+
parser.add_argument('--min_mask_region_area', type=int, default=0, help="sam post-precessing")
|
28 |
+
parser.add_argument('--stability_score_thresh', type=float, default=0.95, help='sam post-processing')
|
29 |
+
parser.add_argument('--box_nms_thresh', type=float, default=0.7, help='sam post-processing')
|
30 |
+
|
31 |
+
args = parser.parse_args()
|
32 |
+
|
33 |
+
if args.debug:
|
34 |
+
print(args)
|
35 |
+
return args
|
caption_anything/utils/utils.py
ADDED
@@ -0,0 +1,496 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import sys
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import hashlib
|
7 |
+
import requests
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
from typing import Union
|
11 |
+
|
12 |
+
from PIL import Image
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
|
16 |
+
def load_image(image: Union[np.ndarray, Image.Image, str], return_type='numpy'):
|
17 |
+
"""
|
18 |
+
Load image from path or PIL.Image or numpy.ndarray to required format.
|
19 |
+
"""
|
20 |
+
|
21 |
+
# Check if image is already in return_type
|
22 |
+
if isinstance(image, Image.Image) and return_type == 'pil' or \
|
23 |
+
isinstance(image, np.ndarray) and return_type == 'numpy':
|
24 |
+
return image
|
25 |
+
|
26 |
+
# PIL.Image as intermediate format
|
27 |
+
if isinstance(image, str):
|
28 |
+
image = Image.open(image)
|
29 |
+
elif isinstance(image, np.ndarray):
|
30 |
+
image = Image.fromarray(image)
|
31 |
+
|
32 |
+
if image.mode == "RGBA":
|
33 |
+
image = image.convert("RGB")
|
34 |
+
|
35 |
+
if return_type == 'pil':
|
36 |
+
return image
|
37 |
+
elif return_type == 'numpy':
|
38 |
+
return np.asarray(image)
|
39 |
+
else:
|
40 |
+
raise NotImplementedError()
|
41 |
+
|
42 |
+
|
43 |
+
def image_resize(image: Image.Image, res=1024):
|
44 |
+
width, height = org_size = image.size
|
45 |
+
ratio = min(1.0 * res / max(width, height), 1.0)
|
46 |
+
if ratio < 1.0:
|
47 |
+
image = image.resize((int(width * ratio), int(height * ratio)))
|
48 |
+
print('Scaling image from {} to {}'.format(org_size, image.size))
|
49 |
+
return image
|
50 |
+
|
51 |
+
def xywh_to_x1y1x2y2(bbox):
|
52 |
+
x, y, w, h = bbox
|
53 |
+
return x,y,x+w,y+h
|
54 |
+
|
55 |
+
|
56 |
+
def x1y1x2y2_to_xywh(bbox):
|
57 |
+
x1, y1, x2, y2 = bbox
|
58 |
+
return x1,y1,x2-x1,y2-y1
|
59 |
+
|
60 |
+
|
61 |
+
def get_image_shape(image):
|
62 |
+
if isinstance(image, str):
|
63 |
+
return Image.open(image).size
|
64 |
+
elif isinstance(image, np.ndarray):
|
65 |
+
return image.shape
|
66 |
+
elif isinstance(image, Image.Image):
|
67 |
+
return image.size
|
68 |
+
else:
|
69 |
+
raise NotImplementedError
|
70 |
+
|
71 |
+
def is_platform_win():
|
72 |
+
return sys.platform == "win32"
|
73 |
+
|
74 |
+
|
75 |
+
def colormap(rgb=True):
|
76 |
+
color_list = np.array(
|
77 |
+
[
|
78 |
+
0.000, 0.000, 0.000,
|
79 |
+
1.000, 1.000, 1.000,
|
80 |
+
1.000, 0.498, 0.313,
|
81 |
+
0.392, 0.581, 0.929,
|
82 |
+
0.000, 0.447, 0.741,
|
83 |
+
0.850, 0.325, 0.098,
|
84 |
+
0.929, 0.694, 0.125,
|
85 |
+
0.494, 0.184, 0.556,
|
86 |
+
0.466, 0.674, 0.188,
|
87 |
+
0.301, 0.745, 0.933,
|
88 |
+
0.635, 0.078, 0.184,
|
89 |
+
0.300, 0.300, 0.300,
|
90 |
+
0.600, 0.600, 0.600,
|
91 |
+
1.000, 0.000, 0.000,
|
92 |
+
1.000, 0.500, 0.000,
|
93 |
+
0.749, 0.749, 0.000,
|
94 |
+
0.000, 1.000, 0.000,
|
95 |
+
0.000, 0.000, 1.000,
|
96 |
+
0.667, 0.000, 1.000,
|
97 |
+
0.333, 0.333, 0.000,
|
98 |
+
0.333, 0.667, 0.000,
|
99 |
+
0.333, 1.000, 0.000,
|
100 |
+
0.667, 0.333, 0.000,
|
101 |
+
0.667, 0.667, 0.000,
|
102 |
+
0.667, 1.000, 0.000,
|
103 |
+
1.000, 0.333, 0.000,
|
104 |
+
1.000, 0.667, 0.000,
|
105 |
+
1.000, 1.000, 0.000,
|
106 |
+
0.000, 0.333, 0.500,
|
107 |
+
0.000, 0.667, 0.500,
|
108 |
+
0.000, 1.000, 0.500,
|
109 |
+
0.333, 0.000, 0.500,
|
110 |
+
0.333, 0.333, 0.500,
|
111 |
+
0.333, 0.667, 0.500,
|
112 |
+
0.333, 1.000, 0.500,
|
113 |
+
0.667, 0.000, 0.500,
|
114 |
+
0.667, 0.333, 0.500,
|
115 |
+
0.667, 0.667, 0.500,
|
116 |
+
0.667, 1.000, 0.500,
|
117 |
+
1.000, 0.000, 0.500,
|
118 |
+
1.000, 0.333, 0.500,
|
119 |
+
1.000, 0.667, 0.500,
|
120 |
+
1.000, 1.000, 0.500,
|
121 |
+
0.000, 0.333, 1.000,
|
122 |
+
0.000, 0.667, 1.000,
|
123 |
+
0.000, 1.000, 1.000,
|
124 |
+
0.333, 0.000, 1.000,
|
125 |
+
0.333, 0.333, 1.000,
|
126 |
+
0.333, 0.667, 1.000,
|
127 |
+
0.333, 1.000, 1.000,
|
128 |
+
0.667, 0.000, 1.000,
|
129 |
+
0.667, 0.333, 1.000,
|
130 |
+
0.667, 0.667, 1.000,
|
131 |
+
0.667, 1.000, 1.000,
|
132 |
+
1.000, 0.000, 1.000,
|
133 |
+
1.000, 0.333, 1.000,
|
134 |
+
1.000, 0.667, 1.000,
|
135 |
+
0.167, 0.000, 0.000,
|
136 |
+
0.333, 0.000, 0.000,
|
137 |
+
0.500, 0.000, 0.000,
|
138 |
+
0.667, 0.000, 0.000,
|
139 |
+
0.833, 0.000, 0.000,
|
140 |
+
1.000, 0.000, 0.000,
|
141 |
+
0.000, 0.167, 0.000,
|
142 |
+
0.000, 0.333, 0.000,
|
143 |
+
0.000, 0.500, 0.000,
|
144 |
+
0.000, 0.667, 0.000,
|
145 |
+
0.000, 0.833, 0.000,
|
146 |
+
0.000, 1.000, 0.000,
|
147 |
+
0.000, 0.000, 0.167,
|
148 |
+
0.000, 0.000, 0.333,
|
149 |
+
0.000, 0.000, 0.500,
|
150 |
+
0.000, 0.000, 0.667,
|
151 |
+
0.000, 0.000, 0.833,
|
152 |
+
0.000, 0.000, 1.000,
|
153 |
+
0.143, 0.143, 0.143,
|
154 |
+
0.286, 0.286, 0.286,
|
155 |
+
0.429, 0.429, 0.429,
|
156 |
+
0.571, 0.571, 0.571,
|
157 |
+
0.714, 0.714, 0.714,
|
158 |
+
0.857, 0.857, 0.857
|
159 |
+
]
|
160 |
+
).astype(np.float32)
|
161 |
+
color_list = color_list.reshape((-1, 3)) * 255
|
162 |
+
if not rgb:
|
163 |
+
color_list = color_list[:, ::-1]
|
164 |
+
return color_list
|
165 |
+
|
166 |
+
|
167 |
+
color_list = colormap()
|
168 |
+
color_list = color_list.astype('uint8').tolist()
|
169 |
+
|
170 |
+
|
171 |
+
def vis_add_mask(image, mask, color, alpha, kernel_size):
|
172 |
+
color = np.array(color)
|
173 |
+
mask = mask.astype('float').copy()
|
174 |
+
mask = (cv2.GaussianBlur(mask, (kernel_size, kernel_size), kernel_size) / 255.) * (alpha)
|
175 |
+
for i in range(3):
|
176 |
+
image[:, :, i] = image[:, :, i] * (1 - alpha + mask) + color[i] * (alpha - mask)
|
177 |
+
return image
|
178 |
+
|
179 |
+
|
180 |
+
def vis_add_mask_wo_blur(image, mask, color, alpha):
|
181 |
+
color = np.array(color)
|
182 |
+
mask = mask.astype('float').copy()
|
183 |
+
for i in range(3):
|
184 |
+
image[:, :, i] = image[:, :, i] * (1 - alpha + mask) + color[i] * (alpha - mask)
|
185 |
+
return image
|
186 |
+
|
187 |
+
|
188 |
+
def vis_add_mask_wo_gaussian(image, background_mask, contour_mask, background_color, contour_color, background_alpha,
|
189 |
+
contour_alpha):
|
190 |
+
background_color = np.array(background_color)
|
191 |
+
contour_color = np.array(contour_color)
|
192 |
+
|
193 |
+
# background_mask = 1 - background_mask
|
194 |
+
# contour_mask = 1 - contour_mask
|
195 |
+
|
196 |
+
for i in range(3):
|
197 |
+
image[:, :, i] = image[:, :, i] * (1 - background_alpha + background_mask * background_alpha) \
|
198 |
+
+ background_color[i] * (background_alpha - background_mask * background_alpha)
|
199 |
+
|
200 |
+
image[:, :, i] = image[:, :, i] * (1 - contour_alpha + contour_mask * contour_alpha) \
|
201 |
+
+ contour_color[i] * (contour_alpha - contour_mask * contour_alpha)
|
202 |
+
|
203 |
+
return image.astype('uint8')
|
204 |
+
|
205 |
+
|
206 |
+
def mask_painter(input_image, input_mask, background_alpha=0.7, background_blur_radius=7, contour_width=3,
|
207 |
+
contour_color=3, contour_alpha=1, background_color=0, paint_foreground=False):
|
208 |
+
"""
|
209 |
+
add color mask to the background/foreground area
|
210 |
+
input_image: numpy array (w, h, C)
|
211 |
+
input_mask: numpy array (w, h)
|
212 |
+
background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
|
213 |
+
background_blur_radius: radius of background blur, must be odd number
|
214 |
+
contour_width: width of mask contour, must be odd number
|
215 |
+
contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
|
216 |
+
background_color: color index of the background (area with input_mask == False)
|
217 |
+
contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
|
218 |
+
paint_foreground: True for paint on foreground, False for background. Default: Flase
|
219 |
+
|
220 |
+
Output:
|
221 |
+
painted_image: numpy array
|
222 |
+
"""
|
223 |
+
assert input_image.shape[:2] == input_mask.shape, 'different shape'
|
224 |
+
assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
|
225 |
+
|
226 |
+
# 0: background, 1: foreground
|
227 |
+
input_mask[input_mask > 0] = 255
|
228 |
+
if paint_foreground:
|
229 |
+
painted_image = vis_add_mask(input_image, 255 - input_mask, color_list[background_color], background_alpha,
|
230 |
+
background_blur_radius) # black for background
|
231 |
+
else:
|
232 |
+
# mask background
|
233 |
+
painted_image = vis_add_mask(input_image, input_mask, color_list[background_color], background_alpha,
|
234 |
+
background_blur_radius) # black for background
|
235 |
+
# mask contour
|
236 |
+
contour_mask = input_mask.copy()
|
237 |
+
contour_mask = cv2.Canny(contour_mask, 100, 200) # contour extraction
|
238 |
+
# widden contour
|
239 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (contour_width, contour_width))
|
240 |
+
contour_mask = cv2.dilate(contour_mask, kernel)
|
241 |
+
painted_image = vis_add_mask(painted_image, 255 - contour_mask, color_list[contour_color], contour_alpha,
|
242 |
+
contour_width)
|
243 |
+
return painted_image
|
244 |
+
|
245 |
+
|
246 |
+
def mask_painter_foreground_all(input_image, input_masks, background_alpha=0.7, background_blur_radius=7,
|
247 |
+
contour_width=3, contour_color=3, contour_alpha=1):
|
248 |
+
"""
|
249 |
+
paint color mask on the all foreground area
|
250 |
+
input_image: numpy array with shape (w, h, C)
|
251 |
+
input_mask: list of masks, each mask is a numpy array with shape (w,h)
|
252 |
+
background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
|
253 |
+
background_blur_radius: radius of background blur, must be odd number
|
254 |
+
contour_width: width of mask contour, must be odd number
|
255 |
+
contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
|
256 |
+
background_color: color index of the background (area with input_mask == False)
|
257 |
+
contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
|
258 |
+
|
259 |
+
Output:
|
260 |
+
painted_image: numpy array
|
261 |
+
"""
|
262 |
+
|
263 |
+
for i, input_mask in enumerate(input_masks):
|
264 |
+
input_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width,
|
265 |
+
contour_color, contour_alpha, background_color=i + 2, paint_foreground=True)
|
266 |
+
return input_image
|
267 |
+
|
268 |
+
|
269 |
+
def mask_generator_00(mask, background_radius, contour_radius):
|
270 |
+
# no background width when '00'
|
271 |
+
# distance map
|
272 |
+
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
|
273 |
+
dist_transform_back = cv2.distanceTransform(1 - mask, cv2.DIST_L2, 3)
|
274 |
+
dist_map = dist_transform_fore - dist_transform_back
|
275 |
+
# ...:::!!!:::...
|
276 |
+
contour_radius += 2
|
277 |
+
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
|
278 |
+
contour_mask = contour_mask / np.max(contour_mask)
|
279 |
+
contour_mask[contour_mask > 0.5] = 1.
|
280 |
+
|
281 |
+
return mask, contour_mask
|
282 |
+
|
283 |
+
|
284 |
+
def mask_generator_01(mask, background_radius, contour_radius):
|
285 |
+
# no background width when '00'
|
286 |
+
# distance map
|
287 |
+
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
|
288 |
+
dist_transform_back = cv2.distanceTransform(1 - mask, cv2.DIST_L2, 3)
|
289 |
+
dist_map = dist_transform_fore - dist_transform_back
|
290 |
+
# ...:::!!!:::...
|
291 |
+
contour_radius += 2
|
292 |
+
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
|
293 |
+
contour_mask = contour_mask / np.max(contour_mask)
|
294 |
+
return mask, contour_mask
|
295 |
+
|
296 |
+
|
297 |
+
def mask_generator_10(mask, background_radius, contour_radius):
|
298 |
+
# distance map
|
299 |
+
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
|
300 |
+
dist_transform_back = cv2.distanceTransform(1 - mask, cv2.DIST_L2, 3)
|
301 |
+
dist_map = dist_transform_fore - dist_transform_back
|
302 |
+
# .....:::::!!!!!
|
303 |
+
background_mask = np.clip(dist_map, -background_radius, background_radius)
|
304 |
+
background_mask = (background_mask - np.min(background_mask))
|
305 |
+
background_mask = background_mask / np.max(background_mask)
|
306 |
+
# ...:::!!!:::...
|
307 |
+
contour_radius += 2
|
308 |
+
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
|
309 |
+
contour_mask = contour_mask / np.max(contour_mask)
|
310 |
+
contour_mask[contour_mask > 0.5] = 1.
|
311 |
+
return background_mask, contour_mask
|
312 |
+
|
313 |
+
|
314 |
+
def mask_generator_11(mask, background_radius, contour_radius):
|
315 |
+
# distance map
|
316 |
+
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
|
317 |
+
dist_transform_back = cv2.distanceTransform(1 - mask, cv2.DIST_L2, 3)
|
318 |
+
dist_map = dist_transform_fore - dist_transform_back
|
319 |
+
# .....:::::!!!!!
|
320 |
+
background_mask = np.clip(dist_map, -background_radius, background_radius)
|
321 |
+
background_mask = (background_mask - np.min(background_mask))
|
322 |
+
background_mask = background_mask / np.max(background_mask)
|
323 |
+
# ...:::!!!:::...
|
324 |
+
contour_radius += 2
|
325 |
+
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
|
326 |
+
contour_mask = contour_mask / np.max(contour_mask)
|
327 |
+
return background_mask, contour_mask
|
328 |
+
|
329 |
+
|
330 |
+
def mask_painter_wo_gaussian(input_image, input_mask, background_alpha=0.5, background_blur_radius=7, contour_width=3,
|
331 |
+
contour_color=3, contour_alpha=1, mode='11'):
|
332 |
+
"""
|
333 |
+
Input:
|
334 |
+
input_image: numpy array
|
335 |
+
input_mask: numpy array
|
336 |
+
background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
|
337 |
+
background_blur_radius: radius of background blur, must be odd number
|
338 |
+
contour_width: width of mask contour, must be odd number
|
339 |
+
contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
|
340 |
+
contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
|
341 |
+
mode: painting mode, '00', no blur, '01' only blur contour, '10' only blur background, '11' blur both
|
342 |
+
|
343 |
+
Output:
|
344 |
+
painted_image: numpy array
|
345 |
+
"""
|
346 |
+
assert input_image.shape[:2] == input_mask.shape, 'different shape'
|
347 |
+
assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
|
348 |
+
assert mode in ['00', '01', '10', '11'], 'mode should be 00, 01, 10, or 11'
|
349 |
+
|
350 |
+
# downsample input image and mask
|
351 |
+
width, height = input_image.shape[0], input_image.shape[1]
|
352 |
+
res = 1024
|
353 |
+
ratio = min(1.0 * res / max(width, height), 1.0)
|
354 |
+
input_image = cv2.resize(input_image, (int(height * ratio), int(width * ratio)))
|
355 |
+
input_mask = cv2.resize(input_mask, (int(height * ratio), int(width * ratio)))
|
356 |
+
|
357 |
+
# 0: background, 1: foreground
|
358 |
+
msk = np.clip(input_mask, 0, 1)
|
359 |
+
|
360 |
+
# generate masks for background and contour pixels
|
361 |
+
background_radius = (background_blur_radius - 1) // 2
|
362 |
+
contour_radius = (contour_width - 1) // 2
|
363 |
+
generator_dict = {'00': mask_generator_00, '01': mask_generator_01, '10': mask_generator_10,
|
364 |
+
'11': mask_generator_11}
|
365 |
+
background_mask, contour_mask = generator_dict[mode](msk, background_radius, contour_radius)
|
366 |
+
|
367 |
+
# paint
|
368 |
+
painted_image = vis_add_mask_wo_gaussian \
|
369 |
+
(input_image, background_mask, contour_mask, color_list[0], color_list[contour_color], background_alpha,
|
370 |
+
contour_alpha) # black for background
|
371 |
+
|
372 |
+
return painted_image
|
373 |
+
|
374 |
+
|
375 |
+
seg_model_map = {
|
376 |
+
'base': 'vit_b',
|
377 |
+
'large': 'vit_l',
|
378 |
+
'huge': 'vit_h'
|
379 |
+
}
|
380 |
+
ckpt_url_map = {
|
381 |
+
'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
|
382 |
+
'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
|
383 |
+
'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'
|
384 |
+
}
|
385 |
+
expected_sha256_map = {
|
386 |
+
'vit_b': 'ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912',
|
387 |
+
'vit_l': '3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622',
|
388 |
+
'vit_h': 'a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e'
|
389 |
+
}
|
390 |
+
|
391 |
+
|
392 |
+
def prepare_segmenter(segmenter="huge", download_root: str = None):
|
393 |
+
"""
|
394 |
+
Prepare segmenter model and download checkpoint if necessary.
|
395 |
+
|
396 |
+
Returns: segmenter model name from 'vit_b', 'vit_l', 'vit_h'.
|
397 |
+
|
398 |
+
"""
|
399 |
+
|
400 |
+
os.makedirs('result', exist_ok=True)
|
401 |
+
seg_model_name = seg_model_map[segmenter]
|
402 |
+
checkpoint_url = ckpt_url_map[seg_model_name]
|
403 |
+
folder = download_root or os.path.expanduser("~/.cache/SAM")
|
404 |
+
filename = os.path.basename(checkpoint_url)
|
405 |
+
segmenter_checkpoint = download_checkpoint(checkpoint_url, folder, filename, expected_sha256_map[seg_model_name])
|
406 |
+
|
407 |
+
return seg_model_name, segmenter_checkpoint
|
408 |
+
|
409 |
+
|
410 |
+
def download_checkpoint(url, folder, filename, expected_sha256):
|
411 |
+
os.makedirs(folder, exist_ok=True)
|
412 |
+
download_target = os.path.join(folder, filename)
|
413 |
+
if os.path.isfile(download_target):
|
414 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
415 |
+
return download_target
|
416 |
+
|
417 |
+
print(f'Download SAM checkpoint {url}, saving to {download_target} ...')
|
418 |
+
with requests.get(url, stream=True) as response, open(download_target, "wb") as output:
|
419 |
+
progress = tqdm(total=int(response.headers.get('content-length', 0)), unit='B', unit_scale=True)
|
420 |
+
for data in response.iter_content(chunk_size=1024):
|
421 |
+
size = output.write(data)
|
422 |
+
progress.update(size)
|
423 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
424 |
+
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
|
425 |
+
return download_target
|
426 |
+
|
427 |
+
|
428 |
+
if __name__ == '__main__':
|
429 |
+
|
430 |
+
background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing
|
431 |
+
background_blur_radius = 31 # radius of background blur, must be odd number
|
432 |
+
contour_width = 11 # contour width, must be odd number
|
433 |
+
contour_color = 3 # id in color map, 0: black, 1: white, >1: others
|
434 |
+
contour_alpha = 1 # transparency of background, 0: no contour highlighted
|
435 |
+
|
436 |
+
# load input image and mask
|
437 |
+
input_image = np.array(Image.open('./test_images/painter_input_image.jpg').convert('RGB'))
|
438 |
+
input_mask = np.array(Image.open('./test_images/painter_input_mask.jpg').convert('P'))
|
439 |
+
|
440 |
+
# paint
|
441 |
+
overall_time_1 = 0
|
442 |
+
overall_time_2 = 0
|
443 |
+
overall_time_3 = 0
|
444 |
+
overall_time_4 = 0
|
445 |
+
overall_time_5 = 0
|
446 |
+
|
447 |
+
for i in range(50):
|
448 |
+
t2 = time.time()
|
449 |
+
painted_image_00 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius,
|
450 |
+
contour_width, contour_color, contour_alpha, mode='00')
|
451 |
+
e2 = time.time()
|
452 |
+
|
453 |
+
t3 = time.time()
|
454 |
+
painted_image_10 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius,
|
455 |
+
contour_width, contour_color, contour_alpha, mode='10')
|
456 |
+
e3 = time.time()
|
457 |
+
|
458 |
+
t1 = time.time()
|
459 |
+
painted_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width,
|
460 |
+
contour_color, contour_alpha)
|
461 |
+
e1 = time.time()
|
462 |
+
|
463 |
+
t4 = time.time()
|
464 |
+
painted_image_01 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius,
|
465 |
+
contour_width, contour_color, contour_alpha, mode='01')
|
466 |
+
e4 = time.time()
|
467 |
+
|
468 |
+
t5 = time.time()
|
469 |
+
painted_image_11 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius,
|
470 |
+
contour_width, contour_color, contour_alpha, mode='11')
|
471 |
+
e5 = time.time()
|
472 |
+
|
473 |
+
overall_time_1 += (e1 - t1)
|
474 |
+
overall_time_2 += (e2 - t2)
|
475 |
+
overall_time_3 += (e3 - t3)
|
476 |
+
overall_time_4 += (e4 - t4)
|
477 |
+
overall_time_5 += (e5 - t5)
|
478 |
+
|
479 |
+
print(f'average time w gaussian: {overall_time_1 / 50}')
|
480 |
+
print(f'average time w/o gaussian00: {overall_time_2 / 50}')
|
481 |
+
print(f'average time w/o gaussian10: {overall_time_3 / 50}')
|
482 |
+
print(f'average time w/o gaussian01: {overall_time_4 / 50}')
|
483 |
+
print(f'average time w/o gaussian11: {overall_time_5 / 50}')
|
484 |
+
|
485 |
+
# save
|
486 |
+
painted_image_00 = Image.fromarray(painted_image_00)
|
487 |
+
painted_image_00.save('./test_images/painter_output_image_00.png')
|
488 |
+
|
489 |
+
painted_image_10 = Image.fromarray(painted_image_10)
|
490 |
+
painted_image_10.save('./test_images/painter_output_image_10.png')
|
491 |
+
|
492 |
+
painted_image_01 = Image.fromarray(painted_image_01)
|
493 |
+
painted_image_01.save('./test_images/painter_output_image_01.png')
|
494 |
+
|
495 |
+
painted_image_11 = Image.fromarray(painted_image_11)
|
496 |
+
painted_image_11.save('./test_images/painter_output_image_11.png')
|
requirements.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
https://download.pytorch.org/whl/cu111/torch-1.10.1%2Bcu111-cp38-cp38-linux_x86_64.whl
|
2 |
+
https://download.pytorch.org/whl/cu111/torchvision-0.11.2%2Bcu111-cp38-cp38-linux_x86_64.whl
|
3 |
+
https://download.pytorch.org/whl/cu111/torchaudio-0.10.1%2Bcu111-cp38-cp38-linux_x86_64.whl
|
4 |
+
openai
|
5 |
+
pillow
|
6 |
+
langchain==0.0.101
|
7 |
+
transformers==4.28.1
|
8 |
+
ftfy
|
9 |
+
regex
|
10 |
+
tqdm
|
11 |
+
git+https://github.com/openai/CLIP.git
|
12 |
+
git+https://github.com/facebookresearch/segment-anything.git
|
13 |
+
opencv-python==4.5.5.64
|
14 |
+
pycocotools
|
15 |
+
matplotlib
|
16 |
+
onnxruntime
|
17 |
+
onnx
|
18 |
+
gradio==3.27.0
|
19 |
+
accelerate
|
20 |
+
bitsandbytes==0.34
|
21 |
+
easyocr
|
sam_vit_h_4b8939.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
|
3 |
+
size 2564550879
|
test_images/img0.png
ADDED
![]() |
test_images/img1.jpg
ADDED
![]() |
test_images/img12.jpg
ADDED
![]() |
test_images/img14.jpg
ADDED
![]() |
test_images/img2.jpg
ADDED
![]() |
test_images/img35.webp
ADDED
![]() |
test_images/img36.webp
ADDED
![]() |
test_images/img5.jpg
ADDED
![]() |
test_images/qingming3.jpeg
ADDED
![]() |
Git LFS Details
|