boris commited on
Commit
4ac66e4
1 Parent(s): dca3ada

feat: use default from_pretrained function

Browse files
Files changed (1) hide show
  1. src/dalle_mini/model/modeling.py +33 -418
src/dalle_mini/model/modeling.py CHANGED
@@ -15,37 +15,21 @@
15
  """ DalleBart model. """
16
 
17
  import math
18
- import os
19
  from functools import partial
20
- from pickle import UnpicklingError
21
- from typing import Any, Dict, Optional, Tuple, Union
22
 
23
  import flax
24
  import flax.linen as nn
25
  import jax
26
  import jax.numpy as jnp
27
- import msgpack.exceptions
28
  from einops import rearrange
29
  from flax.core.frozen_dict import unfreeze
30
  from flax.linen import combine_masks, make_causal_mask
31
  from flax.linen import partitioning as nn_partitioning
32
  from flax.linen.linear import PrecisionLike
33
- from flax.serialization import from_bytes
34
  from flax.traverse_util import flatten_dict, unflatten_dict
35
  from jax import custom_jvp, lax
36
  from jax.random import PRNGKey
37
- from transformers.configuration_utils import PretrainedConfig
38
- from transformers.file_utils import (
39
- FLAX_WEIGHTS_NAME,
40
- WEIGHTS_NAME,
41
- cached_path,
42
- hf_bucket_url,
43
- is_offline_mode,
44
- is_remote_url,
45
- )
46
- from transformers.modeling_flax_pytorch_utils import (
47
- load_pytorch_checkpoint_in_flax_state_dict,
48
- )
49
  from transformers.generation_flax_utils import FlaxSampleOutput
50
  from transformers.modeling_flax_outputs import (
51
  FlaxBaseModelOutput,
@@ -59,17 +43,8 @@ from transformers.models.bart.modeling_flax_bart import (
59
  FlaxBartForConditionalGeneration,
60
  FlaxBartForConditionalGenerationModule,
61
  FlaxBartModule,
62
- FlaxBartPreTrainedModel,
63
- )
64
- from requests import HTTPError
65
- from transformers.utils import (
66
- logging,
67
- RepositoryNotFoundError,
68
- RevisionNotFoundError,
69
- EntryNotFoundError,
70
- has_file,
71
- HUGGINGFACE_CO_RESOLVE_ENDPOINT,
72
  )
 
73
 
74
  from .configuration import DalleBartConfig
75
  from .utils import PretrainedFromWandbMixin
@@ -1321,393 +1296,6 @@ class FlaxBartModule(FlaxBartModule):
1321
  )
1322
 
1323
 
