|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" DalleBart model. """ |
|
|
|
import math |
|
import os |
|
from functools import partial |
|
from pickle import UnpicklingError |
|
from typing import Optional, Tuple, Union |
|
|
|
import flax.linen as nn |
|
import jax |
|
import jax.numpy as jnp |
|
import msgpack.exceptions |
|
from flax.core.frozen_dict import unfreeze |
|
from flax.linen import make_causal_mask |
|
from flax.serialization import from_bytes |
|
from flax.traverse_util import flatten_dict, unflatten_dict |
|
from jax import lax |
|
from jax.random import PRNGKey |
|
from transformers.configuration_utils import PretrainedConfig |
|
from transformers.file_utils import ( |
|
FLAX_WEIGHTS_NAME, |
|
WEIGHTS_NAME, |
|
cached_path, |
|
hf_bucket_url, |
|
is_offline_mode, |
|
is_remote_url, |
|
) |
|
from transformers.modeling_flax_outputs import ( |
|
FlaxCausalLMOutputWithCrossAttentions, |
|
FlaxSeq2SeqLMOutput, |
|
) |
|
from transformers.modeling_flax_utils import ACT2FN |
|
from transformers.models.bart.modeling_flax_bart import ( |
|
FlaxBartAttention, |
|
FlaxBartDecoder, |
|
FlaxBartDecoderLayer, |
|
FlaxBartDecoderLayerCollection, |
|
FlaxBartEncoder, |
|
FlaxBartEncoderLayer, |
|
FlaxBartEncoderLayerCollection, |
|
FlaxBartForConditionalGeneration, |
|
FlaxBartForConditionalGenerationModule, |
|
FlaxBartModule, |
|
FlaxBartPreTrainedModel, |
|
) |
|
from transformers.utils import logging |
|
|
|
from .configuration import DalleBartConfig |
|
from .utils import PretrainedFromWandbMixin |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class FlaxBartAttention(FlaxBartAttention): |
|
""" |
|
Edits: |
|
- causal mask is used only in decoder and considers image_length |
|
""" |
|
|
|
def setup(self) -> None: |
|
self.head_dim = self.embed_dim // self.num_heads |
|
if self.head_dim * self.num_heads != self.embed_dim: |
|
raise ValueError( |
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" |
|
f" and `num_heads`: {self.num_heads})." |
|
) |
|
|
|
dense = partial( |
|
nn.Dense, |
|
self.embed_dim, |
|
use_bias=self.bias, |
|
dtype=self.dtype, |
|
kernel_init=jax.nn.initializers.normal(self.config.init_std), |
|
) |
|
|
|
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() |
|
self.out_proj = dense() |
|
|
|
self.dropout_layer = nn.Dropout(rate=self.dropout) |
|
|
|
if self.causal: |
|
|
|
self.causal_mask = make_causal_mask( |
|
jnp.ones((1, self.config.image_length), dtype="bool"), dtype="bool" |
|
) |
|
|
|
|
|
class FlaxBartEncoderLayer(FlaxBartEncoderLayer): |
|
""" |
|
Edits: |
|
- no bias |
|
- use custom FlaxBartAttention |
|
""" |
|
|
|
def setup(self) -> None: |
|
self.embed_dim = self.config.d_model |
|
self.self_attn = FlaxBartAttention( |
|
config=self.config, |
|
embed_dim=self.embed_dim, |
|
num_heads=self.config.encoder_attention_heads, |
|
dropout=self.config.attention_dropout, |
|
bias=False, |
|
dtype=self.dtype, |
|
) |
|
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) |
|
self.dropout_layer = nn.Dropout(rate=self.config.dropout) |
|
self.activation_fn = ACT2FN[self.config.activation_function] |
|
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) |
|
self.fc1 = nn.Dense( |
|
self.config.encoder_ffn_dim, |
|
dtype=self.dtype, |
|
use_bias=False, |
|
kernel_init=jax.nn.initializers.normal(self.config.init_std), |
|
) |
|
self.fc2 = nn.Dense( |
|
self.embed_dim, |
|
dtype=self.dtype, |
|
use_bias=False, |
|
kernel_init=jax.nn.initializers.normal(self.config.init_std), |
|
) |
|
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) |
|
|
|
|
|
class FlaxBartEncoderLayerCollection(FlaxBartEncoderLayerCollection): |
|
""" |
|
Edits: |
|
- use custom FlaxBartEncoderLayer |
|
- allow Gradient Checkpointing (nn.remat) |
|
""" |
|
|
|
def setup(self): |
|
layer_module = ( |
|
nn.remat(FlaxBartEncoderLayer, concrete=True) |
|
if self.config.gradient_checkpointing |
|
else FlaxBartEncoderLayer |
|
) |
|
self.layers = [ |
|
layer_module(self.config, name=str(i), dtype=self.dtype) |
|
for i in range(self.config.encoder_layers) |
|
] |
|
self.layerdrop = self.config.encoder_layerdrop |
|
|
|
|
|
class FlaxBartDecoderLayer(FlaxBartDecoderLayer): |
|
""" |
|
Edits: |
|
- no bias |
|
- uses custom FlaxBartAttention |
|
""" |
|
|
|
def setup(self) -> None: |
|
self.embed_dim = self.config.d_model |
|
self.self_attn = FlaxBartAttention( |
|
config=self.config, |
|
embed_dim=self.embed_dim, |
|
num_heads=self.config.decoder_attention_heads, |
|
dropout=self.config.attention_dropout, |
|
causal=True, |
|
bias=False, |
|
dtype=self.dtype, |
|
) |
|
self.dropout_layer = nn.Dropout(rate=self.config.dropout) |
|
self.activation_fn = ACT2FN[self.config.activation_function] |
|
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) |
|
|
|
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) |
|
self.encoder_attn = FlaxBartAttention( |
|
config=self.config, |
|
embed_dim=self.embed_dim, |
|
num_heads=self.config.decoder_attention_heads, |
|
dropout=self.config.attention_dropout, |
|
bias=False, |
|
dtype=self.dtype, |
|
) |
|
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) |
|
self.fc1 = nn.Dense( |
|
self.config.encoder_ffn_dim, |
|
dtype=self.dtype, |
|
use_bias=False, |
|
kernel_init=jax.nn.initializers.normal(self.config.init_std), |
|
) |
|
self.fc2 = nn.Dense( |
|
self.embed_dim, |
|
dtype=self.dtype, |
|
use_bias=False, |
|
kernel_init=jax.nn.initializers.normal(self.config.init_std), |
|
) |
|
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) |
|
|
|
|
|
class FlaxBartDecoderLayerCollection(FlaxBartDecoderLayerCollection): |
|
""" |
|
Edits: |
|
- use custom FlaxBartDecoderLayer |
|
- allow Gradient Checkpointing (nn.remat) |
|
""" |
|
|
|
def setup(self): |
|
layer_module = ( |
|
nn.remat(FlaxBartDecoderLayer, concrete=True) |
|
if self.config.gradient_checkpointing |
|
else FlaxBartDecoderLayer |
|
) |
|
self.layers = [ |
|
layer_module(self.config, name=str(i), dtype=self.dtype) |
|
for i in range(self.config.decoder_layers) |
|
] |
|
self.layerdrop = self.config.decoder_layerdrop |
|
|
|
|
|
class FlaxBartEncoder(FlaxBartEncoder): |
|
""" |
|
Edits: |
|
- offset set to 0 (no padding token) |
|
- use max_text_length instead of max_position_embeddings |
|
- use custom FlaxBartEncoderLayerCollection |
|
- embed_tokens cannot be None (issue at compile time) |
|
""" |
|
|
|
def setup(self): |
|
self.dropout_layer = nn.Dropout(rate=self.config.dropout) |
|
|
|
embed_dim = self.config.d_model |
|
self.padding_idx = self.config.pad_token_id |
|
self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0 |
|
|
|
|
|
|
|
self.offset = 0 |
|
self.embed_positions = nn.Embed( |
|
self.config.max_text_length + self.offset, |
|
embed_dim, |
|
embedding_init=jax.nn.initializers.normal(self.config.init_std), |
|
) |
|
self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype) |
|
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) |
|
|
|
|
|
class FlaxBartDecoder(FlaxBartDecoder): |
|
""" |
|
Edits: |
|
- offset set to 0 (no padding token) |
|
- use image_length instead of max_position_embeddings |
|
- use custom FlaxBartDecoderLayerCollection |
|
- embed_tokens cannot be None (issue at compile time) |
|
""" |
|
|
|
def setup(self): |
|
self.dropout_layer = nn.Dropout(rate=self.config.dropout) |
|
|
|
embed_dim = self.config.d_model |
|
self.padding_idx = self.config.pad_token_id |
|
self.embed_scale = ( |
|
math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0 |
|
) |
|
|
|
|
|
|
|
self.offset = 0 |
|
self.embed_positions = nn.Embed( |
|
self.config.image_length + self.offset, |
|
embed_dim, |
|
embedding_init=jax.nn.initializers.normal(self.config.init_std), |
|
) |
|
|
|
self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype) |
|
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) |
|
|
|
|
|
class FlaxBartModule(FlaxBartModule): |
|
""" |
|
Edits |
|
- use custom FlaxBartEncoder & FlaxBartDecoder |
|
- use separate embeddings for Encoder & Decoder |
|
""" |
|
|
|
def setup(self): |
|
encoder_embed_tokens = nn.Embed( |
|
self.config.encoder_vocab_size, |
|
self.config.d_model, |
|
embedding_init=jax.nn.initializers.normal(self.config.init_std), |
|
) |
|
decoder_embed_tokens = nn.Embed( |
|
self.config.image_vocab_size + 1, |
|
self.config.d_model, |
|
embedding_init=jax.nn.initializers.normal(self.config.init_std), |
|
) |
|
|
|
self.encoder = FlaxBartEncoder( |
|
self.config, dtype=self.dtype, embed_tokens=encoder_embed_tokens |
|
) |
|
self.decoder = FlaxBartDecoder( |
|
self.config, dtype=self.dtype, embed_tokens=decoder_embed_tokens |
|
) |
|
|
|
|
|
class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel): |
|
""" |
|
Edits: |
|
- added num_params property |
|
- config_class replaced to DalleBartConfig |
|
- __init__ accepts abstract_init which does uses parameter shape to initialize the model |
|
- init weights on CPU with `load_on_cpu` |
|
- restore weights on CPU with custom `from_pretrained` |
|
""" |
|
|
|
config_class = DalleBartConfig |
|
|
|
def __init__( |
|
self, |
|
config: DalleBartConfig, |
|
input_shape: Tuple[int] = (1, 1), |
|
seed: int = 0, |
|
dtype: jnp.dtype = jnp.float32, |
|
abstract_init: bool = False, |
|
load_on_cpu: bool = False, |
|
**kwargs, |
|
): |
|
module = self.module_class(config=config, dtype=dtype, **kwargs) |
|
|
|
|
|
if config is None: |
|
raise ValueError("config cannot be None") |
|
|
|
if module is None: |
|
raise ValueError("module cannot be None") |
|
|
|
|
|
self._config = config |
|
self._module = module |
|
|
|
|
|
self.key = PRNGKey(seed) |
|
self.dtype = dtype |
|
|
|
|
|
if load_on_cpu: |
|
|
|
init_fn = jax.jit(self.init_weights, static_argnums=(1,), backend="cpu") |
|
else: |
|
init_fn = self.init_weights |
|
|
|
|
|
random_params = self.init_weights(self.key, input_shape) |
|
if abstract_init: |
|
|
|
init_fn = partial(init_fn, input_shape=input_shape) |
|
random_params = jax.eval_shape(init_fn, self.key) |
|
else: |
|
random_params = init_fn(self.key, input_shape) |
|
|
|
|
|
self._required_params = set(flatten_dict(unfreeze(random_params)).keys()) |
|
self.params = random_params |
|
|
|
@property |
|
def num_params(self): |
|
num_params = jax.tree_map( |
|
lambda param: param.size, flatten_dict(unfreeze(self.params)) |
|
).values() |
|
return sum(list(num_params)) |
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, |
|
pretrained_model_name_or_path: Union[str, os.PathLike], |
|
dtype: jnp.dtype = jnp.float32, |
|
*model_args, |
|
**kwargs, |
|
): |
|
config = kwargs.pop("config", None) |
|
cache_dir = kwargs.pop("cache_dir", None) |
|
from_pt = kwargs.pop("from_pt", False) |
|
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) |
|
force_download = kwargs.pop("force_download", False) |
|
resume_download = kwargs.pop("resume_download", False) |
|
proxies = kwargs.pop("proxies", None) |
|
local_files_only = kwargs.pop("local_files_only", False) |
|
use_auth_token = kwargs.pop("use_auth_token", None) |
|
revision = kwargs.pop("revision", None) |
|
from_pipeline = kwargs.pop("_from_pipeline", None) |
|
from_auto_class = kwargs.pop("_from_auto", False) |
|
|
|
user_agent = { |
|
"file_type": "model", |
|
"framework": "flax", |
|
"from_auto_class": from_auto_class, |
|
} |
|
if from_pipeline is not None: |
|
user_agent["using_pipeline"] = from_pipeline |
|
|
|
if is_offline_mode() and not local_files_only: |
|
logger.info("Offline mode: forcing local_files_only=True") |
|
local_files_only = True |
|
|
|
|
|
if not isinstance(config, PretrainedConfig): |
|
config_path = ( |
|
config if config is not None else pretrained_model_name_or_path |
|
) |
|
config, model_kwargs = cls.config_class.from_pretrained( |
|
config_path, |
|
cache_dir=cache_dir, |
|
return_unused_kwargs=True, |
|
force_download=force_download, |
|
resume_download=resume_download, |
|
proxies=proxies, |
|
local_files_only=local_files_only, |
|
use_auth_token=use_auth_token, |
|
revision=revision, |
|
_from_auto=from_auto_class, |
|
_from_pipeline=from_pipeline, |
|
**kwargs, |
|
) |
|
else: |
|
model_kwargs = kwargs |
|
|
|
|
|
model_kwargs["dtype"] = dtype |
|
|
|
|
|
if pretrained_model_name_or_path is not None: |
|
if os.path.isdir(pretrained_model_name_or_path): |
|
if from_pt and os.path.isfile( |
|
os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) |
|
): |
|
|
|
archive_file = os.path.join( |
|
pretrained_model_name_or_path, WEIGHTS_NAME |
|
) |
|
elif os.path.isfile( |
|
os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME) |
|
): |
|
|
|
archive_file = os.path.join( |
|
pretrained_model_name_or_path, FLAX_WEIGHTS_NAME |
|
) |
|
else: |
|
raise EnvironmentError( |
|
f"Error no file named {[FLAX_WEIGHTS_NAME, WEIGHTS_NAME]} found in directory " |
|
f"{pretrained_model_name_or_path} or `from_pt` set to False" |
|
) |
|
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url( |
|
pretrained_model_name_or_path |
|
): |
|
archive_file = pretrained_model_name_or_path |
|
else: |
|
archive_file = hf_bucket_url( |
|
pretrained_model_name_or_path, |
|
filename=WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME, |
|
revision=revision, |
|
) |
|
|
|
|
|
try: |
|
resolved_archive_file = cached_path( |
|
archive_file, |
|
cache_dir=cache_dir, |
|
force_download=force_download, |
|
proxies=proxies, |
|
resume_download=resume_download, |
|
local_files_only=local_files_only, |
|
use_auth_token=use_auth_token, |
|
user_agent=user_agent, |
|
) |
|
except EnvironmentError as err: |
|
logger.error(err) |
|
msg = ( |
|
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n" |
|
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n" |
|
f" (make sure '{pretrained_model_name_or_path}' is not a path to a local directory with something else, in that case)\n\n" |
|
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named {WEIGHTS_NAME}.\n\n" |
|
) |
|
raise EnvironmentError(msg) |
|
|
|
if resolved_archive_file == archive_file: |
|
logger.info(f"loading weights file {archive_file}") |
|
else: |
|
logger.info( |
|
f"loading weights file {archive_file} from cache at {resolved_archive_file}" |
|
) |
|
else: |
|
resolved_archive_file = None |
|
|
|
|
|
model = cls(config, *model_args, **model_kwargs) |
|
|
|
with open(resolved_archive_file, "rb") as state_f: |
|
try: |
|
state = from_bytes(cls, state_f.read()) |
|
except (UnpicklingError, msgpack.exceptions.ExtraData) as e: |
|
try: |
|
with open(resolved_archive_file) as f: |
|
if f.read().startswith("version"): |
|
raise OSError( |
|
"You seem to have cloned a repository without having git-lfs installed. Please install " |
|
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " |
|
"you cloned." |
|
) |
|
else: |
|
raise ValueError from e |
|
except (UnicodeDecodeError, ValueError): |
|
raise EnvironmentError( |
|
f"Unable to convert {archive_file} to Flax deserializable object. " |
|
) |
|
|
|
|
|
if ( |
|
cls.base_model_prefix not in dict(model.params) |
|
and cls.base_model_prefix in state |
|
): |
|
state = state[cls.base_model_prefix] |
|
|
|
|
|
|
|
if ( |
|
cls.base_model_prefix in dict(model.params) |
|
and cls.base_model_prefix not in state |
|
): |
|
state = {cls.base_model_prefix: state} |
|
|
|
|
|
state = flatten_dict(state) |
|
|
|
random_state = flatten_dict(unfreeze(model.params)) |
|
|
|
missing_keys = model.required_params - set(state.keys()) |
|
unexpected_keys = set(state.keys()) - model.required_params |
|
|
|
|
|
|
|
mismatched_keys = [] |
|
for key in state.keys(): |
|
if key in random_state and state[key].shape != random_state[key].shape: |
|
if ignore_mismatched_sizes: |
|
mismatched_keys.append( |
|
(key, state[key].shape, random_state[key].shape) |
|
) |
|
state[key] = random_state[key] |
|
else: |
|
raise ValueError( |
|
f"Trying to load the pretrained weight for {key} failed: checkpoint has shape " |
|
f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. " |
|
"Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this " |
|
"model." |
|
) |
|
|
|
|
|
for missing_key in missing_keys: |
|
state[missing_key] = random_state[missing_key] |
|
|
|
|
|
for unexpected_key in unexpected_keys: |
|
del state[unexpected_key] |
|
|
|
if len(unexpected_keys) > 0: |
|
logger.warning( |
|
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " |
|
f"initializing {model.__class__.__name__}: {unexpected_keys}\n" |
|
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " |
|
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n" |
|
f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect " |
|
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." |
|
) |
|
else: |
|
logger.info( |
|
f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n" |
|
) |
|
|
|
if len(missing_keys) > 0: |
|
logger.warning( |
|
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " |
|
f"and are newly initialized: {missing_keys}\n" |
|
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." |
|
) |
|
elif len(mismatched_keys) == 0: |
|
logger.info( |
|
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n" |
|
f"If your task is similar to the task the model of the checkpoint was trained on, " |
|
f"you can already use {model.__class__.__name__} for predictions without further training." |
|
) |
|
if len(mismatched_keys) > 0: |
|
mismatched_warning = "\n".join( |
|
[ |
|
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" |
|
for key, shape1, shape2 in mismatched_keys |
|
] |
|
) |
|
logger.warning( |
|
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " |
|
f"and are newly initialized because the shapes did not match:\n{mismatched_warning}\n" |
|
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." |
|
) |
|
|
|
|
|
model.params = unflatten_dict(state) |
|
|
|
return model |
|
|
|
|
|
class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule): |
|
""" |
|
Edits: |
|
- no bias |
|
- lm_head set to image_vocab_size + 1 (for BOS) |
|
- uses custom FlaxBartModule |
|
""" |
|
|
|
def setup(self): |
|
self.model = FlaxBartModule(config=self.config, dtype=self.dtype) |
|
self.lm_head = nn.Dense( |
|
self.config.image_vocab_size |
|
+ 1, |
|
use_bias=False, |
|
dtype=self.dtype, |
|
kernel_init=jax.nn.initializers.normal(self.config.init_std), |
|
) |
|
|
|
def __call__( |
|
self, |
|
input_ids, |
|
attention_mask, |
|
decoder_input_ids, |
|
decoder_attention_mask, |
|
position_ids, |
|
decoder_position_ids, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
deterministic: bool = True, |
|
): |
|
outputs = self.model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
decoder_input_ids=decoder_input_ids, |
|
decoder_attention_mask=decoder_attention_mask, |
|
position_ids=position_ids, |
|
decoder_position_ids=decoder_position_ids, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
deterministic=deterministic, |
|
) |
|
|
|
hidden_states = outputs[0] |
|
|
|
if self.config.tie_word_embeddings: |
|
shared_embedding = self.model.variables["params"]["shared"]["embedding"] |
|
lm_logits = self.lm_head.apply( |
|
{"params": {"kernel": shared_embedding.T}}, hidden_states |
|
) |
|
else: |
|
lm_logits = self.lm_head(hidden_states) |
|
|
|
if not return_dict: |
|
output = (lm_logits,) + outputs[1:] |
|
return output |
|
|
|
return FlaxSeq2SeqLMOutput( |
|
logits=lm_logits, |
|
decoder_hidden_states=outputs.decoder_hidden_states, |
|
decoder_attentions=outputs.decoder_attentions, |
|
cross_attentions=outputs.cross_attentions, |
|
encoder_last_hidden_state=outputs.encoder_last_hidden_state, |
|
encoder_hidden_states=outputs.encoder_hidden_states, |
|
encoder_attentions=outputs.encoder_attentions, |
|
) |
|
|
|
|
|
class DalleBart( |
|
PretrainedFromWandbMixin, FlaxBartPreTrainedModel, FlaxBartForConditionalGeneration |
|
): |
|
""" |
|
Edits: |
|
- renamed from FlaxBartForConditionalGeneration |
|
- uses custom FlaxBartPreTrainedModel |
|
- uses custom FlaxBartForConditionalGenerationModule |
|
- no bias in decode method |
|
- custom prepare_inputs_for_generation using "max_length - 1" to avoid issues |
|
related to position embedding during model.generate() |
|
""" |
|
|
|
module_class = FlaxBartForConditionalGenerationModule |
|
|
|
def decode( |
|
self, |
|
decoder_input_ids, |
|
encoder_outputs, |
|
encoder_attention_mask: Optional[jnp.ndarray] = None, |
|
decoder_attention_mask: Optional[jnp.ndarray] = None, |
|
decoder_position_ids: Optional[jnp.ndarray] = None, |
|
past_key_values: dict = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
train: bool = False, |
|
params: dict = None, |
|
dropout_rng: PRNGKey = None, |
|
): |
|
output_attentions = ( |
|
output_attentions |
|
if output_attentions is not None |
|
else self.config.output_attentions |
|
) |
|
output_hidden_states = ( |
|
output_hidden_states |
|
if output_hidden_states is not None |
|
else self.config.output_hidden_states |
|
) |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.return_dict |
|
) |
|
|
|
encoder_hidden_states = encoder_outputs[0] |
|
if encoder_attention_mask is None: |
|
batch_size, sequence_length = encoder_hidden_states.shape[:2] |
|
encoder_attention_mask = jnp.ones((batch_size, sequence_length)) |
|
|
|
batch_size, sequence_length = decoder_input_ids.shape |
|
if decoder_attention_mask is None: |
|
decoder_attention_mask = jnp.ones((batch_size, sequence_length)) |
|
|
|
if decoder_position_ids is None: |
|
if past_key_values is not None: |
|
raise ValueError( |
|
"Make sure to provide `decoder_position_ids` when passing `past_key_values`." |
|
) |
|
|
|
decoder_position_ids = jnp.broadcast_to( |
|
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) |
|
) |
|
|
|
|
|
rngs = {} |
|
if dropout_rng is not None: |
|
rngs["dropout"] = dropout_rng |
|
|
|
inputs = {"params": params or self.params} |
|
|
|
|
|
|
|
|
|
if past_key_values: |
|
inputs["cache"] = past_key_values |
|
mutable = ["cache"] |
|
else: |
|
mutable = False |
|
|
|
def _decoder_forward( |
|
module, |
|
decoder_input_ids, |
|
decoder_attention_mask, |
|
decoder_position_ids, |
|
**kwargs, |
|
): |
|
decoder_module = module._get_decoder_module() |
|
outputs = decoder_module( |
|
decoder_input_ids, |
|
decoder_attention_mask, |
|
decoder_position_ids, |
|
**kwargs, |
|
) |
|
hidden_states = outputs[0] |
|
|
|
if self.config.tie_word_embeddings: |
|
shared_embedding = module.model.variables["params"]["shared"][ |
|
"embedding" |
|
] |
|
lm_logits = module.lm_head.apply( |
|
{"params": {"kernel": shared_embedding.T}}, hidden_states |
|
) |
|
else: |
|
lm_logits = module.lm_head(hidden_states) |
|
|
|
return lm_logits, outputs |
|
|
|
outputs = self.module.apply( |
|
inputs, |
|
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), |
|
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), |
|
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
deterministic=not train, |
|
rngs=rngs, |
|
mutable=mutable, |
|
method=_decoder_forward, |
|
) |
|
|
|
if past_key_values is None: |
|
lm_logits, decoder_outputs = outputs |
|
else: |
|
(lm_logits, decoder_outputs), past = outputs |
|
|
|
if return_dict: |
|
outputs = FlaxCausalLMOutputWithCrossAttentions( |
|
logits=lm_logits, |
|
hidden_states=decoder_outputs.hidden_states, |
|
attentions=decoder_outputs.attentions, |
|
cross_attentions=decoder_outputs.cross_attentions, |
|
) |
|
else: |
|
outputs = (lm_logits,) + decoder_outputs[1:] |
|
|
|
|
|
if past_key_values is not None and return_dict: |
|
outputs["past_key_values"] = unfreeze(past["cache"]) |
|
return outputs |
|
elif past_key_values is not None and not return_dict: |
|
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] |
|
|
|
return outputs |
|
|
|
def prepare_inputs_for_generation( |
|
self, |
|
decoder_input_ids, |
|
max_length, |
|
attention_mask: Optional[jnp.DeviceArray] = None, |
|
decoder_attention_mask: Optional[jnp.DeviceArray] = None, |
|
encoder_outputs=None, |
|
**kwargs, |
|
): |
|
|
|
batch_size, seq_length = decoder_input_ids.shape |
|
|
|
past_key_values = self.init_cache(batch_size, max_length - 1, encoder_outputs) |
|
|
|
|
|
|
|
extended_attention_mask = jnp.ones((batch_size, max_length - 1), dtype="i4") |
|
if decoder_attention_mask is not None: |
|
position_ids = decoder_attention_mask.cumsum(axis=-1) - 1 |
|
extended_attention_mask = lax.dynamic_update_slice( |
|
extended_attention_mask, decoder_attention_mask, (0, 0) |
|
) |
|
else: |
|
position_ids = jnp.broadcast_to( |
|
jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length) |
|
) |
|
|
|
return { |
|
"past_key_values": past_key_values, |
|
"encoder_outputs": encoder_outputs, |
|
"encoder_attention_mask": attention_mask, |
|
"decoder_attention_mask": extended_attention_mask, |
|
"decoder_position_ids": position_ids, |
|
} |
|
|