import argparse from gguf.gguf_reader import GGUFReader from gguf import GGUFWriter,GGUFValueType,ReaderField,OrderedDict # noqa: E402 def add_keys(writer: GGUFWriter, fields: "OrderedDict[str, ReaderField]"): for key, field in fields.items(): if(key not in ["general.architecture","GGUF.version","GGUF.tensor_count","GGUF.kv_count"]): if field.types[0] == GGUFValueType.STRING: writer.add_string(key=key, val=''.join(chr(i) for i in field.parts[field.data[0]])) elif field.types[0] == GGUFValueType.ARRAY: writer.add_array(key=key, val=field.contents()) else: writer.add_key_value(key=key, val=field.parts[field.data[0]][0],vtype=field.types[0]) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Merge GGUF models, especially embedding tables.") parser.add_argument("model_src_path", help="Path to the main model GGUF file.") parser.add_argument("embed_src_path", help="Path to the model GGUF file to take the embeddings table from (or other tensors).") parser.add_argument("dst_path", help="Path to the output GGUF file.") parser.add_argument("--target_blocks", nargs="+", default=["token_embd.weight"], help="List of tensor names to merge from the embedding file. Default: token_embd.weight") args = parser.parse_args() reader_model = GGUFReader(args.model_src_path) reader_embed = GGUFReader(args.embed_src_path) archField = reader_model.get_field("general.architecture") if archField is None: print("Couldn't get arch from src0 file") exit(-1) arch = str(''.join(chr(i) for i in archField.parts[archField.data[0]])) archField = reader_model.get_field("general.architecture") if archField is None: print("Couldn't get arch from src1 file") exit(-1) if str(''.join(chr(i) for i in archField.parts[archField.data[0]])) != arch: print("src0 and sdc1 have different architectures") exit(-1) writer = GGUFWriter(path=args.dst_path, arch=arch) add_keys(writer,reader_model.fields) for tensor in reader_model.tensors: # print(tensor.name) if tensor.name in args.target_blocks: name = tensor.name for tensorSrc in reader_embed.tensors: if tensorSrc.name == name: writer.add_tensor(name = tensorSrc.name, tensor=tensorSrc.data, raw_shape=tensorSrc.shape.tolist().reverse(),raw_dtype= tensorSrc.tensor_type) break else: writer.add_tensor(name = tensor.name, tensor=tensor.data, raw_shape=tensor.shape.tolist().reverse(),raw_dtype= tensor.tensor_type) writer.write_header_to_file() writer.write_kv_data_to_file() writer.write_tensors_to_file() # exit(0) writer.close()