1324
- class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
1325
- """
1326
- Edits:
1327
- - added num_params property
1328
- - use_scan parameter
1329
- """
1330
-
1331
- config_class = DalleBartConfig
1332
-
1333
- def num_params(self, params=None):
1334
- if params is None:
1335
- params = self.params
1336
- num_params = jax.tree_map(
1337
- lambda param: param.size, flatten_dict(unfreeze(params))
1338
- ).values()
1339
- return sum(list(num_params))
1340
-
1341
- @classmethod
1342
- def from_pretrained(
1343
- cls,
1344
- pretrained_model_name_or_path: Union[str, os.PathLike],
1345
- dtype: jnp.dtype = jnp.float32,
1346
- use_scan: bool = None,
1347
- *model_args,
1348
- **kwargs,
1349
- ):
1350
- config = kwargs.pop("config", None)
1351
- cache_dir = kwargs.pop("cache_dir", None)
1352
- from_pt = kwargs.pop("from_pt", False)
1353
- ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
1354
- force_download = kwargs.pop("force_download", False)
1355
- resume_download = kwargs.pop("resume_download", False)
1356
- proxies = kwargs.pop("proxies", None)
1357
- local_files_only = kwargs.pop("local_files_only", False)
1358
- use_auth_token = kwargs.pop("use_auth_token", None)
1359
- revision = kwargs.pop("revision", None)
1360
- from_pipeline = kwargs.pop("_from_pipeline", None)
1361
- from_auto_class = kwargs.pop("_from_auto", False)
1362
- _do_init = kwargs.pop("_do_init", True)
1363
-
1364
- user_agent = {
1365
- "file_type": "model",
1366
- "framework": "flax",
1367
- "from_auto_class": from_auto_class,
1368
- }
1369
- if from_pipeline is not None:
1370
- user_agent["using_pipeline"] = from_pipeline
1371
-
1372
- if is_offline_mode() and not local_files_only:
1373
- logger.info("Offline mode: forcing local_files_only=True")
1374
- local_files_only = True
1375
-
1376
- # Load config if we don't provide a configuration
1377
- if not isinstance(config, PretrainedConfig):
1378
- config_path = (
1379
- config if config is not None else pretrained_model_name_or_path
1380
- )
1381
- config, model_kwargs = cls.config_class.from_pretrained(
1382
- config_path,
1383
- cache_dir=cache_dir,
1384
- return_unused_kwargs=True,
1385
- force_download=force_download,
1386
- resume_download=resume_download,
1387
- proxies=proxies,
1388
- local_files_only=local_files_only,
1389
- use_auth_token=use_auth_token,
1390
- revision=revision,
1391
- _from_auto=from_auto_class,
1392
- _from_pipeline=from_pipeline,
1393
- **kwargs,
1394
- )
1395
- else:
1396
- model_kwargs = kwargs
1397
-
1398
- # Add the dtype to model_kwargs
1399
- model_kwargs["dtype"] = dtype
1400
-
1401
- # Load model
1402
- if pretrained_model_name_or_path is not None:
1403
- if os.path.isdir(pretrained_model_name_or_path):
1404
- if from_pt and os.path.isfile(
1405
- os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
1406
- ):
1407
- # Load from a PyTorch checkpoint
1408
- archive_file = os.path.join(
1409
- pretrained_model_name_or_path, WEIGHTS_NAME
1410
- )
1411
- elif os.path.isfile(
1412
- os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
1413
- ):
1414
- # Load from a Flax checkpoint
1415
- archive_file = os.path.join(
1416
- pretrained_model_name_or_path, FLAX_WEIGHTS_NAME
1417
- )
1418
- # At this stage we don't have a weight file so we will raise an error.
1419
- elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
1420
- raise EnvironmentError(
1421
- f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
1422
- "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
1423
- "weights."
1424
- )
1425
- else:
1426
- raise EnvironmentError(
1427
- f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
1428
- f"{pretrained_model_name_or_path}."
1429
- )
1430
- elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(
1431
- pretrained_model_name_or_path
1432
- ):
1433
- archive_file = pretrained_model_name_or_path
1434
- else:
1435
- filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME
1436
- archive_file = hf_bucket_url(
1437
- pretrained_model_name_or_path,
1438
- filename=filename,
1439
- revision=revision,
1440
- )
1441
-
1442
- # redirect to the cache, if necessary
1443
- try:
1444
- resolved_archive_file = cached_path(
1445
- archive_file,
1446
- cache_dir=cache_dir,
1447
- force_download=force_download,
1448
- proxies=proxies,
1449
- resume_download=resume_download,
1450
- local_files_only=local_files_only,
1451
- use_auth_token=use_auth_token,
1452
- user_agent=user_agent,
1453
- )
1454
-
1455
- except RepositoryNotFoundError:
1456
- raise EnvironmentError(
1457
- f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
1458
- "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
1459
- "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
1460
- "login` and pass `use_auth_token=True`."
1461
- )
1462
- except RevisionNotFoundError:
1463
- raise EnvironmentError(
1464
- f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
1465
- "this model name. Check the model page at "
1466
- f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
1467
- )
1468
- except EntryNotFoundError:
1469
- if filename == FLAX_WEIGHTS_NAME:
1470
- has_file_kwargs = {
1471
- "revision": revision,
1472
- "proxies": proxies,
1473
- "use_auth_token": use_auth_token,
1474
- }
1475
- if has_file(
1476
- pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs
1477
- ):
1478
- raise EnvironmentError(
1479
- f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME} "
1480
- "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from "
1481
- "those weights."
1482
- )
1483
- else:
1484
- raise EnvironmentError(
1485
- f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME} "
1486
- f"or {WEIGHTS_NAME}."
1487
- )
1488
- else:
1489
- raise EnvironmentError(
1490
- f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
1491
- )
1492
- except HTTPError as err:
1493
- raise EnvironmentError(
1494
- f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
1495
- f"{err}"
1496
- )
1497
- except ValueError:
1498
- raise EnvironmentError(
1499
- f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in the cached "
1500
- f"files and it looks like {pretrained_model_name_or_path} is not the path to a directory "
1501
- f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\n"
1502
- "Checkout your internet connection or see how to run the library in offline mode at "
1503
- "'https://huggingface.co/docs/transformers/installation#offline-mode'."
1504
- )
1505
- except EnvironmentError:
1506
- raise EnvironmentError(
1507
- f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
1508
- "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
1509
- f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
1510
- f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
1511
- )
1512
-
1513
- if resolved_archive_file == archive_file:
1514
- logger.info(f"loading weights file {archive_file}")
1515
- else:
1516
- logger.info(
1517
- f"loading weights file {archive_file} from cache at {resolved_archive_file}"
1518
- )
1519
- else:
1520
- resolved_archive_file = None
1521
-
1522
- # unscan model
1523
- unscan_model = None
1524
- if use_scan is not None:
1525
- assert (config.use_scan and use_scan is True) or (
1526
- not use_scan
1527
- ), f"Wrong setting of use_scan: {use_scan} vs config.use_scan: {config.use_scan}"
1528
- if config.use_scan and not use_scan:
1529
- config.use_scan = False
1530
- unscan_model = True
1531
-
1532
- # init random models
1533
- model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)
1534
-
1535
- if from_pt:
1536
- state = load_pytorch_checkpoint_in_flax_state_dict(
1537
- model, resolved_archive_file
1538
- )
1539
- else:
1540
- with open(resolved_archive_file, "rb") as state_f:
1541
- try:
1542
- state = from_bytes(cls, state_f.read())
1543
- except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
1544
- try:
1545
- with open(resolved_archive_file) as f:
1546
- if f.read().startswith("version"):
1547
- raise OSError(
1548
- "You seem to have cloned a repository without having git-lfs installed. Please install "
1549
- "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
1550
- "you cloned."
1551
- )
1552
- else:
1553
- raise ValueError from e
1554
- except (UnicodeDecodeError, ValueError):
1555
- raise EnvironmentError(
1556
- f"Unable to convert {archive_file} to Flax deserializable object. "
1557
- )
1558
- # make sure all arrays are stored as jnp.arrays
1559
- # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
1560
- # https://github.com/google/flax/issues/1261
1561
- if _do_init:
1562
- state = jax.tree_util.tree_map(jnp.array, state)
1563
- else:
1564
- # keep the params on CPU if we don't want to initialize
1565
- state = jax.tree_util.tree_map(
1566
- lambda x: jax.device_put(x, jax.devices("cpu")[0]), state
1567
- )
1568
-
1569
- # if model is base model only use model_prefix key
1570
- if (
1571
- cls.base_model_prefix not in dict(model.params_shape_tree)
1572
- and cls.base_model_prefix in state
1573
- ):
1574
- state = state[cls.base_model_prefix]
1575
-
1576
- # if model is head model and we are loading weights from base model
1577
- # we initialize new params dict with base_model_prefix
1578
- if (
1579
- cls.base_model_prefix in dict(model.params_shape_tree)
1580
- and cls.base_model_prefix not in state
1581
- ):
1582
- state = {cls.base_model_prefix: state}
1583
-
1584
- # flatten dicts
1585
- state = flatten_dict(state)
1586
- if unscan_model:
1587
- scanned_keys = [k for k in state.keys() if "layers" in k]
1588
- for k in scanned_keys:
1589
- v = state[k]
1590
- name_idx = k.index("layers") + 1
1591
- for i in range(len(v)):
1592
- new_k = (
1593
- *k[:name_idx],
1594
- f"{k[name_idx][:-1]}_{i}",
1595
- *k[name_idx + 1 :],
1596
- )
1597
- state[new_k] = v[i]
1598
- del state[k]
1599
-
1600
- random_state = flatten_dict(
1601
- unfreeze(model.params if _do_init else model.params_shape_tree)
1602
- )
1603
-
1604
- missing_keys = model.required_params - set(state.keys())
1605
- unexpected_keys = set(state.keys()) - model.required_params
1606
-
1607
- if missing_keys and not _do_init:
1608
- logger.warn(
1609
- f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
1610
- f"Make sure to call model.init_weights to initialize the missing weights."
1611
- )
1612
- cls._missing_keys = missing_keys
1613
-
1614
- # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
1615
- # matching the weights in the model.
1616
- mismatched_keys = []
1617
- for key in state.keys():
1618
- if key in random_state and state[key].shape != random_state[key].shape:
1619
- if ignore_mismatched_sizes:
1620
- mismatched_keys.append(
1621
- (key, state[key].shape, random_state[key].shape)
1622
- )
1623
- state[key] = random_state[key]
1624
- else:
1625
- raise ValueError(
1626
- f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
1627
- f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. "
1628
- "Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this "
1629
- "model."
1630
- )
1631
-
1632
- # add missing keys as random parameters if we are initializing
1633
- if missing_keys and _do_init:
1634
- for missing_key in missing_keys:
1635
- state[missing_key] = random_state[missing_key]
1636
-
1637
- # remove unexpected keys to not be saved again
1638
- for unexpected_key in unexpected_keys:
1639
- del state[unexpected_key]
1640
-
1641
- if len(unexpected_keys) > 0:
1642
- logger.warning(
1643
- f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
1644
- f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
1645
- f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
1646
- f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
1647
- f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
1648
- f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
1649
- )
1650
- else:
1651
- logger.info(
1652
- f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n"
1653
- )
1654
-
1655
- if len(missing_keys) > 0:
1656
- logger.warning(
1657
- f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
1658
- f"and are newly initialized: {missing_keys}\n"
1659
- f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
1660
- )
1661
- elif len(mismatched_keys) == 0:
1662
- logger.info(
1663
- f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
1664
- f"If your task is similar to the task the model of the checkpoint was trained on, "
1665
- f"you can already use {model.__class__.__name__} for predictions without further training."
1666
- )
1667
- if len(mismatched_keys) > 0:
1668
- mismatched_warning = "\n".join(
1669
- [
1670
- f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
1671
- for key, shape1, shape2 in mismatched_keys
1672
- ]
1673
- )
1674
- logger.warning(
1675
- f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
1676
- f"and are newly initialized because the shapes did not match:\n{mismatched_warning}\n"
1677
- f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
1678
- )
1679
-
1680
- # dictionary of key: dtypes for the model params
1681
- param_dtypes = jax.tree_map(lambda x: x.dtype, state)
1682
- # extract keys of parameters not in jnp.float32
1683
- fp16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.float16]
1684
- bf16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.bfloat16]
1685
-
1686
- # raise a warning if any of the parameters are not in jnp.float32
1687
- if len(fp16_params) > 0:
1688
- logger.warning(
1689
- f"Some of the weights of {model.__class__.__name__} were initialized in float16 precision from "
1690
- f"the model checkpoint at {pretrained_model_name_or_path}:\n{fp16_params}\n"
1691
- "You should probably UPCAST the model weights to float32 if this was not intended. "
1692
- "See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
1693
- )
1694
-
1695
- if len(bf16_params) > 0:
1696
- logger.warning(
1697
- f"Some of the weights of {model.__class__.__name__} were initialized in bfloat16 precision from "
1698
- f"the model checkpoint at {pretrained_model_name_or_path}:\n{bf16_params}\n"
1699
- "You should probably UPCAST the model weights to float32 if this was not intended. "
1700
- "See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
1701
- )
1702
-
1703
- if _do_init:
1704
- # set correct parameters
1705
- model.params = unflatten_dict(state)
1706
- return model
1707
- else:
1708
- return model, unflatten_dict(state)
1709
-
1710
-
1711
  class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
