jongwoopark7978
commited on
Commit
·
54216bc
1
Parent(s):
446e69c
chore: add project files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- .gitignore +169 -0
- LLM_stage.py +168 -0
- README.md +119 -0
- VLM_stage.py +154 -0
- coarseKeyframeDetector.py +254 -0
- config/config.py +38 -0
- config/run.sh +4 -0
- extractKeyword.py +132 -0
- figures/KFSelectionFlowComparison.jpg +3 -0
- figures/architecture.png +3 -0
- figures/architecture_qualitative.png +3 -0
- figures/hkf_graph.png +3 -0
- fineKeyframeDetector.py +189 -0
- keywords/Keyword_4531questions.json +0 -0
- keywords/Keyword_500questions.jsonl +0 -0
- questions/4531questions.json +0 -0
- questions/500questions.jsonl +0 -0
- requirements.txt +24 -0
- scripts/create_caption.sh +11 -0
- scripts/eval_ES.sh +23 -0
- scripts/get_ES_captions.sh +8 -0
- src/open_clip/__init__.py +11 -0
- src/open_clip/bpe_simple_vocab_16e6.txt.gz +3 -0
- src/open_clip/constants.py +2 -0
- src/open_clip/factory.py +287 -0
- src/open_clip/hf_configs.py +60 -0
- src/open_clip/hf_model.py +164 -0
- src/open_clip/loss.py +121 -0
- src/open_clip/model.py +413 -0
- src/open_clip/model_configs/RN101-quickgelu.json +22 -0
- src/open_clip/model_configs/RN101.json +21 -0
- src/open_clip/model_configs/RN50-quickgelu.json +22 -0
- src/open_clip/model_configs/RN50.json +21 -0
- src/open_clip/model_configs/RN50x16.json +21 -0
- src/open_clip/model_configs/RN50x4.json +21 -0
- src/open_clip/model_configs/RN50x64.json +21 -0
- src/open_clip/model_configs/ViT-B-16-plus-240.json +16 -0
- src/open_clip/model_configs/ViT-B-16-plus.json +16 -0
- src/open_clip/model_configs/ViT-B-16.json +16 -0
- src/open_clip/model_configs/ViT-B-32-plus-256.json +16 -0
- src/open_clip/model_configs/ViT-B-32-quickgelu.json +17 -0
- src/open_clip/model_configs/ViT-B-32.json +16 -0
- src/open_clip/model_configs/ViT-H-14.json +17 -0
- src/open_clip/model_configs/ViT-H-16.json +17 -0
- src/open_clip/model_configs/ViT-L-14-280.json +16 -0
- src/open_clip/model_configs/ViT-L-14-336.json +16 -0
- src/open_clip/model_configs/ViT-L-14.json +16 -0
- src/open_clip/model_configs/ViT-L-16-320.json +16 -0
- src/open_clip/model_configs/ViT-L-16.json +16 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
figures/*.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
+
figures/*.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Project Specific
|
2 |
+
data/
|
3 |
+
data
|
4 |
+
datalink
|
5 |
+
|
6 |
+
# Byte-compiled / optimized / DLL files
|
7 |
+
__pycache__/
|
8 |
+
*.py[cod]
|
9 |
+
*$py.class
|
10 |
+
|
11 |
+
# C extensions
|
12 |
+
*.so
|
13 |
+
|
14 |
+
# Distribution / packaging
|
15 |
+
.Python
|
16 |
+
build/
|
17 |
+
develop-eggs/
|
18 |
+
dist/
|
19 |
+
downloads/
|
20 |
+
eggs/
|
21 |
+
.eggs/
|
22 |
+
lib/
|
23 |
+
lib64/
|
24 |
+
parts/
|
25 |
+
sdist/
|
26 |
+
var/
|
27 |
+
wheels/
|
28 |
+
share/python-wheels/
|
29 |
+
*.egg-info/
|
30 |
+
.installed.cfg
|
31 |
+
*.egg
|
32 |
+
MANIFEST
|
33 |
+
|
34 |
+
# PyInstaller
|
35 |
+
# Usually these files are written by a python script from a template
|
36 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
37 |
+
*.manifest
|
38 |
+
*.spec
|
39 |
+
|
40 |
+
# Installer logs
|
41 |
+
pip-log.txt
|
42 |
+
pip-delete-this-directory.txt
|
43 |
+
|
44 |
+
# Unit test / coverage reports
|
45 |
+
htmlcov/
|
46 |
+
.tox/
|
47 |
+
.nox/
|
48 |
+
.coverage
|
49 |
+
.coverage.*
|
50 |
+
.cache
|
51 |
+
nosetests.xml
|
52 |
+
coverage.xml
|
53 |
+
*.cover
|
54 |
+
*.py,cover
|
55 |
+
.hypothesis/
|
56 |
+
.pytest_cache/
|
57 |
+
cover/
|
58 |
+
|
59 |
+
# Translations
|
60 |
+
*.mo
|
61 |
+
*.pot
|
62 |
+
|
63 |
+
# Django stuff:
|
64 |
+
*.log
|
65 |
+
local_settings.py
|
66 |
+
db.sqlite3
|
67 |
+
db.sqlite3-journal
|
68 |
+
|
69 |
+
# Flask stuff:
|
70 |
+
instance/
|
71 |
+
.webassets-cache
|
72 |
+
|
73 |
+
# Scrapy stuff:
|
74 |
+
.scrapy
|
75 |
+
|
76 |
+
# Sphinx documentation
|
77 |
+
docs/_build/
|
78 |
+
|
79 |
+
# PyBuilder
|
80 |
+
.pybuilder/
|
81 |
+
target/
|
82 |
+
|
83 |
+
# Jupyter Notebook
|
84 |
+
.ipynb_checkpoints
|
85 |
+
|
86 |
+
# IPython
|
87 |
+
profile_default/
|
88 |
+
ipython_config.py
|
89 |
+
|
90 |
+
# pyenv
|
91 |
+
# For a library or package, you might want to ignore these files since the code is
|
92 |
+
# intended to run in multiple environments; otherwise, check them in:
|
93 |
+
# .python-version
|
94 |
+
|
95 |
+
# pipenv
|
96 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
97 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
98 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
99 |
+
# install all needed dependencies.
|
100 |
+
#Pipfile.lock
|
101 |
+
|
102 |
+
# poetry
|
103 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
104 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
105 |
+
# commonly ignored for libraries.
|
106 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
107 |
+
#poetry.lock
|
108 |
+
|
109 |
+
# pdm
|
110 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
111 |
+
#pdm.lock
|
112 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
113 |
+
# in version control.
|
114 |
+
# https://pdm.fming.dev/#use-with-ide
|
115 |
+
.pdm.toml
|
116 |
+
|
117 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
118 |
+
__pypackages__/
|
119 |
+
|
120 |
+
# Celery stuff
|
121 |
+
celerybeat-schedule
|
122 |
+
celerybeat.pid
|
123 |
+
|
124 |
+
# SageMath parsed files
|
125 |
+
*.sage.py
|
126 |
+
|
127 |
+
# Environments
|
128 |
+
.env
|
129 |
+
.venv
|
130 |
+
env/
|
131 |
+
venv/
|
132 |
+
ENV/
|
133 |
+
env.bak/
|
134 |
+
venv.bak/
|
135 |
+
|
136 |
+
# Spyder project settings
|
137 |
+
.spyderproject
|
138 |
+
.spyproject
|
139 |
+
|
140 |
+
# Rope project settings
|
141 |
+
.ropeproject
|
142 |
+
|
143 |
+
# mkdocs documentation
|
144 |
+
/site
|
145 |
+
|
146 |
+
# mypy
|
147 |
+
.mypy_cache/
|
148 |
+
.dmypy.json
|
149 |
+
dmypy.json
|
150 |
+
|
151 |
+
# Pyre type checker
|
152 |
+
.pyre/
|
153 |
+
|
154 |
+
# pytype static type analyzer
|
155 |
+
.pytype/
|
156 |
+
|
157 |
+
# Cython debug symbols
|
158 |
+
cython_debug/
|
159 |
+
**private/*
|
160 |
+
**ego_base_link/*
|
161 |
+
**.vscode/*
|
162 |
+
|
163 |
+
|
164 |
+
# PyCharm
|
165 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
166 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
167 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
168 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
169 |
+
#.idea/
|
LLM_stage.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Evaluate a model on the egoschema dataset using LVNet (captions pre-generated)
|
3 |
+
|
4 |
+
Sample Run:
|
5 |
+
|
6 |
+
python3 LLM_stage.py \
|
7 |
+
--output-dir ego_base_link \
|
8 |
+
--captions data/ES_captions_gpt4o.jsonl \
|
9 |
+
--per-vid-captions 12 \
|
10 |
+
--gptmodel "gpt-4o" \
|
11 |
+
--temperature 0.0
|
12 |
+
"""
|
13 |
+
|
14 |
+
import argparse
|
15 |
+
import json
|
16 |
+
import os
|
17 |
+
|
18 |
+
from tqdm import tqdm
|
19 |
+
|
20 |
+
from src.run_gpt import run_gpt
|
21 |
+
|
22 |
+
# You may add multiple keys to run parallel calls
|
23 |
+
dict_api = {
|
24 |
+
"api_key": "ADD",
|
25 |
+
}
|
26 |
+
|
27 |
+
|
28 |
+
_PROMPT_TEMPLATE = (
|
29 |
+
"Here are descriptions of the video frames at specific times, noted in seconds."
|
30 |
+
"\n\n{Putdesc}.\n\nThe descriptions of the frames conclude. Think step-by-step"
|
31 |
+
" and I request your selection of the most appropriate response to the following"
|
32 |
+
" question\n\nQuestion:\n{Putquestion}\n\nOptions:\n{AllOptions}"
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
def eval_model(args):
|
37 |
+
# change split to split
|
38 |
+
captions_path, data_path, split, gptmodel, temp, base_dir, job_name = (
|
39 |
+
args.captions,
|
40 |
+
args.data,
|
41 |
+
args.per_vid_captions,
|
42 |
+
args.gptmodel,
|
43 |
+
args.temperature,
|
44 |
+
args.output_dir,
|
45 |
+
args.job_name,
|
46 |
+
)
|
47 |
+
|
48 |
+
prompt = _PROMPT_TEMPLATE
|
49 |
+
|
50 |
+
os.makedirs(base_dir, exist_ok=True)
|
51 |
+
output_dir = f"{base_dir}/egoschema/{job_name}"
|
52 |
+
output_dir = os.path.expanduser(output_dir)
|
53 |
+
os.makedirs(output_dir, exist_ok=True)
|
54 |
+
|
55 |
+
save_name = captions_path.rsplit("/", 2)[-1].replace(".jsonl", "")
|
56 |
+
output_summary_path = f"{output_dir}/{save_name}.jsonl"
|
57 |
+
print(f"Saving outputs to:{output_summary_path}")
|
58 |
+
output_summary = open(output_summary_path, "w")
|
59 |
+
|
60 |
+
input_summary = [
|
61 |
+
json.loads(q) for q in open(os.path.expanduser(captions_path), "r")
|
62 |
+
]
|
63 |
+
dataset = json.load(open(os.path.expanduser(data_path), "r"))
|
64 |
+
input_len = len(input_summary)
|
65 |
+
assert (
|
66 |
+
input_len % split == 0
|
67 |
+
), f"input_len%split:{input_len%split}, input_len:{input_len}, split:{split}"
|
68 |
+
groups = input_len // split
|
69 |
+
|
70 |
+
final_prompts = []
|
71 |
+
final_info = []
|
72 |
+
for i in tqdm(range(groups)):
|
73 |
+
sidx = i * split
|
74 |
+
eidx = (i + 1) * split
|
75 |
+
|
76 |
+
desc = ""
|
77 |
+
timeline = []
|
78 |
+
for idx, e in enumerate(input_summary[sidx:eidx]):
|
79 |
+
cur_data = dataset[e["q_uid"]]
|
80 |
+
desc += e["answer"] + " "
|
81 |
+
timeline.append(e["timeline"])
|
82 |
+
|
83 |
+
if idx == split - 1: # the last of split
|
84 |
+
action_0 = cur_data["option 0"]
|
85 |
+
action_1 = cur_data["option 1"]
|
86 |
+
action_2 = cur_data["option 2"]
|
87 |
+
action_3 = cur_data["option 3"]
|
88 |
+
action_4 = cur_data["option 4"]
|
89 |
+
|
90 |
+
option_list = ""
|
91 |
+
option_number_candidate = ["one", "two", "three", "four", "five"]
|
92 |
+
option_number = option_number_candidate[4]
|
93 |
+
AllOptNumber = "option 0, option 1, option 2, option 3, option 4"
|
94 |
+
FocusOptions = ""
|
95 |
+
|
96 |
+
alloptions = f"option 0: {action_0}\noption 1: {action_1}\noption 2: {action_2}\noption 3: {action_3}\noption 4: {action_4}"
|
97 |
+
option_list = f"option 0: {action_0}\noption 1: {action_1}\noption 2: {action_2}\noption 3: {action_3}\noption 4: {action_4}"
|
98 |
+
|
99 |
+
FocusOptions += "option 0, option 1, option 2, option 3, option 4"
|
100 |
+
|
101 |
+
question = cur_data["question"]
|
102 |
+
|
103 |
+
curr_prompt = (
|
104 |
+
prompt.replace("{Putdesc}", desc)
|
105 |
+
.replace("{Putquestion}", question)
|
106 |
+
.replace("{Putoptions}", option_list)
|
107 |
+
.replace("{PutOptNumber}", option_number)
|
108 |
+
.replace("{FocusOptions}", FocusOptions)
|
109 |
+
.replace("{AllOptions}", alloptions)
|
110 |
+
.replace("{PutAllOptNumber}", AllOptNumber)
|
111 |
+
)
|
112 |
+
|
113 |
+
final_prompts.append(curr_prompt)
|
114 |
+
|
115 |
+
CA_option = {}
|
116 |
+
if "CA" in cur_data:
|
117 |
+
CA_option = {"CA": cur_data["CA"]}
|
118 |
+
|
119 |
+
info = {
|
120 |
+
"q_uid": e["q_uid"],
|
121 |
+
"prompt": curr_prompt,
|
122 |
+
"timeline": timeline,
|
123 |
+
"question": question,
|
124 |
+
"option 0": action_0,
|
125 |
+
"option 1": action_1,
|
126 |
+
"option 2": action_2,
|
127 |
+
"option 3": action_3,
|
128 |
+
"option 4": action_4,
|
129 |
+
} | CA_option
|
130 |
+
|
131 |
+
final_info.append(info)
|
132 |
+
|
133 |
+
output_VLM = run_gpt(
|
134 |
+
texts=final_prompts,
|
135 |
+
api_keys=list(dict_api.values()),
|
136 |
+
max_tokens=2000,
|
137 |
+
model=gptmodel,
|
138 |
+
temperature=temp,
|
139 |
+
num_threads=20, # Tune this
|
140 |
+
backoff_time=1 * 60,
|
141 |
+
silent=False,
|
142 |
+
dataset="egoschema",
|
143 |
+
)
|
144 |
+
|
145 |
+
output_VLM = list(output_VLM)
|
146 |
+
|
147 |
+
for q_idx, info in enumerate(tqdm(final_info)): # prompt_list = # Q&C
|
148 |
+
info["answer"] = output_VLM[q_idx]
|
149 |
+
output_summary.write(json.dumps(info) + "\n")
|
150 |
+
|
151 |
+
# finish the summarization for the current question
|
152 |
+
output_summary.close()
|
153 |
+
print(f"output_summary_path:{output_summary_path}")
|
154 |
+
|
155 |
+
|
156 |
+
if __name__ == "__main__":
|
157 |
+
parser = argparse.ArgumentParser()
|
158 |
+
parser.add_argument("--output-dir", type=str)
|
159 |
+
parser.add_argument("--job-name", type=str, default="run001")
|
160 |
+
parser.add_argument("--captions", type=str, default="data/ES_captions_gpt4o.jsonl")
|
161 |
+
parser.add_argument("--data", type=str, default="data/ES_qa_data.json")
|
162 |
+
parser.add_argument("--per-vid-captions", type=int, default=12)
|
163 |
+
parser.add_argument("--gptmodel", type=str, default="gpt-3.5-turbo-1106")
|
164 |
+
parser.add_argument("--temperature", type=float, default=None)
|
165 |
+
|
166 |
+
args = parser.parse_args()
|
167 |
+
|
168 |
+
eval_model(args)
|
README.md
CHANGED
@@ -1,3 +1,122 @@
|
|
1 |
---
|
2 |
license: apache-2.0
|
|
|
|
|
|
|
|
|
|
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: apache-2.0
|
3 |
+
tags:
|
4 |
+
- computer-vision
|
5 |
+
- video-question-answering
|
6 |
+
- zero-shot
|
7 |
+
- 9 pages workshop at neurips2024
|
8 |
---
|
9 |
+
|
10 |
+
# LVNet
|
11 |
+
|
12 |
+
[](https://paperswithcode.com/sota/zero-shot-video-question-answer-on-egoschema-1?p=too-many-frames-not-all-useful-efficient)
|
13 |
+
[](https://paperswithcode.com/sota/zero-shot-video-question-answer-on-intentqa?p=too-many-frames-not-all-useful-efficient)
|
14 |
+
[](https://paperswithcode.com/sota/zero-shot-video-question-answer-on-next-qa?p=too-many-frames-not-all-useful-efficient)
|
15 |
+
[](https://paperswithcode.com/sota/zero-shot-video-question-answer-on-egoschema?p=too-many-frames-not-all-useful-efficient)
|
16 |
+
|
17 |
+
Official Code for **_Too Many Frames, Not All Useful_: Efficient Strategies for Long-Form Video QA**
|
18 |
+
Accepted to the 9 pages Workshop on Video-Language Models at *NeurIPS 2024*
|
19 |
+
|
20 |
+
[**Paper Link**](https://arxiv.org/abs/2406.09396)
|
21 |
+
|
22 |
+
---
|
23 |
+
|
24 |
+
## Abstract
|
25 |
+
|
26 |
+
Long-form videos that span wide temporal intervals are highly information-redundant and contain multiple distinct events or entities that are often loosely-related. Therefore, when performing long-form video question answering (LVQA), all information necessary to generate a correct response can often be contained within a small subset of frames. Recent literature explores the use of large language models (LLMs) in LVQA benchmarks, achieving exceptional performance, while relying on vision language models (VLMs) to convert all visual content within videos into natural language. Such VLMs often independently caption a large number of frames uniformly sampled from long videos, which is not efficient and can mostly be redundant.
|
27 |
+
|
28 |
+
Questioning these decision choices, we explore optimal strategies for key-frame selection and sequence-aware captioning that can significantly reduce these redundancies. We propose two novel approaches that improve each aspect, namely **Hierarchical Keyframe Selector** and **Sequential Visual LLM**. Our resulting framework, called **LVNet**, achieves state-of-the-art performance across three benchmark LVQA datasets.
|
29 |
+
|
30 |
+
---
|
31 |
+
|
32 |
+
## Accuracy vs. Captions on the EgoSchema Subset
|
33 |
+
|
34 |
+
- LVNet shows a SOTA **68.2% accuracy**, using only **12** captions.
|
35 |
+
- This highlights the quality of keyframes from the Hierarchical Keyframe Selector.
|
36 |
+
|
37 |
+
<img src="figures/hkf_graph.png" alt="acc_captions" width="600"/>
|
38 |
+
|
39 |
+
---
|
40 |
+
|
41 |
+
## Hierarchical Keyframe Selector: Structural Overview
|
42 |
+
|
43 |
+
- **Overall strategy**: Generate captions via a Hierarchical Keyframe Selector, then feed them to a separate LLM to answer the question.
|
44 |
+
- **Temporal Scene Clustering (TSC)**: Divides the long video into multiple scenes, enabling per-scene subsampling.
|
45 |
+
- **Coarse Keyframe Detector (CKD)**: Selects frames best-aligned with keywords relevant to the query.
|
46 |
+
- **Fine Keyframe Detector (FKD)**: Refines the keyword alignments within a smaller set of frames via templated visual prompting.
|
47 |
+
|
48 |
+
<img src="figures/architecture.png" alt="architecture" width="600"/>
|
49 |
+
|
50 |
+
---
|
51 |
+
|
52 |
+
## Operational Visualization of HKS
|
53 |
+
|
54 |
+
- **Temporal Scene Clustering (TSC)**: 900 frames get clustered into scenes, then uniformly subsampled to produce about 280 frames.
|
55 |
+
- **Coarse Keyframe Detector (CKD)**: Among those, 32 frames are selected, based on alignment with query keywords.
|
56 |
+
- **Visual Templating**: The coarsely refined keyframes are ordered by confidence and temporal order, grouped into 4 chunks of 8 frames.
|
57 |
+
- **Fine Keyframe Detector (FKD)**: Selects the final 12 frames via further keyword alignment checks.
|
58 |
+
|
59 |
+
<img src="figures/architecture_qualitative.png" alt="hks_visualization" width="800"/>
|
60 |
+
|
61 |
+
---
|
62 |
+
|
63 |
+
## Experiments: EgoSchema, NExT-QA, and IntentQA
|
64 |
+
|
65 |
+
- LVNet achieves **61.1%**, **72.9%**, and **71.7%** on the three benchmarks, respectively, using **just 12** frames—on par with or exceeding models that use **many** more frames.
|
66 |
+
- Models with specialized video-caption pretraining or significantly more captions are shown in gray/light green for fairness comparison.
|
67 |
+
|
68 |
+
<img src="tables/table_combined.png" alt="egoschema_table" width="900"/>
|
69 |
+
|
70 |
+
---
|
71 |
+
|
72 |
+
## Comparison with Other Keyframe Selection Methods
|
73 |
+
|
74 |
+
Below is a side-by-side comparison of **LVNet** and **VideoAgent**.
|
75 |
+
- **LVNet** starts with uniform sampling but then refines keyframes via TSC, CKD, and FKD. This yields 12 frames, 8 of which show “phone usage,” the correct activity.
|
76 |
+
- **VideoAgent** continues uniform sampling due to insufficient initial frames, resulting in 0 relevant frames out of 9 and an incorrect final answer.
|
77 |
+
|
78 |
+
<img src="figures/KFSelectionFlowComparison.jpg" alt="kf_selection_flow" width="900"/>
|
79 |
+
|
80 |
+
---
|
81 |
+
|
82 |
+
## Evaluation
|
83 |
+
|
84 |
+
### Generating Answers Using LLM
|
85 |
+
|
86 |
+
You can quickly run the LLM to produce answers once you have the keyframe-based captions:
|
87 |
+
|
88 |
+
1. **Download the Captions for Dataset**
|
89 |
+
|
90 |
+
* EgoSchema: `bash scripts/get_ES_captions.sh `
|
91 |
+
|
92 |
+
2. **Run LLM** `bash scripts/eval_ES.sh`
|
93 |
+
|
94 |
+
### Generate captions using our provided modules
|
95 |
+
#### Hierarchical Keyframe Selector (HKS)
|
96 |
+
- Temporal Scene Clustering (TSC): temporalSceneClustering.py </br>
|
97 |
+
- Coarse Keyframe Detector (CKD): coarseKeyframeDetector.py </br>
|
98 |
+
- Fine Keyframe detector (FKD): fineKeyframeDetector.py </br>
|
99 |
+
|
100 |
+
1. **EgoSchema keyframe selection from images**: `bash config/run.sh `
|
101 |
+
|
102 |
+
2. **Generate captions based on the keyframes**: `bash scripts/create_caption.sh`
|
103 |
+
|
104 |
+
## Data
|
105 |
+
### Hierarchical Keyframe Selector hyper-parameters & paths
|
106 |
+
- [[LINK]](config/config.py)
|
107 |
+
|
108 |
+
### coarseKeyframeDetector.py CLIP model checkpoint
|
109 |
+
- ICCV 2023 [Perceptual Grouping in Contrastive Vision-Language Models](https://arxiv.org/abs/2210.09996)
|
110 |
+
- Checkpoint: [Download](https://github.com/kahnchana/clippy/releases/download/v1.0/clippy_5k.pt)
|
111 |
+
|
112 |
+
|
113 |
+
# Citation
|
114 |
+
```
|
115 |
+
@inproceedings{Park2024TooMF,
|
116 |
+
title={Too Many Frames, not all Useful: Efficient Strategies for Long-Form Video QA},
|
117 |
+
author={Jongwoo Park and Kanchana Ranasinghe and Kumara Kahatapitiya and Wonjeong Ryoo and Donghyun Kim and Michael S. Ryoo},
|
118 |
+
year={2024}
|
119 |
+
}
|
120 |
+
```
|
121 |
+
|
122 |
+
For more details and updates, please see our [GitHub Repository](https://github.com/jongwoopark7978/LVNet).
|
VLM_stage.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import base64
|
4 |
+
import random
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
import natsort
|
8 |
+
|
9 |
+
from PIL import Image
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch.utils.data import Dataset, DataLoader
|
14 |
+
|
15 |
+
from src.run_gpt import run_gpt
|
16 |
+
|
17 |
+
random.seed(10)
|
18 |
+
dict_api = {
|
19 |
+
"api_key":"ADD",
|
20 |
+
}
|
21 |
+
|
22 |
+
|
23 |
+
class CustomDatasetGPT(Dataset):
|
24 |
+
def __init__(self, questions, num_kf):
|
25 |
+
self.questions = questions
|
26 |
+
self.num_kf = num_kf
|
27 |
+
|
28 |
+
def __getitem__(self, index):
|
29 |
+
line = self.questions[index]
|
30 |
+
group = 4
|
31 |
+
newnum_per_group = self.num_kf // group
|
32 |
+
oldnum_per_group = len(line["VLM_path"]) // group
|
33 |
+
assert oldnum_per_group >= newnum_per_group, f"oldnum_per_group:{oldnum_per_group} is smaller than newnum_per_group:{newnum_per_group}"
|
34 |
+
|
35 |
+
new_kf_paths = []
|
36 |
+
new_kf_timelines = []
|
37 |
+
for i in range(group):
|
38 |
+
start_index = i * oldnum_per_group
|
39 |
+
end_index = start_index + oldnum_per_group
|
40 |
+
|
41 |
+
sub_kf_paths = line["VLM_path"][start_index:min(end_index, len(line["VLM_path"]))]
|
42 |
+
sub_kf_timelines = line["VLM_timeline"][start_index:min(end_index, len(line["VLM_timeline"]))]
|
43 |
+
new_kf_paths.extend(sub_kf_paths[:newnum_per_group])
|
44 |
+
new_kf_timelines.extend(sub_kf_timelines[:newnum_per_group])
|
45 |
+
|
46 |
+
kf_paths = natsort.natsorted(new_kf_paths)
|
47 |
+
kf_timelines = natsort.natsorted(new_kf_timelines)
|
48 |
+
|
49 |
+
images = []
|
50 |
+
images_base64 = []
|
51 |
+
|
52 |
+
for e in kf_paths:
|
53 |
+
images.append(Image.open(e).convert('RGB'))
|
54 |
+
images_base64.append(encode_image(e))
|
55 |
+
|
56 |
+
return images_base64, kf_paths, kf_timelines
|
57 |
+
|
58 |
+
def __len__(self):
|
59 |
+
return len(self.questions)
|
60 |
+
|
61 |
+
|
62 |
+
def encode_image(image_path):
|
63 |
+
with open(image_path, "rb") as image_file:
|
64 |
+
return base64.b64encode(image_file.read()).decode('utf-8')
|
65 |
+
|
66 |
+
def create_data_loader_gpt(questions, num_kf, batch_size=1, num_workers=4):
|
67 |
+
assert batch_size == 1, "batch_size must be 1"
|
68 |
+
|
69 |
+
dataset = CustomDatasetGPT(questions, num_kf)
|
70 |
+
data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
|
71 |
+
|
72 |
+
return data_loader, dataset
|
73 |
+
|
74 |
+
def eval_model(args):
|
75 |
+
base_dir, question_path, vlm, num_kf, temp = (
|
76 |
+
args.output_dir,
|
77 |
+
args.question_path,
|
78 |
+
args.gptmodel,
|
79 |
+
args.num_kf,
|
80 |
+
args.temp,
|
81 |
+
)
|
82 |
+
|
83 |
+
questions = [json.loads(q) for q in open(os.path.expanduser(question_path), "r")]
|
84 |
+
|
85 |
+
fname = question_path.split('/')[-1]
|
86 |
+
answer_path = f"{base_dir}/egoschema/{num_kf}/{fname}"
|
87 |
+
os.makedirs(os.path.dirname(answer_path), exist_ok=True)
|
88 |
+
print(f"question_path:{question_path}\nanswer_path:{answer_path}")
|
89 |
+
|
90 |
+
ans_file = open(answer_path, "w")
|
91 |
+
data_loader, dataset = create_data_loader_gpt(questions, num_kf)
|
92 |
+
|
93 |
+
for (base64_image, kf_paths, kf_timelines), line in tqdm(zip(data_loader, questions), total=len(questions)):
|
94 |
+
idx = line["q_uid"]
|
95 |
+
CA = line["CA"] if "CA" in line else None
|
96 |
+
option0 = line['option 0']
|
97 |
+
option1 = line['option 1']
|
98 |
+
option2 = line['option 2']
|
99 |
+
option3 = line['option 3']
|
100 |
+
option4 = line['option 4']
|
101 |
+
question = line['question']
|
102 |
+
|
103 |
+
lenwords = "50"
|
104 |
+
prompt = f"'C' stands for the cameraman. Describe the activity depicted in this first-person perspective image in less than {lenwords} words. In your answer, don't mention that the image is in first-person perspective, as we already know this."
|
105 |
+
prompts = [prompt] * num_kf
|
106 |
+
|
107 |
+
image_paths = [e[0] for e in kf_paths]
|
108 |
+
image_timelines = [e[0] for e in kf_timelines]
|
109 |
+
|
110 |
+
output_VLM = run_gpt(
|
111 |
+
images=image_paths,
|
112 |
+
texts=prompts,
|
113 |
+
api_keys=list(dict_api.values()),
|
114 |
+
max_tokens=2000,
|
115 |
+
model=vlm,
|
116 |
+
temperature=temp,
|
117 |
+
num_threads=20, # Tune this
|
118 |
+
backoff_time=1 * 60,
|
119 |
+
silent=False,
|
120 |
+
dataset="egoschema",
|
121 |
+
verbose=False,
|
122 |
+
)
|
123 |
+
|
124 |
+
output_VLM = list(output_VLM)
|
125 |
+
|
126 |
+
for j, e in enumerate(image_timelines):
|
127 |
+
line_frame = line.copy()
|
128 |
+
line_frame["answer"] = f"At {str(e)} seconds, {output_VLM[j]}"
|
129 |
+
line_frame["AR-VLM_model_id"] = vlm
|
130 |
+
line_frame["AR-VLM_prompt"] = prompts[j]
|
131 |
+
line_frame["timeline"] = float(e)
|
132 |
+
line_frame["frame_idx"] = j
|
133 |
+
line_frame["image_paths"] = image_paths
|
134 |
+
|
135 |
+
if "imgidx_kw_dict" in line_frame.keys(): line_frame.pop("imgidx_kw_dict")
|
136 |
+
if "google_drive_id" in line_frame.keys(): line_frame.pop("google_drive_id")
|
137 |
+
|
138 |
+
ans_file.write(json.dumps(line_frame)+"\n")
|
139 |
+
|
140 |
+
print(f"question.\nquestion_path:{question_path}\nanswer_path:{answer_path}")
|
141 |
+
|
142 |
+
ans_file.close()
|
143 |
+
return "job is done"
|
144 |
+
|
145 |
+
|
146 |
+
if __name__ == "__main__":
|
147 |
+
parser = argparse.ArgumentParser()
|
148 |
+
parser.add_argument("--output-dir", type=str)
|
149 |
+
parser.add_argument("--question-path", type=str, default="")
|
150 |
+
parser.add_argument("--num-kf", type=int)
|
151 |
+
parser.add_argument("--gptmodel", type=str, default="gpt-4o")
|
152 |
+
parser.add_argument("--temp", type=float, default=None)
|
153 |
+
args = parser.parse_args()
|
154 |
+
eval_model(args)
|
coarseKeyframeDetector.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import shutil
|
4 |
+
|
5 |
+
from tqdm import tqdm
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
import natsort
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from torch.utils.data import Dataset, DataLoader
|
15 |
+
|
16 |
+
from config import config
|
17 |
+
from src.open_clip import create_model_and_transforms
|
18 |
+
|
19 |
+
|
20 |
+
class loading_img(Dataset):
|
21 |
+
def __init__(self, img_list):
|
22 |
+
self.img_list = img_list
|
23 |
+
|
24 |
+
def __len__(self):
|
25 |
+
return len(self.img_list)
|
26 |
+
|
27 |
+
def __getitem__(self, idx):
|
28 |
+
return self.img_list[idx].squeeze(0)
|
29 |
+
|
30 |
+
class CustomDataset(Dataset):
|
31 |
+
def __init__(self, questions, clippy, preprocess_val, clip_size, base_dir):
|
32 |
+
self.questions = questions
|
33 |
+
self.clippy = clippy
|
34 |
+
self.clip_size = clip_size
|
35 |
+
self.preprocess_val = preprocess_val
|
36 |
+
self.device = next(clippy.parameters()).device
|
37 |
+
self.base_dir = base_dir
|
38 |
+
|
39 |
+
def __getitem__(self, index):
|
40 |
+
line = self.questions[index]
|
41 |
+
images_dir = f"{line['q_uid']}"
|
42 |
+
|
43 |
+
if line["Activity"] == "" or ("Activity" not in line): ref1 = []
|
44 |
+
|
45 |
+
else:
|
46 |
+
if isinstance(line["Activity"], list): ref1 = line["Activity"]
|
47 |
+
else: ref1 = line["Activity"].split(', ')
|
48 |
+
|
49 |
+
keywords = ref1
|
50 |
+
clip_size = self.clip_size
|
51 |
+
clippy = self.clippy
|
52 |
+
preprocess_val = self.preprocess_val
|
53 |
+
|
54 |
+
images = []
|
55 |
+
timelines = []
|
56 |
+
timelines_int = []
|
57 |
+
img_names = []
|
58 |
+
image_list = []
|
59 |
+
|
60 |
+
nframes_paths = line["filepath"]
|
61 |
+
total_len = len(nframes_paths)
|
62 |
+
nframes_paths = natsort.natsorted(nframes_paths)
|
63 |
+
|
64 |
+
img_paths = []
|
65 |
+
for img_path in nframes_paths:
|
66 |
+
img_path = self.base_dir + "/" + "/".join(img_path.split("/")[-4:])
|
67 |
+
img_paths.append(img_path)
|
68 |
+
|
69 |
+
img_names.append(img_path.split('/')[-1].split('.')[0])
|
70 |
+
cur_img = Image.open(img_path).resize(clip_size)
|
71 |
+
image_list.append(preprocess_val(cur_img))
|
72 |
+
|
73 |
+
timeline = f"{img_names[-1].split('_')[-2]}.{img_names[-1].split('_')[-1]} seconds"
|
74 |
+
timeline_int = float(f"{img_names[-1].split('_')[-2]}.{img_names[-1].split('_')[-1]}")
|
75 |
+
timelines.append(timeline)
|
76 |
+
timelines_int.append(timeline_int)
|
77 |
+
|
78 |
+
return image_list, img_paths, timelines, timelines_int, keywords, img_names
|
79 |
+
|
80 |
+
def __len__(self):
|
81 |
+
return len(self.questions)
|
82 |
+
|
83 |
+
|
84 |
+
def disable_torch_init():
|
85 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
86 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
87 |
+
|
88 |
+
def SortSimilarity(q_uid, simmat, keywords, nimgtokens, nframes_paths, maximgslen):
|
89 |
+
sort_simmat, sort_idx = torch.sort(simmat, dim=-1, descending=True)
|
90 |
+
sort_idx = torch.floor(sort_idx/nimgtokens).to(int)
|
91 |
+
|
92 |
+
curimgslen = 0
|
93 |
+
|
94 |
+
imgidx_kw_dict = dict()
|
95 |
+
numrow, numcol = sort_simmat.shape
|
96 |
+
|
97 |
+
row_col_list = [0 for _ in range(numrow)]
|
98 |
+
token = True
|
99 |
+
|
100 |
+
while token:
|
101 |
+
j = 0
|
102 |
+
while j < numrow:
|
103 |
+
k = 0
|
104 |
+
i = row_col_list[j]
|
105 |
+
|
106 |
+
while k < numcol-i:
|
107 |
+
col_idx = i+k
|
108 |
+
k += 1
|
109 |
+
|
110 |
+
simvalue = sort_simmat[j, col_idx].item()
|
111 |
+
img_idx = sort_idx[j, col_idx].item()
|
112 |
+
|
113 |
+
curr_keyword = keywords[j]
|
114 |
+
curr_kfpath = nframes_paths[img_idx]
|
115 |
+
|
116 |
+
if img_idx in imgidx_kw_dict: continue
|
117 |
+
|
118 |
+
else:
|
119 |
+
imgidx_kw_dict[img_idx] = {"kw": curr_keyword, "simvalue": simvalue, "kf_path": curr_kfpath, "kw_others": []}
|
120 |
+
curimgslen += 1
|
121 |
+
|
122 |
+
row_col_list[j] = col_idx + 1
|
123 |
+
if curimgslen == maximgslen: return imgidx_kw_dict
|
124 |
+
else: break
|
125 |
+
|
126 |
+
j += 1
|
127 |
+
|
128 |
+
if sum(row_col_list) >= numrow*(numcol-1): token = False
|
129 |
+
|
130 |
+
def create_data_loader(questions, clippy, preprocess_val, clip_size, base_dir, batch_size=1, num_workers=16):
|
131 |
+
assert batch_size == 1, "batch_size must be 1"
|
132 |
+
dataset = CustomDataset(questions, clippy, preprocess_val, clip_size, base_dir)
|
133 |
+
data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
|
134 |
+
return data_loader
|
135 |
+
|
136 |
+
def eval_model():
|
137 |
+
disable_torch_init()
|
138 |
+
question_path, maximgslen, base_dir, concatname, modelpath, answerpath, concatdir = config.question_path, config.maximgslen, config.base_dir, config.concatname, config.modelpath, config.answerpath, config.concatdir
|
139 |
+
|
140 |
+
pretrained_ckpt = f"{modelpath}"
|
141 |
+
clippy, preprocess_train, preprocess_val = create_model_and_transforms(
|
142 |
+
"clippy-B-16",
|
143 |
+
device="cuda",
|
144 |
+
pretrained=pretrained_ckpt
|
145 |
+
)
|
146 |
+
clip_size = (224,224)
|
147 |
+
device = next(clippy.parameters()).device
|
148 |
+
|
149 |
+
questions = [json.loads(q) for q in open(os.path.expanduser(question_path), "r")]
|
150 |
+
|
151 |
+
answer_path = f"{answerpath}"
|
152 |
+
print(f"\nquestion_path:{question_path}\nanswer_path:{answer_path}")
|
153 |
+
os.makedirs(os.path.dirname(answer_path), exist_ok=True)
|
154 |
+
|
155 |
+
with open(answer_path, "w") as ans_file:
|
156 |
+
data_loader = create_data_loader(questions, clippy, preprocess_val, clip_size, base_dir)
|
157 |
+
concatimg_dir_base = f"{concatdir}"
|
158 |
+
|
159 |
+
with torch.no_grad():
|
160 |
+
for (image_list, nframes_paths, timelines, timelines_int, keywords, img_names), line in tqdm(zip(data_loader, questions), total=len(questions)):
|
161 |
+
q_uid = line["q_uid"]
|
162 |
+
CA = line["CA"] if "CA" in line else None
|
163 |
+
option0 = line['option 0']
|
164 |
+
option1 = line['option 1']
|
165 |
+
option2 = line['option 2']
|
166 |
+
option3 = line['option 3']
|
167 |
+
option4 = line['option 4']
|
168 |
+
question = line['question']
|
169 |
+
|
170 |
+
pastobj = None
|
171 |
+
past_VLM_path = None
|
172 |
+
past_VLM_timeline = None
|
173 |
+
|
174 |
+
img_embed = []
|
175 |
+
nframes_paths = [e[0] for e in nframes_paths]
|
176 |
+
|
177 |
+
image_set = loading_img(image_list)
|
178 |
+
image_loader = DataLoader(image_set, batch_size=64, shuffle=False, num_workers=16)
|
179 |
+
for e in image_loader: img_embed.append(clippy.encode_image(e.to(device), pool=False)[:, 1:])
|
180 |
+
img_embed = torch.concat(img_embed, dim=0)
|
181 |
+
|
182 |
+
limit_keywords = config.limit_keywords
|
183 |
+
keywords = [e[0] for e in keywords][:limit_keywords]
|
184 |
+
keyword_embed = clippy.text.encode(keywords, convert_to_tensor=True)
|
185 |
+
|
186 |
+
nframe, nimgtokens, channels = img_embed.shape
|
187 |
+
keyword_embed = keyword_embed.unsqueeze(1)
|
188 |
+
img_embed = img_embed.flatten(0, 1).unsqueeze(0)
|
189 |
+
|
190 |
+
simmat = F.cosine_similarity(keyword_embed, img_embed, dim=-1).to(torch.float)
|
191 |
+
imgidx_kw_dict = SortSimilarity(q_uid, simmat, keywords, nimgtokens, nframes_paths, maximgslen=maximgslen)
|
192 |
+
|
193 |
+
# order of simvalue
|
194 |
+
simvalue = np.array([e["simvalue"] for e in imgidx_kw_dict.values()])
|
195 |
+
ordered_idx = np.argsort(simvalue)
|
196 |
+
simvalue = simvalue[ordered_idx]
|
197 |
+
kf_paths = np.array([e["kf_path"] for e in imgidx_kw_dict.values()])[ordered_idx]
|
198 |
+
matchingkw = np.array([e["kw"] for e in imgidx_kw_dict.values()])[ordered_idx]
|
199 |
+
|
200 |
+
#order by timeline
|
201 |
+
time_kf_paths = np.array(kf_paths[:16])
|
202 |
+
timelines_int = np.array([float(f"{e.replace('.jpg', '').split('/')[-1].split('_')[1]}" + "."+ f"{e.replace('.jpg', '').split('/')[-1].split('_')[2]}") for e in time_kf_paths])
|
203 |
+
time_ordered_idx = np.argsort(timelines_int)
|
204 |
+
|
205 |
+
timelines_int = timelines_int[time_ordered_idx]
|
206 |
+
time_simvalue = np.array(simvalue[:16])[time_ordered_idx]
|
207 |
+
time_kf_paths = np.array(time_kf_paths)[time_ordered_idx]
|
208 |
+
time_matchingkw = np.array(matchingkw[:16])[time_ordered_idx]
|
209 |
+
|
210 |
+
simvalue[:16] = time_simvalue
|
211 |
+
kf_paths[:16] = time_kf_paths
|
212 |
+
matchingkw[:16] = time_matchingkw
|
213 |
+
|
214 |
+
segment_timeline = f"{timelines[0][0].split(' seconds')[0]}-{timelines[-1][0].split(' seconds')[0]}"
|
215 |
+
|
216 |
+
imgw, imgh = Image.open(kf_paths[0]).size
|
217 |
+
redwidth = 20
|
218 |
+
newimgw, newimgh = (imgw+redwidth) * 4 + redwidth, (imgh+redwidth) * 2 + redwidth
|
219 |
+
concatimg = np.zeros((newimgh, newimgw, 3), dtype=np.uint8)
|
220 |
+
concatimg[:, :, 0] = 255
|
221 |
+
concatimglist = []
|
222 |
+
concatimg_dir = f"{concatimg_dir_base}/{q_uid}"
|
223 |
+
|
224 |
+
for i, cpath in enumerate(kf_paths):
|
225 |
+
cur_img = np.array(Image.open(cpath))
|
226 |
+
whole_frame = 8
|
227 |
+
remainder = i % whole_frame
|
228 |
+
rowremainder = i % (whole_frame//2)
|
229 |
+
startwidth = redwidth + (imgw + redwidth)*rowremainder
|
230 |
+
endwidth = startwidth + imgw
|
231 |
+
|
232 |
+
if remainder / whole_frame < 0.5: concatimg[redwidth:redwidth+imgh, startwidth:endwidth, :] = cur_img
|
233 |
+
else: concatimg[redwidth+imgh+redwidth:newimgh-redwidth, startwidth:endwidth, :] = cur_img
|
234 |
+
|
235 |
+
if remainder == whole_frame - 1: concatimglist.append(Image.fromarray(concatimg))
|
236 |
+
|
237 |
+
if os.path.exists(concatimg_dir): shutil.rmtree(concatimg_dir)
|
238 |
+
os.makedirs(f"{concatimg_dir}", exist_ok=True)
|
239 |
+
for i, img in enumerate(concatimglist): img.save(f"{concatimg_dir}/concat_{i}.jpg")
|
240 |
+
|
241 |
+
line["kf_paths"] = kf_paths.tolist()
|
242 |
+
line["keywords"] = matchingkw.tolist()
|
243 |
+
line["simvalue"] = simvalue.tolist()
|
244 |
+
line["imgidx_kw_dict"] = imgidx_kw_dict
|
245 |
+
line["segment_timeline"] = segment_timeline
|
246 |
+
line["concatimg_dir"] = concatimg_dir
|
247 |
+
|
248 |
+
ans_file.write(json.dumps(line) + "\n")
|
249 |
+
|
250 |
+
print(f"question_path:{question_path}\nanswer_path:{answer_path}")
|
251 |
+
|
252 |
+
|
253 |
+
if __name__ == "__main__":
|
254 |
+
eval_model()
|
config/config.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# base
|
2 |
+
dict_api = {
|
3 |
+
"api_key":"ADD",
|
4 |
+
}
|
5 |
+
base_dir = "your_path" # your data path ex) img_folder_path/egoschema
|
6 |
+
|
7 |
+
|
8 |
+
# scene clustering
|
9 |
+
divlam = 12
|
10 |
+
f_path = "your_path" # files from keywords dir
|
11 |
+
q_path = "your_path" # files from questions dir
|
12 |
+
a_path = "your_path"
|
13 |
+
img_folder = "your_path" # your img folder path ex) img_folder_path/egoschema/frames_900_4531/q_uid/image_sec_millisec.jpg
|
14 |
+
|
15 |
+
|
16 |
+
# coarse key frame detector
|
17 |
+
maximgslen = 32
|
18 |
+
limit_keywords = 25
|
19 |
+
concatname = "LVnet"
|
20 |
+
modelpath = "your_path" # model path
|
21 |
+
question_path = "your_path" # recommend using the same path with scene clustering answer path
|
22 |
+
answerpath = f"{base_dir}/kwkfmatching/kf_{concatname}.jsonl" # kwkfmatching is not necessary.
|
23 |
+
concatdir = f"{base_dir}/kwkfmatching/concatimg_{concatname}" # kwkfmatching is not necessary.
|
24 |
+
|
25 |
+
|
26 |
+
# fine key frame detector
|
27 |
+
kf_vlm = "gpt-4o"
|
28 |
+
kf_temp = None
|
29 |
+
kf_num_select = 3
|
30 |
+
kf_num_input_imgs = 32
|
31 |
+
kf_question_path = "your_path" # recommend using the same path with coarse key frame detector answer path
|
32 |
+
kf_answer_path = f"{base_dir}/kf_VLM/kf_VLM{kf_num_input_imgs}sel{kf_num_select}_{kf_question_path.split('/')[-1].split('.')[0]}.jsonl" # kf_VLM is not necessary.
|
33 |
+
|
34 |
+
|
35 |
+
# fine key frame detector refine
|
36 |
+
refine_num_group = 4
|
37 |
+
refine_kflen = 12
|
38 |
+
refine_output_path = f"{base_dir}/kf_VLM/refine/" + kf_answer_path.split('/')[-1] # kf_VLM is not necessary.
|
config/run.sh
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
devnum=${1:-0}
|
2 |
+
CUDA_VISIBLE_DEVICES=$devnum python3 temporalSceneClustering.py
|
3 |
+
CUDA_VISIBLE_DEVICES=$devnum python3 coarseKeyframeDetector.py
|
4 |
+
CUDA_VISIBLE_DEVICES=$devnum python3 fineKeyframeDetector.py
|
extractKeyword.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
|
5 |
+
from tqdm import tqdm
|
6 |
+
from src.run_gpt import run_gpt
|
7 |
+
|
8 |
+
"""
|
9 |
+
Extract keywords from the given question and options
|
10 |
+
|
11 |
+
Sample Run
|
12 |
+
python3 extractKeyword.py --output-dir ego_base_link --question questions/500questions.jsonl --gptmodel "gpt-4-1106-preview"
|
13 |
+
|
14 |
+
"""
|
15 |
+
|
16 |
+
|
17 |
+
# You may add multiple keys if you want parallel calls
|
18 |
+
dict_api = {
|
19 |
+
"api_key": "ADD",
|
20 |
+
}
|
21 |
+
|
22 |
+
PROMPT = (
|
23 |
+
"Think step-by-step and for each option, identify all the specified activities. "
|
24 |
+
"Each description of activity should use active voice with plain verbs, contain fewer than six words, "
|
25 |
+
"and retains as many original terms from the options as possible.\n"
|
26 |
+
"Here are the options:\n\n"
|
27 |
+
"option 0: {Putop0}\n"
|
28 |
+
"option 1: {Putop1}\n"
|
29 |
+
"option 2: {Putop2}\n"
|
30 |
+
"option 3: {Putop3}\n"
|
31 |
+
"option 4: {Putop4}\n"
|
32 |
+
"option 5: {Putquestion}.\n"
|
33 |
+
"All the options were introduced. 'C' represents the camera operator in the options. "
|
34 |
+
"Your answer should follow the JSON format shown below and should only include the JSON result. "
|
35 |
+
"Do not output any warnings or notes under any circumstances. Instead, adhere strictly to the provided JSON format example.\n"
|
36 |
+
"This is one example output format.\n"
|
37 |
+
"{\"option 0\": [\"plays soccer\", \"go to school\"], \"option 1\": [\"go to the gym\", \"go to school\"], "
|
38 |
+
"\"option 2\": [\"go to school\", \"dry hair\"], \"option 3\": [\"plays basketball\", \"look at the tree\"], "
|
39 |
+
"\"option 4\": [\"plays soccer\", \"drop the ball\"], \"option 5\": [\"turn the table\", \"go to school\"]}"
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
def main(args):
|
44 |
+
# 1. Create output directories
|
45 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
46 |
+
job_dir = os.path.join(args.output_dir, "extractedKeywords")
|
47 |
+
os.makedirs(job_dir, exist_ok=True)
|
48 |
+
|
49 |
+
|
50 |
+
# 2. Build the output file name (based on --question)
|
51 |
+
question_file_name = os.path.basename(args.question).replace(".jsonl", "")
|
52 |
+
output_summary_path = os.path.join(job_dir, f"{question_file_name}.jsonl")
|
53 |
+
print(f"Saving outputs to: {output_summary_path}")
|
54 |
+
|
55 |
+
# 3. Read the question file
|
56 |
+
with open(os.path.expanduser(args.question), "r") as f:
|
57 |
+
question_data = [json.loads(line) for line in f]
|
58 |
+
|
59 |
+
# 4. Construct final prompts
|
60 |
+
final_prompts = []
|
61 |
+
final_info = []
|
62 |
+
for entry in tqdm(question_data, desc="Building prompts"):
|
63 |
+
q_uid = entry["q_uid"]
|
64 |
+
# Insert each option + question into the embedded prompt
|
65 |
+
cur_prompt = (
|
66 |
+
PROMPT
|
67 |
+
.replace("{Putop0}", entry["option 0"])
|
68 |
+
.replace("{Putop1}", entry["option 1"])
|
69 |
+
.replace("{Putop2}", entry["option 2"])
|
70 |
+
.replace("{Putop3}", entry["option 3"])
|
71 |
+
.replace("{Putop4}", entry["option 4"])
|
72 |
+
.replace("{Putquestion}", entry["question"])
|
73 |
+
)
|
74 |
+
|
75 |
+
final_prompts.append(cur_prompt)
|
76 |
+
|
77 |
+
# Track data for JSON output
|
78 |
+
info = {
|
79 |
+
"q_uid": q_uid,
|
80 |
+
"prompt": cur_prompt,
|
81 |
+
"option 0": entry["option 0"],
|
82 |
+
"option 1": entry["option 1"],
|
83 |
+
"option 2": entry["option 2"],
|
84 |
+
"option 3": entry["option 3"],
|
85 |
+
"option 4": entry["option 4"],
|
86 |
+
"question": entry["question"],
|
87 |
+
}
|
88 |
+
|
89 |
+
# Include ground-truth label if present
|
90 |
+
if "CA" in entry:
|
91 |
+
info["CA"] = entry["CA"]
|
92 |
+
|
93 |
+
final_info.append(info)
|
94 |
+
|
95 |
+
# 5. Call GPT
|
96 |
+
print("Sending prompts to GPT. This may take a while...")
|
97 |
+
output_results = run_gpt(
|
98 |
+
texts=final_prompts,
|
99 |
+
api_keys=list(dict_api.values()),
|
100 |
+
max_tokens=2000,
|
101 |
+
model=args.gptmodel,
|
102 |
+
temperature=args.temperature,
|
103 |
+
num_threads=5, # Adjust as needed
|
104 |
+
backoff_time=60, # Adjust as needed
|
105 |
+
silent=False,
|
106 |
+
dataset="extractKeyword",
|
107 |
+
)
|
108 |
+
|
109 |
+
output_results = list(output_results)
|
110 |
+
|
111 |
+
# 6. Save results
|
112 |
+
with open(output_summary_path, "w") as outfile:
|
113 |
+
for i, info in enumerate(final_info):
|
114 |
+
info["answer"] = output_results[i]
|
115 |
+
outfile.write(json.dumps(info) + "\n")
|
116 |
+
|
117 |
+
print(f"Done! Results written to {output_summary_path}")
|
118 |
+
|
119 |
+
|
120 |
+
if __name__ == "__main__":
|
121 |
+
parser = argparse.ArgumentParser()
|
122 |
+
parser.add_argument("--output-dir", type=str, required=True,
|
123 |
+
help="Directory to store the resulting JSONL file.")
|
124 |
+
parser.add_argument("--question", type=str, required=True,
|
125 |
+
help="Path to the JSONL file with question data (e.g., 500questions.jsonl).")
|
126 |
+
parser.add_argument("--gptmodel", type=str, default="gpt-4-1106-preview",
|
127 |
+
help="The GPT model to call.")
|
128 |
+
parser.add_argument("--temperature", type=float, default=None,
|
129 |
+
help="Temperature parameter for GPT.")
|
130 |
+
|
131 |
+
args = parser.parse_args()
|
132 |
+
main(args)
|
figures/KFSelectionFlowComparison.jpg
ADDED
![]() |
Git LFS Details
|
figures/architecture.png
ADDED
![]() |
Git LFS Details
|
figures/architecture_qualitative.png
ADDED
![]() |
Git LFS Details
|
figures/hkf_graph.png
ADDED
![]() |
Git LFS Details
|
fineKeyframeDetector.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import base64
|
4 |
+
import natsort
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from PIL import Image
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch.utils.data import Dataset, DataLoader
|
13 |
+
|
14 |
+
from config import config
|
15 |
+
from src.refine import refine_answer
|
16 |
+
from src.run_gpt import run_gpt
|
17 |
+
|
18 |
+
class CustomDatasetGPT(Dataset):
|
19 |
+
def __init__(self, questions, num_input_imgs, num_select):
|
20 |
+
self.questions = questions
|
21 |
+
self.num_input_imgs = num_input_imgs
|
22 |
+
self.num_select = num_select
|
23 |
+
|
24 |
+
def __getitem__(self, index):
|
25 |
+
line = self.questions[index]
|
26 |
+
num_select = self.num_select
|
27 |
+
num_input_imgs = self.num_input_imgs
|
28 |
+
|
29 |
+
giter = 0
|
30 |
+
imgs_group = 8
|
31 |
+
num_groups = num_input_imgs//imgs_group
|
32 |
+
|
33 |
+
kf_paths = line["kf_paths"]
|
34 |
+
keywords = line["keywords"]
|
35 |
+
simvalue = line["simvalue"]
|
36 |
+
concatimg_dir = line['concatimg_dir']
|
37 |
+
|
38 |
+
concatimg_paths = natsort.natsorted([f"{concatimg_dir}/{im}" for im in os.listdir(concatimg_dir) if "ipynb" not in im])
|
39 |
+
|
40 |
+
concatimages = []
|
41 |
+
concatimages_base64 = []
|
42 |
+
qs_org = []
|
43 |
+
kw_perconcat = []
|
44 |
+
kf_paths_perconcat = []
|
45 |
+
simvalue_perconcat = []
|
46 |
+
segment_timeline = []
|
47 |
+
|
48 |
+
for concatidx, img_path in enumerate(concatimg_paths):
|
49 |
+
concatimages.append(Image.open(img_path).convert('RGB'))
|
50 |
+
concatimages_base64.append(img_path)
|
51 |
+
|
52 |
+
kw_sidx = imgs_group*(concatidx)
|
53 |
+
kw_eidx = imgs_group*(concatidx+1)
|
54 |
+
|
55 |
+
concat_kw = keywords[kw_sidx:kw_eidx]
|
56 |
+
qs_org_ = create_question(concat_kw, num_select)
|
57 |
+
|
58 |
+
kw_perconcat.append(concat_kw)
|
59 |
+
qs_org.append(qs_org_)
|
60 |
+
kf_paths_perconcat.append(kf_paths[kw_sidx:kw_eidx])
|
61 |
+
simvalue_perconcat.append(simvalue[kw_sidx:kw_eidx])
|
62 |
+
segment_timeline.append(line["segment_timeline"])
|
63 |
+
|
64 |
+
concatimg_paths = concatimg_paths[-num_groups:]
|
65 |
+
concatimages_base64 = concatimages_base64[-num_groups:]
|
66 |
+
qs_org = qs_org[-num_groups:]
|
67 |
+
kw_perconcat = kw_perconcat[-num_groups:]
|
68 |
+
kf_paths_perconcat = kf_paths_perconcat[-num_groups:]
|
69 |
+
simvalue_perconcat = simvalue_perconcat[-num_groups:]
|
70 |
+
segment_timeline = segment_timeline[-num_groups:]
|
71 |
+
|
72 |
+
return concatimages_base64, concatimages[0].size, kw_perconcat, kf_paths_perconcat, qs_org, segment_timeline, concatimg_paths, simvalue_perconcat
|
73 |
+
|
74 |
+
def __len__(self):
|
75 |
+
return len(self.questions)
|
76 |
+
|
77 |
+
def create_question(concat_kw, num_select):
|
78 |
+
imgkw_sen = ""
|
79 |
+
|
80 |
+
for i, e in enumerate(concat_kw):
|
81 |
+
if i < len(concat_kw) - 1: imgkw_sen = imgkw_sen + f"Image_{i}: '{e}', "
|
82 |
+
else: imgkw_sen = imgkw_sen + f"Image_{i}: '{e}'."
|
83 |
+
|
84 |
+
if num_select == 3:
|
85 |
+
prompt = f"Eight images, having egocentric perspectives, are juxtaposed, separated by a red vertical line and red horizontal line. In the first row, the images from left to right are named as image_0, image_1, image_2, image_3. In the second row, the images from left to right are named as image_4, image_5, image_6, image_7. Here are images and their associated guess words: {imgkw_sen}. Think step-by-step and list only the names of the {num_select} images most closely related to the guessed words. Do not select blurry images in your answer. If none of the images correspond to the provided guess words, choose any two images at random. Your answer should follow the JSON format shown below and should only include the JSON result. Do not output any warnings or notes under any circumstances. Instead, adhere strictly to the provided JSON format example.\n"
|
86 |
+
prompt += "{\"image name\": write reason for your selection in 10 words}."
|
87 |
+
prompt += " This is one example output format. {\n \"image_0\": \"Person washing a plate; linked to dish cleaning.\",\n \"image_2\": \"Person washing a bowl; linked to dish cleaning.\",\n \"image_6\": \"Person running water on a sponge; related to dish cleaning.\"\n}"
|
88 |
+
|
89 |
+
elif num_select == 4:
|
90 |
+
prompt = f"Eight images, having egocentric perspectives, are juxtaposed, separated by a red vertical line and red horizontal line. In the first row, the images from left to right are named as image_0, image_1, image_2, image_3. In the second row, the images from left to right are named as image_4, image_5, image_6, image_7. Here are images and their associated guess words: {imgkw_sen}. Think step-by-step and list only the names of the {num_select} images most closely related to the guessed words. Do not select blurry images in your answer. If none of the images correspond to the provided guess words, choose any two images at random. Your answer should follow the JSON format shown below and should only include the JSON result. Do not output any warnings or notes under any circumstances. Instead, adhere strictly to the provided JSON format example.\n"
|
91 |
+
prompt += "{\"image name\": write reason for your selection in 10 words}."
|
92 |
+
prompt += " This is one example output format. {\n \"image_0\": \"Person washing a plate; linked to dish cleaning.\",\n \"image_2\": \"Person washing a bowl; linked to dish cleaning.\",\n \"image_6\": \"Person running water on a sponge; related to dish cleaning.\",\n \"image_7\": \"Person moves mouse; linked to working.\"\n}"
|
93 |
+
|
94 |
+
else: assert False, f"num_select:{num_select} is not defined yet"
|
95 |
+
|
96 |
+
return prompt
|
97 |
+
|
98 |
+
def encode_image(image_path):
|
99 |
+
with open(image_path, "rb") as image_file:
|
100 |
+
return base64.b64encode(image_file.read()).decode('utf-8')
|
101 |
+
|
102 |
+
def create_data_loader_gpt(questions, num_input_imgs, num_select, batch_size=1, num_workers=4):
|
103 |
+
assert batch_size == 1, "batch_size must be 1"
|
104 |
+
dataset = CustomDatasetGPT(questions, num_input_imgs, num_select)
|
105 |
+
data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
|
106 |
+
return data_loader, dataset
|
107 |
+
|
108 |
+
def eval_model():
|
109 |
+
question_path, vlm, num_input_imgs, num_select, temp = config.kf_question_path, config.kf_vlm, config.kf_num_input_imgs, config.kf_num_select, config.kf_temp
|
110 |
+
questions = [json.loads(q) for q in open(os.path.expanduser(question_path), "r")]
|
111 |
+
num_questions = len(questions)
|
112 |
+
giter = 0
|
113 |
+
|
114 |
+
answer_path = config.kf_answer_path
|
115 |
+
os.makedirs(os.path.dirname(answer_path), exist_ok=True)
|
116 |
+
|
117 |
+
print(f"question_path:{question_path}\nanswer_path:{answer_path}")
|
118 |
+
ans_file = open(answer_path, "w")
|
119 |
+
data_loader, dataset = create_data_loader_gpt(questions, num_input_imgs, num_select)
|
120 |
+
|
121 |
+
outputs = ""
|
122 |
+
for (image_paths, image_sizes, kw_perconcat, kf_paths_perconcat, cur_prompts, segment_timeline, concatimg_paths, simvalue_perconcat), line in tqdm(zip(data_loader, questions), total=len(questions)):
|
123 |
+
idx, q_uid = line["q_uid"], line["q_uid"]
|
124 |
+
CA = line["CA"] if "CA" in line else None
|
125 |
+
option0 = line['option 0']
|
126 |
+
option1 = line['option 1']
|
127 |
+
option2 = line['option 2']
|
128 |
+
option3 = line['option 3']
|
129 |
+
option4 = line['option 4']
|
130 |
+
question = line['question']
|
131 |
+
|
132 |
+
pastobj = None
|
133 |
+
past_VLM_path = None
|
134 |
+
past_VLM_timeline = None
|
135 |
+
|
136 |
+
kw_VLM = []
|
137 |
+
kf_paths_VLM = []
|
138 |
+
kf_timeline = []
|
139 |
+
|
140 |
+
kw_VLM_ordered = []
|
141 |
+
kf_paths_VLM_ordered = []
|
142 |
+
kf_timeline_ordered = []
|
143 |
+
|
144 |
+
prompts = [x[0] for x in cur_prompts]
|
145 |
+
image_paths = [x[0] for x in image_paths]
|
146 |
+
|
147 |
+
output_VLM = run_gpt(
|
148 |
+
images=image_paths,
|
149 |
+
texts=prompts,
|
150 |
+
api_keys = list(config.dict_api.values()),
|
151 |
+
max_tokens=2000,
|
152 |
+
model=vlm,
|
153 |
+
temperature=temp,
|
154 |
+
num_threads=20,
|
155 |
+
backoff_time=1*60,
|
156 |
+
silent=False,
|
157 |
+
dataset="egoschema",
|
158 |
+
verbose=False,
|
159 |
+
)
|
160 |
+
output_VLM = list(output_VLM)
|
161 |
+
|
162 |
+
for j, _ in enumerate(cur_prompts):
|
163 |
+
kf_paths_perconcat_ = kf_paths_perconcat[j]
|
164 |
+
kf_timeline.append([f"{e[0].split('_')[-2]}.{e[0].split('_')[-1].split('.')[0]}" for e in kf_paths_perconcat_])
|
165 |
+
|
166 |
+
line_frame = line.copy()
|
167 |
+
|
168 |
+
line_frame["output_VLM"] = output_VLM
|
169 |
+
line_frame["concatimg_paths"] = concatimg_paths
|
170 |
+
line_frame["kf_paths_VLM"] = kf_paths_perconcat
|
171 |
+
line_frame["kf_timeline"] = kf_timeline
|
172 |
+
line_frame["kw_perconcat_clip"] = kw_perconcat
|
173 |
+
line_frame["iter"] = giter
|
174 |
+
|
175 |
+
line_frame.pop("filepath")
|
176 |
+
line_frame.pop("kf_paths")
|
177 |
+
line_frame.pop("google_drive_id")
|
178 |
+
|
179 |
+
try: ans_file.write(json.dumps(line_frame) + "\n")
|
180 |
+
except: assert False, f"line_frame:{line_frame}"
|
181 |
+
|
182 |
+
ans_file.close()
|
183 |
+
print(f"question_path:{question_path}\nanswer_path:{answer_path}")
|
184 |
+
print("job is done")
|
185 |
+
|
186 |
+
|
187 |
+
if __name__ == "__main__":
|
188 |
+
eval_model()
|
189 |
+
refine_answer()
|
keywords/Keyword_4531questions.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
keywords/Keyword_500questions.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
questions/4531questions.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
questions/500questions.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
requirements.txt
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.1.2
|
2 |
+
torchvision==0.16.2
|
3 |
+
transformers==4.37.2
|
4 |
+
tokenizers==0.15.1
|
5 |
+
sentencepiece==0.1.99
|
6 |
+
shortuuid
|
7 |
+
|
8 |
+
peft
|
9 |
+
|
10 |
+
numpy
|
11 |
+
|
12 |
+
requests
|
13 |
+
|
14 |
+
uvicorn
|
15 |
+
fastapi
|
16 |
+
|
17 |
+
rp
|
18 |
+
tqdm
|
19 |
+
|
20 |
+
shutil
|
21 |
+
|
22 |
+
PIL
|
23 |
+
natsort
|
24 |
+
clip
|
scripts/create_caption.sh
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# run this script from the root of the repo
|
2 |
+
# fix paths to data and update GPT keys in code
|
3 |
+
|
4 |
+
export PYTHONPATH=$PYTHONPATH:$PWD
|
5 |
+
|
6 |
+
python3 VLM_stage.py \
|
7 |
+
--output-dir ego_base_link_bigtensor \
|
8 |
+
--question-path your_question_path.jsonl \
|
9 |
+
--gptmodel "gpt-4o" \
|
10 |
+
--num-kf 12 \
|
11 |
+
--temp 0
|
scripts/eval_ES.sh
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# run this script from the root of the repo
|
2 |
+
# fix paths to data and update GPT keys in code
|
3 |
+
|
4 |
+
export PYTHONPATH=$PYTHONPATH:$PWD
|
5 |
+
|
6 |
+
# Evaluate on ES subset (uncomment it if you want to run it)
|
7 |
+
python3 LLM_stage.py \
|
8 |
+
--output-dir ego_base_link \
|
9 |
+
--captions data/ESsub_captions_gpt4o.jsonl \
|
10 |
+
--data data/ESsub_qa_data.json \
|
11 |
+
--per-vid-captions 12 \
|
12 |
+
--gptmodel "gpt-4o" \
|
13 |
+
--temperature 0.0
|
14 |
+
|
15 |
+
|
16 |
+
# Evaluate on ES full dataset (uncomment it if you want to run it)
|
17 |
+
# python3 LLM_stage.py \
|
18 |
+
# --output-dir ego_base_link \
|
19 |
+
# --captions data/ES_captions_gpt4o.jsonl \
|
20 |
+
# --data data/ES_qa_data.json \
|
21 |
+
# --per-vid-captions 12 \
|
22 |
+
# --gptmodel "gpt-4o" \
|
23 |
+
# --temperature 0.0
|
scripts/get_ES_captions.sh
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# downloaded our pre-generated EgoSchema captions
|
2 |
+
|
3 |
+
wget https://github.com/jongwoopark7978/LVNet_dev/releases/download/v1.0/ESsub_captions_gpt4o.jsonl
|
4 |
+
wget TBA
|
5 |
+
wget TBA
|
6 |
+
wget TBA
|
7 |
+
|
8 |
+
# TODO: update the paths when the repo is public.
|
src/open_clip/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from llava.open_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
2 |
+
from llava.open_clip.factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer
|
3 |
+
from llava.open_clip.factory import list_models, add_model_config, get_model_config, load_checkpoint
|
4 |
+
from llava.open_clip.loss import ClipLoss
|
5 |
+
from llava.open_clip.model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg,\
|
6 |
+
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
|
7 |
+
from llava.open_clip.openai import load_openai_model, list_openai_models
|
8 |
+
from llava.open_clip.pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\
|
9 |
+
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
|
10 |
+
from llava.open_clip.tokenizer import SimpleTokenizer, tokenize
|
11 |
+
from llava.open_clip.transform import image_transform
|
src/open_clip/bpe_simple_vocab_16e6.txt.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
+
size 1356917
|
src/open_clip/constants.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
2 |
+
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
|
src/open_clip/factory.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import pathlib
|
5 |
+
import re
|
6 |
+
from copy import deepcopy
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import Optional, Tuple, Union
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from llava.open_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
13 |
+
from llava.open_clip.model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
|
14 |
+
resize_pos_embed, get_cast_dtype
|
15 |
+
from llava.open_clip.openai import load_openai_model
|
16 |
+
from llava.open_clip.pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model
|
17 |
+
from llava.open_clip.transform import image_transform
|
18 |
+
from llava.open_clip.tokenizer import HFTokenizer, tokenize
|
19 |
+
|
20 |
+
|
21 |
+
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
|
22 |
+
# _MODEL_CONFIG_PATHS = ["/home/mryoo/llava_16/llava/open_clip/model_config/"]
|
23 |
+
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
|
24 |
+
|
25 |
+
|
26 |
+
def _natural_key(string_):
|
27 |
+
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
28 |
+
|
29 |
+
|
30 |
+
def _rescan_model_configs():
|
31 |
+
global _MODEL_CONFIGS
|
32 |
+
|
33 |
+
config_ext = ('.json',)
|
34 |
+
config_files = []
|
35 |
+
|
36 |
+
# print(f"_MODEL_CONFIG_PATHS:{_MODEL_CONFIG_PATHS}")
|
37 |
+
for config_path in _MODEL_CONFIG_PATHS:
|
38 |
+
if config_path.is_file() and config_path.suffix in config_ext:
|
39 |
+
config_files.append(config_path)
|
40 |
+
elif config_path.is_dir():
|
41 |
+
for ext in config_ext:
|
42 |
+
config_files.extend(config_path.glob(f'*{ext}'))
|
43 |
+
# for ext in config_ext:
|
44 |
+
# config_files.extend(config_path.glob(f'*{ext}'))
|
45 |
+
|
46 |
+
for cf in config_files:
|
47 |
+
with open(cf, 'r') as f:
|
48 |
+
model_cfg = json.load(f)
|
49 |
+
if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
|
50 |
+
_MODEL_CONFIGS[cf.stem] = model_cfg
|
51 |
+
|
52 |
+
_MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
|
53 |
+
# print(f"_MODEL_CONFIGS:{_MODEL_CONFIGS}")
|
54 |
+
|
55 |
+
|
56 |
+
_rescan_model_configs() # initial populate of model config registry
|
57 |
+
|
58 |
+
|
59 |
+
def list_models():
|
60 |
+
""" enumerate available model architectures based on config files """
|
61 |
+
return list(_MODEL_CONFIGS.keys())
|
62 |
+
|
63 |
+
|
64 |
+
def add_model_config(path):
|
65 |
+
""" add model config path or file and update registry """
|
66 |
+
if not isinstance(path, Path):
|
67 |
+
path = Path(path)
|
68 |
+
_MODEL_CONFIG_PATHS.append(path)
|
69 |
+
_rescan_model_configs()
|
70 |
+
|
71 |
+
|
72 |
+
def get_model_config(model_name):
|
73 |
+
# print(f"_MODEL_CONFIGS:{_MODEL_CONFIGS}")
|
74 |
+
if model_name in _MODEL_CONFIGS:
|
75 |
+
return deepcopy(_MODEL_CONFIGS[model_name])
|
76 |
+
else:
|
77 |
+
return None
|
78 |
+
|
79 |
+
|
80 |
+
def get_tokenizer(model_name):
|
81 |
+
config = get_model_config(model_name)
|
82 |
+
tokenizer = HFTokenizer(config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize
|
83 |
+
return tokenizer
|
84 |
+
|
85 |
+
|
86 |
+
def load_state_dict(checkpoint_path: str, map_location='cpu'):
|
87 |
+
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
88 |
+
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
89 |
+
state_dict = checkpoint['state_dict']
|
90 |
+
else:
|
91 |
+
state_dict = checkpoint
|
92 |
+
if next(iter(state_dict.items()))[0].startswith('module'):
|
93 |
+
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
94 |
+
return state_dict
|
95 |
+
|
96 |
+
|
97 |
+
def load_checkpoint(model, checkpoint_path, strict=True):
|
98 |
+
state_dict = load_state_dict(checkpoint_path)
|
99 |
+
# detect old format and make compatible with new format
|
100 |
+
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
|
101 |
+
state_dict = convert_to_custom_text_state_dict(state_dict)
|
102 |
+
resize_pos_embed(state_dict, model)
|
103 |
+
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
|
104 |
+
return incompatible_keys
|
105 |
+
|
106 |
+
|
107 |
+
def create_model(
|
108 |
+
model_name: str,
|
109 |
+
pretrained: Optional[str] = None,
|
110 |
+
precision: str = 'fp32',
|
111 |
+
device: Union[str, torch.device] = 'cpu',
|
112 |
+
jit: bool = False,
|
113 |
+
force_quick_gelu: bool = False,
|
114 |
+
force_custom_text: bool = False,
|
115 |
+
force_patch_dropout: Optional[float] = None,
|
116 |
+
pretrained_image: bool = False,
|
117 |
+
pretrained_hf: bool = True,
|
118 |
+
cache_dir: Optional[str] = None,
|
119 |
+
):
|
120 |
+
model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
|
121 |
+
if isinstance(device, str):
|
122 |
+
device = torch.device(device)
|
123 |
+
|
124 |
+
if pretrained and pretrained.lower() == 'openai':
|
125 |
+
logging.info(f'Loading pretrained {model_name} from OpenAI.')
|
126 |
+
model = load_openai_model(
|
127 |
+
model_name,
|
128 |
+
precision=precision,
|
129 |
+
device=device,
|
130 |
+
jit=jit,
|
131 |
+
cache_dir=cache_dir,
|
132 |
+
)
|
133 |
+
else:
|
134 |
+
model_cfg = get_model_config(model_name)
|
135 |
+
if model_cfg is not None:
|
136 |
+
logging.info(f'Loaded {model_name} model config.')
|
137 |
+
else:
|
138 |
+
logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
|
139 |
+
raise RuntimeError(f'Model config for {model_name} not found.')
|
140 |
+
|
141 |
+
if force_quick_gelu:
|
142 |
+
# override for use of QuickGELU on non-OpenAI transformer models
|
143 |
+
model_cfg["quick_gelu"] = True
|
144 |
+
|
145 |
+
if force_patch_dropout is not None:
|
146 |
+
# override the default patch dropout value
|
147 |
+
model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
|
148 |
+
|
149 |
+
if pretrained_image:
|
150 |
+
if 'timm_model_name' in model_cfg.get('vision_cfg', {}):
|
151 |
+
# pretrained weight loading for timm models set via vision_cfg
|
152 |
+
model_cfg['vision_cfg']['timm_model_pretrained'] = True
|
153 |
+
else:
|
154 |
+
assert False, 'pretrained image towers currently only supported for timm models'
|
155 |
+
|
156 |
+
cast_dtype = get_cast_dtype(precision)
|
157 |
+
custom_text = model_cfg.pop('custom_text', False) or force_custom_text or ('hf_model_name' in model_cfg.get('text_cfg', {}))
|
158 |
+
|
159 |
+
if custom_text:
|
160 |
+
if 'hf_model_name' in model_cfg.get('text_cfg', {}):
|
161 |
+
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
|
162 |
+
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
|
163 |
+
else:
|
164 |
+
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
|
165 |
+
|
166 |
+
pretrained_cfg = {}
|
167 |
+
if pretrained:
|
168 |
+
checkpoint_path = ''
|
169 |
+
pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
|
170 |
+
if pretrained_cfg:
|
171 |
+
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
|
172 |
+
elif os.path.exists(pretrained):
|
173 |
+
checkpoint_path = pretrained
|
174 |
+
|
175 |
+
if checkpoint_path:
|
176 |
+
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
|
177 |
+
load_checkpoint(model, checkpoint_path)
|
178 |
+
else:
|
179 |
+
error_str = (
|
180 |
+
f'Pretrained weights ({pretrained}) not found for model {model_name}.'
|
181 |
+
f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
|
182 |
+
logging.warning(error_str)
|
183 |
+
raise RuntimeError(error_str)
|
184 |
+
|
185 |
+
model.to(device=device)
|
186 |
+
if precision in ("fp16", "bf16"):
|
187 |
+
convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16)
|
188 |
+
|
189 |
+
# set image / mean metadata from pretrained_cfg if available, or use default
|
190 |
+
model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
|
191 |
+
model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
|
192 |
+
|
193 |
+
if jit:
|
194 |
+
model = torch.jit.script(model)
|
195 |
+
|
196 |
+
return model
|
197 |
+
|
198 |
+
|
199 |
+
def create_model_and_transforms(
|
200 |
+
model_name: str,
|
201 |
+
pretrained: Optional[str] = None,
|
202 |
+
precision: str = 'fp32',
|
203 |
+
device: Union[str, torch.device] = 'cpu',
|
204 |
+
jit: bool = False,
|
205 |
+
force_quick_gelu: bool = False,
|
206 |
+
force_custom_text: bool = False,
|
207 |
+
force_patch_dropout: Optional[float] = None,
|
208 |
+
pretrained_image: bool = False,
|
209 |
+
pretrained_hf: bool = True,
|
210 |
+
image_mean: Optional[Tuple[float, ...]] = None,
|
211 |
+
image_std: Optional[Tuple[float, ...]] = None,
|
212 |
+
cache_dir: Optional[str] = None,
|
213 |
+
):
|
214 |
+
model = create_model(
|
215 |
+
model_name,
|
216 |
+
pretrained,
|
217 |
+
precision=precision,
|
218 |
+
device=device,
|
219 |
+
jit=jit,
|
220 |
+
force_quick_gelu=force_quick_gelu,
|
221 |
+
force_custom_text=force_custom_text,
|
222 |
+
force_patch_dropout=force_patch_dropout,
|
223 |
+
pretrained_image=pretrained_image,
|
224 |
+
pretrained_hf=pretrained_hf,
|
225 |
+
cache_dir=cache_dir,
|
226 |
+
)
|
227 |
+
|
228 |
+
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
|
229 |
+
image_std = image_std or getattr(model.visual, 'image_std', None)
|
230 |
+
preprocess_train = image_transform(
|
231 |
+
model.visual.image_size,
|
232 |
+
is_train=True,
|
233 |
+
mean=image_mean,
|
234 |
+
std=image_std
|
235 |
+
)
|
236 |
+
preprocess_val = image_transform(
|
237 |
+
model.visual.image_size,
|
238 |
+
is_train=False,
|
239 |
+
mean=image_mean,
|
240 |
+
std=image_std
|
241 |
+
)
|
242 |
+
|
243 |
+
return model, preprocess_train, preprocess_val
|
244 |
+
|
245 |
+
|
246 |
+
def create_model_from_pretrained(
|
247 |
+
model_name: str,
|
248 |
+
pretrained: str,
|
249 |
+
precision: str = 'fp32',
|
250 |
+
device: Union[str, torch.device] = 'cpu',
|
251 |
+
jit: bool = False,
|
252 |
+
force_quick_gelu: bool = False,
|
253 |
+
force_custom_text: bool = False,
|
254 |
+
return_transform: bool = True,
|
255 |
+
image_mean: Optional[Tuple[float, ...]] = None,
|
256 |
+
image_std: Optional[Tuple[float, ...]] = None,
|
257 |
+
cache_dir: Optional[str] = None,
|
258 |
+
):
|
259 |
+
if not is_pretrained_cfg(model_name, pretrained) and not os.path.exists(pretrained):
|
260 |
+
raise RuntimeError(
|
261 |
+
f'{pretrained} is not a valid pretrained cfg or checkpoint for {model_name}.'
|
262 |
+
f' Use open_clip.list_pretrained() to find one.')
|
263 |
+
|
264 |
+
model = create_model(
|
265 |
+
model_name,
|
266 |
+
pretrained,
|
267 |
+
precision=precision,
|
268 |
+
device=device,
|
269 |
+
jit=jit,
|
270 |
+
force_quick_gelu=force_quick_gelu,
|
271 |
+
force_custom_text=force_custom_text,
|
272 |
+
cache_dir=cache_dir,
|
273 |
+
)
|
274 |
+
|
275 |
+
if not return_transform:
|
276 |
+
return model
|
277 |
+
|
278 |
+
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
|
279 |
+
image_std = image_std or getattr(model.visual, 'image_std', None)
|
280 |
+
preprocess = image_transform(
|
281 |
+
model.visual.image_size,
|
282 |
+
is_train=False,
|
283 |
+
mean=image_mean,
|
284 |
+
std=image_std
|
285 |
+
)
|
286 |
+
|
287 |
+
return model, preprocess
|
src/open_clip/hf_configs.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# HF architecture dict:
|
2 |
+
arch_dict = {
|
3 |
+
# https://huggingface.co/docs/transformers/model_doc/roberta#roberta
|
4 |
+
"roberta": {
|
5 |
+
"config_names": {
|
6 |
+
"context_length": "max_position_embeddings",
|
7 |
+
"vocab_size": "vocab_size",
|
8 |
+
"width": "hidden_size",
|
9 |
+
"heads": "num_attention_heads",
|
10 |
+
"layers": "num_hidden_layers",
|
11 |
+
"layer_attr": "layer",
|
12 |
+
"token_embeddings_attr": "embeddings"
|
13 |
+
},
|
14 |
+
"pooler": "mean_pooler",
|
15 |
+
},
|
16 |
+
# https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
|
17 |
+
"xlm-roberta": {
|
18 |
+
"config_names": {
|
19 |
+
"context_length": "max_position_embeddings",
|
20 |
+
"vocab_size": "vocab_size",
|
21 |
+
"width": "hidden_size",
|
22 |
+
"heads": "num_attention_heads",
|
23 |
+
"layers": "num_hidden_layers",
|
24 |
+
"layer_attr": "layer",
|
25 |
+
"token_embeddings_attr": "embeddings"
|
26 |
+
},
|
27 |
+
"pooler": "mean_pooler",
|
28 |
+
},
|
29 |
+
# https://huggingface.co/docs/transformers/model_doc/mt5#mt5
|
30 |
+
"mt5": {
|
31 |
+
"config_names": {
|
32 |
+
# unlimited seqlen
|
33 |
+
# https://github.com/google-research/text-to-text-transfer-transformer/issues/273
|
34 |
+
# https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
|
35 |
+
"context_length": "",
|
36 |
+
"vocab_size": "vocab_size",
|
37 |
+
"width": "d_model",
|
38 |
+
"heads": "num_heads",
|
39 |
+
"layers": "num_layers",
|
40 |
+
"layer_attr": "block",
|
41 |
+
"token_embeddings_attr": "embed_tokens"
|
42 |
+
},
|
43 |
+
"pooler": "mean_pooler",
|
44 |
+
},
|
45 |
+
"t5": {
|
46 |
+
"config_names": {
|
47 |
+
# unlimited seqlen
|
48 |
+
# https://github.com/google-research/text-to-text-transfer-transformer/issues/273
|
49 |
+
# https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
|
50 |
+
"context_length": "",
|
51 |
+
"vocab_size": "vocab_size",
|
52 |
+
"width": "d_model",
|
53 |
+
"heads": "num_heads",
|
54 |
+
"layers": "num_layers",
|
55 |
+
"layer_attr": "block",
|
56 |
+
"token_embeddings_attr": "embed_tokens"
|
57 |
+
},
|
58 |
+
"pooler": "mean_pooler",
|
59 |
+
},
|
60 |
+
}
|
src/open_clip/hf_model.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" huggingface model adapter
|
2 |
+
|
3 |
+
Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import re
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from torch import TensorType
|
11 |
+
|
12 |
+
try:
|
13 |
+
import transformers
|
14 |
+
from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
|
15 |
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
|
16 |
+
BaseModelOutputWithPoolingAndCrossAttentions
|
17 |
+
except ImportError as e:
|
18 |
+
transformers = None
|
19 |
+
|
20 |
+
|
21 |
+
class BaseModelOutput:
|
22 |
+
pass
|
23 |
+
|
24 |
+
|
25 |
+
class PretrainedConfig:
|
26 |
+
pass
|
27 |
+
|
28 |
+
from .hf_configs import arch_dict
|
29 |
+
|
30 |
+
|
31 |
+
# utils
|
32 |
+
def _camel2snake(s):
|
33 |
+
return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
|
34 |
+
|
35 |
+
|
36 |
+
# TODO: ?last - for gpt-like models
|
37 |
+
_POOLERS = {}
|
38 |
+
|
39 |
+
|
40 |
+
def register_pooler(cls):
|
41 |
+
"""Decorator registering pooler class"""
|
42 |
+
_POOLERS[_camel2snake(cls.__name__)] = cls
|
43 |
+
return cls
|
44 |
+
|
45 |
+
|
46 |
+
@register_pooler
|
47 |
+
class MeanPooler(nn.Module):
|
48 |
+
"""Mean pooling"""
|
49 |
+
|
50 |
+
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
51 |
+
masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
|
52 |
+
return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
|
53 |
+
|
54 |
+
|
55 |
+
@register_pooler
|
56 |
+
class MaxPooler(nn.Module):
|
57 |
+
"""Max pooling"""
|
58 |
+
|
59 |
+
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
60 |
+
masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
|
61 |
+
return masked_output.max(1).values
|
62 |
+
|
63 |
+
|
64 |
+
@register_pooler
|
65 |
+
class ClsPooler(nn.Module):
|
66 |
+
"""CLS token pooling"""
|
67 |
+
|
68 |
+
def __init__(self, use_pooler_output=True):
|
69 |
+
super().__init__()
|
70 |
+
self.cls_token_position = 0
|
71 |
+
self.use_pooler_output = use_pooler_output
|
72 |
+
|
73 |
+
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
74 |
+
if (self.use_pooler_output and
|
75 |
+
isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
|
76 |
+
(x.pooler_output is not None)
|
77 |
+
):
|
78 |
+
return x.pooler_output
|
79 |
+
|
80 |
+
return x.last_hidden_state[:, self.cls_token_position, :]
|
81 |
+
|
82 |
+
|
83 |
+
class HFTextEncoder(nn.Module):
|
84 |
+
"""HuggingFace model adapter"""
|
85 |
+
|
86 |
+
def __init__(
|
87 |
+
self,
|
88 |
+
model_name_or_path: str,
|
89 |
+
output_dim: int,
|
90 |
+
config: PretrainedConfig = None,
|
91 |
+
pooler_type: str = None,
|
92 |
+
proj: str = None,
|
93 |
+
pretrained: bool = True):
|
94 |
+
super().__init__()
|
95 |
+
|
96 |
+
self.output_dim = output_dim
|
97 |
+
|
98 |
+
# TODO: find better way to get this information
|
99 |
+
uses_transformer_pooler = (pooler_type == "cls_pooler")
|
100 |
+
|
101 |
+
if transformers is None:
|
102 |
+
raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
|
103 |
+
if config is None:
|
104 |
+
self.config = AutoConfig.from_pretrained(model_name_or_path)
|
105 |
+
create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
|
106 |
+
AutoModel.from_config, self.config)
|
107 |
+
# TODO: do all model configs have this attribute? PretrainedConfig does so yes??
|
108 |
+
if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
|
109 |
+
self.transformer = create_func(model_args)
|
110 |
+
self.transformer = self.transformer.encoder
|
111 |
+
else:
|
112 |
+
self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
|
113 |
+
else:
|
114 |
+
self.config = config
|
115 |
+
self.transformer = AutoModel.from_config(config)
|
116 |
+
|
117 |
+
if pooler_type is None: # get default arch pooler
|
118 |
+
self.pooler = _POOLERS[(arch_dict[self.config.model_type]["pooler"])]()
|
119 |
+
else:
|
120 |
+
self.pooler = _POOLERS[pooler_type]()
|
121 |
+
|
122 |
+
d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
|
123 |
+
if (d_model == output_dim) and (proj is None): # do we always need a proj?
|
124 |
+
self.proj = nn.Identity()
|
125 |
+
elif proj == 'linear':
|
126 |
+
self.proj = nn.Linear(d_model, output_dim, bias=False)
|
127 |
+
elif proj == 'mlp':
|
128 |
+
hidden_size = (d_model + output_dim) // 2
|
129 |
+
self.proj = nn.Sequential(
|
130 |
+
nn.Linear(d_model, hidden_size, bias=False),
|
131 |
+
nn.GELU(),
|
132 |
+
nn.Linear(hidden_size, output_dim, bias=False),
|
133 |
+
)
|
134 |
+
|
135 |
+
def forward(self, x: TensorType) -> TensorType:
|
136 |
+
attn_mask = (x != self.config.pad_token_id).long()
|
137 |
+
out = self.transformer(input_ids=x, attention_mask=attn_mask)
|
138 |
+
pooled_out = self.pooler(out, attn_mask)
|
139 |
+
|
140 |
+
return self.proj(pooled_out)
|
141 |
+
|
142 |
+
def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
|
143 |
+
if not unlocked_layers: # full freezing
|
144 |
+
for n, p in self.transformer.named_parameters():
|
145 |
+
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
146 |
+
return
|
147 |
+
|
148 |
+
encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
|
149 |
+
layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
|
150 |
+
print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
|
151 |
+
embeddings = getattr(
|
152 |
+
self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
|
153 |
+
modules = [embeddings, *layer_list][:-unlocked_layers]
|
154 |
+
# freeze layers
|
155 |
+
for module in modules:
|
156 |
+
for n, p in module.named_parameters():
|
157 |
+
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
158 |
+
|
159 |
+
@torch.jit.ignore
|
160 |
+
def set_grad_checkpointing(self, enable=True):
|
161 |
+
self.transformer.gradient_checkpointing_enable()
|
162 |
+
|
163 |
+
def init_parameters(self):
|
164 |
+
pass
|
src/open_clip/loss.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
try:
|
6 |
+
import torch.distributed.nn
|
7 |
+
from torch import distributed as dist
|
8 |
+
has_distributed = True
|
9 |
+
except ImportError:
|
10 |
+
has_distributed = False
|
11 |
+
|
12 |
+
try:
|
13 |
+
import horovod.torch as hvd
|
14 |
+
except ImportError:
|
15 |
+
hvd = None
|
16 |
+
|
17 |
+
|
18 |
+
def gather_features(
|
19 |
+
image_features,
|
20 |
+
text_features,
|
21 |
+
local_loss=False,
|
22 |
+
gather_with_grad=False,
|
23 |
+
rank=0,
|
24 |
+
world_size=1,
|
25 |
+
use_horovod=False
|
26 |
+
):
|
27 |
+
assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
|
28 |
+
if use_horovod:
|
29 |
+
assert hvd is not None, 'Please install horovod'
|
30 |
+
if gather_with_grad:
|
31 |
+
all_image_features = hvd.allgather(image_features)
|
32 |
+
all_text_features = hvd.allgather(text_features)
|
33 |
+
else:
|
34 |
+
with torch.no_grad():
|
35 |
+
all_image_features = hvd.allgather(image_features)
|
36 |
+
all_text_features = hvd.allgather(text_features)
|
37 |
+
if not local_loss:
|
38 |
+
# ensure grads for local rank when all_* features don't have a gradient
|
39 |
+
gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
|
40 |
+
gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
|
41 |
+
gathered_image_features[rank] = image_features
|
42 |
+
gathered_text_features[rank] = text_features
|
43 |
+
all_image_features = torch.cat(gathered_image_features, dim=0)
|
44 |
+
all_text_features = torch.cat(gathered_text_features, dim=0)
|
45 |
+
else:
|
46 |
+
# We gather tensors from all gpus
|
47 |
+
if gather_with_grad:
|
48 |
+
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
|
49 |
+
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
|
50 |
+
else:
|
51 |
+
gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
|
52 |
+
gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
|
53 |
+
dist.all_gather(gathered_image_features, image_features)
|
54 |
+
dist.all_gather(gathered_text_features, text_features)
|
55 |
+
if not local_loss:
|
56 |
+
# ensure grads for local rank when all_* features don't have a gradient
|
57 |
+
gathered_image_features[rank] = image_features
|
58 |
+
gathered_text_features[rank] = text_features
|
59 |
+
all_image_features = torch.cat(gathered_image_features, dim=0)
|
60 |
+
all_text_features = torch.cat(gathered_text_features, dim=0)
|
61 |
+
|
62 |
+
return all_image_features, all_text_features
|
63 |
+
|
64 |
+
|
65 |
+
class ClipLoss(nn.Module):
|
66 |
+
|
67 |
+
def __init__(
|
68 |
+
self,
|
69 |
+
local_loss=False,
|
70 |
+
gather_with_grad=False,
|
71 |
+
cache_labels=False,
|
72 |
+
rank=0,
|
73 |
+
world_size=1,
|
74 |
+
use_horovod=False,
|
75 |
+
):
|
76 |
+
super().__init__()
|
77 |
+
self.local_loss = local_loss
|
78 |
+
self.gather_with_grad = gather_with_grad
|
79 |
+
self.cache_labels = cache_labels
|
80 |
+
self.rank = rank
|
81 |
+
self.world_size = world_size
|
82 |
+
self.use_horovod = use_horovod
|
83 |
+
|
84 |
+
# cache state
|
85 |
+
self.prev_num_logits = 0
|
86 |
+
self.labels = {}
|
87 |
+
|
88 |
+
def forward(self, image_features, text_features, logit_scale):
|
89 |
+
device = image_features.device
|
90 |
+
if self.world_size > 1:
|
91 |
+
all_image_features, all_text_features = gather_features(
|
92 |
+
image_features, text_features,
|
93 |
+
self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
|
94 |
+
|
95 |
+
if self.local_loss:
|
96 |
+
logits_per_image = logit_scale * image_features @ all_text_features.T
|
97 |
+
logits_per_text = logit_scale * text_features @ all_image_features.T
|
98 |
+
else:
|
99 |
+
logits_per_image = logit_scale * all_image_features @ all_text_features.T
|
100 |
+
logits_per_text = logits_per_image.T
|
101 |
+
else:
|
102 |
+
logits_per_image = logit_scale * image_features @ text_features.T
|
103 |
+
logits_per_text = logit_scale * text_features @ image_features.T
|
104 |
+
|
105 |
+
# calculated ground-truth and cache if enabled
|
106 |
+
num_logits = logits_per_image.shape[0]
|
107 |
+
if self.prev_num_logits != num_logits or device not in self.labels:
|
108 |
+
labels = torch.arange(num_logits, device=device, dtype=torch.long)
|
109 |
+
if self.world_size > 1 and self.local_loss:
|
110 |
+
labels = labels + num_logits * self.rank
|
111 |
+
if self.cache_labels:
|
112 |
+
self.labels[device] = labels
|
113 |
+
self.prev_num_logits = num_logits
|
114 |
+
else:
|
115 |
+
labels = self.labels[device]
|
116 |
+
|
117 |
+
total_loss = (
|
118 |
+
F.cross_entropy(logits_per_image, labels) +
|
119 |
+
F.cross_entropy(logits_per_text, labels)
|
120 |
+
) / 2
|
121 |
+
return total_loss
|
src/open_clip/model.py
ADDED
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" CLIP Model
|
2 |
+
|
3 |
+
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
4 |
+
"""
|
5 |
+
from dataclasses import dataclass
|
6 |
+
import logging
|
7 |
+
import math
|
8 |
+
from typing import Optional, Tuple, Union
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch import nn
|
14 |
+
from torch.utils.checkpoint import checkpoint
|
15 |
+
from sentence_transformers import SentenceTransformer
|
16 |
+
|
17 |
+
from llava.open_clip.hf_model import HFTextEncoder
|
18 |
+
from llava.open_clip.modified_resnet import ModifiedResNet
|
19 |
+
from llava.open_clip.timm_model import TimmModel
|
20 |
+
from llava.open_clip.transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
|
21 |
+
from llava.open_clip.utils import to_2tuple
|
22 |
+
|
23 |
+
|
24 |
+
@dataclass
|
25 |
+
class CLIPVisionCfg:
|
26 |
+
layers: Union[Tuple[int, int, int, int], int] = 12
|
27 |
+
width: int = 768
|
28 |
+
head_width: int = 64
|
29 |
+
mlp_ratio: float = 4.0
|
30 |
+
patch_size: int = 16
|
31 |
+
image_size: Union[Tuple[int, int], int] = 224
|
32 |
+
ls_init_value: Optional[float] = None # layer scale initial value
|
33 |
+
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
|
34 |
+
global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
|
35 |
+
global_max_pool: bool = False # whether to max pool, from CLIPpy paper
|
36 |
+
timm_model_name: str = None # a valid model name overrides layers, width, patch_size
|
37 |
+
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
|
38 |
+
timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
|
39 |
+
timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
|
40 |
+
timm_proj_bias: bool = False # enable bias final projection
|
41 |
+
|
42 |
+
|
43 |
+
@dataclass
|
44 |
+
class CLIPTextCfg:
|
45 |
+
context_length: int = 77
|
46 |
+
vocab_size: int = 49408
|
47 |
+
width: int = 512
|
48 |
+
heads: int = 8
|
49 |
+
layers: int = 12
|
50 |
+
ls_init_value: Optional[float] = None # layer scale initial value
|
51 |
+
hf_model_name: str = None
|
52 |
+
hf_tokenizer_name: str = None
|
53 |
+
hf_model_pretrained: bool = True
|
54 |
+
proj: str = 'mlp'
|
55 |
+
pooler_type: str = 'mean_pooler'
|
56 |
+
|
57 |
+
|
58 |
+
def get_cast_dtype(precision: str):
|
59 |
+
cast_dtype = None
|
60 |
+
if precision == 'bf16':
|
61 |
+
cast_dtype = torch.bfloat16
|
62 |
+
elif precision == 'fp16':
|
63 |
+
cast_dtype = torch.float16
|
64 |
+
return cast_dtype
|
65 |
+
|
66 |
+
|
67 |
+
def _build_vision_tower(
|
68 |
+
embed_dim: int,
|
69 |
+
vision_cfg: CLIPVisionCfg,
|
70 |
+
quick_gelu: bool = False,
|
71 |
+
cast_dtype: Optional[torch.dtype] = None
|
72 |
+
):
|
73 |
+
if isinstance(vision_cfg, dict):
|
74 |
+
vision_cfg = CLIPVisionCfg(**vision_cfg)
|
75 |
+
|
76 |
+
# OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
|
77 |
+
# memory efficient in recent PyTorch releases (>= 1.10).
|
78 |
+
# NOTE: timm models always use native GELU regardless of quick_gelu flag.
|
79 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
80 |
+
|
81 |
+
if vision_cfg.timm_model_name:
|
82 |
+
visual = TimmModel(
|
83 |
+
vision_cfg.timm_model_name,
|
84 |
+
pretrained=vision_cfg.timm_model_pretrained,
|
85 |
+
pool=vision_cfg.timm_pool,
|
86 |
+
proj=vision_cfg.timm_proj,
|
87 |
+
proj_bias=vision_cfg.timm_proj_bias,
|
88 |
+
embed_dim=embed_dim,
|
89 |
+
image_size=vision_cfg.image_size
|
90 |
+
)
|
91 |
+
act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
|
92 |
+
elif isinstance(vision_cfg.layers, (tuple, list)):
|
93 |
+
vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
|
94 |
+
visual = ModifiedResNet(
|
95 |
+
layers=vision_cfg.layers,
|
96 |
+
output_dim=embed_dim,
|
97 |
+
heads=vision_heads,
|
98 |
+
image_size=vision_cfg.image_size,
|
99 |
+
width=vision_cfg.width
|
100 |
+
)
|
101 |
+
else:
|
102 |
+
vision_heads = vision_cfg.width // vision_cfg.head_width
|
103 |
+
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
104 |
+
visual = VisionTransformer(
|
105 |
+
image_size=vision_cfg.image_size,
|
106 |
+
patch_size=vision_cfg.patch_size,
|
107 |
+
width=vision_cfg.width,
|
108 |
+
layers=vision_cfg.layers,
|
109 |
+
heads=vision_heads,
|
110 |
+
mlp_ratio=vision_cfg.mlp_ratio,
|
111 |
+
ls_init_value=vision_cfg.ls_init_value,
|
112 |
+
patch_dropout=vision_cfg.patch_dropout,
|
113 |
+
global_average_pool=vision_cfg.global_average_pool,
|
114 |
+
global_max_pool=vision_cfg.global_max_pool,
|
115 |
+
output_dim=embed_dim,
|
116 |
+
act_layer=act_layer,
|
117 |
+
norm_layer=norm_layer,
|
118 |
+
)
|
119 |
+
|
120 |
+
return visual
|
121 |
+
|
122 |
+
|
123 |
+
def _build_text_tower(
|
124 |
+
embed_dim: int,
|
125 |
+
text_cfg: CLIPTextCfg,
|
126 |
+
quick_gelu: bool = False,
|
127 |
+
cast_dtype: Optional[torch.dtype] = None,
|
128 |
+
):
|
129 |
+
if isinstance(text_cfg, dict):
|
130 |
+
text_cfg = CLIPTextCfg(**text_cfg)
|
131 |
+
|
132 |
+
if text_cfg.hf_model_name:
|
133 |
+
if text_cfg.hf_model_name == "sentence-transformers/sentence-t5-base":
|
134 |
+
text = SentenceTransformer("sentence-transformers/sentence-t5-base")
|
135 |
+
else:
|
136 |
+
text = HFTextEncoder(
|
137 |
+
text_cfg.hf_model_name,
|
138 |
+
output_dim=embed_dim,
|
139 |
+
proj=text_cfg.proj,
|
140 |
+
pooler_type=text_cfg.pooler_type,
|
141 |
+
pretrained=text_cfg.hf_model_pretrained
|
142 |
+
)
|
143 |
+
else:
|
144 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
145 |
+
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
146 |
+
|
147 |
+
text = TextTransformer(
|
148 |
+
context_length=text_cfg.context_length,
|
149 |
+
vocab_size=text_cfg.vocab_size,
|
150 |
+
width=text_cfg.width,
|
151 |
+
heads=text_cfg.heads,
|
152 |
+
layers=text_cfg.layers,
|
153 |
+
ls_init_value=text_cfg.ls_init_value,
|
154 |
+
output_dim=embed_dim,
|
155 |
+
act_layer=act_layer,
|
156 |
+
norm_layer=norm_layer,
|
157 |
+
)
|
158 |
+
return text
|
159 |
+
|
160 |
+
|
161 |
+
class CLIP(nn.Module):
|
162 |
+
def __init__(
|
163 |
+
self,
|
164 |
+
embed_dim: int,
|
165 |
+
vision_cfg: CLIPVisionCfg,
|
166 |
+
text_cfg: CLIPTextCfg,
|
167 |
+
quick_gelu: bool = False,
|
168 |
+
cast_dtype: Optional[torch.dtype] = None,
|
169 |
+
):
|
170 |
+
super().__init__()
|
171 |
+
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
172 |
+
|
173 |
+
text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
|
174 |
+
self.transformer = text.transformer
|
175 |
+
self.vocab_size = text.vocab_size
|
176 |
+
self.token_embedding = text.token_embedding
|
177 |
+
self.positional_embedding = text.positional_embedding
|
178 |
+
self.ln_final = text.ln_final
|
179 |
+
self.text_projection = text.text_projection
|
180 |
+
self.register_buffer('attn_mask', text.attn_mask, persistent=False)
|
181 |
+
|
182 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
183 |
+
|
184 |
+
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
|
185 |
+
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
|
186 |
+
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
|
187 |
+
|
188 |
+
@torch.jit.ignore
|
189 |
+
def set_grad_checkpointing(self, enable=True):
|
190 |
+
self.visual.set_grad_checkpointing(enable)
|
191 |
+
self.transformer.grad_checkpointing = enable
|
192 |
+
|
193 |
+
def encode_image(self, image, normalize: bool = False):
|
194 |
+
features = self.visual(image)
|
195 |
+
return F.normalize(features, dim=-1) if normalize else features
|
196 |
+
|
197 |
+
def encode_text(self, text, normalize: bool = False):
|
198 |
+
cast_dtype = self.transformer.get_cast_dtype()
|
199 |
+
|
200 |
+
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
|
201 |
+
|
202 |
+
x = x + self.positional_embedding.to(cast_dtype)
|
203 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
204 |
+
x = self.transformer(x, attn_mask=self.attn_mask)
|
205 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
206 |
+
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
|
207 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
208 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
209 |
+
return F.normalize(x, dim=-1) if normalize else x
|
210 |
+
|
211 |
+
def forward(self, image, text):
|
212 |
+
image_features = self.encode_image(image, normalize=True)
|
213 |
+
text_features = self.encode_text(text, normalize=True)
|
214 |
+
return image_features, text_features, self.logit_scale.exp()
|
215 |
+
|
216 |
+
|
217 |
+
class CustomTextCLIP(nn.Module):
|
218 |
+
def __init__(
|
219 |
+
self,
|
220 |
+
embed_dim: int,
|
221 |
+
vision_cfg: CLIPVisionCfg,
|
222 |
+
text_cfg: CLIPTextCfg,
|
223 |
+
quick_gelu: bool = False,
|
224 |
+
cast_dtype: Optional[torch.dtype] = None,
|
225 |
+
):
|
226 |
+
super().__init__()
|
227 |
+
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
228 |
+
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
|
229 |
+
if isinstance(text_cfg, dict):
|
230 |
+
text_cfg = CLIPTextCfg(**text_cfg)
|
231 |
+
if text_cfg.hf_model_name:
|
232 |
+
self.use_st = text_cfg.hf_model_name == "sentence-transformers/sentence-t5-base"
|
233 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
234 |
+
|
235 |
+
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
|
236 |
+
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
|
237 |
+
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
|
238 |
+
|
239 |
+
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
|
240 |
+
self.text.lock(unlocked_layers, freeze_layer_norm)
|
241 |
+
|
242 |
+
@torch.jit.ignore
|
243 |
+
def set_grad_checkpointing(self, enable=True):
|
244 |
+
self.visual.set_grad_checkpointing(enable)
|
245 |
+
self.text.set_grad_checkpointing(enable)
|
246 |
+
|
247 |
+
def encode_image(self, image, normalize: bool = False, pool: bool = True):
|
248 |
+
features = self.visual(image, pool=pool)
|
249 |
+
return F.normalize(features, dim=-1) if normalize else features
|
250 |
+
|
251 |
+
def encode_text(self, text, normalize: bool = False):
|
252 |
+
features = self.text(text)
|
253 |
+
return F.normalize(features, dim=-1) if normalize else features
|
254 |
+
|
255 |
+
def forward(self, image, text):
|
256 |
+
image_features = self.encode_image(image, normalize=True)
|
257 |
+
text_features = self.encode_text(text, normalize=True)
|
258 |
+
return image_features, text_features, self.logit_scale.exp()
|
259 |
+
|
260 |
+
|
261 |
+
def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
|
262 |
+
"""Convert applicable model parameters to low-precision (bf16 or fp16)"""
|
263 |
+
|
264 |
+
def _convert_weights(l):
|
265 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
266 |
+
l.weight.data = l.weight.data.to(dtype)
|
267 |
+
if l.bias is not None:
|
268 |
+
l.bias.data = l.bias.data.to(dtype)
|
269 |
+
|
270 |
+
if isinstance(l, (nn.MultiheadAttention, Attention)):
|
271 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
272 |
+
tensor = getattr(l, attr)
|
273 |
+
if tensor is not None:
|
274 |
+
tensor.data = tensor.data.to(dtype)
|
275 |
+
|
276 |
+
for name in ["text_projection", "proj"]:
|
277 |
+
if hasattr(l, name):
|
278 |
+
attr = getattr(l, name)
|
279 |
+
if attr is not None:
|
280 |
+
attr.data = attr.data.to(dtype)
|
281 |
+
|
282 |
+
model.apply(_convert_weights)
|
283 |
+
|
284 |
+
|
285 |
+
convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
|
286 |
+
|
287 |
+
|
288 |
+
# used to maintain checkpoint compatibility
|
289 |
+
def convert_to_custom_text_state_dict(state_dict: dict):
|
290 |
+
if 'text_projection' in state_dict:
|
291 |
+
# old format state_dict, move text tower -> .text
|
292 |
+
new_state_dict = {}
|
293 |
+
for k, v in state_dict.items():
|
294 |
+
if any(k.startswith(p) for p in (
|
295 |
+
'text_projection',
|
296 |
+
'positional_embedding',
|
297 |
+
'token_embedding',
|
298 |
+
'transformer',
|
299 |
+
'ln_final',
|
300 |
+
)):
|
301 |
+
k = 'text.' + k
|
302 |
+
new_state_dict[k] = v
|
303 |
+
return new_state_dict
|
304 |
+
return state_dict
|
305 |
+
|
306 |
+
|
307 |
+
def build_model_from_openai_state_dict(
|
308 |
+
state_dict: dict,
|
309 |
+
quick_gelu=True,
|
310 |
+
cast_dtype=torch.float16,
|
311 |
+
):
|
312 |
+
vit = "visual.proj" in state_dict
|
313 |
+
|
314 |
+
if vit:
|
315 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
316 |
+
vision_layers = len(
|
317 |
+
[k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
318 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
319 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
320 |
+
image_size = vision_patch_size * grid_size
|
321 |
+
else:
|
322 |
+
counts: list = [
|
323 |
+
len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
324 |
+
vision_layers = tuple(counts)
|
325 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
326 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
327 |
+
vision_patch_size = None
|
328 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
329 |
+
image_size = output_width * 32
|
330 |
+
|
331 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
332 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
333 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
334 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
335 |
+
transformer_heads = transformer_width // 64
|
336 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
337 |
+
|
338 |
+
vision_cfg = CLIPVisionCfg(
|
339 |
+
layers=vision_layers,
|
340 |
+
width=vision_width,
|
341 |
+
patch_size=vision_patch_size,
|
342 |
+
image_size=image_size,
|
343 |
+
)
|
344 |
+
text_cfg = CLIPTextCfg(
|
345 |
+
context_length=context_length,
|
346 |
+
vocab_size=vocab_size,
|
347 |
+
width=transformer_width,
|
348 |
+
heads=transformer_heads,
|
349 |
+
layers=transformer_layers
|
350 |
+
)
|
351 |
+
model = CLIP(
|
352 |
+
embed_dim,
|
353 |
+
vision_cfg=vision_cfg,
|
354 |
+
text_cfg=text_cfg,
|
355 |
+
quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
|
356 |
+
cast_dtype=cast_dtype,
|
357 |
+
)
|
358 |
+
|
359 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
360 |
+
state_dict.pop(key, None)
|
361 |
+
|
362 |
+
convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
|
363 |
+
model.load_state_dict(state_dict)
|
364 |
+
return model.eval()
|
365 |
+
|
366 |
+
|
367 |
+
def trace_model(model, batch_size=256, device=torch.device('cpu')):
|
368 |
+
model.eval()
|
369 |
+
image_size = model.visual.image_size
|
370 |
+
example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
|
371 |
+
example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
|
372 |
+
model = torch.jit.trace_module(
|
373 |
+
model,
|
374 |
+
inputs=dict(
|
375 |
+
forward=(example_images, example_text),
|
376 |
+
encode_text=(example_text,),
|
377 |
+
encode_image=(example_images,)
|
378 |
+
))
|
379 |
+
model.visual.image_size = image_size
|
380 |
+
return model
|
381 |
+
|
382 |
+
|
383 |
+
def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
|
384 |
+
# Rescale the grid of position embeddings when loading from state_dict
|
385 |
+
old_pos_embed = state_dict.get('visual.positional_embedding', None)
|
386 |
+
if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
|
387 |
+
return
|
388 |
+
grid_size = to_2tuple(model.visual.grid_size)
|
389 |
+
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
|
390 |
+
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
|
391 |
+
if new_seq_len == old_pos_embed.shape[0]:
|
392 |
+
return
|
393 |
+
|
394 |
+
if extra_tokens:
|
395 |
+
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
|
396 |
+
else:
|
397 |
+
pos_emb_tok, pos_emb_img = None, old_pos_embed
|
398 |
+
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
|
399 |
+
|
400 |
+
logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
|
401 |
+
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
|
402 |
+
pos_emb_img = F.interpolate(
|
403 |
+
pos_emb_img,
|
404 |
+
size=grid_size,
|
405 |
+
mode=interpolation,
|
406 |
+
align_corners=True,
|
407 |
+
)
|
408 |
+
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
|
409 |
+
if pos_emb_tok is not None:
|
410 |
+
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
|
411 |
+
else:
|
412 |
+
new_pos_embed = pos_emb_img
|
413 |
+
state_dict['visual.positional_embedding'] = new_pos_embed
|
src/open_clip/model_configs/RN101-quickgelu.json
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"quick_gelu": true,
|
4 |
+
"vision_cfg": {
|
5 |
+
"image_size": 224,
|
6 |
+
"layers": [
|
7 |
+
3,
|
8 |
+
4,
|
9 |
+
23,
|
10 |
+
3
|
11 |
+
],
|
12 |
+
"width": 64,
|
13 |
+
"patch_size": null
|
14 |
+
},
|
15 |
+
"text_cfg": {
|
16 |
+
"context_length": 77,
|
17 |
+
"vocab_size": 49408,
|
18 |
+
"width": 512,
|
19 |
+
"heads": 8,
|
20 |
+
"layers": 12
|
21 |
+
}
|
22 |
+
}
|
src/open_clip/model_configs/RN101.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": [
|
6 |
+
3,
|
7 |
+
4,
|
8 |
+
23,
|
9 |
+
3
|
10 |
+
],
|
11 |
+
"width": 64,
|
12 |
+
"patch_size": null
|
13 |
+
},
|
14 |
+
"text_cfg": {
|
15 |
+
"context_length": 77,
|
16 |
+
"vocab_size": 49408,
|
17 |
+
"width": 512,
|
18 |
+
"heads": 8,
|
19 |
+
"layers": 12
|
20 |
+
}
|
21 |
+
}
|
src/open_clip/model_configs/RN50-quickgelu.json
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"quick_gelu": true,
|
4 |
+
"vision_cfg": {
|
5 |
+
"image_size": 224,
|
6 |
+
"layers": [
|
7 |
+
3,
|
8 |
+
4,
|
9 |
+
6,
|
10 |
+
3
|
11 |
+
],
|
12 |
+
"width": 64,
|
13 |
+
"patch_size": null
|
14 |
+
},
|
15 |
+
"text_cfg": {
|
16 |
+
"context_length": 77,
|
17 |
+
"vocab_size": 49408,
|
18 |
+
"width": 512,
|
19 |
+
"heads": 8,
|
20 |
+
"layers": 12
|
21 |
+
}
|
22 |
+
}
|
src/open_clip/model_configs/RN50.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": [
|
6 |
+
3,
|
7 |
+
4,
|
8 |
+
6,
|
9 |
+
3
|
10 |
+
],
|
11 |
+
"width": 64,
|
12 |
+
"patch_size": null
|
13 |
+
},
|
14 |
+
"text_cfg": {
|
15 |
+
"context_length": 77,
|
16 |
+
"vocab_size": 49408,
|
17 |
+
"width": 512,
|
18 |
+
"heads": 8,
|
19 |
+
"layers": 12
|
20 |
+
}
|
21 |
+
}
|
src/open_clip/model_configs/RN50x16.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 384,
|
5 |
+
"layers": [
|
6 |
+
6,
|
7 |
+
8,
|
8 |
+
18,
|
9 |
+
8
|
10 |
+
],
|
11 |
+
"width": 96,
|
12 |
+
"patch_size": null
|
13 |
+
},
|
14 |
+
"text_cfg": {
|
15 |
+
"context_length": 77,
|
16 |
+
"vocab_size": 49408,
|
17 |
+
"width": 768,
|
18 |
+
"heads": 12,
|
19 |
+
"layers": 12
|
20 |
+
}
|
21 |
+
}
|
src/open_clip/model_configs/RN50x4.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 640,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 288,
|
5 |
+
"layers": [
|
6 |
+
4,
|
7 |
+
6,
|
8 |
+
10,
|
9 |
+
6
|
10 |
+
],
|
11 |
+
"width": 80,
|
12 |
+
"patch_size": null
|
13 |
+
},
|
14 |
+
"text_cfg": {
|
15 |
+
"context_length": 77,
|
16 |
+
"vocab_size": 49408,
|
17 |
+
"width": 640,
|
18 |
+
"heads": 10,
|
19 |
+
"layers": 12
|
20 |
+
}
|
21 |
+
}
|
src/open_clip/model_configs/RN50x64.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 448,
|
5 |
+
"layers": [
|
6 |
+
3,
|
7 |
+
15,
|
8 |
+
36,
|
9 |
+
10
|
10 |
+
],
|
11 |
+
"width": 128,
|
12 |
+
"patch_size": null
|
13 |
+
},
|
14 |
+
"text_cfg": {
|
15 |
+
"context_length": 77,
|
16 |
+
"vocab_size": 49408,
|
17 |
+
"width": 1024,
|
18 |
+
"heads": 16,
|
19 |
+
"layers": 12
|
20 |
+
}
|
21 |
+
}
|
src/open_clip/model_configs/ViT-B-16-plus-240.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 640,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 240,
|
5 |
+
"layers": 12,
|
6 |
+
"width": 896,
|
7 |
+
"patch_size": 16
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 640,
|
13 |
+
"heads": 10,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|
src/open_clip/model_configs/ViT-B-16-plus.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 640,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 12,
|
6 |
+
"width": 896,
|
7 |
+
"patch_size": 16
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 640,
|
13 |
+
"heads": 10,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|
src/open_clip/model_configs/ViT-B-16.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 12,
|
6 |
+
"width": 768,
|
7 |
+
"patch_size": 16
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 512,
|
13 |
+
"heads": 8,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|
src/open_clip/model_configs/ViT-B-32-plus-256.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 640,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 256,
|
5 |
+
"layers": 12,
|
6 |
+
"width": 896,
|
7 |
+
"patch_size": 32
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 640,
|
13 |
+
"heads": 10,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|
src/open_clip/model_configs/ViT-B-32-quickgelu.json
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"quick_gelu": true,
|
4 |
+
"vision_cfg": {
|
5 |
+
"image_size": 224,
|
6 |
+
"layers": 12,
|
7 |
+
"width": 768,
|
8 |
+
"patch_size": 32
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 512,
|
14 |
+
"heads": 8,
|
15 |
+
"layers": 12
|
16 |
+
}
|
17 |
+
}
|
src/open_clip/model_configs/ViT-B-32.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 12,
|
6 |
+
"width": 768,
|
7 |
+
"patch_size": 32
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 512,
|
13 |
+
"heads": 8,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|
src/open_clip/model_configs/ViT-H-14.json
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 32,
|
6 |
+
"width": 1280,
|
7 |
+
"head_width": 80,
|
8 |
+
"patch_size": 14
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 1024,
|
14 |
+
"heads": 16,
|
15 |
+
"layers": 24
|
16 |
+
}
|
17 |
+
}
|
src/open_clip/model_configs/ViT-H-16.json
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 32,
|
6 |
+
"width": 1280,
|
7 |
+
"head_width": 80,
|
8 |
+
"patch_size": 16
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 1024,
|
14 |
+
"heads": 16,
|
15 |
+
"layers": 24
|
16 |
+
}
|
17 |
+
}
|
src/open_clip/model_configs/ViT-L-14-280.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 280,
|
5 |
+
"layers": 24,
|
6 |
+
"width": 1024,
|
7 |
+
"patch_size": 14
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 768,
|
13 |
+
"heads": 12,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|
src/open_clip/model_configs/ViT-L-14-336.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 336,
|
5 |
+
"layers": 24,
|
6 |
+
"width": 1024,
|
7 |
+
"patch_size": 14
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 768,
|
13 |
+
"heads": 12,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|
src/open_clip/model_configs/ViT-L-14.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 24,
|
6 |
+
"width": 1024,
|
7 |
+
"patch_size": 14
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 768,
|
13 |
+
"heads": 12,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|
src/open_clip/model_configs/ViT-L-16-320.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 320,
|
5 |
+
"layers": 24,
|
6 |
+
"width": 1024,
|
7 |
+
"patch_size": 16
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 768,
|
13 |
+
"heads": 12,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|
src/open_clip/model_configs/ViT-L-16.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 24,
|
6 |
+
"width": 1024,
|
7 |
+
"patch_size": 16
|
8 |
+
},
|
9 |
+
"text_cfg": {
|
10 |
+
"context_length": 77,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"width": 768,
|
13 |
+
"heads": 12,
|
14 |
+
"layers": 12
|
15 |
+
}
|
16 |
+
}
|