Val-2 commited on
Commit
66347a3
·
1 Parent(s): 07c3151

First commit

Browse files
.gitattributes CHANGED
@@ -1,35 +1 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  *.pth filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
.github/workflows/update_space.yml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Run Python script
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+
8
+ jobs:
9
+ build:
10
+ runs-on: ubuntu-latest
11
+
12
+ steps:
13
+ - name: Checkout
14
+ uses: actions/checkout@v2
15
+
16
+ - name: Set up Python
17
+ uses: actions/setup-python@v2
18
+ with:
19
+ python-version: '3.9'
20
+
21
+ - name: Install Gradio
22
+ run: python -m pip install gradio
23
+
24
+ - name: Log in to Hugging Face
25
+ run: python -c 'import huggingface_hub; huggingface_hub.login(token="${{ secrets.hf_token }}")'
26
+
27
+ - name: Deploy to Spaces
28
+ run: gradio deploy
.gitignore ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[codz]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py.cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+ #poetry.toml
110
+
111
+ # pdm
112
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
114
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
115
+ #pdm.lock
116
+ #pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # pixi
121
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
122
+ #pixi.lock
123
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
124
+ # in the .venv directory. It is recommended not to include this directory in version control.
125
+ .pixi
126
+
127
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
128
+ __pypackages__/
129
+
130
+ # Celery stuff
131
+ celerybeat-schedule
132
+ celerybeat.pid
133
+
134
+ # SageMath parsed files
135
+ *.sage.py
136
+
137
+ # Environments
138
+ .env
139
+ .envrc
140
+ .venv
141
+ env/
142
+ venv/
143
+ ENV/
144
+ env.bak/
145
+ venv.bak/
146
+
147
+ # Spyder project settings
148
+ .spyderproject
149
+ .spyproject
150
+
151
+ # Rope project settings
152
+ .ropeproject
153
+
154
+ # mkdocs documentation
155
+ /site
156
+
157
+ # mypy
158
+ .mypy_cache/
159
+ .dmypy.json
160
+ dmypy.json
161
+
162
+ # Pyre type checker
163
+ .pyre/
164
+
165
+ # pytype static type analyzer
166
+ .pytype/
167
+
168
+ # Cython debug symbols
169
+ cython_debug/
170
+
171
+ # PyCharm
172
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
173
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
174
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
175
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
176
+ #.idea/
177
+
178
+ # Abstra
179
+ # Abstra is an AI-powered process automation framework.
180
+ # Ignore directories containing user credentials, local state, and settings.
181
+ # Learn more at https://abstra.io/docs
182
+ .abstra/
183
+
184
+ # Visual Studio Code
185
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
186
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
187
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
188
+ # you could uncomment the following to ignore the entire vscode folder
189
+ # .vscode/
190
+
191
+ # Ruff stuff:
192
+ .ruff_cache/
193
+
194
+ # PyPI configuration file
195
+ .pypirc
196
+
197
+ # Cursor
198
+ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
199
+ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
200
+ # refer to https://docs.cursor.com/context/ignore-files
201
+ .cursorignore
202
+ .cursorindexingignore
203
+
204
+ # Marimo
205
+ marimo/_static/
206
+ marimo/_lsp/
207
+ __marimo__/
208
+
209
+ # Streamlit
210
+ .streamlit/secrets.toml
211
+
212
+ # Project
213
+ /pikapikagen/dataset
214
+ /pikapikagen/training_output
215
+ /dataset
216
+ /old_notebooks/dataset
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.idea/DeepLearning.iml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="jdk" jdkName="Python 3.12 virtualenv at C:\Users\valer\Mega\Programming\DeepLearning\.venv" jdkType="Python SDK" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ </module>
.idea/deployment.xml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="PublishConfigData" autoUpload="Always" remoteFilesAllowedToDisappearOnAutoupload="false">
4
+ <serverData>
5
+ <paths name="root@salad:22 agent">
6
+ <serverdata>
7
+ <mappings>
8
+ <mapping deploy="/tmp/pycharm_project_71" local="$PROJECT_DIR$" />
9
+ </mappings>
10
+ </serverdata>
11
+ </paths>
12
+ <paths name="val@46.101.132.64:22 key">
13
+ <serverdata>
14
+ <mappings>
15
+ <mapping local="$PROJECT_DIR$" web="/" />
16
+ </mappings>
17
+ </serverdata>
18
+ </paths>
19
+ <paths name="val@46.101.132.64:22 key (2)">
20
+ <serverdata>
21
+ <mappings>
22
+ <mapping local="$PROJECT_DIR$" web="/" />
23
+ </mappings>
24
+ </serverdata>
25
+ </paths>
26
+ </serverData>
27
+ <option name="myAutoUpload" value="ALWAYS" />
28
+ </component>
29
+ </project>
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="DuplicatedCode" enabled="true" level="WEAK WARNING" enabled_by_default="true">
5
+ <Languages>
6
+ <language minSize="49" name="Python" />
7
+ </Languages>
8
+ </inspection_tool>
9
+ <inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
10
+ <inspection_tool class="Mypy" enabled="true" level="TYPO" enabled_by_default="true" editorAttributes="TYPO" />
11
+ <inspection_tool class="PyPep8NamingInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
12
+ <option name="ignoredErrors">
13
+ <list>
14
+ <option value="N802" />
15
+ <option value="N803" />
16
+ <option value="N806" />
17
+ </list>
18
+ </option>
19
+ </inspection_tool>
20
+ <inspection_tool class="PyUnresolvedReferencesInspection" enabled="true" level="WARNING" enabled_by_default="true">
21
+ <option name="ignoredIdentifiers">
22
+ <list>
23
+ <option value="fitz.fitz.Page.MediaBox" />
24
+ <option value="color_tol" />
25
+ </list>
26
+ </option>
27
+ </inspection_tool>
28
+ <inspection_tool class="SpellCheckingInspection" enabled="false" level="TYPO" enabled_by_default="false">
29
+ <option name="processCode" value="true" />
30
+ <option name="processLiterals" value="true" />
31
+ <option name="processComments" value="true" />
32
+ </inspection_tool>
33
+ </profile>
34
+ </component>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="Black">
4
+ <option name="sdkName" value="Python 3.12 virtualenv at C:\Users\valer\Mega\Programming\DeepLearning\.venv" />
5
+ </component>
6
+ <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.12 virtualenv at C:\Users\valer\Mega\Programming\DeepLearning\.venv" project-jdk-type="Python SDK" />
7
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/DeepLearning.iml" filepath="$PROJECT_DIR$/.idea/DeepLearning.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="" vcs="Git" />
5
+ </component>
6
+ </project>
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12
PikaPikaTraining.ipynb ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# PikaPikaGen: Training del Modello\n",
8
+ "\n",
9
+ "Questo notebook automatizza il processo di setup e avvio del training per il modello PikaPikaGen.\n",
10
+ "\n",
11
+ "I passaggi eseguiti sono:\n",
12
+ "1. Clonazione del repository GitHub pubblico.\n",
13
+ "2. Installazione delle dipendenze necessarie tramite `uv`.\n",
14
+ "3. Esecuzione dello script di training `main.py`."
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": null,
20
+ "metadata": {},
21
+ "outputs": [],
22
+ "source": [
23
+ "print(\"Installazione delle dipendenze necessarie...\")\n",
24
+ "\n",
25
+ "# Assicurati che uv sia installato\n",
26
+ "%pip install uv\n",
27
+ "print(\"✅ uv installato con successo.\")\n",
28
+ "\n",
29
+ "# Controlla se torch è già installato\n",
30
+ "try:\n",
31
+ " import torch\n",
32
+ " print(f\"✅ PyTorch già installato (versione: {torch.__version__})\")\n",
33
+ " torch_installed = True\n",
34
+ "except ImportError:\n",
35
+ " print(\"❌ PyTorch non trovato, sarà installato\")\n",
36
+ " torch_installed = False\n",
37
+ "\n",
38
+ "# Lista delle dipendenze principali del progetto\n",
39
+ "dependencies = [\n",
40
+ " \"transformers\",\n",
41
+ " \"pandas\",\n",
42
+ " \"tqdm\",\n",
43
+ " \"matplotlib\",\n",
44
+ " \"Pillow\",\n",
45
+ " \"requests\",\n",
46
+ " \"ipywidgets\"\n",
47
+ "]\n",
48
+ "\n",
49
+ "# Aggiungi torch e torchvision solo se non sono già installati\n",
50
+ "if not torch_installed:\n",
51
+ " dependencies.extend([\"torch\", \"torchvision\"])\n",
52
+ "\n",
53
+ "print(\"Installazione delle dipendenze con uv...\")\n",
54
+ "deps_str = \" \".join(dependencies)\n",
55
+ "if torch_installed:\n",
56
+ " !uv pip install {deps_str}\n",
57
+ "else:\n",
58
+ " !uv pip install {deps_str} --torch-backend=auto\n",
59
+ "print(\"✅ Dipendenze principali installate con successo.\")\n"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": null,
65
+ "metadata": {},
66
+ "outputs": [],
67
+ "source": [
68
+ "import os\n",
69
+ "\n",
70
+ "repo_url = \"https://github.com/val-2/DeepLearning\"\n",
71
+ "branch = \"main\"\n",
72
+ "repo_name = repo_url.split('/')[-1]\n",
73
+ "\n",
74
+ "print(f\"Clonazione del repository: {repo_url}\")\n",
75
+ "\n",
76
+ "# Check if we're already in the repo directory\n",
77
+ "current_dir = os.path.basename(os.getcwd())\n",
78
+ "if current_dir == repo_name:\n",
79
+ " print(f\"Già nella directory del repository '{repo_name}'. Aggiornamento...\")\n",
80
+ " !git fetch\n",
81
+ " !git pull\n",
82
+ " !git checkout {branch}\n",
83
+ "elif os.path.exists(repo_name):\n",
84
+ " print(f\"La directory '{repo_name}' esiste già. Aggiornamento del repository...\")\n",
85
+ " os.chdir(repo_name)\n",
86
+ " !git fetch\n",
87
+ " !git pull\n",
88
+ " !git checkout {branch}\n",
89
+ "else:\n",
90
+ " print(\"Clonazione del repository...\")\n",
91
+ " !git clone -b {branch} {repo_url}\n",
92
+ " os.chdir(repo_name)\n",
93
+ "\n",
94
+ "# Spostati nella directory del repository\n",
95
+ "print(f\"Directory di lavoro corrente: {os.getcwd()}\")"
96
+ ]
97
+ }
98
+ ],
99
+ "metadata": {
100
+ "kernelspec": {
101
+ "display_name": ".venv",
102
+ "language": "python",
103
+ "name": "python3"
104
+ },
105
+ "language_info": {
106
+ "name": "python",
107
+ "version": "3.12.11"
108
+ }
109
+ },
110
+ "nbformat": 4,
111
+ "nbformat_minor": 2
112
+ }
pikapikagen/PikaPikaGen.ipynb ADDED
@@ -0,0 +1,2241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "raw",
5
+ "metadata": {
6
+ "id": "VDSaH9SVsnNl",
7
+ "vscode": {
8
+ "languageId": "raw"
9
+ }
10
+ },
11
+ "source": [
12
+ "# PikaPikaGen: Text-to-Image Pokemon Sprite Generation with GAN\n"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": null,
18
+ "metadata": {},
19
+ "outputs": [],
20
+ "source": [
21
+ "# Install required packages\n",
22
+ "#!pip install torch torchvision transformers pandas pillow requests matplotlib tqdm ipywidgets gradio torch-fidelity\n",
23
+ "\n",
24
+ "import torch\n",
25
+ "import torch.nn as nn\n",
26
+ "import torch.optim as optim\n",
27
+ "import torch.nn.functional as F\n",
28
+ "\n",
29
+ "import numpy as np\n",
30
+ "import matplotlib.pyplot as plt\n",
31
+ "import os\n",
32
+ "from tqdm import tqdm\n",
33
+ "from transformers import AutoTokenizer\n",
34
+ "import warnings\n",
35
+ "warnings.filterwarnings('ignore')\n",
36
+ "\n",
37
+ "# Set device\n",
38
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
39
+ "print(f\"Using device: {device}\")\n",
40
+ "\n",
41
+ "# Set random seeds for reproducibility\n",
42
+ "RANDOM_SEED = 42\n",
43
+ "torch.manual_seed(RANDOM_SEED)\n",
44
+ "np.random.seed(RANDOM_SEED)"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "raw",
49
+ "metadata": {
50
+ "id": "-rrtsHGqsnNo",
51
+ "vscode": {
52
+ "languageId": "raw"
53
+ }
54
+ },
55
+ "source": [
56
+ "## 1. Data Loading and Preprocessing\n"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": 4,
62
+ "metadata": {
63
+ "id": "aeVuv1YCsnNp"
64
+ },
65
+ "outputs": [],
66
+ "source": [
67
+ "import torch\n",
68
+ "import torchvision.transforms as T\n",
69
+ "\n",
70
+ "\n",
71
+ "class AugmentationPipeline:\n",
72
+ " def __init__(self, p=0.8):\n",
73
+ " self.p = p\n",
74
+ " self.transforms = T.RandomApply([\n",
75
+ " T.RandomHorizontalFlip(p=0.5),\n",
76
+ "\n",
77
+ " T.RandomAffine(degrees=10, translate=(0.05, 0.05), scale=(0.95, 1.05), fill=1),\n",
78
+ "\n",
79
+ " T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),\n",
80
+ "\n",
81
+ " T.RandomErasing(p=0.15, scale=(0.02, 0.1), ratio=(0.3, 3.3), value='random'),\n",
82
+ " ], p=self.p)\n",
83
+ "\n",
84
+ " def apply(self, images):\n",
85
+ " return self.transforms(images)"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": null,
91
+ "metadata": {
92
+ "colab": {
93
+ "base_uri": "https://localhost:8080/",
94
+ "height": 1000,
95
+ "referenced_widgets": [
96
+ "5efdceae0bac4c978d3a7226247e237f",
97
+ "a39c5c623a3e42448e109fb9ec6bc263",
98
+ "a6ed2ddb1c6f4d1aa945c5a39372f781",
99
+ "8cf950b898e142c1af9b4db92019aa4d",
100
+ "8ed7abd0602c43a1bfc0f96d7611d429",
101
+ "65ba2d78fde14bb2baf5ae1101d7e5ff",
102
+ "4795a78a75dc439a8da7df58bf738940",
103
+ "4545ff199b874d3680a83918513e1d4b",
104
+ "cad8fd90586443778568a1babb8c40e6",
105
+ "57e526d188b9414dabb3b1c895373864",
106
+ "8226a55726c54abba3a48dbfa8e1b6f6",
107
+ "86a3c1a4e9eb4989b23364f21e5df531",
108
+ "5ba39d9d997a45ca848e3e2ffd0e7307",
109
+ "4c22e1b396f342ffb90c1b50a0051862",
110
+ "370e5663868f411697bfb24f4e3efa09",
111
+ "3a338ac4d2944030a07843d8ea24e9fd",
112
+ "128f4312bcdc4166b9e24d8cdd34184d",
113
+ "1b65d6c8540e4f458886d5e7075ab30a",
114
+ "a5a9f8607fdd4f9cad7519eca573f3dc",
115
+ "926149594f94457295c60b4fad9cbac7",
116
+ "7e89bc79516f405e9684eacdce7b4551",
117
+ "c917f3a000fb44338e4afbeabeaab55f"
118
+ ]
119
+ },
120
+ "id": "ppTYW-n5snNp",
121
+ "outputId": "4d7a3003-296a-458c-a339-aeacf5232c91"
122
+ },
123
+ "outputs": [],
124
+ "source": [
125
+ "from data_loader import create_training_setup\n",
126
+ "from utils import denormalize_image\n",
127
+ "\n",
128
+ "tokenizer = AutoTokenizer.from_pretrained('prajjwal1/bert-mini')\n",
129
+ "\n",
130
+ "# train_augmentation_pipeline = AugmentationPipeline()\n",
131
+ "# Create the complete training setup using the function from pokemon_dataset.py\n",
132
+ "print(\"Creating training setup with train/val split and fixed batches...\")\n",
133
+ "training_setup = create_training_setup(\n",
134
+ " tokenizer=tokenizer,\n",
135
+ " test_set_size=0.2,\n",
136
+ " val_set_size=0.1,\n",
137
+ " batch_size=16,\n",
138
+ " num_workers=0,\n",
139
+ " num_viz_samples=4,\n",
140
+ " random_seed=42,\n",
141
+ " train_augmentation_pipeline=None\n",
142
+ ")\n",
143
+ "\n",
144
+ "# Extract components\n",
145
+ "train_loader = training_setup['train_loader']\n",
146
+ "val_loader = training_setup['val_loader']\n",
147
+ "fixed_train_batch = training_setup['fixed_train_batch']\n",
148
+ "fixed_val_batch = training_setup['fixed_val_batch']\n",
149
+ "fixed_train_attention_batch = training_setup['fixed_train_attention_batch']\n",
150
+ "fixed_val_attention_batch = training_setup['fixed_val_attention_batch']\n",
151
+ "\n",
152
+ "print(\"Training setup complete!\")\n",
153
+ "print(f\"Train loader batches: {len(train_loader)}\")\n",
154
+ "print(f\"Val loader batches: {len(val_loader)}\")\n",
155
+ "\n",
156
+ "# Test the training setup with fixed batches\n",
157
+ "print(\"\\nFixed batch shapes:\")\n",
158
+ "print(f\" Train batch - Images: {fixed_train_batch['image'].shape}\")\n",
159
+ "print(f\" Train batch - Text: {fixed_train_batch['text'].shape}\")\n",
160
+ "print(f\" Train batch - Attention: {fixed_train_batch['attention_mask'].shape}\")\n",
161
+ "print(f\" Val batch - Images: {fixed_val_batch['image'].shape}\")\n",
162
+ "\n",
163
+ "# Display sample images from fixed batches\n",
164
+ "fig, axes = plt.subplots(2, 4, figsize=(16, 8))\n",
165
+ "for i in range(4):\n",
166
+ " # Fixed train batch images\n",
167
+ " train_img = denormalize_image(fixed_train_batch['image'][i])\n",
168
+ " axes[0, i].imshow(train_img.permute(1, 2, 0))\n",
169
+ " axes[0, i].set_title(f\"Train: {fixed_train_batch['pokemon_name'][i]}\")\n",
170
+ " axes[0, i].axis('off')\n",
171
+ "\n",
172
+ " # Fixed val batch images\n",
173
+ " val_img = denormalize_image(fixed_val_batch['image'][i])\n",
174
+ " axes[1, i].imshow(val_img.permute(1, 2, 0))\n",
175
+ " axes[1, i].set_title(f\"Val: {fixed_val_batch['pokemon_name'][i]}\")\n",
176
+ " axes[1, i].axis('off')\n",
177
+ "\n",
178
+ "plt.suptitle(\"Fixed Batches for Training Visualization\", fontsize=16)\n",
179
+ "plt.tight_layout()\n",
180
+ "plt.show()\n",
181
+ "\n",
182
+ "\n",
183
+ "print(\"\\n✅ Dataset and batches loaded successfully from pokemon_dataset.py functionality!\")\n",
184
+ "print(\"Ready for training with proper train/val split and fixed visualization batches.\")\n"
185
+ ]
186
+ },
187
+ {
188
+ "cell_type": "raw",
189
+ "metadata": {
190
+ "id": "eJSVrf3ysnNq",
191
+ "vscode": {
192
+ "languageId": "raw"
193
+ }
194
+ },
195
+ "source": [
196
+ "## 2. Model Architecture Implementation\n"
197
+ ]
198
+ },
199
+ {
200
+ "cell_type": "code",
201
+ "execution_count": null,
202
+ "metadata": {
203
+ "colab": {
204
+ "base_uri": "https://localhost:8080/",
205
+ "height": 923,
206
+ "referenced_widgets": [
207
+ "bdf500351aea42698c6d6dd5a99021f3",
208
+ "ab61b90c1a5b4a2b9bb5c9d5a215bb3f",
209
+ "dc03fed540b74f3aa4a1b17ebf2c81d3",
210
+ "5837f2c4668646c0a6db2407aebb46e3",
211
+ "edeb423e9ff84e5c8a0d790368d68bba",
212
+ "bf8eb066cdaf4ac096dc14392d085daf",
213
+ "4e32e76c44fb449c8cb767abeb17868a",
214
+ "5c3cb981f324446eae642f7c23a539f0",
215
+ "2fe9614fe5984fa6b887d1e1b3e18b04",
216
+ "64277772cc30408e8ea29f0e268c8880",
217
+ "5b0d55ea20714104818097bd7d1f509a",
218
+ "7e21c6a9c7f44496b6f28513caefb631",
219
+ "439eba0eb4184c0ab83f65fc26bbe388",
220
+ "eee695744ec64aa7b71b9e85968c6f8f",
221
+ "c4ecdc9d982f49129368893c1c0aece9",
222
+ "5f5e7ff6e4c845b99602a4fa00ad550a",
223
+ "304d50e74ad744cdb3a7cc88739cb923",
224
+ "bfcc6d01c9ff4db698afa4318e7c91ac",
225
+ "b2bf751bb96746e4a828241f70e52050",
226
+ "828b227361fe45cd83964149e7475503",
227
+ "58ab975eaba2485cb0945482c26ecf3d",
228
+ "d0b4e43ab5cd4edda6cc061b36bf10a3"
229
+ ]
230
+ },
231
+ "id": "RnNQM3_ysnNr",
232
+ "outputId": "6905696e-05d6-4d97-dd9b-9dc36eea95b7"
233
+ },
234
+ "outputs": [],
235
+ "source": [
236
+ "from model import Generator\n",
237
+ "\n",
238
+ "# Test the generator\n",
239
+ "generator = Generator().to(device)\n",
240
+ "with torch.no_grad():\n",
241
+ " generated_images_256, generated_images_64 = generator(\n",
242
+ " fixed_train_batch['text'][:2].to(device),\n",
243
+ " fixed_train_batch['attention_mask'][:2].to(device)\n",
244
+ " )\n",
245
+ "print(f\"Generator output shape 256x256: {generated_images_256.shape}\")\n",
246
+ "print(f\"Generator output shape 64x64: {generated_images_64.shape}\")\n",
247
+ "\n",
248
+ "print(\"Generator test\")\n",
249
+ "plt.figure(figsize=(12, 8))\n",
250
+ "for i in range(2):\n",
251
+ " # 256x256 images\n",
252
+ " plt.subplot(2, 2, i+1)\n",
253
+ " img_256 = denormalize_image(generated_images_256[i].cpu())\n",
254
+ " plt.imshow(img_256.permute(1, 2, 0))\n",
255
+ " plt.title(f\"Generated 256x256 Sample {i+1}\")\n",
256
+ " plt.axis('off')\n",
257
+ "\n",
258
+ " # 64x64 images\n",
259
+ " plt.subplot(2, 2, i+3)\n",
260
+ " img_64 = denormalize_image(generated_images_64[i].cpu())\n",
261
+ " plt.imshow(img_64.permute(1, 2, 0))\n",
262
+ " plt.title(f\"Generated 64x64 Sample {i+1}\")\n",
263
+ " plt.axis('off')\n",
264
+ "plt.tight_layout()\n",
265
+ "plt.show()\n"
266
+ ]
267
+ },
268
+ {
269
+ "cell_type": "raw",
270
+ "metadata": {
271
+ "id": "7drCU21JsnNs",
272
+ "vscode": {
273
+ "languageId": "raw"
274
+ }
275
+ },
276
+ "source": [
277
+ "## 3. Training Setup and Utilities\n"
278
+ ]
279
+ },
280
+ {
281
+ "cell_type": "code",
282
+ "execution_count": null,
283
+ "metadata": {
284
+ "colab": {
285
+ "base_uri": "https://localhost:8080/"
286
+ },
287
+ "id": "iQdhzEQQsnNs",
288
+ "outputId": "2dbee275-3b6d-43da-8929-97e21403821f"
289
+ },
290
+ "outputs": [],
291
+ "source": [
292
+ "from discriminators import Discriminator256, Discriminator64\n",
293
+ "from losses import VGGPerceptualLoss, SobelLoss\n",
294
+ "from plots import save_attention_visualization\n",
295
+ "\n",
296
+ "def weights_init(m):\n",
297
+ " \"\"\"Initialize model weights according to the original DCGAN paper\"\"\"\n",
298
+ " classname = m.__class__.__name__\n",
299
+ " if classname.find('Conv') != -1:\n",
300
+ " nn.init.normal_(m.weight.data, 0.0, 0.02)\n",
301
+ " elif classname.find('BatchNorm') != -1:\n",
302
+ " nn.init.normal_(m.weight.data, 1.0, 0.02)\n",
303
+ " nn.init.constant_(m.bias.data, 0)\n",
304
+ "\n",
305
+ "generator = Generator().to(device)\n",
306
+ "discriminator_256 = Discriminator256().to(device)\n",
307
+ "discriminator_64 = Discriminator64().to(device)\n",
308
+ "\n",
309
+ "generator.apply(weights_init)\n",
310
+ "discriminator_256.apply(weights_init)\n",
311
+ "discriminator_64.apply(weights_init)\n",
312
+ "\n",
313
+ "\n",
314
+ "# Optimizer params\n",
315
+ "lr = 0.0002\n",
316
+ "beta1 = 0.5\n",
317
+ "beta2 = 0.999\n",
318
+ "\n",
319
+ "optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2))\n",
320
+ "optimizer_D_256 = optim.Adam(discriminator_256.parameters(), lr=lr, betas=(beta1, beta2))\n",
321
+ "optimizer_D_64 = optim.Adam(discriminator_64.parameters(), lr=lr, betas=(beta1, beta2))\n",
322
+ "\n",
323
+ "adv_criterion = nn.BCEWithLogitsLoss().to(device) # no sigmoid at the end of discriminators\n",
324
+ "l1_criterion = nn.L1Loss().to(device)\n",
325
+ "perc_criterion = VGGPerceptualLoss(device)\n",
326
+ "sobel_criterion = SobelLoss().to(device)"
327
+ ]
328
+ },
329
+ {
330
+ "cell_type": "code",
331
+ "execution_count": null,
332
+ "metadata": {},
333
+ "outputs": [],
334
+ "source": [
335
+ "from typing import TypedDict\n",
336
+ "import torch\n",
337
+ "from plots import save_comparison_grid\n",
338
+ "\n",
339
+ "# Create checkpoint saving directory\n",
340
+ "os.makedirs('models', exist_ok=True)\n",
341
+ "\n",
342
+ "# TypedDicts to pass and return many object at once, without\n",
343
+ "class LossesDict(TypedDict):\n",
344
+ " \"\"\"History of training losses\"\"\"\n",
345
+ " generator: list[float]\n",
346
+ " discriminator: list[float]\n",
347
+ " l1: list[float]\n",
348
+ " perceptual: list[float]\n",
349
+ " sobel: list[float]\n",
350
+ "\n",
351
+ "class ValidationLossesDict(TypedDict):\n",
352
+ " \"\"\"History of validation losses\"\"\"\n",
353
+ " l1: list[float]\n",
354
+ " perceptual: list[float]\n",
355
+ " sobel: list[float]\n",
356
+ " total: list[float]\n",
357
+ "\n",
358
+ "class DiscriminatorComponentsDict(TypedDict):\n",
359
+ " \"\"\"Components of the discriminator loss\"\"\"\n",
360
+ " real_uncond: float\n",
361
+ " real_cond: float\n",
362
+ " real_cond_wrong: float\n",
363
+ " fake_uncond: float\n",
364
+ "\n",
365
+ "class ValidationResultsDict(TypedDict):\n",
366
+ " \"\"\"Single losses for validation\"\"\"\n",
367
+ " l1: float\n",
368
+ " perceptual: float\n",
369
+ " sobel: float\n",
370
+ " total: float\n",
371
+ "\n",
372
+ "# Training history\n",
373
+ "losses: LossesDict = {\n",
374
+ " 'generator': [],\n",
375
+ " 'discriminator': [],\n",
376
+ " 'l1': [],\n",
377
+ " 'perceptual': [],\n",
378
+ " 'sobel': [],\n",
379
+ "}\n",
380
+ "\n",
381
+ "# Validation history\n",
382
+ "val_losses: ValidationLossesDict = {\n",
383
+ " 'l1': [],\n",
384
+ " 'perceptual': [],\n",
385
+ " 'sobel': [],\n",
386
+ " 'total': [],\n",
387
+ "}\n",
388
+ "\n",
389
+ "def validate_model(generator, val_loader, device, l1_criterion, perc_criterion, sobel_criterion) -> ValidationResultsDict:\n",
390
+ " \"\"\"\n",
391
+ " Validate the model on the validation set\n",
392
+ " Returns validation losses\n",
393
+ " \"\"\"\n",
394
+ " generator.eval()\n",
395
+ "\n",
396
+ " val_l1_loss = 0.0\n",
397
+ " val_perc_loss = 0.0\n",
398
+ " val_sobel_loss = 0.0\n",
399
+ " num_batches = 0\n",
400
+ "\n",
401
+ " with torch.no_grad():\n",
402
+ " for batch in val_loader:\n",
403
+ " # Move data to device\n",
404
+ " real_images = batch['image'].to(device)\n",
405
+ " text_ids = batch['text'].to(device)\n",
406
+ " attention_mask = batch['attention_mask'].to(device)\n",
407
+ "\n",
408
+ " # Generate images\n",
409
+ " generated_images, _ = generator(text_ids, attention_mask)\n",
410
+ "\n",
411
+ " # Calculate validation losses (no adversarial loss)\n",
412
+ " batch_l1_loss = l1_criterion(generated_images, real_images)\n",
413
+ " batch_perc_loss = perc_criterion(generated_images, real_images)\n",
414
+ " batch_sobel_loss = sobel_criterion(generated_images, real_images)\n",
415
+ "\n",
416
+ " val_l1_loss += batch_l1_loss.item()\n",
417
+ " val_perc_loss += batch_perc_loss.item()\n",
418
+ " val_sobel_loss += batch_sobel_loss.item()\n",
419
+ " num_batches += 1\n",
420
+ "\n",
421
+ " # Calculate averages\n",
422
+ " avg_val_l1 = val_l1_loss / num_batches\n",
423
+ " avg_val_perc = val_perc_loss / num_batches\n",
424
+ " avg_val_sobel = val_sobel_loss / num_batches\n",
425
+ " avg_val_total = avg_val_l1 + avg_val_perc + avg_val_sobel\n",
426
+ "\n",
427
+ " # Set models back to training mode\n",
428
+ " generator.train()\n",
429
+ "\n",
430
+ " return ValidationResultsDict(\n",
431
+ " l1=avg_val_l1,\n",
432
+ " perceptual=avg_val_perc,\n",
433
+ " sobel=avg_val_sobel,\n",
434
+ " total=avg_val_total\n",
435
+ " )\n",
436
+ "\n",
437
+ "def create_mismatched_text_batch(text_ids, attention_mask):\n",
438
+ " \"\"\"Create a batch with mismatched text for wrong text conditioning\"\"\"\n",
439
+ " batch_size = text_ids.size(0)\n",
440
+ " indices = torch.randperm(batch_size)\n",
441
+ " return text_ids[indices], attention_mask[indices]\n",
442
+ "\n",
443
+ "def compute_discriminator_loss(\n",
444
+ " discriminator,\n",
445
+ " real_images,\n",
446
+ " fake_images,\n",
447
+ " text_ids,\n",
448
+ " attention_mask,\n",
449
+ " wrong_text_ids,\n",
450
+ " wrong_attention_mask,\n",
451
+ " real_labels,\n",
452
+ " fake_labels,\n",
453
+ " adv_criterion\n",
454
+ ") -> tuple[torch.Tensor, DiscriminatorComponentsDict]:\n",
455
+ " \"\"\"Compute discriminator loss with the 4 components.\n",
456
+ " Returns the total loss and the 4 components.\"\"\"\n",
457
+ " # Real images with correct text\n",
458
+ " real_uncond, real_cond = discriminator(real_images, text_ids, attention_mask, return_both=True)\n",
459
+ " real_uncond_loss = adv_criterion(real_uncond, real_labels)\n",
460
+ " real_cond_loss = adv_criterion(real_cond, real_labels)\n",
461
+ "\n",
462
+ " # Real images with wrong text\n",
463
+ " _, real_cond_wrong = discriminator(real_images, wrong_text_ids, wrong_attention_mask, return_both=True)\n",
464
+ " real_cond_wrong_loss = adv_criterion(real_cond_wrong, fake_labels)\n",
465
+ "\n",
466
+ " # Fake images with wrong text\n",
467
+ " fake_uncond, _ = discriminator(fake_images.detach(), wrong_text_ids, wrong_attention_mask, return_both=True)\n",
468
+ " fake_uncond_loss = adv_criterion(fake_uncond, fake_labels)\n",
469
+ "\n",
470
+ " total_loss = (real_uncond_loss + real_cond_loss + real_cond_wrong_loss + fake_uncond_loss) / 4\n",
471
+ "\n",
472
+ " components: DiscriminatorComponentsDict = {\n",
473
+ " 'real_uncond': real_uncond_loss.item(),\n",
474
+ " 'real_cond': real_cond_loss.item(),\n",
475
+ " 'real_cond_wrong': real_cond_wrong_loss.item(),\n",
476
+ " 'fake_uncond': fake_uncond_loss.item(),\n",
477
+ " }\n",
478
+ "\n",
479
+ " return total_loss, components\n",
480
+ "\n",
481
+ "def compute_generator_adversarial_loss(\n",
482
+ " discriminator,\n",
483
+ " fake_images,\n",
484
+ " text_ids,\n",
485
+ " attention_mask,\n",
486
+ " real_labels,\n",
487
+ " adv_criterion\n",
488
+ ") -> torch.Tensor:\n",
489
+ " \"\"\"Compute generator adversarial loss for one discriminator\"\"\"\n",
490
+ " fake_uncond, fake_cond = discriminator(fake_images, text_ids, attention_mask, return_both=True)\n",
491
+ " uncond_loss = adv_criterion(fake_uncond, real_labels)\n",
492
+ " cond_loss = adv_criterion(fake_cond, real_labels)\n",
493
+ " return (uncond_loss + cond_loss) / 2\n",
494
+ "\n",
495
+ "def compute_loss(fake_images, real_images, criterion, lmd):\n",
496
+ " \"\"\"Compute a reconstruction loss only if its lambda > 0\"\"\"\n",
497
+ " return criterion(fake_images, real_images) if lmd > 0 else torch.tensor(0.0, device=device)\n",
498
+ "\n",
499
+ "\n",
500
+ "epoch = 0\n",
501
+ "noise_dim = 100\n"
502
+ ]
503
+ },
504
+ {
505
+ "cell_type": "raw",
506
+ "metadata": {
507
+ "id": "Oenm8AkasnNt",
508
+ "vscode": {
509
+ "languageId": "raw"
510
+ }
511
+ },
512
+ "source": [
513
+ "## 4. GAN Training Loop"
514
+ ]
515
+ },
516
+ {
517
+ "cell_type": "code",
518
+ "execution_count": null,
519
+ "metadata": {
520
+ "colab": {
521
+ "base_uri": "https://localhost:8080/",
522
+ "height": 1000
523
+ },
524
+ "id": "gmo0Mi6osnNt",
525
+ "outputId": "a4691684-def3-4d29-d00d-20ae40287c8c"
526
+ },
527
+ "outputs": [],
528
+ "source": [
529
+ "from IPython.display import clear_output\n",
530
+ "\n",
531
+ "total_epochs = 150\n",
532
+ "display_interval = 1 # To show generation of training sample\n",
533
+ "save_interval = 15 # To save checkpoint\n",
534
+ "clear_interval = 22 # To clear cell output. If too high or not present, Kaggle page would crash\n",
535
+ "\n",
536
+ "lambda_l1 = 1.0\n",
537
+ "lambda_adv = 1.0\n",
538
+ "lambda_perceptual = 0.0\n",
539
+ "lambda_sobel = 0.0\n",
540
+ "\n",
541
+ "real_label = 1.0\n",
542
+ "fake_label = 0.0\n",
543
+ "\n",
544
+ "print(\"Starting training with dual discriminators...\")\n",
545
+ "\n",
546
+ "for epoch in range(epoch, total_epochs):\n",
547
+ " epoch_g_loss = 0.0\n",
548
+ " epoch_d_loss_64 = 0.0\n",
549
+ " epoch_d_loss_256 = 0.0\n",
550
+ " epoch_l1_loss = 0.0\n",
551
+ " epoch_perc_loss = 0.0\n",
552
+ " epoch_sobel_loss = 0.0\n",
553
+ "\n",
554
+ " # Track discriminator loss components\n",
555
+ " epoch_d256_components: DiscriminatorComponentsDict = {'real_uncond': 0.0, 'real_cond': 0.0, 'real_cond_wrong': 0.0, 'fake_uncond': 0.0}\n",
556
+ " epoch_d64_components: DiscriminatorComponentsDict = {'real_uncond': 0.0, 'real_cond': 0.0, 'real_cond_wrong': 0.0, 'fake_uncond': 0.0}\n",
557
+ "\n",
558
+ " progress_bar = tqdm(train_loader, desc=f\"Epoch {epoch+1}/{total_epochs}\")\n",
559
+ "\n",
560
+ " for i, batch in enumerate(progress_bar):\n",
561
+ " batch_size = batch['image'].size(0)\n",
562
+ "\n",
563
+ " # Move data to device\n",
564
+ " real_images = batch['image'].to(device)\n",
565
+ " text_ids = batch['text'].to(device)\n",
566
+ " attention_mask = batch['attention_mask'].to(device)\n",
567
+ "\n",
568
+ " # Create labels and mismatched text for GAN training\n",
569
+ " real_labels = torch.full((batch_size, 1), real_label, device=device, dtype=torch.float)\n",
570
+ " fake_labels = torch.full((batch_size, 1), fake_label, device=device, dtype=torch.float)\n",
571
+ " wrong_text_ids, wrong_attention_mask = create_mismatched_text_batch(text_ids, attention_mask)\n",
572
+ "\n",
573
+ " # Generate fake images\n",
574
+ " fake_images_256, fake_images_64 = generator(text_ids, attention_mask)\n",
575
+ " real_images_64 = F.interpolate(real_images, size=(64, 64), mode='bilinear', align_corners=False)\n",
576
+ "\n",
577
+ " # Training both discriminators\n",
578
+ " optimizer_D_256.zero_grad()\n",
579
+ " optimizer_D_64.zero_grad()\n",
580
+ "\n",
581
+ " d_loss_256, d256_components = compute_discriminator_loss(\n",
582
+ " discriminator_256, real_images, fake_images_256,\n",
583
+ " text_ids, attention_mask, wrong_text_ids, wrong_attention_mask,\n",
584
+ " real_labels, fake_labels, adv_criterion\n",
585
+ " )\n",
586
+ " d_loss_256.backward()\n",
587
+ "\n",
588
+ " d_loss_64, d64_components = compute_discriminator_loss(\n",
589
+ " discriminator_64, real_images_64, fake_images_64,\n",
590
+ " text_ids, attention_mask, wrong_text_ids, wrong_attention_mask,\n",
591
+ " real_labels, fake_labels, adv_criterion\n",
592
+ " )\n",
593
+ " d_loss_64.backward()\n",
594
+ "\n",
595
+ " optimizer_D_256.step()\n",
596
+ " optimizer_D_64.step()\n",
597
+ "\n",
598
+ " # Training generator\n",
599
+ " optimizer_G.zero_grad()\n",
600
+ "\n",
601
+ " # Adversarial losses for both discriminators\n",
602
+ " g_adv_loss_256 = compute_generator_adversarial_loss(\n",
603
+ " discriminator_256, fake_images_256, text_ids, attention_mask, real_labels, adv_criterion\n",
604
+ " )\n",
605
+ " g_adv_loss_64 = compute_generator_adversarial_loss(\n",
606
+ " discriminator_64, fake_images_64, text_ids, attention_mask, real_labels, adv_criterion\n",
607
+ " )\n",
608
+ " adversarial_loss = (g_adv_loss_256 + g_adv_loss_64) / 2\n",
609
+ "\n",
610
+ " # Compute losses if their lambda is > 0\n",
611
+ " l1_loss = compute_loss(fake_images_256, real_images, l1_criterion, lambda_l1)\n",
612
+ " perc_loss = compute_loss(fake_images_256, real_images, perc_criterion, lambda_perceptual)\n",
613
+ " sobel_loss = compute_loss(fake_images_256, real_images, sobel_criterion, lambda_sobel)\n",
614
+ "\n",
615
+ " # Total generator loss\n",
616
+ " g_loss = (\n",
617
+ " lambda_adv * adversarial_loss +\n",
618
+ " lambda_l1 * l1_loss +\n",
619
+ " lambda_perceptual * perc_loss +\n",
620
+ " lambda_sobel * sobel_loss\n",
621
+ " )\n",
622
+ " g_loss.backward()\n",
623
+ " optimizer_G.step()\n",
624
+ "\n",
625
+ " # Update loss tracking\n",
626
+ " epoch_g_loss += g_loss.item()\n",
627
+ " epoch_d_loss_256 += d_loss_256.item()\n",
628
+ " epoch_d_loss_64 += d_loss_64.item()\n",
629
+ " epoch_l1_loss += l1_loss.item()\n",
630
+ " epoch_perc_loss += perc_loss.item()\n",
631
+ " epoch_sobel_loss += sobel_loss.item()\n",
632
+ "\n",
633
+ " # Update discriminator component tracking\n",
634
+ " for key in epoch_d256_components:\n",
635
+ " epoch_d256_components[key] += d256_components[key]\n",
636
+ " epoch_d64_components[key] += d64_components[key]\n",
637
+ "\n",
638
+ " # Update progress bar with detailed losses and loss components\n",
639
+ " progress_bar.set_postfix({\n",
640
+ " 'G': f'{g_loss.item():.3f}',\n",
641
+ " 'L1': f'{l1_loss.item():.3f}',\n",
642
+ " 'Adv': f'{adversarial_loss.item():.3f}',\n",
643
+ " 'D256': f'{d_loss_256.item():.3f}',\n",
644
+ " 'D256_ru': f'{d256_components[\"real_uncond\"]:.3f}',\n",
645
+ " 'D256_rc': f'{d256_components[\"real_cond\"]:.3f}',\n",
646
+ " 'D256_rcw': f'{d256_components[\"real_cond_wrong\"]:.3f}',\n",
647
+ " 'D256_fu': f'{d256_components[\"fake_uncond\"]:.3f}',\n",
648
+ " 'D64': f'{d_loss_64.item():.3f}',\n",
649
+ " 'D64_ru': f'{d64_components[\"real_uncond\"]:.3f}',\n",
650
+ " 'D64_rc': f'{d64_components[\"real_cond\"]:.3f}',\n",
651
+ " 'D64_rcw': f'{d64_components[\"real_cond_wrong\"]:.3f}',\n",
652
+ " 'D64_fu': f'{d64_components[\"fake_uncond\"]:.3f}',\n",
653
+ " })\n",
654
+ "\n",
655
+ " # Calculate average losses for the epoch\n",
656
+ " avg_g_loss = epoch_g_loss / len(train_loader)\n",
657
+ " avg_d_loss_256 = epoch_d_loss_256 / len(train_loader)\n",
658
+ " avg_d_loss_64 = epoch_d_loss_64 / len(train_loader)\n",
659
+ " avg_l1_loss = epoch_l1_loss / len(train_loader)\n",
660
+ " avg_perc_loss = epoch_perc_loss / len(train_loader)\n",
661
+ " avg_sobel_loss = epoch_sobel_loss / len(train_loader)\n",
662
+ "\n",
663
+ " # Calculate average discriminator components for epoch\n",
664
+ " avg_d256_components = {key: val / len(train_loader) for key, val in epoch_d256_components.items()}\n",
665
+ " avg_d64_components = {key: val / len(train_loader) for key, val in epoch_d64_components.items()}\n",
666
+ "\n",
667
+ " # Store losses (combine discriminator losses)\n",
668
+ " losses['generator'].append(avg_g_loss)\n",
669
+ " losses['discriminator'].append((avg_d_loss_256 + avg_d_loss_64) / 2)\n",
670
+ " losses['l1'].append(avg_l1_loss)\n",
671
+ " losses['perceptual'].append(avg_perc_loss)\n",
672
+ " losses['sobel'].append(avg_sobel_loss)\n",
673
+ "\n",
674
+ " print(f\"Running validation for epoch {epoch+1}...\")\n",
675
+ " validation_results = validate_model(generator, val_loader, device, l1_criterion, perc_criterion, sobel_criterion)\n",
676
+ "\n",
677
+ " # Store validation losses\n",
678
+ "\n",
679
+ " for k, v in validation_results.items():\n",
680
+ " val_losses[k].append(v)\n",
681
+ "\n",
682
+ " if (epoch + 1) % clear_interval == 0:\n",
683
+ " clear_output(wait=True)\n",
684
+ "\n",
685
+ " print(f\"Epoch [{epoch+1}/{total_epochs}]\")\n",
686
+ " print(f\" Train - D_256: {avg_d_loss_256:.4f}, D_64: {avg_d_loss_64:.4f}, G_loss: {avg_g_loss:.4f}\")\n",
687
+ " print(f\" D_256 Components - RU: {avg_d256_components['real_uncond']:.4f}, RC: {avg_d256_components['real_cond']:.4f}, RCW: {avg_d256_components['real_cond_wrong']:.4f}, FU: {avg_d256_components['fake_uncond']:.4f}\")\n",
688
+ " print(f\" D_64 Components - RU: {avg_d64_components['real_uncond']:.4f}, RC: {avg_d64_components['real_cond']:.4f}, RCW: {avg_d64_components['real_cond_wrong']:.4f}, FU: {avg_d64_components['fake_uncond']:.4f}\")\n",
689
+ " print(f\" Train - L1: {avg_l1_loss:.4f}, Perceptual: {avg_perc_loss:.4f}, Sobel: {avg_sobel_loss:.4f}\")\n",
690
+ " print(f\" Val - L1: {validation_results['l1']:.4f}, Perceptual: {validation_results['perceptual']:.4f}, Sobel: {validation_results['sobel']:.4f}, Total: {validation_results['total']:.4f}\")\n",
691
+ " print(\" Legend: RU=RealUncond, RC=RealCond, RCW=RealCondWrong, FU=FakeUncond\")\n",
692
+ "\n",
693
+ " # Display generated images\n",
694
+ " if (epoch + 1) % display_interval == 0:\n",
695
+ " print(f\"\\nGenerating sample images at epoch {epoch+1}:\")\n",
696
+ " print(\"256x256 Training Images:\")\n",
697
+ " save_comparison_grid(epoch+1, generator, fixed_train_batch, \"train_256\", device, show_inline=True)\n",
698
+ " print(\"64x64 Training Images:\")\n",
699
+ " save_comparison_grid(epoch+1, generator, fixed_train_batch, \"train_64\", device, show_inline=True)\n",
700
+ "\n",
701
+ " # Save checkpoint and show visualizations\n",
702
+ " if (epoch + 1) % save_interval == 0:\n",
703
+ " checkpoint_path = f'models/checkpoint_epoch_{epoch+1}.pth'\n",
704
+ " all_losses = {'train': losses, 'val': val_losses}\n",
705
+ " checkpoint = {\n",
706
+ " 'epoch': epoch,\n",
707
+ " 'generator_state_dict': generator.state_dict(),\n",
708
+ " 'discriminator_256_state_dict': discriminator_256.state_dict(),\n",
709
+ " 'discriminator_64_state_dict': discriminator_64.state_dict(),\n",
710
+ " 'g_optimizer_state_dict': optimizer_G.state_dict(),\n",
711
+ " 'd_optimizer_state_dict': optimizer_D_256.state_dict(),\n",
712
+ " 'd_64_optimizer_state_dict': optimizer_D_64.state_dict(),\n",
713
+ " 'losses': all_losses\n",
714
+ " }\n",
715
+ " torch.save(checkpoint, checkpoint_path)\n",
716
+ " print(f\"Checkpoint saved to {checkpoint_path}\")\n",
717
+ "\n",
718
+ " print(\"256x256 Validation Images:\")\n",
719
+ " save_comparison_grid(epoch+1, generator, fixed_val_batch, \"val_256\", device, show_inline=True)\n",
720
+ " print(\"64x64 Validation Images:\")\n",
721
+ " save_comparison_grid(epoch+1, generator, fixed_val_batch, \"val_64\", device, show_inline=True)\n",
722
+ " save_attention_visualization(epoch+1, generator, tokenizer, fixed_train_batch, device, \"train\", show_inline=True)\n",
723
+ " save_attention_visualization(epoch+1, generator, tokenizer, fixed_val_batch, device, \"val\", show_inline=True)\n",
724
+ "\n",
725
+ "print(\"Training completed!\")\n"
726
+ ]
727
+ },
728
+ {
729
+ "cell_type": "raw",
730
+ "metadata": {
731
+ "id": "rbv1Wz4csnNu",
732
+ "vscode": {
733
+ "languageId": "raw"
734
+ }
735
+ },
736
+ "source": [
737
+ "## 5. Training Analysis and Visualization\n"
738
+ ]
739
+ },
740
+ {
741
+ "cell_type": "code",
742
+ "execution_count": null,
743
+ "metadata": {
744
+ "id": "l_90zE2CsnNu"
745
+ },
746
+ "outputs": [],
747
+ "source": [
748
+ "from plots import save_plot_losses, save_plot_non_gan_losses\n",
749
+ "\n",
750
+ "\n",
751
+ "save_plot_losses(\n",
752
+ " losses_g=losses['generator'],\n",
753
+ " losses_d=losses['discriminator'],\n",
754
+ " output_dir=\"training_output\",\n",
755
+ " show_inline=True\n",
756
+ ")\n",
757
+ "\n",
758
+ "# Plot training vs validation losses for non-GAN components (so except \"generator\" and \"discriminator\" from losses)\n",
759
+ "# Convert to list of dicts format expected by save_plot_non_gan_losses\n",
760
+ "train_losses_history = []\n",
761
+ "val_losses_history = []\n",
762
+ "\n",
763
+ "for i in range(len(losses['l1'])):\n",
764
+ " train_losses_history.append({\n",
765
+ " 'l1': losses['l1'][i],\n",
766
+ " 'perceptual': losses['perceptual'][i],\n",
767
+ " 'sobel': losses['sobel'][i],\n",
768
+ " 'total': losses['l1'][i] + losses['perceptual'][i] + losses['sobel'][i]\n",
769
+ " })\n",
770
+ "\n",
771
+ "for i in range(len(val_losses['l1'])):\n",
772
+ " val_losses_history.append({\n",
773
+ " 'l1': val_losses['l1'][i],\n",
774
+ " 'perceptual': val_losses['perceptual'][i],\n",
775
+ " 'sobel': val_losses['sobel'][i],\n",
776
+ " 'total': val_losses['total'][i]\n",
777
+ " })\n",
778
+ "\n",
779
+ "save_plot_non_gan_losses(\n",
780
+ " train_losses_history=train_losses_history,\n",
781
+ " val_losses_history=val_losses_history,\n",
782
+ " output_dir=\"training_output\",\n",
783
+ " show_inline=True\n",
784
+ ")\n",
785
+ "\n",
786
+ "# Print final statistics\n",
787
+ "print(f\"Final Train - Generator Loss: {losses['generator'][-1]:.4f}\")\n",
788
+ "print(f\"Final Train - Discriminator Loss: {losses['discriminator'][-1]:.4f}\")\n",
789
+ "print(f\"Final Train - L1 Loss: {losses['l1'][-1]:.4f}\")\n",
790
+ "print(f\"Final Train - Perceptual Loss: {losses['perceptual'][-1]:.4f}\")\n",
791
+ "print(f\"Final Train - Sobel Loss: {losses['sobel'][-1]:.4f}\")\n",
792
+ "\n",
793
+ "print(f\"Final Val - L1 Loss: {val_losses['l1'][-1]:.4f}\")\n",
794
+ "print(f\"Final Val - Perceptual Loss: {val_losses['perceptual'][-1]:.4f}\")\n",
795
+ "print(f\"Final Val - Sobel Loss: {val_losses['sobel'][-1]:.4f}\")\n",
796
+ "print(f\"Final Val - Total Loss: {val_losses['total'][-1]:.4f}\")\n"
797
+ ]
798
+ },
799
+ {
800
+ "cell_type": "code",
801
+ "execution_count": null,
802
+ "metadata": {
803
+ "id": "Io7I7RTqsnNu"
804
+ },
805
+ "outputs": [],
806
+ "source": [
807
+ "# Generate a grid of final results\n",
808
+ "print(\"Final Results - Generated Pokemon Sprites (256x256):\")\n",
809
+ "batch = next(iter(train_loader))\n",
810
+ "save_comparison_grid(0, generator, batch, \"final_256\", device, show_inline=True)\n",
811
+ "\n",
812
+ "print(\"Final Results - Generated Pokemon Sprites (64x64):\")\n",
813
+ "save_comparison_grid(0, generator, batch, \"final_64\", device, show_inline=True)\n"
814
+ ]
815
+ },
816
+ {
817
+ "cell_type": "raw",
818
+ "metadata": {
819
+ "id": "3a_jxGvCsnNu",
820
+ "vscode": {
821
+ "languageId": "raw"
822
+ }
823
+ },
824
+ "source": [
825
+ "## 7. Model Analysis\n"
826
+ ]
827
+ },
828
+ {
829
+ "cell_type": "code",
830
+ "execution_count": null,
831
+ "metadata": {},
832
+ "outputs": [],
833
+ "source": [
834
+ "def count_parameters(model):\n",
835
+ " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
836
+ "\n",
837
+ "print(f\"Generator parameters: {count_parameters(generator):,}\")\n",
838
+ "print(f\"Discriminator (256) parameters: {count_parameters(discriminator_256):,}\")\n",
839
+ "print(f\"Discriminator (64) parameters: {count_parameters(discriminator_64):,}\")\n"
840
+ ]
841
+ }
842
+ ],
843
+ "metadata": {
844
+ "accelerator": "GPU",
845
+ "colab": {
846
+ "gpuType": "T4",
847
+ "provenance": []
848
+ },
849
+ "kernelspec": {
850
+ "display_name": "Python 3 (ipykernel)",
851
+ "language": "python",
852
+ "name": "python3"
853
+ },
854
+ "language_info": {
855
+ "codemirror_mode": {
856
+ "name": "ipython",
857
+ "version": 3
858
+ },
859
+ "file_extension": ".py",
860
+ "mimetype": "text/x-python",
861
+ "name": "python",
862
+ "nbconvert_exporter": "python",
863
+ "pygments_lexer": "ipython3",
864
+ "version": "3.12.11"
865
+ },
866
+ "widgets": {
867
+ "application/vnd.jupyter.widget-state+json": {
868
+ "128f4312bcdc4166b9e24d8cdd34184d": {
869
+ "model_module": "@jupyter-widgets/base",
870
+ "model_module_version": "1.2.0",
871
+ "model_name": "LayoutModel",
872
+ "state": {
873
+ "_model_module": "@jupyter-widgets/base",
874
+ "_model_module_version": "1.2.0",
875
+ "_model_name": "LayoutModel",
876
+ "_view_count": null,
877
+ "_view_module": "@jupyter-widgets/base",
878
+ "_view_module_version": "1.2.0",
879
+ "_view_name": "LayoutView",
880
+ "align_content": null,
881
+ "align_items": null,
882
+ "align_self": null,
883
+ "border": null,
884
+ "bottom": null,
885
+ "display": null,
886
+ "flex": null,
887
+ "flex_flow": null,
888
+ "grid_area": null,
889
+ "grid_auto_columns": null,
890
+ "grid_auto_flow": null,
891
+ "grid_auto_rows": null,
892
+ "grid_column": null,
893
+ "grid_gap": null,
894
+ "grid_row": null,
895
+ "grid_template_areas": null,
896
+ "grid_template_columns": null,
897
+ "grid_template_rows": null,
898
+ "height": null,
899
+ "justify_content": null,
900
+ "justify_items": null,
901
+ "left": null,
902
+ "margin": null,
903
+ "max_height": null,
904
+ "max_width": null,
905
+ "min_height": null,
906
+ "min_width": null,
907
+ "object_fit": null,
908
+ "object_position": null,
909
+ "order": null,
910
+ "overflow": null,
911
+ "overflow_x": null,
912
+ "overflow_y": null,
913
+ "padding": null,
914
+ "right": null,
915
+ "top": null,
916
+ "visibility": null,
917
+ "width": null
918
+ }
919
+ },
920
+ "1b65d6c8540e4f458886d5e7075ab30a": {
921
+ "model_module": "@jupyter-widgets/controls",
922
+ "model_module_version": "1.5.0",
923
+ "model_name": "DescriptionStyleModel",
924
+ "state": {
925
+ "_model_module": "@jupyter-widgets/controls",
926
+ "_model_module_version": "1.5.0",
927
+ "_model_name": "DescriptionStyleModel",
928
+ "_view_count": null,
929
+ "_view_module": "@jupyter-widgets/base",
930
+ "_view_module_version": "1.2.0",
931
+ "_view_name": "StyleView",
932
+ "description_width": ""
933
+ }
934
+ },
935
+ "2fe9614fe5984fa6b887d1e1b3e18b04": {
936
+ "model_module": "@jupyter-widgets/controls",
937
+ "model_module_version": "1.5.0",
938
+ "model_name": "ProgressStyleModel",
939
+ "state": {
940
+ "_model_module": "@jupyter-widgets/controls",
941
+ "_model_module_version": "1.5.0",
942
+ "_model_name": "ProgressStyleModel",
943
+ "_view_count": null,
944
+ "_view_module": "@jupyter-widgets/base",
945
+ "_view_module_version": "1.2.0",
946
+ "_view_name": "StyleView",
947
+ "bar_color": null,
948
+ "description_width": ""
949
+ }
950
+ },
951
+ "304d50e74ad744cdb3a7cc88739cb923": {
952
+ "model_module": "@jupyter-widgets/base",
953
+ "model_module_version": "1.2.0",
954
+ "model_name": "LayoutModel",
955
+ "state": {
956
+ "_model_module": "@jupyter-widgets/base",
957
+ "_model_module_version": "1.2.0",
958
+ "_model_name": "LayoutModel",
959
+ "_view_count": null,
960
+ "_view_module": "@jupyter-widgets/base",
961
+ "_view_module_version": "1.2.0",
962
+ "_view_name": "LayoutView",
963
+ "align_content": null,
964
+ "align_items": null,
965
+ "align_self": null,
966
+ "border": null,
967
+ "bottom": null,
968
+ "display": null,
969
+ "flex": null,
970
+ "flex_flow": null,
971
+ "grid_area": null,
972
+ "grid_auto_columns": null,
973
+ "grid_auto_flow": null,
974
+ "grid_auto_rows": null,
975
+ "grid_column": null,
976
+ "grid_gap": null,
977
+ "grid_row": null,
978
+ "grid_template_areas": null,
979
+ "grid_template_columns": null,
980
+ "grid_template_rows": null,
981
+ "height": null,
982
+ "justify_content": null,
983
+ "justify_items": null,
984
+ "left": null,
985
+ "margin": null,
986
+ "max_height": null,
987
+ "max_width": null,
988
+ "min_height": null,
989
+ "min_width": null,
990
+ "object_fit": null,
991
+ "object_position": null,
992
+ "order": null,
993
+ "overflow": null,
994
+ "overflow_x": null,
995
+ "overflow_y": null,
996
+ "padding": null,
997
+ "right": null,
998
+ "top": null,
999
+ "visibility": null,
1000
+ "width": null
1001
+ }
1002
+ },
1003
+ "370e5663868f411697bfb24f4e3efa09": {
1004
+ "model_module": "@jupyter-widgets/controls",
1005
+ "model_module_version": "1.5.0",
1006
+ "model_name": "HTMLModel",
1007
+ "state": {
1008
+ "_dom_classes": [],
1009
+ "_model_module": "@jupyter-widgets/controls",
1010
+ "_model_module_version": "1.5.0",
1011
+ "_model_name": "HTMLModel",
1012
+ "_view_count": null,
1013
+ "_view_module": "@jupyter-widgets/controls",
1014
+ "_view_module_version": "1.5.0",
1015
+ "_view_name": "HTMLView",
1016
+ "description": "",
1017
+ "description_tooltip": null,
1018
+ "layout": "IPY_MODEL_7e89bc79516f405e9684eacdce7b4551",
1019
+ "placeholder": "​",
1020
+ "style": "IPY_MODEL_c917f3a000fb44338e4afbeabeaab55f",
1021
+ "value": " 232k/? [00:00&lt;00:00, 12.0MB/s]"
1022
+ }
1023
+ },
1024
+ "3a338ac4d2944030a07843d8ea24e9fd": {
1025
+ "model_module": "@jupyter-widgets/base",
1026
+ "model_module_version": "1.2.0",
1027
+ "model_name": "LayoutModel",
1028
+ "state": {
1029
+ "_model_module": "@jupyter-widgets/base",
1030
+ "_model_module_version": "1.2.0",
1031
+ "_model_name": "LayoutModel",
1032
+ "_view_count": null,
1033
+ "_view_module": "@jupyter-widgets/base",
1034
+ "_view_module_version": "1.2.0",
1035
+ "_view_name": "LayoutView",
1036
+ "align_content": null,
1037
+ "align_items": null,
1038
+ "align_self": null,
1039
+ "border": null,
1040
+ "bottom": null,
1041
+ "display": null,
1042
+ "flex": null,
1043
+ "flex_flow": null,
1044
+ "grid_area": null,
1045
+ "grid_auto_columns": null,
1046
+ "grid_auto_flow": null,
1047
+ "grid_auto_rows": null,
1048
+ "grid_column": null,
1049
+ "grid_gap": null,
1050
+ "grid_row": null,
1051
+ "grid_template_areas": null,
1052
+ "grid_template_columns": null,
1053
+ "grid_template_rows": null,
1054
+ "height": null,
1055
+ "justify_content": null,
1056
+ "justify_items": null,
1057
+ "left": null,
1058
+ "margin": null,
1059
+ "max_height": null,
1060
+ "max_width": null,
1061
+ "min_height": null,
1062
+ "min_width": null,
1063
+ "object_fit": null,
1064
+ "object_position": null,
1065
+ "order": null,
1066
+ "overflow": null,
1067
+ "overflow_x": null,
1068
+ "overflow_y": null,
1069
+ "padding": null,
1070
+ "right": null,
1071
+ "top": null,
1072
+ "visibility": null,
1073
+ "width": null
1074
+ }
1075
+ },
1076
+ "439eba0eb4184c0ab83f65fc26bbe388": {
1077
+ "model_module": "@jupyter-widgets/controls",
1078
+ "model_module_version": "1.5.0",
1079
+ "model_name": "HTMLModel",
1080
+ "state": {
1081
+ "_dom_classes": [],
1082
+ "_model_module": "@jupyter-widgets/controls",
1083
+ "_model_module_version": "1.5.0",
1084
+ "_model_name": "HTMLModel",
1085
+ "_view_count": null,
1086
+ "_view_module": "@jupyter-widgets/controls",
1087
+ "_view_module_version": "1.5.0",
1088
+ "_view_name": "HTMLView",
1089
+ "description": "",
1090
+ "description_tooltip": null,
1091
+ "layout": "IPY_MODEL_304d50e74ad744cdb3a7cc88739cb923",
1092
+ "placeholder": "​",
1093
+ "style": "IPY_MODEL_bfcc6d01c9ff4db698afa4318e7c91ac",
1094
+ "value": "model.safetensors: 100%"
1095
+ }
1096
+ },
1097
+ "4545ff199b874d3680a83918513e1d4b": {
1098
+ "model_module": "@jupyter-widgets/base",
1099
+ "model_module_version": "1.2.0",
1100
+ "model_name": "LayoutModel",
1101
+ "state": {
1102
+ "_model_module": "@jupyter-widgets/base",
1103
+ "_model_module_version": "1.2.0",
1104
+ "_model_name": "LayoutModel",
1105
+ "_view_count": null,
1106
+ "_view_module": "@jupyter-widgets/base",
1107
+ "_view_module_version": "1.2.0",
1108
+ "_view_name": "LayoutView",
1109
+ "align_content": null,
1110
+ "align_items": null,
1111
+ "align_self": null,
1112
+ "border": null,
1113
+ "bottom": null,
1114
+ "display": null,
1115
+ "flex": null,
1116
+ "flex_flow": null,
1117
+ "grid_area": null,
1118
+ "grid_auto_columns": null,
1119
+ "grid_auto_flow": null,
1120
+ "grid_auto_rows": null,
1121
+ "grid_column": null,
1122
+ "grid_gap": null,
1123
+ "grid_row": null,
1124
+ "grid_template_areas": null,
1125
+ "grid_template_columns": null,
1126
+ "grid_template_rows": null,
1127
+ "height": null,
1128
+ "justify_content": null,
1129
+ "justify_items": null,
1130
+ "left": null,
1131
+ "margin": null,
1132
+ "max_height": null,
1133
+ "max_width": null,
1134
+ "min_height": null,
1135
+ "min_width": null,
1136
+ "object_fit": null,
1137
+ "object_position": null,
1138
+ "order": null,
1139
+ "overflow": null,
1140
+ "overflow_x": null,
1141
+ "overflow_y": null,
1142
+ "padding": null,
1143
+ "right": null,
1144
+ "top": null,
1145
+ "visibility": null,
1146
+ "width": null
1147
+ }
1148
+ },
1149
+ "4795a78a75dc439a8da7df58bf738940": {
1150
+ "model_module": "@jupyter-widgets/controls",
1151
+ "model_module_version": "1.5.0",
1152
+ "model_name": "DescriptionStyleModel",
1153
+ "state": {
1154
+ "_model_module": "@jupyter-widgets/controls",
1155
+ "_model_module_version": "1.5.0",
1156
+ "_model_name": "DescriptionStyleModel",
1157
+ "_view_count": null,
1158
+ "_view_module": "@jupyter-widgets/base",
1159
+ "_view_module_version": "1.2.0",
1160
+ "_view_name": "StyleView",
1161
+ "description_width": ""
1162
+ }
1163
+ },
1164
+ "4c22e1b396f342ffb90c1b50a0051862": {
1165
+ "model_module": "@jupyter-widgets/controls",
1166
+ "model_module_version": "1.5.0",
1167
+ "model_name": "FloatProgressModel",
1168
+ "state": {
1169
+ "_dom_classes": [],
1170
+ "_model_module": "@jupyter-widgets/controls",
1171
+ "_model_module_version": "1.5.0",
1172
+ "_model_name": "FloatProgressModel",
1173
+ "_view_count": null,
1174
+ "_view_module": "@jupyter-widgets/controls",
1175
+ "_view_module_version": "1.5.0",
1176
+ "_view_name": "ProgressView",
1177
+ "bar_style": "success",
1178
+ "description": "",
1179
+ "description_tooltip": null,
1180
+ "layout": "IPY_MODEL_a5a9f8607fdd4f9cad7519eca573f3dc",
1181
+ "max": 1,
1182
+ "min": 0,
1183
+ "orientation": "horizontal",
1184
+ "style": "IPY_MODEL_926149594f94457295c60b4fad9cbac7",
1185
+ "value": 1
1186
+ }
1187
+ },
1188
+ "4e32e76c44fb449c8cb767abeb17868a": {
1189
+ "model_module": "@jupyter-widgets/controls",
1190
+ "model_module_version": "1.5.0",
1191
+ "model_name": "DescriptionStyleModel",
1192
+ "state": {
1193
+ "_model_module": "@jupyter-widgets/controls",
1194
+ "_model_module_version": "1.5.0",
1195
+ "_model_name": "DescriptionStyleModel",
1196
+ "_view_count": null,
1197
+ "_view_module": "@jupyter-widgets/base",
1198
+ "_view_module_version": "1.2.0",
1199
+ "_view_name": "StyleView",
1200
+ "description_width": ""
1201
+ }
1202
+ },
1203
+ "57e526d188b9414dabb3b1c895373864": {
1204
+ "model_module": "@jupyter-widgets/base",
1205
+ "model_module_version": "1.2.0",
1206
+ "model_name": "LayoutModel",
1207
+ "state": {
1208
+ "_model_module": "@jupyter-widgets/base",
1209
+ "_model_module_version": "1.2.0",
1210
+ "_model_name": "LayoutModel",
1211
+ "_view_count": null,
1212
+ "_view_module": "@jupyter-widgets/base",
1213
+ "_view_module_version": "1.2.0",
1214
+ "_view_name": "LayoutView",
1215
+ "align_content": null,
1216
+ "align_items": null,
1217
+ "align_self": null,
1218
+ "border": null,
1219
+ "bottom": null,
1220
+ "display": null,
1221
+ "flex": null,
1222
+ "flex_flow": null,
1223
+ "grid_area": null,
1224
+ "grid_auto_columns": null,
1225
+ "grid_auto_flow": null,
1226
+ "grid_auto_rows": null,
1227
+ "grid_column": null,
1228
+ "grid_gap": null,
1229
+ "grid_row": null,
1230
+ "grid_template_areas": null,
1231
+ "grid_template_columns": null,
1232
+ "grid_template_rows": null,
1233
+ "height": null,
1234
+ "justify_content": null,
1235
+ "justify_items": null,
1236
+ "left": null,
1237
+ "margin": null,
1238
+ "max_height": null,
1239
+ "max_width": null,
1240
+ "min_height": null,
1241
+ "min_width": null,
1242
+ "object_fit": null,
1243
+ "object_position": null,
1244
+ "order": null,
1245
+ "overflow": null,
1246
+ "overflow_x": null,
1247
+ "overflow_y": null,
1248
+ "padding": null,
1249
+ "right": null,
1250
+ "top": null,
1251
+ "visibility": null,
1252
+ "width": null
1253
+ }
1254
+ },
1255
+ "5837f2c4668646c0a6db2407aebb46e3": {
1256
+ "model_module": "@jupyter-widgets/controls",
1257
+ "model_module_version": "1.5.0",
1258
+ "model_name": "HTMLModel",
1259
+ "state": {
1260
+ "_dom_classes": [],
1261
+ "_model_module": "@jupyter-widgets/controls",
1262
+ "_model_module_version": "1.5.0",
1263
+ "_model_name": "HTMLModel",
1264
+ "_view_count": null,
1265
+ "_view_module": "@jupyter-widgets/controls",
1266
+ "_view_module_version": "1.5.0",
1267
+ "_view_name": "HTMLView",
1268
+ "description": "",
1269
+ "description_tooltip": null,
1270
+ "layout": "IPY_MODEL_64277772cc30408e8ea29f0e268c8880",
1271
+ "placeholder": "​",
1272
+ "style": "IPY_MODEL_5b0d55ea20714104818097bd7d1f509a",
1273
+ "value": " 45.1M/45.1M [00:00&lt;00:00, 112MB/s]"
1274
+ }
1275
+ },
1276
+ "58ab975eaba2485cb0945482c26ecf3d": {
1277
+ "model_module": "@jupyter-widgets/base",
1278
+ "model_module_version": "1.2.0",
1279
+ "model_name": "LayoutModel",
1280
+ "state": {
1281
+ "_model_module": "@jupyter-widgets/base",
1282
+ "_model_module_version": "1.2.0",
1283
+ "_model_name": "LayoutModel",
1284
+ "_view_count": null,
1285
+ "_view_module": "@jupyter-widgets/base",
1286
+ "_view_module_version": "1.2.0",
1287
+ "_view_name": "LayoutView",
1288
+ "align_content": null,
1289
+ "align_items": null,
1290
+ "align_self": null,
1291
+ "border": null,
1292
+ "bottom": null,
1293
+ "display": null,
1294
+ "flex": null,
1295
+ "flex_flow": null,
1296
+ "grid_area": null,
1297
+ "grid_auto_columns": null,
1298
+ "grid_auto_flow": null,
1299
+ "grid_auto_rows": null,
1300
+ "grid_column": null,
1301
+ "grid_gap": null,
1302
+ "grid_row": null,
1303
+ "grid_template_areas": null,
1304
+ "grid_template_columns": null,
1305
+ "grid_template_rows": null,
1306
+ "height": null,
1307
+ "justify_content": null,
1308
+ "justify_items": null,
1309
+ "left": null,
1310
+ "margin": null,
1311
+ "max_height": null,
1312
+ "max_width": null,
1313
+ "min_height": null,
1314
+ "min_width": null,
1315
+ "object_fit": null,
1316
+ "object_position": null,
1317
+ "order": null,
1318
+ "overflow": null,
1319
+ "overflow_x": null,
1320
+ "overflow_y": null,
1321
+ "padding": null,
1322
+ "right": null,
1323
+ "top": null,
1324
+ "visibility": null,
1325
+ "width": null
1326
+ }
1327
+ },
1328
+ "5b0d55ea20714104818097bd7d1f509a": {
1329
+ "model_module": "@jupyter-widgets/controls",
1330
+ "model_module_version": "1.5.0",
1331
+ "model_name": "DescriptionStyleModel",
1332
+ "state": {
1333
+ "_model_module": "@jupyter-widgets/controls",
1334
+ "_model_module_version": "1.5.0",
1335
+ "_model_name": "DescriptionStyleModel",
1336
+ "_view_count": null,
1337
+ "_view_module": "@jupyter-widgets/base",
1338
+ "_view_module_version": "1.2.0",
1339
+ "_view_name": "StyleView",
1340
+ "description_width": ""
1341
+ }
1342
+ },
1343
+ "5ba39d9d997a45ca848e3e2ffd0e7307": {
1344
+ "model_module": "@jupyter-widgets/controls",
1345
+ "model_module_version": "1.5.0",
1346
+ "model_name": "HTMLModel",
1347
+ "state": {
1348
+ "_dom_classes": [],
1349
+ "_model_module": "@jupyter-widgets/controls",
1350
+ "_model_module_version": "1.5.0",
1351
+ "_model_name": "HTMLModel",
1352
+ "_view_count": null,
1353
+ "_view_module": "@jupyter-widgets/controls",
1354
+ "_view_module_version": "1.5.0",
1355
+ "_view_name": "HTMLView",
1356
+ "description": "",
1357
+ "description_tooltip": null,
1358
+ "layout": "IPY_MODEL_128f4312bcdc4166b9e24d8cdd34184d",
1359
+ "placeholder": "​",
1360
+ "style": "IPY_MODEL_1b65d6c8540e4f458886d5e7075ab30a",
1361
+ "value": "vocab.txt: "
1362
+ }
1363
+ },
1364
+ "5c3cb981f324446eae642f7c23a539f0": {
1365
+ "model_module": "@jupyter-widgets/base",
1366
+ "model_module_version": "1.2.0",
1367
+ "model_name": "LayoutModel",
1368
+ "state": {
1369
+ "_model_module": "@jupyter-widgets/base",
1370
+ "_model_module_version": "1.2.0",
1371
+ "_model_name": "LayoutModel",
1372
+ "_view_count": null,
1373
+ "_view_module": "@jupyter-widgets/base",
1374
+ "_view_module_version": "1.2.0",
1375
+ "_view_name": "LayoutView",
1376
+ "align_content": null,
1377
+ "align_items": null,
1378
+ "align_self": null,
1379
+ "border": null,
1380
+ "bottom": null,
1381
+ "display": null,
1382
+ "flex": null,
1383
+ "flex_flow": null,
1384
+ "grid_area": null,
1385
+ "grid_auto_columns": null,
1386
+ "grid_auto_flow": null,
1387
+ "grid_auto_rows": null,
1388
+ "grid_column": null,
1389
+ "grid_gap": null,
1390
+ "grid_row": null,
1391
+ "grid_template_areas": null,
1392
+ "grid_template_columns": null,
1393
+ "grid_template_rows": null,
1394
+ "height": null,
1395
+ "justify_content": null,
1396
+ "justify_items": null,
1397
+ "left": null,
1398
+ "margin": null,
1399
+ "max_height": null,
1400
+ "max_width": null,
1401
+ "min_height": null,
1402
+ "min_width": null,
1403
+ "object_fit": null,
1404
+ "object_position": null,
1405
+ "order": null,
1406
+ "overflow": null,
1407
+ "overflow_x": null,
1408
+ "overflow_y": null,
1409
+ "padding": null,
1410
+ "right": null,
1411
+ "top": null,
1412
+ "visibility": null,
1413
+ "width": null
1414
+ }
1415
+ },
1416
+ "5efdceae0bac4c978d3a7226247e237f": {
1417
+ "model_module": "@jupyter-widgets/controls",
1418
+ "model_module_version": "1.5.0",
1419
+ "model_name": "HBoxModel",
1420
+ "state": {
1421
+ "_dom_classes": [],
1422
+ "_model_module": "@jupyter-widgets/controls",
1423
+ "_model_module_version": "1.5.0",
1424
+ "_model_name": "HBoxModel",
1425
+ "_view_count": null,
1426
+ "_view_module": "@jupyter-widgets/controls",
1427
+ "_view_module_version": "1.5.0",
1428
+ "_view_name": "HBoxView",
1429
+ "box_style": "",
1430
+ "children": [
1431
+ "IPY_MODEL_a39c5c623a3e42448e109fb9ec6bc263",
1432
+ "IPY_MODEL_a6ed2ddb1c6f4d1aa945c5a39372f781",
1433
+ "IPY_MODEL_8cf950b898e142c1af9b4db92019aa4d"
1434
+ ],
1435
+ "layout": "IPY_MODEL_8ed7abd0602c43a1bfc0f96d7611d429"
1436
+ }
1437
+ },
1438
+ "5f5e7ff6e4c845b99602a4fa00ad550a": {
1439
+ "model_module": "@jupyter-widgets/base",
1440
+ "model_module_version": "1.2.0",
1441
+ "model_name": "LayoutModel",
1442
+ "state": {
1443
+ "_model_module": "@jupyter-widgets/base",
1444
+ "_model_module_version": "1.2.0",
1445
+ "_model_name": "LayoutModel",
1446
+ "_view_count": null,
1447
+ "_view_module": "@jupyter-widgets/base",
1448
+ "_view_module_version": "1.2.0",
1449
+ "_view_name": "LayoutView",
1450
+ "align_content": null,
1451
+ "align_items": null,
1452
+ "align_self": null,
1453
+ "border": null,
1454
+ "bottom": null,
1455
+ "display": null,
1456
+ "flex": null,
1457
+ "flex_flow": null,
1458
+ "grid_area": null,
1459
+ "grid_auto_columns": null,
1460
+ "grid_auto_flow": null,
1461
+ "grid_auto_rows": null,
1462
+ "grid_column": null,
1463
+ "grid_gap": null,
1464
+ "grid_row": null,
1465
+ "grid_template_areas": null,
1466
+ "grid_template_columns": null,
1467
+ "grid_template_rows": null,
1468
+ "height": null,
1469
+ "justify_content": null,
1470
+ "justify_items": null,
1471
+ "left": null,
1472
+ "margin": null,
1473
+ "max_height": null,
1474
+ "max_width": null,
1475
+ "min_height": null,
1476
+ "min_width": null,
1477
+ "object_fit": null,
1478
+ "object_position": null,
1479
+ "order": null,
1480
+ "overflow": null,
1481
+ "overflow_x": null,
1482
+ "overflow_y": null,
1483
+ "padding": null,
1484
+ "right": null,
1485
+ "top": null,
1486
+ "visibility": null,
1487
+ "width": null
1488
+ }
1489
+ },
1490
+ "64277772cc30408e8ea29f0e268c8880": {
1491
+ "model_module": "@jupyter-widgets/base",
1492
+ "model_module_version": "1.2.0",
1493
+ "model_name": "LayoutModel",
1494
+ "state": {
1495
+ "_model_module": "@jupyter-widgets/base",
1496
+ "_model_module_version": "1.2.0",
1497
+ "_model_name": "LayoutModel",
1498
+ "_view_count": null,
1499
+ "_view_module": "@jupyter-widgets/base",
1500
+ "_view_module_version": "1.2.0",
1501
+ "_view_name": "LayoutView",
1502
+ "align_content": null,
1503
+ "align_items": null,
1504
+ "align_self": null,
1505
+ "border": null,
1506
+ "bottom": null,
1507
+ "display": null,
1508
+ "flex": null,
1509
+ "flex_flow": null,
1510
+ "grid_area": null,
1511
+ "grid_auto_columns": null,
1512
+ "grid_auto_flow": null,
1513
+ "grid_auto_rows": null,
1514
+ "grid_column": null,
1515
+ "grid_gap": null,
1516
+ "grid_row": null,
1517
+ "grid_template_areas": null,
1518
+ "grid_template_columns": null,
1519
+ "grid_template_rows": null,
1520
+ "height": null,
1521
+ "justify_content": null,
1522
+ "justify_items": null,
1523
+ "left": null,
1524
+ "margin": null,
1525
+ "max_height": null,
1526
+ "max_width": null,
1527
+ "min_height": null,
1528
+ "min_width": null,
1529
+ "object_fit": null,
1530
+ "object_position": null,
1531
+ "order": null,
1532
+ "overflow": null,
1533
+ "overflow_x": null,
1534
+ "overflow_y": null,
1535
+ "padding": null,
1536
+ "right": null,
1537
+ "top": null,
1538
+ "visibility": null,
1539
+ "width": null
1540
+ }
1541
+ },
1542
+ "65ba2d78fde14bb2baf5ae1101d7e5ff": {
1543
+ "model_module": "@jupyter-widgets/base",
1544
+ "model_module_version": "1.2.0",
1545
+ "model_name": "LayoutModel",
1546
+ "state": {
1547
+ "_model_module": "@jupyter-widgets/base",
1548
+ "_model_module_version": "1.2.0",
1549
+ "_model_name": "LayoutModel",
1550
+ "_view_count": null,
1551
+ "_view_module": "@jupyter-widgets/base",
1552
+ "_view_module_version": "1.2.0",
1553
+ "_view_name": "LayoutView",
1554
+ "align_content": null,
1555
+ "align_items": null,
1556
+ "align_self": null,
1557
+ "border": null,
1558
+ "bottom": null,
1559
+ "display": null,
1560
+ "flex": null,
1561
+ "flex_flow": null,
1562
+ "grid_area": null,
1563
+ "grid_auto_columns": null,
1564
+ "grid_auto_flow": null,
1565
+ "grid_auto_rows": null,
1566
+ "grid_column": null,
1567
+ "grid_gap": null,
1568
+ "grid_row": null,
1569
+ "grid_template_areas": null,
1570
+ "grid_template_columns": null,
1571
+ "grid_template_rows": null,
1572
+ "height": null,
1573
+ "justify_content": null,
1574
+ "justify_items": null,
1575
+ "left": null,
1576
+ "margin": null,
1577
+ "max_height": null,
1578
+ "max_width": null,
1579
+ "min_height": null,
1580
+ "min_width": null,
1581
+ "object_fit": null,
1582
+ "object_position": null,
1583
+ "order": null,
1584
+ "overflow": null,
1585
+ "overflow_x": null,
1586
+ "overflow_y": null,
1587
+ "padding": null,
1588
+ "right": null,
1589
+ "top": null,
1590
+ "visibility": null,
1591
+ "width": null
1592
+ }
1593
+ },
1594
+ "7e21c6a9c7f44496b6f28513caefb631": {
1595
+ "model_module": "@jupyter-widgets/controls",
1596
+ "model_module_version": "1.5.0",
1597
+ "model_name": "HBoxModel",
1598
+ "state": {
1599
+ "_dom_classes": [],
1600
+ "_model_module": "@jupyter-widgets/controls",
1601
+ "_model_module_version": "1.5.0",
1602
+ "_model_name": "HBoxModel",
1603
+ "_view_count": null,
1604
+ "_view_module": "@jupyter-widgets/controls",
1605
+ "_view_module_version": "1.5.0",
1606
+ "_view_name": "HBoxView",
1607
+ "box_style": "",
1608
+ "children": [
1609
+ "IPY_MODEL_439eba0eb4184c0ab83f65fc26bbe388",
1610
+ "IPY_MODEL_eee695744ec64aa7b71b9e85968c6f8f",
1611
+ "IPY_MODEL_c4ecdc9d982f49129368893c1c0aece9"
1612
+ ],
1613
+ "layout": "IPY_MODEL_5f5e7ff6e4c845b99602a4fa00ad550a"
1614
+ }
1615
+ },
1616
+ "7e89bc79516f405e9684eacdce7b4551": {
1617
+ "model_module": "@jupyter-widgets/base",
1618
+ "model_module_version": "1.2.0",
1619
+ "model_name": "LayoutModel",
1620
+ "state": {
1621
+ "_model_module": "@jupyter-widgets/base",
1622
+ "_model_module_version": "1.2.0",
1623
+ "_model_name": "LayoutModel",
1624
+ "_view_count": null,
1625
+ "_view_module": "@jupyter-widgets/base",
1626
+ "_view_module_version": "1.2.0",
1627
+ "_view_name": "LayoutView",
1628
+ "align_content": null,
1629
+ "align_items": null,
1630
+ "align_self": null,
1631
+ "border": null,
1632
+ "bottom": null,
1633
+ "display": null,
1634
+ "flex": null,
1635
+ "flex_flow": null,
1636
+ "grid_area": null,
1637
+ "grid_auto_columns": null,
1638
+ "grid_auto_flow": null,
1639
+ "grid_auto_rows": null,
1640
+ "grid_column": null,
1641
+ "grid_gap": null,
1642
+ "grid_row": null,
1643
+ "grid_template_areas": null,
1644
+ "grid_template_columns": null,
1645
+ "grid_template_rows": null,
1646
+ "height": null,
1647
+ "justify_content": null,
1648
+ "justify_items": null,
1649
+ "left": null,
1650
+ "margin": null,
1651
+ "max_height": null,
1652
+ "max_width": null,
1653
+ "min_height": null,
1654
+ "min_width": null,
1655
+ "object_fit": null,
1656
+ "object_position": null,
1657
+ "order": null,
1658
+ "overflow": null,
1659
+ "overflow_x": null,
1660
+ "overflow_y": null,
1661
+ "padding": null,
1662
+ "right": null,
1663
+ "top": null,
1664
+ "visibility": null,
1665
+ "width": null
1666
+ }
1667
+ },
1668
+ "8226a55726c54abba3a48dbfa8e1b6f6": {
1669
+ "model_module": "@jupyter-widgets/controls",
1670
+ "model_module_version": "1.5.0",
1671
+ "model_name": "DescriptionStyleModel",
1672
+ "state": {
1673
+ "_model_module": "@jupyter-widgets/controls",
1674
+ "_model_module_version": "1.5.0",
1675
+ "_model_name": "DescriptionStyleModel",
1676
+ "_view_count": null,
1677
+ "_view_module": "@jupyter-widgets/base",
1678
+ "_view_module_version": "1.2.0",
1679
+ "_view_name": "StyleView",
1680
+ "description_width": ""
1681
+ }
1682
+ },
1683
+ "828b227361fe45cd83964149e7475503": {
1684
+ "model_module": "@jupyter-widgets/controls",
1685
+ "model_module_version": "1.5.0",
1686
+ "model_name": "ProgressStyleModel",
1687
+ "state": {
1688
+ "_model_module": "@jupyter-widgets/controls",
1689
+ "_model_module_version": "1.5.0",
1690
+ "_model_name": "ProgressStyleModel",
1691
+ "_view_count": null,
1692
+ "_view_module": "@jupyter-widgets/base",
1693
+ "_view_module_version": "1.2.0",
1694
+ "_view_name": "StyleView",
1695
+ "bar_color": null,
1696
+ "description_width": ""
1697
+ }
1698
+ },
1699
+ "86a3c1a4e9eb4989b23364f21e5df531": {
1700
+ "model_module": "@jupyter-widgets/controls",
1701
+ "model_module_version": "1.5.0",
1702
+ "model_name": "HBoxModel",
1703
+ "state": {
1704
+ "_dom_classes": [],
1705
+ "_model_module": "@jupyter-widgets/controls",
1706
+ "_model_module_version": "1.5.0",
1707
+ "_model_name": "HBoxModel",
1708
+ "_view_count": null,
1709
+ "_view_module": "@jupyter-widgets/controls",
1710
+ "_view_module_version": "1.5.0",
1711
+ "_view_name": "HBoxView",
1712
+ "box_style": "",
1713
+ "children": [
1714
+ "IPY_MODEL_5ba39d9d997a45ca848e3e2ffd0e7307",
1715
+ "IPY_MODEL_4c22e1b396f342ffb90c1b50a0051862",
1716
+ "IPY_MODEL_370e5663868f411697bfb24f4e3efa09"
1717
+ ],
1718
+ "layout": "IPY_MODEL_3a338ac4d2944030a07843d8ea24e9fd"
1719
+ }
1720
+ },
1721
+ "8cf950b898e142c1af9b4db92019aa4d": {
1722
+ "model_module": "@jupyter-widgets/controls",
1723
+ "model_module_version": "1.5.0",
1724
+ "model_name": "HTMLModel",
1725
+ "state": {
1726
+ "_dom_classes": [],
1727
+ "_model_module": "@jupyter-widgets/controls",
1728
+ "_model_module_version": "1.5.0",
1729
+ "_model_name": "HTMLModel",
1730
+ "_view_count": null,
1731
+ "_view_module": "@jupyter-widgets/controls",
1732
+ "_view_module_version": "1.5.0",
1733
+ "_view_name": "HTMLView",
1734
+ "description": "",
1735
+ "description_tooltip": null,
1736
+ "layout": "IPY_MODEL_57e526d188b9414dabb3b1c895373864",
1737
+ "placeholder": "​",
1738
+ "style": "IPY_MODEL_8226a55726c54abba3a48dbfa8e1b6f6",
1739
+ "value": " 286/286 [00:00&lt;00:00, 25.2kB/s]"
1740
+ }
1741
+ },
1742
+ "8ed7abd0602c43a1bfc0f96d7611d429": {
1743
+ "model_module": "@jupyter-widgets/base",
1744
+ "model_module_version": "1.2.0",
1745
+ "model_name": "LayoutModel",
1746
+ "state": {
1747
+ "_model_module": "@jupyter-widgets/base",
1748
+ "_model_module_version": "1.2.0",
1749
+ "_model_name": "LayoutModel",
1750
+ "_view_count": null,
1751
+ "_view_module": "@jupyter-widgets/base",
1752
+ "_view_module_version": "1.2.0",
1753
+ "_view_name": "LayoutView",
1754
+ "align_content": null,
1755
+ "align_items": null,
1756
+ "align_self": null,
1757
+ "border": null,
1758
+ "bottom": null,
1759
+ "display": null,
1760
+ "flex": null,
1761
+ "flex_flow": null,
1762
+ "grid_area": null,
1763
+ "grid_auto_columns": null,
1764
+ "grid_auto_flow": null,
1765
+ "grid_auto_rows": null,
1766
+ "grid_column": null,
1767
+ "grid_gap": null,
1768
+ "grid_row": null,
1769
+ "grid_template_areas": null,
1770
+ "grid_template_columns": null,
1771
+ "grid_template_rows": null,
1772
+ "height": null,
1773
+ "justify_content": null,
1774
+ "justify_items": null,
1775
+ "left": null,
1776
+ "margin": null,
1777
+ "max_height": null,
1778
+ "max_width": null,
1779
+ "min_height": null,
1780
+ "min_width": null,
1781
+ "object_fit": null,
1782
+ "object_position": null,
1783
+ "order": null,
1784
+ "overflow": null,
1785
+ "overflow_x": null,
1786
+ "overflow_y": null,
1787
+ "padding": null,
1788
+ "right": null,
1789
+ "top": null,
1790
+ "visibility": null,
1791
+ "width": null
1792
+ }
1793
+ },
1794
+ "926149594f94457295c60b4fad9cbac7": {
1795
+ "model_module": "@jupyter-widgets/controls",
1796
+ "model_module_version": "1.5.0",
1797
+ "model_name": "ProgressStyleModel",
1798
+ "state": {
1799
+ "_model_module": "@jupyter-widgets/controls",
1800
+ "_model_module_version": "1.5.0",
1801
+ "_model_name": "ProgressStyleModel",
1802
+ "_view_count": null,
1803
+ "_view_module": "@jupyter-widgets/base",
1804
+ "_view_module_version": "1.2.0",
1805
+ "_view_name": "StyleView",
1806
+ "bar_color": null,
1807
+ "description_width": ""
1808
+ }
1809
+ },
1810
+ "a39c5c623a3e42448e109fb9ec6bc263": {
1811
+ "model_module": "@jupyter-widgets/controls",
1812
+ "model_module_version": "1.5.0",
1813
+ "model_name": "HTMLModel",
1814
+ "state": {
1815
+ "_dom_classes": [],
1816
+ "_model_module": "@jupyter-widgets/controls",
1817
+ "_model_module_version": "1.5.0",
1818
+ "_model_name": "HTMLModel",
1819
+ "_view_count": null,
1820
+ "_view_module": "@jupyter-widgets/controls",
1821
+ "_view_module_version": "1.5.0",
1822
+ "_view_name": "HTMLView",
1823
+ "description": "",
1824
+ "description_tooltip": null,
1825
+ "layout": "IPY_MODEL_65ba2d78fde14bb2baf5ae1101d7e5ff",
1826
+ "placeholder": "​",
1827
+ "style": "IPY_MODEL_4795a78a75dc439a8da7df58bf738940",
1828
+ "value": "config.json: 100%"
1829
+ }
1830
+ },
1831
+ "a5a9f8607fdd4f9cad7519eca573f3dc": {
1832
+ "model_module": "@jupyter-widgets/base",
1833
+ "model_module_version": "1.2.0",
1834
+ "model_name": "LayoutModel",
1835
+ "state": {
1836
+ "_model_module": "@jupyter-widgets/base",
1837
+ "_model_module_version": "1.2.0",
1838
+ "_model_name": "LayoutModel",
1839
+ "_view_count": null,
1840
+ "_view_module": "@jupyter-widgets/base",
1841
+ "_view_module_version": "1.2.0",
1842
+ "_view_name": "LayoutView",
1843
+ "align_content": null,
1844
+ "align_items": null,
1845
+ "align_self": null,
1846
+ "border": null,
1847
+ "bottom": null,
1848
+ "display": null,
1849
+ "flex": null,
1850
+ "flex_flow": null,
1851
+ "grid_area": null,
1852
+ "grid_auto_columns": null,
1853
+ "grid_auto_flow": null,
1854
+ "grid_auto_rows": null,
1855
+ "grid_column": null,
1856
+ "grid_gap": null,
1857
+ "grid_row": null,
1858
+ "grid_template_areas": null,
1859
+ "grid_template_columns": null,
1860
+ "grid_template_rows": null,
1861
+ "height": null,
1862
+ "justify_content": null,
1863
+ "justify_items": null,
1864
+ "left": null,
1865
+ "margin": null,
1866
+ "max_height": null,
1867
+ "max_width": null,
1868
+ "min_height": null,
1869
+ "min_width": null,
1870
+ "object_fit": null,
1871
+ "object_position": null,
1872
+ "order": null,
1873
+ "overflow": null,
1874
+ "overflow_x": null,
1875
+ "overflow_y": null,
1876
+ "padding": null,
1877
+ "right": null,
1878
+ "top": null,
1879
+ "visibility": null,
1880
+ "width": "20px"
1881
+ }
1882
+ },
1883
+ "a6ed2ddb1c6f4d1aa945c5a39372f781": {
1884
+ "model_module": "@jupyter-widgets/controls",
1885
+ "model_module_version": "1.5.0",
1886
+ "model_name": "FloatProgressModel",
1887
+ "state": {
1888
+ "_dom_classes": [],
1889
+ "_model_module": "@jupyter-widgets/controls",
1890
+ "_model_module_version": "1.5.0",
1891
+ "_model_name": "FloatProgressModel",
1892
+ "_view_count": null,
1893
+ "_view_module": "@jupyter-widgets/controls",
1894
+ "_view_module_version": "1.5.0",
1895
+ "_view_name": "ProgressView",
1896
+ "bar_style": "success",
1897
+ "description": "",
1898
+ "description_tooltip": null,
1899
+ "layout": "IPY_MODEL_4545ff199b874d3680a83918513e1d4b",
1900
+ "max": 286,
1901
+ "min": 0,
1902
+ "orientation": "horizontal",
1903
+ "style": "IPY_MODEL_cad8fd90586443778568a1babb8c40e6",
1904
+ "value": 286
1905
+ }
1906
+ },
1907
+ "ab61b90c1a5b4a2b9bb5c9d5a215bb3f": {
1908
+ "model_module": "@jupyter-widgets/controls",
1909
+ "model_module_version": "1.5.0",
1910
+ "model_name": "HTMLModel",
1911
+ "state": {
1912
+ "_dom_classes": [],
1913
+ "_model_module": "@jupyter-widgets/controls",
1914
+ "_model_module_version": "1.5.0",
1915
+ "_model_name": "HTMLModel",
1916
+ "_view_count": null,
1917
+ "_view_module": "@jupyter-widgets/controls",
1918
+ "_view_module_version": "1.5.0",
1919
+ "_view_name": "HTMLView",
1920
+ "description": "",
1921
+ "description_tooltip": null,
1922
+ "layout": "IPY_MODEL_bf8eb066cdaf4ac096dc14392d085daf",
1923
+ "placeholder": "​",
1924
+ "style": "IPY_MODEL_4e32e76c44fb449c8cb767abeb17868a",
1925
+ "value": "pytorch_model.bin: 100%"
1926
+ }
1927
+ },
1928
+ "b2bf751bb96746e4a828241f70e52050": {
1929
+ "model_module": "@jupyter-widgets/base",
1930
+ "model_module_version": "1.2.0",
1931
+ "model_name": "LayoutModel",
1932
+ "state": {
1933
+ "_model_module": "@jupyter-widgets/base",
1934
+ "_model_module_version": "1.2.0",
1935
+ "_model_name": "LayoutModel",
1936
+ "_view_count": null,
1937
+ "_view_module": "@jupyter-widgets/base",
1938
+ "_view_module_version": "1.2.0",
1939
+ "_view_name": "LayoutView",
1940
+ "align_content": null,
1941
+ "align_items": null,
1942
+ "align_self": null,
1943
+ "border": null,
1944
+ "bottom": null,
1945
+ "display": null,
1946
+ "flex": null,
1947
+ "flex_flow": null,
1948
+ "grid_area": null,
1949
+ "grid_auto_columns": null,
1950
+ "grid_auto_flow": null,
1951
+ "grid_auto_rows": null,
1952
+ "grid_column": null,
1953
+ "grid_gap": null,
1954
+ "grid_row": null,
1955
+ "grid_template_areas": null,
1956
+ "grid_template_columns": null,
1957
+ "grid_template_rows": null,
1958
+ "height": null,
1959
+ "justify_content": null,
1960
+ "justify_items": null,
1961
+ "left": null,
1962
+ "margin": null,
1963
+ "max_height": null,
1964
+ "max_width": null,
1965
+ "min_height": null,
1966
+ "min_width": null,
1967
+ "object_fit": null,
1968
+ "object_position": null,
1969
+ "order": null,
1970
+ "overflow": null,
1971
+ "overflow_x": null,
1972
+ "overflow_y": null,
1973
+ "padding": null,
1974
+ "right": null,
1975
+ "top": null,
1976
+ "visibility": null,
1977
+ "width": null
1978
+ }
1979
+ },
1980
+ "bdf500351aea42698c6d6dd5a99021f3": {
1981
+ "model_module": "@jupyter-widgets/controls",
1982
+ "model_module_version": "1.5.0",
1983
+ "model_name": "HBoxModel",
1984
+ "state": {
1985
+ "_dom_classes": [],
1986
+ "_model_module": "@jupyter-widgets/controls",
1987
+ "_model_module_version": "1.5.0",
1988
+ "_model_name": "HBoxModel",
1989
+ "_view_count": null,
1990
+ "_view_module": "@jupyter-widgets/controls",
1991
+ "_view_module_version": "1.5.0",
1992
+ "_view_name": "HBoxView",
1993
+ "box_style": "",
1994
+ "children": [
1995
+ "IPY_MODEL_ab61b90c1a5b4a2b9bb5c9d5a215bb3f",
1996
+ "IPY_MODEL_dc03fed540b74f3aa4a1b17ebf2c81d3",
1997
+ "IPY_MODEL_5837f2c4668646c0a6db2407aebb46e3"
1998
+ ],
1999
+ "layout": "IPY_MODEL_edeb423e9ff84e5c8a0d790368d68bba"
2000
+ }
2001
+ },
2002
+ "bf8eb066cdaf4ac096dc14392d085daf": {
2003
+ "model_module": "@jupyter-widgets/base",
2004
+ "model_module_version": "1.2.0",
2005
+ "model_name": "LayoutModel",
2006
+ "state": {
2007
+ "_model_module": "@jupyter-widgets/base",
2008
+ "_model_module_version": "1.2.0",
2009
+ "_model_name": "LayoutModel",
2010
+ "_view_count": null,
2011
+ "_view_module": "@jupyter-widgets/base",
2012
+ "_view_module_version": "1.2.0",
2013
+ "_view_name": "LayoutView",
2014
+ "align_content": null,
2015
+ "align_items": null,
2016
+ "align_self": null,
2017
+ "border": null,
2018
+ "bottom": null,
2019
+ "display": null,
2020
+ "flex": null,
2021
+ "flex_flow": null,
2022
+ "grid_area": null,
2023
+ "grid_auto_columns": null,
2024
+ "grid_auto_flow": null,
2025
+ "grid_auto_rows": null,
2026
+ "grid_column": null,
2027
+ "grid_gap": null,
2028
+ "grid_row": null,
2029
+ "grid_template_areas": null,
2030
+ "grid_template_columns": null,
2031
+ "grid_template_rows": null,
2032
+ "height": null,
2033
+ "justify_content": null,
2034
+ "justify_items": null,
2035
+ "left": null,
2036
+ "margin": null,
2037
+ "max_height": null,
2038
+ "max_width": null,
2039
+ "min_height": null,
2040
+ "min_width": null,
2041
+ "object_fit": null,
2042
+ "object_position": null,
2043
+ "order": null,
2044
+ "overflow": null,
2045
+ "overflow_x": null,
2046
+ "overflow_y": null,
2047
+ "padding": null,
2048
+ "right": null,
2049
+ "top": null,
2050
+ "visibility": null,
2051
+ "width": null
2052
+ }
2053
+ },
2054
+ "bfcc6d01c9ff4db698afa4318e7c91ac": {
2055
+ "model_module": "@jupyter-widgets/controls",
2056
+ "model_module_version": "1.5.0",
2057
+ "model_name": "DescriptionStyleModel",
2058
+ "state": {
2059
+ "_model_module": "@jupyter-widgets/controls",
2060
+ "_model_module_version": "1.5.0",
2061
+ "_model_name": "DescriptionStyleModel",
2062
+ "_view_count": null,
2063
+ "_view_module": "@jupyter-widgets/base",
2064
+ "_view_module_version": "1.2.0",
2065
+ "_view_name": "StyleView",
2066
+ "description_width": ""
2067
+ }
2068
+ },
2069
+ "c4ecdc9d982f49129368893c1c0aece9": {
2070
+ "model_module": "@jupyter-widgets/controls",
2071
+ "model_module_version": "1.5.0",
2072
+ "model_name": "HTMLModel",
2073
+ "state": {
2074
+ "_dom_classes": [],
2075
+ "_model_module": "@jupyter-widgets/controls",
2076
+ "_model_module_version": "1.5.0",
2077
+ "_model_name": "HTMLModel",
2078
+ "_view_count": null,
2079
+ "_view_module": "@jupyter-widgets/controls",
2080
+ "_view_module_version": "1.5.0",
2081
+ "_view_name": "HTMLView",
2082
+ "description": "",
2083
+ "description_tooltip": null,
2084
+ "layout": "IPY_MODEL_58ab975eaba2485cb0945482c26ecf3d",
2085
+ "placeholder": "​",
2086
+ "style": "IPY_MODEL_d0b4e43ab5cd4edda6cc061b36bf10a3",
2087
+ "value": " 45.1M/45.1M [00:00&lt;00:00, 89.1MB/s]"
2088
+ }
2089
+ },
2090
+ "c917f3a000fb44338e4afbeabeaab55f": {
2091
+ "model_module": "@jupyter-widgets/controls",
2092
+ "model_module_version": "1.5.0",
2093
+ "model_name": "DescriptionStyleModel",
2094
+ "state": {
2095
+ "_model_module": "@jupyter-widgets/controls",
2096
+ "_model_module_version": "1.5.0",
2097
+ "_model_name": "DescriptionStyleModel",
2098
+ "_view_count": null,
2099
+ "_view_module": "@jupyter-widgets/base",
2100
+ "_view_module_version": "1.2.0",
2101
+ "_view_name": "StyleView",
2102
+ "description_width": ""
2103
+ }
2104
+ },
2105
+ "cad8fd90586443778568a1babb8c40e6": {
2106
+ "model_module": "@jupyter-widgets/controls",
2107
+ "model_module_version": "1.5.0",
2108
+ "model_name": "ProgressStyleModel",
2109
+ "state": {
2110
+ "_model_module": "@jupyter-widgets/controls",
2111
+ "_model_module_version": "1.5.0",
2112
+ "_model_name": "ProgressStyleModel",
2113
+ "_view_count": null,
2114
+ "_view_module": "@jupyter-widgets/base",
2115
+ "_view_module_version": "1.2.0",
2116
+ "_view_name": "StyleView",
2117
+ "bar_color": null,
2118
+ "description_width": ""
2119
+ }
2120
+ },
2121
+ "d0b4e43ab5cd4edda6cc061b36bf10a3": {
2122
+ "model_module": "@jupyter-widgets/controls",
2123
+ "model_module_version": "1.5.0",
2124
+ "model_name": "DescriptionStyleModel",
2125
+ "state": {
2126
+ "_model_module": "@jupyter-widgets/controls",
2127
+ "_model_module_version": "1.5.0",
2128
+ "_model_name": "DescriptionStyleModel",
2129
+ "_view_count": null,
2130
+ "_view_module": "@jupyter-widgets/base",
2131
+ "_view_module_version": "1.2.0",
2132
+ "_view_name": "StyleView",
2133
+ "description_width": ""
2134
+ }
2135
+ },
2136
+ "dc03fed540b74f3aa4a1b17ebf2c81d3": {
2137
+ "model_module": "@jupyter-widgets/controls",
2138
+ "model_module_version": "1.5.0",
2139
+ "model_name": "FloatProgressModel",
2140
+ "state": {
2141
+ "_dom_classes": [],
2142
+ "_model_module": "@jupyter-widgets/controls",
2143
+ "_model_module_version": "1.5.0",
2144
+ "_model_name": "FloatProgressModel",
2145
+ "_view_count": null,
2146
+ "_view_module": "@jupyter-widgets/controls",
2147
+ "_view_module_version": "1.5.0",
2148
+ "_view_name": "ProgressView",
2149
+ "bar_style": "success",
2150
+ "description": "",
2151
+ "description_tooltip": null,
2152
+ "layout": "IPY_MODEL_5c3cb981f324446eae642f7c23a539f0",
2153
+ "max": 45106985,
2154
+ "min": 0,
2155
+ "orientation": "horizontal",
2156
+ "style": "IPY_MODEL_2fe9614fe5984fa6b887d1e1b3e18b04",
2157
+ "value": 45106985
2158
+ }
2159
+ },
2160
+ "edeb423e9ff84e5c8a0d790368d68bba": {
2161
+ "model_module": "@jupyter-widgets/base",
2162
+ "model_module_version": "1.2.0",
2163
+ "model_name": "LayoutModel",
2164
+ "state": {
2165
+ "_model_module": "@jupyter-widgets/base",
2166
+ "_model_module_version": "1.2.0",
2167
+ "_model_name": "LayoutModel",
2168
+ "_view_count": null,
2169
+ "_view_module": "@jupyter-widgets/base",
2170
+ "_view_module_version": "1.2.0",
2171
+ "_view_name": "LayoutView",
2172
+ "align_content": null,
2173
+ "align_items": null,
2174
+ "align_self": null,
2175
+ "border": null,
2176
+ "bottom": null,
2177
+ "display": null,
2178
+ "flex": null,
2179
+ "flex_flow": null,
2180
+ "grid_area": null,
2181
+ "grid_auto_columns": null,
2182
+ "grid_auto_flow": null,
2183
+ "grid_auto_rows": null,
2184
+ "grid_column": null,
2185
+ "grid_gap": null,
2186
+ "grid_row": null,
2187
+ "grid_template_areas": null,
2188
+ "grid_template_columns": null,
2189
+ "grid_template_rows": null,
2190
+ "height": null,
2191
+ "justify_content": null,
2192
+ "justify_items": null,
2193
+ "left": null,
2194
+ "margin": null,
2195
+ "max_height": null,
2196
+ "max_width": null,
2197
+ "min_height": null,
2198
+ "min_width": null,
2199
+ "object_fit": null,
2200
+ "object_position": null,
2201
+ "order": null,
2202
+ "overflow": null,
2203
+ "overflow_x": null,
2204
+ "overflow_y": null,
2205
+ "padding": null,
2206
+ "right": null,
2207
+ "top": null,
2208
+ "visibility": null,
2209
+ "width": null
2210
+ }
2211
+ },
2212
+ "eee695744ec64aa7b71b9e85968c6f8f": {
2213
+ "model_module": "@jupyter-widgets/controls",
2214
+ "model_module_version": "1.5.0",
2215
+ "model_name": "FloatProgressModel",
2216
+ "state": {
2217
+ "_dom_classes": [],
2218
+ "_model_module": "@jupyter-widgets/controls",
2219
+ "_model_module_version": "1.5.0",
2220
+ "_model_name": "FloatProgressModel",
2221
+ "_view_count": null,
2222
+ "_view_module": "@jupyter-widgets/controls",
2223
+ "_view_module_version": "1.5.0",
2224
+ "_view_name": "ProgressView",
2225
+ "bar_style": "success",
2226
+ "description": "",
2227
+ "description_tooltip": null,
2228
+ "layout": "IPY_MODEL_b2bf751bb96746e4a828241f70e52050",
2229
+ "max": 45084768,
2230
+ "min": 0,
2231
+ "orientation": "horizontal",
2232
+ "style": "IPY_MODEL_828b227361fe45cd83964149e7475503",
2233
+ "value": 45084768
2234
+ }
2235
+ }
2236
+ }
2237
+ }
2238
+ },
2239
+ "nbformat": 4,
2240
+ "nbformat_minor": 4
2241
+ }
pikapikagen/README.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ ---
2
+ title: pikapikagen
3
+ app_file: gradio_demo.py
4
+ sdk: gradio
5
+ sdk_version: 5.35.0
6
+ ---
pikapikagen/__init__.py ADDED
File without changes
pikapikagen/data_loader.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader, Subset
2
+ import torch
3
+ from dataset import PokemonDataset
4
+ import math
5
+
6
+ def create_training_setup(
7
+ tokenizer,
8
+ test_set_size,
9
+ val_set_size,
10
+ batch_size,
11
+ num_workers=0,
12
+ num_viz_samples=4,
13
+ random_seed=42,
14
+ train_augmentation_pipeline=None,
15
+ ):
16
+ """
17
+ Create a complete setup for training with dataset, dataloaders and fixed batches for visualization.
18
+ """
19
+ assert 0 <= test_set_size < 1.0, "test_set_size must be a float between 0 and 1"
20
+ assert 0 <= val_set_size < 1.0, "val_set_size must be a float between 0 and 1"
21
+ assert (test_set_size + val_set_size) < 1.0, "The sum of test and validation sizes must be less than 1"
22
+
23
+ train_full_dataset = PokemonDataset(tokenizer=tokenizer, augmentation_transforms=train_augmentation_pipeline)
24
+ # Don't use augmentation for test and validation
25
+ test_val_full_dataset = PokemonDataset(tokenizer=tokenizer)
26
+
27
+ dataset_size = len(train_full_dataset)
28
+
29
+ # Create a random reproducible permutation
30
+ generator = torch.Generator().manual_seed(random_seed)
31
+ shuffled_indices = torch.randperm(dataset_size, generator=generator)
32
+
33
+ val_count = math.ceil(val_set_size * dataset_size)
34
+ test_count = math.ceil(test_set_size * dataset_size)
35
+ train_count = dataset_size - val_count - test_count
36
+
37
+ # Partition based on the computed splits
38
+ train_indices = shuffled_indices[:train_count].tolist()
39
+ test_indices = shuffled_indices[train_count : train_count + test_count].tolist()
40
+ val_indices = shuffled_indices[train_count + test_count :].tolist()
41
+
42
+ # Create the subsets based on the indices
43
+ train_dataset = Subset(train_full_dataset, train_indices)
44
+ test_dataset = Subset(test_val_full_dataset, test_indices)
45
+ val_dataset = Subset(test_val_full_dataset, val_indices)
46
+
47
+ train_loader = DataLoader(
48
+ train_dataset,
49
+ batch_size=batch_size,
50
+ shuffle=True,
51
+ num_workers=num_workers,
52
+ pin_memory=True,
53
+ )
54
+ test_loader = DataLoader(
55
+ test_dataset,
56
+ batch_size=batch_size,
57
+ shuffle=False,
58
+ num_workers=num_workers,
59
+ pin_memory=True,
60
+ )
61
+ val_loader = DataLoader(
62
+ val_dataset,
63
+ batch_size=batch_size,
64
+ shuffle=False,
65
+ num_workers=num_workers,
66
+ pin_memory=True,
67
+ )
68
+
69
+ # Batch for visualization
70
+ vis_generator = torch.Generator().manual_seed(random_seed)
71
+
72
+ fixed_train_batch = next(
73
+ iter(DataLoader(train_dataset, batch_size=num_viz_samples, shuffle=True, generator=vis_generator))
74
+ )
75
+ # Since no shuffle, a generator is not needed
76
+ fixed_test_batch = next(iter(DataLoader(test_dataset, batch_size=num_viz_samples, shuffle=False)))
77
+ fixed_val_batch = next(iter(DataLoader(val_dataset, batch_size=num_viz_samples, shuffle=False)))
78
+
79
+ # Batch (dimensione 1) for attention map visualization
80
+ vis_generator.manual_seed(random_seed)
81
+ fixed_train_attention_batch = next(
82
+ iter(DataLoader(train_dataset, batch_size=1, shuffle=True, generator=vis_generator))
83
+ )
84
+ fixed_test_attention_batch = next(iter(DataLoader(test_dataset, batch_size=1, shuffle=False)))
85
+ fixed_val_attention_batch = next(iter(DataLoader(val_dataset, batch_size=1, shuffle=False)))
86
+
87
+ return {
88
+ 'train_loader': train_loader,
89
+ 'val_loader': val_loader,
90
+ 'test_loader': test_loader,
91
+ 'train_dataset': train_dataset,
92
+ 'val_dataset': val_dataset,
93
+ 'test_dataset': test_dataset,
94
+ 'fixed_train_batch': fixed_train_batch,
95
+ 'fixed_val_batch': fixed_val_batch,
96
+ 'fixed_test_batch': fixed_test_batch,
97
+ 'fixed_train_attention_batch': fixed_train_attention_batch,
98
+ 'fixed_val_attention_batch': fixed_val_attention_batch,
99
+ 'fixed_test_attention_batch': fixed_test_attention_batch,
100
+ }
pikapikagen/dataset.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import urllib.request
3
+ import zipfile
4
+ from torch.utils.data import Dataset
5
+ import pandas as pd
6
+ from pathlib import Path
7
+ import torchvision.transforms as transforms
8
+ from PIL import Image
9
+ from typing import TypedDict
10
+ import torch
11
+
12
+
13
+ class PokemonSample(TypedDict):
14
+ text: torch.Tensor # Text already tokenized
15
+ image: torch.Tensor
16
+ description: str # Text before tokenization
17
+ pokemon_name: str
18
+ idx: int
19
+ attention_mask: torch.Tensor
20
+
21
+
22
+ def reporthook(block_num, block_size, total_size):
23
+ if block_num % 16384 == 0:
24
+ print(f"Downloading... {block_num * block_size / (1024 * 1024):.2f} MB")
25
+
26
+
27
+ def download_dataset_if_not_exists():
28
+ dataset_dir = "dataset"
29
+ pokedex_main_dir = os.path.join(dataset_dir, "pokedex-main")
30
+ zip_url = "https://github.com/cristobalmitchell/pokedex/archive/refs/heads/main.zip"
31
+ zip_path = "pokedex_main.zip"
32
+
33
+ if os.path.exists(pokedex_main_dir):
34
+ print(f"{pokedex_main_dir} already exists. Skipping download.")
35
+ return
36
+
37
+ os.makedirs(dataset_dir, exist_ok=True)
38
+
39
+ print("Downloading dataset...")
40
+ urllib.request.urlretrieve(zip_url, zip_path, reporthook)
41
+ print("Download complete.")
42
+
43
+ print("Extracting dataset...")
44
+ with zipfile.ZipFile(zip_path, "r") as zip_ref:
45
+ zip_ref.extractall(dataset_dir)
46
+ print("Extraction complete.")
47
+
48
+ os.remove(zip_path)
49
+
50
+
51
+ class PokemonDataset(Dataset):
52
+ def __init__(
53
+ self,
54
+ tokenizer,
55
+ csv_path="dataset/pokedex-main/data/pokemon.csv",
56
+ image_dir="dataset/pokedex-main/images/small_images",
57
+ max_length=128,
58
+ augmentation_transforms=None,
59
+ ):
60
+ self.df = pd.read_csv(csv_path, encoding="utf-16 LE", delimiter="\t")
61
+ self.image_dir = Path(image_dir)
62
+ print(f"Dataset caricato: {len(self.df)} Pokemon con descrizioni e immagini")
63
+
64
+ self.tokenizer = tokenizer
65
+ self.max_length = max_length
66
+
67
+ if augmentation_transforms is not None:
68
+ self.final_transform = transforms.Compose(
69
+ [
70
+ transforms.ToTensor(),
71
+ transforms.Resize((256, 256), antialias=True),
72
+ augmentation_transforms,
73
+ transforms.Normalize(
74
+ mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]
75
+ ), # Normalizza a [-1, 1]
76
+ ]
77
+ )
78
+ else:
79
+ self.final_transform = transforms.Compose(
80
+ [
81
+ transforms.ToTensor(),
82
+ transforms.Resize((256, 256), antialias=True),
83
+ transforms.Normalize(
84
+ mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]
85
+ ), # Normalizza a [-1, 1]
86
+ ]
87
+ )
88
+
89
+ def __len__(self):
90
+ return len(self.df)
91
+
92
+ def __getitem__(self, idx: int) -> PokemonSample:
93
+ # Ottieni la riga corrispondente
94
+ row = self.df.iloc[idx]
95
+
96
+ # === PREPROCESSING DEL TESTO ===
97
+ description = str(row["description"])
98
+
99
+ # Tokenizza il testo
100
+ encoded = self.tokenizer(
101
+ description,
102
+ max_length=self.max_length,
103
+ padding="max_length",
104
+ truncation=True,
105
+ return_tensors="pt",
106
+ )
107
+
108
+ # Estrai token_ids e attention_mask
109
+ text_ids = encoded["input_ids"].squeeze(0) # Rimuovi la dimensione batch
110
+ attention_mask = encoded["attention_mask"].squeeze(0)
111
+
112
+ # === CARICAMENTO E PREPROCESSING DELL'IMMAGINE ===
113
+ # Costruisce il percorso dell'immagine
114
+ image_filename = f"{row['national_number']:03d}.png"
115
+ image_path = self.image_dir / image_filename
116
+
117
+ # Carica l'immagine
118
+ image_rgba = Image.open(image_path).convert("RGBA")
119
+
120
+ # Gestisce la trasparenza: ricombina l'immagine con uno sfondo bianco
121
+ background = Image.new("RGB", image_rgba.size, (255, 255, 255))
122
+ background.paste(image_rgba, mask=image_rgba.split()[-1])
123
+
124
+ # Applica le trasformazioni finali (ToTensor, Resize, Normalize)
125
+ image_tensor = self.final_transform(background)
126
+
127
+ # Costruisce il risultato (matches pokemon_dataset.py structure)
128
+ sample = {
129
+ "text": text_ids,
130
+ "image": image_tensor,
131
+ "description": description,
132
+ "pokemon_name": row["english_name"],
133
+ "idx": idx,
134
+ "attention_mask": attention_mask,
135
+ }
136
+
137
+ return sample
138
+
139
+
140
+ download_dataset_if_not_exists()
141
+ print("Dataset ready!")
pikapikagen/discriminators.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from model_blocks.text_encoder import TextEncoder
4
+
5
+
6
+ class Discriminator256(nn.Module):
7
+ def __init__(self, text_dim=256, img_channels=3):
8
+ super(Discriminator256, self).__init__()
9
+
10
+ self.text_encoder = TextEncoder() # Separate text encoder for discriminators
11
+
12
+ self.img_path = nn.Sequential(
13
+ # 256x256 -> 128x128
14
+ nn.Conv2d(img_channels, 16, 4, 2, 1, bias=False),
15
+ nn.LeakyReLU(0.2, inplace=True),
16
+
17
+ # 128x128 -> 64x64
18
+ nn.Conv2d(16, 32, 4, 2, 1, bias=False),
19
+ nn.BatchNorm2d(32),
20
+ nn.LeakyReLU(0.2, inplace=True),
21
+
22
+ # 64x64 -> 32x32
23
+ nn.Conv2d(32, 64, 4, 2, 1, bias=False),
24
+ nn.BatchNorm2d(64),
25
+ nn.LeakyReLU(0.2, inplace=True),
26
+
27
+ # 32x32 -> 16x16
28
+ nn.Conv2d(64, 128, 4, 2, 1, bias=False),
29
+ nn.BatchNorm2d(128),
30
+ nn.LeakyReLU(0.2, inplace=True),
31
+
32
+ # 16x16 -> 8x8
33
+ nn.Conv2d(128, 256, 4, 2, 1, bias=False),
34
+ nn.BatchNorm2d(256),
35
+ nn.LeakyReLU(0.2, inplace=True),
36
+
37
+ # 8x8 -> 4x4
38
+ nn.Conv2d(256, 512, 4, 2, 1, bias=False),
39
+ nn.BatchNorm2d(512),
40
+ nn.LeakyReLU(0.2, inplace=True),
41
+ )
42
+
43
+ self.text_path = nn.Sequential(
44
+ nn.Linear(text_dim, 1024),
45
+ nn.LeakyReLU(0.2, inplace=True),
46
+ nn.Linear(1024, 512)
47
+ )
48
+
49
+ # Unconditional classifier (real/fake without text conditioning)
50
+ self.unconditional_classifier = nn.Sequential(
51
+ nn.Linear(512 * 4 * 4, 1024),
52
+ nn.LeakyReLU(0.2, inplace=True),
53
+ nn.Dropout(0.5),
54
+ nn.Linear(1024, 1),
55
+ )
56
+
57
+ # Conditional classifier (text-conditioned real/fake)
58
+ self.conditional_classifier = nn.Sequential(
59
+ nn.Linear(512 * 4 * 4 + 512, 1024), # size: sum of flattened image and text embedding
60
+ nn.LeakyReLU(0.2, inplace=True),
61
+ nn.Dropout(0.5),
62
+ nn.Linear(1024, 1),
63
+ )
64
+
65
+ def forward(self, images, text_features=None, text_mask=None, return_both=True):
66
+ # Encode image
67
+ img_features = self.img_path(images)
68
+ img_features_flat = img_features.view(img_features.size(0), -1) # Flatten
69
+
70
+ unconditional_output = self.unconditional_classifier(img_features_flat)
71
+
72
+ if not return_both:
73
+ return unconditional_output
74
+
75
+ if text_features is None or text_mask is None:
76
+ raise AttributeError("text_features and text_mask necessary for text conditioning")
77
+
78
+ # Encode text (mean pooling)
79
+ global_full_text = self.text_encoder(text_features, text_mask)
80
+ global_text = global_full_text.mean(dim=1)
81
+ text_features_encoded = self.text_path(global_text)
82
+
83
+ # Combine features
84
+ combined = torch.cat([img_features_flat, text_features_encoded], dim=1)
85
+ conditional_output = self.conditional_classifier(combined)
86
+
87
+ return unconditional_output, conditional_output
88
+
89
+
90
+ class Discriminator64(nn.Module):
91
+ def __init__(self, text_dim=256, img_channels=3):
92
+ super(Discriminator64, self).__init__()
93
+
94
+ self.text_encoder = TextEncoder()
95
+
96
+ self.img_path = nn.Sequential(
97
+ # 64x64 -> 32x32
98
+ nn.Conv2d(img_channels, 16, 4, 2, 1, bias=False),
99
+ nn.LeakyReLU(0.2, inplace=True),
100
+
101
+ # 32x32 -> 16x16
102
+ nn.Conv2d(16, 32, 4, 2, 1, bias=False),
103
+ nn.BatchNorm2d(32),
104
+ nn.LeakyReLU(0.2, inplace=True),
105
+
106
+ # 16x16 -> 8x8
107
+ nn.Conv2d(32, 64, 4, 2, 1, bias=False),
108
+ nn.BatchNorm2d(64),
109
+ nn.LeakyReLU(0.2, inplace=True),
110
+
111
+ # 8x8 -> 4x4
112
+ nn.Conv2d(64, 128, 4, 2, 1, bias=False),
113
+ nn.BatchNorm2d(128),
114
+ nn.LeakyReLU(0.2, inplace=True),
115
+ )
116
+
117
+ # Text encoder for discriminator
118
+ self.text_path = nn.Sequential(
119
+ nn.Linear(text_dim, 1024),
120
+ nn.LeakyReLU(0.2, inplace=True),
121
+ nn.Linear(1024, 512)
122
+ )
123
+
124
+ # Unconditional classifier (real/fake without text conditioning)
125
+ self.unconditional_classifier = nn.Sequential(
126
+ nn.Linear(128 * 4 * 4, 1024),
127
+ nn.LeakyReLU(0.2, inplace=True),
128
+ nn.Dropout(0.5),
129
+ nn.Linear(1024, 1),
130
+ )
131
+
132
+ # Conditional classifier (text-conditioned real/fake)
133
+ self.conditional_classifier = nn.Sequential(
134
+ nn.Linear(128 * 4 * 4 + 512, 1024),
135
+ nn.LeakyReLU(0.2, inplace=True),
136
+ nn.Dropout(0.5),
137
+ nn.Linear(1024, 1),
138
+ )
139
+
140
+ def forward(self, images, text_features=None, text_mask=None, return_both=True):
141
+ img_features = self.img_path(images)
142
+ img_features_flat = img_features.view(img_features.size(0), -1) # Flatten
143
+
144
+ unconditional_output = self.unconditional_classifier(img_features_flat)
145
+
146
+ if not return_both:
147
+ return unconditional_output
148
+
149
+ if text_features is None or text_mask is None:
150
+ raise AttributeError("text_features and text_mask necessary for text conditioning")
151
+
152
+
153
+ # Encode text (mean pooling)
154
+ global_full_text = self.text_encoder(text_features, text_mask)
155
+ global_text = global_full_text.mean(dim=1)
156
+ text_features_encoded = self.text_path(global_text)
157
+
158
+ combined = torch.cat([img_features_flat, text_features_encoded], dim=1)
159
+ conditional_output = self.conditional_classifier(combined)
160
+
161
+ return unconditional_output, conditional_output
pikapikagen/evaluate_kid.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer
3
+ from model import Generator as PikaPikaGen
4
+ from data_loader import create_training_setup
5
+ from utils import denormalize_image
6
+ from torch_fidelity import calculate_metrics
7
+ import os
8
+ import tempfile
9
+ from PIL import Image
10
+ import shutil
11
+
12
+ CHECKPOINT_PATH = "pikapikagen/model_checkpoint/checkpoint_epoch_150.pth"
13
+
14
+ TOKENIZER_NAME = "prajjwal1/bert-mini"
15
+
16
+ BATCH_SIZE = 16 # Batch size for generating images
17
+ NUM_WORKERS = 2 # Number of workers for the data loader
18
+
19
+ KID_SUBSET_SIZE = 50
20
+ KID_NUM_SUBSETS = 20
21
+
22
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23
+
24
+
25
+ class PokemonKIDEvaluator:
26
+ """Evaluator class for computing KID metrics on PikaPikaGen."""
27
+
28
+ def __init__(self, checkpoint_path, device=DEVICE):
29
+ self.device = device
30
+ self.tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
31
+ self.checkpoint_path = checkpoint_path
32
+
33
+ self._load_model() # As in gradio demo
34
+
35
+ def _load_model(self):
36
+ self.generator = PikaPikaGen().to(self.device)
37
+
38
+ checkpoint = torch.load(self.checkpoint_path, map_location=self.device, weights_only=True)
39
+ self.generator.load_state_dict(checkpoint['generator_state_dict'])
40
+ self.generator.eval()
41
+
42
+ @staticmethod
43
+ def _tensor_to_pil(tensor: torch.Tensor) -> Image.Image:
44
+ denormalized = denormalize_image(tensor)
45
+ uint8_tensor = (denormalized * 255).clamp(0, 255).to(torch.uint8)
46
+ img_np = uint8_tensor.cpu().permute(1, 2, 0).numpy()
47
+ return Image.fromarray(img_np)
48
+
49
+ def _save_images_to_temp_dir(self, images_tensor: torch.Tensor, prefix: str) -> str:
50
+ """Save a batch of image tensors to a new temporary directory."""
51
+ temp_dir = tempfile.mkdtemp(prefix=f"pikakid_{prefix}_")
52
+ for i, img_tensor in enumerate(images_tensor):
53
+ pil_img = self._tensor_to_pil(img_tensor)
54
+ img_path = os.path.join(temp_dir, f"{i:06d}.png")
55
+ pil_img.save(img_path)
56
+ return temp_dir
57
+
58
+ def evaluate_kid(self, test_loader, resolution="256x256"):
59
+
60
+ all_real_images = []
61
+ all_generated_images = []
62
+
63
+ with torch.no_grad():
64
+ for batch in test_loader:
65
+ text_ids = batch["text"].to(self.device)
66
+ attention_mask = batch["attention_mask"].to(self.device)
67
+ real_images_256 = batch["image"] # (B, 3, 256, 256)
68
+
69
+ generated_256, generated_64 = self.generator(text_ids, attention_mask)
70
+
71
+ # Select the correct resolution for both real and generated images
72
+ if resolution == "256x256":
73
+ generated_images = generated_256
74
+ processed_real_images = real_images_256
75
+ elif resolution == "64x64":
76
+ generated_images = generated_64
77
+ processed_real_images = torch.nn.functional.interpolate(
78
+ real_images_256, size=(64, 64), mode='bilinear', align_corners=False
79
+ )
80
+ else:
81
+ raise ValueError(f"Unsupported resolution: {resolution}")
82
+
83
+ all_real_images.append(processed_real_images.cpu())
84
+ all_generated_images.append(generated_images.cpu())
85
+
86
+ # Combine all batches into single tensors
87
+ all_real_images = torch.cat(all_real_images, dim=0)
88
+ all_generated_images = torch.cat(all_generated_images, dim=0)
89
+
90
+ # Save images to temporary directories for torch-fidelity
91
+ real_temp_dir = self._save_images_to_temp_dir(all_real_images, "real")
92
+ generated_temp_dir = self._save_images_to_temp_dir(all_generated_images, "generated")
93
+
94
+ metrics = calculate_metrics(
95
+ input1=generated_temp_dir, # Path to generated (fake) images
96
+ input2=real_temp_dir, # Path to real images
97
+ kid=True,
98
+ kid_subset_size=KID_SUBSET_SIZE,
99
+ kid_subsets=KID_NUM_SUBSETS,
100
+ batch_size=BATCH_SIZE,
101
+ device=self.device
102
+ )
103
+
104
+ kid_mean = metrics['kernel_inception_distance_mean']
105
+ kid_std = metrics['kernel_inception_distance_std']
106
+
107
+ # Clean up the temporary directories
108
+ shutil.rmtree(real_temp_dir)
109
+ shutil.rmtree(generated_temp_dir)
110
+
111
+ return kid_mean, kid_std
112
+
113
+ def main():
114
+ tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
115
+ training_setup = create_training_setup(
116
+ tokenizer=tokenizer,
117
+ test_set_size=0.2,
118
+ val_set_size=0.1,
119
+ batch_size=BATCH_SIZE,
120
+ num_workers=NUM_WORKERS,
121
+ random_seed=42, # Use a fixed seed for a reproducible split
122
+ )
123
+ test_loader = training_setup['test_loader']
124
+ test_set_size = len(test_loader.dataset)
125
+
126
+ evaluator = PokemonKIDEvaluator(checkpoint_path=CHECKPOINT_PATH)
127
+
128
+ resolutions_to_test = ['64x64', '256x256']
129
+
130
+ print(f"Checkpoint: {CHECKPOINT_PATH}")
131
+ print(f"Test samples: {test_set_size}")
132
+ print(f"KID Subset Size: {KID_SUBSET_SIZE}")
133
+ print(f"KID Subsets: {KID_NUM_SUBSETS}")
134
+
135
+ for res in resolutions_to_test:
136
+ kid_mean, kid_std = evaluator.evaluate_kid(test_loader, resolution=res)
137
+ print(f"Resolution {res}:\t KID = {kid_mean:.6f} ± {kid_std:.6f}")
138
+
139
+
140
+ if __name__ == "__main__":
141
+ main()
pikapikagen/gradio_demo.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import gradio.themes
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ from transformers import AutoTokenizer
7
+ from model import Generator as PikaPikaGen
8
+ from utils import denormalize_image
9
+ from plots import plot_attention_visualization
10
+ import os
11
+
12
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
+ CHECKPOINT_PATH = "model_checkpoints/checkpoint_epoch_150.pth"
14
+ TOKENIZER_NAME = "prajjwal1/bert-mini"
15
+
16
+
17
+ class PokemonGenerator:
18
+ """Main class for the Pokemon generation demo"""
19
+
20
+ def __init__(self):
21
+ self.device = DEVICE
22
+ self.tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
23
+
24
+ self._load_model()
25
+
26
+
27
+ def _load_model(self):
28
+ """Load the trained PikaPikaGen model"""
29
+ try:
30
+ # Initialize model
31
+ self.generator = PikaPikaGen().to(self.device)
32
+
33
+ # Load checkpoint
34
+ checkpoint = torch.load(CHECKPOINT_PATH, map_location=self.device, weights_only=True)
35
+
36
+ # Load saved weights into model
37
+ self.generator.load_state_dict(checkpoint['generator_state_dict'])
38
+ print(f"✅ Generator loaded from checkpoint (epoch {checkpoint.get('epoch', 'unknown')})")
39
+
40
+ # No training
41
+ self.generator.eval()
42
+
43
+ except Exception as e:
44
+ print(f"❌ Error loading model: {e}")
45
+ raise
46
+
47
+ def _tensor_to_pil(self, tensor):
48
+ """Convert tensor to PIL Image"""
49
+ # tensor shape: (3, H, W)
50
+ img_np = tensor.permute(1, 2, 0).clamp(0, 1).numpy()
51
+ img_np = (img_np * 255).astype(np.uint8)
52
+ return Image.fromarray(img_np)
53
+
54
+ def generate_pokemon(self, description, num_samples=4, show_attention=False, resolution="both"):
55
+ """
56
+ Generate Pokemon sprites from text description
57
+
58
+ Args:
59
+ description (str): Text description of the desired Pokemon
60
+ num_samples (int): Number of samples to generate (1-8)
61
+ show_attention (bool): Whether to show attention visualization
62
+ resolution (str): Output resolution - "256x256", "64x64", or "both"
63
+
64
+ Returns:
65
+ tuple: (generated_images, attention_plot)
66
+ """
67
+ if not description.strip():
68
+ return [], "❌ Please enter a description."
69
+
70
+ # No reason to compute gradients
71
+ with torch.no_grad():
72
+ tokens = self.tokenizer(
73
+ description,
74
+ max_length=128,
75
+ padding='max_length',
76
+ truncation=True,
77
+ return_tensors='pt'
78
+ )
79
+
80
+ text_ids = tokens['input_ids'].repeat(num_samples, 1).to(self.device)
81
+ attention_mask = tokens['attention_mask'].repeat(num_samples, 1).to(self.device)
82
+
83
+ generated_256, generated_64, attention_maps, initial_weights = self.generator(
84
+ text_ids, attention_mask, return_attentions=True
85
+ )
86
+
87
+ # Convert tensors to PIL images
88
+ output_images = []
89
+ images_to_process = []
90
+ if resolution in ["256x256", "both"]:
91
+ images_to_process.append(generated_256)
92
+ if resolution in ["64x64", "both"]:
93
+ images_to_process.append(generated_64)
94
+
95
+ for img_batch in images_to_process:
96
+ img_batch_denorm = denormalize_image(img_batch.cpu())
97
+ for i in range(num_samples):
98
+ img_pil = self._tensor_to_pil(img_batch_denorm[i])
99
+ output_images.append(img_pil)
100
+
101
+ attention_plot = None
102
+ if show_attention:
103
+ # Create directory if it doesn't exist
104
+ output_dir = "attention_visualizations"
105
+ os.makedirs(output_dir, exist_ok=True)
106
+
107
+ # Create a more descriptive ID for the file
108
+ # To avoid overwriting the same file with the same name
109
+ plot_id = description.strip().replace(" ", "_")[:30]
110
+
111
+
112
+ # Use the first sample for the attention visualization
113
+ attention_plot = plot_attention_visualization(
114
+ epoch=0,
115
+ set_name="demo",
116
+ output_dir=output_dir,
117
+
118
+ generated_images=generated_256,
119
+
120
+ # Full batch data from the model
121
+ decoder_attention_maps=attention_maps,
122
+ initial_context_weights=initial_weights,
123
+
124
+ token_ids=text_ids,
125
+ attention_mask=attention_mask,
126
+ tokenizer=self.tokenizer,
127
+
128
+ # Metadata for the specific sample
129
+ description=description,
130
+ pokemon_id=plot_id,
131
+
132
+ sample_idx=0,
133
+ show_inline=False
134
+ )
135
+
136
+ return output_images, attention_plot
137
+
138
+
139
+ print("Initializing PikaPikaGen Demo...")
140
+ pokemon_gen = PokemonGenerator()
141
+
142
+ def generate_pokemon_interface(description, num_samples, show_attention, resolution):
143
+ images, attention_plot = pokemon_gen.generate_pokemon(
144
+ description=description,
145
+ num_samples=num_samples,
146
+ show_attention=show_attention,
147
+ resolution=resolution
148
+ )
149
+
150
+ if images is None:
151
+ return [], attention_plot # attention_plot contains error message if error
152
+
153
+ status_msg = f"Generated {len(images)} Pokemon sprites"
154
+ if resolution == "both":
155
+ status_msg += f" ({num_samples} at 256x256 + {num_samples} at 64x64)"
156
+ else:
157
+ status_msg += f" at {resolution}"
158
+
159
+ return images, attention_plot
160
+
161
+ def create_interface():
162
+ with gr.Blocks(
163
+ title="PikaPikaGen: AI Pokemon Generator",
164
+ theme=gradio.themes.Soft(),
165
+ css="""
166
+ .main-header {
167
+ text-align: center;
168
+ background: linear-gradient(45deg, #ff6b6b, #4ecdc4);
169
+ -webkit-background-clip: text;
170
+ -webkit-text-fill-color: transparent;
171
+ font-size: 2.5em;
172
+ font-weight: bold;
173
+ margin-bottom: 0.5em;
174
+ }
175
+ .description {
176
+ text-align: center;
177
+ font-size: 1.1em;
178
+ color: #666;
179
+ margin-bottom: 1em;
180
+ }
181
+ """
182
+ ) as demo:
183
+
184
+ gr.HTML("""
185
+ <div class="main-header">🎮 PikaPikaGen: AI Pokemon Generator</div>
186
+ <div class="description">
187
+ Creation of Pokemon sprites from text descriptions using Transformer attention and CNN generation.
188
+ </div>
189
+ """)
190
+
191
+ with gr.Row():
192
+ with gr.Column(scale=1):
193
+ gr.Markdown("### 📝 Input")
194
+
195
+ description_input = gr.Textbox(
196
+ label="Pokemon Description",
197
+ placeholder="Describe your Pokemon! e.g., 'A fire dragon with golden scales and ruby eyes'",
198
+ lines=3,
199
+ value="A legendary fire dragon pokemon with golden scales and red eyes"
200
+ )
201
+
202
+ with gr.Row():
203
+ num_samples = gr.Slider(
204
+ minimum=1, maximum=8, value=4, step=1,
205
+ label="Number of samples"
206
+ )
207
+
208
+ resolution = gr.Radio(
209
+ choices=["256x256", "64x64", "both"],
210
+ value="256x256",
211
+ label="Output resolution"
212
+ )
213
+
214
+ show_attention = gr.Checkbox(
215
+ label="Show attention visualization",
216
+ value=True,
217
+ info="Visualize which words the model focuses on"
218
+ )
219
+
220
+ generate_btn = gr.Button(
221
+ "🎨 Generate Pokemon!",
222
+ variant="primary",
223
+ size="lg"
224
+ )
225
+
226
+ with gr.Column(scale=2):
227
+ gr.Markdown("### 🎨 Generated Pokemon")
228
+
229
+ output_gallery = gr.Gallery(
230
+ label="Generated Pokemon sprites",
231
+ show_label=True,
232
+ elem_id="gallery",
233
+ columns=2,
234
+ rows=2,
235
+ height="auto",
236
+ allow_preview=True
237
+ )
238
+
239
+ attention_output = gr.Image(
240
+ label="Attention visualization",
241
+ show_label=True,
242
+ interactive=False
243
+ )
244
+
245
+ # Examples section
246
+ gr.Markdown("### 🌟 Examples to try")
247
+ gr.Examples(
248
+ examples=[
249
+ ["A fire dragon with golden scales and red eyes", 4, True, "256x256"],
250
+ ["An electric mouse with yellow fur and lightning bolts", 3, False, "both"],
251
+ ["A water turtle with blue shell and powerful jaws", 2, True, "256x256"],
252
+ ["A psychic cat with purple fur and mystical powers", 4, True, "256x256"],
253
+ ["A grass serpent with emerald scales and vine whips", 3, False, "64x64"],
254
+ ["An ice phoenix with crystal wings and frozen flames", 4, True, "256x256"],
255
+ ["A dark wolf with shadow abilities and glowing eyes", 2, True, "both"],
256
+ ["A steel robot pokemon with metallic armor and laser beams", 3, False, "256x256"]
257
+ ],
258
+ inputs=[description_input, num_samples, show_attention, resolution],
259
+ outputs=[output_gallery, attention_output],
260
+ fn=generate_pokemon_interface,
261
+ cache_examples=False
262
+ )
263
+
264
+ # Event handlers
265
+ generate_btn.click(
266
+ fn=generate_pokemon_interface,
267
+ inputs=[description_input, num_samples, show_attention, resolution],
268
+ outputs=[output_gallery, attention_output]
269
+ )
270
+
271
+ # Footer
272
+ gr.Markdown("""
273
+ ---
274
+ **PikaPikaGen** - Text-to-Image Pokemon Generation using Transformer + CNN
275
+ """)
276
+
277
+ return demo
278
+
279
+ if __name__ == "__main__":
280
+ print("Starting PikaPikaGen Demo...")
281
+
282
+ # Create and launch interface
283
+ demo = create_interface()
284
+
285
+ demo.launch(
286
+ server_name="0.0.0.0", # Allow external access
287
+ share=False, # Set to True for public sharing
288
+ debug=False,
289
+ show_error=True,
290
+ inbrowser=True # Auto-open browser
291
+ )
pikapikagen/losses.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision import models
5
+ from torchvision.models import VGG19_Weights
6
+
7
+
8
+ class VGGPerceptualLoss(nn.Module):
9
+ """
10
+ Perceptual loss using VGG19 pretrained on ImageNet.
11
+ We extract features at:
12
+ - relu1_2 (index: 3)
13
+ - relu2_2 (index: 8)
14
+ - relu3_2 (index: 17)
15
+ - relu4_2 (index: 26)
16
+ Then compute L1 distance between those feature maps.
17
+ Input images are in [-1,1]. We convert to [0,1], then normalize with ImageNet stats.
18
+ """
19
+ def __init__(self, device):
20
+ super(VGGPerceptualLoss, self).__init__()
21
+ vgg19_features = models.vgg19(weights=VGG19_Weights.DEFAULT).features.to(device).eval()
22
+ # We only need layers up to 26 (relu4_2)
23
+ self.slices = nn.ModuleDict({
24
+ "relu1_2": nn.Sequential(*list(vgg19_features.children())[:4]), # conv1_1, relu1_1, conv1_2, relu1_2
25
+ "relu2_2": nn.Sequential(*list(vgg19_features.children())[4:9]), # pool1, conv2_1, relu2_1, conv2_2, relu2_2
26
+ "relu3_2": nn.Sequential(*list(vgg19_features.children())[9:18]), # pool2, conv3_1, relu3_1, conv3_2, relu3_2, ...
27
+ "relu4_2": nn.Sequential(*list(vgg19_features.children())[18:27]) # pool3, conv4_1, relu4_1, conv4_2, relu4_2
28
+ })
29
+ for param in self.parameters():
30
+ param.requires_grad = False
31
+
32
+ self.l1 = nn.L1Loss()
33
+ self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1))
34
+ self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1))
35
+
36
+ def forward(self, img_gen, img_ref):
37
+ """
38
+ img_gen, img_ref: [B,3,H,W] in range [-1,1].
39
+ Return: sum of L1 distances between VGG feature maps at chosen layers.
40
+ """
41
+ # Convert to [0,1]
42
+ gen = (img_gen + 1.0) / 2.0
43
+ ref = (img_ref + 1.0) / 2.0
44
+ # Normalize
45
+ gen_norm = (gen - self.mean) / self.std
46
+ ref_norm = (ref - self.mean) / self.std
47
+
48
+ loss = 0.0
49
+ x_gen = gen_norm
50
+ x_ref = ref_norm
51
+ for slice_mod in self.slices.values():
52
+ x_gen = slice_mod(x_gen)
53
+ x_ref = slice_mod(x_ref)
54
+ loss += self.l1(x_gen, x_ref)
55
+ return loss
56
+
57
+
58
+ class SobelLoss(nn.Module):
59
+ """
60
+ Computes the Sobel loss between two images, which encourages edge similarity.
61
+ This loss operates on the grayscale versions of the input images.
62
+ """
63
+ def __init__(self):
64
+ super(SobelLoss, self).__init__()
65
+ # Sobel kernels for edge detection
66
+ self.kernel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3)
67
+ self.kernel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3)
68
+ self.l1 = nn.L1Loss()
69
+
70
+ # Grayscale conversion weights (ITU-R BT.601)
71
+ self.rgb_to_gray_weights = torch.tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)
72
+
73
+ def _get_edges(self, img):
74
+ """
75
+ Converts an RGB image to grayscale and applies Sobel filters.
76
+ Args:
77
+ img: [B, 3, H, W] image tensor in range [-1, 1].
78
+ Returns:
79
+ Gradient magnitude map [B, 1, H, W].
80
+ """
81
+
82
+ # Convert from [-1, 1] to [0, 1]
83
+ img = (img + 1.0) / 2.0
84
+
85
+ # Convert to grayscale
86
+ grayscale_img = F.conv2d(img, self.rgb_to_gray_weights.to(img.device))
87
+
88
+ # Apply Sobel filters
89
+ grad_x = F.conv2d(grayscale_img, self.kernel_x.to(img.device), padding=1)
90
+ grad_y = F.conv2d(grayscale_img, self.kernel_y.to(img.device), padding=1)
91
+
92
+ # Compute gradient magnitude
93
+ edges = torch.sqrt(grad_x**2 + grad_y**2 + 1e-6) # add epsilon for stability
94
+ return edges
95
+
96
+ def forward(self, img_gen, img_ref):
97
+ """
98
+ img_gen, img_ref: [B, 3, H, W] in range [-1, 1].
99
+ Returns: L1 loss between the edge maps of the two images.
100
+ """
101
+ edges_gen = self._get_edges(img_gen)
102
+ edges_ref = self._get_edges(img_ref)
103
+ return self.l1(edges_gen, edges_ref)
pikapikagen/model.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from model_blocks.text_encoder import TextEncoder
4
+ from model_blocks.image_decoder import ImageDecoder
5
+
6
+ class Generator(nn.Module):
7
+ """
8
+ Modello completo che unisce Encoder e Decoder.
9
+ """
10
+ def __init__(self, text_encoder_model_name="prajjwal1/bert-mini", noise_dim=100):
11
+ super().__init__()
12
+ self.text_encoder = TextEncoder(
13
+ model_name=text_encoder_model_name,
14
+ )
15
+
16
+ text_embed_dim = 256
17
+
18
+ self.image_decoder = ImageDecoder(
19
+ noise_dim=noise_dim,
20
+ text_embed_dim=text_embed_dim
21
+ )
22
+
23
+ self.noise_dim = noise_dim
24
+
25
+ def forward(self, token_ids, attention_mask, return_attentions=False):
26
+ # token_ids.shape: (batch_size, seq_len)
27
+ # attention_mask.shape: (batch_size, seq_len)
28
+ # Genera rumore casuale per il batch
29
+ batch_size = token_ids.size(0)
30
+ # noise.shape: (batch_size, noise_dim)
31
+ noise = torch.randn(batch_size, self.noise_dim, device=token_ids.device)
32
+
33
+ # 1. Codifica il testo per ottenere i vettori di ogni parola
34
+ # encoder_output.shape: (batch_size, seq_len, text_embed_dim)
35
+ encoder_output = self.text_encoder(token_ids, attention_mask=attention_mask)
36
+
37
+ # 2. Genera l'immagine usando l'output completo dell'encoder
38
+ # Il decoder calcolerà internamente sia il contesto iniziale (ATTENZIONE #1)
39
+ # sia l'attenzione per-step (ATTENZIONE #2)
40
+ # generated_image_256.shape: (batch_size, 3, 256, 256)
41
+ # generated_image_64.shape: (batch_size, 3, 64, 64)
42
+ generated_image_256, generated_image_64, attention_maps, initial_attention_weights = self.image_decoder(noise, encoder_output, attention_mask)
43
+
44
+ if return_attentions:
45
+ return generated_image_256, generated_image_64, attention_maps, initial_attention_weights
46
+ return generated_image_256, generated_image_64
pikapikagen/model_blocks/decoder_block.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from model_blocks.image_cross_attention import ImageCrossAttention
4
+
5
+ class DecoderBlock(nn.Module):
6
+ """
7
+ Image decoder block
8
+ Channel adaptation (if necessary) -> Attention (optional) -> Merge -> Residual connection
9
+ -> Upsampling (ConvTranspose) -> Normalization -> Activation.
10
+ """
11
+ def __init__(self, in_channels, out_channels, use_attention=True, text_embed_dim=256, nhead=4):
12
+ super().__init__()
13
+ self.use_attention = use_attention
14
+
15
+ if self.use_attention:
16
+ # If in_channels is different from text_embed_dim, add a 1x1 conv to adapt the channel size
17
+ if in_channels != text_embed_dim:
18
+ self.channel_adapter = nn.Conv2d(in_channels, text_embed_dim, kernel_size=1, bias=False)
19
+ else:
20
+ self.channel_adapter = None
21
+
22
+ self.cross_attention = ImageCrossAttention(embed_dim=text_embed_dim, num_heads=nhead)
23
+ # Convolution to merge the text_embedding and the cross-attention output
24
+ self.fusion_conv = nn.Conv2d(text_embed_dim * 2, in_channels, kernel_size=1, bias=False)
25
+
26
+ # Upsample block as described in the instructions
27
+ self.upsample_block = nn.Sequential(
28
+ nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
29
+ nn.GroupNorm(1, out_channels),
30
+ nn.LeakyReLU(inplace=True)
31
+ )
32
+
33
+ def forward(self, x, encoder_output=None, attention_mask=None):
34
+ attn_weights = None
35
+ if self.use_attention:
36
+ if encoder_output is None or attention_mask is None:
37
+ raise ValueError("encoder_output and attention_mask must be provided for attention.")
38
+
39
+ # Adapt channel size if deemed necessary
40
+ if self.channel_adapter is not None:
41
+ x_adapted = self.channel_adapter(x)
42
+ else:
43
+ x_adapted = x
44
+
45
+ attn_output, attn_weights = self.cross_attention(
46
+ image_features=x_adapted,
47
+ text_features=encoder_output,
48
+ key_padding_mask=attention_mask
49
+ )
50
+
51
+ # Concatenates the features with the cross-attention output,
52
+ # then conv 1x1 and residual connection
53
+ fused_features = torch.cat([x_adapted, attn_output], dim=1) # Shape: (B, 2*in_channels, H, W)
54
+ skip = self.fusion_conv(fused_features) # Shape: (B, in_channels, H, W)
55
+ x = x + skip # Shape: (B, in_channels, H, W)
56
+
57
+
58
+ x = self.upsample_block(x)
59
+ return x, attn_weights
pikapikagen/model_blocks/image_cross_attention.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ class ImageCrossAttention(nn.Module):
4
+ """
5
+ Image cross-attention module
6
+ Allows a sequence of queries (from the image) to "pay attention"
7
+ to a sequence of key/value (from the text), internally managing
8
+ the reshaping of tensors and the attention mask.
9
+ """
10
+ def __init__(self, embed_dim, num_heads):
11
+ super().__init__()
12
+ self.attention = nn.MultiheadAttention(
13
+ embed_dim=embed_dim, num_heads=num_heads, batch_first=True
14
+ )
15
+ self.layer_norm = nn.LayerNorm(embed_dim)
16
+
17
+ def forward(self, image_features, text_features, key_padding_mask=None):
18
+ # query: (B, C, H, W) - Image features
19
+ # key/value: (B, seq_len, embed_dim) - Text encoder output
20
+ # key_padding_mask: (B, seq_len) - Attention mask from the tokenizer
21
+
22
+ B, C, H, W = image_features.shape
23
+
24
+ # Reshape from image to sequence: (B, C, H, W) -> (B, H*W, C)
25
+ query_seq = image_features.view(B, C, H * W).permute(0, 2, 1)
26
+ query_norm = self.layer_norm(query_seq)
27
+
28
+ # Prepare the padding mask from the attention mask
29
+ # The HuggingFace mask is 1 for real tokens, 0 for padding.
30
+ # MultiheadAttention expects True for positions to ignore.
31
+ if key_padding_mask is not None:
32
+ mask = (key_padding_mask == 0)
33
+ else:
34
+ mask = None
35
+
36
+ attn_output, attn_weights = self.attention(
37
+ query=query_norm,
38
+ key=text_features,
39
+ value=text_features,
40
+ key_padding_mask=mask,
41
+ need_weights=True
42
+ )
43
+ # attn_output: (B, H*W, C)
44
+
45
+ # Convert output back into its original size
46
+ # (B, H*W, C) -> (B, C, H*W) -> (B, C, H, W)
47
+ attn_output_spatial = attn_output.permute(0, 2, 1).view(B, C, H, W)
48
+
49
+ return attn_output_spatial, attn_weights
pikapikagen/model_blocks/image_decoder.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from model_blocks.decoder_block import DecoderBlock
4
+
5
+ class ImageDecoder(nn.Module):
6
+ """
7
+ Decoder CNN (Generatore) che sintetizza l'immagine.
8
+ Questa versione usa l'attenzione per-step fin dall'inizio.
9
+ """
10
+ def __init__(self, noise_dim, text_embed_dim, final_image_channels=3):
11
+ super().__init__()
12
+
13
+ # Mechanism to calculate attention scores for the initial context.
14
+ self.initial_context_scorer = nn.Sequential(
15
+ nn.Linear(in_features=text_embed_dim, out_features=512),
16
+ nn.Tanh(),
17
+ nn.Linear(in_features=512, out_features=1)
18
+ # Softmax applied in forward pass to use the attention mask
19
+ )
20
+
21
+ # Initial linear projection to a 4x4 feature map.
22
+ self.initial_projection = nn.Sequential(
23
+ nn.Linear(noise_dim + text_embed_dim, 256 * 4 * 4),
24
+ nn.GroupNorm(1, 256 * 4 * 4),
25
+ nn.LeakyReLU(inplace=True)
26
+ )
27
+
28
+ # Shared blocks for both resolutions (until 64x64)
29
+ self.blocks_64 = nn.ModuleList([
30
+ # Input: (B, 256, 4, 4) -> Output: (B, 256, 8, 8)
31
+ DecoderBlock(in_channels=256, out_channels=256, use_attention=True),
32
+ # Input: (B, 256, 8, 8) -> Output: (B, 256, 16, 16)
33
+ DecoderBlock(in_channels=256, out_channels=256, use_attention=True),
34
+ # Input: (B, 256, 16, 16) -> Output: (B, 128, 32, 32)
35
+ DecoderBlock(in_channels=256, out_channels=128, use_attention=True),
36
+ # Input: (B, 128, 32, 32) -> Output: (B, 64, 64, 64)
37
+ DecoderBlock(in_channels=128, out_channels=64, use_attention=False),
38
+ ])
39
+
40
+ # ModuleList is used instead of a Sequential for example because
41
+ # of the branching based on use_attention in the forward pass
42
+
43
+ # Blocks only for 256x256 (from 64x64 to 256x256)
44
+ self.blocks_256 = nn.ModuleList([
45
+ # Input: (B, 64, 64, 64) -> Output: (B, 32, 128, 128)
46
+ DecoderBlock(in_channels=64, out_channels=32, use_attention=True),
47
+ # Input: (B, 32, 128, 128) -> Output: (B, 16, 256, 256)
48
+ DecoderBlock(in_channels=32, out_channels=16, use_attention=False),
49
+ ])
50
+
51
+ # Last layer to get to RGB channels - 256x256
52
+ # Input: (B, 16, 256, 256) -> Output: (B, 3, 256, 256)
53
+ self.final_conv_256 = nn.Conv2d(16, final_image_channels, kernel_size=3, padding=1)
54
+ self.final_activation_256 = nn.Tanh()
55
+
56
+ # Last layer to get to RGB channels - 64x64
57
+ # Input: (B, 64, 64, 64) -> Output: (B, 3, 64, 64)
58
+ self.final_conv_64 = nn.Conv2d(64, final_image_channels, kernel_size=3, padding=1)
59
+ self.final_activation_64 = nn.Tanh()
60
+
61
+ def forward(self, noise, encoder_output_full, attention_mask):
62
+ # noise.shape: (B, noise_dim)
63
+ # encoder_output_full.shape: (B, seq_len, text_embed_dim)
64
+ # attention_mask.shape: (B, seq_len)
65
+
66
+ # 1. Compute the first attention, with the scores (logits) for each token
67
+ attn_scores = self.initial_context_scorer(encoder_output_full)
68
+
69
+ # Apply attention mask before Softmax.
70
+ # Set the scores of the padding tokens, where attention mask is 0, to -inf.
71
+ # The mask is (B, seq_len), the scores (B, seq_len, 1)
72
+ # The unsqueeze takes care of the dimension diference.
73
+ attn_scores.masked_fill_(attention_mask.unsqueeze(-1) == 0, -1e9)
74
+
75
+ # attention_weights.shape: (B, seq_len, 1)
76
+ attention_weights = torch.softmax(attn_scores, dim=1)
77
+
78
+ # Weighted average of the encoder output
79
+ # context_vector.shape: (B, text_embed_dim)
80
+ context_vector = torch.sum(attention_weights * encoder_output_full, dim=1)
81
+
82
+ # 2. Merge the noise and the context vector for the initial projection
83
+ # initial_input.shape: (B, noise_dim + text_embed_dim)
84
+ initial_input = torch.cat([noise, context_vector], dim=1)
85
+
86
+ # 3. Initial projection and reshape to fit the transposed convolutions
87
+ # x.shape: (B, 256 * 4 * 4)
88
+ x = self.initial_projection(initial_input)
89
+ # x.shape: (B, 256, 4, 4)
90
+ x = x.view(x.size(0), 256, 4, 4)
91
+
92
+ # 4. Pass through the encoder blocks
93
+ attention_maps = []
94
+
95
+ # Shared path for both resolutions (fino a 64x64)
96
+ for block in self.blocks_64:
97
+ encoder_ctx = encoder_output_full if block.use_attention else None
98
+ mask_ctx = attention_mask if block.use_attention else None
99
+ x, attn_weights = block(x, encoder_ctx, mask_ctx)
100
+ if attn_weights is not None:
101
+ attention_maps.append(attn_weights)
102
+
103
+ # Now x has size (B, 64, 64, 64)
104
+
105
+ # 64x64-only path
106
+ image_64 = self.final_conv_64(x)
107
+ image_64 = self.final_activation_64(image_64)
108
+
109
+ # 5. 256x256-only path
110
+ for block in self.blocks_256:
111
+ encoder_ctx = encoder_output_full if block.use_attention else None
112
+ mask_ctx = attention_mask if block.use_attention else None
113
+ x, attn_weights = block(x, encoder_ctx, mask_ctx)
114
+ if attn_weights is not None:
115
+ attention_maps.append(attn_weights)
116
+
117
+ # Final layer for 256x256
118
+ # x_256.shape: (B, 16, 256, 256) -> (B, 3, 256, 256)
119
+ image_256 = self.final_conv_256(x)
120
+ image_256 = self.final_activation_256(image_256)
121
+
122
+ return image_256, image_64, attention_maps, attention_weights
pikapikagen/model_blocks/text_encoder.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from transformers import AutoModel
3
+
4
+ class TextEncoder(nn.Module):
5
+ """
6
+ Text encoder
7
+ Uses bert-mini embeddings and passes them through a Transformer.
8
+ """
9
+ def __init__(self, model_name="prajjwal1/bert-mini", fine_tune_embeddings=True):
10
+ super().__init__()
11
+ # Load the pre-trained bert-mini model for embeddings
12
+ bert_mini_model = AutoModel.from_pretrained(model_name)
13
+
14
+ self.embedding = bert_mini_model.embeddings
15
+
16
+ # Set whether to fine-tune the embeddings during training
17
+ for param in self.embedding.parameters():
18
+ param.requires_grad = fine_tune_embeddings
19
+
20
+ encoder_layer = nn.TransformerEncoderLayer(
21
+ d_model=256, nhead=4, dim_feedforward=1024, batch_first=True
22
+ )
23
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
24
+
25
+ def forward(self, token_ids, attention_mask=None):
26
+ # Get the embeddings from the tokens
27
+ # Shape: (batch_size, seq_len) -> (batch_size, seq_len, embedding_dim)
28
+ embedded_text = self.embedding(token_ids)
29
+
30
+ # Prepare the padding mask for TransformerEncoder
31
+ # The HuggingFace mask is 1 for real tokens, 0 for padding.
32
+ # TransformerEncoder expects True for positions to ignore (padding).
33
+ src_key_padding_mask = None
34
+ if attention_mask is not None:
35
+ src_key_padding_mask = (attention_mask == 0)
36
+
37
+ # Pass the embeddings through the Transformer Encoder with the mask
38
+ # Shape: (batch_size, seq_len, embedding_dim) -> (batch_size, seq_len, embedding_dim)
39
+ encoder_output = self.transformer_encoder(
40
+ src=embedded_text,
41
+ src_key_padding_mask=src_key_padding_mask
42
+ )
43
+ return encoder_output
pikapikagen/model_checkpoint/checkpoint_epoch_150.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7902ac75581c4a54ec5345ccf2bd30440a99a4c1031b5c12af6cabb318dde225
3
+ size 789795998
pikapikagen/plots.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import os
5
+ import io
6
+ from PIL import Image
7
+ from utils import denormalize_image
8
+ import torch.nn.functional as F
9
+ from transformers import AutoTokenizer
10
+
11
+
12
+ def save_attention_visualization(
13
+ epoch, model, tokenizer, batch, device, set_name, output_dir, show_inline=False
14
+ ):
15
+ print(f"Epoch {epoch}: Generating attention visualization for {set_name} set...")
16
+
17
+ attention_data = generate_attention_data(model, tokenizer, batch, device)
18
+
19
+ if attention_data:
20
+ plot_attention_visualization(
21
+ epoch=epoch,
22
+ set_name=set_name,
23
+ output_dir=output_dir,
24
+ show_inline=show_inline,
25
+ **attention_data,
26
+ )
27
+ print(f"Epoch {epoch}: Attention visualization saved for Pokémon #{attention_data['pokemon_id']}.")
28
+ else:
29
+ print(f"Epoch {epoch}: Skipped attention visualization due to missing data.")
30
+
31
+
32
+ def generate_attention_data(model, tokenizer, batch, device):
33
+ """
34
+ Runs the model to generate the image and attention maps, filtering the padding tokens.
35
+ """
36
+ model.eval()
37
+
38
+ with torch.no_grad():
39
+ token_ids = batch["text"].to(device)
40
+ attention_mask = batch["attention_mask"].to(device)
41
+ # Ensure batch size is 1 for visualization
42
+ if token_ids.dim() > 1:
43
+ token_ids = token_ids[0].unsqueeze(0)
44
+ attention_mask = attention_mask[0].unsqueeze(0)
45
+
46
+ # Get the first sample from the batch
47
+ pokemon_id = batch["idx"][0]
48
+ description = batch["description"][0]
49
+
50
+ generated_image, attention_maps, initial_context_weights = model(
51
+ token_ids, attention_mask, return_attentions=True
52
+ )
53
+
54
+ decoder_attention_maps = [m for m in attention_maps if m is not None]
55
+
56
+ if not decoder_attention_maps or initial_context_weights is None:
57
+ print("Attention maps not available. Skipping data generation.")
58
+ return None
59
+
60
+ # Extract valid tokens to display
61
+ tokens_all = tokenizer.convert_ids_to_tokens(token_ids.squeeze(0))
62
+ display_tokens = []
63
+ for i, token in enumerate(tokens_all):
64
+ if (
65
+ token not in [tokenizer.sep_token, tokenizer.pad_token]
66
+ and attention_mask[0, i] == 1
67
+ ):
68
+ display_tokens.append({"token": token, "index": i})
69
+
70
+ if not display_tokens:
71
+ print(f"No valid tokens to display for '{description}'. Skipping.")
72
+ return None
73
+
74
+ return {
75
+ "generated_image": generated_image.cpu(),
76
+ "decoder_attention_maps": [m.cpu() for m in decoder_attention_maps],
77
+ "initial_context_weights": initial_context_weights.cpu(),
78
+ "display_tokens": display_tokens,
79
+ "description": description,
80
+ "pokemon_id": pokemon_id,
81
+ }
82
+
83
+
84
+ def plot_attention_visualization(
85
+ # Plot identification arguments
86
+ epoch: int,
87
+ set_name: str,
88
+ output_dir: str | None,
89
+ # Data generated by the model (can be full batches)
90
+ generated_images: torch.Tensor,
91
+ decoder_attention_maps: list[torch.Tensor],
92
+ initial_context_weights: torch.Tensor,
93
+ # Original text input (can be a full batch)
94
+ token_ids: torch.Tensor,
95
+ attention_mask: torch.Tensor,
96
+ tokenizer: AutoTokenizer,
97
+ # Batch metadata (for the specific sample)
98
+ description: str,
99
+ pokemon_id: int | str,
100
+ # Control options
101
+ sample_idx: int = 0,
102
+ show_inline: bool = False,
103
+ ):
104
+ """
105
+ Generates and saves an attention visualization for a single sample from a batch.
106
+
107
+ This function is self-contained: it accepts full batch tensors and internally
108
+ handles sample selection and token preparation.
109
+
110
+ Args:
111
+ epoch (int): Epoch number (for title/filename).
112
+ set_name (str): Set name (e.g., 'train', for title/filename).
113
+ output_dir (str, optional): Folder to save the image. If None, the plot is not saved.
114
+
115
+ generated_images (torch.Tensor): Tensor of generated images.
116
+ Shape: (B, C, H, W).
117
+ decoder_attention_maps (list[torch.Tensor]): List of attention tensors.
118
+ Each tensor shape: (B, num_patches, seq_length).
119
+ initial_context_weights (torch.Tensor): Initial attention weights.
120
+ Shape: (B, 1, seq_length).
121
+
122
+ token_ids (torch.Tensor): Input token.
123
+ Shape: (B, seq_length).
124
+ attention_mask (torch.Tensor): Attention mask for tokens.
125
+ Shape: (B, seq_length).
126
+ tokenizer: The tokenizer object for id -> token conversion.
127
+
128
+ description (str): The text prompt for the selected sample.
129
+ pokemon_id (int or str): The ID of the selected sample.
130
+
131
+ sample_idx (int, optional): Index of the sample in the batch to visualize.
132
+ Defaults to 0.
133
+ show_inline (bool, optional): If True, shows the plot. Defaults to False.
134
+ """
135
+ # Select the specific sample using sample_idx and move to CPU
136
+ img_tensor = generated_images[sample_idx].cpu()
137
+ layer_maps = [m[sample_idx].cpu() for m in decoder_attention_maps if m is not None]
138
+ initial_weights = initial_context_weights[sample_idx].cpu()
139
+ token_ids_sample = token_ids[sample_idx].cpu()
140
+ attention_mask_sample = attention_mask[sample_idx].cpu()
141
+
142
+ # Token filtering logic
143
+ tokens_all = tokenizer.convert_ids_to_tokens(token_ids_sample)
144
+ display_tokens = []
145
+ for i, token in enumerate(tokens_all):
146
+ if (
147
+ token not in [tokenizer.sep_token, tokenizer.pad_token]
148
+ and attention_mask_sample[i] == 1
149
+ ):
150
+ display_tokens.append({"token": token, "index": i})
151
+
152
+ img_tensor_cpu = denormalize_image(img_tensor).permute(1, 2, 0)
153
+ num_decoder_layers = len(layer_maps)
154
+ num_tokens = len(display_tokens)
155
+ token_indices_to_display = [t["index"] for t in display_tokens]
156
+
157
+ cols = min(num_tokens, 8)
158
+ rows_per_layer = (num_tokens + cols - 1) // cols
159
+ height_ratios = [3, 2] + [2 * rows_per_layer] * num_decoder_layers
160
+ fig_height = sum(height_ratios)
161
+ fig_width = max(20, 2.5 * cols)
162
+
163
+ fig = plt.figure(figsize=(fig_width, fig_height))
164
+ gs_main = fig.add_gridspec(len(height_ratios), 1, height_ratios=height_ratios, hspace=1.2)
165
+ fig.suptitle(f"Epoch {epoch}: Attention for Pokémon #{pokemon_id} ({set_name.capitalize()})", fontsize=24)
166
+
167
+ ax_main_img = fig.add_subplot(gs_main[0])
168
+ ax_main_img.imshow(img_tensor_cpu)
169
+ ax_main_img.set_title("Generated Image", fontsize=18)
170
+ ax_main_img.text(0.5, -0.1, f"Prompt: {description}", ha="center", va="top",
171
+ transform=ax_main_img.transAxes, fontsize=14, wrap=True)
172
+ ax_main_img.axis("off")
173
+
174
+ ax_initial_attn = fig.add_subplot(gs_main[1])
175
+ initial_weights_squeezed = initial_weights.squeeze().numpy()
176
+ token_strings = [t["token"] for t in display_tokens]
177
+ relevant_weights = initial_weights_squeezed[[t["index"] for t in display_tokens]]
178
+ ax_initial_attn.bar(np.arange(len(token_strings)), relevant_weights, color="skyblue")
179
+ ax_initial_attn.set_xticks(np.arange(len(token_strings)))
180
+ ax_initial_attn.set_xticklabels(token_strings, rotation=45, ha="right", fontsize=10)
181
+ ax_initial_attn.set_title("Initial Context Attention (Global)", fontsize=16)
182
+ ax_initial_attn.set_ylabel("Weight", fontsize=12)
183
+ ax_initial_attn.grid(axis="y", linestyle="--", alpha=0.7)
184
+
185
+ # Iterate through each decoder layer's attention maps
186
+ for i, layer_attn_map in enumerate(layer_maps):
187
+ # layer_attn_map shape is now (num_patches, seq_len)
188
+ map_size_flat = layer_attn_map.shape[0]
189
+ map_side = int(np.sqrt(map_size_flat))
190
+ layer_title = f"Decoder Cross-Attention Layer {i+1} (Size: {map_side}x{map_side})"
191
+
192
+ # Extract attention weights only for tokens we want to display
193
+ relevant_attn_maps = layer_attn_map[:, token_indices_to_display]
194
+ vmin, vmax = relevant_attn_maps.min(), relevant_attn_maps.max()
195
+
196
+ # Create subplot grid for this layer
197
+ gs_layer = gs_main[2 + i].subgridspec(rows_per_layer, cols + 1, wspace=0.2, hspace=0.4, width_ratios=[*([1] * cols), 0.1])
198
+ axes_in_layer = [fig.add_subplot(gs_layer[r, c]) for r in range(rows_per_layer) for c in range(cols)]
199
+
200
+ # Add layer title above the token attention maps
201
+ if axes_in_layer:
202
+ y_pos = axes_in_layer[0].get_position().y1
203
+ fig.text(0.5, y_pos + 0.01, layer_title, ha="center", va="bottom", fontsize=16, weight="bold")
204
+
205
+ # Plot attention heatmap for each token
206
+ im = None
207
+ for j, token_info in enumerate(display_tokens):
208
+ if j >= len(axes_in_layer):
209
+ break
210
+ ax = axes_in_layer[j]
211
+ attn_for_token = layer_attn_map[:, token_info["index"]]
212
+ # Reshape flat attention to spatial grid
213
+ heatmap = attn_for_token.reshape(map_side, map_side)
214
+ im = ax.imshow(heatmap, cmap="jet", interpolation="nearest", vmin=vmin, vmax=vmax)
215
+ ax.set_title(f"'{token_info['token']}'", fontsize=12)
216
+ ax.axis("off")
217
+
218
+ # Add colorbar for the layer
219
+ if im:
220
+ cax = fig.add_subplot(gs_layer[:, -1])
221
+ cbar = fig.colorbar(im, cax=cax)
222
+ cbar.ax.tick_params(labelsize=10)
223
+ cbar.set_label("Attention Weight", rotation=270, labelpad=15, fontsize=12)
224
+
225
+ # Hide unused subplots
226
+ for j in range(num_tokens, len(axes_in_layer)):
227
+ axes_in_layer[j].axis("off")
228
+
229
+ plt.tight_layout(rect=(0, 0.03, 1, 0.96))
230
+ if output_dir is not None:
231
+ save_path = os.path.join(output_dir, f"{epoch:03d}_{set_name}_attention_visualization_{pokemon_id}.png")
232
+ plt.savefig(save_path, bbox_inches="tight")
233
+
234
+ # Save figure to bytes for potential further use (e.g., logging)
235
+ buf = io.BytesIO()
236
+ fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
237
+ buf.seek(0)
238
+
239
+ # Convert to PIL image
240
+ attention_plot = Image.open(buf)
241
+
242
+ if show_inline:
243
+ plt.show()
244
+ plt.close(fig)
245
+
246
+ return attention_plot
247
+
248
+
249
+ def save_plot_losses(losses_g, losses_d, output_dir="training_output", show_inline=True):
250
+ """
251
+ Generates and saves a plot of the generator and discriminator losses.
252
+ """
253
+ os.makedirs(output_dir, exist_ok=True)
254
+
255
+ fig, ax = plt.subplots(figsize=(12, 6))
256
+ ax.plot(losses_g, label="Generator Loss", color="blue")
257
+ ax.plot(losses_d, label="Discriminator Loss", color="red")
258
+ ax.set_title("Training Losses")
259
+ ax.set_xlabel("Epochs")
260
+ ax.set_ylabel("Loss")
261
+ ax.legend()
262
+ ax.grid(True)
263
+
264
+ save_path = os.path.join(output_dir, "training_losses.png")
265
+ plt.savefig(save_path)
266
+ print(f"Loss plot saved to: {save_path}")
267
+
268
+ if show_inline:
269
+ plt.show()
270
+ else:
271
+ plt.close(fig)
272
+
273
+ def save_plot_non_gan_losses(train_losses_history, val_losses_history, output_dir="training_output", show_inline=True, filter_losses=None):
274
+ """
275
+ Generates and saves plots of losses for non-GAN models with multiple loss components.
276
+
277
+ Args:
278
+ train_losses_history (list[dict]): List of dicts containing training losses per epoch.
279
+ e.g., [{'l1': 0.5, 'sobel': 0.3}, ...]
280
+ val_losses_history (list[dict]): List of dicts containing validation losses per epoch.
281
+ output_dir (str): Directory to save the plot.
282
+ show_inline (bool): Whether to display the plot inline.
283
+ filter_losses (list[str], optional): List of loss names to plot.
284
+ If None, plots all found losses.
285
+ """
286
+ os.makedirs(output_dir, exist_ok=True)
287
+
288
+ # Extract all unique loss keys from both training and validation
289
+ all_keys = set()
290
+ for losses_dict in train_losses_history + val_losses_history:
291
+ all_keys.update(losses_dict.keys())
292
+
293
+ # Filter out non-numeric keys if any
294
+ loss_keys = [key for key in all_keys if key not in ['epoch']]
295
+
296
+ # Apply filter if specified
297
+ if filter_losses is not None:
298
+ loss_keys = [key for key in loss_keys if key in filter_losses]
299
+
300
+ loss_keys = sorted(loss_keys) # Sort for consistent ordering
301
+
302
+ # Create subplots
303
+ n_losses = len(loss_keys)
304
+ cols = min(3, n_losses) # Max 3 columns
305
+ rows = (n_losses + cols - 1) // cols # Ceiling division
306
+
307
+ fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 4 * rows))
308
+ if n_losses == 1:
309
+ axes = [axes]
310
+ elif rows > 1:
311
+ axes = axes.flatten()
312
+
313
+ fig.suptitle("Training and Validation Losses", fontsize=16, y=0.98)
314
+
315
+ for i, loss_key in enumerate(loss_keys):
316
+ ax = axes[i]
317
+
318
+ # Extract train and validation losses for this key
319
+ train_values = [losses.get(loss_key, 0) for losses in train_losses_history]
320
+ val_values = [losses.get(loss_key, 0) for losses in val_losses_history]
321
+
322
+ epochs_train = range(1, len(train_values) + 1)
323
+ epochs_val = range(1, len(val_values) + 1)
324
+
325
+ # Plot training and validation curves
326
+ if train_values:
327
+ ax.plot(epochs_train, train_values, label=f"Train {loss_key}", color="blue", linewidth=1.5)
328
+ if val_values:
329
+ ax.plot(epochs_val, val_values, label=f"Val {loss_key}", color="red", linewidth=1.5, linestyle='--')
330
+
331
+ ax.set_title(f"{loss_key.capitalize()} Loss", fontsize=12)
332
+ ax.set_xlabel("Epoch")
333
+ ax.set_ylabel("Loss")
334
+ ax.legend()
335
+ ax.grid(True, alpha=0.3)
336
+ ax.set_ylim(bottom=0)
337
+
338
+ # Hide unused subplots
339
+ for i in range(n_losses, len(axes)):
340
+ axes[i].set_visible(False)
341
+
342
+ plt.tight_layout()
343
+
344
+ # Save the plot
345
+ save_path = os.path.join(output_dir, "non_gan_training_losses.png")
346
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
347
+ print(f"Non-GAN training losses plot saved to: {save_path}")
348
+
349
+ if show_inline:
350
+ plt.show()
351
+ else:
352
+ plt.close(fig)
353
+
354
+
355
+ def save_comparison_grid(epoch, model, batch, set_name, device, output_dir="training_output", show_inline=True):
356
+ """
357
+ Generates and saves/shows a horizontal comparison grid (real vs. generated).
358
+ Automatically handles 256x256 or 64x64 output based on set_name.
359
+ """
360
+ os.makedirs(output_dir, exist_ok=True)
361
+
362
+ model.eval()
363
+ token_ids = batch["text"].to(device)
364
+ attention_mask = batch["attention_mask"].to(device)
365
+ real_images = batch["image"]
366
+ pokemon_ids = batch["idx"]
367
+ descriptions = batch["description"]
368
+ num_images = real_images.size(0)
369
+
370
+ with torch.no_grad():
371
+ generated_images = model(token_ids, attention_mask)
372
+ # Handle tuple output from generator (e.g., 256px and 64px images)
373
+ if isinstance(generated_images, tuple):
374
+ # Check if we want 64x64 or 256x256 based on set_name
375
+ if "64" in set_name:
376
+ generated_images = generated_images[1] # Use 64x64 output
377
+ # Resize real images to 64x64 for comparison
378
+ real_images = F.interpolate(real_images, size=(64, 64), mode='bilinear', align_corners=False)
379
+ else:
380
+ generated_images = generated_images[0] # Use 256x256 output
381
+
382
+ fig, axs = plt.subplots(2, num_images, figsize=(4 * num_images, 8.5))
383
+ resolution = "64x64" if "64" in set_name else "256x256"
384
+ fig.suptitle(
385
+ f"Epoch {epoch} - {set_name.capitalize()} Comparison ({resolution})", fontsize=16, y=0.98
386
+ )
387
+
388
+ for i in range(num_images):
389
+ ax_real = axs[0, i]
390
+ ax_real.imshow(denormalize_image(real_images[i].cpu()).permute(1, 2, 0))
391
+ ax_real.set_title(f"#{pokemon_ids[i]}: {descriptions[i][:35]}...", fontsize=10)
392
+ ax_real.axis("off")
393
+
394
+ ax_gen = axs[1, i]
395
+ ax_gen.imshow(denormalize_image(generated_images[i].cpu()).permute(1, 2, 0))
396
+ ax_gen.axis("off")
397
+
398
+ axs[0, 0].text(
399
+ -0.1,
400
+ 0.5,
401
+ "Real",
402
+ ha="center",
403
+ va="center",
404
+ rotation="vertical",
405
+ fontsize=14,
406
+ transform=axs[0, 0].transAxes,
407
+ )
408
+ axs[1, 0].text(
409
+ -0.1,
410
+ 0.5,
411
+ "Generated",
412
+ ha="center",
413
+ va="center",
414
+ rotation="vertical",
415
+ fontsize=14,
416
+ transform=axs[1, 0].transAxes,
417
+ )
418
+
419
+ plt.tight_layout(rect=(0, 0, 1, 0.95))
420
+
421
+ # Save the figure and optionally show it
422
+ save_path = os.path.join(output_dir, f"{epoch:03d}_{set_name}_comparison.png")
423
+ plt.savefig(save_path)
424
+
425
+ if show_inline:
426
+ plt.show()
427
+ else:
428
+ plt.close(fig)
pikapikagen/utils.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def denormalize_image(tensor):
2
+ """
3
+ Denormalizza un tensore immagine dall'intervallo [-1, 1] a [0, 1] per la visualizzazione.
4
+
5
+ Args:
6
+ tensor (torch.Tensor): Il tensore dell'immagine, con valori in [-1, 1].
7
+
8
+ Returns:
9
+ torch.Tensor: Il tensore denormalizzato con valori in [0, 1].
10
+ """
11
+ tensor = (tensor + 1) / 2
12
+ return tensor.clamp(0, 1)
pyproject.toml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "pikapikagen"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "gradio>=5.35.0",
9
+ "ipykernel>=6.29.5",
10
+ "ipywidgets>=8.1.7",
11
+ "jupyterlab>=4.4.5",
12
+ "matplotlib>=3.10.3",
13
+ "pandas>=2.3.0",
14
+ "sentence-transformers>=5.0.0",
15
+ "torch-fidelity>=0.3.0",
16
+ "torchvision>=0.22.1",
17
+ "transformers>=4.53.0",
18
+ ]
uv.lock ADDED
The diff for this file is too large to render. See raw diff