|
import re |
|
|
|
import jax.numpy as jnp |
|
from flax.traverse_util import flatten_dict, unflatten_dict |
|
|
|
import torch |
|
|
|
from modeling_flax_vqgan import VQModel |
|
from configuration_vqgan import VQGANConfig |
|
|
|
|
|
regex = r"\w+[.]\d+" |
|
|
|
|
|
def rename_key(key): |
|
pats = re.findall(regex, key) |
|
for pat in pats: |
|
key = key.replace(pat, "_".join(pat.split("."))) |
|
return key |
|
|
|
|
|
|
|
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): |
|
|
|
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} |
|
|
|
random_flax_state_dict = flatten_dict(flax_model.params) |
|
flax_state_dict = {} |
|
|
|
remove_base_model_prefix = (flax_model.base_model_prefix not in flax_model.params) and ( |
|
flax_model.base_model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()]) |
|
) |
|
add_base_model_prefix = (flax_model.base_model_prefix in flax_model.params) and ( |
|
flax_model.base_model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()]) |
|
) |
|
|
|
|
|
for pt_key, pt_tensor in pt_state_dict.items(): |
|
pt_tuple_key = tuple(pt_key.split(".")) |
|
|
|
has_base_model_prefix = pt_tuple_key[0] == flax_model.base_model_prefix |
|
require_base_model_prefix = (flax_model.base_model_prefix,) + pt_tuple_key in random_flax_state_dict |
|
|
|
if remove_base_model_prefix and has_base_model_prefix: |
|
pt_tuple_key = pt_tuple_key[1:] |
|
elif add_base_model_prefix and require_base_model_prefix: |
|
pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key |
|
|
|
|
|
if ( |
|
"norm" in pt_key |
|
and (pt_tuple_key[-1] == "bias") |
|
and (pt_tuple_key[:-1] + ("bias",) in random_flax_state_dict) |
|
): |
|
pt_tensor = pt_tensor[None, None, None, :] |
|
elif ( |
|
"norm" in pt_key |
|
and (pt_tuple_key[-1] == "bias") |
|
and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict) |
|
): |
|
pt_tuple_key = pt_tuple_key[:-1] + ("scale",) |
|
pt_tensor = pt_tensor[None, None, None, :] |
|
elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict: |
|
pt_tuple_key = pt_tuple_key[:-1] + ("scale",) |
|
pt_tensor = pt_tensor[None, None, None, :] |
|
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict: |
|
pt_tuple_key = pt_tuple_key[:-1] + ("embedding",) |
|
elif pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and pt_tuple_key not in random_flax_state_dict: |
|
|
|
pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) |
|
pt_tensor = pt_tensor.transpose(2, 3, 1, 0) |
|
elif pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict: |
|
|
|
pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) |
|
pt_tensor = pt_tensor.T |
|
elif pt_tuple_key[-1] == "gamma": |
|
pt_tuple_key = pt_tuple_key[:-1] + ("weight",) |
|
elif pt_tuple_key[-1] == "beta": |
|
pt_tuple_key = pt_tuple_key[:-1] + ("bias",) |
|
|
|
if pt_tuple_key in random_flax_state_dict: |
|
if pt_tensor.shape != random_flax_state_dict[pt_tuple_key].shape: |
|
raise ValueError( |
|
f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape " |
|
f"{random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}." |
|
) |
|
|
|
|
|
flax_state_dict[pt_tuple_key] = jnp.asarray(pt_tensor) |
|
|
|
return unflatten_dict(flax_state_dict) |
|
|
|
|
|
def convert_model(config_path, pt_state_dict_path, save_path): |
|
config = VQGANConfig.from_pretrained(config_path) |
|
model = VQModel(config) |
|
|
|
state_dict = torch.load(pt_state_dict_path, map_location="cpu")["state_dict"] |
|
keys = list(state_dict.keys()) |
|
for key in keys: |
|
if key.startswith("loss"): |
|
state_dict.pop(key) |
|
continue |
|
renamed_key = rename_key(key) |
|
state_dict[renamed_key] = state_dict.pop(key) |
|
|
|
state = convert_pytorch_state_dict_to_flax(state_dict, model) |
|
model.params = unflatten_dict(state) |
|
model.save_pretrained(save_path) |
|
|