feat: use default from_pretrained function
Browse files- src/dalle_mini/model/modeling.py +33 -418
src/dalle_mini/model/modeling.py
CHANGED
@@ -15,37 +15,21 @@
|
|
15 |
""" DalleBart model. """
|
16 |
|
17 |
import math
|
18 |
-
import os
|
19 |
from functools import partial
|
20 |
-
from
|
21 |
-
from typing import Any, Dict, Optional, Tuple, Union
|
22 |
|
23 |
import flax
|
24 |
import flax.linen as nn
|
25 |
import jax
|
26 |
import jax.numpy as jnp
|
27 |
-
import msgpack.exceptions
|
28 |
from einops import rearrange
|
29 |
from flax.core.frozen_dict import unfreeze
|
30 |
from flax.linen import combine_masks, make_causal_mask
|
31 |
from flax.linen import partitioning as nn_partitioning
|
32 |
from flax.linen.linear import PrecisionLike
|
33 |
-
from flax.serialization import from_bytes
|
34 |
from flax.traverse_util import flatten_dict, unflatten_dict
|
35 |
from jax import custom_jvp, lax
|
36 |
from jax.random import PRNGKey
|
37 |
-
from transformers.configuration_utils import PretrainedConfig
|
38 |
-
from transformers.file_utils import (
|
39 |
-
FLAX_WEIGHTS_NAME,
|
40 |
-
WEIGHTS_NAME,
|
41 |
-
cached_path,
|
42 |
-
hf_bucket_url,
|
43 |
-
is_offline_mode,
|
44 |
-
is_remote_url,
|
45 |
-
)
|
46 |
-
from transformers.modeling_flax_pytorch_utils import (
|
47 |
-
load_pytorch_checkpoint_in_flax_state_dict,
|
48 |
-
)
|
49 |
from transformers.generation_flax_utils import FlaxSampleOutput
|
50 |
from transformers.modeling_flax_outputs import (
|
51 |
FlaxBaseModelOutput,
|
@@ -59,17 +43,8 @@ from transformers.models.bart.modeling_flax_bart import (
|
|
59 |
FlaxBartForConditionalGeneration,
|
60 |
FlaxBartForConditionalGenerationModule,
|
61 |
FlaxBartModule,
|
62 |
-
FlaxBartPreTrainedModel,
|
63 |
-
)
|
64 |
-
from requests import HTTPError
|
65 |
-
from transformers.utils import (
|
66 |
-
logging,
|
67 |
-
RepositoryNotFoundError,
|
68 |
-
RevisionNotFoundError,
|
69 |
-
EntryNotFoundError,
|
70 |
-
has_file,
|
71 |
-
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
72 |
)
|
|
|
73 |
|
74 |
from .configuration import DalleBartConfig
|
75 |
from .utils import PretrainedFromWandbMixin
|
@@ -1321,393 +1296,6 @@ class FlaxBartModule(FlaxBartModule):
|
|
1321 |
)
|
1322 |
|
1323 |
|
1324 |
-
class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
|
1325 |
-
"""
|
1326 |
-
Edits:
|
1327 |
-
- added num_params property
|
1328 |
-
- use_scan parameter
|
1329 |
-
"""
|
1330 |
-
|
1331 |
-
config_class = DalleBartConfig
|
1332 |
-
|
1333 |
-
def num_params(self, params=None):
|
1334 |
-
if params is None:
|
1335 |
-
params = self.params
|
1336 |
-
num_params = jax.tree_map(
|
1337 |
-
lambda param: param.size, flatten_dict(unfreeze(params))
|
1338 |
-
).values()
|
1339 |
-
return sum(list(num_params))
|
1340 |
-
|
1341 |
-
@classmethod
|
1342 |
-
def from_pretrained(
|
1343 |
-
cls,
|
1344 |
-
pretrained_model_name_or_path: Union[str, os.PathLike],
|
1345 |
-
dtype: jnp.dtype = jnp.float32,
|
1346 |
-
use_scan: bool = None,
|
1347 |
-
*model_args,
|
1348 |
-
**kwargs,
|
1349 |
-
):
|
1350 |
-
config = kwargs.pop("config", None)
|
1351 |
-
cache_dir = kwargs.pop("cache_dir", None)
|
1352 |
-
from_pt = kwargs.pop("from_pt", False)
|
1353 |
-
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
1354 |
-
force_download = kwargs.pop("force_download", False)
|
1355 |
-
resume_download = kwargs.pop("resume_download", False)
|
1356 |
-
proxies = kwargs.pop("proxies", None)
|
1357 |
-
local_files_only = kwargs.pop("local_files_only", False)
|
1358 |
-
use_auth_token = kwargs.pop("use_auth_token", None)
|
1359 |
-
revision = kwargs.pop("revision", None)
|
1360 |
-
from_pipeline = kwargs.pop("_from_pipeline", None)
|
1361 |
-
from_auto_class = kwargs.pop("_from_auto", False)
|
1362 |
-
_do_init = kwargs.pop("_do_init", True)
|
1363 |
-
|
1364 |
-
user_agent = {
|
1365 |
-
"file_type": "model",
|
1366 |
-
"framework": "flax",
|
1367 |
-
"from_auto_class": from_auto_class,
|
1368 |
-
}
|
1369 |
-
if from_pipeline is not None:
|
1370 |
-
user_agent["using_pipeline"] = from_pipeline
|
1371 |
-
|
1372 |
-
if is_offline_mode() and not local_files_only:
|
1373 |
-
logger.info("Offline mode: forcing local_files_only=True")
|
1374 |
-
local_files_only = True
|
1375 |
-
|
1376 |
-
# Load config if we don't provide a configuration
|
1377 |
-
if not isinstance(config, PretrainedConfig):
|
1378 |
-
config_path = (
|
1379 |
-
config if config is not None else pretrained_model_name_or_path
|
1380 |
-
)
|
1381 |
-
config, model_kwargs = cls.config_class.from_pretrained(
|
1382 |
-
config_path,
|
1383 |
-
cache_dir=cache_dir,
|
1384 |
-
return_unused_kwargs=True,
|
1385 |
-
force_download=force_download,
|
1386 |
-
resume_download=resume_download,
|
1387 |
-
proxies=proxies,
|
1388 |
-
local_files_only=local_files_only,
|
1389 |
-
use_auth_token=use_auth_token,
|
1390 |
-
revision=revision,
|
1391 |
-
_from_auto=from_auto_class,
|
1392 |
-
_from_pipeline=from_pipeline,
|
1393 |
-
**kwargs,
|
1394 |
-
)
|
1395 |
-
else:
|
1396 |
-
model_kwargs = kwargs
|
1397 |
-
|
1398 |
-
# Add the dtype to model_kwargs
|
1399 |
-
model_kwargs["dtype"] = dtype
|
1400 |
-
|
1401 |
-
# Load model
|
1402 |
-
if pretrained_model_name_or_path is not None:
|
1403 |
-
if os.path.isdir(pretrained_model_name_or_path):
|
1404 |
-
if from_pt and os.path.isfile(
|
1405 |
-
os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
1406 |
-
):
|
1407 |
-
# Load from a PyTorch checkpoint
|
1408 |
-
archive_file = os.path.join(
|
1409 |
-
pretrained_model_name_or_path, WEIGHTS_NAME
|
1410 |
-
)
|
1411 |
-
elif os.path.isfile(
|
1412 |
-
os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
|
1413 |
-
):
|
1414 |
-
# Load from a Flax checkpoint
|
1415 |
-
archive_file = os.path.join(
|
1416 |
-
pretrained_model_name_or_path, FLAX_WEIGHTS_NAME
|
1417 |
-
)
|
1418 |
-
# At this stage we don't have a weight file so we will raise an error.
|
1419 |
-
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
|
1420 |
-
raise EnvironmentError(
|
1421 |
-
f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
|
1422 |
-
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
|
1423 |
-
"weights."
|
1424 |
-
)
|
1425 |
-
else:
|
1426 |
-
raise EnvironmentError(
|
1427 |
-
f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
|
1428 |
-
f"{pretrained_model_name_or_path}."
|
1429 |
-
)
|
1430 |
-
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(
|
1431 |
-
pretrained_model_name_or_path
|
1432 |
-
):
|
1433 |
-
archive_file = pretrained_model_name_or_path
|
1434 |
-
else:
|
1435 |
-
filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME
|
1436 |
-
archive_file = hf_bucket_url(
|
1437 |
-
pretrained_model_name_or_path,
|
1438 |
-
filename=filename,
|
1439 |
-
revision=revision,
|
1440 |
-
)
|
1441 |
-
|
1442 |
-
# redirect to the cache, if necessary
|
1443 |
-
try:
|
1444 |
-
resolved_archive_file = cached_path(
|
1445 |
-
archive_file,
|
1446 |
-
cache_dir=cache_dir,
|
1447 |
-
force_download=force_download,
|
1448 |
-
proxies=proxies,
|
1449 |
-
resume_download=resume_download,
|
1450 |
-
local_files_only=local_files_only,
|
1451 |
-
use_auth_token=use_auth_token,
|
1452 |
-
user_agent=user_agent,
|
1453 |
-
)
|
1454 |
-
|
1455 |
-
except RepositoryNotFoundError:
|
1456 |
-
raise EnvironmentError(
|
1457 |
-
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
1458 |
-
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
1459 |
-
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
1460 |
-
"login` and pass `use_auth_token=True`."
|
1461 |
-
)
|
1462 |
-
except RevisionNotFoundError:
|
1463 |
-
raise EnvironmentError(
|
1464 |
-
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
1465 |
-
"this model name. Check the model page at "
|
1466 |
-
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
1467 |
-
)
|
1468 |
-
except EntryNotFoundError:
|
1469 |
-
if filename == FLAX_WEIGHTS_NAME:
|
1470 |
-
has_file_kwargs = {
|
1471 |
-
"revision": revision,
|
1472 |
-
"proxies": proxies,
|
1473 |
-
"use_auth_token": use_auth_token,
|
1474 |
-
}
|
1475 |
-
if has_file(
|
1476 |
-
pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs
|
1477 |
-
):
|
1478 |
-
raise EnvironmentError(
|
1479 |
-
f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME} "
|
1480 |
-
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from "
|
1481 |
-
"those weights."
|
1482 |
-
)
|
1483 |
-
else:
|
1484 |
-
raise EnvironmentError(
|
1485 |
-
f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME} "
|
1486 |
-
f"or {WEIGHTS_NAME}."
|
1487 |
-
)
|
1488 |
-
else:
|
1489 |
-
raise EnvironmentError(
|
1490 |
-
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
|
1491 |
-
)
|
1492 |
-
except HTTPError as err:
|
1493 |
-
raise EnvironmentError(
|
1494 |
-
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
|
1495 |
-
f"{err}"
|
1496 |
-
)
|
1497 |
-
except ValueError:
|
1498 |
-
raise EnvironmentError(
|
1499 |
-
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in the cached "
|
1500 |
-
f"files and it looks like {pretrained_model_name_or_path} is not the path to a directory "
|
1501 |
-
f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\n"
|
1502 |
-
"Checkout your internet connection or see how to run the library in offline mode at "
|
1503 |
-
"'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
1504 |
-
)
|
1505 |
-
except EnvironmentError:
|
1506 |
-
raise EnvironmentError(
|
1507 |
-
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
1508 |
-
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
1509 |
-
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
1510 |
-
f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
|
1511 |
-
)
|
1512 |
-
|
1513 |
-
if resolved_archive_file == archive_file:
|
1514 |
-
logger.info(f"loading weights file {archive_file}")
|
1515 |
-
else:
|
1516 |
-
logger.info(
|
1517 |
-
f"loading weights file {archive_file} from cache at {resolved_archive_file}"
|
1518 |
-
)
|
1519 |
-
else:
|
1520 |
-
resolved_archive_file = None
|
1521 |
-
|
1522 |
-
# unscan model
|
1523 |
-
unscan_model = None
|
1524 |
-
if use_scan is not None:
|
1525 |
-
assert (config.use_scan and use_scan is True) or (
|
1526 |
-
not use_scan
|
1527 |
-
), f"Wrong setting of use_scan: {use_scan} vs config.use_scan: {config.use_scan}"
|
1528 |
-
if config.use_scan and not use_scan:
|
1529 |
-
config.use_scan = False
|
1530 |
-
unscan_model = True
|
1531 |
-
|
1532 |
-
# init random models
|
1533 |
-
model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)
|
1534 |
-
|
1535 |
-
if from_pt:
|
1536 |
-
state = load_pytorch_checkpoint_in_flax_state_dict(
|
1537 |
-
model, resolved_archive_file
|
1538 |
-
)
|
1539 |
-
else:
|
1540 |
-
with open(resolved_archive_file, "rb") as state_f:
|
1541 |
-
try:
|
1542 |
-
state = from_bytes(cls, state_f.read())
|
1543 |
-
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
|
1544 |
-
try:
|
1545 |
-
with open(resolved_archive_file) as f:
|
1546 |
-
if f.read().startswith("version"):
|
1547 |
-
raise OSError(
|
1548 |
-
"You seem to have cloned a repository without having git-lfs installed. Please install "
|
1549 |
-
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
|
1550 |
-
"you cloned."
|
1551 |
-
)
|
1552 |
-
else:
|
1553 |
-
raise ValueError from e
|
1554 |
-
except (UnicodeDecodeError, ValueError):
|
1555 |
-
raise EnvironmentError(
|
1556 |
-
f"Unable to convert {archive_file} to Flax deserializable object. "
|
1557 |
-
)
|
1558 |
-
# make sure all arrays are stored as jnp.arrays
|
1559 |
-
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
|
1560 |
-
# https://github.com/google/flax/issues/1261
|
1561 |
-
if _do_init:
|
1562 |
-
state = jax.tree_util.tree_map(jnp.array, state)
|
1563 |
-
else:
|
1564 |
-
# keep the params on CPU if we don't want to initialize
|
1565 |
-
state = jax.tree_util.tree_map(
|
1566 |
-
lambda x: jax.device_put(x, jax.devices("cpu")[0]), state
|
1567 |
-
)
|
1568 |
-
|
1569 |
-
# if model is base model only use model_prefix key
|
1570 |
-
if (
|
1571 |
-
cls.base_model_prefix not in dict(model.params_shape_tree)
|
1572 |
-
and cls.base_model_prefix in state
|
1573 |
-
):
|
1574 |
-
state = state[cls.base_model_prefix]
|
1575 |
-
|
1576 |
-
# if model is head model and we are loading weights from base model
|
1577 |
-
# we initialize new params dict with base_model_prefix
|
1578 |
-
if (
|
1579 |
-
cls.base_model_prefix in dict(model.params_shape_tree)
|
1580 |
-
and cls.base_model_prefix not in state
|
1581 |
-
):
|
1582 |
-
state = {cls.base_model_prefix: state}
|
1583 |
-
|
1584 |
-
# flatten dicts
|
1585 |
-
state = flatten_dict(state)
|
1586 |
-
if unscan_model:
|
1587 |
-
scanned_keys = [k for k in state.keys() if "layers" in k]
|
1588 |
-
for k in scanned_keys:
|
1589 |
-
v = state[k]
|
1590 |
-
name_idx = k.index("layers") + 1
|
1591 |
-
for i in range(len(v)):
|
1592 |
-
new_k = (
|
1593 |
-
*k[:name_idx],
|
1594 |
-
f"{k[name_idx][:-1]}_{i}",
|
1595 |
-
*k[name_idx + 1 :],
|
1596 |
-
)
|
1597 |
-
state[new_k] = v[i]
|
1598 |
-
del state[k]
|
1599 |
-
|
1600 |
-
random_state = flatten_dict(
|
1601 |
-
unfreeze(model.params if _do_init else model.params_shape_tree)
|
1602 |
-
)
|
1603 |
-
|
1604 |
-
missing_keys = model.required_params - set(state.keys())
|
1605 |
-
unexpected_keys = set(state.keys()) - model.required_params
|
1606 |
-
|
1607 |
-
if missing_keys and not _do_init:
|
1608 |
-
logger.warn(
|
1609 |
-
f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
|
1610 |
-
f"Make sure to call model.init_weights to initialize the missing weights."
|
1611 |
-
)
|
1612 |
-
cls._missing_keys = missing_keys
|
1613 |
-
|
1614 |
-
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
1615 |
-
# matching the weights in the model.
|
1616 |
-
mismatched_keys = []
|
1617 |
-
for key in state.keys():
|
1618 |
-
if key in random_state and state[key].shape != random_state[key].shape:
|
1619 |
-
if ignore_mismatched_sizes:
|
1620 |
-
mismatched_keys.append(
|
1621 |
-
(key, state[key].shape, random_state[key].shape)
|
1622 |
-
)
|
1623 |
-
state[key] = random_state[key]
|
1624 |
-
else:
|
1625 |
-
raise ValueError(
|
1626 |
-
f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
|
1627 |
-
f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. "
|
1628 |
-
"Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this "
|
1629 |
-
"model."
|
1630 |
-
)
|
1631 |
-
|
1632 |
-
# add missing keys as random parameters if we are initializing
|
1633 |
-
if missing_keys and _do_init:
|
1634 |
-
for missing_key in missing_keys:
|
1635 |
-
state[missing_key] = random_state[missing_key]
|
1636 |
-
|
1637 |
-
# remove unexpected keys to not be saved again
|
1638 |
-
for unexpected_key in unexpected_keys:
|
1639 |
-
del state[unexpected_key]
|
1640 |
-
|
1641 |
-
if len(unexpected_keys) > 0:
|
1642 |
-
logger.warning(
|
1643 |
-
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
|
1644 |
-
f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
|
1645 |
-
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
|
1646 |
-
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
|
1647 |
-
f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
|
1648 |
-
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
|
1649 |
-
)
|
1650 |
-
else:
|
1651 |
-
logger.info(
|
1652 |
-
f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n"
|
1653 |
-
)
|
1654 |
-
|
1655 |
-
if len(missing_keys) > 0:
|
1656 |
-
logger.warning(
|
1657 |
-
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
|
1658 |
-
f"and are newly initialized: {missing_keys}\n"
|
1659 |
-
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
1660 |
-
)
|
1661 |
-
elif len(mismatched_keys) == 0:
|
1662 |
-
logger.info(
|
1663 |
-
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
|
1664 |
-
f"If your task is similar to the task the model of the checkpoint was trained on, "
|
1665 |
-
f"you can already use {model.__class__.__name__} for predictions without further training."
|
1666 |
-
)
|
1667 |
-
if len(mismatched_keys) > 0:
|
1668 |
-
mismatched_warning = "\n".join(
|
1669 |
-
[
|
1670 |
-
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
1671 |
-
for key, shape1, shape2 in mismatched_keys
|
1672 |
-
]
|
1673 |
-
)
|
1674 |
-
logger.warning(
|
1675 |
-
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
|
1676 |
-
f"and are newly initialized because the shapes did not match:\n{mismatched_warning}\n"
|
1677 |
-
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
1678 |
-
)
|
1679 |
-
|
1680 |
-
# dictionary of key: dtypes for the model params
|
1681 |
-
param_dtypes = jax.tree_map(lambda x: x.dtype, state)
|
1682 |
-
# extract keys of parameters not in jnp.float32
|
1683 |
-
fp16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.float16]
|
1684 |
-
bf16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.bfloat16]
|
1685 |
-
|
1686 |
-
# raise a warning if any of the parameters are not in jnp.float32
|
1687 |
-
if len(fp16_params) > 0:
|
1688 |
-
logger.warning(
|
1689 |
-
f"Some of the weights of {model.__class__.__name__} were initialized in float16 precision from "
|
1690 |
-
f"the model checkpoint at {pretrained_model_name_or_path}:\n{fp16_params}\n"
|
1691 |
-
"You should probably UPCAST the model weights to float32 if this was not intended. "
|
1692 |
-
"See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
|
1693 |
-
)
|
1694 |
-
|
1695 |
-
if len(bf16_params) > 0:
|
1696 |
-
logger.warning(
|
1697 |
-
f"Some of the weights of {model.__class__.__name__} were initialized in bfloat16 precision from "
|
1698 |
-
f"the model checkpoint at {pretrained_model_name_or_path}:\n{bf16_params}\n"
|
1699 |
-
"You should probably UPCAST the model weights to float32 if this was not intended. "
|
1700 |
-
"See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
|
1701 |
-
)
|
1702 |
-
|
1703 |
-
if _do_init:
|
1704 |
-
# set correct parameters
|
1705 |
-
model.params = unflatten_dict(state)
|
1706 |
-
return model
|
1707 |
-
else:
|
1708 |
-
return model, unflatten_dict(state)
|
1709 |
-
|
1710 |
-
|
1711 |
class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
|
1712 |
"""
|
1713 |
Edits:
|
@@ -1788,21 +1376,48 @@ class SampleState:
|
|
1788 |
model_kwargs_uncond: Dict[str, jnp.ndarray]
|
1789 |
|
1790 |
|
1791 |
-
class DalleBart(
|
1792 |
-
PretrainedFromWandbMixin, FlaxBartPreTrainedModel, FlaxBartForConditionalGeneration
|
1793 |
-
):
|
1794 |
"""
|
1795 |
Edits:
|
1796 |
- renamed from FlaxBartForConditionalGeneration
|
1797 |
-
- uses custom FlaxBartPreTrainedModel
|
1798 |
- uses custom FlaxBartForConditionalGenerationModule
|
1799 |
- no bias in decode method
|
1800 |
- custom prepare_inputs_for_generation using "max_length - 1" to avoid issues
|
1801 |
related to position embedding during model.generate()
|
1802 |
- custom generate method to allow super conditions
|
|
|
|
|
1803 |
"""
|
1804 |
|
1805 |
module_class = FlaxBartForConditionalGenerationModule
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1806 |
|
1807 |
def decode(
|
1808 |
self,
|
|
|
15 |
""" DalleBart model. """
|
16 |
|
17 |
import math
|
|
|
18 |
from functools import partial
|
19 |
+
from typing import Any, Dict, Optional, Tuple
|
|
|
20 |
|
21 |
import flax
|
22 |
import flax.linen as nn
|
23 |
import jax
|
24 |
import jax.numpy as jnp
|
|
|
25 |
from einops import rearrange
|
26 |
from flax.core.frozen_dict import unfreeze
|
27 |
from flax.linen import combine_masks, make_causal_mask
|
28 |
from flax.linen import partitioning as nn_partitioning
|
29 |
from flax.linen.linear import PrecisionLike
|
|
|
30 |
from flax.traverse_util import flatten_dict, unflatten_dict
|
31 |
from jax import custom_jvp, lax
|
32 |
from jax.random import PRNGKey
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
from transformers.generation_flax_utils import FlaxSampleOutput
|
34 |
from transformers.modeling_flax_outputs import (
|
35 |
FlaxBaseModelOutput,
|
|
|
43 |
FlaxBartForConditionalGeneration,
|
44 |
FlaxBartForConditionalGenerationModule,
|
45 |
FlaxBartModule,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
)
|
47 |
+
from transformers.utils import logging
|
48 |
|
49 |
from .configuration import DalleBartConfig
|
50 |
from .utils import PretrainedFromWandbMixin
|
|
|
1296 |
)
|
1297 |
|
1298 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1299 |
class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
|
1300 |
"""
|
1301 |
Edits:
|
|
|
1376 |
model_kwargs_uncond: Dict[str, jnp.ndarray]
|
1377 |
|
1378 |
|
1379 |
+
class DalleBart(PretrainedFromWandbMixin, FlaxBartForConditionalGeneration):
|
|
|
|
|
1380 |
"""
|
1381 |
Edits:
|
1382 |
- renamed from FlaxBartForConditionalGeneration
|
|
|
1383 |
- uses custom FlaxBartForConditionalGenerationModule
|
1384 |
- no bias in decode method
|
1385 |
- custom prepare_inputs_for_generation using "max_length - 1" to avoid issues
|
1386 |
related to position embedding during model.generate()
|
1387 |
- custom generate method to allow super conditions
|
1388 |
+
- num_params property
|
1389 |
+
- unscan function
|
1390 |
"""
|
1391 |
|
1392 |
module_class = FlaxBartForConditionalGenerationModule
|
1393 |
+
config_class = DalleBartConfig
|
1394 |
+
|
1395 |
+
def num_params(self, params=None):
|
1396 |
+
if params is None:
|
1397 |
+
params = self.params
|
1398 |
+
num_params = jax.tree_map(
|
1399 |
+
lambda param: param.size, flatten_dict(unfreeze(params))
|
1400 |
+
).values()
|
1401 |
+
return sum(list(num_params))
|
1402 |
+
|
1403 |
+
def unscan(self, params):
|
1404 |
+
if self.config.use_scan:
|
1405 |
+
self.config.use_scan = False
|
1406 |
+
params = flatten_dict(params)
|
1407 |
+
scanned_keys = [k for k in params.keys() if "layers" in k]
|
1408 |
+
for k in scanned_keys:
|
1409 |
+
v = params[k]
|
1410 |
+
name_idx = k.index("layers") + 1
|
1411 |
+
for i in range(len(v)):
|
1412 |
+
new_k = (
|
1413 |
+
*k[:name_idx],
|
1414 |
+
f"{k[name_idx][:-1]}_{i}",
|
1415 |
+
*k[name_idx + 1 :],
|
1416 |
+
)
|
1417 |
+
params[new_k] = v[i]
|
1418 |
+
del params[k]
|
1419 |
+
params = unflatten_dict(params)
|
1420 |
+
return params
|
1421 |
|
1422 |
def decode(
|
1423 |
self,
|