File size: 1,646 Bytes
2995564
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
from torch.nn import functional as F
from safetensors.torch import load_file, save_file

pad_size = 128  # Specific to Qwen2-72B architecture
total_shards = 33  # Total number of shards in the model, edit according to the actual files

for shard_idx in range(1, total_shards + 1):
    # Generate filename with zero-padded shard numbers
    filename = f"model-{shard_idx:05d}-of-{total_shards:05d}.safetensors"
    
    # Load shard
    state_dict = load_file(filename)
    modified = False

    # Process each tensor in the current shard
    for key in list(state_dict.keys()):
        tensor = state_dict[key]
        
        if 'mlp.up_proj.weight' in key or 'mlp.gate_proj.weight' in key:
            # Apply interleaving pattern for up/gate projections
            prev_tensor = F.pad(tensor.unsqueeze(1), (0, 0, 0, 1, 0, 0)).reshape(29568*2, -1)[:pad_size*2]
            new_tensor = torch.cat([prev_tensor, tensor[pad_size:]], dim=0)
            state_dict[key] = new_tensor
            modified = True
            
        elif 'mlp.down_proj.weight' in key:
            # Apply pattern for down projection
            prev_tensor = F.pad(tensor.unsqueeze(2), (0, 1)).reshape(8192, 29568*2)[:, :pad_size*2]
            new_tensor = torch.cat([prev_tensor, tensor[:, pad_size:]], dim=1)
            state_dict[key] = new_tensor
            modified = True

    # Save modified shard back to original file if changes were made
    if modified:
        save_file(state_dict, filename, metadata={"format": "pt"})
        print(f"Processed and saved {filename}")
    else:
        print(f"No modifications needed for {filename}")