Spaces:
Sleeping
Sleeping
First commit
Browse files- .gitattributes +0 -34
- .github/workflows/update_space.yml +28 -0
- .gitignore +216 -0
- .idea/.gitignore +8 -0
- .idea/DeepLearning.iml +8 -0
- .idea/deployment.xml +29 -0
- .idea/inspectionProfiles/Project_Default.xml +34 -0
- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/misc.xml +7 -0
- .idea/modules.xml +8 -0
- .idea/vcs.xml +6 -0
- .python-version +1 -0
- PikaPikaTraining.ipynb +112 -0
- pikapikagen/PikaPikaGen.ipynb +2241 -0
- pikapikagen/README.md +6 -0
- pikapikagen/__init__.py +0 -0
- pikapikagen/data_loader.py +100 -0
- pikapikagen/dataset.py +141 -0
- pikapikagen/discriminators.py +161 -0
- pikapikagen/evaluate_kid.py +141 -0
- pikapikagen/gradio_demo.py +291 -0
- pikapikagen/losses.py +103 -0
- pikapikagen/model.py +46 -0
- pikapikagen/model_blocks/decoder_block.py +59 -0
- pikapikagen/model_blocks/image_cross_attention.py +49 -0
- pikapikagen/model_blocks/image_decoder.py +122 -0
- pikapikagen/model_blocks/text_encoder.py +43 -0
- pikapikagen/model_checkpoint/checkpoint_epoch_150.pth +3 -0
- pikapikagen/plots.py +428 -0
- pikapikagen/utils.py +12 -0
- pyproject.toml +18 -0
- uv.lock +0 -0
.gitattributes
CHANGED
@@ -1,35 +1 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.github/workflows/update_space.yml
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Run Python script
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- main
|
7 |
+
|
8 |
+
jobs:
|
9 |
+
build:
|
10 |
+
runs-on: ubuntu-latest
|
11 |
+
|
12 |
+
steps:
|
13 |
+
- name: Checkout
|
14 |
+
uses: actions/checkout@v2
|
15 |
+
|
16 |
+
- name: Set up Python
|
17 |
+
uses: actions/setup-python@v2
|
18 |
+
with:
|
19 |
+
python-version: '3.9'
|
20 |
+
|
21 |
+
- name: Install Gradio
|
22 |
+
run: python -m pip install gradio
|
23 |
+
|
24 |
+
- name: Log in to Hugging Face
|
25 |
+
run: python -c 'import huggingface_hub; huggingface_hub.login(token="${{ secrets.hf_token }}")'
|
26 |
+
|
27 |
+
- name: Deploy to Spaces
|
28 |
+
run: gradio deploy
|
.gitignore
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[codz]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py.cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# UV
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
#uv.lock
|
102 |
+
|
103 |
+
# poetry
|
104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
106 |
+
# commonly ignored for libraries.
|
107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
108 |
+
#poetry.lock
|
109 |
+
#poetry.toml
|
110 |
+
|
111 |
+
# pdm
|
112 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
113 |
+
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
114 |
+
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
115 |
+
#pdm.lock
|
116 |
+
#pdm.toml
|
117 |
+
.pdm-python
|
118 |
+
.pdm-build/
|
119 |
+
|
120 |
+
# pixi
|
121 |
+
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
122 |
+
#pixi.lock
|
123 |
+
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
124 |
+
# in the .venv directory. It is recommended not to include this directory in version control.
|
125 |
+
.pixi
|
126 |
+
|
127 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
128 |
+
__pypackages__/
|
129 |
+
|
130 |
+
# Celery stuff
|
131 |
+
celerybeat-schedule
|
132 |
+
celerybeat.pid
|
133 |
+
|
134 |
+
# SageMath parsed files
|
135 |
+
*.sage.py
|
136 |
+
|
137 |
+
# Environments
|
138 |
+
.env
|
139 |
+
.envrc
|
140 |
+
.venv
|
141 |
+
env/
|
142 |
+
venv/
|
143 |
+
ENV/
|
144 |
+
env.bak/
|
145 |
+
venv.bak/
|
146 |
+
|
147 |
+
# Spyder project settings
|
148 |
+
.spyderproject
|
149 |
+
.spyproject
|
150 |
+
|
151 |
+
# Rope project settings
|
152 |
+
.ropeproject
|
153 |
+
|
154 |
+
# mkdocs documentation
|
155 |
+
/site
|
156 |
+
|
157 |
+
# mypy
|
158 |
+
.mypy_cache/
|
159 |
+
.dmypy.json
|
160 |
+
dmypy.json
|
161 |
+
|
162 |
+
# Pyre type checker
|
163 |
+
.pyre/
|
164 |
+
|
165 |
+
# pytype static type analyzer
|
166 |
+
.pytype/
|
167 |
+
|
168 |
+
# Cython debug symbols
|
169 |
+
cython_debug/
|
170 |
+
|
171 |
+
# PyCharm
|
172 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
173 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
174 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
175 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
176 |
+
#.idea/
|
177 |
+
|
178 |
+
# Abstra
|
179 |
+
# Abstra is an AI-powered process automation framework.
|
180 |
+
# Ignore directories containing user credentials, local state, and settings.
|
181 |
+
# Learn more at https://abstra.io/docs
|
182 |
+
.abstra/
|
183 |
+
|
184 |
+
# Visual Studio Code
|
185 |
+
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
186 |
+
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
187 |
+
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
188 |
+
# you could uncomment the following to ignore the entire vscode folder
|
189 |
+
# .vscode/
|
190 |
+
|
191 |
+
# Ruff stuff:
|
192 |
+
.ruff_cache/
|
193 |
+
|
194 |
+
# PyPI configuration file
|
195 |
+
.pypirc
|
196 |
+
|
197 |
+
# Cursor
|
198 |
+
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
|
199 |
+
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
|
200 |
+
# refer to https://docs.cursor.com/context/ignore-files
|
201 |
+
.cursorignore
|
202 |
+
.cursorindexingignore
|
203 |
+
|
204 |
+
# Marimo
|
205 |
+
marimo/_static/
|
206 |
+
marimo/_lsp/
|
207 |
+
__marimo__/
|
208 |
+
|
209 |
+
# Streamlit
|
210 |
+
.streamlit/secrets.toml
|
211 |
+
|
212 |
+
# Project
|
213 |
+
/pikapikagen/dataset
|
214 |
+
/pikapikagen/training_output
|
215 |
+
/dataset
|
216 |
+
/old_notebooks/dataset
|
.idea/.gitignore
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Default ignored files
|
2 |
+
/shelf/
|
3 |
+
/workspace.xml
|
4 |
+
# Editor-based HTTP Client requests
|
5 |
+
/httpRequests/
|
6 |
+
# Datasource local storage ignored files
|
7 |
+
/dataSources/
|
8 |
+
/dataSources.local.xml
|
.idea/DeepLearning.iml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<module type="PYTHON_MODULE" version="4">
|
3 |
+
<component name="NewModuleRootManager">
|
4 |
+
<content url="file://$MODULE_DIR$" />
|
5 |
+
<orderEntry type="jdk" jdkName="Python 3.12 virtualenv at C:\Users\valer\Mega\Programming\DeepLearning\.venv" jdkType="Python SDK" />
|
6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
7 |
+
</component>
|
8 |
+
</module>
|
.idea/deployment.xml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="PublishConfigData" autoUpload="Always" remoteFilesAllowedToDisappearOnAutoupload="false">
|
4 |
+
<serverData>
|
5 |
+
<paths name="root@salad:22 agent">
|
6 |
+
<serverdata>
|
7 |
+
<mappings>
|
8 |
+
<mapping deploy="/tmp/pycharm_project_71" local="$PROJECT_DIR$" />
|
9 |
+
</mappings>
|
10 |
+
</serverdata>
|
11 |
+
</paths>
|
12 |
+
<paths name="val@46.101.132.64:22 key">
|
13 |
+
<serverdata>
|
14 |
+
<mappings>
|
15 |
+
<mapping local="$PROJECT_DIR$" web="/" />
|
16 |
+
</mappings>
|
17 |
+
</serverdata>
|
18 |
+
</paths>
|
19 |
+
<paths name="val@46.101.132.64:22 key (2)">
|
20 |
+
<serverdata>
|
21 |
+
<mappings>
|
22 |
+
<mapping local="$PROJECT_DIR$" web="/" />
|
23 |
+
</mappings>
|
24 |
+
</serverdata>
|
25 |
+
</paths>
|
26 |
+
</serverData>
|
27 |
+
<option name="myAutoUpload" value="ALWAYS" />
|
28 |
+
</component>
|
29 |
+
</project>
|
.idea/inspectionProfiles/Project_Default.xml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<component name="InspectionProjectProfileManager">
|
2 |
+
<profile version="1.0">
|
3 |
+
<option name="myName" value="Project Default" />
|
4 |
+
<inspection_tool class="DuplicatedCode" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
5 |
+
<Languages>
|
6 |
+
<language minSize="49" name="Python" />
|
7 |
+
</Languages>
|
8 |
+
</inspection_tool>
|
9 |
+
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
|
10 |
+
<inspection_tool class="Mypy" enabled="true" level="TYPO" enabled_by_default="true" editorAttributes="TYPO" />
|
11 |
+
<inspection_tool class="PyPep8NamingInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
12 |
+
<option name="ignoredErrors">
|
13 |
+
<list>
|
14 |
+
<option value="N802" />
|
15 |
+
<option value="N803" />
|
16 |
+
<option value="N806" />
|
17 |
+
</list>
|
18 |
+
</option>
|
19 |
+
</inspection_tool>
|
20 |
+
<inspection_tool class="PyUnresolvedReferencesInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
21 |
+
<option name="ignoredIdentifiers">
|
22 |
+
<list>
|
23 |
+
<option value="fitz.fitz.Page.MediaBox" />
|
24 |
+
<option value="color_tol" />
|
25 |
+
</list>
|
26 |
+
</option>
|
27 |
+
</inspection_tool>
|
28 |
+
<inspection_tool class="SpellCheckingInspection" enabled="false" level="TYPO" enabled_by_default="false">
|
29 |
+
<option name="processCode" value="true" />
|
30 |
+
<option name="processLiterals" value="true" />
|
31 |
+
<option name="processComments" value="true" />
|
32 |
+
</inspection_tool>
|
33 |
+
</profile>
|
34 |
+
</component>
|
.idea/inspectionProfiles/profiles_settings.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<component name="InspectionProjectProfileManager">
|
2 |
+
<settings>
|
3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
4 |
+
<version value="1.0" />
|
5 |
+
</settings>
|
6 |
+
</component>
|
.idea/misc.xml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="Black">
|
4 |
+
<option name="sdkName" value="Python 3.12 virtualenv at C:\Users\valer\Mega\Programming\DeepLearning\.venv" />
|
5 |
+
</component>
|
6 |
+
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.12 virtualenv at C:\Users\valer\Mega\Programming\DeepLearning\.venv" project-jdk-type="Python SDK" />
|
7 |
+
</project>
|
.idea/modules.xml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="ProjectModuleManager">
|
4 |
+
<modules>
|
5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/DeepLearning.iml" filepath="$PROJECT_DIR$/.idea/DeepLearning.iml" />
|
6 |
+
</modules>
|
7 |
+
</component>
|
8 |
+
</project>
|
.idea/vcs.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="VcsDirectoryMappings">
|
4 |
+
<mapping directory="" vcs="Git" />
|
5 |
+
</component>
|
6 |
+
</project>
|
.python-version
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
3.12
|
PikaPikaTraining.ipynb
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# PikaPikaGen: Training del Modello\n",
|
8 |
+
"\n",
|
9 |
+
"Questo notebook automatizza il processo di setup e avvio del training per il modello PikaPikaGen.\n",
|
10 |
+
"\n",
|
11 |
+
"I passaggi eseguiti sono:\n",
|
12 |
+
"1. Clonazione del repository GitHub pubblico.\n",
|
13 |
+
"2. Installazione delle dipendenze necessarie tramite `uv`.\n",
|
14 |
+
"3. Esecuzione dello script di training `main.py`."
|
15 |
+
]
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"cell_type": "code",
|
19 |
+
"execution_count": null,
|
20 |
+
"metadata": {},
|
21 |
+
"outputs": [],
|
22 |
+
"source": [
|
23 |
+
"print(\"Installazione delle dipendenze necessarie...\")\n",
|
24 |
+
"\n",
|
25 |
+
"# Assicurati che uv sia installato\n",
|
26 |
+
"%pip install uv\n",
|
27 |
+
"print(\"✅ uv installato con successo.\")\n",
|
28 |
+
"\n",
|
29 |
+
"# Controlla se torch è già installato\n",
|
30 |
+
"try:\n",
|
31 |
+
" import torch\n",
|
32 |
+
" print(f\"✅ PyTorch già installato (versione: {torch.__version__})\")\n",
|
33 |
+
" torch_installed = True\n",
|
34 |
+
"except ImportError:\n",
|
35 |
+
" print(\"❌ PyTorch non trovato, sarà installato\")\n",
|
36 |
+
" torch_installed = False\n",
|
37 |
+
"\n",
|
38 |
+
"# Lista delle dipendenze principali del progetto\n",
|
39 |
+
"dependencies = [\n",
|
40 |
+
" \"transformers\",\n",
|
41 |
+
" \"pandas\",\n",
|
42 |
+
" \"tqdm\",\n",
|
43 |
+
" \"matplotlib\",\n",
|
44 |
+
" \"Pillow\",\n",
|
45 |
+
" \"requests\",\n",
|
46 |
+
" \"ipywidgets\"\n",
|
47 |
+
"]\n",
|
48 |
+
"\n",
|
49 |
+
"# Aggiungi torch e torchvision solo se non sono già installati\n",
|
50 |
+
"if not torch_installed:\n",
|
51 |
+
" dependencies.extend([\"torch\", \"torchvision\"])\n",
|
52 |
+
"\n",
|
53 |
+
"print(\"Installazione delle dipendenze con uv...\")\n",
|
54 |
+
"deps_str = \" \".join(dependencies)\n",
|
55 |
+
"if torch_installed:\n",
|
56 |
+
" !uv pip install {deps_str}\n",
|
57 |
+
"else:\n",
|
58 |
+
" !uv pip install {deps_str} --torch-backend=auto\n",
|
59 |
+
"print(\"✅ Dipendenze principali installate con successo.\")\n"
|
60 |
+
]
|
61 |
+
},
|
62 |
+
{
|
63 |
+
"cell_type": "code",
|
64 |
+
"execution_count": null,
|
65 |
+
"metadata": {},
|
66 |
+
"outputs": [],
|
67 |
+
"source": [
|
68 |
+
"import os\n",
|
69 |
+
"\n",
|
70 |
+
"repo_url = \"https://github.com/val-2/DeepLearning\"\n",
|
71 |
+
"branch = \"main\"\n",
|
72 |
+
"repo_name = repo_url.split('/')[-1]\n",
|
73 |
+
"\n",
|
74 |
+
"print(f\"Clonazione del repository: {repo_url}\")\n",
|
75 |
+
"\n",
|
76 |
+
"# Check if we're already in the repo directory\n",
|
77 |
+
"current_dir = os.path.basename(os.getcwd())\n",
|
78 |
+
"if current_dir == repo_name:\n",
|
79 |
+
" print(f\"Già nella directory del repository '{repo_name}'. Aggiornamento...\")\n",
|
80 |
+
" !git fetch\n",
|
81 |
+
" !git pull\n",
|
82 |
+
" !git checkout {branch}\n",
|
83 |
+
"elif os.path.exists(repo_name):\n",
|
84 |
+
" print(f\"La directory '{repo_name}' esiste già. Aggiornamento del repository...\")\n",
|
85 |
+
" os.chdir(repo_name)\n",
|
86 |
+
" !git fetch\n",
|
87 |
+
" !git pull\n",
|
88 |
+
" !git checkout {branch}\n",
|
89 |
+
"else:\n",
|
90 |
+
" print(\"Clonazione del repository...\")\n",
|
91 |
+
" !git clone -b {branch} {repo_url}\n",
|
92 |
+
" os.chdir(repo_name)\n",
|
93 |
+
"\n",
|
94 |
+
"# Spostati nella directory del repository\n",
|
95 |
+
"print(f\"Directory di lavoro corrente: {os.getcwd()}\")"
|
96 |
+
]
|
97 |
+
}
|
98 |
+
],
|
99 |
+
"metadata": {
|
100 |
+
"kernelspec": {
|
101 |
+
"display_name": ".venv",
|
102 |
+
"language": "python",
|
103 |
+
"name": "python3"
|
104 |
+
},
|
105 |
+
"language_info": {
|
106 |
+
"name": "python",
|
107 |
+
"version": "3.12.11"
|
108 |
+
}
|
109 |
+
},
|
110 |
+
"nbformat": 4,
|
111 |
+
"nbformat_minor": 2
|
112 |
+
}
|
pikapikagen/PikaPikaGen.ipynb
ADDED
@@ -0,0 +1,2241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "raw",
|
5 |
+
"metadata": {
|
6 |
+
"id": "VDSaH9SVsnNl",
|
7 |
+
"vscode": {
|
8 |
+
"languageId": "raw"
|
9 |
+
}
|
10 |
+
},
|
11 |
+
"source": [
|
12 |
+
"# PikaPikaGen: Text-to-Image Pokemon Sprite Generation with GAN\n"
|
13 |
+
]
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"cell_type": "code",
|
17 |
+
"execution_count": null,
|
18 |
+
"metadata": {},
|
19 |
+
"outputs": [],
|
20 |
+
"source": [
|
21 |
+
"# Install required packages\n",
|
22 |
+
"#!pip install torch torchvision transformers pandas pillow requests matplotlib tqdm ipywidgets gradio torch-fidelity\n",
|
23 |
+
"\n",
|
24 |
+
"import torch\n",
|
25 |
+
"import torch.nn as nn\n",
|
26 |
+
"import torch.optim as optim\n",
|
27 |
+
"import torch.nn.functional as F\n",
|
28 |
+
"\n",
|
29 |
+
"import numpy as np\n",
|
30 |
+
"import matplotlib.pyplot as plt\n",
|
31 |
+
"import os\n",
|
32 |
+
"from tqdm import tqdm\n",
|
33 |
+
"from transformers import AutoTokenizer\n",
|
34 |
+
"import warnings\n",
|
35 |
+
"warnings.filterwarnings('ignore')\n",
|
36 |
+
"\n",
|
37 |
+
"# Set device\n",
|
38 |
+
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
39 |
+
"print(f\"Using device: {device}\")\n",
|
40 |
+
"\n",
|
41 |
+
"# Set random seeds for reproducibility\n",
|
42 |
+
"RANDOM_SEED = 42\n",
|
43 |
+
"torch.manual_seed(RANDOM_SEED)\n",
|
44 |
+
"np.random.seed(RANDOM_SEED)"
|
45 |
+
]
|
46 |
+
},
|
47 |
+
{
|
48 |
+
"cell_type": "raw",
|
49 |
+
"metadata": {
|
50 |
+
"id": "-rrtsHGqsnNo",
|
51 |
+
"vscode": {
|
52 |
+
"languageId": "raw"
|
53 |
+
}
|
54 |
+
},
|
55 |
+
"source": [
|
56 |
+
"## 1. Data Loading and Preprocessing\n"
|
57 |
+
]
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"cell_type": "code",
|
61 |
+
"execution_count": 4,
|
62 |
+
"metadata": {
|
63 |
+
"id": "aeVuv1YCsnNp"
|
64 |
+
},
|
65 |
+
"outputs": [],
|
66 |
+
"source": [
|
67 |
+
"import torch\n",
|
68 |
+
"import torchvision.transforms as T\n",
|
69 |
+
"\n",
|
70 |
+
"\n",
|
71 |
+
"class AugmentationPipeline:\n",
|
72 |
+
" def __init__(self, p=0.8):\n",
|
73 |
+
" self.p = p\n",
|
74 |
+
" self.transforms = T.RandomApply([\n",
|
75 |
+
" T.RandomHorizontalFlip(p=0.5),\n",
|
76 |
+
"\n",
|
77 |
+
" T.RandomAffine(degrees=10, translate=(0.05, 0.05), scale=(0.95, 1.05), fill=1),\n",
|
78 |
+
"\n",
|
79 |
+
" T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),\n",
|
80 |
+
"\n",
|
81 |
+
" T.RandomErasing(p=0.15, scale=(0.02, 0.1), ratio=(0.3, 3.3), value='random'),\n",
|
82 |
+
" ], p=self.p)\n",
|
83 |
+
"\n",
|
84 |
+
" def apply(self, images):\n",
|
85 |
+
" return self.transforms(images)"
|
86 |
+
]
|
87 |
+
},
|
88 |
+
{
|
89 |
+
"cell_type": "code",
|
90 |
+
"execution_count": null,
|
91 |
+
"metadata": {
|
92 |
+
"colab": {
|
93 |
+
"base_uri": "https://localhost:8080/",
|
94 |
+
"height": 1000,
|
95 |
+
"referenced_widgets": [
|
96 |
+
"5efdceae0bac4c978d3a7226247e237f",
|
97 |
+
"a39c5c623a3e42448e109fb9ec6bc263",
|
98 |
+
"a6ed2ddb1c6f4d1aa945c5a39372f781",
|
99 |
+
"8cf950b898e142c1af9b4db92019aa4d",
|
100 |
+
"8ed7abd0602c43a1bfc0f96d7611d429",
|
101 |
+
"65ba2d78fde14bb2baf5ae1101d7e5ff",
|
102 |
+
"4795a78a75dc439a8da7df58bf738940",
|
103 |
+
"4545ff199b874d3680a83918513e1d4b",
|
104 |
+
"cad8fd90586443778568a1babb8c40e6",
|
105 |
+
"57e526d188b9414dabb3b1c895373864",
|
106 |
+
"8226a55726c54abba3a48dbfa8e1b6f6",
|
107 |
+
"86a3c1a4e9eb4989b23364f21e5df531",
|
108 |
+
"5ba39d9d997a45ca848e3e2ffd0e7307",
|
109 |
+
"4c22e1b396f342ffb90c1b50a0051862",
|
110 |
+
"370e5663868f411697bfb24f4e3efa09",
|
111 |
+
"3a338ac4d2944030a07843d8ea24e9fd",
|
112 |
+
"128f4312bcdc4166b9e24d8cdd34184d",
|
113 |
+
"1b65d6c8540e4f458886d5e7075ab30a",
|
114 |
+
"a5a9f8607fdd4f9cad7519eca573f3dc",
|
115 |
+
"926149594f94457295c60b4fad9cbac7",
|
116 |
+
"7e89bc79516f405e9684eacdce7b4551",
|
117 |
+
"c917f3a000fb44338e4afbeabeaab55f"
|
118 |
+
]
|
119 |
+
},
|
120 |
+
"id": "ppTYW-n5snNp",
|
121 |
+
"outputId": "4d7a3003-296a-458c-a339-aeacf5232c91"
|
122 |
+
},
|
123 |
+
"outputs": [],
|
124 |
+
"source": [
|
125 |
+
"from data_loader import create_training_setup\n",
|
126 |
+
"from utils import denormalize_image\n",
|
127 |
+
"\n",
|
128 |
+
"tokenizer = AutoTokenizer.from_pretrained('prajjwal1/bert-mini')\n",
|
129 |
+
"\n",
|
130 |
+
"# train_augmentation_pipeline = AugmentationPipeline()\n",
|
131 |
+
"# Create the complete training setup using the function from pokemon_dataset.py\n",
|
132 |
+
"print(\"Creating training setup with train/val split and fixed batches...\")\n",
|
133 |
+
"training_setup = create_training_setup(\n",
|
134 |
+
" tokenizer=tokenizer,\n",
|
135 |
+
" test_set_size=0.2,\n",
|
136 |
+
" val_set_size=0.1,\n",
|
137 |
+
" batch_size=16,\n",
|
138 |
+
" num_workers=0,\n",
|
139 |
+
" num_viz_samples=4,\n",
|
140 |
+
" random_seed=42,\n",
|
141 |
+
" train_augmentation_pipeline=None\n",
|
142 |
+
")\n",
|
143 |
+
"\n",
|
144 |
+
"# Extract components\n",
|
145 |
+
"train_loader = training_setup['train_loader']\n",
|
146 |
+
"val_loader = training_setup['val_loader']\n",
|
147 |
+
"fixed_train_batch = training_setup['fixed_train_batch']\n",
|
148 |
+
"fixed_val_batch = training_setup['fixed_val_batch']\n",
|
149 |
+
"fixed_train_attention_batch = training_setup['fixed_train_attention_batch']\n",
|
150 |
+
"fixed_val_attention_batch = training_setup['fixed_val_attention_batch']\n",
|
151 |
+
"\n",
|
152 |
+
"print(\"Training setup complete!\")\n",
|
153 |
+
"print(f\"Train loader batches: {len(train_loader)}\")\n",
|
154 |
+
"print(f\"Val loader batches: {len(val_loader)}\")\n",
|
155 |
+
"\n",
|
156 |
+
"# Test the training setup with fixed batches\n",
|
157 |
+
"print(\"\\nFixed batch shapes:\")\n",
|
158 |
+
"print(f\" Train batch - Images: {fixed_train_batch['image'].shape}\")\n",
|
159 |
+
"print(f\" Train batch - Text: {fixed_train_batch['text'].shape}\")\n",
|
160 |
+
"print(f\" Train batch - Attention: {fixed_train_batch['attention_mask'].shape}\")\n",
|
161 |
+
"print(f\" Val batch - Images: {fixed_val_batch['image'].shape}\")\n",
|
162 |
+
"\n",
|
163 |
+
"# Display sample images from fixed batches\n",
|
164 |
+
"fig, axes = plt.subplots(2, 4, figsize=(16, 8))\n",
|
165 |
+
"for i in range(4):\n",
|
166 |
+
" # Fixed train batch images\n",
|
167 |
+
" train_img = denormalize_image(fixed_train_batch['image'][i])\n",
|
168 |
+
" axes[0, i].imshow(train_img.permute(1, 2, 0))\n",
|
169 |
+
" axes[0, i].set_title(f\"Train: {fixed_train_batch['pokemon_name'][i]}\")\n",
|
170 |
+
" axes[0, i].axis('off')\n",
|
171 |
+
"\n",
|
172 |
+
" # Fixed val batch images\n",
|
173 |
+
" val_img = denormalize_image(fixed_val_batch['image'][i])\n",
|
174 |
+
" axes[1, i].imshow(val_img.permute(1, 2, 0))\n",
|
175 |
+
" axes[1, i].set_title(f\"Val: {fixed_val_batch['pokemon_name'][i]}\")\n",
|
176 |
+
" axes[1, i].axis('off')\n",
|
177 |
+
"\n",
|
178 |
+
"plt.suptitle(\"Fixed Batches for Training Visualization\", fontsize=16)\n",
|
179 |
+
"plt.tight_layout()\n",
|
180 |
+
"plt.show()\n",
|
181 |
+
"\n",
|
182 |
+
"\n",
|
183 |
+
"print(\"\\n✅ Dataset and batches loaded successfully from pokemon_dataset.py functionality!\")\n",
|
184 |
+
"print(\"Ready for training with proper train/val split and fixed visualization batches.\")\n"
|
185 |
+
]
|
186 |
+
},
|
187 |
+
{
|
188 |
+
"cell_type": "raw",
|
189 |
+
"metadata": {
|
190 |
+
"id": "eJSVrf3ysnNq",
|
191 |
+
"vscode": {
|
192 |
+
"languageId": "raw"
|
193 |
+
}
|
194 |
+
},
|
195 |
+
"source": [
|
196 |
+
"## 2. Model Architecture Implementation\n"
|
197 |
+
]
|
198 |
+
},
|
199 |
+
{
|
200 |
+
"cell_type": "code",
|
201 |
+
"execution_count": null,
|
202 |
+
"metadata": {
|
203 |
+
"colab": {
|
204 |
+
"base_uri": "https://localhost:8080/",
|
205 |
+
"height": 923,
|
206 |
+
"referenced_widgets": [
|
207 |
+
"bdf500351aea42698c6d6dd5a99021f3",
|
208 |
+
"ab61b90c1a5b4a2b9bb5c9d5a215bb3f",
|
209 |
+
"dc03fed540b74f3aa4a1b17ebf2c81d3",
|
210 |
+
"5837f2c4668646c0a6db2407aebb46e3",
|
211 |
+
"edeb423e9ff84e5c8a0d790368d68bba",
|
212 |
+
"bf8eb066cdaf4ac096dc14392d085daf",
|
213 |
+
"4e32e76c44fb449c8cb767abeb17868a",
|
214 |
+
"5c3cb981f324446eae642f7c23a539f0",
|
215 |
+
"2fe9614fe5984fa6b887d1e1b3e18b04",
|
216 |
+
"64277772cc30408e8ea29f0e268c8880",
|
217 |
+
"5b0d55ea20714104818097bd7d1f509a",
|
218 |
+
"7e21c6a9c7f44496b6f28513caefb631",
|
219 |
+
"439eba0eb4184c0ab83f65fc26bbe388",
|
220 |
+
"eee695744ec64aa7b71b9e85968c6f8f",
|
221 |
+
"c4ecdc9d982f49129368893c1c0aece9",
|
222 |
+
"5f5e7ff6e4c845b99602a4fa00ad550a",
|
223 |
+
"304d50e74ad744cdb3a7cc88739cb923",
|
224 |
+
"bfcc6d01c9ff4db698afa4318e7c91ac",
|
225 |
+
"b2bf751bb96746e4a828241f70e52050",
|
226 |
+
"828b227361fe45cd83964149e7475503",
|
227 |
+
"58ab975eaba2485cb0945482c26ecf3d",
|
228 |
+
"d0b4e43ab5cd4edda6cc061b36bf10a3"
|
229 |
+
]
|
230 |
+
},
|
231 |
+
"id": "RnNQM3_ysnNr",
|
232 |
+
"outputId": "6905696e-05d6-4d97-dd9b-9dc36eea95b7"
|
233 |
+
},
|
234 |
+
"outputs": [],
|
235 |
+
"source": [
|
236 |
+
"from model import Generator\n",
|
237 |
+
"\n",
|
238 |
+
"# Test the generator\n",
|
239 |
+
"generator = Generator().to(device)\n",
|
240 |
+
"with torch.no_grad():\n",
|
241 |
+
" generated_images_256, generated_images_64 = generator(\n",
|
242 |
+
" fixed_train_batch['text'][:2].to(device),\n",
|
243 |
+
" fixed_train_batch['attention_mask'][:2].to(device)\n",
|
244 |
+
" )\n",
|
245 |
+
"print(f\"Generator output shape 256x256: {generated_images_256.shape}\")\n",
|
246 |
+
"print(f\"Generator output shape 64x64: {generated_images_64.shape}\")\n",
|
247 |
+
"\n",
|
248 |
+
"print(\"Generator test\")\n",
|
249 |
+
"plt.figure(figsize=(12, 8))\n",
|
250 |
+
"for i in range(2):\n",
|
251 |
+
" # 256x256 images\n",
|
252 |
+
" plt.subplot(2, 2, i+1)\n",
|
253 |
+
" img_256 = denormalize_image(generated_images_256[i].cpu())\n",
|
254 |
+
" plt.imshow(img_256.permute(1, 2, 0))\n",
|
255 |
+
" plt.title(f\"Generated 256x256 Sample {i+1}\")\n",
|
256 |
+
" plt.axis('off')\n",
|
257 |
+
"\n",
|
258 |
+
" # 64x64 images\n",
|
259 |
+
" plt.subplot(2, 2, i+3)\n",
|
260 |
+
" img_64 = denormalize_image(generated_images_64[i].cpu())\n",
|
261 |
+
" plt.imshow(img_64.permute(1, 2, 0))\n",
|
262 |
+
" plt.title(f\"Generated 64x64 Sample {i+1}\")\n",
|
263 |
+
" plt.axis('off')\n",
|
264 |
+
"plt.tight_layout()\n",
|
265 |
+
"plt.show()\n"
|
266 |
+
]
|
267 |
+
},
|
268 |
+
{
|
269 |
+
"cell_type": "raw",
|
270 |
+
"metadata": {
|
271 |
+
"id": "7drCU21JsnNs",
|
272 |
+
"vscode": {
|
273 |
+
"languageId": "raw"
|
274 |
+
}
|
275 |
+
},
|
276 |
+
"source": [
|
277 |
+
"## 3. Training Setup and Utilities\n"
|
278 |
+
]
|
279 |
+
},
|
280 |
+
{
|
281 |
+
"cell_type": "code",
|
282 |
+
"execution_count": null,
|
283 |
+
"metadata": {
|
284 |
+
"colab": {
|
285 |
+
"base_uri": "https://localhost:8080/"
|
286 |
+
},
|
287 |
+
"id": "iQdhzEQQsnNs",
|
288 |
+
"outputId": "2dbee275-3b6d-43da-8929-97e21403821f"
|
289 |
+
},
|
290 |
+
"outputs": [],
|
291 |
+
"source": [
|
292 |
+
"from discriminators import Discriminator256, Discriminator64\n",
|
293 |
+
"from losses import VGGPerceptualLoss, SobelLoss\n",
|
294 |
+
"from plots import save_attention_visualization\n",
|
295 |
+
"\n",
|
296 |
+
"def weights_init(m):\n",
|
297 |
+
" \"\"\"Initialize model weights according to the original DCGAN paper\"\"\"\n",
|
298 |
+
" classname = m.__class__.__name__\n",
|
299 |
+
" if classname.find('Conv') != -1:\n",
|
300 |
+
" nn.init.normal_(m.weight.data, 0.0, 0.02)\n",
|
301 |
+
" elif classname.find('BatchNorm') != -1:\n",
|
302 |
+
" nn.init.normal_(m.weight.data, 1.0, 0.02)\n",
|
303 |
+
" nn.init.constant_(m.bias.data, 0)\n",
|
304 |
+
"\n",
|
305 |
+
"generator = Generator().to(device)\n",
|
306 |
+
"discriminator_256 = Discriminator256().to(device)\n",
|
307 |
+
"discriminator_64 = Discriminator64().to(device)\n",
|
308 |
+
"\n",
|
309 |
+
"generator.apply(weights_init)\n",
|
310 |
+
"discriminator_256.apply(weights_init)\n",
|
311 |
+
"discriminator_64.apply(weights_init)\n",
|
312 |
+
"\n",
|
313 |
+
"\n",
|
314 |
+
"# Optimizer params\n",
|
315 |
+
"lr = 0.0002\n",
|
316 |
+
"beta1 = 0.5\n",
|
317 |
+
"beta2 = 0.999\n",
|
318 |
+
"\n",
|
319 |
+
"optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2))\n",
|
320 |
+
"optimizer_D_256 = optim.Adam(discriminator_256.parameters(), lr=lr, betas=(beta1, beta2))\n",
|
321 |
+
"optimizer_D_64 = optim.Adam(discriminator_64.parameters(), lr=lr, betas=(beta1, beta2))\n",
|
322 |
+
"\n",
|
323 |
+
"adv_criterion = nn.BCEWithLogitsLoss().to(device) # no sigmoid at the end of discriminators\n",
|
324 |
+
"l1_criterion = nn.L1Loss().to(device)\n",
|
325 |
+
"perc_criterion = VGGPerceptualLoss(device)\n",
|
326 |
+
"sobel_criterion = SobelLoss().to(device)"
|
327 |
+
]
|
328 |
+
},
|
329 |
+
{
|
330 |
+
"cell_type": "code",
|
331 |
+
"execution_count": null,
|
332 |
+
"metadata": {},
|
333 |
+
"outputs": [],
|
334 |
+
"source": [
|
335 |
+
"from typing import TypedDict\n",
|
336 |
+
"import torch\n",
|
337 |
+
"from plots import save_comparison_grid\n",
|
338 |
+
"\n",
|
339 |
+
"# Create checkpoint saving directory\n",
|
340 |
+
"os.makedirs('models', exist_ok=True)\n",
|
341 |
+
"\n",
|
342 |
+
"# TypedDicts to pass and return many object at once, without\n",
|
343 |
+
"class LossesDict(TypedDict):\n",
|
344 |
+
" \"\"\"History of training losses\"\"\"\n",
|
345 |
+
" generator: list[float]\n",
|
346 |
+
" discriminator: list[float]\n",
|
347 |
+
" l1: list[float]\n",
|
348 |
+
" perceptual: list[float]\n",
|
349 |
+
" sobel: list[float]\n",
|
350 |
+
"\n",
|
351 |
+
"class ValidationLossesDict(TypedDict):\n",
|
352 |
+
" \"\"\"History of validation losses\"\"\"\n",
|
353 |
+
" l1: list[float]\n",
|
354 |
+
" perceptual: list[float]\n",
|
355 |
+
" sobel: list[float]\n",
|
356 |
+
" total: list[float]\n",
|
357 |
+
"\n",
|
358 |
+
"class DiscriminatorComponentsDict(TypedDict):\n",
|
359 |
+
" \"\"\"Components of the discriminator loss\"\"\"\n",
|
360 |
+
" real_uncond: float\n",
|
361 |
+
" real_cond: float\n",
|
362 |
+
" real_cond_wrong: float\n",
|
363 |
+
" fake_uncond: float\n",
|
364 |
+
"\n",
|
365 |
+
"class ValidationResultsDict(TypedDict):\n",
|
366 |
+
" \"\"\"Single losses for validation\"\"\"\n",
|
367 |
+
" l1: float\n",
|
368 |
+
" perceptual: float\n",
|
369 |
+
" sobel: float\n",
|
370 |
+
" total: float\n",
|
371 |
+
"\n",
|
372 |
+
"# Training history\n",
|
373 |
+
"losses: LossesDict = {\n",
|
374 |
+
" 'generator': [],\n",
|
375 |
+
" 'discriminator': [],\n",
|
376 |
+
" 'l1': [],\n",
|
377 |
+
" 'perceptual': [],\n",
|
378 |
+
" 'sobel': [],\n",
|
379 |
+
"}\n",
|
380 |
+
"\n",
|
381 |
+
"# Validation history\n",
|
382 |
+
"val_losses: ValidationLossesDict = {\n",
|
383 |
+
" 'l1': [],\n",
|
384 |
+
" 'perceptual': [],\n",
|
385 |
+
" 'sobel': [],\n",
|
386 |
+
" 'total': [],\n",
|
387 |
+
"}\n",
|
388 |
+
"\n",
|
389 |
+
"def validate_model(generator, val_loader, device, l1_criterion, perc_criterion, sobel_criterion) -> ValidationResultsDict:\n",
|
390 |
+
" \"\"\"\n",
|
391 |
+
" Validate the model on the validation set\n",
|
392 |
+
" Returns validation losses\n",
|
393 |
+
" \"\"\"\n",
|
394 |
+
" generator.eval()\n",
|
395 |
+
"\n",
|
396 |
+
" val_l1_loss = 0.0\n",
|
397 |
+
" val_perc_loss = 0.0\n",
|
398 |
+
" val_sobel_loss = 0.0\n",
|
399 |
+
" num_batches = 0\n",
|
400 |
+
"\n",
|
401 |
+
" with torch.no_grad():\n",
|
402 |
+
" for batch in val_loader:\n",
|
403 |
+
" # Move data to device\n",
|
404 |
+
" real_images = batch['image'].to(device)\n",
|
405 |
+
" text_ids = batch['text'].to(device)\n",
|
406 |
+
" attention_mask = batch['attention_mask'].to(device)\n",
|
407 |
+
"\n",
|
408 |
+
" # Generate images\n",
|
409 |
+
" generated_images, _ = generator(text_ids, attention_mask)\n",
|
410 |
+
"\n",
|
411 |
+
" # Calculate validation losses (no adversarial loss)\n",
|
412 |
+
" batch_l1_loss = l1_criterion(generated_images, real_images)\n",
|
413 |
+
" batch_perc_loss = perc_criterion(generated_images, real_images)\n",
|
414 |
+
" batch_sobel_loss = sobel_criterion(generated_images, real_images)\n",
|
415 |
+
"\n",
|
416 |
+
" val_l1_loss += batch_l1_loss.item()\n",
|
417 |
+
" val_perc_loss += batch_perc_loss.item()\n",
|
418 |
+
" val_sobel_loss += batch_sobel_loss.item()\n",
|
419 |
+
" num_batches += 1\n",
|
420 |
+
"\n",
|
421 |
+
" # Calculate averages\n",
|
422 |
+
" avg_val_l1 = val_l1_loss / num_batches\n",
|
423 |
+
" avg_val_perc = val_perc_loss / num_batches\n",
|
424 |
+
" avg_val_sobel = val_sobel_loss / num_batches\n",
|
425 |
+
" avg_val_total = avg_val_l1 + avg_val_perc + avg_val_sobel\n",
|
426 |
+
"\n",
|
427 |
+
" # Set models back to training mode\n",
|
428 |
+
" generator.train()\n",
|
429 |
+
"\n",
|
430 |
+
" return ValidationResultsDict(\n",
|
431 |
+
" l1=avg_val_l1,\n",
|
432 |
+
" perceptual=avg_val_perc,\n",
|
433 |
+
" sobel=avg_val_sobel,\n",
|
434 |
+
" total=avg_val_total\n",
|
435 |
+
" )\n",
|
436 |
+
"\n",
|
437 |
+
"def create_mismatched_text_batch(text_ids, attention_mask):\n",
|
438 |
+
" \"\"\"Create a batch with mismatched text for wrong text conditioning\"\"\"\n",
|
439 |
+
" batch_size = text_ids.size(0)\n",
|
440 |
+
" indices = torch.randperm(batch_size)\n",
|
441 |
+
" return text_ids[indices], attention_mask[indices]\n",
|
442 |
+
"\n",
|
443 |
+
"def compute_discriminator_loss(\n",
|
444 |
+
" discriminator,\n",
|
445 |
+
" real_images,\n",
|
446 |
+
" fake_images,\n",
|
447 |
+
" text_ids,\n",
|
448 |
+
" attention_mask,\n",
|
449 |
+
" wrong_text_ids,\n",
|
450 |
+
" wrong_attention_mask,\n",
|
451 |
+
" real_labels,\n",
|
452 |
+
" fake_labels,\n",
|
453 |
+
" adv_criterion\n",
|
454 |
+
") -> tuple[torch.Tensor, DiscriminatorComponentsDict]:\n",
|
455 |
+
" \"\"\"Compute discriminator loss with the 4 components.\n",
|
456 |
+
" Returns the total loss and the 4 components.\"\"\"\n",
|
457 |
+
" # Real images with correct text\n",
|
458 |
+
" real_uncond, real_cond = discriminator(real_images, text_ids, attention_mask, return_both=True)\n",
|
459 |
+
" real_uncond_loss = adv_criterion(real_uncond, real_labels)\n",
|
460 |
+
" real_cond_loss = adv_criterion(real_cond, real_labels)\n",
|
461 |
+
"\n",
|
462 |
+
" # Real images with wrong text\n",
|
463 |
+
" _, real_cond_wrong = discriminator(real_images, wrong_text_ids, wrong_attention_mask, return_both=True)\n",
|
464 |
+
" real_cond_wrong_loss = adv_criterion(real_cond_wrong, fake_labels)\n",
|
465 |
+
"\n",
|
466 |
+
" # Fake images with wrong text\n",
|
467 |
+
" fake_uncond, _ = discriminator(fake_images.detach(), wrong_text_ids, wrong_attention_mask, return_both=True)\n",
|
468 |
+
" fake_uncond_loss = adv_criterion(fake_uncond, fake_labels)\n",
|
469 |
+
"\n",
|
470 |
+
" total_loss = (real_uncond_loss + real_cond_loss + real_cond_wrong_loss + fake_uncond_loss) / 4\n",
|
471 |
+
"\n",
|
472 |
+
" components: DiscriminatorComponentsDict = {\n",
|
473 |
+
" 'real_uncond': real_uncond_loss.item(),\n",
|
474 |
+
" 'real_cond': real_cond_loss.item(),\n",
|
475 |
+
" 'real_cond_wrong': real_cond_wrong_loss.item(),\n",
|
476 |
+
" 'fake_uncond': fake_uncond_loss.item(),\n",
|
477 |
+
" }\n",
|
478 |
+
"\n",
|
479 |
+
" return total_loss, components\n",
|
480 |
+
"\n",
|
481 |
+
"def compute_generator_adversarial_loss(\n",
|
482 |
+
" discriminator,\n",
|
483 |
+
" fake_images,\n",
|
484 |
+
" text_ids,\n",
|
485 |
+
" attention_mask,\n",
|
486 |
+
" real_labels,\n",
|
487 |
+
" adv_criterion\n",
|
488 |
+
") -> torch.Tensor:\n",
|
489 |
+
" \"\"\"Compute generator adversarial loss for one discriminator\"\"\"\n",
|
490 |
+
" fake_uncond, fake_cond = discriminator(fake_images, text_ids, attention_mask, return_both=True)\n",
|
491 |
+
" uncond_loss = adv_criterion(fake_uncond, real_labels)\n",
|
492 |
+
" cond_loss = adv_criterion(fake_cond, real_labels)\n",
|
493 |
+
" return (uncond_loss + cond_loss) / 2\n",
|
494 |
+
"\n",
|
495 |
+
"def compute_loss(fake_images, real_images, criterion, lmd):\n",
|
496 |
+
" \"\"\"Compute a reconstruction loss only if its lambda > 0\"\"\"\n",
|
497 |
+
" return criterion(fake_images, real_images) if lmd > 0 else torch.tensor(0.0, device=device)\n",
|
498 |
+
"\n",
|
499 |
+
"\n",
|
500 |
+
"epoch = 0\n",
|
501 |
+
"noise_dim = 100\n"
|
502 |
+
]
|
503 |
+
},
|
504 |
+
{
|
505 |
+
"cell_type": "raw",
|
506 |
+
"metadata": {
|
507 |
+
"id": "Oenm8AkasnNt",
|
508 |
+
"vscode": {
|
509 |
+
"languageId": "raw"
|
510 |
+
}
|
511 |
+
},
|
512 |
+
"source": [
|
513 |
+
"## 4. GAN Training Loop"
|
514 |
+
]
|
515 |
+
},
|
516 |
+
{
|
517 |
+
"cell_type": "code",
|
518 |
+
"execution_count": null,
|
519 |
+
"metadata": {
|
520 |
+
"colab": {
|
521 |
+
"base_uri": "https://localhost:8080/",
|
522 |
+
"height": 1000
|
523 |
+
},
|
524 |
+
"id": "gmo0Mi6osnNt",
|
525 |
+
"outputId": "a4691684-def3-4d29-d00d-20ae40287c8c"
|
526 |
+
},
|
527 |
+
"outputs": [],
|
528 |
+
"source": [
|
529 |
+
"from IPython.display import clear_output\n",
|
530 |
+
"\n",
|
531 |
+
"total_epochs = 150\n",
|
532 |
+
"display_interval = 1 # To show generation of training sample\n",
|
533 |
+
"save_interval = 15 # To save checkpoint\n",
|
534 |
+
"clear_interval = 22 # To clear cell output. If too high or not present, Kaggle page would crash\n",
|
535 |
+
"\n",
|
536 |
+
"lambda_l1 = 1.0\n",
|
537 |
+
"lambda_adv = 1.0\n",
|
538 |
+
"lambda_perceptual = 0.0\n",
|
539 |
+
"lambda_sobel = 0.0\n",
|
540 |
+
"\n",
|
541 |
+
"real_label = 1.0\n",
|
542 |
+
"fake_label = 0.0\n",
|
543 |
+
"\n",
|
544 |
+
"print(\"Starting training with dual discriminators...\")\n",
|
545 |
+
"\n",
|
546 |
+
"for epoch in range(epoch, total_epochs):\n",
|
547 |
+
" epoch_g_loss = 0.0\n",
|
548 |
+
" epoch_d_loss_64 = 0.0\n",
|
549 |
+
" epoch_d_loss_256 = 0.0\n",
|
550 |
+
" epoch_l1_loss = 0.0\n",
|
551 |
+
" epoch_perc_loss = 0.0\n",
|
552 |
+
" epoch_sobel_loss = 0.0\n",
|
553 |
+
"\n",
|
554 |
+
" # Track discriminator loss components\n",
|
555 |
+
" epoch_d256_components: DiscriminatorComponentsDict = {'real_uncond': 0.0, 'real_cond': 0.0, 'real_cond_wrong': 0.0, 'fake_uncond': 0.0}\n",
|
556 |
+
" epoch_d64_components: DiscriminatorComponentsDict = {'real_uncond': 0.0, 'real_cond': 0.0, 'real_cond_wrong': 0.0, 'fake_uncond': 0.0}\n",
|
557 |
+
"\n",
|
558 |
+
" progress_bar = tqdm(train_loader, desc=f\"Epoch {epoch+1}/{total_epochs}\")\n",
|
559 |
+
"\n",
|
560 |
+
" for i, batch in enumerate(progress_bar):\n",
|
561 |
+
" batch_size = batch['image'].size(0)\n",
|
562 |
+
"\n",
|
563 |
+
" # Move data to device\n",
|
564 |
+
" real_images = batch['image'].to(device)\n",
|
565 |
+
" text_ids = batch['text'].to(device)\n",
|
566 |
+
" attention_mask = batch['attention_mask'].to(device)\n",
|
567 |
+
"\n",
|
568 |
+
" # Create labels and mismatched text for GAN training\n",
|
569 |
+
" real_labels = torch.full((batch_size, 1), real_label, device=device, dtype=torch.float)\n",
|
570 |
+
" fake_labels = torch.full((batch_size, 1), fake_label, device=device, dtype=torch.float)\n",
|
571 |
+
" wrong_text_ids, wrong_attention_mask = create_mismatched_text_batch(text_ids, attention_mask)\n",
|
572 |
+
"\n",
|
573 |
+
" # Generate fake images\n",
|
574 |
+
" fake_images_256, fake_images_64 = generator(text_ids, attention_mask)\n",
|
575 |
+
" real_images_64 = F.interpolate(real_images, size=(64, 64), mode='bilinear', align_corners=False)\n",
|
576 |
+
"\n",
|
577 |
+
" # Training both discriminators\n",
|
578 |
+
" optimizer_D_256.zero_grad()\n",
|
579 |
+
" optimizer_D_64.zero_grad()\n",
|
580 |
+
"\n",
|
581 |
+
" d_loss_256, d256_components = compute_discriminator_loss(\n",
|
582 |
+
" discriminator_256, real_images, fake_images_256,\n",
|
583 |
+
" text_ids, attention_mask, wrong_text_ids, wrong_attention_mask,\n",
|
584 |
+
" real_labels, fake_labels, adv_criterion\n",
|
585 |
+
" )\n",
|
586 |
+
" d_loss_256.backward()\n",
|
587 |
+
"\n",
|
588 |
+
" d_loss_64, d64_components = compute_discriminator_loss(\n",
|
589 |
+
" discriminator_64, real_images_64, fake_images_64,\n",
|
590 |
+
" text_ids, attention_mask, wrong_text_ids, wrong_attention_mask,\n",
|
591 |
+
" real_labels, fake_labels, adv_criterion\n",
|
592 |
+
" )\n",
|
593 |
+
" d_loss_64.backward()\n",
|
594 |
+
"\n",
|
595 |
+
" optimizer_D_256.step()\n",
|
596 |
+
" optimizer_D_64.step()\n",
|
597 |
+
"\n",
|
598 |
+
" # Training generator\n",
|
599 |
+
" optimizer_G.zero_grad()\n",
|
600 |
+
"\n",
|
601 |
+
" # Adversarial losses for both discriminators\n",
|
602 |
+
" g_adv_loss_256 = compute_generator_adversarial_loss(\n",
|
603 |
+
" discriminator_256, fake_images_256, text_ids, attention_mask, real_labels, adv_criterion\n",
|
604 |
+
" )\n",
|
605 |
+
" g_adv_loss_64 = compute_generator_adversarial_loss(\n",
|
606 |
+
" discriminator_64, fake_images_64, text_ids, attention_mask, real_labels, adv_criterion\n",
|
607 |
+
" )\n",
|
608 |
+
" adversarial_loss = (g_adv_loss_256 + g_adv_loss_64) / 2\n",
|
609 |
+
"\n",
|
610 |
+
" # Compute losses if their lambda is > 0\n",
|
611 |
+
" l1_loss = compute_loss(fake_images_256, real_images, l1_criterion, lambda_l1)\n",
|
612 |
+
" perc_loss = compute_loss(fake_images_256, real_images, perc_criterion, lambda_perceptual)\n",
|
613 |
+
" sobel_loss = compute_loss(fake_images_256, real_images, sobel_criterion, lambda_sobel)\n",
|
614 |
+
"\n",
|
615 |
+
" # Total generator loss\n",
|
616 |
+
" g_loss = (\n",
|
617 |
+
" lambda_adv * adversarial_loss +\n",
|
618 |
+
" lambda_l1 * l1_loss +\n",
|
619 |
+
" lambda_perceptual * perc_loss +\n",
|
620 |
+
" lambda_sobel * sobel_loss\n",
|
621 |
+
" )\n",
|
622 |
+
" g_loss.backward()\n",
|
623 |
+
" optimizer_G.step()\n",
|
624 |
+
"\n",
|
625 |
+
" # Update loss tracking\n",
|
626 |
+
" epoch_g_loss += g_loss.item()\n",
|
627 |
+
" epoch_d_loss_256 += d_loss_256.item()\n",
|
628 |
+
" epoch_d_loss_64 += d_loss_64.item()\n",
|
629 |
+
" epoch_l1_loss += l1_loss.item()\n",
|
630 |
+
" epoch_perc_loss += perc_loss.item()\n",
|
631 |
+
" epoch_sobel_loss += sobel_loss.item()\n",
|
632 |
+
"\n",
|
633 |
+
" # Update discriminator component tracking\n",
|
634 |
+
" for key in epoch_d256_components:\n",
|
635 |
+
" epoch_d256_components[key] += d256_components[key]\n",
|
636 |
+
" epoch_d64_components[key] += d64_components[key]\n",
|
637 |
+
"\n",
|
638 |
+
" # Update progress bar with detailed losses and loss components\n",
|
639 |
+
" progress_bar.set_postfix({\n",
|
640 |
+
" 'G': f'{g_loss.item():.3f}',\n",
|
641 |
+
" 'L1': f'{l1_loss.item():.3f}',\n",
|
642 |
+
" 'Adv': f'{adversarial_loss.item():.3f}',\n",
|
643 |
+
" 'D256': f'{d_loss_256.item():.3f}',\n",
|
644 |
+
" 'D256_ru': f'{d256_components[\"real_uncond\"]:.3f}',\n",
|
645 |
+
" 'D256_rc': f'{d256_components[\"real_cond\"]:.3f}',\n",
|
646 |
+
" 'D256_rcw': f'{d256_components[\"real_cond_wrong\"]:.3f}',\n",
|
647 |
+
" 'D256_fu': f'{d256_components[\"fake_uncond\"]:.3f}',\n",
|
648 |
+
" 'D64': f'{d_loss_64.item():.3f}',\n",
|
649 |
+
" 'D64_ru': f'{d64_components[\"real_uncond\"]:.3f}',\n",
|
650 |
+
" 'D64_rc': f'{d64_components[\"real_cond\"]:.3f}',\n",
|
651 |
+
" 'D64_rcw': f'{d64_components[\"real_cond_wrong\"]:.3f}',\n",
|
652 |
+
" 'D64_fu': f'{d64_components[\"fake_uncond\"]:.3f}',\n",
|
653 |
+
" })\n",
|
654 |
+
"\n",
|
655 |
+
" # Calculate average losses for the epoch\n",
|
656 |
+
" avg_g_loss = epoch_g_loss / len(train_loader)\n",
|
657 |
+
" avg_d_loss_256 = epoch_d_loss_256 / len(train_loader)\n",
|
658 |
+
" avg_d_loss_64 = epoch_d_loss_64 / len(train_loader)\n",
|
659 |
+
" avg_l1_loss = epoch_l1_loss / len(train_loader)\n",
|
660 |
+
" avg_perc_loss = epoch_perc_loss / len(train_loader)\n",
|
661 |
+
" avg_sobel_loss = epoch_sobel_loss / len(train_loader)\n",
|
662 |
+
"\n",
|
663 |
+
" # Calculate average discriminator components for epoch\n",
|
664 |
+
" avg_d256_components = {key: val / len(train_loader) for key, val in epoch_d256_components.items()}\n",
|
665 |
+
" avg_d64_components = {key: val / len(train_loader) for key, val in epoch_d64_components.items()}\n",
|
666 |
+
"\n",
|
667 |
+
" # Store losses (combine discriminator losses)\n",
|
668 |
+
" losses['generator'].append(avg_g_loss)\n",
|
669 |
+
" losses['discriminator'].append((avg_d_loss_256 + avg_d_loss_64) / 2)\n",
|
670 |
+
" losses['l1'].append(avg_l1_loss)\n",
|
671 |
+
" losses['perceptual'].append(avg_perc_loss)\n",
|
672 |
+
" losses['sobel'].append(avg_sobel_loss)\n",
|
673 |
+
"\n",
|
674 |
+
" print(f\"Running validation for epoch {epoch+1}...\")\n",
|
675 |
+
" validation_results = validate_model(generator, val_loader, device, l1_criterion, perc_criterion, sobel_criterion)\n",
|
676 |
+
"\n",
|
677 |
+
" # Store validation losses\n",
|
678 |
+
"\n",
|
679 |
+
" for k, v in validation_results.items():\n",
|
680 |
+
" val_losses[k].append(v)\n",
|
681 |
+
"\n",
|
682 |
+
" if (epoch + 1) % clear_interval == 0:\n",
|
683 |
+
" clear_output(wait=True)\n",
|
684 |
+
"\n",
|
685 |
+
" print(f\"Epoch [{epoch+1}/{total_epochs}]\")\n",
|
686 |
+
" print(f\" Train - D_256: {avg_d_loss_256:.4f}, D_64: {avg_d_loss_64:.4f}, G_loss: {avg_g_loss:.4f}\")\n",
|
687 |
+
" print(f\" D_256 Components - RU: {avg_d256_components['real_uncond']:.4f}, RC: {avg_d256_components['real_cond']:.4f}, RCW: {avg_d256_components['real_cond_wrong']:.4f}, FU: {avg_d256_components['fake_uncond']:.4f}\")\n",
|
688 |
+
" print(f\" D_64 Components - RU: {avg_d64_components['real_uncond']:.4f}, RC: {avg_d64_components['real_cond']:.4f}, RCW: {avg_d64_components['real_cond_wrong']:.4f}, FU: {avg_d64_components['fake_uncond']:.4f}\")\n",
|
689 |
+
" print(f\" Train - L1: {avg_l1_loss:.4f}, Perceptual: {avg_perc_loss:.4f}, Sobel: {avg_sobel_loss:.4f}\")\n",
|
690 |
+
" print(f\" Val - L1: {validation_results['l1']:.4f}, Perceptual: {validation_results['perceptual']:.4f}, Sobel: {validation_results['sobel']:.4f}, Total: {validation_results['total']:.4f}\")\n",
|
691 |
+
" print(\" Legend: RU=RealUncond, RC=RealCond, RCW=RealCondWrong, FU=FakeUncond\")\n",
|
692 |
+
"\n",
|
693 |
+
" # Display generated images\n",
|
694 |
+
" if (epoch + 1) % display_interval == 0:\n",
|
695 |
+
" print(f\"\\nGenerating sample images at epoch {epoch+1}:\")\n",
|
696 |
+
" print(\"256x256 Training Images:\")\n",
|
697 |
+
" save_comparison_grid(epoch+1, generator, fixed_train_batch, \"train_256\", device, show_inline=True)\n",
|
698 |
+
" print(\"64x64 Training Images:\")\n",
|
699 |
+
" save_comparison_grid(epoch+1, generator, fixed_train_batch, \"train_64\", device, show_inline=True)\n",
|
700 |
+
"\n",
|
701 |
+
" # Save checkpoint and show visualizations\n",
|
702 |
+
" if (epoch + 1) % save_interval == 0:\n",
|
703 |
+
" checkpoint_path = f'models/checkpoint_epoch_{epoch+1}.pth'\n",
|
704 |
+
" all_losses = {'train': losses, 'val': val_losses}\n",
|
705 |
+
" checkpoint = {\n",
|
706 |
+
" 'epoch': epoch,\n",
|
707 |
+
" 'generator_state_dict': generator.state_dict(),\n",
|
708 |
+
" 'discriminator_256_state_dict': discriminator_256.state_dict(),\n",
|
709 |
+
" 'discriminator_64_state_dict': discriminator_64.state_dict(),\n",
|
710 |
+
" 'g_optimizer_state_dict': optimizer_G.state_dict(),\n",
|
711 |
+
" 'd_optimizer_state_dict': optimizer_D_256.state_dict(),\n",
|
712 |
+
" 'd_64_optimizer_state_dict': optimizer_D_64.state_dict(),\n",
|
713 |
+
" 'losses': all_losses\n",
|
714 |
+
" }\n",
|
715 |
+
" torch.save(checkpoint, checkpoint_path)\n",
|
716 |
+
" print(f\"Checkpoint saved to {checkpoint_path}\")\n",
|
717 |
+
"\n",
|
718 |
+
" print(\"256x256 Validation Images:\")\n",
|
719 |
+
" save_comparison_grid(epoch+1, generator, fixed_val_batch, \"val_256\", device, show_inline=True)\n",
|
720 |
+
" print(\"64x64 Validation Images:\")\n",
|
721 |
+
" save_comparison_grid(epoch+1, generator, fixed_val_batch, \"val_64\", device, show_inline=True)\n",
|
722 |
+
" save_attention_visualization(epoch+1, generator, tokenizer, fixed_train_batch, device, \"train\", show_inline=True)\n",
|
723 |
+
" save_attention_visualization(epoch+1, generator, tokenizer, fixed_val_batch, device, \"val\", show_inline=True)\n",
|
724 |
+
"\n",
|
725 |
+
"print(\"Training completed!\")\n"
|
726 |
+
]
|
727 |
+
},
|
728 |
+
{
|
729 |
+
"cell_type": "raw",
|
730 |
+
"metadata": {
|
731 |
+
"id": "rbv1Wz4csnNu",
|
732 |
+
"vscode": {
|
733 |
+
"languageId": "raw"
|
734 |
+
}
|
735 |
+
},
|
736 |
+
"source": [
|
737 |
+
"## 5. Training Analysis and Visualization\n"
|
738 |
+
]
|
739 |
+
},
|
740 |
+
{
|
741 |
+
"cell_type": "code",
|
742 |
+
"execution_count": null,
|
743 |
+
"metadata": {
|
744 |
+
"id": "l_90zE2CsnNu"
|
745 |
+
},
|
746 |
+
"outputs": [],
|
747 |
+
"source": [
|
748 |
+
"from plots import save_plot_losses, save_plot_non_gan_losses\n",
|
749 |
+
"\n",
|
750 |
+
"\n",
|
751 |
+
"save_plot_losses(\n",
|
752 |
+
" losses_g=losses['generator'],\n",
|
753 |
+
" losses_d=losses['discriminator'],\n",
|
754 |
+
" output_dir=\"training_output\",\n",
|
755 |
+
" show_inline=True\n",
|
756 |
+
")\n",
|
757 |
+
"\n",
|
758 |
+
"# Plot training vs validation losses for non-GAN components (so except \"generator\" and \"discriminator\" from losses)\n",
|
759 |
+
"# Convert to list of dicts format expected by save_plot_non_gan_losses\n",
|
760 |
+
"train_losses_history = []\n",
|
761 |
+
"val_losses_history = []\n",
|
762 |
+
"\n",
|
763 |
+
"for i in range(len(losses['l1'])):\n",
|
764 |
+
" train_losses_history.append({\n",
|
765 |
+
" 'l1': losses['l1'][i],\n",
|
766 |
+
" 'perceptual': losses['perceptual'][i],\n",
|
767 |
+
" 'sobel': losses['sobel'][i],\n",
|
768 |
+
" 'total': losses['l1'][i] + losses['perceptual'][i] + losses['sobel'][i]\n",
|
769 |
+
" })\n",
|
770 |
+
"\n",
|
771 |
+
"for i in range(len(val_losses['l1'])):\n",
|
772 |
+
" val_losses_history.append({\n",
|
773 |
+
" 'l1': val_losses['l1'][i],\n",
|
774 |
+
" 'perceptual': val_losses['perceptual'][i],\n",
|
775 |
+
" 'sobel': val_losses['sobel'][i],\n",
|
776 |
+
" 'total': val_losses['total'][i]\n",
|
777 |
+
" })\n",
|
778 |
+
"\n",
|
779 |
+
"save_plot_non_gan_losses(\n",
|
780 |
+
" train_losses_history=train_losses_history,\n",
|
781 |
+
" val_losses_history=val_losses_history,\n",
|
782 |
+
" output_dir=\"training_output\",\n",
|
783 |
+
" show_inline=True\n",
|
784 |
+
")\n",
|
785 |
+
"\n",
|
786 |
+
"# Print final statistics\n",
|
787 |
+
"print(f\"Final Train - Generator Loss: {losses['generator'][-1]:.4f}\")\n",
|
788 |
+
"print(f\"Final Train - Discriminator Loss: {losses['discriminator'][-1]:.4f}\")\n",
|
789 |
+
"print(f\"Final Train - L1 Loss: {losses['l1'][-1]:.4f}\")\n",
|
790 |
+
"print(f\"Final Train - Perceptual Loss: {losses['perceptual'][-1]:.4f}\")\n",
|
791 |
+
"print(f\"Final Train - Sobel Loss: {losses['sobel'][-1]:.4f}\")\n",
|
792 |
+
"\n",
|
793 |
+
"print(f\"Final Val - L1 Loss: {val_losses['l1'][-1]:.4f}\")\n",
|
794 |
+
"print(f\"Final Val - Perceptual Loss: {val_losses['perceptual'][-1]:.4f}\")\n",
|
795 |
+
"print(f\"Final Val - Sobel Loss: {val_losses['sobel'][-1]:.4f}\")\n",
|
796 |
+
"print(f\"Final Val - Total Loss: {val_losses['total'][-1]:.4f}\")\n"
|
797 |
+
]
|
798 |
+
},
|
799 |
+
{
|
800 |
+
"cell_type": "code",
|
801 |
+
"execution_count": null,
|
802 |
+
"metadata": {
|
803 |
+
"id": "Io7I7RTqsnNu"
|
804 |
+
},
|
805 |
+
"outputs": [],
|
806 |
+
"source": [
|
807 |
+
"# Generate a grid of final results\n",
|
808 |
+
"print(\"Final Results - Generated Pokemon Sprites (256x256):\")\n",
|
809 |
+
"batch = next(iter(train_loader))\n",
|
810 |
+
"save_comparison_grid(0, generator, batch, \"final_256\", device, show_inline=True)\n",
|
811 |
+
"\n",
|
812 |
+
"print(\"Final Results - Generated Pokemon Sprites (64x64):\")\n",
|
813 |
+
"save_comparison_grid(0, generator, batch, \"final_64\", device, show_inline=True)\n"
|
814 |
+
]
|
815 |
+
},
|
816 |
+
{
|
817 |
+
"cell_type": "raw",
|
818 |
+
"metadata": {
|
819 |
+
"id": "3a_jxGvCsnNu",
|
820 |
+
"vscode": {
|
821 |
+
"languageId": "raw"
|
822 |
+
}
|
823 |
+
},
|
824 |
+
"source": [
|
825 |
+
"## 7. Model Analysis\n"
|
826 |
+
]
|
827 |
+
},
|
828 |
+
{
|
829 |
+
"cell_type": "code",
|
830 |
+
"execution_count": null,
|
831 |
+
"metadata": {},
|
832 |
+
"outputs": [],
|
833 |
+
"source": [
|
834 |
+
"def count_parameters(model):\n",
|
835 |
+
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
|
836 |
+
"\n",
|
837 |
+
"print(f\"Generator parameters: {count_parameters(generator):,}\")\n",
|
838 |
+
"print(f\"Discriminator (256) parameters: {count_parameters(discriminator_256):,}\")\n",
|
839 |
+
"print(f\"Discriminator (64) parameters: {count_parameters(discriminator_64):,}\")\n"
|
840 |
+
]
|
841 |
+
}
|
842 |
+
],
|
843 |
+
"metadata": {
|
844 |
+
"accelerator": "GPU",
|
845 |
+
"colab": {
|
846 |
+
"gpuType": "T4",
|
847 |
+
"provenance": []
|
848 |
+
},
|
849 |
+
"kernelspec": {
|
850 |
+
"display_name": "Python 3 (ipykernel)",
|
851 |
+
"language": "python",
|
852 |
+
"name": "python3"
|
853 |
+
},
|
854 |
+
"language_info": {
|
855 |
+
"codemirror_mode": {
|
856 |
+
"name": "ipython",
|
857 |
+
"version": 3
|
858 |
+
},
|
859 |
+
"file_extension": ".py",
|
860 |
+
"mimetype": "text/x-python",
|
861 |
+
"name": "python",
|
862 |
+
"nbconvert_exporter": "python",
|
863 |
+
"pygments_lexer": "ipython3",
|
864 |
+
"version": "3.12.11"
|
865 |
+
},
|
866 |
+
"widgets": {
|
867 |
+
"application/vnd.jupyter.widget-state+json": {
|
868 |
+
"128f4312bcdc4166b9e24d8cdd34184d": {
|
869 |
+
"model_module": "@jupyter-widgets/base",
|
870 |
+
"model_module_version": "1.2.0",
|
871 |
+
"model_name": "LayoutModel",
|
872 |
+
"state": {
|
873 |
+
"_model_module": "@jupyter-widgets/base",
|
874 |
+
"_model_module_version": "1.2.0",
|
875 |
+
"_model_name": "LayoutModel",
|
876 |
+
"_view_count": null,
|
877 |
+
"_view_module": "@jupyter-widgets/base",
|
878 |
+
"_view_module_version": "1.2.0",
|
879 |
+
"_view_name": "LayoutView",
|
880 |
+
"align_content": null,
|
881 |
+
"align_items": null,
|
882 |
+
"align_self": null,
|
883 |
+
"border": null,
|
884 |
+
"bottom": null,
|
885 |
+
"display": null,
|
886 |
+
"flex": null,
|
887 |
+
"flex_flow": null,
|
888 |
+
"grid_area": null,
|
889 |
+
"grid_auto_columns": null,
|
890 |
+
"grid_auto_flow": null,
|
891 |
+
"grid_auto_rows": null,
|
892 |
+
"grid_column": null,
|
893 |
+
"grid_gap": null,
|
894 |
+
"grid_row": null,
|
895 |
+
"grid_template_areas": null,
|
896 |
+
"grid_template_columns": null,
|
897 |
+
"grid_template_rows": null,
|
898 |
+
"height": null,
|
899 |
+
"justify_content": null,
|
900 |
+
"justify_items": null,
|
901 |
+
"left": null,
|
902 |
+
"margin": null,
|
903 |
+
"max_height": null,
|
904 |
+
"max_width": null,
|
905 |
+
"min_height": null,
|
906 |
+
"min_width": null,
|
907 |
+
"object_fit": null,
|
908 |
+
"object_position": null,
|
909 |
+
"order": null,
|
910 |
+
"overflow": null,
|
911 |
+
"overflow_x": null,
|
912 |
+
"overflow_y": null,
|
913 |
+
"padding": null,
|
914 |
+
"right": null,
|
915 |
+
"top": null,
|
916 |
+
"visibility": null,
|
917 |
+
"width": null
|
918 |
+
}
|
919 |
+
},
|
920 |
+
"1b65d6c8540e4f458886d5e7075ab30a": {
|
921 |
+
"model_module": "@jupyter-widgets/controls",
|
922 |
+
"model_module_version": "1.5.0",
|
923 |
+
"model_name": "DescriptionStyleModel",
|
924 |
+
"state": {
|
925 |
+
"_model_module": "@jupyter-widgets/controls",
|
926 |
+
"_model_module_version": "1.5.0",
|
927 |
+
"_model_name": "DescriptionStyleModel",
|
928 |
+
"_view_count": null,
|
929 |
+
"_view_module": "@jupyter-widgets/base",
|
930 |
+
"_view_module_version": "1.2.0",
|
931 |
+
"_view_name": "StyleView",
|
932 |
+
"description_width": ""
|
933 |
+
}
|
934 |
+
},
|
935 |
+
"2fe9614fe5984fa6b887d1e1b3e18b04": {
|
936 |
+
"model_module": "@jupyter-widgets/controls",
|
937 |
+
"model_module_version": "1.5.0",
|
938 |
+
"model_name": "ProgressStyleModel",
|
939 |
+
"state": {
|
940 |
+
"_model_module": "@jupyter-widgets/controls",
|
941 |
+
"_model_module_version": "1.5.0",
|
942 |
+
"_model_name": "ProgressStyleModel",
|
943 |
+
"_view_count": null,
|
944 |
+
"_view_module": "@jupyter-widgets/base",
|
945 |
+
"_view_module_version": "1.2.0",
|
946 |
+
"_view_name": "StyleView",
|
947 |
+
"bar_color": null,
|
948 |
+
"description_width": ""
|
949 |
+
}
|
950 |
+
},
|
951 |
+
"304d50e74ad744cdb3a7cc88739cb923": {
|
952 |
+
"model_module": "@jupyter-widgets/base",
|
953 |
+
"model_module_version": "1.2.0",
|
954 |
+
"model_name": "LayoutModel",
|
955 |
+
"state": {
|
956 |
+
"_model_module": "@jupyter-widgets/base",
|
957 |
+
"_model_module_version": "1.2.0",
|
958 |
+
"_model_name": "LayoutModel",
|
959 |
+
"_view_count": null,
|
960 |
+
"_view_module": "@jupyter-widgets/base",
|
961 |
+
"_view_module_version": "1.2.0",
|
962 |
+
"_view_name": "LayoutView",
|
963 |
+
"align_content": null,
|
964 |
+
"align_items": null,
|
965 |
+
"align_self": null,
|
966 |
+
"border": null,
|
967 |
+
"bottom": null,
|
968 |
+
"display": null,
|
969 |
+
"flex": null,
|
970 |
+
"flex_flow": null,
|
971 |
+
"grid_area": null,
|
972 |
+
"grid_auto_columns": null,
|
973 |
+
"grid_auto_flow": null,
|
974 |
+
"grid_auto_rows": null,
|
975 |
+
"grid_column": null,
|
976 |
+
"grid_gap": null,
|
977 |
+
"grid_row": null,
|
978 |
+
"grid_template_areas": null,
|
979 |
+
"grid_template_columns": null,
|
980 |
+
"grid_template_rows": null,
|
981 |
+
"height": null,
|
982 |
+
"justify_content": null,
|
983 |
+
"justify_items": null,
|
984 |
+
"left": null,
|
985 |
+
"margin": null,
|
986 |
+
"max_height": null,
|
987 |
+
"max_width": null,
|
988 |
+
"min_height": null,
|
989 |
+
"min_width": null,
|
990 |
+
"object_fit": null,
|
991 |
+
"object_position": null,
|
992 |
+
"order": null,
|
993 |
+
"overflow": null,
|
994 |
+
"overflow_x": null,
|
995 |
+
"overflow_y": null,
|
996 |
+
"padding": null,
|
997 |
+
"right": null,
|
998 |
+
"top": null,
|
999 |
+
"visibility": null,
|
1000 |
+
"width": null
|
1001 |
+
}
|
1002 |
+
},
|
1003 |
+
"370e5663868f411697bfb24f4e3efa09": {
|
1004 |
+
"model_module": "@jupyter-widgets/controls",
|
1005 |
+
"model_module_version": "1.5.0",
|
1006 |
+
"model_name": "HTMLModel",
|
1007 |
+
"state": {
|
1008 |
+
"_dom_classes": [],
|
1009 |
+
"_model_module": "@jupyter-widgets/controls",
|
1010 |
+
"_model_module_version": "1.5.0",
|
1011 |
+
"_model_name": "HTMLModel",
|
1012 |
+
"_view_count": null,
|
1013 |
+
"_view_module": "@jupyter-widgets/controls",
|
1014 |
+
"_view_module_version": "1.5.0",
|
1015 |
+
"_view_name": "HTMLView",
|
1016 |
+
"description": "",
|
1017 |
+
"description_tooltip": null,
|
1018 |
+
"layout": "IPY_MODEL_7e89bc79516f405e9684eacdce7b4551",
|
1019 |
+
"placeholder": "",
|
1020 |
+
"style": "IPY_MODEL_c917f3a000fb44338e4afbeabeaab55f",
|
1021 |
+
"value": " 232k/? [00:00<00:00, 12.0MB/s]"
|
1022 |
+
}
|
1023 |
+
},
|
1024 |
+
"3a338ac4d2944030a07843d8ea24e9fd": {
|
1025 |
+
"model_module": "@jupyter-widgets/base",
|
1026 |
+
"model_module_version": "1.2.0",
|
1027 |
+
"model_name": "LayoutModel",
|
1028 |
+
"state": {
|
1029 |
+
"_model_module": "@jupyter-widgets/base",
|
1030 |
+
"_model_module_version": "1.2.0",
|
1031 |
+
"_model_name": "LayoutModel",
|
1032 |
+
"_view_count": null,
|
1033 |
+
"_view_module": "@jupyter-widgets/base",
|
1034 |
+
"_view_module_version": "1.2.0",
|
1035 |
+
"_view_name": "LayoutView",
|
1036 |
+
"align_content": null,
|
1037 |
+
"align_items": null,
|
1038 |
+
"align_self": null,
|
1039 |
+
"border": null,
|
1040 |
+
"bottom": null,
|
1041 |
+
"display": null,
|
1042 |
+
"flex": null,
|
1043 |
+
"flex_flow": null,
|
1044 |
+
"grid_area": null,
|
1045 |
+
"grid_auto_columns": null,
|
1046 |
+
"grid_auto_flow": null,
|
1047 |
+
"grid_auto_rows": null,
|
1048 |
+
"grid_column": null,
|
1049 |
+
"grid_gap": null,
|
1050 |
+
"grid_row": null,
|
1051 |
+
"grid_template_areas": null,
|
1052 |
+
"grid_template_columns": null,
|
1053 |
+
"grid_template_rows": null,
|
1054 |
+
"height": null,
|
1055 |
+
"justify_content": null,
|
1056 |
+
"justify_items": null,
|
1057 |
+
"left": null,
|
1058 |
+
"margin": null,
|
1059 |
+
"max_height": null,
|
1060 |
+
"max_width": null,
|
1061 |
+
"min_height": null,
|
1062 |
+
"min_width": null,
|
1063 |
+
"object_fit": null,
|
1064 |
+
"object_position": null,
|
1065 |
+
"order": null,
|
1066 |
+
"overflow": null,
|
1067 |
+
"overflow_x": null,
|
1068 |
+
"overflow_y": null,
|
1069 |
+
"padding": null,
|
1070 |
+
"right": null,
|
1071 |
+
"top": null,
|
1072 |
+
"visibility": null,
|
1073 |
+
"width": null
|
1074 |
+
}
|
1075 |
+
},
|
1076 |
+
"439eba0eb4184c0ab83f65fc26bbe388": {
|
1077 |
+
"model_module": "@jupyter-widgets/controls",
|
1078 |
+
"model_module_version": "1.5.0",
|
1079 |
+
"model_name": "HTMLModel",
|
1080 |
+
"state": {
|
1081 |
+
"_dom_classes": [],
|
1082 |
+
"_model_module": "@jupyter-widgets/controls",
|
1083 |
+
"_model_module_version": "1.5.0",
|
1084 |
+
"_model_name": "HTMLModel",
|
1085 |
+
"_view_count": null,
|
1086 |
+
"_view_module": "@jupyter-widgets/controls",
|
1087 |
+
"_view_module_version": "1.5.0",
|
1088 |
+
"_view_name": "HTMLView",
|
1089 |
+
"description": "",
|
1090 |
+
"description_tooltip": null,
|
1091 |
+
"layout": "IPY_MODEL_304d50e74ad744cdb3a7cc88739cb923",
|
1092 |
+
"placeholder": "",
|
1093 |
+
"style": "IPY_MODEL_bfcc6d01c9ff4db698afa4318e7c91ac",
|
1094 |
+
"value": "model.safetensors: 100%"
|
1095 |
+
}
|
1096 |
+
},
|
1097 |
+
"4545ff199b874d3680a83918513e1d4b": {
|
1098 |
+
"model_module": "@jupyter-widgets/base",
|
1099 |
+
"model_module_version": "1.2.0",
|
1100 |
+
"model_name": "LayoutModel",
|
1101 |
+
"state": {
|
1102 |
+
"_model_module": "@jupyter-widgets/base",
|
1103 |
+
"_model_module_version": "1.2.0",
|
1104 |
+
"_model_name": "LayoutModel",
|
1105 |
+
"_view_count": null,
|
1106 |
+
"_view_module": "@jupyter-widgets/base",
|
1107 |
+
"_view_module_version": "1.2.0",
|
1108 |
+
"_view_name": "LayoutView",
|
1109 |
+
"align_content": null,
|
1110 |
+
"align_items": null,
|
1111 |
+
"align_self": null,
|
1112 |
+
"border": null,
|
1113 |
+
"bottom": null,
|
1114 |
+
"display": null,
|
1115 |
+
"flex": null,
|
1116 |
+
"flex_flow": null,
|
1117 |
+
"grid_area": null,
|
1118 |
+
"grid_auto_columns": null,
|
1119 |
+
"grid_auto_flow": null,
|
1120 |
+
"grid_auto_rows": null,
|
1121 |
+
"grid_column": null,
|
1122 |
+
"grid_gap": null,
|
1123 |
+
"grid_row": null,
|
1124 |
+
"grid_template_areas": null,
|
1125 |
+
"grid_template_columns": null,
|
1126 |
+
"grid_template_rows": null,
|
1127 |
+
"height": null,
|
1128 |
+
"justify_content": null,
|
1129 |
+
"justify_items": null,
|
1130 |
+
"left": null,
|
1131 |
+
"margin": null,
|
1132 |
+
"max_height": null,
|
1133 |
+
"max_width": null,
|
1134 |
+
"min_height": null,
|
1135 |
+
"min_width": null,
|
1136 |
+
"object_fit": null,
|
1137 |
+
"object_position": null,
|
1138 |
+
"order": null,
|
1139 |
+
"overflow": null,
|
1140 |
+
"overflow_x": null,
|
1141 |
+
"overflow_y": null,
|
1142 |
+
"padding": null,
|
1143 |
+
"right": null,
|
1144 |
+
"top": null,
|
1145 |
+
"visibility": null,
|
1146 |
+
"width": null
|
1147 |
+
}
|
1148 |
+
},
|
1149 |
+
"4795a78a75dc439a8da7df58bf738940": {
|
1150 |
+
"model_module": "@jupyter-widgets/controls",
|
1151 |
+
"model_module_version": "1.5.0",
|
1152 |
+
"model_name": "DescriptionStyleModel",
|
1153 |
+
"state": {
|
1154 |
+
"_model_module": "@jupyter-widgets/controls",
|
1155 |
+
"_model_module_version": "1.5.0",
|
1156 |
+
"_model_name": "DescriptionStyleModel",
|
1157 |
+
"_view_count": null,
|
1158 |
+
"_view_module": "@jupyter-widgets/base",
|
1159 |
+
"_view_module_version": "1.2.0",
|
1160 |
+
"_view_name": "StyleView",
|
1161 |
+
"description_width": ""
|
1162 |
+
}
|
1163 |
+
},
|
1164 |
+
"4c22e1b396f342ffb90c1b50a0051862": {
|
1165 |
+
"model_module": "@jupyter-widgets/controls",
|
1166 |
+
"model_module_version": "1.5.0",
|
1167 |
+
"model_name": "FloatProgressModel",
|
1168 |
+
"state": {
|
1169 |
+
"_dom_classes": [],
|
1170 |
+
"_model_module": "@jupyter-widgets/controls",
|
1171 |
+
"_model_module_version": "1.5.0",
|
1172 |
+
"_model_name": "FloatProgressModel",
|
1173 |
+
"_view_count": null,
|
1174 |
+
"_view_module": "@jupyter-widgets/controls",
|
1175 |
+
"_view_module_version": "1.5.0",
|
1176 |
+
"_view_name": "ProgressView",
|
1177 |
+
"bar_style": "success",
|
1178 |
+
"description": "",
|
1179 |
+
"description_tooltip": null,
|
1180 |
+
"layout": "IPY_MODEL_a5a9f8607fdd4f9cad7519eca573f3dc",
|
1181 |
+
"max": 1,
|
1182 |
+
"min": 0,
|
1183 |
+
"orientation": "horizontal",
|
1184 |
+
"style": "IPY_MODEL_926149594f94457295c60b4fad9cbac7",
|
1185 |
+
"value": 1
|
1186 |
+
}
|
1187 |
+
},
|
1188 |
+
"4e32e76c44fb449c8cb767abeb17868a": {
|
1189 |
+
"model_module": "@jupyter-widgets/controls",
|
1190 |
+
"model_module_version": "1.5.0",
|
1191 |
+
"model_name": "DescriptionStyleModel",
|
1192 |
+
"state": {
|
1193 |
+
"_model_module": "@jupyter-widgets/controls",
|
1194 |
+
"_model_module_version": "1.5.0",
|
1195 |
+
"_model_name": "DescriptionStyleModel",
|
1196 |
+
"_view_count": null,
|
1197 |
+
"_view_module": "@jupyter-widgets/base",
|
1198 |
+
"_view_module_version": "1.2.0",
|
1199 |
+
"_view_name": "StyleView",
|
1200 |
+
"description_width": ""
|
1201 |
+
}
|
1202 |
+
},
|
1203 |
+
"57e526d188b9414dabb3b1c895373864": {
|
1204 |
+
"model_module": "@jupyter-widgets/base",
|
1205 |
+
"model_module_version": "1.2.0",
|
1206 |
+
"model_name": "LayoutModel",
|
1207 |
+
"state": {
|
1208 |
+
"_model_module": "@jupyter-widgets/base",
|
1209 |
+
"_model_module_version": "1.2.0",
|
1210 |
+
"_model_name": "LayoutModel",
|
1211 |
+
"_view_count": null,
|
1212 |
+
"_view_module": "@jupyter-widgets/base",
|
1213 |
+
"_view_module_version": "1.2.0",
|
1214 |
+
"_view_name": "LayoutView",
|
1215 |
+
"align_content": null,
|
1216 |
+
"align_items": null,
|
1217 |
+
"align_self": null,
|
1218 |
+
"border": null,
|
1219 |
+
"bottom": null,
|
1220 |
+
"display": null,
|
1221 |
+
"flex": null,
|
1222 |
+
"flex_flow": null,
|
1223 |
+
"grid_area": null,
|
1224 |
+
"grid_auto_columns": null,
|
1225 |
+
"grid_auto_flow": null,
|
1226 |
+
"grid_auto_rows": null,
|
1227 |
+
"grid_column": null,
|
1228 |
+
"grid_gap": null,
|
1229 |
+
"grid_row": null,
|
1230 |
+
"grid_template_areas": null,
|
1231 |
+
"grid_template_columns": null,
|
1232 |
+
"grid_template_rows": null,
|
1233 |
+
"height": null,
|
1234 |
+
"justify_content": null,
|
1235 |
+
"justify_items": null,
|
1236 |
+
"left": null,
|
1237 |
+
"margin": null,
|
1238 |
+
"max_height": null,
|
1239 |
+
"max_width": null,
|
1240 |
+
"min_height": null,
|
1241 |
+
"min_width": null,
|
1242 |
+
"object_fit": null,
|
1243 |
+
"object_position": null,
|
1244 |
+
"order": null,
|
1245 |
+
"overflow": null,
|
1246 |
+
"overflow_x": null,
|
1247 |
+
"overflow_y": null,
|
1248 |
+
"padding": null,
|
1249 |
+
"right": null,
|
1250 |
+
"top": null,
|
1251 |
+
"visibility": null,
|
1252 |
+
"width": null
|
1253 |
+
}
|
1254 |
+
},
|
1255 |
+
"5837f2c4668646c0a6db2407aebb46e3": {
|
1256 |
+
"model_module": "@jupyter-widgets/controls",
|
1257 |
+
"model_module_version": "1.5.0",
|
1258 |
+
"model_name": "HTMLModel",
|
1259 |
+
"state": {
|
1260 |
+
"_dom_classes": [],
|
1261 |
+
"_model_module": "@jupyter-widgets/controls",
|
1262 |
+
"_model_module_version": "1.5.0",
|
1263 |
+
"_model_name": "HTMLModel",
|
1264 |
+
"_view_count": null,
|
1265 |
+
"_view_module": "@jupyter-widgets/controls",
|
1266 |
+
"_view_module_version": "1.5.0",
|
1267 |
+
"_view_name": "HTMLView",
|
1268 |
+
"description": "",
|
1269 |
+
"description_tooltip": null,
|
1270 |
+
"layout": "IPY_MODEL_64277772cc30408e8ea29f0e268c8880",
|
1271 |
+
"placeholder": "",
|
1272 |
+
"style": "IPY_MODEL_5b0d55ea20714104818097bd7d1f509a",
|
1273 |
+
"value": " 45.1M/45.1M [00:00<00:00, 112MB/s]"
|
1274 |
+
}
|
1275 |
+
},
|
1276 |
+
"58ab975eaba2485cb0945482c26ecf3d": {
|
1277 |
+
"model_module": "@jupyter-widgets/base",
|
1278 |
+
"model_module_version": "1.2.0",
|
1279 |
+
"model_name": "LayoutModel",
|
1280 |
+
"state": {
|
1281 |
+
"_model_module": "@jupyter-widgets/base",
|
1282 |
+
"_model_module_version": "1.2.0",
|
1283 |
+
"_model_name": "LayoutModel",
|
1284 |
+
"_view_count": null,
|
1285 |
+
"_view_module": "@jupyter-widgets/base",
|
1286 |
+
"_view_module_version": "1.2.0",
|
1287 |
+
"_view_name": "LayoutView",
|
1288 |
+
"align_content": null,
|
1289 |
+
"align_items": null,
|
1290 |
+
"align_self": null,
|
1291 |
+
"border": null,
|
1292 |
+
"bottom": null,
|
1293 |
+
"display": null,
|
1294 |
+
"flex": null,
|
1295 |
+
"flex_flow": null,
|
1296 |
+
"grid_area": null,
|
1297 |
+
"grid_auto_columns": null,
|
1298 |
+
"grid_auto_flow": null,
|
1299 |
+
"grid_auto_rows": null,
|
1300 |
+
"grid_column": null,
|
1301 |
+
"grid_gap": null,
|
1302 |
+
"grid_row": null,
|
1303 |
+
"grid_template_areas": null,
|
1304 |
+
"grid_template_columns": null,
|
1305 |
+
"grid_template_rows": null,
|
1306 |
+
"height": null,
|
1307 |
+
"justify_content": null,
|
1308 |
+
"justify_items": null,
|
1309 |
+
"left": null,
|
1310 |
+
"margin": null,
|
1311 |
+
"max_height": null,
|
1312 |
+
"max_width": null,
|
1313 |
+
"min_height": null,
|
1314 |
+
"min_width": null,
|
1315 |
+
"object_fit": null,
|
1316 |
+
"object_position": null,
|
1317 |
+
"order": null,
|
1318 |
+
"overflow": null,
|
1319 |
+
"overflow_x": null,
|
1320 |
+
"overflow_y": null,
|
1321 |
+
"padding": null,
|
1322 |
+
"right": null,
|
1323 |
+
"top": null,
|
1324 |
+
"visibility": null,
|
1325 |
+
"width": null
|
1326 |
+
}
|
1327 |
+
},
|
1328 |
+
"5b0d55ea20714104818097bd7d1f509a": {
|
1329 |
+
"model_module": "@jupyter-widgets/controls",
|
1330 |
+
"model_module_version": "1.5.0",
|
1331 |
+
"model_name": "DescriptionStyleModel",
|
1332 |
+
"state": {
|
1333 |
+
"_model_module": "@jupyter-widgets/controls",
|
1334 |
+
"_model_module_version": "1.5.0",
|
1335 |
+
"_model_name": "DescriptionStyleModel",
|
1336 |
+
"_view_count": null,
|
1337 |
+
"_view_module": "@jupyter-widgets/base",
|
1338 |
+
"_view_module_version": "1.2.0",
|
1339 |
+
"_view_name": "StyleView",
|
1340 |
+
"description_width": ""
|
1341 |
+
}
|
1342 |
+
},
|
1343 |
+
"5ba39d9d997a45ca848e3e2ffd0e7307": {
|
1344 |
+
"model_module": "@jupyter-widgets/controls",
|
1345 |
+
"model_module_version": "1.5.0",
|
1346 |
+
"model_name": "HTMLModel",
|
1347 |
+
"state": {
|
1348 |
+
"_dom_classes": [],
|
1349 |
+
"_model_module": "@jupyter-widgets/controls",
|
1350 |
+
"_model_module_version": "1.5.0",
|
1351 |
+
"_model_name": "HTMLModel",
|
1352 |
+
"_view_count": null,
|
1353 |
+
"_view_module": "@jupyter-widgets/controls",
|
1354 |
+
"_view_module_version": "1.5.0",
|
1355 |
+
"_view_name": "HTMLView",
|
1356 |
+
"description": "",
|
1357 |
+
"description_tooltip": null,
|
1358 |
+
"layout": "IPY_MODEL_128f4312bcdc4166b9e24d8cdd34184d",
|
1359 |
+
"placeholder": "",
|
1360 |
+
"style": "IPY_MODEL_1b65d6c8540e4f458886d5e7075ab30a",
|
1361 |
+
"value": "vocab.txt: "
|
1362 |
+
}
|
1363 |
+
},
|
1364 |
+
"5c3cb981f324446eae642f7c23a539f0": {
|
1365 |
+
"model_module": "@jupyter-widgets/base",
|
1366 |
+
"model_module_version": "1.2.0",
|
1367 |
+
"model_name": "LayoutModel",
|
1368 |
+
"state": {
|
1369 |
+
"_model_module": "@jupyter-widgets/base",
|
1370 |
+
"_model_module_version": "1.2.0",
|
1371 |
+
"_model_name": "LayoutModel",
|
1372 |
+
"_view_count": null,
|
1373 |
+
"_view_module": "@jupyter-widgets/base",
|
1374 |
+
"_view_module_version": "1.2.0",
|
1375 |
+
"_view_name": "LayoutView",
|
1376 |
+
"align_content": null,
|
1377 |
+
"align_items": null,
|
1378 |
+
"align_self": null,
|
1379 |
+
"border": null,
|
1380 |
+
"bottom": null,
|
1381 |
+
"display": null,
|
1382 |
+
"flex": null,
|
1383 |
+
"flex_flow": null,
|
1384 |
+
"grid_area": null,
|
1385 |
+
"grid_auto_columns": null,
|
1386 |
+
"grid_auto_flow": null,
|
1387 |
+
"grid_auto_rows": null,
|
1388 |
+
"grid_column": null,
|
1389 |
+
"grid_gap": null,
|
1390 |
+
"grid_row": null,
|
1391 |
+
"grid_template_areas": null,
|
1392 |
+
"grid_template_columns": null,
|
1393 |
+
"grid_template_rows": null,
|
1394 |
+
"height": null,
|
1395 |
+
"justify_content": null,
|
1396 |
+
"justify_items": null,
|
1397 |
+
"left": null,
|
1398 |
+
"margin": null,
|
1399 |
+
"max_height": null,
|
1400 |
+
"max_width": null,
|
1401 |
+
"min_height": null,
|
1402 |
+
"min_width": null,
|
1403 |
+
"object_fit": null,
|
1404 |
+
"object_position": null,
|
1405 |
+
"order": null,
|
1406 |
+
"overflow": null,
|
1407 |
+
"overflow_x": null,
|
1408 |
+
"overflow_y": null,
|
1409 |
+
"padding": null,
|
1410 |
+
"right": null,
|
1411 |
+
"top": null,
|
1412 |
+
"visibility": null,
|
1413 |
+
"width": null
|
1414 |
+
}
|
1415 |
+
},
|
1416 |
+
"5efdceae0bac4c978d3a7226247e237f": {
|
1417 |
+
"model_module": "@jupyter-widgets/controls",
|
1418 |
+
"model_module_version": "1.5.0",
|
1419 |
+
"model_name": "HBoxModel",
|
1420 |
+
"state": {
|
1421 |
+
"_dom_classes": [],
|
1422 |
+
"_model_module": "@jupyter-widgets/controls",
|
1423 |
+
"_model_module_version": "1.5.0",
|
1424 |
+
"_model_name": "HBoxModel",
|
1425 |
+
"_view_count": null,
|
1426 |
+
"_view_module": "@jupyter-widgets/controls",
|
1427 |
+
"_view_module_version": "1.5.0",
|
1428 |
+
"_view_name": "HBoxView",
|
1429 |
+
"box_style": "",
|
1430 |
+
"children": [
|
1431 |
+
"IPY_MODEL_a39c5c623a3e42448e109fb9ec6bc263",
|
1432 |
+
"IPY_MODEL_a6ed2ddb1c6f4d1aa945c5a39372f781",
|
1433 |
+
"IPY_MODEL_8cf950b898e142c1af9b4db92019aa4d"
|
1434 |
+
],
|
1435 |
+
"layout": "IPY_MODEL_8ed7abd0602c43a1bfc0f96d7611d429"
|
1436 |
+
}
|
1437 |
+
},
|
1438 |
+
"5f5e7ff6e4c845b99602a4fa00ad550a": {
|
1439 |
+
"model_module": "@jupyter-widgets/base",
|
1440 |
+
"model_module_version": "1.2.0",
|
1441 |
+
"model_name": "LayoutModel",
|
1442 |
+
"state": {
|
1443 |
+
"_model_module": "@jupyter-widgets/base",
|
1444 |
+
"_model_module_version": "1.2.0",
|
1445 |
+
"_model_name": "LayoutModel",
|
1446 |
+
"_view_count": null,
|
1447 |
+
"_view_module": "@jupyter-widgets/base",
|
1448 |
+
"_view_module_version": "1.2.0",
|
1449 |
+
"_view_name": "LayoutView",
|
1450 |
+
"align_content": null,
|
1451 |
+
"align_items": null,
|
1452 |
+
"align_self": null,
|
1453 |
+
"border": null,
|
1454 |
+
"bottom": null,
|
1455 |
+
"display": null,
|
1456 |
+
"flex": null,
|
1457 |
+
"flex_flow": null,
|
1458 |
+
"grid_area": null,
|
1459 |
+
"grid_auto_columns": null,
|
1460 |
+
"grid_auto_flow": null,
|
1461 |
+
"grid_auto_rows": null,
|
1462 |
+
"grid_column": null,
|
1463 |
+
"grid_gap": null,
|
1464 |
+
"grid_row": null,
|
1465 |
+
"grid_template_areas": null,
|
1466 |
+
"grid_template_columns": null,
|
1467 |
+
"grid_template_rows": null,
|
1468 |
+
"height": null,
|
1469 |
+
"justify_content": null,
|
1470 |
+
"justify_items": null,
|
1471 |
+
"left": null,
|
1472 |
+
"margin": null,
|
1473 |
+
"max_height": null,
|
1474 |
+
"max_width": null,
|
1475 |
+
"min_height": null,
|
1476 |
+
"min_width": null,
|
1477 |
+
"object_fit": null,
|
1478 |
+
"object_position": null,
|
1479 |
+
"order": null,
|
1480 |
+
"overflow": null,
|
1481 |
+
"overflow_x": null,
|
1482 |
+
"overflow_y": null,
|
1483 |
+
"padding": null,
|
1484 |
+
"right": null,
|
1485 |
+
"top": null,
|
1486 |
+
"visibility": null,
|
1487 |
+
"width": null
|
1488 |
+
}
|
1489 |
+
},
|
1490 |
+
"64277772cc30408e8ea29f0e268c8880": {
|
1491 |
+
"model_module": "@jupyter-widgets/base",
|
1492 |
+
"model_module_version": "1.2.0",
|
1493 |
+
"model_name": "LayoutModel",
|
1494 |
+
"state": {
|
1495 |
+
"_model_module": "@jupyter-widgets/base",
|
1496 |
+
"_model_module_version": "1.2.0",
|
1497 |
+
"_model_name": "LayoutModel",
|
1498 |
+
"_view_count": null,
|
1499 |
+
"_view_module": "@jupyter-widgets/base",
|
1500 |
+
"_view_module_version": "1.2.0",
|
1501 |
+
"_view_name": "LayoutView",
|
1502 |
+
"align_content": null,
|
1503 |
+
"align_items": null,
|
1504 |
+
"align_self": null,
|
1505 |
+
"border": null,
|
1506 |
+
"bottom": null,
|
1507 |
+
"display": null,
|
1508 |
+
"flex": null,
|
1509 |
+
"flex_flow": null,
|
1510 |
+
"grid_area": null,
|
1511 |
+
"grid_auto_columns": null,
|
1512 |
+
"grid_auto_flow": null,
|
1513 |
+
"grid_auto_rows": null,
|
1514 |
+
"grid_column": null,
|
1515 |
+
"grid_gap": null,
|
1516 |
+
"grid_row": null,
|
1517 |
+
"grid_template_areas": null,
|
1518 |
+
"grid_template_columns": null,
|
1519 |
+
"grid_template_rows": null,
|
1520 |
+
"height": null,
|
1521 |
+
"justify_content": null,
|
1522 |
+
"justify_items": null,
|
1523 |
+
"left": null,
|
1524 |
+
"margin": null,
|
1525 |
+
"max_height": null,
|
1526 |
+
"max_width": null,
|
1527 |
+
"min_height": null,
|
1528 |
+
"min_width": null,
|
1529 |
+
"object_fit": null,
|
1530 |
+
"object_position": null,
|
1531 |
+
"order": null,
|
1532 |
+
"overflow": null,
|
1533 |
+
"overflow_x": null,
|
1534 |
+
"overflow_y": null,
|
1535 |
+
"padding": null,
|
1536 |
+
"right": null,
|
1537 |
+
"top": null,
|
1538 |
+
"visibility": null,
|
1539 |
+
"width": null
|
1540 |
+
}
|
1541 |
+
},
|
1542 |
+
"65ba2d78fde14bb2baf5ae1101d7e5ff": {
|
1543 |
+
"model_module": "@jupyter-widgets/base",
|
1544 |
+
"model_module_version": "1.2.0",
|
1545 |
+
"model_name": "LayoutModel",
|
1546 |
+
"state": {
|
1547 |
+
"_model_module": "@jupyter-widgets/base",
|
1548 |
+
"_model_module_version": "1.2.0",
|
1549 |
+
"_model_name": "LayoutModel",
|
1550 |
+
"_view_count": null,
|
1551 |
+
"_view_module": "@jupyter-widgets/base",
|
1552 |
+
"_view_module_version": "1.2.0",
|
1553 |
+
"_view_name": "LayoutView",
|
1554 |
+
"align_content": null,
|
1555 |
+
"align_items": null,
|
1556 |
+
"align_self": null,
|
1557 |
+
"border": null,
|
1558 |
+
"bottom": null,
|
1559 |
+
"display": null,
|
1560 |
+
"flex": null,
|
1561 |
+
"flex_flow": null,
|
1562 |
+
"grid_area": null,
|
1563 |
+
"grid_auto_columns": null,
|
1564 |
+
"grid_auto_flow": null,
|
1565 |
+
"grid_auto_rows": null,
|
1566 |
+
"grid_column": null,
|
1567 |
+
"grid_gap": null,
|
1568 |
+
"grid_row": null,
|
1569 |
+
"grid_template_areas": null,
|
1570 |
+
"grid_template_columns": null,
|
1571 |
+
"grid_template_rows": null,
|
1572 |
+
"height": null,
|
1573 |
+
"justify_content": null,
|
1574 |
+
"justify_items": null,
|
1575 |
+
"left": null,
|
1576 |
+
"margin": null,
|
1577 |
+
"max_height": null,
|
1578 |
+
"max_width": null,
|
1579 |
+
"min_height": null,
|
1580 |
+
"min_width": null,
|
1581 |
+
"object_fit": null,
|
1582 |
+
"object_position": null,
|
1583 |
+
"order": null,
|
1584 |
+
"overflow": null,
|
1585 |
+
"overflow_x": null,
|
1586 |
+
"overflow_y": null,
|
1587 |
+
"padding": null,
|
1588 |
+
"right": null,
|
1589 |
+
"top": null,
|
1590 |
+
"visibility": null,
|
1591 |
+
"width": null
|
1592 |
+
}
|
1593 |
+
},
|
1594 |
+
"7e21c6a9c7f44496b6f28513caefb631": {
|
1595 |
+
"model_module": "@jupyter-widgets/controls",
|
1596 |
+
"model_module_version": "1.5.0",
|
1597 |
+
"model_name": "HBoxModel",
|
1598 |
+
"state": {
|
1599 |
+
"_dom_classes": [],
|
1600 |
+
"_model_module": "@jupyter-widgets/controls",
|
1601 |
+
"_model_module_version": "1.5.0",
|
1602 |
+
"_model_name": "HBoxModel",
|
1603 |
+
"_view_count": null,
|
1604 |
+
"_view_module": "@jupyter-widgets/controls",
|
1605 |
+
"_view_module_version": "1.5.0",
|
1606 |
+
"_view_name": "HBoxView",
|
1607 |
+
"box_style": "",
|
1608 |
+
"children": [
|
1609 |
+
"IPY_MODEL_439eba0eb4184c0ab83f65fc26bbe388",
|
1610 |
+
"IPY_MODEL_eee695744ec64aa7b71b9e85968c6f8f",
|
1611 |
+
"IPY_MODEL_c4ecdc9d982f49129368893c1c0aece9"
|
1612 |
+
],
|
1613 |
+
"layout": "IPY_MODEL_5f5e7ff6e4c845b99602a4fa00ad550a"
|
1614 |
+
}
|
1615 |
+
},
|
1616 |
+
"7e89bc79516f405e9684eacdce7b4551": {
|
1617 |
+
"model_module": "@jupyter-widgets/base",
|
1618 |
+
"model_module_version": "1.2.0",
|
1619 |
+
"model_name": "LayoutModel",
|
1620 |
+
"state": {
|
1621 |
+
"_model_module": "@jupyter-widgets/base",
|
1622 |
+
"_model_module_version": "1.2.0",
|
1623 |
+
"_model_name": "LayoutModel",
|
1624 |
+
"_view_count": null,
|
1625 |
+
"_view_module": "@jupyter-widgets/base",
|
1626 |
+
"_view_module_version": "1.2.0",
|
1627 |
+
"_view_name": "LayoutView",
|
1628 |
+
"align_content": null,
|
1629 |
+
"align_items": null,
|
1630 |
+
"align_self": null,
|
1631 |
+
"border": null,
|
1632 |
+
"bottom": null,
|
1633 |
+
"display": null,
|
1634 |
+
"flex": null,
|
1635 |
+
"flex_flow": null,
|
1636 |
+
"grid_area": null,
|
1637 |
+
"grid_auto_columns": null,
|
1638 |
+
"grid_auto_flow": null,
|
1639 |
+
"grid_auto_rows": null,
|
1640 |
+
"grid_column": null,
|
1641 |
+
"grid_gap": null,
|
1642 |
+
"grid_row": null,
|
1643 |
+
"grid_template_areas": null,
|
1644 |
+
"grid_template_columns": null,
|
1645 |
+
"grid_template_rows": null,
|
1646 |
+
"height": null,
|
1647 |
+
"justify_content": null,
|
1648 |
+
"justify_items": null,
|
1649 |
+
"left": null,
|
1650 |
+
"margin": null,
|
1651 |
+
"max_height": null,
|
1652 |
+
"max_width": null,
|
1653 |
+
"min_height": null,
|
1654 |
+
"min_width": null,
|
1655 |
+
"object_fit": null,
|
1656 |
+
"object_position": null,
|
1657 |
+
"order": null,
|
1658 |
+
"overflow": null,
|
1659 |
+
"overflow_x": null,
|
1660 |
+
"overflow_y": null,
|
1661 |
+
"padding": null,
|
1662 |
+
"right": null,
|
1663 |
+
"top": null,
|
1664 |
+
"visibility": null,
|
1665 |
+
"width": null
|
1666 |
+
}
|
1667 |
+
},
|
1668 |
+
"8226a55726c54abba3a48dbfa8e1b6f6": {
|
1669 |
+
"model_module": "@jupyter-widgets/controls",
|
1670 |
+
"model_module_version": "1.5.0",
|
1671 |
+
"model_name": "DescriptionStyleModel",
|
1672 |
+
"state": {
|
1673 |
+
"_model_module": "@jupyter-widgets/controls",
|
1674 |
+
"_model_module_version": "1.5.0",
|
1675 |
+
"_model_name": "DescriptionStyleModel",
|
1676 |
+
"_view_count": null,
|
1677 |
+
"_view_module": "@jupyter-widgets/base",
|
1678 |
+
"_view_module_version": "1.2.0",
|
1679 |
+
"_view_name": "StyleView",
|
1680 |
+
"description_width": ""
|
1681 |
+
}
|
1682 |
+
},
|
1683 |
+
"828b227361fe45cd83964149e7475503": {
|
1684 |
+
"model_module": "@jupyter-widgets/controls",
|
1685 |
+
"model_module_version": "1.5.0",
|
1686 |
+
"model_name": "ProgressStyleModel",
|
1687 |
+
"state": {
|
1688 |
+
"_model_module": "@jupyter-widgets/controls",
|
1689 |
+
"_model_module_version": "1.5.0",
|
1690 |
+
"_model_name": "ProgressStyleModel",
|
1691 |
+
"_view_count": null,
|
1692 |
+
"_view_module": "@jupyter-widgets/base",
|
1693 |
+
"_view_module_version": "1.2.0",
|
1694 |
+
"_view_name": "StyleView",
|
1695 |
+
"bar_color": null,
|
1696 |
+
"description_width": ""
|
1697 |
+
}
|
1698 |
+
},
|
1699 |
+
"86a3c1a4e9eb4989b23364f21e5df531": {
|
1700 |
+
"model_module": "@jupyter-widgets/controls",
|
1701 |
+
"model_module_version": "1.5.0",
|
1702 |
+
"model_name": "HBoxModel",
|
1703 |
+
"state": {
|
1704 |
+
"_dom_classes": [],
|
1705 |
+
"_model_module": "@jupyter-widgets/controls",
|
1706 |
+
"_model_module_version": "1.5.0",
|
1707 |
+
"_model_name": "HBoxModel",
|
1708 |
+
"_view_count": null,
|
1709 |
+
"_view_module": "@jupyter-widgets/controls",
|
1710 |
+
"_view_module_version": "1.5.0",
|
1711 |
+
"_view_name": "HBoxView",
|
1712 |
+
"box_style": "",
|
1713 |
+
"children": [
|
1714 |
+
"IPY_MODEL_5ba39d9d997a45ca848e3e2ffd0e7307",
|
1715 |
+
"IPY_MODEL_4c22e1b396f342ffb90c1b50a0051862",
|
1716 |
+
"IPY_MODEL_370e5663868f411697bfb24f4e3efa09"
|
1717 |
+
],
|
1718 |
+
"layout": "IPY_MODEL_3a338ac4d2944030a07843d8ea24e9fd"
|
1719 |
+
}
|
1720 |
+
},
|
1721 |
+
"8cf950b898e142c1af9b4db92019aa4d": {
|
1722 |
+
"model_module": "@jupyter-widgets/controls",
|
1723 |
+
"model_module_version": "1.5.0",
|
1724 |
+
"model_name": "HTMLModel",
|
1725 |
+
"state": {
|
1726 |
+
"_dom_classes": [],
|
1727 |
+
"_model_module": "@jupyter-widgets/controls",
|
1728 |
+
"_model_module_version": "1.5.0",
|
1729 |
+
"_model_name": "HTMLModel",
|
1730 |
+
"_view_count": null,
|
1731 |
+
"_view_module": "@jupyter-widgets/controls",
|
1732 |
+
"_view_module_version": "1.5.0",
|
1733 |
+
"_view_name": "HTMLView",
|
1734 |
+
"description": "",
|
1735 |
+
"description_tooltip": null,
|
1736 |
+
"layout": "IPY_MODEL_57e526d188b9414dabb3b1c895373864",
|
1737 |
+
"placeholder": "",
|
1738 |
+
"style": "IPY_MODEL_8226a55726c54abba3a48dbfa8e1b6f6",
|
1739 |
+
"value": " 286/286 [00:00<00:00, 25.2kB/s]"
|
1740 |
+
}
|
1741 |
+
},
|
1742 |
+
"8ed7abd0602c43a1bfc0f96d7611d429": {
|
1743 |
+
"model_module": "@jupyter-widgets/base",
|
1744 |
+
"model_module_version": "1.2.0",
|
1745 |
+
"model_name": "LayoutModel",
|
1746 |
+
"state": {
|
1747 |
+
"_model_module": "@jupyter-widgets/base",
|
1748 |
+
"_model_module_version": "1.2.0",
|
1749 |
+
"_model_name": "LayoutModel",
|
1750 |
+
"_view_count": null,
|
1751 |
+
"_view_module": "@jupyter-widgets/base",
|
1752 |
+
"_view_module_version": "1.2.0",
|
1753 |
+
"_view_name": "LayoutView",
|
1754 |
+
"align_content": null,
|
1755 |
+
"align_items": null,
|
1756 |
+
"align_self": null,
|
1757 |
+
"border": null,
|
1758 |
+
"bottom": null,
|
1759 |
+
"display": null,
|
1760 |
+
"flex": null,
|
1761 |
+
"flex_flow": null,
|
1762 |
+
"grid_area": null,
|
1763 |
+
"grid_auto_columns": null,
|
1764 |
+
"grid_auto_flow": null,
|
1765 |
+
"grid_auto_rows": null,
|
1766 |
+
"grid_column": null,
|
1767 |
+
"grid_gap": null,
|
1768 |
+
"grid_row": null,
|
1769 |
+
"grid_template_areas": null,
|
1770 |
+
"grid_template_columns": null,
|
1771 |
+
"grid_template_rows": null,
|
1772 |
+
"height": null,
|
1773 |
+
"justify_content": null,
|
1774 |
+
"justify_items": null,
|
1775 |
+
"left": null,
|
1776 |
+
"margin": null,
|
1777 |
+
"max_height": null,
|
1778 |
+
"max_width": null,
|
1779 |
+
"min_height": null,
|
1780 |
+
"min_width": null,
|
1781 |
+
"object_fit": null,
|
1782 |
+
"object_position": null,
|
1783 |
+
"order": null,
|
1784 |
+
"overflow": null,
|
1785 |
+
"overflow_x": null,
|
1786 |
+
"overflow_y": null,
|
1787 |
+
"padding": null,
|
1788 |
+
"right": null,
|
1789 |
+
"top": null,
|
1790 |
+
"visibility": null,
|
1791 |
+
"width": null
|
1792 |
+
}
|
1793 |
+
},
|
1794 |
+
"926149594f94457295c60b4fad9cbac7": {
|
1795 |
+
"model_module": "@jupyter-widgets/controls",
|
1796 |
+
"model_module_version": "1.5.0",
|
1797 |
+
"model_name": "ProgressStyleModel",
|
1798 |
+
"state": {
|
1799 |
+
"_model_module": "@jupyter-widgets/controls",
|
1800 |
+
"_model_module_version": "1.5.0",
|
1801 |
+
"_model_name": "ProgressStyleModel",
|
1802 |
+
"_view_count": null,
|
1803 |
+
"_view_module": "@jupyter-widgets/base",
|
1804 |
+
"_view_module_version": "1.2.0",
|
1805 |
+
"_view_name": "StyleView",
|
1806 |
+
"bar_color": null,
|
1807 |
+
"description_width": ""
|
1808 |
+
}
|
1809 |
+
},
|
1810 |
+
"a39c5c623a3e42448e109fb9ec6bc263": {
|
1811 |
+
"model_module": "@jupyter-widgets/controls",
|
1812 |
+
"model_module_version": "1.5.0",
|
1813 |
+
"model_name": "HTMLModel",
|
1814 |
+
"state": {
|
1815 |
+
"_dom_classes": [],
|
1816 |
+
"_model_module": "@jupyter-widgets/controls",
|
1817 |
+
"_model_module_version": "1.5.0",
|
1818 |
+
"_model_name": "HTMLModel",
|
1819 |
+
"_view_count": null,
|
1820 |
+
"_view_module": "@jupyter-widgets/controls",
|
1821 |
+
"_view_module_version": "1.5.0",
|
1822 |
+
"_view_name": "HTMLView",
|
1823 |
+
"description": "",
|
1824 |
+
"description_tooltip": null,
|
1825 |
+
"layout": "IPY_MODEL_65ba2d78fde14bb2baf5ae1101d7e5ff",
|
1826 |
+
"placeholder": "",
|
1827 |
+
"style": "IPY_MODEL_4795a78a75dc439a8da7df58bf738940",
|
1828 |
+
"value": "config.json: 100%"
|
1829 |
+
}
|
1830 |
+
},
|
1831 |
+
"a5a9f8607fdd4f9cad7519eca573f3dc": {
|
1832 |
+
"model_module": "@jupyter-widgets/base",
|
1833 |
+
"model_module_version": "1.2.0",
|
1834 |
+
"model_name": "LayoutModel",
|
1835 |
+
"state": {
|
1836 |
+
"_model_module": "@jupyter-widgets/base",
|
1837 |
+
"_model_module_version": "1.2.0",
|
1838 |
+
"_model_name": "LayoutModel",
|
1839 |
+
"_view_count": null,
|
1840 |
+
"_view_module": "@jupyter-widgets/base",
|
1841 |
+
"_view_module_version": "1.2.0",
|
1842 |
+
"_view_name": "LayoutView",
|
1843 |
+
"align_content": null,
|
1844 |
+
"align_items": null,
|
1845 |
+
"align_self": null,
|
1846 |
+
"border": null,
|
1847 |
+
"bottom": null,
|
1848 |
+
"display": null,
|
1849 |
+
"flex": null,
|
1850 |
+
"flex_flow": null,
|
1851 |
+
"grid_area": null,
|
1852 |
+
"grid_auto_columns": null,
|
1853 |
+
"grid_auto_flow": null,
|
1854 |
+
"grid_auto_rows": null,
|
1855 |
+
"grid_column": null,
|
1856 |
+
"grid_gap": null,
|
1857 |
+
"grid_row": null,
|
1858 |
+
"grid_template_areas": null,
|
1859 |
+
"grid_template_columns": null,
|
1860 |
+
"grid_template_rows": null,
|
1861 |
+
"height": null,
|
1862 |
+
"justify_content": null,
|
1863 |
+
"justify_items": null,
|
1864 |
+
"left": null,
|
1865 |
+
"margin": null,
|
1866 |
+
"max_height": null,
|
1867 |
+
"max_width": null,
|
1868 |
+
"min_height": null,
|
1869 |
+
"min_width": null,
|
1870 |
+
"object_fit": null,
|
1871 |
+
"object_position": null,
|
1872 |
+
"order": null,
|
1873 |
+
"overflow": null,
|
1874 |
+
"overflow_x": null,
|
1875 |
+
"overflow_y": null,
|
1876 |
+
"padding": null,
|
1877 |
+
"right": null,
|
1878 |
+
"top": null,
|
1879 |
+
"visibility": null,
|
1880 |
+
"width": "20px"
|
1881 |
+
}
|
1882 |
+
},
|
1883 |
+
"a6ed2ddb1c6f4d1aa945c5a39372f781": {
|
1884 |
+
"model_module": "@jupyter-widgets/controls",
|
1885 |
+
"model_module_version": "1.5.0",
|
1886 |
+
"model_name": "FloatProgressModel",
|
1887 |
+
"state": {
|
1888 |
+
"_dom_classes": [],
|
1889 |
+
"_model_module": "@jupyter-widgets/controls",
|
1890 |
+
"_model_module_version": "1.5.0",
|
1891 |
+
"_model_name": "FloatProgressModel",
|
1892 |
+
"_view_count": null,
|
1893 |
+
"_view_module": "@jupyter-widgets/controls",
|
1894 |
+
"_view_module_version": "1.5.0",
|
1895 |
+
"_view_name": "ProgressView",
|
1896 |
+
"bar_style": "success",
|
1897 |
+
"description": "",
|
1898 |
+
"description_tooltip": null,
|
1899 |
+
"layout": "IPY_MODEL_4545ff199b874d3680a83918513e1d4b",
|
1900 |
+
"max": 286,
|
1901 |
+
"min": 0,
|
1902 |
+
"orientation": "horizontal",
|
1903 |
+
"style": "IPY_MODEL_cad8fd90586443778568a1babb8c40e6",
|
1904 |
+
"value": 286
|
1905 |
+
}
|
1906 |
+
},
|
1907 |
+
"ab61b90c1a5b4a2b9bb5c9d5a215bb3f": {
|
1908 |
+
"model_module": "@jupyter-widgets/controls",
|
1909 |
+
"model_module_version": "1.5.0",
|
1910 |
+
"model_name": "HTMLModel",
|
1911 |
+
"state": {
|
1912 |
+
"_dom_classes": [],
|
1913 |
+
"_model_module": "@jupyter-widgets/controls",
|
1914 |
+
"_model_module_version": "1.5.0",
|
1915 |
+
"_model_name": "HTMLModel",
|
1916 |
+
"_view_count": null,
|
1917 |
+
"_view_module": "@jupyter-widgets/controls",
|
1918 |
+
"_view_module_version": "1.5.0",
|
1919 |
+
"_view_name": "HTMLView",
|
1920 |
+
"description": "",
|
1921 |
+
"description_tooltip": null,
|
1922 |
+
"layout": "IPY_MODEL_bf8eb066cdaf4ac096dc14392d085daf",
|
1923 |
+
"placeholder": "",
|
1924 |
+
"style": "IPY_MODEL_4e32e76c44fb449c8cb767abeb17868a",
|
1925 |
+
"value": "pytorch_model.bin: 100%"
|
1926 |
+
}
|
1927 |
+
},
|
1928 |
+
"b2bf751bb96746e4a828241f70e52050": {
|
1929 |
+
"model_module": "@jupyter-widgets/base",
|
1930 |
+
"model_module_version": "1.2.0",
|
1931 |
+
"model_name": "LayoutModel",
|
1932 |
+
"state": {
|
1933 |
+
"_model_module": "@jupyter-widgets/base",
|
1934 |
+
"_model_module_version": "1.2.0",
|
1935 |
+
"_model_name": "LayoutModel",
|
1936 |
+
"_view_count": null,
|
1937 |
+
"_view_module": "@jupyter-widgets/base",
|
1938 |
+
"_view_module_version": "1.2.0",
|
1939 |
+
"_view_name": "LayoutView",
|
1940 |
+
"align_content": null,
|
1941 |
+
"align_items": null,
|
1942 |
+
"align_self": null,
|
1943 |
+
"border": null,
|
1944 |
+
"bottom": null,
|
1945 |
+
"display": null,
|
1946 |
+
"flex": null,
|
1947 |
+
"flex_flow": null,
|
1948 |
+
"grid_area": null,
|
1949 |
+
"grid_auto_columns": null,
|
1950 |
+
"grid_auto_flow": null,
|
1951 |
+
"grid_auto_rows": null,
|
1952 |
+
"grid_column": null,
|
1953 |
+
"grid_gap": null,
|
1954 |
+
"grid_row": null,
|
1955 |
+
"grid_template_areas": null,
|
1956 |
+
"grid_template_columns": null,
|
1957 |
+
"grid_template_rows": null,
|
1958 |
+
"height": null,
|
1959 |
+
"justify_content": null,
|
1960 |
+
"justify_items": null,
|
1961 |
+
"left": null,
|
1962 |
+
"margin": null,
|
1963 |
+
"max_height": null,
|
1964 |
+
"max_width": null,
|
1965 |
+
"min_height": null,
|
1966 |
+
"min_width": null,
|
1967 |
+
"object_fit": null,
|
1968 |
+
"object_position": null,
|
1969 |
+
"order": null,
|
1970 |
+
"overflow": null,
|
1971 |
+
"overflow_x": null,
|
1972 |
+
"overflow_y": null,
|
1973 |
+
"padding": null,
|
1974 |
+
"right": null,
|
1975 |
+
"top": null,
|
1976 |
+
"visibility": null,
|
1977 |
+
"width": null
|
1978 |
+
}
|
1979 |
+
},
|
1980 |
+
"bdf500351aea42698c6d6dd5a99021f3": {
|
1981 |
+
"model_module": "@jupyter-widgets/controls",
|
1982 |
+
"model_module_version": "1.5.0",
|
1983 |
+
"model_name": "HBoxModel",
|
1984 |
+
"state": {
|
1985 |
+
"_dom_classes": [],
|
1986 |
+
"_model_module": "@jupyter-widgets/controls",
|
1987 |
+
"_model_module_version": "1.5.0",
|
1988 |
+
"_model_name": "HBoxModel",
|
1989 |
+
"_view_count": null,
|
1990 |
+
"_view_module": "@jupyter-widgets/controls",
|
1991 |
+
"_view_module_version": "1.5.0",
|
1992 |
+
"_view_name": "HBoxView",
|
1993 |
+
"box_style": "",
|
1994 |
+
"children": [
|
1995 |
+
"IPY_MODEL_ab61b90c1a5b4a2b9bb5c9d5a215bb3f",
|
1996 |
+
"IPY_MODEL_dc03fed540b74f3aa4a1b17ebf2c81d3",
|
1997 |
+
"IPY_MODEL_5837f2c4668646c0a6db2407aebb46e3"
|
1998 |
+
],
|
1999 |
+
"layout": "IPY_MODEL_edeb423e9ff84e5c8a0d790368d68bba"
|
2000 |
+
}
|
2001 |
+
},
|
2002 |
+
"bf8eb066cdaf4ac096dc14392d085daf": {
|
2003 |
+
"model_module": "@jupyter-widgets/base",
|
2004 |
+
"model_module_version": "1.2.0",
|
2005 |
+
"model_name": "LayoutModel",
|
2006 |
+
"state": {
|
2007 |
+
"_model_module": "@jupyter-widgets/base",
|
2008 |
+
"_model_module_version": "1.2.0",
|
2009 |
+
"_model_name": "LayoutModel",
|
2010 |
+
"_view_count": null,
|
2011 |
+
"_view_module": "@jupyter-widgets/base",
|
2012 |
+
"_view_module_version": "1.2.0",
|
2013 |
+
"_view_name": "LayoutView",
|
2014 |
+
"align_content": null,
|
2015 |
+
"align_items": null,
|
2016 |
+
"align_self": null,
|
2017 |
+
"border": null,
|
2018 |
+
"bottom": null,
|
2019 |
+
"display": null,
|
2020 |
+
"flex": null,
|
2021 |
+
"flex_flow": null,
|
2022 |
+
"grid_area": null,
|
2023 |
+
"grid_auto_columns": null,
|
2024 |
+
"grid_auto_flow": null,
|
2025 |
+
"grid_auto_rows": null,
|
2026 |
+
"grid_column": null,
|
2027 |
+
"grid_gap": null,
|
2028 |
+
"grid_row": null,
|
2029 |
+
"grid_template_areas": null,
|
2030 |
+
"grid_template_columns": null,
|
2031 |
+
"grid_template_rows": null,
|
2032 |
+
"height": null,
|
2033 |
+
"justify_content": null,
|
2034 |
+
"justify_items": null,
|
2035 |
+
"left": null,
|
2036 |
+
"margin": null,
|
2037 |
+
"max_height": null,
|
2038 |
+
"max_width": null,
|
2039 |
+
"min_height": null,
|
2040 |
+
"min_width": null,
|
2041 |
+
"object_fit": null,
|
2042 |
+
"object_position": null,
|
2043 |
+
"order": null,
|
2044 |
+
"overflow": null,
|
2045 |
+
"overflow_x": null,
|
2046 |
+
"overflow_y": null,
|
2047 |
+
"padding": null,
|
2048 |
+
"right": null,
|
2049 |
+
"top": null,
|
2050 |
+
"visibility": null,
|
2051 |
+
"width": null
|
2052 |
+
}
|
2053 |
+
},
|
2054 |
+
"bfcc6d01c9ff4db698afa4318e7c91ac": {
|
2055 |
+
"model_module": "@jupyter-widgets/controls",
|
2056 |
+
"model_module_version": "1.5.0",
|
2057 |
+
"model_name": "DescriptionStyleModel",
|
2058 |
+
"state": {
|
2059 |
+
"_model_module": "@jupyter-widgets/controls",
|
2060 |
+
"_model_module_version": "1.5.0",
|
2061 |
+
"_model_name": "DescriptionStyleModel",
|
2062 |
+
"_view_count": null,
|
2063 |
+
"_view_module": "@jupyter-widgets/base",
|
2064 |
+
"_view_module_version": "1.2.0",
|
2065 |
+
"_view_name": "StyleView",
|
2066 |
+
"description_width": ""
|
2067 |
+
}
|
2068 |
+
},
|
2069 |
+
"c4ecdc9d982f49129368893c1c0aece9": {
|
2070 |
+
"model_module": "@jupyter-widgets/controls",
|
2071 |
+
"model_module_version": "1.5.0",
|
2072 |
+
"model_name": "HTMLModel",
|
2073 |
+
"state": {
|
2074 |
+
"_dom_classes": [],
|
2075 |
+
"_model_module": "@jupyter-widgets/controls",
|
2076 |
+
"_model_module_version": "1.5.0",
|
2077 |
+
"_model_name": "HTMLModel",
|
2078 |
+
"_view_count": null,
|
2079 |
+
"_view_module": "@jupyter-widgets/controls",
|
2080 |
+
"_view_module_version": "1.5.0",
|
2081 |
+
"_view_name": "HTMLView",
|
2082 |
+
"description": "",
|
2083 |
+
"description_tooltip": null,
|
2084 |
+
"layout": "IPY_MODEL_58ab975eaba2485cb0945482c26ecf3d",
|
2085 |
+
"placeholder": "",
|
2086 |
+
"style": "IPY_MODEL_d0b4e43ab5cd4edda6cc061b36bf10a3",
|
2087 |
+
"value": " 45.1M/45.1M [00:00<00:00, 89.1MB/s]"
|
2088 |
+
}
|
2089 |
+
},
|
2090 |
+
"c917f3a000fb44338e4afbeabeaab55f": {
|
2091 |
+
"model_module": "@jupyter-widgets/controls",
|
2092 |
+
"model_module_version": "1.5.0",
|
2093 |
+
"model_name": "DescriptionStyleModel",
|
2094 |
+
"state": {
|
2095 |
+
"_model_module": "@jupyter-widgets/controls",
|
2096 |
+
"_model_module_version": "1.5.0",
|
2097 |
+
"_model_name": "DescriptionStyleModel",
|
2098 |
+
"_view_count": null,
|
2099 |
+
"_view_module": "@jupyter-widgets/base",
|
2100 |
+
"_view_module_version": "1.2.0",
|
2101 |
+
"_view_name": "StyleView",
|
2102 |
+
"description_width": ""
|
2103 |
+
}
|
2104 |
+
},
|
2105 |
+
"cad8fd90586443778568a1babb8c40e6": {
|
2106 |
+
"model_module": "@jupyter-widgets/controls",
|
2107 |
+
"model_module_version": "1.5.0",
|
2108 |
+
"model_name": "ProgressStyleModel",
|
2109 |
+
"state": {
|
2110 |
+
"_model_module": "@jupyter-widgets/controls",
|
2111 |
+
"_model_module_version": "1.5.0",
|
2112 |
+
"_model_name": "ProgressStyleModel",
|
2113 |
+
"_view_count": null,
|
2114 |
+
"_view_module": "@jupyter-widgets/base",
|
2115 |
+
"_view_module_version": "1.2.0",
|
2116 |
+
"_view_name": "StyleView",
|
2117 |
+
"bar_color": null,
|
2118 |
+
"description_width": ""
|
2119 |
+
}
|
2120 |
+
},
|
2121 |
+
"d0b4e43ab5cd4edda6cc061b36bf10a3": {
|
2122 |
+
"model_module": "@jupyter-widgets/controls",
|
2123 |
+
"model_module_version": "1.5.0",
|
2124 |
+
"model_name": "DescriptionStyleModel",
|
2125 |
+
"state": {
|
2126 |
+
"_model_module": "@jupyter-widgets/controls",
|
2127 |
+
"_model_module_version": "1.5.0",
|
2128 |
+
"_model_name": "DescriptionStyleModel",
|
2129 |
+
"_view_count": null,
|
2130 |
+
"_view_module": "@jupyter-widgets/base",
|
2131 |
+
"_view_module_version": "1.2.0",
|
2132 |
+
"_view_name": "StyleView",
|
2133 |
+
"description_width": ""
|
2134 |
+
}
|
2135 |
+
},
|
2136 |
+
"dc03fed540b74f3aa4a1b17ebf2c81d3": {
|
2137 |
+
"model_module": "@jupyter-widgets/controls",
|
2138 |
+
"model_module_version": "1.5.0",
|
2139 |
+
"model_name": "FloatProgressModel",
|
2140 |
+
"state": {
|
2141 |
+
"_dom_classes": [],
|
2142 |
+
"_model_module": "@jupyter-widgets/controls",
|
2143 |
+
"_model_module_version": "1.5.0",
|
2144 |
+
"_model_name": "FloatProgressModel",
|
2145 |
+
"_view_count": null,
|
2146 |
+
"_view_module": "@jupyter-widgets/controls",
|
2147 |
+
"_view_module_version": "1.5.0",
|
2148 |
+
"_view_name": "ProgressView",
|
2149 |
+
"bar_style": "success",
|
2150 |
+
"description": "",
|
2151 |
+
"description_tooltip": null,
|
2152 |
+
"layout": "IPY_MODEL_5c3cb981f324446eae642f7c23a539f0",
|
2153 |
+
"max": 45106985,
|
2154 |
+
"min": 0,
|
2155 |
+
"orientation": "horizontal",
|
2156 |
+
"style": "IPY_MODEL_2fe9614fe5984fa6b887d1e1b3e18b04",
|
2157 |
+
"value": 45106985
|
2158 |
+
}
|
2159 |
+
},
|
2160 |
+
"edeb423e9ff84e5c8a0d790368d68bba": {
|
2161 |
+
"model_module": "@jupyter-widgets/base",
|
2162 |
+
"model_module_version": "1.2.0",
|
2163 |
+
"model_name": "LayoutModel",
|
2164 |
+
"state": {
|
2165 |
+
"_model_module": "@jupyter-widgets/base",
|
2166 |
+
"_model_module_version": "1.2.0",
|
2167 |
+
"_model_name": "LayoutModel",
|
2168 |
+
"_view_count": null,
|
2169 |
+
"_view_module": "@jupyter-widgets/base",
|
2170 |
+
"_view_module_version": "1.2.0",
|
2171 |
+
"_view_name": "LayoutView",
|
2172 |
+
"align_content": null,
|
2173 |
+
"align_items": null,
|
2174 |
+
"align_self": null,
|
2175 |
+
"border": null,
|
2176 |
+
"bottom": null,
|
2177 |
+
"display": null,
|
2178 |
+
"flex": null,
|
2179 |
+
"flex_flow": null,
|
2180 |
+
"grid_area": null,
|
2181 |
+
"grid_auto_columns": null,
|
2182 |
+
"grid_auto_flow": null,
|
2183 |
+
"grid_auto_rows": null,
|
2184 |
+
"grid_column": null,
|
2185 |
+
"grid_gap": null,
|
2186 |
+
"grid_row": null,
|
2187 |
+
"grid_template_areas": null,
|
2188 |
+
"grid_template_columns": null,
|
2189 |
+
"grid_template_rows": null,
|
2190 |
+
"height": null,
|
2191 |
+
"justify_content": null,
|
2192 |
+
"justify_items": null,
|
2193 |
+
"left": null,
|
2194 |
+
"margin": null,
|
2195 |
+
"max_height": null,
|
2196 |
+
"max_width": null,
|
2197 |
+
"min_height": null,
|
2198 |
+
"min_width": null,
|
2199 |
+
"object_fit": null,
|
2200 |
+
"object_position": null,
|
2201 |
+
"order": null,
|
2202 |
+
"overflow": null,
|
2203 |
+
"overflow_x": null,
|
2204 |
+
"overflow_y": null,
|
2205 |
+
"padding": null,
|
2206 |
+
"right": null,
|
2207 |
+
"top": null,
|
2208 |
+
"visibility": null,
|
2209 |
+
"width": null
|
2210 |
+
}
|
2211 |
+
},
|
2212 |
+
"eee695744ec64aa7b71b9e85968c6f8f": {
|
2213 |
+
"model_module": "@jupyter-widgets/controls",
|
2214 |
+
"model_module_version": "1.5.0",
|
2215 |
+
"model_name": "FloatProgressModel",
|
2216 |
+
"state": {
|
2217 |
+
"_dom_classes": [],
|
2218 |
+
"_model_module": "@jupyter-widgets/controls",
|
2219 |
+
"_model_module_version": "1.5.0",
|
2220 |
+
"_model_name": "FloatProgressModel",
|
2221 |
+
"_view_count": null,
|
2222 |
+
"_view_module": "@jupyter-widgets/controls",
|
2223 |
+
"_view_module_version": "1.5.0",
|
2224 |
+
"_view_name": "ProgressView",
|
2225 |
+
"bar_style": "success",
|
2226 |
+
"description": "",
|
2227 |
+
"description_tooltip": null,
|
2228 |
+
"layout": "IPY_MODEL_b2bf751bb96746e4a828241f70e52050",
|
2229 |
+
"max": 45084768,
|
2230 |
+
"min": 0,
|
2231 |
+
"orientation": "horizontal",
|
2232 |
+
"style": "IPY_MODEL_828b227361fe45cd83964149e7475503",
|
2233 |
+
"value": 45084768
|
2234 |
+
}
|
2235 |
+
}
|
2236 |
+
}
|
2237 |
+
}
|
2238 |
+
},
|
2239 |
+
"nbformat": 4,
|
2240 |
+
"nbformat_minor": 4
|
2241 |
+
}
|
pikapikagen/README.md
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: pikapikagen
|
3 |
+
app_file: gradio_demo.py
|
4 |
+
sdk: gradio
|
5 |
+
sdk_version: 5.35.0
|
6 |
+
---
|
pikapikagen/__init__.py
ADDED
File without changes
|
pikapikagen/data_loader.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import DataLoader, Subset
|
2 |
+
import torch
|
3 |
+
from dataset import PokemonDataset
|
4 |
+
import math
|
5 |
+
|
6 |
+
def create_training_setup(
|
7 |
+
tokenizer,
|
8 |
+
test_set_size,
|
9 |
+
val_set_size,
|
10 |
+
batch_size,
|
11 |
+
num_workers=0,
|
12 |
+
num_viz_samples=4,
|
13 |
+
random_seed=42,
|
14 |
+
train_augmentation_pipeline=None,
|
15 |
+
):
|
16 |
+
"""
|
17 |
+
Create a complete setup for training with dataset, dataloaders and fixed batches for visualization.
|
18 |
+
"""
|
19 |
+
assert 0 <= test_set_size < 1.0, "test_set_size must be a float between 0 and 1"
|
20 |
+
assert 0 <= val_set_size < 1.0, "val_set_size must be a float between 0 and 1"
|
21 |
+
assert (test_set_size + val_set_size) < 1.0, "The sum of test and validation sizes must be less than 1"
|
22 |
+
|
23 |
+
train_full_dataset = PokemonDataset(tokenizer=tokenizer, augmentation_transforms=train_augmentation_pipeline)
|
24 |
+
# Don't use augmentation for test and validation
|
25 |
+
test_val_full_dataset = PokemonDataset(tokenizer=tokenizer)
|
26 |
+
|
27 |
+
dataset_size = len(train_full_dataset)
|
28 |
+
|
29 |
+
# Create a random reproducible permutation
|
30 |
+
generator = torch.Generator().manual_seed(random_seed)
|
31 |
+
shuffled_indices = torch.randperm(dataset_size, generator=generator)
|
32 |
+
|
33 |
+
val_count = math.ceil(val_set_size * dataset_size)
|
34 |
+
test_count = math.ceil(test_set_size * dataset_size)
|
35 |
+
train_count = dataset_size - val_count - test_count
|
36 |
+
|
37 |
+
# Partition based on the computed splits
|
38 |
+
train_indices = shuffled_indices[:train_count].tolist()
|
39 |
+
test_indices = shuffled_indices[train_count : train_count + test_count].tolist()
|
40 |
+
val_indices = shuffled_indices[train_count + test_count :].tolist()
|
41 |
+
|
42 |
+
# Create the subsets based on the indices
|
43 |
+
train_dataset = Subset(train_full_dataset, train_indices)
|
44 |
+
test_dataset = Subset(test_val_full_dataset, test_indices)
|
45 |
+
val_dataset = Subset(test_val_full_dataset, val_indices)
|
46 |
+
|
47 |
+
train_loader = DataLoader(
|
48 |
+
train_dataset,
|
49 |
+
batch_size=batch_size,
|
50 |
+
shuffle=True,
|
51 |
+
num_workers=num_workers,
|
52 |
+
pin_memory=True,
|
53 |
+
)
|
54 |
+
test_loader = DataLoader(
|
55 |
+
test_dataset,
|
56 |
+
batch_size=batch_size,
|
57 |
+
shuffle=False,
|
58 |
+
num_workers=num_workers,
|
59 |
+
pin_memory=True,
|
60 |
+
)
|
61 |
+
val_loader = DataLoader(
|
62 |
+
val_dataset,
|
63 |
+
batch_size=batch_size,
|
64 |
+
shuffle=False,
|
65 |
+
num_workers=num_workers,
|
66 |
+
pin_memory=True,
|
67 |
+
)
|
68 |
+
|
69 |
+
# Batch for visualization
|
70 |
+
vis_generator = torch.Generator().manual_seed(random_seed)
|
71 |
+
|
72 |
+
fixed_train_batch = next(
|
73 |
+
iter(DataLoader(train_dataset, batch_size=num_viz_samples, shuffle=True, generator=vis_generator))
|
74 |
+
)
|
75 |
+
# Since no shuffle, a generator is not needed
|
76 |
+
fixed_test_batch = next(iter(DataLoader(test_dataset, batch_size=num_viz_samples, shuffle=False)))
|
77 |
+
fixed_val_batch = next(iter(DataLoader(val_dataset, batch_size=num_viz_samples, shuffle=False)))
|
78 |
+
|
79 |
+
# Batch (dimensione 1) for attention map visualization
|
80 |
+
vis_generator.manual_seed(random_seed)
|
81 |
+
fixed_train_attention_batch = next(
|
82 |
+
iter(DataLoader(train_dataset, batch_size=1, shuffle=True, generator=vis_generator))
|
83 |
+
)
|
84 |
+
fixed_test_attention_batch = next(iter(DataLoader(test_dataset, batch_size=1, shuffle=False)))
|
85 |
+
fixed_val_attention_batch = next(iter(DataLoader(val_dataset, batch_size=1, shuffle=False)))
|
86 |
+
|
87 |
+
return {
|
88 |
+
'train_loader': train_loader,
|
89 |
+
'val_loader': val_loader,
|
90 |
+
'test_loader': test_loader,
|
91 |
+
'train_dataset': train_dataset,
|
92 |
+
'val_dataset': val_dataset,
|
93 |
+
'test_dataset': test_dataset,
|
94 |
+
'fixed_train_batch': fixed_train_batch,
|
95 |
+
'fixed_val_batch': fixed_val_batch,
|
96 |
+
'fixed_test_batch': fixed_test_batch,
|
97 |
+
'fixed_train_attention_batch': fixed_train_attention_batch,
|
98 |
+
'fixed_val_attention_batch': fixed_val_attention_batch,
|
99 |
+
'fixed_test_attention_batch': fixed_test_attention_batch,
|
100 |
+
}
|
pikapikagen/dataset.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import urllib.request
|
3 |
+
import zipfile
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
import pandas as pd
|
6 |
+
from pathlib import Path
|
7 |
+
import torchvision.transforms as transforms
|
8 |
+
from PIL import Image
|
9 |
+
from typing import TypedDict
|
10 |
+
import torch
|
11 |
+
|
12 |
+
|
13 |
+
class PokemonSample(TypedDict):
|
14 |
+
text: torch.Tensor # Text already tokenized
|
15 |
+
image: torch.Tensor
|
16 |
+
description: str # Text before tokenization
|
17 |
+
pokemon_name: str
|
18 |
+
idx: int
|
19 |
+
attention_mask: torch.Tensor
|
20 |
+
|
21 |
+
|
22 |
+
def reporthook(block_num, block_size, total_size):
|
23 |
+
if block_num % 16384 == 0:
|
24 |
+
print(f"Downloading... {block_num * block_size / (1024 * 1024):.2f} MB")
|
25 |
+
|
26 |
+
|
27 |
+
def download_dataset_if_not_exists():
|
28 |
+
dataset_dir = "dataset"
|
29 |
+
pokedex_main_dir = os.path.join(dataset_dir, "pokedex-main")
|
30 |
+
zip_url = "https://github.com/cristobalmitchell/pokedex/archive/refs/heads/main.zip"
|
31 |
+
zip_path = "pokedex_main.zip"
|
32 |
+
|
33 |
+
if os.path.exists(pokedex_main_dir):
|
34 |
+
print(f"{pokedex_main_dir} already exists. Skipping download.")
|
35 |
+
return
|
36 |
+
|
37 |
+
os.makedirs(dataset_dir, exist_ok=True)
|
38 |
+
|
39 |
+
print("Downloading dataset...")
|
40 |
+
urllib.request.urlretrieve(zip_url, zip_path, reporthook)
|
41 |
+
print("Download complete.")
|
42 |
+
|
43 |
+
print("Extracting dataset...")
|
44 |
+
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
45 |
+
zip_ref.extractall(dataset_dir)
|
46 |
+
print("Extraction complete.")
|
47 |
+
|
48 |
+
os.remove(zip_path)
|
49 |
+
|
50 |
+
|
51 |
+
class PokemonDataset(Dataset):
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
tokenizer,
|
55 |
+
csv_path="dataset/pokedex-main/data/pokemon.csv",
|
56 |
+
image_dir="dataset/pokedex-main/images/small_images",
|
57 |
+
max_length=128,
|
58 |
+
augmentation_transforms=None,
|
59 |
+
):
|
60 |
+
self.df = pd.read_csv(csv_path, encoding="utf-16 LE", delimiter="\t")
|
61 |
+
self.image_dir = Path(image_dir)
|
62 |
+
print(f"Dataset caricato: {len(self.df)} Pokemon con descrizioni e immagini")
|
63 |
+
|
64 |
+
self.tokenizer = tokenizer
|
65 |
+
self.max_length = max_length
|
66 |
+
|
67 |
+
if augmentation_transforms is not None:
|
68 |
+
self.final_transform = transforms.Compose(
|
69 |
+
[
|
70 |
+
transforms.ToTensor(),
|
71 |
+
transforms.Resize((256, 256), antialias=True),
|
72 |
+
augmentation_transforms,
|
73 |
+
transforms.Normalize(
|
74 |
+
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]
|
75 |
+
), # Normalizza a [-1, 1]
|
76 |
+
]
|
77 |
+
)
|
78 |
+
else:
|
79 |
+
self.final_transform = transforms.Compose(
|
80 |
+
[
|
81 |
+
transforms.ToTensor(),
|
82 |
+
transforms.Resize((256, 256), antialias=True),
|
83 |
+
transforms.Normalize(
|
84 |
+
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]
|
85 |
+
), # Normalizza a [-1, 1]
|
86 |
+
]
|
87 |
+
)
|
88 |
+
|
89 |
+
def __len__(self):
|
90 |
+
return len(self.df)
|
91 |
+
|
92 |
+
def __getitem__(self, idx: int) -> PokemonSample:
|
93 |
+
# Ottieni la riga corrispondente
|
94 |
+
row = self.df.iloc[idx]
|
95 |
+
|
96 |
+
# === PREPROCESSING DEL TESTO ===
|
97 |
+
description = str(row["description"])
|
98 |
+
|
99 |
+
# Tokenizza il testo
|
100 |
+
encoded = self.tokenizer(
|
101 |
+
description,
|
102 |
+
max_length=self.max_length,
|
103 |
+
padding="max_length",
|
104 |
+
truncation=True,
|
105 |
+
return_tensors="pt",
|
106 |
+
)
|
107 |
+
|
108 |
+
# Estrai token_ids e attention_mask
|
109 |
+
text_ids = encoded["input_ids"].squeeze(0) # Rimuovi la dimensione batch
|
110 |
+
attention_mask = encoded["attention_mask"].squeeze(0)
|
111 |
+
|
112 |
+
# === CARICAMENTO E PREPROCESSING DELL'IMMAGINE ===
|
113 |
+
# Costruisce il percorso dell'immagine
|
114 |
+
image_filename = f"{row['national_number']:03d}.png"
|
115 |
+
image_path = self.image_dir / image_filename
|
116 |
+
|
117 |
+
# Carica l'immagine
|
118 |
+
image_rgba = Image.open(image_path).convert("RGBA")
|
119 |
+
|
120 |
+
# Gestisce la trasparenza: ricombina l'immagine con uno sfondo bianco
|
121 |
+
background = Image.new("RGB", image_rgba.size, (255, 255, 255))
|
122 |
+
background.paste(image_rgba, mask=image_rgba.split()[-1])
|
123 |
+
|
124 |
+
# Applica le trasformazioni finali (ToTensor, Resize, Normalize)
|
125 |
+
image_tensor = self.final_transform(background)
|
126 |
+
|
127 |
+
# Costruisce il risultato (matches pokemon_dataset.py structure)
|
128 |
+
sample = {
|
129 |
+
"text": text_ids,
|
130 |
+
"image": image_tensor,
|
131 |
+
"description": description,
|
132 |
+
"pokemon_name": row["english_name"],
|
133 |
+
"idx": idx,
|
134 |
+
"attention_mask": attention_mask,
|
135 |
+
}
|
136 |
+
|
137 |
+
return sample
|
138 |
+
|
139 |
+
|
140 |
+
download_dataset_if_not_exists()
|
141 |
+
print("Dataset ready!")
|
pikapikagen/discriminators.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from model_blocks.text_encoder import TextEncoder
|
4 |
+
|
5 |
+
|
6 |
+
class Discriminator256(nn.Module):
|
7 |
+
def __init__(self, text_dim=256, img_channels=3):
|
8 |
+
super(Discriminator256, self).__init__()
|
9 |
+
|
10 |
+
self.text_encoder = TextEncoder() # Separate text encoder for discriminators
|
11 |
+
|
12 |
+
self.img_path = nn.Sequential(
|
13 |
+
# 256x256 -> 128x128
|
14 |
+
nn.Conv2d(img_channels, 16, 4, 2, 1, bias=False),
|
15 |
+
nn.LeakyReLU(0.2, inplace=True),
|
16 |
+
|
17 |
+
# 128x128 -> 64x64
|
18 |
+
nn.Conv2d(16, 32, 4, 2, 1, bias=False),
|
19 |
+
nn.BatchNorm2d(32),
|
20 |
+
nn.LeakyReLU(0.2, inplace=True),
|
21 |
+
|
22 |
+
# 64x64 -> 32x32
|
23 |
+
nn.Conv2d(32, 64, 4, 2, 1, bias=False),
|
24 |
+
nn.BatchNorm2d(64),
|
25 |
+
nn.LeakyReLU(0.2, inplace=True),
|
26 |
+
|
27 |
+
# 32x32 -> 16x16
|
28 |
+
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
|
29 |
+
nn.BatchNorm2d(128),
|
30 |
+
nn.LeakyReLU(0.2, inplace=True),
|
31 |
+
|
32 |
+
# 16x16 -> 8x8
|
33 |
+
nn.Conv2d(128, 256, 4, 2, 1, bias=False),
|
34 |
+
nn.BatchNorm2d(256),
|
35 |
+
nn.LeakyReLU(0.2, inplace=True),
|
36 |
+
|
37 |
+
# 8x8 -> 4x4
|
38 |
+
nn.Conv2d(256, 512, 4, 2, 1, bias=False),
|
39 |
+
nn.BatchNorm2d(512),
|
40 |
+
nn.LeakyReLU(0.2, inplace=True),
|
41 |
+
)
|
42 |
+
|
43 |
+
self.text_path = nn.Sequential(
|
44 |
+
nn.Linear(text_dim, 1024),
|
45 |
+
nn.LeakyReLU(0.2, inplace=True),
|
46 |
+
nn.Linear(1024, 512)
|
47 |
+
)
|
48 |
+
|
49 |
+
# Unconditional classifier (real/fake without text conditioning)
|
50 |
+
self.unconditional_classifier = nn.Sequential(
|
51 |
+
nn.Linear(512 * 4 * 4, 1024),
|
52 |
+
nn.LeakyReLU(0.2, inplace=True),
|
53 |
+
nn.Dropout(0.5),
|
54 |
+
nn.Linear(1024, 1),
|
55 |
+
)
|
56 |
+
|
57 |
+
# Conditional classifier (text-conditioned real/fake)
|
58 |
+
self.conditional_classifier = nn.Sequential(
|
59 |
+
nn.Linear(512 * 4 * 4 + 512, 1024), # size: sum of flattened image and text embedding
|
60 |
+
nn.LeakyReLU(0.2, inplace=True),
|
61 |
+
nn.Dropout(0.5),
|
62 |
+
nn.Linear(1024, 1),
|
63 |
+
)
|
64 |
+
|
65 |
+
def forward(self, images, text_features=None, text_mask=None, return_both=True):
|
66 |
+
# Encode image
|
67 |
+
img_features = self.img_path(images)
|
68 |
+
img_features_flat = img_features.view(img_features.size(0), -1) # Flatten
|
69 |
+
|
70 |
+
unconditional_output = self.unconditional_classifier(img_features_flat)
|
71 |
+
|
72 |
+
if not return_both:
|
73 |
+
return unconditional_output
|
74 |
+
|
75 |
+
if text_features is None or text_mask is None:
|
76 |
+
raise AttributeError("text_features and text_mask necessary for text conditioning")
|
77 |
+
|
78 |
+
# Encode text (mean pooling)
|
79 |
+
global_full_text = self.text_encoder(text_features, text_mask)
|
80 |
+
global_text = global_full_text.mean(dim=1)
|
81 |
+
text_features_encoded = self.text_path(global_text)
|
82 |
+
|
83 |
+
# Combine features
|
84 |
+
combined = torch.cat([img_features_flat, text_features_encoded], dim=1)
|
85 |
+
conditional_output = self.conditional_classifier(combined)
|
86 |
+
|
87 |
+
return unconditional_output, conditional_output
|
88 |
+
|
89 |
+
|
90 |
+
class Discriminator64(nn.Module):
|
91 |
+
def __init__(self, text_dim=256, img_channels=3):
|
92 |
+
super(Discriminator64, self).__init__()
|
93 |
+
|
94 |
+
self.text_encoder = TextEncoder()
|
95 |
+
|
96 |
+
self.img_path = nn.Sequential(
|
97 |
+
# 64x64 -> 32x32
|
98 |
+
nn.Conv2d(img_channels, 16, 4, 2, 1, bias=False),
|
99 |
+
nn.LeakyReLU(0.2, inplace=True),
|
100 |
+
|
101 |
+
# 32x32 -> 16x16
|
102 |
+
nn.Conv2d(16, 32, 4, 2, 1, bias=False),
|
103 |
+
nn.BatchNorm2d(32),
|
104 |
+
nn.LeakyReLU(0.2, inplace=True),
|
105 |
+
|
106 |
+
# 16x16 -> 8x8
|
107 |
+
nn.Conv2d(32, 64, 4, 2, 1, bias=False),
|
108 |
+
nn.BatchNorm2d(64),
|
109 |
+
nn.LeakyReLU(0.2, inplace=True),
|
110 |
+
|
111 |
+
# 8x8 -> 4x4
|
112 |
+
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
|
113 |
+
nn.BatchNorm2d(128),
|
114 |
+
nn.LeakyReLU(0.2, inplace=True),
|
115 |
+
)
|
116 |
+
|
117 |
+
# Text encoder for discriminator
|
118 |
+
self.text_path = nn.Sequential(
|
119 |
+
nn.Linear(text_dim, 1024),
|
120 |
+
nn.LeakyReLU(0.2, inplace=True),
|
121 |
+
nn.Linear(1024, 512)
|
122 |
+
)
|
123 |
+
|
124 |
+
# Unconditional classifier (real/fake without text conditioning)
|
125 |
+
self.unconditional_classifier = nn.Sequential(
|
126 |
+
nn.Linear(128 * 4 * 4, 1024),
|
127 |
+
nn.LeakyReLU(0.2, inplace=True),
|
128 |
+
nn.Dropout(0.5),
|
129 |
+
nn.Linear(1024, 1),
|
130 |
+
)
|
131 |
+
|
132 |
+
# Conditional classifier (text-conditioned real/fake)
|
133 |
+
self.conditional_classifier = nn.Sequential(
|
134 |
+
nn.Linear(128 * 4 * 4 + 512, 1024),
|
135 |
+
nn.LeakyReLU(0.2, inplace=True),
|
136 |
+
nn.Dropout(0.5),
|
137 |
+
nn.Linear(1024, 1),
|
138 |
+
)
|
139 |
+
|
140 |
+
def forward(self, images, text_features=None, text_mask=None, return_both=True):
|
141 |
+
img_features = self.img_path(images)
|
142 |
+
img_features_flat = img_features.view(img_features.size(0), -1) # Flatten
|
143 |
+
|
144 |
+
unconditional_output = self.unconditional_classifier(img_features_flat)
|
145 |
+
|
146 |
+
if not return_both:
|
147 |
+
return unconditional_output
|
148 |
+
|
149 |
+
if text_features is None or text_mask is None:
|
150 |
+
raise AttributeError("text_features and text_mask necessary for text conditioning")
|
151 |
+
|
152 |
+
|
153 |
+
# Encode text (mean pooling)
|
154 |
+
global_full_text = self.text_encoder(text_features, text_mask)
|
155 |
+
global_text = global_full_text.mean(dim=1)
|
156 |
+
text_features_encoded = self.text_path(global_text)
|
157 |
+
|
158 |
+
combined = torch.cat([img_features_flat, text_features_encoded], dim=1)
|
159 |
+
conditional_output = self.conditional_classifier(combined)
|
160 |
+
|
161 |
+
return unconditional_output, conditional_output
|
pikapikagen/evaluate_kid.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
from model import Generator as PikaPikaGen
|
4 |
+
from data_loader import create_training_setup
|
5 |
+
from utils import denormalize_image
|
6 |
+
from torch_fidelity import calculate_metrics
|
7 |
+
import os
|
8 |
+
import tempfile
|
9 |
+
from PIL import Image
|
10 |
+
import shutil
|
11 |
+
|
12 |
+
CHECKPOINT_PATH = "pikapikagen/model_checkpoint/checkpoint_epoch_150.pth"
|
13 |
+
|
14 |
+
TOKENIZER_NAME = "prajjwal1/bert-mini"
|
15 |
+
|
16 |
+
BATCH_SIZE = 16 # Batch size for generating images
|
17 |
+
NUM_WORKERS = 2 # Number of workers for the data loader
|
18 |
+
|
19 |
+
KID_SUBSET_SIZE = 50
|
20 |
+
KID_NUM_SUBSETS = 20
|
21 |
+
|
22 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
23 |
+
|
24 |
+
|
25 |
+
class PokemonKIDEvaluator:
|
26 |
+
"""Evaluator class for computing KID metrics on PikaPikaGen."""
|
27 |
+
|
28 |
+
def __init__(self, checkpoint_path, device=DEVICE):
|
29 |
+
self.device = device
|
30 |
+
self.tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
|
31 |
+
self.checkpoint_path = checkpoint_path
|
32 |
+
|
33 |
+
self._load_model() # As in gradio demo
|
34 |
+
|
35 |
+
def _load_model(self):
|
36 |
+
self.generator = PikaPikaGen().to(self.device)
|
37 |
+
|
38 |
+
checkpoint = torch.load(self.checkpoint_path, map_location=self.device, weights_only=True)
|
39 |
+
self.generator.load_state_dict(checkpoint['generator_state_dict'])
|
40 |
+
self.generator.eval()
|
41 |
+
|
42 |
+
@staticmethod
|
43 |
+
def _tensor_to_pil(tensor: torch.Tensor) -> Image.Image:
|
44 |
+
denormalized = denormalize_image(tensor)
|
45 |
+
uint8_tensor = (denormalized * 255).clamp(0, 255).to(torch.uint8)
|
46 |
+
img_np = uint8_tensor.cpu().permute(1, 2, 0).numpy()
|
47 |
+
return Image.fromarray(img_np)
|
48 |
+
|
49 |
+
def _save_images_to_temp_dir(self, images_tensor: torch.Tensor, prefix: str) -> str:
|
50 |
+
"""Save a batch of image tensors to a new temporary directory."""
|
51 |
+
temp_dir = tempfile.mkdtemp(prefix=f"pikakid_{prefix}_")
|
52 |
+
for i, img_tensor in enumerate(images_tensor):
|
53 |
+
pil_img = self._tensor_to_pil(img_tensor)
|
54 |
+
img_path = os.path.join(temp_dir, f"{i:06d}.png")
|
55 |
+
pil_img.save(img_path)
|
56 |
+
return temp_dir
|
57 |
+
|
58 |
+
def evaluate_kid(self, test_loader, resolution="256x256"):
|
59 |
+
|
60 |
+
all_real_images = []
|
61 |
+
all_generated_images = []
|
62 |
+
|
63 |
+
with torch.no_grad():
|
64 |
+
for batch in test_loader:
|
65 |
+
text_ids = batch["text"].to(self.device)
|
66 |
+
attention_mask = batch["attention_mask"].to(self.device)
|
67 |
+
real_images_256 = batch["image"] # (B, 3, 256, 256)
|
68 |
+
|
69 |
+
generated_256, generated_64 = self.generator(text_ids, attention_mask)
|
70 |
+
|
71 |
+
# Select the correct resolution for both real and generated images
|
72 |
+
if resolution == "256x256":
|
73 |
+
generated_images = generated_256
|
74 |
+
processed_real_images = real_images_256
|
75 |
+
elif resolution == "64x64":
|
76 |
+
generated_images = generated_64
|
77 |
+
processed_real_images = torch.nn.functional.interpolate(
|
78 |
+
real_images_256, size=(64, 64), mode='bilinear', align_corners=False
|
79 |
+
)
|
80 |
+
else:
|
81 |
+
raise ValueError(f"Unsupported resolution: {resolution}")
|
82 |
+
|
83 |
+
all_real_images.append(processed_real_images.cpu())
|
84 |
+
all_generated_images.append(generated_images.cpu())
|
85 |
+
|
86 |
+
# Combine all batches into single tensors
|
87 |
+
all_real_images = torch.cat(all_real_images, dim=0)
|
88 |
+
all_generated_images = torch.cat(all_generated_images, dim=0)
|
89 |
+
|
90 |
+
# Save images to temporary directories for torch-fidelity
|
91 |
+
real_temp_dir = self._save_images_to_temp_dir(all_real_images, "real")
|
92 |
+
generated_temp_dir = self._save_images_to_temp_dir(all_generated_images, "generated")
|
93 |
+
|
94 |
+
metrics = calculate_metrics(
|
95 |
+
input1=generated_temp_dir, # Path to generated (fake) images
|
96 |
+
input2=real_temp_dir, # Path to real images
|
97 |
+
kid=True,
|
98 |
+
kid_subset_size=KID_SUBSET_SIZE,
|
99 |
+
kid_subsets=KID_NUM_SUBSETS,
|
100 |
+
batch_size=BATCH_SIZE,
|
101 |
+
device=self.device
|
102 |
+
)
|
103 |
+
|
104 |
+
kid_mean = metrics['kernel_inception_distance_mean']
|
105 |
+
kid_std = metrics['kernel_inception_distance_std']
|
106 |
+
|
107 |
+
# Clean up the temporary directories
|
108 |
+
shutil.rmtree(real_temp_dir)
|
109 |
+
shutil.rmtree(generated_temp_dir)
|
110 |
+
|
111 |
+
return kid_mean, kid_std
|
112 |
+
|
113 |
+
def main():
|
114 |
+
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
|
115 |
+
training_setup = create_training_setup(
|
116 |
+
tokenizer=tokenizer,
|
117 |
+
test_set_size=0.2,
|
118 |
+
val_set_size=0.1,
|
119 |
+
batch_size=BATCH_SIZE,
|
120 |
+
num_workers=NUM_WORKERS,
|
121 |
+
random_seed=42, # Use a fixed seed for a reproducible split
|
122 |
+
)
|
123 |
+
test_loader = training_setup['test_loader']
|
124 |
+
test_set_size = len(test_loader.dataset)
|
125 |
+
|
126 |
+
evaluator = PokemonKIDEvaluator(checkpoint_path=CHECKPOINT_PATH)
|
127 |
+
|
128 |
+
resolutions_to_test = ['64x64', '256x256']
|
129 |
+
|
130 |
+
print(f"Checkpoint: {CHECKPOINT_PATH}")
|
131 |
+
print(f"Test samples: {test_set_size}")
|
132 |
+
print(f"KID Subset Size: {KID_SUBSET_SIZE}")
|
133 |
+
print(f"KID Subsets: {KID_NUM_SUBSETS}")
|
134 |
+
|
135 |
+
for res in resolutions_to_test:
|
136 |
+
kid_mean, kid_std = evaluator.evaluate_kid(test_loader, resolution=res)
|
137 |
+
print(f"Resolution {res}:\t KID = {kid_mean:.6f} ± {kid_std:.6f}")
|
138 |
+
|
139 |
+
|
140 |
+
if __name__ == "__main__":
|
141 |
+
main()
|
pikapikagen/gradio_demo.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import gradio.themes
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
from transformers import AutoTokenizer
|
7 |
+
from model import Generator as PikaPikaGen
|
8 |
+
from utils import denormalize_image
|
9 |
+
from plots import plot_attention_visualization
|
10 |
+
import os
|
11 |
+
|
12 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
13 |
+
CHECKPOINT_PATH = "model_checkpoints/checkpoint_epoch_150.pth"
|
14 |
+
TOKENIZER_NAME = "prajjwal1/bert-mini"
|
15 |
+
|
16 |
+
|
17 |
+
class PokemonGenerator:
|
18 |
+
"""Main class for the Pokemon generation demo"""
|
19 |
+
|
20 |
+
def __init__(self):
|
21 |
+
self.device = DEVICE
|
22 |
+
self.tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
|
23 |
+
|
24 |
+
self._load_model()
|
25 |
+
|
26 |
+
|
27 |
+
def _load_model(self):
|
28 |
+
"""Load the trained PikaPikaGen model"""
|
29 |
+
try:
|
30 |
+
# Initialize model
|
31 |
+
self.generator = PikaPikaGen().to(self.device)
|
32 |
+
|
33 |
+
# Load checkpoint
|
34 |
+
checkpoint = torch.load(CHECKPOINT_PATH, map_location=self.device, weights_only=True)
|
35 |
+
|
36 |
+
# Load saved weights into model
|
37 |
+
self.generator.load_state_dict(checkpoint['generator_state_dict'])
|
38 |
+
print(f"✅ Generator loaded from checkpoint (epoch {checkpoint.get('epoch', 'unknown')})")
|
39 |
+
|
40 |
+
# No training
|
41 |
+
self.generator.eval()
|
42 |
+
|
43 |
+
except Exception as e:
|
44 |
+
print(f"❌ Error loading model: {e}")
|
45 |
+
raise
|
46 |
+
|
47 |
+
def _tensor_to_pil(self, tensor):
|
48 |
+
"""Convert tensor to PIL Image"""
|
49 |
+
# tensor shape: (3, H, W)
|
50 |
+
img_np = tensor.permute(1, 2, 0).clamp(0, 1).numpy()
|
51 |
+
img_np = (img_np * 255).astype(np.uint8)
|
52 |
+
return Image.fromarray(img_np)
|
53 |
+
|
54 |
+
def generate_pokemon(self, description, num_samples=4, show_attention=False, resolution="both"):
|
55 |
+
"""
|
56 |
+
Generate Pokemon sprites from text description
|
57 |
+
|
58 |
+
Args:
|
59 |
+
description (str): Text description of the desired Pokemon
|
60 |
+
num_samples (int): Number of samples to generate (1-8)
|
61 |
+
show_attention (bool): Whether to show attention visualization
|
62 |
+
resolution (str): Output resolution - "256x256", "64x64", or "both"
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
tuple: (generated_images, attention_plot)
|
66 |
+
"""
|
67 |
+
if not description.strip():
|
68 |
+
return [], "❌ Please enter a description."
|
69 |
+
|
70 |
+
# No reason to compute gradients
|
71 |
+
with torch.no_grad():
|
72 |
+
tokens = self.tokenizer(
|
73 |
+
description,
|
74 |
+
max_length=128,
|
75 |
+
padding='max_length',
|
76 |
+
truncation=True,
|
77 |
+
return_tensors='pt'
|
78 |
+
)
|
79 |
+
|
80 |
+
text_ids = tokens['input_ids'].repeat(num_samples, 1).to(self.device)
|
81 |
+
attention_mask = tokens['attention_mask'].repeat(num_samples, 1).to(self.device)
|
82 |
+
|
83 |
+
generated_256, generated_64, attention_maps, initial_weights = self.generator(
|
84 |
+
text_ids, attention_mask, return_attentions=True
|
85 |
+
)
|
86 |
+
|
87 |
+
# Convert tensors to PIL images
|
88 |
+
output_images = []
|
89 |
+
images_to_process = []
|
90 |
+
if resolution in ["256x256", "both"]:
|
91 |
+
images_to_process.append(generated_256)
|
92 |
+
if resolution in ["64x64", "both"]:
|
93 |
+
images_to_process.append(generated_64)
|
94 |
+
|
95 |
+
for img_batch in images_to_process:
|
96 |
+
img_batch_denorm = denormalize_image(img_batch.cpu())
|
97 |
+
for i in range(num_samples):
|
98 |
+
img_pil = self._tensor_to_pil(img_batch_denorm[i])
|
99 |
+
output_images.append(img_pil)
|
100 |
+
|
101 |
+
attention_plot = None
|
102 |
+
if show_attention:
|
103 |
+
# Create directory if it doesn't exist
|
104 |
+
output_dir = "attention_visualizations"
|
105 |
+
os.makedirs(output_dir, exist_ok=True)
|
106 |
+
|
107 |
+
# Create a more descriptive ID for the file
|
108 |
+
# To avoid overwriting the same file with the same name
|
109 |
+
plot_id = description.strip().replace(" ", "_")[:30]
|
110 |
+
|
111 |
+
|
112 |
+
# Use the first sample for the attention visualization
|
113 |
+
attention_plot = plot_attention_visualization(
|
114 |
+
epoch=0,
|
115 |
+
set_name="demo",
|
116 |
+
output_dir=output_dir,
|
117 |
+
|
118 |
+
generated_images=generated_256,
|
119 |
+
|
120 |
+
# Full batch data from the model
|
121 |
+
decoder_attention_maps=attention_maps,
|
122 |
+
initial_context_weights=initial_weights,
|
123 |
+
|
124 |
+
token_ids=text_ids,
|
125 |
+
attention_mask=attention_mask,
|
126 |
+
tokenizer=self.tokenizer,
|
127 |
+
|
128 |
+
# Metadata for the specific sample
|
129 |
+
description=description,
|
130 |
+
pokemon_id=plot_id,
|
131 |
+
|
132 |
+
sample_idx=0,
|
133 |
+
show_inline=False
|
134 |
+
)
|
135 |
+
|
136 |
+
return output_images, attention_plot
|
137 |
+
|
138 |
+
|
139 |
+
print("Initializing PikaPikaGen Demo...")
|
140 |
+
pokemon_gen = PokemonGenerator()
|
141 |
+
|
142 |
+
def generate_pokemon_interface(description, num_samples, show_attention, resolution):
|
143 |
+
images, attention_plot = pokemon_gen.generate_pokemon(
|
144 |
+
description=description,
|
145 |
+
num_samples=num_samples,
|
146 |
+
show_attention=show_attention,
|
147 |
+
resolution=resolution
|
148 |
+
)
|
149 |
+
|
150 |
+
if images is None:
|
151 |
+
return [], attention_plot # attention_plot contains error message if error
|
152 |
+
|
153 |
+
status_msg = f"Generated {len(images)} Pokemon sprites"
|
154 |
+
if resolution == "both":
|
155 |
+
status_msg += f" ({num_samples} at 256x256 + {num_samples} at 64x64)"
|
156 |
+
else:
|
157 |
+
status_msg += f" at {resolution}"
|
158 |
+
|
159 |
+
return images, attention_plot
|
160 |
+
|
161 |
+
def create_interface():
|
162 |
+
with gr.Blocks(
|
163 |
+
title="PikaPikaGen: AI Pokemon Generator",
|
164 |
+
theme=gradio.themes.Soft(),
|
165 |
+
css="""
|
166 |
+
.main-header {
|
167 |
+
text-align: center;
|
168 |
+
background: linear-gradient(45deg, #ff6b6b, #4ecdc4);
|
169 |
+
-webkit-background-clip: text;
|
170 |
+
-webkit-text-fill-color: transparent;
|
171 |
+
font-size: 2.5em;
|
172 |
+
font-weight: bold;
|
173 |
+
margin-bottom: 0.5em;
|
174 |
+
}
|
175 |
+
.description {
|
176 |
+
text-align: center;
|
177 |
+
font-size: 1.1em;
|
178 |
+
color: #666;
|
179 |
+
margin-bottom: 1em;
|
180 |
+
}
|
181 |
+
"""
|
182 |
+
) as demo:
|
183 |
+
|
184 |
+
gr.HTML("""
|
185 |
+
<div class="main-header">🎮 PikaPikaGen: AI Pokemon Generator</div>
|
186 |
+
<div class="description">
|
187 |
+
Creation of Pokemon sprites from text descriptions using Transformer attention and CNN generation.
|
188 |
+
</div>
|
189 |
+
""")
|
190 |
+
|
191 |
+
with gr.Row():
|
192 |
+
with gr.Column(scale=1):
|
193 |
+
gr.Markdown("### 📝 Input")
|
194 |
+
|
195 |
+
description_input = gr.Textbox(
|
196 |
+
label="Pokemon Description",
|
197 |
+
placeholder="Describe your Pokemon! e.g., 'A fire dragon with golden scales and ruby eyes'",
|
198 |
+
lines=3,
|
199 |
+
value="A legendary fire dragon pokemon with golden scales and red eyes"
|
200 |
+
)
|
201 |
+
|
202 |
+
with gr.Row():
|
203 |
+
num_samples = gr.Slider(
|
204 |
+
minimum=1, maximum=8, value=4, step=1,
|
205 |
+
label="Number of samples"
|
206 |
+
)
|
207 |
+
|
208 |
+
resolution = gr.Radio(
|
209 |
+
choices=["256x256", "64x64", "both"],
|
210 |
+
value="256x256",
|
211 |
+
label="Output resolution"
|
212 |
+
)
|
213 |
+
|
214 |
+
show_attention = gr.Checkbox(
|
215 |
+
label="Show attention visualization",
|
216 |
+
value=True,
|
217 |
+
info="Visualize which words the model focuses on"
|
218 |
+
)
|
219 |
+
|
220 |
+
generate_btn = gr.Button(
|
221 |
+
"🎨 Generate Pokemon!",
|
222 |
+
variant="primary",
|
223 |
+
size="lg"
|
224 |
+
)
|
225 |
+
|
226 |
+
with gr.Column(scale=2):
|
227 |
+
gr.Markdown("### 🎨 Generated Pokemon")
|
228 |
+
|
229 |
+
output_gallery = gr.Gallery(
|
230 |
+
label="Generated Pokemon sprites",
|
231 |
+
show_label=True,
|
232 |
+
elem_id="gallery",
|
233 |
+
columns=2,
|
234 |
+
rows=2,
|
235 |
+
height="auto",
|
236 |
+
allow_preview=True
|
237 |
+
)
|
238 |
+
|
239 |
+
attention_output = gr.Image(
|
240 |
+
label="Attention visualization",
|
241 |
+
show_label=True,
|
242 |
+
interactive=False
|
243 |
+
)
|
244 |
+
|
245 |
+
# Examples section
|
246 |
+
gr.Markdown("### 🌟 Examples to try")
|
247 |
+
gr.Examples(
|
248 |
+
examples=[
|
249 |
+
["A fire dragon with golden scales and red eyes", 4, True, "256x256"],
|
250 |
+
["An electric mouse with yellow fur and lightning bolts", 3, False, "both"],
|
251 |
+
["A water turtle with blue shell and powerful jaws", 2, True, "256x256"],
|
252 |
+
["A psychic cat with purple fur and mystical powers", 4, True, "256x256"],
|
253 |
+
["A grass serpent with emerald scales and vine whips", 3, False, "64x64"],
|
254 |
+
["An ice phoenix with crystal wings and frozen flames", 4, True, "256x256"],
|
255 |
+
["A dark wolf with shadow abilities and glowing eyes", 2, True, "both"],
|
256 |
+
["A steel robot pokemon with metallic armor and laser beams", 3, False, "256x256"]
|
257 |
+
],
|
258 |
+
inputs=[description_input, num_samples, show_attention, resolution],
|
259 |
+
outputs=[output_gallery, attention_output],
|
260 |
+
fn=generate_pokemon_interface,
|
261 |
+
cache_examples=False
|
262 |
+
)
|
263 |
+
|
264 |
+
# Event handlers
|
265 |
+
generate_btn.click(
|
266 |
+
fn=generate_pokemon_interface,
|
267 |
+
inputs=[description_input, num_samples, show_attention, resolution],
|
268 |
+
outputs=[output_gallery, attention_output]
|
269 |
+
)
|
270 |
+
|
271 |
+
# Footer
|
272 |
+
gr.Markdown("""
|
273 |
+
---
|
274 |
+
**PikaPikaGen** - Text-to-Image Pokemon Generation using Transformer + CNN
|
275 |
+
""")
|
276 |
+
|
277 |
+
return demo
|
278 |
+
|
279 |
+
if __name__ == "__main__":
|
280 |
+
print("Starting PikaPikaGen Demo...")
|
281 |
+
|
282 |
+
# Create and launch interface
|
283 |
+
demo = create_interface()
|
284 |
+
|
285 |
+
demo.launch(
|
286 |
+
server_name="0.0.0.0", # Allow external access
|
287 |
+
share=False, # Set to True for public sharing
|
288 |
+
debug=False,
|
289 |
+
show_error=True,
|
290 |
+
inbrowser=True # Auto-open browser
|
291 |
+
)
|
pikapikagen/losses.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torchvision import models
|
5 |
+
from torchvision.models import VGG19_Weights
|
6 |
+
|
7 |
+
|
8 |
+
class VGGPerceptualLoss(nn.Module):
|
9 |
+
"""
|
10 |
+
Perceptual loss using VGG19 pretrained on ImageNet.
|
11 |
+
We extract features at:
|
12 |
+
- relu1_2 (index: 3)
|
13 |
+
- relu2_2 (index: 8)
|
14 |
+
- relu3_2 (index: 17)
|
15 |
+
- relu4_2 (index: 26)
|
16 |
+
Then compute L1 distance between those feature maps.
|
17 |
+
Input images are in [-1,1]. We convert to [0,1], then normalize with ImageNet stats.
|
18 |
+
"""
|
19 |
+
def __init__(self, device):
|
20 |
+
super(VGGPerceptualLoss, self).__init__()
|
21 |
+
vgg19_features = models.vgg19(weights=VGG19_Weights.DEFAULT).features.to(device).eval()
|
22 |
+
# We only need layers up to 26 (relu4_2)
|
23 |
+
self.slices = nn.ModuleDict({
|
24 |
+
"relu1_2": nn.Sequential(*list(vgg19_features.children())[:4]), # conv1_1, relu1_1, conv1_2, relu1_2
|
25 |
+
"relu2_2": nn.Sequential(*list(vgg19_features.children())[4:9]), # pool1, conv2_1, relu2_1, conv2_2, relu2_2
|
26 |
+
"relu3_2": nn.Sequential(*list(vgg19_features.children())[9:18]), # pool2, conv3_1, relu3_1, conv3_2, relu3_2, ...
|
27 |
+
"relu4_2": nn.Sequential(*list(vgg19_features.children())[18:27]) # pool3, conv4_1, relu4_1, conv4_2, relu4_2
|
28 |
+
})
|
29 |
+
for param in self.parameters():
|
30 |
+
param.requires_grad = False
|
31 |
+
|
32 |
+
self.l1 = nn.L1Loss()
|
33 |
+
self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1))
|
34 |
+
self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1))
|
35 |
+
|
36 |
+
def forward(self, img_gen, img_ref):
|
37 |
+
"""
|
38 |
+
img_gen, img_ref: [B,3,H,W] in range [-1,1].
|
39 |
+
Return: sum of L1 distances between VGG feature maps at chosen layers.
|
40 |
+
"""
|
41 |
+
# Convert to [0,1]
|
42 |
+
gen = (img_gen + 1.0) / 2.0
|
43 |
+
ref = (img_ref + 1.0) / 2.0
|
44 |
+
# Normalize
|
45 |
+
gen_norm = (gen - self.mean) / self.std
|
46 |
+
ref_norm = (ref - self.mean) / self.std
|
47 |
+
|
48 |
+
loss = 0.0
|
49 |
+
x_gen = gen_norm
|
50 |
+
x_ref = ref_norm
|
51 |
+
for slice_mod in self.slices.values():
|
52 |
+
x_gen = slice_mod(x_gen)
|
53 |
+
x_ref = slice_mod(x_ref)
|
54 |
+
loss += self.l1(x_gen, x_ref)
|
55 |
+
return loss
|
56 |
+
|
57 |
+
|
58 |
+
class SobelLoss(nn.Module):
|
59 |
+
"""
|
60 |
+
Computes the Sobel loss between two images, which encourages edge similarity.
|
61 |
+
This loss operates on the grayscale versions of the input images.
|
62 |
+
"""
|
63 |
+
def __init__(self):
|
64 |
+
super(SobelLoss, self).__init__()
|
65 |
+
# Sobel kernels for edge detection
|
66 |
+
self.kernel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3)
|
67 |
+
self.kernel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3)
|
68 |
+
self.l1 = nn.L1Loss()
|
69 |
+
|
70 |
+
# Grayscale conversion weights (ITU-R BT.601)
|
71 |
+
self.rgb_to_gray_weights = torch.tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)
|
72 |
+
|
73 |
+
def _get_edges(self, img):
|
74 |
+
"""
|
75 |
+
Converts an RGB image to grayscale and applies Sobel filters.
|
76 |
+
Args:
|
77 |
+
img: [B, 3, H, W] image tensor in range [-1, 1].
|
78 |
+
Returns:
|
79 |
+
Gradient magnitude map [B, 1, H, W].
|
80 |
+
"""
|
81 |
+
|
82 |
+
# Convert from [-1, 1] to [0, 1]
|
83 |
+
img = (img + 1.0) / 2.0
|
84 |
+
|
85 |
+
# Convert to grayscale
|
86 |
+
grayscale_img = F.conv2d(img, self.rgb_to_gray_weights.to(img.device))
|
87 |
+
|
88 |
+
# Apply Sobel filters
|
89 |
+
grad_x = F.conv2d(grayscale_img, self.kernel_x.to(img.device), padding=1)
|
90 |
+
grad_y = F.conv2d(grayscale_img, self.kernel_y.to(img.device), padding=1)
|
91 |
+
|
92 |
+
# Compute gradient magnitude
|
93 |
+
edges = torch.sqrt(grad_x**2 + grad_y**2 + 1e-6) # add epsilon for stability
|
94 |
+
return edges
|
95 |
+
|
96 |
+
def forward(self, img_gen, img_ref):
|
97 |
+
"""
|
98 |
+
img_gen, img_ref: [B, 3, H, W] in range [-1, 1].
|
99 |
+
Returns: L1 loss between the edge maps of the two images.
|
100 |
+
"""
|
101 |
+
edges_gen = self._get_edges(img_gen)
|
102 |
+
edges_ref = self._get_edges(img_ref)
|
103 |
+
return self.l1(edges_gen, edges_ref)
|
pikapikagen/model.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from model_blocks.text_encoder import TextEncoder
|
4 |
+
from model_blocks.image_decoder import ImageDecoder
|
5 |
+
|
6 |
+
class Generator(nn.Module):
|
7 |
+
"""
|
8 |
+
Modello completo che unisce Encoder e Decoder.
|
9 |
+
"""
|
10 |
+
def __init__(self, text_encoder_model_name="prajjwal1/bert-mini", noise_dim=100):
|
11 |
+
super().__init__()
|
12 |
+
self.text_encoder = TextEncoder(
|
13 |
+
model_name=text_encoder_model_name,
|
14 |
+
)
|
15 |
+
|
16 |
+
text_embed_dim = 256
|
17 |
+
|
18 |
+
self.image_decoder = ImageDecoder(
|
19 |
+
noise_dim=noise_dim,
|
20 |
+
text_embed_dim=text_embed_dim
|
21 |
+
)
|
22 |
+
|
23 |
+
self.noise_dim = noise_dim
|
24 |
+
|
25 |
+
def forward(self, token_ids, attention_mask, return_attentions=False):
|
26 |
+
# token_ids.shape: (batch_size, seq_len)
|
27 |
+
# attention_mask.shape: (batch_size, seq_len)
|
28 |
+
# Genera rumore casuale per il batch
|
29 |
+
batch_size = token_ids.size(0)
|
30 |
+
# noise.shape: (batch_size, noise_dim)
|
31 |
+
noise = torch.randn(batch_size, self.noise_dim, device=token_ids.device)
|
32 |
+
|
33 |
+
# 1. Codifica il testo per ottenere i vettori di ogni parola
|
34 |
+
# encoder_output.shape: (batch_size, seq_len, text_embed_dim)
|
35 |
+
encoder_output = self.text_encoder(token_ids, attention_mask=attention_mask)
|
36 |
+
|
37 |
+
# 2. Genera l'immagine usando l'output completo dell'encoder
|
38 |
+
# Il decoder calcolerà internamente sia il contesto iniziale (ATTENZIONE #1)
|
39 |
+
# sia l'attenzione per-step (ATTENZIONE #2)
|
40 |
+
# generated_image_256.shape: (batch_size, 3, 256, 256)
|
41 |
+
# generated_image_64.shape: (batch_size, 3, 64, 64)
|
42 |
+
generated_image_256, generated_image_64, attention_maps, initial_attention_weights = self.image_decoder(noise, encoder_output, attention_mask)
|
43 |
+
|
44 |
+
if return_attentions:
|
45 |
+
return generated_image_256, generated_image_64, attention_maps, initial_attention_weights
|
46 |
+
return generated_image_256, generated_image_64
|
pikapikagen/model_blocks/decoder_block.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from model_blocks.image_cross_attention import ImageCrossAttention
|
4 |
+
|
5 |
+
class DecoderBlock(nn.Module):
|
6 |
+
"""
|
7 |
+
Image decoder block
|
8 |
+
Channel adaptation (if necessary) -> Attention (optional) -> Merge -> Residual connection
|
9 |
+
-> Upsampling (ConvTranspose) -> Normalization -> Activation.
|
10 |
+
"""
|
11 |
+
def __init__(self, in_channels, out_channels, use_attention=True, text_embed_dim=256, nhead=4):
|
12 |
+
super().__init__()
|
13 |
+
self.use_attention = use_attention
|
14 |
+
|
15 |
+
if self.use_attention:
|
16 |
+
# If in_channels is different from text_embed_dim, add a 1x1 conv to adapt the channel size
|
17 |
+
if in_channels != text_embed_dim:
|
18 |
+
self.channel_adapter = nn.Conv2d(in_channels, text_embed_dim, kernel_size=1, bias=False)
|
19 |
+
else:
|
20 |
+
self.channel_adapter = None
|
21 |
+
|
22 |
+
self.cross_attention = ImageCrossAttention(embed_dim=text_embed_dim, num_heads=nhead)
|
23 |
+
# Convolution to merge the text_embedding and the cross-attention output
|
24 |
+
self.fusion_conv = nn.Conv2d(text_embed_dim * 2, in_channels, kernel_size=1, bias=False)
|
25 |
+
|
26 |
+
# Upsample block as described in the instructions
|
27 |
+
self.upsample_block = nn.Sequential(
|
28 |
+
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
|
29 |
+
nn.GroupNorm(1, out_channels),
|
30 |
+
nn.LeakyReLU(inplace=True)
|
31 |
+
)
|
32 |
+
|
33 |
+
def forward(self, x, encoder_output=None, attention_mask=None):
|
34 |
+
attn_weights = None
|
35 |
+
if self.use_attention:
|
36 |
+
if encoder_output is None or attention_mask is None:
|
37 |
+
raise ValueError("encoder_output and attention_mask must be provided for attention.")
|
38 |
+
|
39 |
+
# Adapt channel size if deemed necessary
|
40 |
+
if self.channel_adapter is not None:
|
41 |
+
x_adapted = self.channel_adapter(x)
|
42 |
+
else:
|
43 |
+
x_adapted = x
|
44 |
+
|
45 |
+
attn_output, attn_weights = self.cross_attention(
|
46 |
+
image_features=x_adapted,
|
47 |
+
text_features=encoder_output,
|
48 |
+
key_padding_mask=attention_mask
|
49 |
+
)
|
50 |
+
|
51 |
+
# Concatenates the features with the cross-attention output,
|
52 |
+
# then conv 1x1 and residual connection
|
53 |
+
fused_features = torch.cat([x_adapted, attn_output], dim=1) # Shape: (B, 2*in_channels, H, W)
|
54 |
+
skip = self.fusion_conv(fused_features) # Shape: (B, in_channels, H, W)
|
55 |
+
x = x + skip # Shape: (B, in_channels, H, W)
|
56 |
+
|
57 |
+
|
58 |
+
x = self.upsample_block(x)
|
59 |
+
return x, attn_weights
|
pikapikagen/model_blocks/image_cross_attention.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
class ImageCrossAttention(nn.Module):
|
4 |
+
"""
|
5 |
+
Image cross-attention module
|
6 |
+
Allows a sequence of queries (from the image) to "pay attention"
|
7 |
+
to a sequence of key/value (from the text), internally managing
|
8 |
+
the reshaping of tensors and the attention mask.
|
9 |
+
"""
|
10 |
+
def __init__(self, embed_dim, num_heads):
|
11 |
+
super().__init__()
|
12 |
+
self.attention = nn.MultiheadAttention(
|
13 |
+
embed_dim=embed_dim, num_heads=num_heads, batch_first=True
|
14 |
+
)
|
15 |
+
self.layer_norm = nn.LayerNorm(embed_dim)
|
16 |
+
|
17 |
+
def forward(self, image_features, text_features, key_padding_mask=None):
|
18 |
+
# query: (B, C, H, W) - Image features
|
19 |
+
# key/value: (B, seq_len, embed_dim) - Text encoder output
|
20 |
+
# key_padding_mask: (B, seq_len) - Attention mask from the tokenizer
|
21 |
+
|
22 |
+
B, C, H, W = image_features.shape
|
23 |
+
|
24 |
+
# Reshape from image to sequence: (B, C, H, W) -> (B, H*W, C)
|
25 |
+
query_seq = image_features.view(B, C, H * W).permute(0, 2, 1)
|
26 |
+
query_norm = self.layer_norm(query_seq)
|
27 |
+
|
28 |
+
# Prepare the padding mask from the attention mask
|
29 |
+
# The HuggingFace mask is 1 for real tokens, 0 for padding.
|
30 |
+
# MultiheadAttention expects True for positions to ignore.
|
31 |
+
if key_padding_mask is not None:
|
32 |
+
mask = (key_padding_mask == 0)
|
33 |
+
else:
|
34 |
+
mask = None
|
35 |
+
|
36 |
+
attn_output, attn_weights = self.attention(
|
37 |
+
query=query_norm,
|
38 |
+
key=text_features,
|
39 |
+
value=text_features,
|
40 |
+
key_padding_mask=mask,
|
41 |
+
need_weights=True
|
42 |
+
)
|
43 |
+
# attn_output: (B, H*W, C)
|
44 |
+
|
45 |
+
# Convert output back into its original size
|
46 |
+
# (B, H*W, C) -> (B, C, H*W) -> (B, C, H, W)
|
47 |
+
attn_output_spatial = attn_output.permute(0, 2, 1).view(B, C, H, W)
|
48 |
+
|
49 |
+
return attn_output_spatial, attn_weights
|
pikapikagen/model_blocks/image_decoder.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from model_blocks.decoder_block import DecoderBlock
|
4 |
+
|
5 |
+
class ImageDecoder(nn.Module):
|
6 |
+
"""
|
7 |
+
Decoder CNN (Generatore) che sintetizza l'immagine.
|
8 |
+
Questa versione usa l'attenzione per-step fin dall'inizio.
|
9 |
+
"""
|
10 |
+
def __init__(self, noise_dim, text_embed_dim, final_image_channels=3):
|
11 |
+
super().__init__()
|
12 |
+
|
13 |
+
# Mechanism to calculate attention scores for the initial context.
|
14 |
+
self.initial_context_scorer = nn.Sequential(
|
15 |
+
nn.Linear(in_features=text_embed_dim, out_features=512),
|
16 |
+
nn.Tanh(),
|
17 |
+
nn.Linear(in_features=512, out_features=1)
|
18 |
+
# Softmax applied in forward pass to use the attention mask
|
19 |
+
)
|
20 |
+
|
21 |
+
# Initial linear projection to a 4x4 feature map.
|
22 |
+
self.initial_projection = nn.Sequential(
|
23 |
+
nn.Linear(noise_dim + text_embed_dim, 256 * 4 * 4),
|
24 |
+
nn.GroupNorm(1, 256 * 4 * 4),
|
25 |
+
nn.LeakyReLU(inplace=True)
|
26 |
+
)
|
27 |
+
|
28 |
+
# Shared blocks for both resolutions (until 64x64)
|
29 |
+
self.blocks_64 = nn.ModuleList([
|
30 |
+
# Input: (B, 256, 4, 4) -> Output: (B, 256, 8, 8)
|
31 |
+
DecoderBlock(in_channels=256, out_channels=256, use_attention=True),
|
32 |
+
# Input: (B, 256, 8, 8) -> Output: (B, 256, 16, 16)
|
33 |
+
DecoderBlock(in_channels=256, out_channels=256, use_attention=True),
|
34 |
+
# Input: (B, 256, 16, 16) -> Output: (B, 128, 32, 32)
|
35 |
+
DecoderBlock(in_channels=256, out_channels=128, use_attention=True),
|
36 |
+
# Input: (B, 128, 32, 32) -> Output: (B, 64, 64, 64)
|
37 |
+
DecoderBlock(in_channels=128, out_channels=64, use_attention=False),
|
38 |
+
])
|
39 |
+
|
40 |
+
# ModuleList is used instead of a Sequential for example because
|
41 |
+
# of the branching based on use_attention in the forward pass
|
42 |
+
|
43 |
+
# Blocks only for 256x256 (from 64x64 to 256x256)
|
44 |
+
self.blocks_256 = nn.ModuleList([
|
45 |
+
# Input: (B, 64, 64, 64) -> Output: (B, 32, 128, 128)
|
46 |
+
DecoderBlock(in_channels=64, out_channels=32, use_attention=True),
|
47 |
+
# Input: (B, 32, 128, 128) -> Output: (B, 16, 256, 256)
|
48 |
+
DecoderBlock(in_channels=32, out_channels=16, use_attention=False),
|
49 |
+
])
|
50 |
+
|
51 |
+
# Last layer to get to RGB channels - 256x256
|
52 |
+
# Input: (B, 16, 256, 256) -> Output: (B, 3, 256, 256)
|
53 |
+
self.final_conv_256 = nn.Conv2d(16, final_image_channels, kernel_size=3, padding=1)
|
54 |
+
self.final_activation_256 = nn.Tanh()
|
55 |
+
|
56 |
+
# Last layer to get to RGB channels - 64x64
|
57 |
+
# Input: (B, 64, 64, 64) -> Output: (B, 3, 64, 64)
|
58 |
+
self.final_conv_64 = nn.Conv2d(64, final_image_channels, kernel_size=3, padding=1)
|
59 |
+
self.final_activation_64 = nn.Tanh()
|
60 |
+
|
61 |
+
def forward(self, noise, encoder_output_full, attention_mask):
|
62 |
+
# noise.shape: (B, noise_dim)
|
63 |
+
# encoder_output_full.shape: (B, seq_len, text_embed_dim)
|
64 |
+
# attention_mask.shape: (B, seq_len)
|
65 |
+
|
66 |
+
# 1. Compute the first attention, with the scores (logits) for each token
|
67 |
+
attn_scores = self.initial_context_scorer(encoder_output_full)
|
68 |
+
|
69 |
+
# Apply attention mask before Softmax.
|
70 |
+
# Set the scores of the padding tokens, where attention mask is 0, to -inf.
|
71 |
+
# The mask is (B, seq_len), the scores (B, seq_len, 1)
|
72 |
+
# The unsqueeze takes care of the dimension diference.
|
73 |
+
attn_scores.masked_fill_(attention_mask.unsqueeze(-1) == 0, -1e9)
|
74 |
+
|
75 |
+
# attention_weights.shape: (B, seq_len, 1)
|
76 |
+
attention_weights = torch.softmax(attn_scores, dim=1)
|
77 |
+
|
78 |
+
# Weighted average of the encoder output
|
79 |
+
# context_vector.shape: (B, text_embed_dim)
|
80 |
+
context_vector = torch.sum(attention_weights * encoder_output_full, dim=1)
|
81 |
+
|
82 |
+
# 2. Merge the noise and the context vector for the initial projection
|
83 |
+
# initial_input.shape: (B, noise_dim + text_embed_dim)
|
84 |
+
initial_input = torch.cat([noise, context_vector], dim=1)
|
85 |
+
|
86 |
+
# 3. Initial projection and reshape to fit the transposed convolutions
|
87 |
+
# x.shape: (B, 256 * 4 * 4)
|
88 |
+
x = self.initial_projection(initial_input)
|
89 |
+
# x.shape: (B, 256, 4, 4)
|
90 |
+
x = x.view(x.size(0), 256, 4, 4)
|
91 |
+
|
92 |
+
# 4. Pass through the encoder blocks
|
93 |
+
attention_maps = []
|
94 |
+
|
95 |
+
# Shared path for both resolutions (fino a 64x64)
|
96 |
+
for block in self.blocks_64:
|
97 |
+
encoder_ctx = encoder_output_full if block.use_attention else None
|
98 |
+
mask_ctx = attention_mask if block.use_attention else None
|
99 |
+
x, attn_weights = block(x, encoder_ctx, mask_ctx)
|
100 |
+
if attn_weights is not None:
|
101 |
+
attention_maps.append(attn_weights)
|
102 |
+
|
103 |
+
# Now x has size (B, 64, 64, 64)
|
104 |
+
|
105 |
+
# 64x64-only path
|
106 |
+
image_64 = self.final_conv_64(x)
|
107 |
+
image_64 = self.final_activation_64(image_64)
|
108 |
+
|
109 |
+
# 5. 256x256-only path
|
110 |
+
for block in self.blocks_256:
|
111 |
+
encoder_ctx = encoder_output_full if block.use_attention else None
|
112 |
+
mask_ctx = attention_mask if block.use_attention else None
|
113 |
+
x, attn_weights = block(x, encoder_ctx, mask_ctx)
|
114 |
+
if attn_weights is not None:
|
115 |
+
attention_maps.append(attn_weights)
|
116 |
+
|
117 |
+
# Final layer for 256x256
|
118 |
+
# x_256.shape: (B, 16, 256, 256) -> (B, 3, 256, 256)
|
119 |
+
image_256 = self.final_conv_256(x)
|
120 |
+
image_256 = self.final_activation_256(image_256)
|
121 |
+
|
122 |
+
return image_256, image_64, attention_maps, attention_weights
|
pikapikagen/model_blocks/text_encoder.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from transformers import AutoModel
|
3 |
+
|
4 |
+
class TextEncoder(nn.Module):
|
5 |
+
"""
|
6 |
+
Text encoder
|
7 |
+
Uses bert-mini embeddings and passes them through a Transformer.
|
8 |
+
"""
|
9 |
+
def __init__(self, model_name="prajjwal1/bert-mini", fine_tune_embeddings=True):
|
10 |
+
super().__init__()
|
11 |
+
# Load the pre-trained bert-mini model for embeddings
|
12 |
+
bert_mini_model = AutoModel.from_pretrained(model_name)
|
13 |
+
|
14 |
+
self.embedding = bert_mini_model.embeddings
|
15 |
+
|
16 |
+
# Set whether to fine-tune the embeddings during training
|
17 |
+
for param in self.embedding.parameters():
|
18 |
+
param.requires_grad = fine_tune_embeddings
|
19 |
+
|
20 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
21 |
+
d_model=256, nhead=4, dim_feedforward=1024, batch_first=True
|
22 |
+
)
|
23 |
+
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
|
24 |
+
|
25 |
+
def forward(self, token_ids, attention_mask=None):
|
26 |
+
# Get the embeddings from the tokens
|
27 |
+
# Shape: (batch_size, seq_len) -> (batch_size, seq_len, embedding_dim)
|
28 |
+
embedded_text = self.embedding(token_ids)
|
29 |
+
|
30 |
+
# Prepare the padding mask for TransformerEncoder
|
31 |
+
# The HuggingFace mask is 1 for real tokens, 0 for padding.
|
32 |
+
# TransformerEncoder expects True for positions to ignore (padding).
|
33 |
+
src_key_padding_mask = None
|
34 |
+
if attention_mask is not None:
|
35 |
+
src_key_padding_mask = (attention_mask == 0)
|
36 |
+
|
37 |
+
# Pass the embeddings through the Transformer Encoder with the mask
|
38 |
+
# Shape: (batch_size, seq_len, embedding_dim) -> (batch_size, seq_len, embedding_dim)
|
39 |
+
encoder_output = self.transformer_encoder(
|
40 |
+
src=embedded_text,
|
41 |
+
src_key_padding_mask=src_key_padding_mask
|
42 |
+
)
|
43 |
+
return encoder_output
|
pikapikagen/model_checkpoint/checkpoint_epoch_150.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7902ac75581c4a54ec5345ccf2bd30440a99a4c1031b5c12af6cabb318dde225
|
3 |
+
size 789795998
|
pikapikagen/plots.py
ADDED
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import os
|
5 |
+
import io
|
6 |
+
from PIL import Image
|
7 |
+
from utils import denormalize_image
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from transformers import AutoTokenizer
|
10 |
+
|
11 |
+
|
12 |
+
def save_attention_visualization(
|
13 |
+
epoch, model, tokenizer, batch, device, set_name, output_dir, show_inline=False
|
14 |
+
):
|
15 |
+
print(f"Epoch {epoch}: Generating attention visualization for {set_name} set...")
|
16 |
+
|
17 |
+
attention_data = generate_attention_data(model, tokenizer, batch, device)
|
18 |
+
|
19 |
+
if attention_data:
|
20 |
+
plot_attention_visualization(
|
21 |
+
epoch=epoch,
|
22 |
+
set_name=set_name,
|
23 |
+
output_dir=output_dir,
|
24 |
+
show_inline=show_inline,
|
25 |
+
**attention_data,
|
26 |
+
)
|
27 |
+
print(f"Epoch {epoch}: Attention visualization saved for Pokémon #{attention_data['pokemon_id']}.")
|
28 |
+
else:
|
29 |
+
print(f"Epoch {epoch}: Skipped attention visualization due to missing data.")
|
30 |
+
|
31 |
+
|
32 |
+
def generate_attention_data(model, tokenizer, batch, device):
|
33 |
+
"""
|
34 |
+
Runs the model to generate the image and attention maps, filtering the padding tokens.
|
35 |
+
"""
|
36 |
+
model.eval()
|
37 |
+
|
38 |
+
with torch.no_grad():
|
39 |
+
token_ids = batch["text"].to(device)
|
40 |
+
attention_mask = batch["attention_mask"].to(device)
|
41 |
+
# Ensure batch size is 1 for visualization
|
42 |
+
if token_ids.dim() > 1:
|
43 |
+
token_ids = token_ids[0].unsqueeze(0)
|
44 |
+
attention_mask = attention_mask[0].unsqueeze(0)
|
45 |
+
|
46 |
+
# Get the first sample from the batch
|
47 |
+
pokemon_id = batch["idx"][0]
|
48 |
+
description = batch["description"][0]
|
49 |
+
|
50 |
+
generated_image, attention_maps, initial_context_weights = model(
|
51 |
+
token_ids, attention_mask, return_attentions=True
|
52 |
+
)
|
53 |
+
|
54 |
+
decoder_attention_maps = [m for m in attention_maps if m is not None]
|
55 |
+
|
56 |
+
if not decoder_attention_maps or initial_context_weights is None:
|
57 |
+
print("Attention maps not available. Skipping data generation.")
|
58 |
+
return None
|
59 |
+
|
60 |
+
# Extract valid tokens to display
|
61 |
+
tokens_all = tokenizer.convert_ids_to_tokens(token_ids.squeeze(0))
|
62 |
+
display_tokens = []
|
63 |
+
for i, token in enumerate(tokens_all):
|
64 |
+
if (
|
65 |
+
token not in [tokenizer.sep_token, tokenizer.pad_token]
|
66 |
+
and attention_mask[0, i] == 1
|
67 |
+
):
|
68 |
+
display_tokens.append({"token": token, "index": i})
|
69 |
+
|
70 |
+
if not display_tokens:
|
71 |
+
print(f"No valid tokens to display for '{description}'. Skipping.")
|
72 |
+
return None
|
73 |
+
|
74 |
+
return {
|
75 |
+
"generated_image": generated_image.cpu(),
|
76 |
+
"decoder_attention_maps": [m.cpu() for m in decoder_attention_maps],
|
77 |
+
"initial_context_weights": initial_context_weights.cpu(),
|
78 |
+
"display_tokens": display_tokens,
|
79 |
+
"description": description,
|
80 |
+
"pokemon_id": pokemon_id,
|
81 |
+
}
|
82 |
+
|
83 |
+
|
84 |
+
def plot_attention_visualization(
|
85 |
+
# Plot identification arguments
|
86 |
+
epoch: int,
|
87 |
+
set_name: str,
|
88 |
+
output_dir: str | None,
|
89 |
+
# Data generated by the model (can be full batches)
|
90 |
+
generated_images: torch.Tensor,
|
91 |
+
decoder_attention_maps: list[torch.Tensor],
|
92 |
+
initial_context_weights: torch.Tensor,
|
93 |
+
# Original text input (can be a full batch)
|
94 |
+
token_ids: torch.Tensor,
|
95 |
+
attention_mask: torch.Tensor,
|
96 |
+
tokenizer: AutoTokenizer,
|
97 |
+
# Batch metadata (for the specific sample)
|
98 |
+
description: str,
|
99 |
+
pokemon_id: int | str,
|
100 |
+
# Control options
|
101 |
+
sample_idx: int = 0,
|
102 |
+
show_inline: bool = False,
|
103 |
+
):
|
104 |
+
"""
|
105 |
+
Generates and saves an attention visualization for a single sample from a batch.
|
106 |
+
|
107 |
+
This function is self-contained: it accepts full batch tensors and internally
|
108 |
+
handles sample selection and token preparation.
|
109 |
+
|
110 |
+
Args:
|
111 |
+
epoch (int): Epoch number (for title/filename).
|
112 |
+
set_name (str): Set name (e.g., 'train', for title/filename).
|
113 |
+
output_dir (str, optional): Folder to save the image. If None, the plot is not saved.
|
114 |
+
|
115 |
+
generated_images (torch.Tensor): Tensor of generated images.
|
116 |
+
Shape: (B, C, H, W).
|
117 |
+
decoder_attention_maps (list[torch.Tensor]): List of attention tensors.
|
118 |
+
Each tensor shape: (B, num_patches, seq_length).
|
119 |
+
initial_context_weights (torch.Tensor): Initial attention weights.
|
120 |
+
Shape: (B, 1, seq_length).
|
121 |
+
|
122 |
+
token_ids (torch.Tensor): Input token.
|
123 |
+
Shape: (B, seq_length).
|
124 |
+
attention_mask (torch.Tensor): Attention mask for tokens.
|
125 |
+
Shape: (B, seq_length).
|
126 |
+
tokenizer: The tokenizer object for id -> token conversion.
|
127 |
+
|
128 |
+
description (str): The text prompt for the selected sample.
|
129 |
+
pokemon_id (int or str): The ID of the selected sample.
|
130 |
+
|
131 |
+
sample_idx (int, optional): Index of the sample in the batch to visualize.
|
132 |
+
Defaults to 0.
|
133 |
+
show_inline (bool, optional): If True, shows the plot. Defaults to False.
|
134 |
+
"""
|
135 |
+
# Select the specific sample using sample_idx and move to CPU
|
136 |
+
img_tensor = generated_images[sample_idx].cpu()
|
137 |
+
layer_maps = [m[sample_idx].cpu() for m in decoder_attention_maps if m is not None]
|
138 |
+
initial_weights = initial_context_weights[sample_idx].cpu()
|
139 |
+
token_ids_sample = token_ids[sample_idx].cpu()
|
140 |
+
attention_mask_sample = attention_mask[sample_idx].cpu()
|
141 |
+
|
142 |
+
# Token filtering logic
|
143 |
+
tokens_all = tokenizer.convert_ids_to_tokens(token_ids_sample)
|
144 |
+
display_tokens = []
|
145 |
+
for i, token in enumerate(tokens_all):
|
146 |
+
if (
|
147 |
+
token not in [tokenizer.sep_token, tokenizer.pad_token]
|
148 |
+
and attention_mask_sample[i] == 1
|
149 |
+
):
|
150 |
+
display_tokens.append({"token": token, "index": i})
|
151 |
+
|
152 |
+
img_tensor_cpu = denormalize_image(img_tensor).permute(1, 2, 0)
|
153 |
+
num_decoder_layers = len(layer_maps)
|
154 |
+
num_tokens = len(display_tokens)
|
155 |
+
token_indices_to_display = [t["index"] for t in display_tokens]
|
156 |
+
|
157 |
+
cols = min(num_tokens, 8)
|
158 |
+
rows_per_layer = (num_tokens + cols - 1) // cols
|
159 |
+
height_ratios = [3, 2] + [2 * rows_per_layer] * num_decoder_layers
|
160 |
+
fig_height = sum(height_ratios)
|
161 |
+
fig_width = max(20, 2.5 * cols)
|
162 |
+
|
163 |
+
fig = plt.figure(figsize=(fig_width, fig_height))
|
164 |
+
gs_main = fig.add_gridspec(len(height_ratios), 1, height_ratios=height_ratios, hspace=1.2)
|
165 |
+
fig.suptitle(f"Epoch {epoch}: Attention for Pokémon #{pokemon_id} ({set_name.capitalize()})", fontsize=24)
|
166 |
+
|
167 |
+
ax_main_img = fig.add_subplot(gs_main[0])
|
168 |
+
ax_main_img.imshow(img_tensor_cpu)
|
169 |
+
ax_main_img.set_title("Generated Image", fontsize=18)
|
170 |
+
ax_main_img.text(0.5, -0.1, f"Prompt: {description}", ha="center", va="top",
|
171 |
+
transform=ax_main_img.transAxes, fontsize=14, wrap=True)
|
172 |
+
ax_main_img.axis("off")
|
173 |
+
|
174 |
+
ax_initial_attn = fig.add_subplot(gs_main[1])
|
175 |
+
initial_weights_squeezed = initial_weights.squeeze().numpy()
|
176 |
+
token_strings = [t["token"] for t in display_tokens]
|
177 |
+
relevant_weights = initial_weights_squeezed[[t["index"] for t in display_tokens]]
|
178 |
+
ax_initial_attn.bar(np.arange(len(token_strings)), relevant_weights, color="skyblue")
|
179 |
+
ax_initial_attn.set_xticks(np.arange(len(token_strings)))
|
180 |
+
ax_initial_attn.set_xticklabels(token_strings, rotation=45, ha="right", fontsize=10)
|
181 |
+
ax_initial_attn.set_title("Initial Context Attention (Global)", fontsize=16)
|
182 |
+
ax_initial_attn.set_ylabel("Weight", fontsize=12)
|
183 |
+
ax_initial_attn.grid(axis="y", linestyle="--", alpha=0.7)
|
184 |
+
|
185 |
+
# Iterate through each decoder layer's attention maps
|
186 |
+
for i, layer_attn_map in enumerate(layer_maps):
|
187 |
+
# layer_attn_map shape is now (num_patches, seq_len)
|
188 |
+
map_size_flat = layer_attn_map.shape[0]
|
189 |
+
map_side = int(np.sqrt(map_size_flat))
|
190 |
+
layer_title = f"Decoder Cross-Attention Layer {i+1} (Size: {map_side}x{map_side})"
|
191 |
+
|
192 |
+
# Extract attention weights only for tokens we want to display
|
193 |
+
relevant_attn_maps = layer_attn_map[:, token_indices_to_display]
|
194 |
+
vmin, vmax = relevant_attn_maps.min(), relevant_attn_maps.max()
|
195 |
+
|
196 |
+
# Create subplot grid for this layer
|
197 |
+
gs_layer = gs_main[2 + i].subgridspec(rows_per_layer, cols + 1, wspace=0.2, hspace=0.4, width_ratios=[*([1] * cols), 0.1])
|
198 |
+
axes_in_layer = [fig.add_subplot(gs_layer[r, c]) for r in range(rows_per_layer) for c in range(cols)]
|
199 |
+
|
200 |
+
# Add layer title above the token attention maps
|
201 |
+
if axes_in_layer:
|
202 |
+
y_pos = axes_in_layer[0].get_position().y1
|
203 |
+
fig.text(0.5, y_pos + 0.01, layer_title, ha="center", va="bottom", fontsize=16, weight="bold")
|
204 |
+
|
205 |
+
# Plot attention heatmap for each token
|
206 |
+
im = None
|
207 |
+
for j, token_info in enumerate(display_tokens):
|
208 |
+
if j >= len(axes_in_layer):
|
209 |
+
break
|
210 |
+
ax = axes_in_layer[j]
|
211 |
+
attn_for_token = layer_attn_map[:, token_info["index"]]
|
212 |
+
# Reshape flat attention to spatial grid
|
213 |
+
heatmap = attn_for_token.reshape(map_side, map_side)
|
214 |
+
im = ax.imshow(heatmap, cmap="jet", interpolation="nearest", vmin=vmin, vmax=vmax)
|
215 |
+
ax.set_title(f"'{token_info['token']}'", fontsize=12)
|
216 |
+
ax.axis("off")
|
217 |
+
|
218 |
+
# Add colorbar for the layer
|
219 |
+
if im:
|
220 |
+
cax = fig.add_subplot(gs_layer[:, -1])
|
221 |
+
cbar = fig.colorbar(im, cax=cax)
|
222 |
+
cbar.ax.tick_params(labelsize=10)
|
223 |
+
cbar.set_label("Attention Weight", rotation=270, labelpad=15, fontsize=12)
|
224 |
+
|
225 |
+
# Hide unused subplots
|
226 |
+
for j in range(num_tokens, len(axes_in_layer)):
|
227 |
+
axes_in_layer[j].axis("off")
|
228 |
+
|
229 |
+
plt.tight_layout(rect=(0, 0.03, 1, 0.96))
|
230 |
+
if output_dir is not None:
|
231 |
+
save_path = os.path.join(output_dir, f"{epoch:03d}_{set_name}_attention_visualization_{pokemon_id}.png")
|
232 |
+
plt.savefig(save_path, bbox_inches="tight")
|
233 |
+
|
234 |
+
# Save figure to bytes for potential further use (e.g., logging)
|
235 |
+
buf = io.BytesIO()
|
236 |
+
fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
|
237 |
+
buf.seek(0)
|
238 |
+
|
239 |
+
# Convert to PIL image
|
240 |
+
attention_plot = Image.open(buf)
|
241 |
+
|
242 |
+
if show_inline:
|
243 |
+
plt.show()
|
244 |
+
plt.close(fig)
|
245 |
+
|
246 |
+
return attention_plot
|
247 |
+
|
248 |
+
|
249 |
+
def save_plot_losses(losses_g, losses_d, output_dir="training_output", show_inline=True):
|
250 |
+
"""
|
251 |
+
Generates and saves a plot of the generator and discriminator losses.
|
252 |
+
"""
|
253 |
+
os.makedirs(output_dir, exist_ok=True)
|
254 |
+
|
255 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
256 |
+
ax.plot(losses_g, label="Generator Loss", color="blue")
|
257 |
+
ax.plot(losses_d, label="Discriminator Loss", color="red")
|
258 |
+
ax.set_title("Training Losses")
|
259 |
+
ax.set_xlabel("Epochs")
|
260 |
+
ax.set_ylabel("Loss")
|
261 |
+
ax.legend()
|
262 |
+
ax.grid(True)
|
263 |
+
|
264 |
+
save_path = os.path.join(output_dir, "training_losses.png")
|
265 |
+
plt.savefig(save_path)
|
266 |
+
print(f"Loss plot saved to: {save_path}")
|
267 |
+
|
268 |
+
if show_inline:
|
269 |
+
plt.show()
|
270 |
+
else:
|
271 |
+
plt.close(fig)
|
272 |
+
|
273 |
+
def save_plot_non_gan_losses(train_losses_history, val_losses_history, output_dir="training_output", show_inline=True, filter_losses=None):
|
274 |
+
"""
|
275 |
+
Generates and saves plots of losses for non-GAN models with multiple loss components.
|
276 |
+
|
277 |
+
Args:
|
278 |
+
train_losses_history (list[dict]): List of dicts containing training losses per epoch.
|
279 |
+
e.g., [{'l1': 0.5, 'sobel': 0.3}, ...]
|
280 |
+
val_losses_history (list[dict]): List of dicts containing validation losses per epoch.
|
281 |
+
output_dir (str): Directory to save the plot.
|
282 |
+
show_inline (bool): Whether to display the plot inline.
|
283 |
+
filter_losses (list[str], optional): List of loss names to plot.
|
284 |
+
If None, plots all found losses.
|
285 |
+
"""
|
286 |
+
os.makedirs(output_dir, exist_ok=True)
|
287 |
+
|
288 |
+
# Extract all unique loss keys from both training and validation
|
289 |
+
all_keys = set()
|
290 |
+
for losses_dict in train_losses_history + val_losses_history:
|
291 |
+
all_keys.update(losses_dict.keys())
|
292 |
+
|
293 |
+
# Filter out non-numeric keys if any
|
294 |
+
loss_keys = [key for key in all_keys if key not in ['epoch']]
|
295 |
+
|
296 |
+
# Apply filter if specified
|
297 |
+
if filter_losses is not None:
|
298 |
+
loss_keys = [key for key in loss_keys if key in filter_losses]
|
299 |
+
|
300 |
+
loss_keys = sorted(loss_keys) # Sort for consistent ordering
|
301 |
+
|
302 |
+
# Create subplots
|
303 |
+
n_losses = len(loss_keys)
|
304 |
+
cols = min(3, n_losses) # Max 3 columns
|
305 |
+
rows = (n_losses + cols - 1) // cols # Ceiling division
|
306 |
+
|
307 |
+
fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 4 * rows))
|
308 |
+
if n_losses == 1:
|
309 |
+
axes = [axes]
|
310 |
+
elif rows > 1:
|
311 |
+
axes = axes.flatten()
|
312 |
+
|
313 |
+
fig.suptitle("Training and Validation Losses", fontsize=16, y=0.98)
|
314 |
+
|
315 |
+
for i, loss_key in enumerate(loss_keys):
|
316 |
+
ax = axes[i]
|
317 |
+
|
318 |
+
# Extract train and validation losses for this key
|
319 |
+
train_values = [losses.get(loss_key, 0) for losses in train_losses_history]
|
320 |
+
val_values = [losses.get(loss_key, 0) for losses in val_losses_history]
|
321 |
+
|
322 |
+
epochs_train = range(1, len(train_values) + 1)
|
323 |
+
epochs_val = range(1, len(val_values) + 1)
|
324 |
+
|
325 |
+
# Plot training and validation curves
|
326 |
+
if train_values:
|
327 |
+
ax.plot(epochs_train, train_values, label=f"Train {loss_key}", color="blue", linewidth=1.5)
|
328 |
+
if val_values:
|
329 |
+
ax.plot(epochs_val, val_values, label=f"Val {loss_key}", color="red", linewidth=1.5, linestyle='--')
|
330 |
+
|
331 |
+
ax.set_title(f"{loss_key.capitalize()} Loss", fontsize=12)
|
332 |
+
ax.set_xlabel("Epoch")
|
333 |
+
ax.set_ylabel("Loss")
|
334 |
+
ax.legend()
|
335 |
+
ax.grid(True, alpha=0.3)
|
336 |
+
ax.set_ylim(bottom=0)
|
337 |
+
|
338 |
+
# Hide unused subplots
|
339 |
+
for i in range(n_losses, len(axes)):
|
340 |
+
axes[i].set_visible(False)
|
341 |
+
|
342 |
+
plt.tight_layout()
|
343 |
+
|
344 |
+
# Save the plot
|
345 |
+
save_path = os.path.join(output_dir, "non_gan_training_losses.png")
|
346 |
+
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
347 |
+
print(f"Non-GAN training losses plot saved to: {save_path}")
|
348 |
+
|
349 |
+
if show_inline:
|
350 |
+
plt.show()
|
351 |
+
else:
|
352 |
+
plt.close(fig)
|
353 |
+
|
354 |
+
|
355 |
+
def save_comparison_grid(epoch, model, batch, set_name, device, output_dir="training_output", show_inline=True):
|
356 |
+
"""
|
357 |
+
Generates and saves/shows a horizontal comparison grid (real vs. generated).
|
358 |
+
Automatically handles 256x256 or 64x64 output based on set_name.
|
359 |
+
"""
|
360 |
+
os.makedirs(output_dir, exist_ok=True)
|
361 |
+
|
362 |
+
model.eval()
|
363 |
+
token_ids = batch["text"].to(device)
|
364 |
+
attention_mask = batch["attention_mask"].to(device)
|
365 |
+
real_images = batch["image"]
|
366 |
+
pokemon_ids = batch["idx"]
|
367 |
+
descriptions = batch["description"]
|
368 |
+
num_images = real_images.size(0)
|
369 |
+
|
370 |
+
with torch.no_grad():
|
371 |
+
generated_images = model(token_ids, attention_mask)
|
372 |
+
# Handle tuple output from generator (e.g., 256px and 64px images)
|
373 |
+
if isinstance(generated_images, tuple):
|
374 |
+
# Check if we want 64x64 or 256x256 based on set_name
|
375 |
+
if "64" in set_name:
|
376 |
+
generated_images = generated_images[1] # Use 64x64 output
|
377 |
+
# Resize real images to 64x64 for comparison
|
378 |
+
real_images = F.interpolate(real_images, size=(64, 64), mode='bilinear', align_corners=False)
|
379 |
+
else:
|
380 |
+
generated_images = generated_images[0] # Use 256x256 output
|
381 |
+
|
382 |
+
fig, axs = plt.subplots(2, num_images, figsize=(4 * num_images, 8.5))
|
383 |
+
resolution = "64x64" if "64" in set_name else "256x256"
|
384 |
+
fig.suptitle(
|
385 |
+
f"Epoch {epoch} - {set_name.capitalize()} Comparison ({resolution})", fontsize=16, y=0.98
|
386 |
+
)
|
387 |
+
|
388 |
+
for i in range(num_images):
|
389 |
+
ax_real = axs[0, i]
|
390 |
+
ax_real.imshow(denormalize_image(real_images[i].cpu()).permute(1, 2, 0))
|
391 |
+
ax_real.set_title(f"#{pokemon_ids[i]}: {descriptions[i][:35]}...", fontsize=10)
|
392 |
+
ax_real.axis("off")
|
393 |
+
|
394 |
+
ax_gen = axs[1, i]
|
395 |
+
ax_gen.imshow(denormalize_image(generated_images[i].cpu()).permute(1, 2, 0))
|
396 |
+
ax_gen.axis("off")
|
397 |
+
|
398 |
+
axs[0, 0].text(
|
399 |
+
-0.1,
|
400 |
+
0.5,
|
401 |
+
"Real",
|
402 |
+
ha="center",
|
403 |
+
va="center",
|
404 |
+
rotation="vertical",
|
405 |
+
fontsize=14,
|
406 |
+
transform=axs[0, 0].transAxes,
|
407 |
+
)
|
408 |
+
axs[1, 0].text(
|
409 |
+
-0.1,
|
410 |
+
0.5,
|
411 |
+
"Generated",
|
412 |
+
ha="center",
|
413 |
+
va="center",
|
414 |
+
rotation="vertical",
|
415 |
+
fontsize=14,
|
416 |
+
transform=axs[1, 0].transAxes,
|
417 |
+
)
|
418 |
+
|
419 |
+
plt.tight_layout(rect=(0, 0, 1, 0.95))
|
420 |
+
|
421 |
+
# Save the figure and optionally show it
|
422 |
+
save_path = os.path.join(output_dir, f"{epoch:03d}_{set_name}_comparison.png")
|
423 |
+
plt.savefig(save_path)
|
424 |
+
|
425 |
+
if show_inline:
|
426 |
+
plt.show()
|
427 |
+
else:
|
428 |
+
plt.close(fig)
|
pikapikagen/utils.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def denormalize_image(tensor):
|
2 |
+
"""
|
3 |
+
Denormalizza un tensore immagine dall'intervallo [-1, 1] a [0, 1] per la visualizzazione.
|
4 |
+
|
5 |
+
Args:
|
6 |
+
tensor (torch.Tensor): Il tensore dell'immagine, con valori in [-1, 1].
|
7 |
+
|
8 |
+
Returns:
|
9 |
+
torch.Tensor: Il tensore denormalizzato con valori in [0, 1].
|
10 |
+
"""
|
11 |
+
tensor = (tensor + 1) / 2
|
12 |
+
return tensor.clamp(0, 1)
|
pyproject.toml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[project]
|
2 |
+
name = "pikapikagen"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = "Add your description here"
|
5 |
+
readme = "README.md"
|
6 |
+
requires-python = ">=3.12"
|
7 |
+
dependencies = [
|
8 |
+
"gradio>=5.35.0",
|
9 |
+
"ipykernel>=6.29.5",
|
10 |
+
"ipywidgets>=8.1.7",
|
11 |
+
"jupyterlab>=4.4.5",
|
12 |
+
"matplotlib>=3.10.3",
|
13 |
+
"pandas>=2.3.0",
|
14 |
+
"sentence-transformers>=5.0.0",
|
15 |
+
"torch-fidelity>=0.3.0",
|
16 |
+
"torchvision>=0.22.1",
|
17 |
+
"transformers>=4.53.0",
|
18 |
+
]
|
uv.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|