1712
  """
1713
  Edits:
@@ -1788,21 +1376,48 @@ class SampleState:
1788
  model_kwargs_uncond: Dict[str, jnp.ndarray]
1789
 
1790
 
1791
- class DalleBart(
1792
- PretrainedFromWandbMixin, FlaxBartPreTrainedModel, FlaxBartForConditionalGeneration
1793
- ):
1794
  """
1795
  Edits:
1796
  - renamed from FlaxBartForConditionalGeneration
1797
- - uses custom FlaxBartPreTrainedModel
1798
  - uses custom FlaxBartForConditionalGenerationModule
1799
  - no bias in decode method
1800
  - custom prepare_inputs_for_generation using "max_length - 1" to avoid issues
1801
  related to position embedding during model.generate()
1802
  - custom generate method to allow super conditions
 
 
1803
  """
1804
 
1805
  module_class = FlaxBartForConditionalGenerationModule
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1806
 
1807
  def decode(
1808
  self,
 
15
  """ DalleBart model. """
16
 
17
  import math
 
18
  from functools import partial
19
+ from typing import Any, Dict, Optional, Tuple
 
20
 
21
  import flax
22
  import flax.linen as nn
23
  import jax
24
  import jax.numpy as jnp
 
25
  from einops import rearrange
26
  from flax.core.frozen_dict import unfreeze
27
  from flax.linen import combine_masks, make_causal_mask
28
  from flax.linen import partitioning as nn_partitioning
29
  from flax.linen.linear import PrecisionLike
 
30
  from flax.traverse_util import flatten_dict, unflatten_dict
31
  from jax import custom_jvp, lax
32
  from jax.random import PRNGKey
 
 
 
 
 
 
 
 
 
 
 
 
33
  from transformers.generation_flax_utils import FlaxSampleOutput
34
  from transformers.modeling_flax_outputs import (
35
  FlaxBaseModelOutput,
 
43
  FlaxBartForConditionalGeneration,
44
  FlaxBartForConditionalGenerationModule,
45
  FlaxBartModule,
 
 
 
 
 
 
 
 
 
 
46
  )
47
+ from transformers.utils import logging
48
 
49
  from .configuration import DalleBartConfig
50
  from .utils import PretrainedFromWandbMixin
 
1296
  )
1297
 
1298
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1299
  class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
1300
  """
