jongwoopark7978 commited on
Commit
54216bc
·
1 Parent(s): 446e69c

chore: add project files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. .gitignore +169 -0
  3. LLM_stage.py +168 -0
  4. README.md +119 -0
  5. VLM_stage.py +154 -0
  6. coarseKeyframeDetector.py +254 -0
  7. config/config.py +38 -0
  8. config/run.sh +4 -0
  9. extractKeyword.py +132 -0
  10. figures/KFSelectionFlowComparison.jpg +3 -0
  11. figures/architecture.png +3 -0
  12. figures/architecture_qualitative.png +3 -0
  13. figures/hkf_graph.png +3 -0
  14. fineKeyframeDetector.py +189 -0
  15. keywords/Keyword_4531questions.json +0 -0
  16. keywords/Keyword_500questions.jsonl +0 -0
  17. questions/4531questions.json +0 -0
  18. questions/500questions.jsonl +0 -0
  19. requirements.txt +24 -0
  20. scripts/create_caption.sh +11 -0
  21. scripts/eval_ES.sh +23 -0
  22. scripts/get_ES_captions.sh +8 -0
  23. src/open_clip/__init__.py +11 -0
  24. src/open_clip/bpe_simple_vocab_16e6.txt.gz +3 -0
  25. src/open_clip/constants.py +2 -0
  26. src/open_clip/factory.py +287 -0
  27. src/open_clip/hf_configs.py +60 -0
  28. src/open_clip/hf_model.py +164 -0
  29. src/open_clip/loss.py +121 -0
  30. src/open_clip/model.py +413 -0
  31. src/open_clip/model_configs/RN101-quickgelu.json +22 -0
  32. src/open_clip/model_configs/RN101.json +21 -0
  33. src/open_clip/model_configs/RN50-quickgelu.json +22 -0
  34. src/open_clip/model_configs/RN50.json +21 -0
  35. src/open_clip/model_configs/RN50x16.json +21 -0
  36. src/open_clip/model_configs/RN50x4.json +21 -0
  37. src/open_clip/model_configs/RN50x64.json +21 -0
  38. src/open_clip/model_configs/ViT-B-16-plus-240.json +16 -0
  39. src/open_clip/model_configs/ViT-B-16-plus.json +16 -0
  40. src/open_clip/model_configs/ViT-B-16.json +16 -0
  41. src/open_clip/model_configs/ViT-B-32-plus-256.json +16 -0
  42. src/open_clip/model_configs/ViT-B-32-quickgelu.json +17 -0
  43. src/open_clip/model_configs/ViT-B-32.json +16 -0
  44. src/open_clip/model_configs/ViT-H-14.json +17 -0
  45. src/open_clip/model_configs/ViT-H-16.json +17 -0
  46. src/open_clip/model_configs/ViT-L-14-280.json +16 -0
  47. src/open_clip/model_configs/ViT-L-14-336.json +16 -0
  48. src/open_clip/model_configs/ViT-L-14.json +16 -0
  49. src/open_clip/model_configs/ViT-L-16-320.json +16 -0
  50. src/open_clip/model_configs/ViT-L-16.json +16 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ figures/*.jpg filter=lfs diff=lfs merge=lfs -text
37
+ figures/*.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project Specific
2
+ data/
3
+ data
4
+ datalink
5
+
6
+ # Byte-compiled / optimized / DLL files
7
+ __pycache__/
8
+ *.py[cod]
9
+ *$py.class
10
+
11
+ # C extensions
12
+ *.so
13
+
14
+ # Distribution / packaging
15
+ .Python
16
+ build/
17
+ develop-eggs/
18
+ dist/
19
+ downloads/
20
+ eggs/
21
+ .eggs/
22
+ lib/
23
+ lib64/
24
+ parts/
25
+ sdist/
26
+ var/
27
+ wheels/
28
+ share/python-wheels/
29
+ *.egg-info/
30
+ .installed.cfg
31
+ *.egg
32
+ MANIFEST
33
+
34
+ # PyInstaller
35
+ # Usually these files are written by a python script from a template
36
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
37
+ *.manifest
38
+ *.spec
39
+
40
+ # Installer logs
41
+ pip-log.txt
42
+ pip-delete-this-directory.txt
43
+
44
+ # Unit test / coverage reports
45
+ htmlcov/
46
+ .tox/
47
+ .nox/
48
+ .coverage
49
+ .coverage.*
50
+ .cache
51
+ nosetests.xml
52
+ coverage.xml
53
+ *.cover
54
+ *.py,cover
55
+ .hypothesis/
56
+ .pytest_cache/
57
+ cover/
58
+
59
+ # Translations
60
+ *.mo
61
+ *.pot
62
+
63
+ # Django stuff:
64
+ *.log
65
+ local_settings.py
66
+ db.sqlite3
67
+ db.sqlite3-journal
68
+
69
+ # Flask stuff:
70
+ instance/
71
+ .webassets-cache
72
+
73
+ # Scrapy stuff:
74
+ .scrapy
75
+
76
+ # Sphinx documentation
77
+ docs/_build/
78
+
79
+ # PyBuilder
80
+ .pybuilder/
81
+ target/
82
+
83
+ # Jupyter Notebook
84
+ .ipynb_checkpoints
85
+
86
+ # IPython
87
+ profile_default/
88
+ ipython_config.py
89
+
90
+ # pyenv
91
+ # For a library or package, you might want to ignore these files since the code is
92
+ # intended to run in multiple environments; otherwise, check them in:
93
+ # .python-version
94
+
95
+ # pipenv
96
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
98
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
99
+ # install all needed dependencies.
100
+ #Pipfile.lock
101
+
102
+ # poetry
103
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
104
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
105
+ # commonly ignored for libraries.
106
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
107
+ #poetry.lock
108
+
109
+ # pdm
110
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
111
+ #pdm.lock
112
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
113
+ # in version control.
114
+ # https://pdm.fming.dev/#use-with-ide
115
+ .pdm.toml
116
+
117
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
118
+ __pypackages__/
119
+
120
+ # Celery stuff
121
+ celerybeat-schedule
122
+ celerybeat.pid
123
+
124
+ # SageMath parsed files
125
+ *.sage.py
126
+
127
+ # Environments
128
+ .env
129
+ .venv
130
+ env/
131
+ venv/
132
+ ENV/
133
+ env.bak/
134
+ venv.bak/
135
+
136
+ # Spyder project settings
137
+ .spyderproject
138
+ .spyproject
139
+
140
+ # Rope project settings
141
+ .ropeproject
142
+
143
+ # mkdocs documentation
144
+ /site
145
+
146
+ # mypy
147
+ .mypy_cache/
148
+ .dmypy.json
149
+ dmypy.json
150
+
151
+ # Pyre type checker
152
+ .pyre/
153
+
154
+ # pytype static type analyzer
155
+ .pytype/
156
+
157
+ # Cython debug symbols
158
+ cython_debug/
159
+ **private/*
160
+ **ego_base_link/*
161
+ **.vscode/*
162
+
163
+
164
+ # PyCharm
165
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
166
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
167
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
168
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
169
+ #.idea/
LLM_stage.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluate a model on the egoschema dataset using LVNet (captions pre-generated)
3
+
4
+ Sample Run:
5
+
6
+ python3 LLM_stage.py \
7
+ --output-dir ego_base_link \
8
+ --captions data/ES_captions_gpt4o.jsonl \
9
+ --per-vid-captions 12 \
10
+ --gptmodel "gpt-4o" \
11
+ --temperature 0.0
12
+ """
13
+
14
+ import argparse
15
+ import json
16
+ import os
17
+
18
+ from tqdm import tqdm
19
+
20
+ from src.run_gpt import run_gpt
21
+
22
+ # You may add multiple keys to run parallel calls
23
+ dict_api = {
24
+ "api_key": "ADD",
25
+ }
26
+
27
+
28
+ _PROMPT_TEMPLATE = (
29
+ "Here are descriptions of the video frames at specific times, noted in seconds."
30
+ "\n\n{Putdesc}.\n\nThe descriptions of the frames conclude. Think step-by-step"
31
+ " and I request your selection of the most appropriate response to the following"
32
+ " question\n\nQuestion:\n{Putquestion}\n\nOptions:\n{AllOptions}"
33
+ )
34
+
35
+
36
+ def eval_model(args):
37
+ # change split to split
38
+ captions_path, data_path, split, gptmodel, temp, base_dir, job_name = (
39
+ args.captions,
40
+ args.data,
41
+ args.per_vid_captions,
42
+ args.gptmodel,
43
+ args.temperature,
44
+ args.output_dir,
45
+ args.job_name,
46
+ )
47
+
48
+ prompt = _PROMPT_TEMPLATE
49
+
50
+ os.makedirs(base_dir, exist_ok=True)
51
+ output_dir = f"{base_dir}/egoschema/{job_name}"
52
+ output_dir = os.path.expanduser(output_dir)
53
+ os.makedirs(output_dir, exist_ok=True)
54
+
55
+ save_name = captions_path.rsplit("/", 2)[-1].replace(".jsonl", "")
56
+ output_summary_path = f"{output_dir}/{save_name}.jsonl"
57
+ print(f"Saving outputs to:{output_summary_path}")
58
+ output_summary = open(output_summary_path, "w")
59
+
60
+ input_summary = [
61
+ json.loads(q) for q in open(os.path.expanduser(captions_path), "r")
62
+ ]
63
+ dataset = json.load(open(os.path.expanduser(data_path), "r"))
64
+ input_len = len(input_summary)
65
+ assert (
66
+ input_len % split == 0
67
+ ), f"input_len%split:{input_len%split}, input_len:{input_len}, split:{split}"
68
+ groups = input_len // split
69
+
70
+ final_prompts = []
71
+ final_info = []
72
+ for i in tqdm(range(groups)):
73
+ sidx = i * split
74
+ eidx = (i + 1) * split
75
+
76
+ desc = ""
77
+ timeline = []
78
+ for idx, e in enumerate(input_summary[sidx:eidx]):
79
+ cur_data = dataset[e["q_uid"]]
80
+ desc += e["answer"] + " "
81
+ timeline.append(e["timeline"])
82
+
83
+ if idx == split - 1: # the last of split
84
+ action_0 = cur_data["option 0"]
85
+ action_1 = cur_data["option 1"]
86
+ action_2 = cur_data["option 2"]
87
+ action_3 = cur_data["option 3"]
88
+ action_4 = cur_data["option 4"]
89
+
90
+ option_list = ""
91
+ option_number_candidate = ["one", "two", "three", "four", "five"]
92
+ option_number = option_number_candidate[4]
93
+ AllOptNumber = "option 0, option 1, option 2, option 3, option 4"
94
+ FocusOptions = ""
95
+
96
+ alloptions = f"option 0: {action_0}\noption 1: {action_1}\noption 2: {action_2}\noption 3: {action_3}\noption 4: {action_4}"
97
+ option_list = f"option 0: {action_0}\noption 1: {action_1}\noption 2: {action_2}\noption 3: {action_3}\noption 4: {action_4}"
98
+
99
+ FocusOptions += "option 0, option 1, option 2, option 3, option 4"
100
+
101
+ question = cur_data["question"]
102
+
103
+ curr_prompt = (
104
+ prompt.replace("{Putdesc}", desc)
105
+ .replace("{Putquestion}", question)
106
+ .replace("{Putoptions}", option_list)
107
+ .replace("{PutOptNumber}", option_number)
108
+ .replace("{FocusOptions}", FocusOptions)
109
+ .replace("{AllOptions}", alloptions)
110
+ .replace("{PutAllOptNumber}", AllOptNumber)
111
+ )
112
+
113
+ final_prompts.append(curr_prompt)
114
+
115
+ CA_option = {}
116
+ if "CA" in cur_data:
117
+ CA_option = {"CA": cur_data["CA"]}
118
+
119
+ info = {
120
+ "q_uid": e["q_uid"],
121
+ "prompt": curr_prompt,
122
+ "timeline": timeline,
123
+ "question": question,
124
+ "option 0": action_0,
125
+ "option 1": action_1,
126
+ "option 2": action_2,
127
+ "option 3": action_3,
128
+ "option 4": action_4,
129
+ } | CA_option
130
+
131
+ final_info.append(info)
132
+
133
+ output_VLM = run_gpt(
134
+ texts=final_prompts,
135
+ api_keys=list(dict_api.values()),
136
+ max_tokens=2000,
137
+ model=gptmodel,
138
+ temperature=temp,
139
+ num_threads=20, # Tune this
140
+ backoff_time=1 * 60,
141
+ silent=False,
142
+ dataset="egoschema",
143
+ )
144
+
145
+ output_VLM = list(output_VLM)
146
+
147
+ for q_idx, info in enumerate(tqdm(final_info)): # prompt_list = # Q&C
148
+ info["answer"] = output_VLM[q_idx]
149
+ output_summary.write(json.dumps(info) + "\n")
150
+
151
+ # finish the summarization for the current question
152
+ output_summary.close()
153
+ print(f"output_summary_path:{output_summary_path}")
154
+
155
+
156
+ if __name__ == "__main__":
157
+ parser = argparse.ArgumentParser()
158
+ parser.add_argument("--output-dir", type=str)
159
+ parser.add_argument("--job-name", type=str, default="run001")
160
+ parser.add_argument("--captions", type=str, default="data/ES_captions_gpt4o.jsonl")
161
+ parser.add_argument("--data", type=str, default="data/ES_qa_data.json")
162
+ parser.add_argument("--per-vid-captions", type=int, default=12)
163
+ parser.add_argument("--gptmodel", type=str, default="gpt-3.5-turbo-1106")
164
+ parser.add_argument("--temperature", type=float, default=None)
165
+
166
+ args = parser.parse_args()
167
+
168
+ eval_model(args)
README.md CHANGED
@@ -1,3 +1,122 @@
1
  ---
2
  license: apache-2.0
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
+ tags:
4
+ - computer-vision
5
+ - video-question-answering
6
+ - zero-shot
7
+ - 9 pages workshop at neurips2024
8
  ---
9
+
10
+ # LVNet
11
+
12
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/too-many-frames-not-all-useful-efficient/zero-shot-video-question-answer-on-egoschema-1)](https://paperswithcode.com/sota/zero-shot-video-question-answer-on-egoschema-1?p=too-many-frames-not-all-useful-efficient)
13
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/too-many-frames-not-all-useful-efficient/zero-shot-video-question-answer-on-intentqa)](https://paperswithcode.com/sota/zero-shot-video-question-answer-on-intentqa?p=too-many-frames-not-all-useful-efficient)
14
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/too-many-frames-not-all-useful-efficient/zero-shot-video-question-answer-on-next-qa)](https://paperswithcode.com/sota/zero-shot-video-question-answer-on-next-qa?p=too-many-frames-not-all-useful-efficient)
15
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/too-many-frames-not-all-useful-efficient/zero-shot-video-question-answer-on-egoschema)](https://paperswithcode.com/sota/zero-shot-video-question-answer-on-egoschema?p=too-many-frames-not-all-useful-efficient)
16
+
17
+ Official Code for **_Too Many Frames, Not All Useful_: Efficient Strategies for Long-Form Video QA**
18
+ Accepted to the 9 pages Workshop on Video-Language Models at *NeurIPS 2024*
19
+
20
+ [**Paper Link**](https://arxiv.org/abs/2406.09396)
21
+
22
+ ---
23
+
24
+ ## Abstract
25
+
26
+ Long-form videos that span wide temporal intervals are highly information-redundant and contain multiple distinct events or entities that are often loosely-related. Therefore, when performing long-form video question answering (LVQA), all information necessary to generate a correct response can often be contained within a small subset of frames. Recent literature explores the use of large language models (LLMs) in LVQA benchmarks, achieving exceptional performance, while relying on vision language models (VLMs) to convert all visual content within videos into natural language. Such VLMs often independently caption a large number of frames uniformly sampled from long videos, which is not efficient and can mostly be redundant.
27
+
28
+ Questioning these decision choices, we explore optimal strategies for key-frame selection and sequence-aware captioning that can significantly reduce these redundancies. We propose two novel approaches that improve each aspect, namely **Hierarchical Keyframe Selector** and **Sequential Visual LLM**. Our resulting framework, called **LVNet**, achieves state-of-the-art performance across three benchmark LVQA datasets.
29
+
30
+ ---
31
+
32
+ ## Accuracy vs. Captions on the EgoSchema Subset
33
+
34
+ - LVNet shows a SOTA **68.2% accuracy**, using only **12** captions.
35
+ - This highlights the quality of keyframes from the Hierarchical Keyframe Selector.
36
+
37
+ <img src="figures/hkf_graph.png" alt="acc_captions" width="600"/>
38
+
39
+ ---
40
+
41
+ ## Hierarchical Keyframe Selector: Structural Overview
42
+
43
+ - **Overall strategy**: Generate captions via a Hierarchical Keyframe Selector, then feed them to a separate LLM to answer the question.
44
+ - **Temporal Scene Clustering (TSC)**: Divides the long video into multiple scenes, enabling per-scene subsampling.
45
+ - **Coarse Keyframe Detector (CKD)**: Selects frames best-aligned with keywords relevant to the query.
46
+ - **Fine Keyframe Detector (FKD)**: Refines the keyword alignments within a smaller set of frames via templated visual prompting.
47
+
48
+ <img src="figures/architecture.png" alt="architecture" width="600"/>
49
+
50
+ ---
51
+
52
+ ## Operational Visualization of HKS
53
+
54
+ - **Temporal Scene Clustering (TSC)**: 900 frames get clustered into scenes, then uniformly subsampled to produce about 280 frames.
55
+ - **Coarse Keyframe Detector (CKD)**: Among those, 32 frames are selected, based on alignment with query keywords.
56
+ - **Visual Templating**: The coarsely refined keyframes are ordered by confidence and temporal order, grouped into 4 chunks of 8 frames.
57
+ - **Fine Keyframe Detector (FKD)**: Selects the final 12 frames via further keyword alignment checks.
58
+
59
+ <img src="figures/architecture_qualitative.png" alt="hks_visualization" width="800"/>
60
+
61
+ ---
62
+
63
+ ## Experiments: EgoSchema, NExT-QA, and IntentQA
64
+
65
+ - LVNet achieves **61.1%**, **72.9%**, and **71.7%** on the three benchmarks, respectively, using **just 12** frames—on par with or exceeding models that use **many** more frames.
66
+ - Models with specialized video-caption pretraining or significantly more captions are shown in gray/light green for fairness comparison.
67
+
68
+ <img src="tables/table_combined.png" alt="egoschema_table" width="900"/>
69
+
70
+ ---
71
+
72
+ ## Comparison with Other Keyframe Selection Methods
73
+
74
+ Below is a side-by-side comparison of **LVNet** and **VideoAgent**.
75
+ - **LVNet** starts with uniform sampling but then refines keyframes via TSC, CKD, and FKD. This yields 12 frames, 8 of which show “phone usage,” the correct activity.
76
+ - **VideoAgent** continues uniform sampling due to insufficient initial frames, resulting in 0 relevant frames out of 9 and an incorrect final answer.
77
+
78
+ <img src="figures/KFSelectionFlowComparison.jpg" alt="kf_selection_flow" width="900"/>
79
+
80
+ ---
81
+
82
+ ## Evaluation
83
+
84
+ ### Generating Answers Using LLM
85
+
86
+ You can quickly run the LLM to produce answers once you have the keyframe-based captions:
87
+
88
+ 1. **Download the Captions for Dataset**
89
+
90
+ * EgoSchema: `bash scripts/get_ES_captions.sh `
91
+
92
+ 2. **Run LLM** `bash scripts/eval_ES.sh`
93
+
94
+ ### Generate captions using our provided modules
95
+ #### Hierarchical Keyframe Selector (HKS)
96
+ - Temporal Scene Clustering (TSC): temporalSceneClustering.py </br>
97
+ - Coarse Keyframe Detector (CKD): coarseKeyframeDetector.py </br>
98
+ - Fine Keyframe detector (FKD): fineKeyframeDetector.py </br>
99
+
100
+ 1. **EgoSchema keyframe selection from images**: `bash config/run.sh `
101
+
102
+ 2. **Generate captions based on the keyframes**: `bash scripts/create_caption.sh`
103
+
104
+ ## Data
105
+ ### Hierarchical Keyframe Selector hyper-parameters & paths
106
+ - [[LINK]](config/config.py)
107
+
108
+ ### coarseKeyframeDetector.py CLIP model checkpoint
109
+ - ICCV 2023 [Perceptual Grouping in Contrastive Vision-Language Models](https://arxiv.org/abs/2210.09996)
110
+ - Checkpoint: [Download](https://github.com/kahnchana/clippy/releases/download/v1.0/clippy_5k.pt)
111
+
112
+
113
+ # Citation
114
+ ```
115
+ @inproceedings{Park2024TooMF,
116
+ title={Too Many Frames, not all Useful: Efficient Strategies for Long-Form Video QA},
117
+ author={Jongwoo Park and Kanchana Ranasinghe and Kumara Kahatapitiya and Wonjeong Ryoo and Donghyun Kim and Michael S. Ryoo},
118
+ year={2024}
119
+ }
120
+ ```
121
+
122
+ For more details and updates, please see our [GitHub Repository](https://github.com/jongwoopark7978/LVNet).
VLM_stage.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import base64
4
+ import random
5
+ import argparse
6
+
7
+ import natsort
8
+
9
+ from PIL import Image
10
+ from tqdm import tqdm
11
+
12
+ import torch
13
+ from torch.utils.data import Dataset, DataLoader
14
+
15
+ from src.run_gpt import run_gpt
16
+
17
+ random.seed(10)
18
+ dict_api = {
19
+ "api_key":"ADD",
20
+ }
21
+
22
+
23
+ class CustomDatasetGPT(Dataset):
24
+ def __init__(self, questions, num_kf):
25
+ self.questions = questions
26
+ self.num_kf = num_kf
27
+
28
+ def __getitem__(self, index):
29
+ line = self.questions[index]
30
+ group = 4
31
+ newnum_per_group = self.num_kf // group
32
+ oldnum_per_group = len(line["VLM_path"]) // group
33
+ assert oldnum_per_group >= newnum_per_group, f"oldnum_per_group:{oldnum_per_group} is smaller than newnum_per_group:{newnum_per_group}"
34
+
35
+ new_kf_paths = []
36
+ new_kf_timelines = []
37
+ for i in range(group):
38
+ start_index = i * oldnum_per_group
39
+ end_index = start_index + oldnum_per_group
40
+
41
+ sub_kf_paths = line["VLM_path"][start_index:min(end_index, len(line["VLM_path"]))]
42
+ sub_kf_timelines = line["VLM_timeline"][start_index:min(end_index, len(line["VLM_timeline"]))]
43
+ new_kf_paths.extend(sub_kf_paths[:newnum_per_group])
44
+ new_kf_timelines.extend(sub_kf_timelines[:newnum_per_group])
45
+
46
+ kf_paths = natsort.natsorted(new_kf_paths)
47
+ kf_timelines = natsort.natsorted(new_kf_timelines)
48
+
49
+ images = []
50
+ images_base64 = []
51
+
52
+ for e in kf_paths:
53
+ images.append(Image.open(e).convert('RGB'))
54
+ images_base64.append(encode_image(e))
55
+
56
+ return images_base64, kf_paths, kf_timelines
57
+
58
+ def __len__(self):
59
+ return len(self.questions)
60
+
61
+
62
+ def encode_image(image_path):
63
+ with open(image_path, "rb") as image_file:
64
+ return base64.b64encode(image_file.read()).decode('utf-8')
65
+
66
+ def create_data_loader_gpt(questions, num_kf, batch_size=1, num_workers=4):
67
+ assert batch_size == 1, "batch_size must be 1"
68
+
69
+ dataset = CustomDatasetGPT(questions, num_kf)
70
+ data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
71
+
72
+ return data_loader, dataset
73
+
74
+ def eval_model(args):
75
+ base_dir, question_path, vlm, num_kf, temp = (
76
+ args.output_dir,
77
+ args.question_path,
78
+ args.gptmodel,
79
+ args.num_kf,
80
+ args.temp,
81
+ )
82
+
83
+ questions = [json.loads(q) for q in open(os.path.expanduser(question_path), "r")]
84
+
85
+ fname = question_path.split('/')[-1]
86
+ answer_path = f"{base_dir}/egoschema/{num_kf}/{fname}"
87
+ os.makedirs(os.path.dirname(answer_path), exist_ok=True)
88
+ print(f"question_path:{question_path}\nanswer_path:{answer_path}")
89
+
90
+ ans_file = open(answer_path, "w")
91
+ data_loader, dataset = create_data_loader_gpt(questions, num_kf)
92
+
93
+ for (base64_image, kf_paths, kf_timelines), line in tqdm(zip(data_loader, questions), total=len(questions)):
94
+ idx = line["q_uid"]
95
+ CA = line["CA"] if "CA" in line else None
96
+ option0 = line['option 0']
97
+ option1 = line['option 1']
98
+ option2 = line['option 2']
99
+ option3 = line['option 3']
100
+ option4 = line['option 4']
101
+ question = line['question']
102
+
103
+ lenwords = "50"
104
+ prompt = f"'C' stands for the cameraman. Describe the activity depicted in this first-person perspective image in less than {lenwords} words. In your answer, don't mention that the image is in first-person perspective, as we already know this."
105
+ prompts = [prompt] * num_kf
106
+
107
+ image_paths = [e[0] for e in kf_paths]
108
+ image_timelines = [e[0] for e in kf_timelines]
109
+
110
+ output_VLM = run_gpt(
111
+ images=image_paths,
112
+ texts=prompts,
113
+ api_keys=list(dict_api.values()),
114
+ max_tokens=2000,
115
+ model=vlm,
116
+ temperature=temp,
117
+ num_threads=20, # Tune this
118
+ backoff_time=1 * 60,
119
+ silent=False,
120
+ dataset="egoschema",
121
+ verbose=False,
122
+ )
123
+
124
+ output_VLM = list(output_VLM)
125
+
126
+ for j, e in enumerate(image_timelines):
127
+ line_frame = line.copy()
128
+ line_frame["answer"] = f"At {str(e)} seconds, {output_VLM[j]}"
129
+ line_frame["AR-VLM_model_id"] = vlm
130
+ line_frame["AR-VLM_prompt"] = prompts[j]
131
+ line_frame["timeline"] = float(e)
132
+ line_frame["frame_idx"] = j
133
+ line_frame["image_paths"] = image_paths
134
+
135
+ if "imgidx_kw_dict" in line_frame.keys(): line_frame.pop("imgidx_kw_dict")
136
+ if "google_drive_id" in line_frame.keys(): line_frame.pop("google_drive_id")
137
+
138
+ ans_file.write(json.dumps(line_frame)+"\n")
139
+
140
+ print(f"question.\nquestion_path:{question_path}\nanswer_path:{answer_path}")
141
+
142
+ ans_file.close()
143
+ return "job is done"
144
+
145
+
146
+ if __name__ == "__main__":
147
+ parser = argparse.ArgumentParser()
148
+ parser.add_argument("--output-dir", type=str)
149
+ parser.add_argument("--question-path", type=str, default="")
150
+ parser.add_argument("--num-kf", type=int)
151
+ parser.add_argument("--gptmodel", type=str, default="gpt-4o")
152
+ parser.add_argument("--temp", type=float, default=None)
153
+ args = parser.parse_args()
154
+ eval_model(args)
coarseKeyframeDetector.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import shutil
4
+
5
+ from tqdm import tqdm
6
+ from PIL import Image
7
+
8
+ import natsort
9
+
10
+ import numpy as np
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from torch.utils.data import Dataset, DataLoader
15
+
16
+ from config import config
17
+ from src.open_clip import create_model_and_transforms
18
+
19
+
20
+ class loading_img(Dataset):
21
+ def __init__(self, img_list):
22
+ self.img_list = img_list
23
+
24
+ def __len__(self):
25
+ return len(self.img_list)
26
+
27
+ def __getitem__(self, idx):
28
+ return self.img_list[idx].squeeze(0)
29
+
30
+ class CustomDataset(Dataset):
31
+ def __init__(self, questions, clippy, preprocess_val, clip_size, base_dir):
32
+ self.questions = questions
33
+ self.clippy = clippy
34
+ self.clip_size = clip_size
35
+ self.preprocess_val = preprocess_val
36
+ self.device = next(clippy.parameters()).device
37
+ self.base_dir = base_dir
38
+
39
+ def __getitem__(self, index):
40
+ line = self.questions[index]
41
+ images_dir = f"{line['q_uid']}"
42
+
43
+ if line["Activity"] == "" or ("Activity" not in line): ref1 = []
44
+
45
+ else:
46
+ if isinstance(line["Activity"], list): ref1 = line["Activity"]
47
+ else: ref1 = line["Activity"].split(', ')
48
+
49
+ keywords = ref1
50
+ clip_size = self.clip_size
51
+ clippy = self.clippy
52
+ preprocess_val = self.preprocess_val
53
+
54
+ images = []
55
+ timelines = []
56
+ timelines_int = []
57
+ img_names = []
58
+ image_list = []
59
+
60
+ nframes_paths = line["filepath"]
61
+ total_len = len(nframes_paths)
62
+ nframes_paths = natsort.natsorted(nframes_paths)
63
+
64
+ img_paths = []
65
+ for img_path in nframes_paths:
66
+ img_path = self.base_dir + "/" + "/".join(img_path.split("/")[-4:])
67
+ img_paths.append(img_path)
68
+
69
+ img_names.append(img_path.split('/')[-1].split('.')[0])
70
+ cur_img = Image.open(img_path).resize(clip_size)
71
+ image_list.append(preprocess_val(cur_img))
72
+
73
+ timeline = f"{img_names[-1].split('_')[-2]}.{img_names[-1].split('_')[-1]} seconds"
74
+ timeline_int = float(f"{img_names[-1].split('_')[-2]}.{img_names[-1].split('_')[-1]}")
75
+ timelines.append(timeline)
76
+ timelines_int.append(timeline_int)
77
+
78
+ return image_list, img_paths, timelines, timelines_int, keywords, img_names
79
+
80
+ def __len__(self):
81
+ return len(self.questions)
82
+
83
+
84
+ def disable_torch_init():
85
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
86
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
87
+
88
+ def SortSimilarity(q_uid, simmat, keywords, nimgtokens, nframes_paths, maximgslen):
89
+ sort_simmat, sort_idx = torch.sort(simmat, dim=-1, descending=True)
90
+ sort_idx = torch.floor(sort_idx/nimgtokens).to(int)
91
+
92
+ curimgslen = 0
93
+
94
+ imgidx_kw_dict = dict()
95
+ numrow, numcol = sort_simmat.shape
96
+
97
+ row_col_list = [0 for _ in range(numrow)]
98
+ token = True
99
+
100
+ while token:
101
+ j = 0
102
+ while j < numrow:
103
+ k = 0
104
+ i = row_col_list[j]
105
+
106
+ while k < numcol-i:
107
+ col_idx = i+k
108
+ k += 1
109
+
110
+ simvalue = sort_simmat[j, col_idx].item()
111
+ img_idx = sort_idx[j, col_idx].item()
112
+
113
+ curr_keyword = keywords[j]
114
+ curr_kfpath = nframes_paths[img_idx]
115
+
116
+ if img_idx in imgidx_kw_dict: continue
117
+
118
+ else:
119
+ imgidx_kw_dict[img_idx] = {"kw": curr_keyword, "simvalue": simvalue, "kf_path": curr_kfpath, "kw_others": []}
120
+ curimgslen += 1
121
+
122
+ row_col_list[j] = col_idx + 1
123
+ if curimgslen == maximgslen: return imgidx_kw_dict
124
+ else: break
125
+
126
+ j += 1
127
+
128
+ if sum(row_col_list) >= numrow*(numcol-1): token = False
129
+
130
+ def create_data_loader(questions, clippy, preprocess_val, clip_size, base_dir, batch_size=1, num_workers=16):
131
+ assert batch_size == 1, "batch_size must be 1"
132
+ dataset = CustomDataset(questions, clippy, preprocess_val, clip_size, base_dir)
133
+ data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
134
+ return data_loader
135
+
136
+ def eval_model():
137
+ disable_torch_init()
138
+ question_path, maximgslen, base_dir, concatname, modelpath, answerpath, concatdir = config.question_path, config.maximgslen, config.base_dir, config.concatname, config.modelpath, config.answerpath, config.concatdir
139
+
140
+ pretrained_ckpt = f"{modelpath}"
141
+ clippy, preprocess_train, preprocess_val = create_model_and_transforms(
142
+ "clippy-B-16",
143
+ device="cuda",
144
+ pretrained=pretrained_ckpt
145
+ )
146
+ clip_size = (224,224)
147
+ device = next(clippy.parameters()).device
148
+
149
+ questions = [json.loads(q) for q in open(os.path.expanduser(question_path), "r")]
150
+
151
+ answer_path = f"{answerpath}"
152
+ print(f"\nquestion_path:{question_path}\nanswer_path:{answer_path}")
153
+ os.makedirs(os.path.dirname(answer_path), exist_ok=True)
154
+
155
+ with open(answer_path, "w") as ans_file:
156
+ data_loader = create_data_loader(questions, clippy, preprocess_val, clip_size, base_dir)
157
+ concatimg_dir_base = f"{concatdir}"
158
+
159
+ with torch.no_grad():
160
+ for (image_list, nframes_paths, timelines, timelines_int, keywords, img_names), line in tqdm(zip(data_loader, questions), total=len(questions)):
161
+ q_uid = line["q_uid"]
162
+ CA = line["CA"] if "CA" in line else None
163
+ option0 = line['option 0']
164
+ option1 = line['option 1']
165
+ option2 = line['option 2']
166
+ option3 = line['option 3']
167
+ option4 = line['option 4']
168
+ question = line['question']
169
+
170
+ pastobj = None
171
+ past_VLM_path = None
172
+ past_VLM_timeline = None
173
+
174
+ img_embed = []
175
+ nframes_paths = [e[0] for e in nframes_paths]
176
+
177
+ image_set = loading_img(image_list)
178
+ image_loader = DataLoader(image_set, batch_size=64, shuffle=False, num_workers=16)
179
+ for e in image_loader: img_embed.append(clippy.encode_image(e.to(device), pool=False)[:, 1:])
180
+ img_embed = torch.concat(img_embed, dim=0)
181
+
182
+ limit_keywords = config.limit_keywords
183
+ keywords = [e[0] for e in keywords][:limit_keywords]
184
+ keyword_embed = clippy.text.encode(keywords, convert_to_tensor=True)
185
+
186
+ nframe, nimgtokens, channels = img_embed.shape
187
+ keyword_embed = keyword_embed.unsqueeze(1)
188
+ img_embed = img_embed.flatten(0, 1).unsqueeze(0)
189
+
190
+ simmat = F.cosine_similarity(keyword_embed, img_embed, dim=-1).to(torch.float)
191
+ imgidx_kw_dict = SortSimilarity(q_uid, simmat, keywords, nimgtokens, nframes_paths, maximgslen=maximgslen)
192
+
193
+ # order of simvalue
194
+ simvalue = np.array([e["simvalue"] for e in imgidx_kw_dict.values()])
195
+ ordered_idx = np.argsort(simvalue)
196
+ simvalue = simvalue[ordered_idx]
197
+ kf_paths = np.array([e["kf_path"] for e in imgidx_kw_dict.values()])[ordered_idx]
198
+ matchingkw = np.array([e["kw"] for e in imgidx_kw_dict.values()])[ordered_idx]
199
+
200
+ #order by timeline
201
+ time_kf_paths = np.array(kf_paths[:16])
202
+ timelines_int = np.array([float(f"{e.replace('.jpg', '').split('/')[-1].split('_')[1]}" + "."+ f"{e.replace('.jpg', '').split('/')[-1].split('_')[2]}") for e in time_kf_paths])
203
+ time_ordered_idx = np.argsort(timelines_int)
204
+
205
+ timelines_int = timelines_int[time_ordered_idx]
206
+ time_simvalue = np.array(simvalue[:16])[time_ordered_idx]
207
+ time_kf_paths = np.array(time_kf_paths)[time_ordered_idx]
208
+ time_matchingkw = np.array(matchingkw[:16])[time_ordered_idx]
209
+
210
+ simvalue[:16] = time_simvalue
211
+ kf_paths[:16] = time_kf_paths
212
+ matchingkw[:16] = time_matchingkw
213
+
214
+ segment_timeline = f"{timelines[0][0].split(' seconds')[0]}-{timelines[-1][0].split(' seconds')[0]}"
215
+
216
+ imgw, imgh = Image.open(kf_paths[0]).size
217
+ redwidth = 20
218
+ newimgw, newimgh = (imgw+redwidth) * 4 + redwidth, (imgh+redwidth) * 2 + redwidth
219
+ concatimg = np.zeros((newimgh, newimgw, 3), dtype=np.uint8)
220
+ concatimg[:, :, 0] = 255
221
+ concatimglist = []
222
+ concatimg_dir = f"{concatimg_dir_base}/{q_uid}"
223
+
224
+ for i, cpath in enumerate(kf_paths):
225
+ cur_img = np.array(Image.open(cpath))
226
+ whole_frame = 8
227
+ remainder = i % whole_frame
228
+ rowremainder = i % (whole_frame//2)
229
+ startwidth = redwidth + (imgw + redwidth)*rowremainder
230
+ endwidth = startwidth + imgw
231
+
232
+ if remainder / whole_frame < 0.5: concatimg[redwidth:redwidth+imgh, startwidth:endwidth, :] = cur_img
233
+ else: concatimg[redwidth+imgh+redwidth:newimgh-redwidth, startwidth:endwidth, :] = cur_img
234
+
235
+ if remainder == whole_frame - 1: concatimglist.append(Image.fromarray(concatimg))
236
+
237
+ if os.path.exists(concatimg_dir): shutil.rmtree(concatimg_dir)
238
+ os.makedirs(f"{concatimg_dir}", exist_ok=True)
239
+ for i, img in enumerate(concatimglist): img.save(f"{concatimg_dir}/concat_{i}.jpg")
240
+
241
+ line["kf_paths"] = kf_paths.tolist()
242
+ line["keywords"] = matchingkw.tolist()
243
+ line["simvalue"] = simvalue.tolist()
244
+ line["imgidx_kw_dict"] = imgidx_kw_dict
245
+ line["segment_timeline"] = segment_timeline
246
+ line["concatimg_dir"] = concatimg_dir
247
+
248
+ ans_file.write(json.dumps(line) + "\n")
249
+
250
+ print(f"question_path:{question_path}\nanswer_path:{answer_path}")
251
+
252
+
253
+ if __name__ == "__main__":
254
+ eval_model()
config/config.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # base
2
+ dict_api = {
3
+ "api_key":"ADD",
4
+ }
5
+ base_dir = "your_path" # your data path ex) img_folder_path/egoschema
6
+
7
+
8
+ # scene clustering
9
+ divlam = 12
10
+ f_path = "your_path" # files from keywords dir
11
+ q_path = "your_path" # files from questions dir
12
+ a_path = "your_path"
13
+ img_folder = "your_path" # your img folder path ex) img_folder_path/egoschema/frames_900_4531/q_uid/image_sec_millisec.jpg
14
+
15
+
16
+ # coarse key frame detector
17
+ maximgslen = 32
18
+ limit_keywords = 25
19
+ concatname = "LVnet"
20
+ modelpath = "your_path" # model path
21
+ question_path = "your_path" # recommend using the same path with scene clustering answer path
22
+ answerpath = f"{base_dir}/kwkfmatching/kf_{concatname}.jsonl" # kwkfmatching is not necessary.
23
+ concatdir = f"{base_dir}/kwkfmatching/concatimg_{concatname}" # kwkfmatching is not necessary.
24
+
25
+
26
+ # fine key frame detector
27
+ kf_vlm = "gpt-4o"
28
+ kf_temp = None
29
+ kf_num_select = 3
30
+ kf_num_input_imgs = 32
31
+ kf_question_path = "your_path" # recommend using the same path with coarse key frame detector answer path
32
+ kf_answer_path = f"{base_dir}/kf_VLM/kf_VLM{kf_num_input_imgs}sel{kf_num_select}_{kf_question_path.split('/')[-1].split('.')[0]}.jsonl" # kf_VLM is not necessary.
33
+
34
+
35
+ # fine key frame detector refine
36
+ refine_num_group = 4
37
+ refine_kflen = 12
38
+ refine_output_path = f"{base_dir}/kf_VLM/refine/" + kf_answer_path.split('/')[-1] # kf_VLM is not necessary.
config/run.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ devnum=${1:-0}
2
+ CUDA_VISIBLE_DEVICES=$devnum python3 temporalSceneClustering.py
3
+ CUDA_VISIBLE_DEVICES=$devnum python3 coarseKeyframeDetector.py
4
+ CUDA_VISIBLE_DEVICES=$devnum python3 fineKeyframeDetector.py
extractKeyword.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+
5
+ from tqdm import tqdm
6
+ from src.run_gpt import run_gpt
7
+
8
+ """
9
+ Extract keywords from the given question and options
10
+
11
+ Sample Run
12
+ python3 extractKeyword.py --output-dir ego_base_link --question questions/500questions.jsonl --gptmodel "gpt-4-1106-preview"
13
+
14
+ """
15
+
16
+
17
+ # You may add multiple keys if you want parallel calls
18
+ dict_api = {
19
+ "api_key": "ADD",
20
+ }
21
+
22
+ PROMPT = (
23
+ "Think step-by-step and for each option, identify all the specified activities. "
24
+ "Each description of activity should use active voice with plain verbs, contain fewer than six words, "
25
+ "and retains as many original terms from the options as possible.\n"
26
+ "Here are the options:\n\n"
27
+ "option 0: {Putop0}\n"
28
+ "option 1: {Putop1}\n"
29
+ "option 2: {Putop2}\n"
30
+ "option 3: {Putop3}\n"
31
+ "option 4: {Putop4}\n"
32
+ "option 5: {Putquestion}.\n"
33
+ "All the options were introduced. 'C' represents the camera operator in the options. "
34
+ "Your answer should follow the JSON format shown below and should only include the JSON result. "
35
+ "Do not output any warnings or notes under any circumstances. Instead, adhere strictly to the provided JSON format example.\n"
36
+ "This is one example output format.\n"
37
+ "{\"option 0\": [\"plays soccer\", \"go to school\"], \"option 1\": [\"go to the gym\", \"go to school\"], "
38
+ "\"option 2\": [\"go to school\", \"dry hair\"], \"option 3\": [\"plays basketball\", \"look at the tree\"], "
39
+ "\"option 4\": [\"plays soccer\", \"drop the ball\"], \"option 5\": [\"turn the table\", \"go to school\"]}"
40
+ )
41
+
42
+
43
+ def main(args):
44
+ # 1. Create output directories
45
+ os.makedirs(args.output_dir, exist_ok=True)
46
+ job_dir = os.path.join(args.output_dir, "extractedKeywords")
47
+ os.makedirs(job_dir, exist_ok=True)
48
+
49
+
50
+ # 2. Build the output file name (based on --question)
51
+ question_file_name = os.path.basename(args.question).replace(".jsonl", "")
52
+ output_summary_path = os.path.join(job_dir, f"{question_file_name}.jsonl")
53
+ print(f"Saving outputs to: {output_summary_path}")
54
+
55
+ # 3. Read the question file
56
+ with open(os.path.expanduser(args.question), "r") as f:
57
+ question_data = [json.loads(line) for line in f]
58
+
59
+ # 4. Construct final prompts
60
+ final_prompts = []
61
+ final_info = []
62
+ for entry in tqdm(question_data, desc="Building prompts"):
63
+ q_uid = entry["q_uid"]
64
+ # Insert each option + question into the embedded prompt
65
+ cur_prompt = (
66
+ PROMPT
67
+ .replace("{Putop0}", entry["option 0"])
68
+ .replace("{Putop1}", entry["option 1"])
69
+ .replace("{Putop2}", entry["option 2"])
70
+ .replace("{Putop3}", entry["option 3"])
71
+ .replace("{Putop4}", entry["option 4"])
72
+ .replace("{Putquestion}", entry["question"])
73
+ )
74
+
75
+ final_prompts.append(cur_prompt)
76
+
77
+ # Track data for JSON output
78
+ info = {
79
+ "q_uid": q_uid,
80
+ "prompt": cur_prompt,
81
+ "option 0": entry["option 0"],
82
+ "option 1": entry["option 1"],
83
+ "option 2": entry["option 2"],
84
+ "option 3": entry["option 3"],
85
+ "option 4": entry["option 4"],
86
+ "question": entry["question"],
87
+ }
88
+
89
+ # Include ground-truth label if present
90
+ if "CA" in entry:
91
+ info["CA"] = entry["CA"]
92
+
93
+ final_info.append(info)
94
+
95
+ # 5. Call GPT
96
+ print("Sending prompts to GPT. This may take a while...")
97
+ output_results = run_gpt(
98
+ texts=final_prompts,
99
+ api_keys=list(dict_api.values()),
100
+ max_tokens=2000,
101
+ model=args.gptmodel,
102
+ temperature=args.temperature,
103
+ num_threads=5, # Adjust as needed
104
+ backoff_time=60, # Adjust as needed
105
+ silent=False,
106
+ dataset="extractKeyword",
107
+ )
108
+
109
+ output_results = list(output_results)
110
+
111
+ # 6. Save results
112
+ with open(output_summary_path, "w") as outfile:
113
+ for i, info in enumerate(final_info):
114
+ info["answer"] = output_results[i]
115
+ outfile.write(json.dumps(info) + "\n")
116
+
117
+ print(f"Done! Results written to {output_summary_path}")
118
+
119
+
120
+ if __name__ == "__main__":
121
+ parser = argparse.ArgumentParser()
122
+ parser.add_argument("--output-dir", type=str, required=True,
123
+ help="Directory to store the resulting JSONL file.")
124
+ parser.add_argument("--question", type=str, required=True,
125
+ help="Path to the JSONL file with question data (e.g., 500questions.jsonl).")
126
+ parser.add_argument("--gptmodel", type=str, default="gpt-4-1106-preview",
127
+ help="The GPT model to call.")
128
+ parser.add_argument("--temperature", type=float, default=None,
129
+ help="Temperature parameter for GPT.")
130
+
131
+ args = parser.parse_args()
132
+ main(args)
figures/KFSelectionFlowComparison.jpg ADDED

Git LFS Details

  • SHA256: 7c2ef6619fb724aa008ea05ea8e969737ca739a749d9a1da39d27d0de2031b4a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.74 MB
figures/architecture.png ADDED

Git LFS Details

  • SHA256: 251d2f2286745187b63adb270ca6f350fbd0cc73969a36c27b527b2372d41f91
  • Pointer size: 130 Bytes
  • Size of remote file: 27.2 kB
figures/architecture_qualitative.png ADDED

Git LFS Details

  • SHA256: dadcab8c10896e275c9ea59852a21e6b7b89b9d1417f8c2621a7235cd75da6c0
  • Pointer size: 132 Bytes
  • Size of remote file: 6.45 MB
figures/hkf_graph.png ADDED

Git LFS Details

  • SHA256: 5d872d1e442d90739a2b46544988790dcc64bfad53d22267d08fc03bb524ef3d
  • Pointer size: 131 Bytes
  • Size of remote file: 234 kB
fineKeyframeDetector.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import base64
4
+ import natsort
5
+
6
+ import numpy as np
7
+
8
+ from PIL import Image
9
+ from tqdm import tqdm
10
+
11
+ import torch
12
+ from torch.utils.data import Dataset, DataLoader
13
+
14
+ from config import config
15
+ from src.refine import refine_answer
16
+ from src.run_gpt import run_gpt
17
+
18
+ class CustomDatasetGPT(Dataset):
19
+ def __init__(self, questions, num_input_imgs, num_select):
20
+ self.questions = questions
21
+ self.num_input_imgs = num_input_imgs
22
+ self.num_select = num_select
23
+
24
+ def __getitem__(self, index):
25
+ line = self.questions[index]
26
+ num_select = self.num_select
27
+ num_input_imgs = self.num_input_imgs
28
+
29
+ giter = 0
30
+ imgs_group = 8
31
+ num_groups = num_input_imgs//imgs_group
32
+
33
+ kf_paths = line["kf_paths"]
34
+ keywords = line["keywords"]
35
+ simvalue = line["simvalue"]
36
+ concatimg_dir = line['concatimg_dir']
37
+
38
+ concatimg_paths = natsort.natsorted([f"{concatimg_dir}/{im}" for im in os.listdir(concatimg_dir) if "ipynb" not in im])
39
+
40
+ concatimages = []
41
+ concatimages_base64 = []
42
+ qs_org = []
43
+ kw_perconcat = []
44
+ kf_paths_perconcat = []
45
+ simvalue_perconcat = []
46
+ segment_timeline = []
47
+
48
+ for concatidx, img_path in enumerate(concatimg_paths):
49
+ concatimages.append(Image.open(img_path).convert('RGB'))
50
+ concatimages_base64.append(img_path)
51
+
52
+ kw_sidx = imgs_group*(concatidx)
53
+ kw_eidx = imgs_group*(concatidx+1)
54
+
55
+ concat_kw = keywords[kw_sidx:kw_eidx]
56
+ qs_org_ = create_question(concat_kw, num_select)
57
+
58
+ kw_perconcat.append(concat_kw)
59
+ qs_org.append(qs_org_)
60
+ kf_paths_perconcat.append(kf_paths[kw_sidx:kw_eidx])
61
+ simvalue_perconcat.append(simvalue[kw_sidx:kw_eidx])
62
+ segment_timeline.append(line["segment_timeline"])
63
+
64
+ concatimg_paths = concatimg_paths[-num_groups:]
65
+ concatimages_base64 = concatimages_base64[-num_groups:]
66
+ qs_org = qs_org[-num_groups:]
67
+ kw_perconcat = kw_perconcat[-num_groups:]
68
+ kf_paths_perconcat = kf_paths_perconcat[-num_groups:]
69
+ simvalue_perconcat = simvalue_perconcat[-num_groups:]
70
+ segment_timeline = segment_timeline[-num_groups:]
71
+
72
+ return concatimages_base64, concatimages[0].size, kw_perconcat, kf_paths_perconcat, qs_org, segment_timeline, concatimg_paths, simvalue_perconcat
73
+
74
+ def __len__(self):
75
+ return len(self.questions)
76
+
77
+ def create_question(concat_kw, num_select):
78
+ imgkw_sen = ""
79
+
80
+ for i, e in enumerate(concat_kw):
81
+ if i < len(concat_kw) - 1: imgkw_sen = imgkw_sen + f"Image_{i}: '{e}', "
82
+ else: imgkw_sen = imgkw_sen + f"Image_{i}: '{e}'."
83
+
84
+ if num_select == 3:
85
+ prompt = f"Eight images, having egocentric perspectives, are juxtaposed, separated by a red vertical line and red horizontal line. In the first row, the images from left to right are named as image_0, image_1, image_2, image_3. In the second row, the images from left to right are named as image_4, image_5, image_6, image_7. Here are images and their associated guess words: {imgkw_sen}. Think step-by-step and list only the names of the {num_select} images most closely related to the guessed words. Do not select blurry images in your answer. If none of the images correspond to the provided guess words, choose any two images at random. Your answer should follow the JSON format shown below and should only include the JSON result. Do not output any warnings or notes under any circumstances. Instead, adhere strictly to the provided JSON format example.\n"
86
+ prompt += "{\"image name\": write reason for your selection in 10 words}."
87
+ prompt += " This is one example output format. {\n \"image_0\": \"Person washing a plate; linked to dish cleaning.\",\n \"image_2\": \"Person washing a bowl; linked to dish cleaning.\",\n \"image_6\": \"Person running water on a sponge; related to dish cleaning.\"\n}"
88
+
89
+ elif num_select == 4:
90
+ prompt = f"Eight images, having egocentric perspectives, are juxtaposed, separated by a red vertical line and red horizontal line. In the first row, the images from left to right are named as image_0, image_1, image_2, image_3. In the second row, the images from left to right are named as image_4, image_5, image_6, image_7. Here are images and their associated guess words: {imgkw_sen}. Think step-by-step and list only the names of the {num_select} images most closely related to the guessed words. Do not select blurry images in your answer. If none of the images correspond to the provided guess words, choose any two images at random. Your answer should follow the JSON format shown below and should only include the JSON result. Do not output any warnings or notes under any circumstances. Instead, adhere strictly to the provided JSON format example.\n"
91
+ prompt += "{\"image name\": write reason for your selection in 10 words}."
92
+ prompt += " This is one example output format. {\n \"image_0\": \"Person washing a plate; linked to dish cleaning.\",\n \"image_2\": \"Person washing a bowl; linked to dish cleaning.\",\n \"image_6\": \"Person running water on a sponge; related to dish cleaning.\",\n \"image_7\": \"Person moves mouse; linked to working.\"\n}"
93
+
94
+ else: assert False, f"num_select:{num_select} is not defined yet"
95
+
96
+ return prompt
97
+
98
+ def encode_image(image_path):
99
+ with open(image_path, "rb") as image_file:
100
+ return base64.b64encode(image_file.read()).decode('utf-8')
101
+
102
+ def create_data_loader_gpt(questions, num_input_imgs, num_select, batch_size=1, num_workers=4):
103
+ assert batch_size == 1, "batch_size must be 1"
104
+ dataset = CustomDatasetGPT(questions, num_input_imgs, num_select)
105
+ data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
106
+ return data_loader, dataset
107
+
108
+ def eval_model():
109
+ question_path, vlm, num_input_imgs, num_select, temp = config.kf_question_path, config.kf_vlm, config.kf_num_input_imgs, config.kf_num_select, config.kf_temp
110
+ questions = [json.loads(q) for q in open(os.path.expanduser(question_path), "r")]
111
+ num_questions = len(questions)
112
+ giter = 0
113
+
114
+ answer_path = config.kf_answer_path
115
+ os.makedirs(os.path.dirname(answer_path), exist_ok=True)
116
+
117
+ print(f"question_path:{question_path}\nanswer_path:{answer_path}")
118
+ ans_file = open(answer_path, "w")
119
+ data_loader, dataset = create_data_loader_gpt(questions, num_input_imgs, num_select)
120
+
121
+ outputs = ""
122
+ for (image_paths, image_sizes, kw_perconcat, kf_paths_perconcat, cur_prompts, segment_timeline, concatimg_paths, simvalue_perconcat), line in tqdm(zip(data_loader, questions), total=len(questions)):
123
+ idx, q_uid = line["q_uid"], line["q_uid"]
124
+ CA = line["CA"] if "CA" in line else None
125
+ option0 = line['option 0']
126
+ option1 = line['option 1']
127
+ option2 = line['option 2']
128
+ option3 = line['option 3']
129
+ option4 = line['option 4']
130
+ question = line['question']
131
+
132
+ pastobj = None
133
+ past_VLM_path = None
134
+ past_VLM_timeline = None
135
+
136
+ kw_VLM = []
137
+ kf_paths_VLM = []
138
+ kf_timeline = []
139
+
140
+ kw_VLM_ordered = []
141
+ kf_paths_VLM_ordered = []
142
+ kf_timeline_ordered = []
143
+
144
+ prompts = [x[0] for x in cur_prompts]
145
+ image_paths = [x[0] for x in image_paths]
146
+
147
+ output_VLM = run_gpt(
148
+ images=image_paths,
149
+ texts=prompts,
150
+ api_keys = list(config.dict_api.values()),
151
+ max_tokens=2000,
152
+ model=vlm,
153
+ temperature=temp,
154
+ num_threads=20,
155
+ backoff_time=1*60,
156
+ silent=False,
157
+ dataset="egoschema",
158
+ verbose=False,
159
+ )
160
+ output_VLM = list(output_VLM)
161
+
162
+ for j, _ in enumerate(cur_prompts):
163
+ kf_paths_perconcat_ = kf_paths_perconcat[j]
164
+ kf_timeline.append([f"{e[0].split('_')[-2]}.{e[0].split('_')[-1].split('.')[0]}" for e in kf_paths_perconcat_])
165
+
166
+ line_frame = line.copy()
167
+
168
+ line_frame["output_VLM"] = output_VLM
169
+ line_frame["concatimg_paths"] = concatimg_paths
170
+ line_frame["kf_paths_VLM"] = kf_paths_perconcat
171
+ line_frame["kf_timeline"] = kf_timeline
172
+ line_frame["kw_perconcat_clip"] = kw_perconcat
173
+ line_frame["iter"] = giter
174
+
175
+ line_frame.pop("filepath")
176
+ line_frame.pop("kf_paths")
177
+ line_frame.pop("google_drive_id")
178
+
179
+ try: ans_file.write(json.dumps(line_frame) + "\n")
180
+ except: assert False, f"line_frame:{line_frame}"
181
+
182
+ ans_file.close()
183
+ print(f"question_path:{question_path}\nanswer_path:{answer_path}")
184
+ print("job is done")
185
+
186
+
187
+ if __name__ == "__main__":
188
+ eval_model()
189
+ refine_answer()
keywords/Keyword_4531questions.json ADDED
The diff for this file is too large to render. See raw diff
 
keywords/Keyword_500questions.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
questions/4531questions.json ADDED
The diff for this file is too large to render. See raw diff
 
questions/500questions.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.1.2
2
+ torchvision==0.16.2
3
+ transformers==4.37.2
4
+ tokenizers==0.15.1
5
+ sentencepiece==0.1.99
6
+ shortuuid
7
+
8
+ peft
9
+
10
+ numpy
11
+
12
+ requests
13
+
14
+ uvicorn
15
+ fastapi
16
+
17
+ rp
18
+ tqdm
19
+
20
+ shutil
21
+
22
+ PIL
23
+ natsort
24
+ clip
scripts/create_caption.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # run this script from the root of the repo
2
+ # fix paths to data and update GPT keys in code
3
+
4
+ export PYTHONPATH=$PYTHONPATH:$PWD
5
+
6
+ python3 VLM_stage.py \
7
+ --output-dir ego_base_link_bigtensor \
8
+ --question-path your_question_path.jsonl \
9
+ --gptmodel "gpt-4o" \
10
+ --num-kf 12 \
11
+ --temp 0
scripts/eval_ES.sh ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # run this script from the root of the repo
2
+ # fix paths to data and update GPT keys in code
3
+
4
+ export PYTHONPATH=$PYTHONPATH:$PWD
5
+
6
+ # Evaluate on ES subset (uncomment it if you want to run it)
7
+ python3 LLM_stage.py \
8
+ --output-dir ego_base_link \
9
+ --captions data/ESsub_captions_gpt4o.jsonl \
10
+ --data data/ESsub_qa_data.json \
11
+ --per-vid-captions 12 \
12
+ --gptmodel "gpt-4o" \
13
+ --temperature 0.0
14
+
15
+
16
+ # Evaluate on ES full dataset (uncomment it if you want to run it)
17
+ # python3 LLM_stage.py \
18
+ # --output-dir ego_base_link \
19
+ # --captions data/ES_captions_gpt4o.jsonl \
20
+ # --data data/ES_qa_data.json \
21
+ # --per-vid-captions 12 \
22
+ # --gptmodel "gpt-4o" \
23
+ # --temperature 0.0
scripts/get_ES_captions.sh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # downloaded our pre-generated EgoSchema captions
2
+
3
+ wget https://github.com/jongwoopark7978/LVNet_dev/releases/download/v1.0/ESsub_captions_gpt4o.jsonl
4
+ wget TBA
5
+ wget TBA
6
+ wget TBA
7
+
8
+ # TODO: update the paths when the repo is public.
src/open_clip/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from llava.open_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
2
+ from llava.open_clip.factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer
3
+ from llava.open_clip.factory import list_models, add_model_config, get_model_config, load_checkpoint
4
+ from llava.open_clip.loss import ClipLoss
5
+ from llava.open_clip.model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg,\
6
+ convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
7
+ from llava.open_clip.openai import load_openai_model, list_openai_models
8
+ from llava.open_clip.pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\
9
+ get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
10
+ from llava.open_clip.tokenizer import SimpleTokenizer, tokenize
11
+ from llava.open_clip.transform import image_transform
src/open_clip/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
src/open_clip/constants.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
2
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
src/open_clip/factory.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import pathlib
5
+ import re
6
+ from copy import deepcopy
7
+ from pathlib import Path
8
+ from typing import Optional, Tuple, Union
9
+
10
+ import torch
11
+
12
+ from llava.open_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
13
+ from llava.open_clip.model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
14
+ resize_pos_embed, get_cast_dtype
15
+ from llava.open_clip.openai import load_openai_model
16
+ from llava.open_clip.pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model
17
+ from llava.open_clip.transform import image_transform
18
+ from llava.open_clip.tokenizer import HFTokenizer, tokenize
19
+
20
+
21
+ _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
22
+ # _MODEL_CONFIG_PATHS = ["/home/mryoo/llava_16/llava/open_clip/model_config/"]
23
+ _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
24
+
25
+
26
+ def _natural_key(string_):
27
+ return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
28
+
29
+
30
+ def _rescan_model_configs():
31
+ global _MODEL_CONFIGS
32
+
33
+ config_ext = ('.json',)
34
+ config_files = []
35
+
36
+ # print(f"_MODEL_CONFIG_PATHS:{_MODEL_CONFIG_PATHS}")
37
+ for config_path in _MODEL_CONFIG_PATHS:
38
+ if config_path.is_file() and config_path.suffix in config_ext:
39
+ config_files.append(config_path)
40
+ elif config_path.is_dir():
41
+ for ext in config_ext:
42
+ config_files.extend(config_path.glob(f'*{ext}'))
43
+ # for ext in config_ext:
44
+ # config_files.extend(config_path.glob(f'*{ext}'))
45
+
46
+ for cf in config_files:
47
+ with open(cf, 'r') as f:
48
+ model_cfg = json.load(f)
49
+ if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
50
+ _MODEL_CONFIGS[cf.stem] = model_cfg
51
+
52
+ _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
53
+ # print(f"_MODEL_CONFIGS:{_MODEL_CONFIGS}")
54
+
55
+
56
+ _rescan_model_configs() # initial populate of model config registry
57
+
58
+
59
+ def list_models():
60
+ """ enumerate available model architectures based on config files """
61
+ return list(_MODEL_CONFIGS.keys())
62
+
63
+
64
+ def add_model_config(path):
65
+ """ add model config path or file and update registry """
66
+ if not isinstance(path, Path):
67
+ path = Path(path)
68
+ _MODEL_CONFIG_PATHS.append(path)
69
+ _rescan_model_configs()
70
+
71
+
72
+ def get_model_config(model_name):
73
+ # print(f"_MODEL_CONFIGS:{_MODEL_CONFIGS}")
74
+ if model_name in _MODEL_CONFIGS:
75
+ return deepcopy(_MODEL_CONFIGS[model_name])
76
+ else:
77
+ return None
78
+
79
+
80
+ def get_tokenizer(model_name):
81
+ config = get_model_config(model_name)
82
+ tokenizer = HFTokenizer(config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize
83
+ return tokenizer
84
+
85
+
86
+ def load_state_dict(checkpoint_path: str, map_location='cpu'):
87
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
88
+ if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
89
+ state_dict = checkpoint['state_dict']
90
+ else:
91
+ state_dict = checkpoint
92
+ if next(iter(state_dict.items()))[0].startswith('module'):
93
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
94
+ return state_dict
95
+
96
+
97
+ def load_checkpoint(model, checkpoint_path, strict=True):
98
+ state_dict = load_state_dict(checkpoint_path)
99
+ # detect old format and make compatible with new format
100
+ if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
101
+ state_dict = convert_to_custom_text_state_dict(state_dict)
102
+ resize_pos_embed(state_dict, model)
103
+ incompatible_keys = model.load_state_dict(state_dict, strict=strict)
104
+ return incompatible_keys
105
+
106
+
107
+ def create_model(
108
+ model_name: str,
109
+ pretrained: Optional[str] = None,
110
+ precision: str = 'fp32',
111
+ device: Union[str, torch.device] = 'cpu',
112
+ jit: bool = False,
113
+ force_quick_gelu: bool = False,
114
+ force_custom_text: bool = False,
115
+ force_patch_dropout: Optional[float] = None,
116
+ pretrained_image: bool = False,
117
+ pretrained_hf: bool = True,
118
+ cache_dir: Optional[str] = None,
119
+ ):
120
+ model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
121
+ if isinstance(device, str):
122
+ device = torch.device(device)
123
+
124
+ if pretrained and pretrained.lower() == 'openai':
125
+ logging.info(f'Loading pretrained {model_name} from OpenAI.')
126
+ model = load_openai_model(
127
+ model_name,
128
+ precision=precision,
129
+ device=device,
130
+ jit=jit,
131
+ cache_dir=cache_dir,
132
+ )
133
+ else:
134
+ model_cfg = get_model_config(model_name)
135
+ if model_cfg is not None:
136
+ logging.info(f'Loaded {model_name} model config.')
137
+ else:
138
+ logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
139
+ raise RuntimeError(f'Model config for {model_name} not found.')
140
+
141
+ if force_quick_gelu:
142
+ # override for use of QuickGELU on non-OpenAI transformer models
143
+ model_cfg["quick_gelu"] = True
144
+
145
+ if force_patch_dropout is not None:
146
+ # override the default patch dropout value
147
+ model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
148
+
149
+ if pretrained_image:
150
+ if 'timm_model_name' in model_cfg.get('vision_cfg', {}):
151
+ # pretrained weight loading for timm models set via vision_cfg
152
+ model_cfg['vision_cfg']['timm_model_pretrained'] = True
153
+ else:
154
+ assert False, 'pretrained image towers currently only supported for timm models'
155
+
156
+ cast_dtype = get_cast_dtype(precision)
157
+ custom_text = model_cfg.pop('custom_text', False) or force_custom_text or ('hf_model_name' in model_cfg.get('text_cfg', {}))
158
+
159
+ if custom_text:
160
+ if 'hf_model_name' in model_cfg.get('text_cfg', {}):
161
+ model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
162
+ model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
163
+ else:
164
+ model = CLIP(**model_cfg, cast_dtype=cast_dtype)
165
+
166
+ pretrained_cfg = {}
167
+ if pretrained:
168
+ checkpoint_path = ''
169
+ pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
170
+ if pretrained_cfg:
171
+ checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
172
+ elif os.path.exists(pretrained):
173
+ checkpoint_path = pretrained
174
+
175
+ if checkpoint_path:
176
+ logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
177
+ load_checkpoint(model, checkpoint_path)
178
+ else:
179
+ error_str = (
180
+ f'Pretrained weights ({pretrained}) not found for model {model_name}.'
181
+ f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
182
+ logging.warning(error_str)
183
+ raise RuntimeError(error_str)
184
+
185
+ model.to(device=device)
186
+ if precision in ("fp16", "bf16"):
187
+ convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16)
188
+
189
+ # set image / mean metadata from pretrained_cfg if available, or use default
190
+ model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
191
+ model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
192
+
193
+ if jit:
194
+ model = torch.jit.script(model)
195
+
196
+ return model
197
+
198
+
199
+ def create_model_and_transforms(
200
+ model_name: str,
201
+ pretrained: Optional[str] = None,
202
+ precision: str = 'fp32',
203
+ device: Union[str, torch.device] = 'cpu',
204
+ jit: bool = False,
205
+ force_quick_gelu: bool = False,
206
+ force_custom_text: bool = False,
207
+ force_patch_dropout: Optional[float] = None,
208
+ pretrained_image: bool = False,
209
+ pretrained_hf: bool = True,
210
+ image_mean: Optional[Tuple[float, ...]] = None,
211
+ image_std: Optional[Tuple[float, ...]] = None,
212
+ cache_dir: Optional[str] = None,
213
+ ):
214
+ model = create_model(
215
+ model_name,
216
+ pretrained,
217
+ precision=precision,
218
+ device=device,
219
+ jit=jit,
220
+ force_quick_gelu=force_quick_gelu,
221
+ force_custom_text=force_custom_text,
222
+ force_patch_dropout=force_patch_dropout,
223
+ pretrained_image=pretrained_image,
224
+ pretrained_hf=pretrained_hf,
225
+ cache_dir=cache_dir,
226
+ )
227
+
228
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
229
+ image_std = image_std or getattr(model.visual, 'image_std', None)
230
+ preprocess_train = image_transform(
231
+ model.visual.image_size,
232
+ is_train=True,
233
+ mean=image_mean,
234
+ std=image_std
235
+ )
236
+ preprocess_val = image_transform(
237
+ model.visual.image_size,
238
+ is_train=False,
239
+ mean=image_mean,
240
+ std=image_std
241
+ )
242
+
243
+ return model, preprocess_train, preprocess_val
244
+
245
+
246
+ def create_model_from_pretrained(
247
+ model_name: str,
248
+ pretrained: str,
249
+ precision: str = 'fp32',
250
+ device: Union[str, torch.device] = 'cpu',
251
+ jit: bool = False,
252
+ force_quick_gelu: bool = False,
253
+ force_custom_text: bool = False,
254
+ return_transform: bool = True,
255
+ image_mean: Optional[Tuple[float, ...]] = None,
256
+ image_std: Optional[Tuple[float, ...]] = None,
257
+ cache_dir: Optional[str] = None,
258
+ ):
259
+ if not is_pretrained_cfg(model_name, pretrained) and not os.path.exists(pretrained):
260
+ raise RuntimeError(
261
+ f'{pretrained} is not a valid pretrained cfg or checkpoint for {model_name}.'
262
+ f' Use open_clip.list_pretrained() to find one.')
263
+
264
+ model = create_model(
265
+ model_name,
266
+ pretrained,
267
+ precision=precision,
268
+ device=device,
269
+ jit=jit,
270
+ force_quick_gelu=force_quick_gelu,
271
+ force_custom_text=force_custom_text,
272
+ cache_dir=cache_dir,
273
+ )
274
+
275
+ if not return_transform:
276
+ return model
277
+
278
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
279
+ image_std = image_std or getattr(model.visual, 'image_std', None)
280
+ preprocess = image_transform(
281
+ model.visual.image_size,
282
+ is_train=False,
283
+ mean=image_mean,
284
+ std=image_std
285
+ )
286
+
287
+ return model, preprocess
src/open_clip/hf_configs.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HF architecture dict:
2
+ arch_dict = {
3
+ # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
4
+ "roberta": {
5
+ "config_names": {
6
+ "context_length": "max_position_embeddings",
7
+ "vocab_size": "vocab_size",
8
+ "width": "hidden_size",
9
+ "heads": "num_attention_heads",
10
+ "layers": "num_hidden_layers",
11
+ "layer_attr": "layer",
12
+ "token_embeddings_attr": "embeddings"
13
+ },
14
+ "pooler": "mean_pooler",
15
+ },
16
+ # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
17
+ "xlm-roberta": {
18
+ "config_names": {
19
+ "context_length": "max_position_embeddings",
20
+ "vocab_size": "vocab_size",
21
+ "width": "hidden_size",
22
+ "heads": "num_attention_heads",
23
+ "layers": "num_hidden_layers",
24
+ "layer_attr": "layer",
25
+ "token_embeddings_attr": "embeddings"
26
+ },
27
+ "pooler": "mean_pooler",
28
+ },
29
+ # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
30
+ "mt5": {
31
+ "config_names": {
32
+ # unlimited seqlen
33
+ # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
34
+ # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
35
+ "context_length": "",
36
+ "vocab_size": "vocab_size",
37
+ "width": "d_model",
38
+ "heads": "num_heads",
39
+ "layers": "num_layers",
40
+ "layer_attr": "block",
41
+ "token_embeddings_attr": "embed_tokens"
42
+ },
43
+ "pooler": "mean_pooler",
44
+ },
45
+ "t5": {
46
+ "config_names": {
47
+ # unlimited seqlen
48
+ # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
49
+ # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
50
+ "context_length": "",
51
+ "vocab_size": "vocab_size",
52
+ "width": "d_model",
53
+ "heads": "num_heads",
54
+ "layers": "num_layers",
55
+ "layer_attr": "block",
56
+ "token_embeddings_attr": "embed_tokens"
57
+ },
58
+ "pooler": "mean_pooler",
59
+ },
60
+ }
src/open_clip/hf_model.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ huggingface model adapter
2
+
3
+ Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
4
+ """
5
+
6
+ import re
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch import TensorType
11
+
12
+ try:
13
+ import transformers
14
+ from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
15
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
16
+ BaseModelOutputWithPoolingAndCrossAttentions
17
+ except ImportError as e:
18
+ transformers = None
19
+
20
+
21
+ class BaseModelOutput:
22
+ pass
23
+
24
+
25
+ class PretrainedConfig:
26
+ pass
27
+
28
+ from .hf_configs import arch_dict
29
+
30
+
31
+ # utils
32
+ def _camel2snake(s):
33
+ return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
34
+
35
+
36
+ # TODO: ?last - for gpt-like models
37
+ _POOLERS = {}
38
+
39
+
40
+ def register_pooler(cls):
41
+ """Decorator registering pooler class"""
42
+ _POOLERS[_camel2snake(cls.__name__)] = cls
43
+ return cls
44
+
45
+
46
+ @register_pooler
47
+ class MeanPooler(nn.Module):
48
+ """Mean pooling"""
49
+
50
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
51
+ masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
52
+ return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
53
+
54
+
55
+ @register_pooler
56
+ class MaxPooler(nn.Module):
57
+ """Max pooling"""
58
+
59
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
60
+ masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
61
+ return masked_output.max(1).values
62
+
63
+
64
+ @register_pooler
65
+ class ClsPooler(nn.Module):
66
+ """CLS token pooling"""
67
+
68
+ def __init__(self, use_pooler_output=True):
69
+ super().__init__()
70
+ self.cls_token_position = 0
71
+ self.use_pooler_output = use_pooler_output
72
+
73
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
74
+ if (self.use_pooler_output and
75
+ isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
76
+ (x.pooler_output is not None)
77
+ ):
78
+ return x.pooler_output
79
+
80
+ return x.last_hidden_state[:, self.cls_token_position, :]
81
+
82
+
83
+ class HFTextEncoder(nn.Module):
84
+ """HuggingFace model adapter"""
85
+
86
+ def __init__(
87
+ self,
88
+ model_name_or_path: str,
89
+ output_dim: int,
90
+ config: PretrainedConfig = None,
91
+ pooler_type: str = None,
92
+ proj: str = None,
93
+ pretrained: bool = True):
94
+ super().__init__()
95
+
96
+ self.output_dim = output_dim
97
+
98
+ # TODO: find better way to get this information
99
+ uses_transformer_pooler = (pooler_type == "cls_pooler")
100
+
101
+ if transformers is None:
102
+ raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
103
+ if config is None:
104
+ self.config = AutoConfig.from_pretrained(model_name_or_path)
105
+ create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
106
+ AutoModel.from_config, self.config)
107
+ # TODO: do all model configs have this attribute? PretrainedConfig does so yes??
108
+ if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
109
+ self.transformer = create_func(model_args)
110
+ self.transformer = self.transformer.encoder
111
+ else:
112
+ self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
113
+ else:
114
+ self.config = config
115
+ self.transformer = AutoModel.from_config(config)
116
+
117
+ if pooler_type is None: # get default arch pooler
118
+ self.pooler = _POOLERS[(arch_dict[self.config.model_type]["pooler"])]()
119
+ else:
120
+ self.pooler = _POOLERS[pooler_type]()
121
+
122
+ d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
123
+ if (d_model == output_dim) and (proj is None): # do we always need a proj?
124
+ self.proj = nn.Identity()
125
+ elif proj == 'linear':
126
+ self.proj = nn.Linear(d_model, output_dim, bias=False)
127
+ elif proj == 'mlp':
128
+ hidden_size = (d_model + output_dim) // 2
129
+ self.proj = nn.Sequential(
130
+ nn.Linear(d_model, hidden_size, bias=False),
131
+ nn.GELU(),
132
+ nn.Linear(hidden_size, output_dim, bias=False),
133
+ )
134
+
135
+ def forward(self, x: TensorType) -> TensorType:
136
+ attn_mask = (x != self.config.pad_token_id).long()
137
+ out = self.transformer(input_ids=x, attention_mask=attn_mask)
138
+ pooled_out = self.pooler(out, attn_mask)
139
+
140
+ return self.proj(pooled_out)
141
+
142
+ def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
143
+ if not unlocked_layers: # full freezing
144
+ for n, p in self.transformer.named_parameters():
145
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
146
+ return
147
+
148
+ encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
149
+ layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
150
+ print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
151
+ embeddings = getattr(
152
+ self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
153
+ modules = [embeddings, *layer_list][:-unlocked_layers]
154
+ # freeze layers
155
+ for module in modules:
156
+ for n, p in module.named_parameters():
157
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
158
+
159
+ @torch.jit.ignore
160
+ def set_grad_checkpointing(self, enable=True):
161
+ self.transformer.gradient_checkpointing_enable()
162
+
163
+ def init_parameters(self):
164
+ pass
src/open_clip/loss.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+
5
+ try:
6
+ import torch.distributed.nn
7
+ from torch import distributed as dist
8
+ has_distributed = True
9
+ except ImportError:
10
+ has_distributed = False
11
+
12
+ try:
13
+ import horovod.torch as hvd
14
+ except ImportError:
15
+ hvd = None
16
+
17
+
18
+ def gather_features(
19
+ image_features,
20
+ text_features,
21
+ local_loss=False,
22
+ gather_with_grad=False,
23
+ rank=0,
24
+ world_size=1,
25
+ use_horovod=False
26
+ ):
27
+ assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
28
+ if use_horovod:
29
+ assert hvd is not None, 'Please install horovod'
30
+ if gather_with_grad:
31
+ all_image_features = hvd.allgather(image_features)
32
+ all_text_features = hvd.allgather(text_features)
33
+ else:
34
+ with torch.no_grad():
35
+ all_image_features = hvd.allgather(image_features)
36
+ all_text_features = hvd.allgather(text_features)
37
+ if not local_loss:
38
+ # ensure grads for local rank when all_* features don't have a gradient
39
+ gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
40
+ gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
41
+ gathered_image_features[rank] = image_features
42
+ gathered_text_features[rank] = text_features
43
+ all_image_features = torch.cat(gathered_image_features, dim=0)
44
+ all_text_features = torch.cat(gathered_text_features, dim=0)
45
+ else:
46
+ # We gather tensors from all gpus
47
+ if gather_with_grad:
48
+ all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
49
+ all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
50
+ else:
51
+ gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
52
+ gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
53
+ dist.all_gather(gathered_image_features, image_features)
54
+ dist.all_gather(gathered_text_features, text_features)
55
+ if not local_loss:
56
+ # ensure grads for local rank when all_* features don't have a gradient
57
+ gathered_image_features[rank] = image_features
58
+ gathered_text_features[rank] = text_features
59
+ all_image_features = torch.cat(gathered_image_features, dim=0)
60
+ all_text_features = torch.cat(gathered_text_features, dim=0)
61
+
62
+ return all_image_features, all_text_features
63
+
64
+
65
+ class ClipLoss(nn.Module):
66
+
67
+ def __init__(
68
+ self,
69
+ local_loss=False,
70
+ gather_with_grad=False,
71
+ cache_labels=False,
72
+ rank=0,
73
+ world_size=1,
74
+ use_horovod=False,
75
+ ):
76
+ super().__init__()
77
+ self.local_loss = local_loss
78
+ self.gather_with_grad = gather_with_grad
79
+ self.cache_labels = cache_labels
80
+ self.rank = rank
81
+ self.world_size = world_size
82
+ self.use_horovod = use_horovod
83
+
84
+ # cache state
85
+ self.prev_num_logits = 0
86
+ self.labels = {}
87
+
88
+ def forward(self, image_features, text_features, logit_scale):
89
+ device = image_features.device
90
+ if self.world_size > 1:
91
+ all_image_features, all_text_features = gather_features(
92
+ image_features, text_features,
93
+ self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
94
+
95
+ if self.local_loss:
96
+ logits_per_image = logit_scale * image_features @ all_text_features.T
97
+ logits_per_text = logit_scale * text_features @ all_image_features.T
98
+ else:
99
+ logits_per_image = logit_scale * all_image_features @ all_text_features.T
100
+ logits_per_text = logits_per_image.T
101
+ else:
102
+ logits_per_image = logit_scale * image_features @ text_features.T
103
+ logits_per_text = logit_scale * text_features @ image_features.T
104
+
105
+ # calculated ground-truth and cache if enabled
106
+ num_logits = logits_per_image.shape[0]
107
+ if self.prev_num_logits != num_logits or device not in self.labels:
108
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
109
+ if self.world_size > 1 and self.local_loss:
110
+ labels = labels + num_logits * self.rank
111
+ if self.cache_labels:
112
+ self.labels[device] = labels
113
+ self.prev_num_logits = num_logits
114
+ else:
115
+ labels = self.labels[device]
116
+
117
+ total_loss = (
118
+ F.cross_entropy(logits_per_image, labels) +
119
+ F.cross_entropy(logits_per_text, labels)
120
+ ) / 2
121
+ return total_loss
src/open_clip/model.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLIP Model
2
+
3
+ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+ from dataclasses import dataclass
6
+ import logging
7
+ import math
8
+ from typing import Optional, Tuple, Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import nn
14
+ from torch.utils.checkpoint import checkpoint
15
+ from sentence_transformers import SentenceTransformer
16
+
17
+ from llava.open_clip.hf_model import HFTextEncoder
18
+ from llava.open_clip.modified_resnet import ModifiedResNet
19
+ from llava.open_clip.timm_model import TimmModel
20
+ from llava.open_clip.transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
21
+ from llava.open_clip.utils import to_2tuple
22
+
23
+
24
+ @dataclass
25
+ class CLIPVisionCfg:
26
+ layers: Union[Tuple[int, int, int, int], int] = 12
27
+ width: int = 768
28
+ head_width: int = 64
29
+ mlp_ratio: float = 4.0
30
+ patch_size: int = 16
31
+ image_size: Union[Tuple[int, int], int] = 224
32
+ ls_init_value: Optional[float] = None # layer scale initial value
33
+ patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
34
+ global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
35
+ global_max_pool: bool = False # whether to max pool, from CLIPpy paper
36
+ timm_model_name: str = None # a valid model name overrides layers, width, patch_size
37
+ timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
38
+ timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
39
+ timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
40
+ timm_proj_bias: bool = False # enable bias final projection
41
+
42
+
43
+ @dataclass
44
+ class CLIPTextCfg:
45
+ context_length: int = 77
46
+ vocab_size: int = 49408
47
+ width: int = 512
48
+ heads: int = 8
49
+ layers: int = 12
50
+ ls_init_value: Optional[float] = None # layer scale initial value
51
+ hf_model_name: str = None
52
+ hf_tokenizer_name: str = None
53
+ hf_model_pretrained: bool = True
54
+ proj: str = 'mlp'
55
+ pooler_type: str = 'mean_pooler'
56
+
57
+
58
+ def get_cast_dtype(precision: str):
59
+ cast_dtype = None
60
+ if precision == 'bf16':
61
+ cast_dtype = torch.bfloat16
62
+ elif precision == 'fp16':
63
+ cast_dtype = torch.float16
64
+ return cast_dtype
65
+
66
+
67
+ def _build_vision_tower(
68
+ embed_dim: int,
69
+ vision_cfg: CLIPVisionCfg,
70
+ quick_gelu: bool = False,
71
+ cast_dtype: Optional[torch.dtype] = None
72
+ ):
73
+ if isinstance(vision_cfg, dict):
74
+ vision_cfg = CLIPVisionCfg(**vision_cfg)
75
+
76
+ # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
77
+ # memory efficient in recent PyTorch releases (>= 1.10).
78
+ # NOTE: timm models always use native GELU regardless of quick_gelu flag.
79
+ act_layer = QuickGELU if quick_gelu else nn.GELU
80
+
81
+ if vision_cfg.timm_model_name:
82
+ visual = TimmModel(
83
+ vision_cfg.timm_model_name,
84
+ pretrained=vision_cfg.timm_model_pretrained,
85
+ pool=vision_cfg.timm_pool,
86
+ proj=vision_cfg.timm_proj,
87
+ proj_bias=vision_cfg.timm_proj_bias,
88
+ embed_dim=embed_dim,
89
+ image_size=vision_cfg.image_size
90
+ )
91
+ act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
92
+ elif isinstance(vision_cfg.layers, (tuple, list)):
93
+ vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
94
+ visual = ModifiedResNet(
95
+ layers=vision_cfg.layers,
96
+ output_dim=embed_dim,
97
+ heads=vision_heads,
98
+ image_size=vision_cfg.image_size,
99
+ width=vision_cfg.width
100
+ )
101
+ else:
102
+ vision_heads = vision_cfg.width // vision_cfg.head_width
103
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
104
+ visual = VisionTransformer(
105
+ image_size=vision_cfg.image_size,
106
+ patch_size=vision_cfg.patch_size,
107
+ width=vision_cfg.width,
108
+ layers=vision_cfg.layers,
109
+ heads=vision_heads,
110
+ mlp_ratio=vision_cfg.mlp_ratio,
111
+ ls_init_value=vision_cfg.ls_init_value,
112
+ patch_dropout=vision_cfg.patch_dropout,
113
+ global_average_pool=vision_cfg.global_average_pool,
114
+ global_max_pool=vision_cfg.global_max_pool,
115
+ output_dim=embed_dim,
116
+ act_layer=act_layer,
117
+ norm_layer=norm_layer,
118
+ )
119
+
120
+ return visual
121
+
122
+
123
+ def _build_text_tower(
124
+ embed_dim: int,
125
+ text_cfg: CLIPTextCfg,
126
+ quick_gelu: bool = False,
127
+ cast_dtype: Optional[torch.dtype] = None,
128
+ ):
129
+ if isinstance(text_cfg, dict):
130
+ text_cfg = CLIPTextCfg(**text_cfg)
131
+
132
+ if text_cfg.hf_model_name:
133
+ if text_cfg.hf_model_name == "sentence-transformers/sentence-t5-base":
134
+ text = SentenceTransformer("sentence-transformers/sentence-t5-base")
135
+ else:
136
+ text = HFTextEncoder(
137
+ text_cfg.hf_model_name,
138
+ output_dim=embed_dim,
139
+ proj=text_cfg.proj,
140
+ pooler_type=text_cfg.pooler_type,
141
+ pretrained=text_cfg.hf_model_pretrained
142
+ )
143
+ else:
144
+ act_layer = QuickGELU if quick_gelu else nn.GELU
145
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
146
+
147
+ text = TextTransformer(
148
+ context_length=text_cfg.context_length,
149
+ vocab_size=text_cfg.vocab_size,
150
+ width=text_cfg.width,
151
+ heads=text_cfg.heads,
152
+ layers=text_cfg.layers,
153
+ ls_init_value=text_cfg.ls_init_value,
154
+ output_dim=embed_dim,
155
+ act_layer=act_layer,
156
+ norm_layer=norm_layer,
157
+ )
158
+ return text
159
+
160
+
161
+ class CLIP(nn.Module):
162
+ def __init__(
163
+ self,
164
+ embed_dim: int,
165
+ vision_cfg: CLIPVisionCfg,
166
+ text_cfg: CLIPTextCfg,
167
+ quick_gelu: bool = False,
168
+ cast_dtype: Optional[torch.dtype] = None,
169
+ ):
170
+ super().__init__()
171
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
172
+
173
+ text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
174
+ self.transformer = text.transformer
175
+ self.vocab_size = text.vocab_size
176
+ self.token_embedding = text.token_embedding
177
+ self.positional_embedding = text.positional_embedding
178
+ self.ln_final = text.ln_final
179
+ self.text_projection = text.text_projection
180
+ self.register_buffer('attn_mask', text.attn_mask, persistent=False)
181
+
182
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
183
+
184
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
185
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
186
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
187
+
188
+ @torch.jit.ignore
189
+ def set_grad_checkpointing(self, enable=True):
190
+ self.visual.set_grad_checkpointing(enable)
191
+ self.transformer.grad_checkpointing = enable
192
+
193
+ def encode_image(self, image, normalize: bool = False):
194
+ features = self.visual(image)
195
+ return F.normalize(features, dim=-1) if normalize else features
196
+
197
+ def encode_text(self, text, normalize: bool = False):
198
+ cast_dtype = self.transformer.get_cast_dtype()
199
+
200
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
201
+
202
+ x = x + self.positional_embedding.to(cast_dtype)
203
+ x = x.permute(1, 0, 2) # NLD -> LND
204
+ x = self.transformer(x, attn_mask=self.attn_mask)
205
+ x = x.permute(1, 0, 2) # LND -> NLD
206
+ x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
207
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
208
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
209
+ return F.normalize(x, dim=-1) if normalize else x
210
+
211
+ def forward(self, image, text):
212
+ image_features = self.encode_image(image, normalize=True)
213
+ text_features = self.encode_text(text, normalize=True)
214
+ return image_features, text_features, self.logit_scale.exp()
215
+
216
+
217
+ class CustomTextCLIP(nn.Module):
218
+ def __init__(
219
+ self,
220
+ embed_dim: int,
221
+ vision_cfg: CLIPVisionCfg,
222
+ text_cfg: CLIPTextCfg,
223
+ quick_gelu: bool = False,
224
+ cast_dtype: Optional[torch.dtype] = None,
225
+ ):
226
+ super().__init__()
227
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
228
+ self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
229
+ if isinstance(text_cfg, dict):
230
+ text_cfg = CLIPTextCfg(**text_cfg)
231
+ if text_cfg.hf_model_name:
232
+ self.use_st = text_cfg.hf_model_name == "sentence-transformers/sentence-t5-base"
233
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
234
+
235
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
236
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
237
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
238
+
239
+ def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
240
+ self.text.lock(unlocked_layers, freeze_layer_norm)
241
+
242
+ @torch.jit.ignore
243
+ def set_grad_checkpointing(self, enable=True):
244
+ self.visual.set_grad_checkpointing(enable)
245
+ self.text.set_grad_checkpointing(enable)
246
+
247
+ def encode_image(self, image, normalize: bool = False, pool: bool = True):
248
+ features = self.visual(image, pool=pool)
249
+ return F.normalize(features, dim=-1) if normalize else features
250
+
251
+ def encode_text(self, text, normalize: bool = False):
252
+ features = self.text(text)
253
+ return F.normalize(features, dim=-1) if normalize else features
254
+
255
+ def forward(self, image, text):
256
+ image_features = self.encode_image(image, normalize=True)
257
+ text_features = self.encode_text(text, normalize=True)
258
+ return image_features, text_features, self.logit_scale.exp()
259
+
260
+
261
+ def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
262
+ """Convert applicable model parameters to low-precision (bf16 or fp16)"""
263
+
264
+ def _convert_weights(l):
265
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
266
+ l.weight.data = l.weight.data.to(dtype)
267
+ if l.bias is not None:
268
+ l.bias.data = l.bias.data.to(dtype)
269
+
270
+ if isinstance(l, (nn.MultiheadAttention, Attention)):
271
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
272
+ tensor = getattr(l, attr)
273
+ if tensor is not None:
274
+ tensor.data = tensor.data.to(dtype)
275
+
276
+ for name in ["text_projection", "proj"]:
277
+ if hasattr(l, name):
278
+ attr = getattr(l, name)
279
+ if attr is not None:
280
+ attr.data = attr.data.to(dtype)
281
+
282
+ model.apply(_convert_weights)
283
+
284
+
285
+ convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
286
+
287
+
288
+ # used to maintain checkpoint compatibility
289
+ def convert_to_custom_text_state_dict(state_dict: dict):
290
+ if 'text_projection' in state_dict:
291
+ # old format state_dict, move text tower -> .text
292
+ new_state_dict = {}
293
+ for k, v in state_dict.items():
294
+ if any(k.startswith(p) for p in (
295
+ 'text_projection',
296
+ 'positional_embedding',
297
+ 'token_embedding',
298
+ 'transformer',
299
+ 'ln_final',
300
+ )):
301
+ k = 'text.' + k
302
+ new_state_dict[k] = v
303
+ return new_state_dict
304
+ return state_dict
305
+
306
+
307
+ def build_model_from_openai_state_dict(
308
+ state_dict: dict,
309
+ quick_gelu=True,
310
+ cast_dtype=torch.float16,
311
+ ):
312
+ vit = "visual.proj" in state_dict
313
+
314
+ if vit:
315
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
316
+ vision_layers = len(
317
+ [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
318
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
319
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
320
+ image_size = vision_patch_size * grid_size
321
+ else:
322
+ counts: list = [
323
+ len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
324
+ vision_layers = tuple(counts)
325
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
326
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
327
+ vision_patch_size = None
328
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
329
+ image_size = output_width * 32
330
+
331
+ embed_dim = state_dict["text_projection"].shape[1]
332
+ context_length = state_dict["positional_embedding"].shape[0]
333
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
334
+ transformer_width = state_dict["ln_final.weight"].shape[0]
335
+ transformer_heads = transformer_width // 64
336
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
337
+
338
+ vision_cfg = CLIPVisionCfg(
339
+ layers=vision_layers,
340
+ width=vision_width,
341
+ patch_size=vision_patch_size,
342
+ image_size=image_size,
343
+ )
344
+ text_cfg = CLIPTextCfg(
345
+ context_length=context_length,
346
+ vocab_size=vocab_size,
347
+ width=transformer_width,
348
+ heads=transformer_heads,
349
+ layers=transformer_layers
350
+ )
351
+ model = CLIP(
352
+ embed_dim,
353
+ vision_cfg=vision_cfg,
354
+ text_cfg=text_cfg,
355
+ quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
356
+ cast_dtype=cast_dtype,
357
+ )
358
+
359
+ for key in ["input_resolution", "context_length", "vocab_size"]:
360
+ state_dict.pop(key, None)
361
+
362
+ convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
363
+ model.load_state_dict(state_dict)
364
+ return model.eval()
365
+
366
+
367
+ def trace_model(model, batch_size=256, device=torch.device('cpu')):
368
+ model.eval()
369
+ image_size = model.visual.image_size
370
+ example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
371
+ example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
372
+ model = torch.jit.trace_module(
373
+ model,
374
+ inputs=dict(
375
+ forward=(example_images, example_text),
376
+ encode_text=(example_text,),
377
+ encode_image=(example_images,)
378
+ ))
379
+ model.visual.image_size = image_size
380
+ return model
381
+
382
+
383
+ def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
384
+ # Rescale the grid of position embeddings when loading from state_dict
385
+ old_pos_embed = state_dict.get('visual.positional_embedding', None)
386
+ if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
387
+ return
388
+ grid_size = to_2tuple(model.visual.grid_size)
389
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
390
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
391
+ if new_seq_len == old_pos_embed.shape[0]:
392
+ return
393
+
394
+ if extra_tokens:
395
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
396
+ else:
397
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
398
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
399
+
400
+ logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
401
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
402
+ pos_emb_img = F.interpolate(
403
+ pos_emb_img,
404
+ size=grid_size,
405
+ mode=interpolation,
406
+ align_corners=True,
407
+ )
408
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
409
+ if pos_emb_tok is not None:
410
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
411
+ else:
412
+ new_pos_embed = pos_emb_img
413
+ state_dict['visual.positional_embedding'] = new_pos_embed
src/open_clip/model_configs/RN101-quickgelu.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": [
7
+ 3,
8
+ 4,
9
+ 23,
10
+ 3
11
+ ],
12
+ "width": 64,
13
+ "patch_size": null
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 512,
19
+ "heads": 8,
20
+ "layers": 12
21
+ }
22
+ }
src/open_clip/model_configs/RN101.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": [
6
+ 3,
7
+ 4,
8
+ 23,
9
+ 3
10
+ ],
11
+ "width": 64,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 512,
18
+ "heads": 8,
19
+ "layers": 12
20
+ }
21
+ }
src/open_clip/model_configs/RN50-quickgelu.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": [
7
+ 3,
8
+ 4,
9
+ 6,
10
+ 3
11
+ ],
12
+ "width": 64,
13
+ "patch_size": null
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 512,
19
+ "heads": 8,
20
+ "layers": 12
21
+ }
22
+ }
src/open_clip/model_configs/RN50.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": [
6
+ 3,
7
+ 4,
8
+ 6,
9
+ 3
10
+ ],
11
+ "width": 64,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 512,
18
+ "heads": 8,
19
+ "layers": 12
20
+ }
21
+ }
src/open_clip/model_configs/RN50x16.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 384,
5
+ "layers": [
6
+ 6,
7
+ 8,
8
+ 18,
9
+ 8
10
+ ],
11
+ "width": 96,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 768,
18
+ "heads": 12,
19
+ "layers": 12
20
+ }
21
+ }
src/open_clip/model_configs/RN50x4.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 640,
3
+ "vision_cfg": {
4
+ "image_size": 288,
5
+ "layers": [
6
+ 4,
7
+ 6,
8
+ 10,
9
+ 6
10
+ ],
11
+ "width": 80,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 640,
18
+ "heads": 10,
19
+ "layers": 12
20
+ }
21
+ }
src/open_clip/model_configs/RN50x64.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 448,
5
+ "layers": [
6
+ 3,
7
+ 15,
8
+ 36,
9
+ 10
10
+ ],
11
+ "width": 128,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 1024,
18
+ "heads": 16,
19
+ "layers": 12
20
+ }
21
+ }
src/open_clip/model_configs/ViT-B-16-plus-240.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 640,
3
+ "vision_cfg": {
4
+ "image_size": 240,
5
+ "layers": 12,
6
+ "width": 896,
7
+ "patch_size": 16
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 640,
13
+ "heads": 10,
14
+ "layers": 12
15
+ }
16
+ }
src/open_clip/model_configs/ViT-B-16-plus.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 640,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 896,
7
+ "patch_size": 16
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 640,
13
+ "heads": 10,
14
+ "layers": 12
15
+ }
16
+ }
src/open_clip/model_configs/ViT-B-16.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 768,
7
+ "patch_size": 16
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 512,
13
+ "heads": 8,
14
+ "layers": 12
15
+ }
16
+ }
src/open_clip/model_configs/ViT-B-32-plus-256.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 640,
3
+ "vision_cfg": {
4
+ "image_size": 256,
5
+ "layers": 12,
6
+ "width": 896,
7
+ "patch_size": 32
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 640,
13
+ "heads": 10,
14
+ "layers": 12
15
+ }
16
+ }
src/open_clip/model_configs/ViT-B-32-quickgelu.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": 12,
7
+ "width": 768,
8
+ "patch_size": 32
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 512,
14
+ "heads": 8,
15
+ "layers": 12
16
+ }
17
+ }
src/open_clip/model_configs/ViT-B-32.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 768,
7
+ "patch_size": 32
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 512,
13
+ "heads": 8,
14
+ "layers": 12
15
+ }
16
+ }
src/open_clip/model_configs/ViT-H-14.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 32,
6
+ "width": 1280,
7
+ "head_width": 80,
8
+ "patch_size": 14
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 1024,
14
+ "heads": 16,
15
+ "layers": 24
16
+ }
17
+ }
src/open_clip/model_configs/ViT-H-16.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 32,
6
+ "width": 1280,
7
+ "head_width": 80,
8
+ "patch_size": 16
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 1024,
14
+ "heads": 16,
15
+ "layers": 24
16
+ }
17
+ }
src/open_clip/model_configs/ViT-L-14-280.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 280,
5
+ "layers": 24,
6
+ "width": 1024,
7
+ "patch_size": 14
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 768,
13
+ "heads": 12,
14
+ "layers": 12
15
+ }
16
+ }
src/open_clip/model_configs/ViT-L-14-336.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 336,
5
+ "layers": 24,
6
+ "width": 1024,
7
+ "patch_size": 14
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 768,
13
+ "heads": 12,
14
+ "layers": 12
15
+ }
16
+ }
src/open_clip/model_configs/ViT-L-14.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 24,
6
+ "width": 1024,
7
+ "patch_size": 14
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 768,
13
+ "heads": 12,
14
+ "layers": 12
15
+ }
16
+ }
src/open_clip/model_configs/ViT-L-16-320.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 320,
5
+ "layers": 24,
6
+ "width": 1024,
7
+ "patch_size": 16
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 768,
13
+ "heads": 12,
14
+ "layers": 12
15
+ }
16
+ }
src/open_clip/model_configs/ViT-L-16.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 24,
6
+ "width": 1024,
7
+ "patch_size": 16
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 768,
13
+ "heads": 12,
14
+ "layers": 12
15
+ }
16
+ }