|
import torch |
|
from torch.nn import functional as F |
|
from safetensors.torch import load_file, save_file |
|
|
|
pad_size = 128 |
|
total_shards = 33 |
|
|
|
for shard_idx in range(1, total_shards + 1): |
|
|
|
filename = f"model-{shard_idx:05d}-of-{total_shards:05d}.safetensors" |
|
|
|
|
|
state_dict = load_file(filename) |
|
modified = False |
|
|
|
|
|
for key in list(state_dict.keys()): |
|
tensor = state_dict[key] |
|
|
|
if 'mlp.up_proj.weight' in key or 'mlp.gate_proj.weight' in key: |
|
|
|
prev_tensor = F.pad(tensor.unsqueeze(1), (0, 0, 0, 1, 0, 0)).reshape(29568*2, -1)[:pad_size*2] |
|
new_tensor = torch.cat([prev_tensor, tensor[pad_size:]], dim=0) |
|
state_dict[key] = new_tensor |
|
modified = True |
|
|
|
elif 'mlp.down_proj.weight' in key: |
|
|
|
prev_tensor = F.pad(tensor.unsqueeze(2), (0, 1)).reshape(8192, 29568*2)[:, :pad_size*2] |
|
new_tensor = torch.cat([prev_tensor, tensor[:, pad_size:]], dim=1) |
|
state_dict[key] = new_tensor |
|
modified = True |
|
|
|
|
|
if modified: |
|
save_file(state_dict, filename, metadata={"format": "pt"}) |
|
print(f"Processed and saved {filename}") |
|
else: |
|
print(f"No modifications needed for {filename}") |
|
|