stduhpf commited on
Commit
0cb29e4
·
verified ·
1 Parent(s): 3a3d8a0

Upload merging script used

Browse files
Files changed (1) hide show
  1. swap_embeds.py +58 -0
swap_embeds.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from gguf.gguf_reader import GGUFReader
3
+ from gguf import GGUFWriter,GGUFValueType,ReaderField,OrderedDict # noqa: E402
4
+
5
+ def add_keys(writer: GGUFWriter, fields: "OrderedDict[str, ReaderField]"):
6
+ for key, field in fields.items():
7
+ if(key not in ["general.architecture","GGUF.version","GGUF.tensor_count","GGUF.kv_count"]):
8
+ if field.types[0] == GGUFValueType.STRING:
9
+ writer.add_string(key=key, val=''.join(chr(i) for i in field.parts[field.data[0]]))
10
+ elif field.types[0] == GGUFValueType.ARRAY:
11
+ writer.add_array(key=key, val=field.contents())
12
+ else:
13
+ writer.add_key_value(key=key, val=field.parts[field.data[0]][0],vtype=field.types[0])
14
+
15
+ if __name__ == "__main__":
16
+ parser = argparse.ArgumentParser(description="Merge GGUF models, especially embedding tables.")
17
+ parser.add_argument("model_src_path", help="Path to the main model GGUF file.")
18
+ parser.add_argument("embed_src_path", help="Path to the model GGUF file to take the embeddings table from (or other tensors).")
19
+ parser.add_argument("dst_path", help="Path to the output GGUF file.")
20
+ parser.add_argument("--target_blocks", nargs="+", default=["token_embd.weight"],
21
+ help="List of tensor names to merge from the embedding file. Default: token_embd.weight")
22
+
23
+ args = parser.parse_args()
24
+ reader_model = GGUFReader(args.model_src_path)
25
+ reader_embed = GGUFReader(args.embed_src_path)
26
+
27
+ archField = reader_model.get_field("general.architecture")
28
+ if archField is None:
29
+ print("Couldn't get arch from src0 file")
30
+ exit(-1)
31
+ arch = str(''.join(chr(i) for i in archField.parts[archField.data[0]]))
32
+ archField = reader_model.get_field("general.architecture")
33
+ if archField is None:
34
+ print("Couldn't get arch from src1 file")
35
+ exit(-1)
36
+ if str(''.join(chr(i) for i in archField.parts[archField.data[0]])) != arch:
37
+ print("src0 and sdc1 have different architectures")
38
+ exit(-1)
39
+ writer = GGUFWriter(path=args.dst_path, arch=arch)
40
+
41
+ add_keys(writer,reader_model.fields)
42
+
43
+ for tensor in reader_model.tensors:
44
+ # print(tensor.name)
45
+ if tensor.name in args.target_blocks:
46
+ name = tensor.name
47
+ for tensorSrc in reader_embed.tensors:
48
+ if tensorSrc.name == name:
49
+ writer.add_tensor(name = tensorSrc.name, tensor=tensorSrc.data, raw_shape=tensorSrc.shape.tolist().reverse(),raw_dtype= tensorSrc.tensor_type)
50
+ break
51
+ else:
52
+ writer.add_tensor(name = tensor.name, tensor=tensor.data, raw_shape=tensor.shape.tolist().reverse(),raw_dtype= tensor.tensor_type)
53
+
54
+ writer.write_header_to_file()
55
+ writer.write_kv_data_to_file()
56
+ writer.write_tensors_to_file()
57
+ # exit(0)
58
+ writer.close()