1#!/usr/bin/env python3
2from __future__ import annotations
3
4import logging
5import argparse
6import os
7import sys
8import json
9from pathlib import Path
10
11from tqdm import tqdm
12from typing import Any, Sequence, NamedTuple
13
14# Necessary to load the local gguf package
15if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists():
16 sys.path.insert(0, str(Path(__file__).parent.parent.parent))
17
18import gguf
19
20logger = logging.getLogger("gguf-new-metadata")
21
22
23class MetadataDetails(NamedTuple):
24 type: gguf.GGUFValueType
25 value: Any
26 description: str = ''
27 sub_type: gguf.GGUFValueType | None = None
28
29
30def get_field_data(reader: gguf.GGUFReader, key: str) -> Any:
31 field = reader.get_field(key)
32
33 return field.contents() if field else None
34
35
36def find_token(token_list: Sequence[int], token: str) -> Sequence[int]:
37 token_ids = [index for index, value in enumerate(token_list) if value == token]
38
39 if len(token_ids) == 0:
40 raise LookupError(f'Unable to find "{token}" in token list!')
41
42 return token_ids
43
44
45def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_metadata: dict[str, MetadataDetails], remove_metadata: Sequence[str]) -> None:
46 for field in reader.fields.values():
47 # Suppress virtual fields and fields written by GGUFWriter
48 if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'):
49 logger.debug(f'Suppressing {field.name}')
50 continue
51
52 # Skip old chat templates if we have new ones
53 if field.name.startswith(gguf.Keys.Tokenizer.CHAT_TEMPLATE) and gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata:
54 logger.debug(f'Skipping {field.name}')
55 continue
56
57 if field.name in remove_metadata:
58 logger.debug(f'Removing {field.name}')
59 continue
60
61 val_type = field.types[0]
62 sub_type = field.types[-1] if val_type == gguf.GGUFValueType.ARRAY else None
63 old_val = MetadataDetails(val_type, field.contents(), sub_type=sub_type)
64 val = new_metadata.get(field.name, old_val)
65
66 if field.name in new_metadata:
67 logger.debug(f'Modifying {field.name}: "{old_val.value}" -> "{val.value}" {val.description}')
68 del new_metadata[field.name]
69 elif val.value is not None:
70 logger.debug(f'Copying {field.name}')
71
72 if val.value is not None:
73 writer.add_key_value(field.name, val.value, val.type, sub_type=sub_type if val.sub_type is None else val.sub_type)
74
75 if gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata:
76 logger.debug('Adding chat template(s)')
77 writer.add_chat_template(new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE].value)
78 del new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE]
79
80 for key, val in new_metadata.items():
81 logger.debug(f'Adding {key}: "{val.value}" {val.description}')
82 writer.add_key_value(key, val.value, val.type)
83
84 total_bytes = 0
85
86 for tensor in reader.tensors:
87 total_bytes += tensor.n_bytes
88 writer.add_tensor_info(tensor.name, tensor.data.shape, tensor.data.dtype, tensor.data.nbytes, tensor.tensor_type)
89
90 bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
91
92 writer.write_header_to_file()
93 writer.write_kv_data_to_file()
94 writer.write_ti_data_to_file()
95
96 for tensor in reader.tensors:
97 writer.write_tensor_data(tensor.data, tensor_endianess=reader.endianess)
98 bar.update(tensor.n_bytes)
99
100 writer.close()
101
102
103def main() -> None:
104 tokenizer_metadata = (getattr(gguf.Keys.Tokenizer, n) for n in gguf.Keys.Tokenizer.__dict__.keys() if not n.startswith('_'))
105 token_names = dict((n.split('.')[-1][:-len('_token_id')], n) for n in tokenizer_metadata if n.endswith('_token_id'))
106
107 parser = argparse.ArgumentParser(description="Make a copy of a GGUF file with new metadata")
108 parser.add_argument("input", type=Path, help="GGUF format model input filename")
109 parser.add_argument("output", type=Path, help="GGUF format model output filename")
110 parser.add_argument("--general-name", type=str, help="The models general.name", metavar='"name"')
111 parser.add_argument("--general-description", type=str, help="The models general.description", metavar='"Description ..."')
112 parser.add_argument("--chat-template", type=str, help="Chat template string (or JSON string containing templates)", metavar='"{% ... %} ..."')
113 parser.add_argument("--chat-template-config", type=Path, help="Config file containing chat template(s)", metavar='tokenizer_config.json')
114 parser.add_argument("--chat-template-file", type=Path, help="Jinja file containing chat template", metavar='chat_template.jinja')
115 parser.add_argument("--pre-tokenizer", type=str, help="The models tokenizer.ggml.pre", metavar='"pre tokenizer"')
116 parser.add_argument("--remove-metadata", action="append", type=str, help="Remove metadata (by key name) from output model", metavar='general.url')
117 parser.add_argument("--special-token", action="append", type=str, help="Special token by value", nargs=2, metavar=(' | '.join(token_names.keys()), '"<token>"'))
118 parser.add_argument("--special-token-by-id", action="append", type=str, help="Special token by id", nargs=2, metavar=(' | '.join(token_names.keys()), '0'))
119 parser.add_argument("--force", action="store_true", help="Bypass warnings without confirmation")
120 parser.add_argument("--verbose", action="store_true", help="Increase output verbosity")
121 args = parser.parse_args(None if len(sys.argv) > 2 else ["--help"])
122
123 logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
124
125 new_metadata = {}
126 remove_metadata = args.remove_metadata or []
127
128 if args.general_name:
129 new_metadata[gguf.Keys.General.NAME] = MetadataDetails(gguf.GGUFValueType.STRING, args.general_name)
130
131 if args.general_description:
132 new_metadata[gguf.Keys.General.DESCRIPTION] = MetadataDetails(gguf.GGUFValueType.STRING, args.general_description)
133
134 if args.chat_template:
135 new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = MetadataDetails(gguf.GGUFValueType.STRING, json.loads(args.chat_template) if args.chat_template.startswith('[') else args.chat_template)
136
137 if args.chat_template_config:
138 with open(args.chat_template_config, 'r', encoding='utf-8') as fp:
139 config = json.load(fp)
140 template = config.get('chat_template')
141 if template:
142 new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = MetadataDetails(gguf.GGUFValueType.STRING, template)
143
144 if args.chat_template_file:
145 with open(args.chat_template_file, 'r', encoding='utf-8') as fp:
146 template = fp.read()
147 new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = MetadataDetails(gguf.GGUFValueType.STRING, template)
148
149 if args.pre_tokenizer:
150 new_metadata[gguf.Keys.Tokenizer.PRE] = MetadataDetails(gguf.GGUFValueType.STRING, args.pre_tokenizer)
151
152 if remove_metadata:
153 logger.warning('*** Warning *** Warning *** Warning **')
154 logger.warning('* Most metadata is required for a fully functional GGUF file,')
155 logger.warning('* removing crucial metadata may result in a corrupt output file!')
156
157 if not args.force:
158 logger.warning('* Enter exactly YES if you are positive you want to proceed:')
159 response = input('YES, I am sure> ')
160 if response != 'YES':
161 logger.info("You didn't enter YES. Okay then, see ya!")
162 sys.exit(0)
163
164 logger.info(f'* Loading: {args.input}')
165 reader = gguf.GGUFReader(args.input, 'r')
166
167 arch = get_field_data(reader, gguf.Keys.General.ARCHITECTURE)
168
169 token_list = get_field_data(reader, gguf.Keys.Tokenizer.LIST) or []
170
171 for name, token in args.special_token or []:
172 if name not in token_names:
173 logger.warning(f'Unknown special token "{name}", ignoring...')
174 else:
175 ids = find_token(token_list, token)
176 new_metadata[token_names[name]] = MetadataDetails(gguf.GGUFValueType.UINT32, ids[0], f'= {token}')
177
178 if len(ids) > 1:
179 logger.warning(f'Multiple "{token}" tokens found, choosing ID {ids[0]}, use --special-token-by-id if you want another:')
180 logger.warning(', '.join(str(i) for i in ids))
181
182 for name, id_string in args.special_token_by_id or []:
183 if name not in token_names:
184 logger.warning(f'Unknown special token "{name}", ignoring...')
185 elif not id_string.isdecimal():
186 raise LookupError(f'Token ID "{id_string}" is not a valid ID!')
187 else:
188 id_int = int(id_string)
189
190 if id_int >= 0 and id_int < len(token_list):
191 new_metadata[token_names[name]] = MetadataDetails(gguf.GGUFValueType.UINT32, id_int, f'= {token_list[id_int]}')
192 else:
193 raise LookupError(f'Token ID {id_int} is not within token list!')
194
195 if os.path.isfile(args.output) and not args.force:
196 logger.warning('*** Warning *** Warning *** Warning **')
197 logger.warning(f'* The "{args.output}" GGUF file already exists, it will be overwritten!')
198 logger.warning('* Enter exactly YES if you are positive you want to proceed:')
199 response = input('YES, I am sure> ')
200 if response != 'YES':
201 logger.info("You didn't enter YES. Okay then, see ya!")
202 sys.exit(0)
203
204 logger.info(f'* Writing: {args.output}')
205 writer = gguf.GGUFWriter(args.output, arch=arch, endianess=reader.endianess)
206
207 alignment = get_field_data(reader, gguf.Keys.General.ALIGNMENT)
208 if alignment is not None:
209 logger.debug(f'Setting custom alignment: {alignment}')
210 writer.data_alignment = alignment
211
212 copy_with_new_metadata(reader, writer, new_metadata, remove_metadata)
213
214
215if __name__ == '__main__':
216 main()