1301
  Edits:
 
1376
  model_kwargs_uncond: Dict[str, jnp.ndarray]
1377
 
1378
 
1379
+ class DalleBart(PretrainedFromWandbMixin, FlaxBartForConditionalGeneration):
 
 
1380
  """
1381
  Edits:
1382
  - renamed from FlaxBartForConditionalGeneration
 
1383
  - uses custom FlaxBartForConditionalGenerationModule
1384
  - no bias in decode method
1385
  - custom prepare_inputs_for_generation using "max_length - 1" to avoid issues
1386
  related to position embedding during model.generate()
1387
  - custom generate method to allow super conditions
1388
+ - num_params property
1389
+ - unscan function
1390
  """
1391
 
1392
  module_class = FlaxBartForConditionalGenerationModule
1393
+ config_class = DalleBartConfig
1394
+
1395
+ def num_params(self, params=None):
1396
+ if params is None:
1397
+ params = self.params
1398
+ num_params = jax.tree_map(
1399
+ lambda param: param.size, flatten_dict(unfreeze(params))
1400
+ ).values()
1401
+ return sum(list(num_params))
1402
+
1403
+ def unscan(self, params):
1404
+ if self.config.use_scan:
1405
+ self.config.use_scan = False
1406
+ params = flatten_dict(params)
1407
+ scanned_keys = [k for k in params.keys() if "layers" in k]
1408
+ for k in scanned_keys:
1409
+ v = params[k]
1410
+ name_idx = k.index("layers") + 1
1411
+ for i in range(len(v)):
1412
+ new_k = (
1413
+ *k[:name_idx],
1414
+ f"{k[name_idx][:-1]}_{i}",
1415
+ *k[name_idx + 1 :],
1416
+ )
1417
+ params[new_k] = v[i]
1418
+ del params[k]
1419
+ params = unflatten_dict(params)
1420
+ return params
1421
 
1422
  def decode(
1423
  self,