1#!/usr/bin/env python3
  2from __future__ import annotations
  3
  4import logging
  5import argparse
  6import os
  7import sys
  8import json
  9from pathlib import Path
 10
 11from tqdm import tqdm
 12from typing import Any, Sequence, NamedTuple
 13
 14# Necessary to load the local gguf package
 15if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists():
 16    sys.path.insert(0, str(Path(__file__).parent.parent.parent))
 17
 18import gguf
 19
 20logger = logging.getLogger("gguf-new-metadata")
 21
 22
 23class MetadataDetails(NamedTuple):
 24    type: gguf.GGUFValueType
 25    value: Any
 26    description: str = ''
 27    sub_type: gguf.GGUFValueType | None = None
 28
 29
 30def get_field_data(reader: gguf.GGUFReader, key: str) -> Any:
 31    field = reader.get_field(key)
 32
 33    return field.contents() if field else None
 34
 35
 36def find_token(token_list: Sequence[int], token: str) -> Sequence[int]:
 37    token_ids = [index for index, value in enumerate(token_list) if value == token]
 38
 39    if len(token_ids) == 0:
 40        raise LookupError(f'Unable to find "{token}" in token list!')
 41
 42    return token_ids
 43
 44
 45def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_metadata: dict[str, MetadataDetails], remove_metadata: Sequence[str]) -> None:
 46    for field in reader.fields.values():
 47        # Suppress virtual fields and fields written by GGUFWriter
 48        if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'):
 49            logger.debug(f'Suppressing {field.name}')
 50            continue
 51
 52        # Skip old chat templates if we have new ones
 53        if field.name.startswith(gguf.Keys.Tokenizer.CHAT_TEMPLATE) and gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata:
 54            logger.debug(f'Skipping {field.name}')
 55            continue
 56
 57        if field.name in remove_metadata:
 58            logger.debug(f'Removing {field.name}')
 59            continue
 60
 61        val_type = field.types[0]
 62        sub_type = field.types[-1] if val_type == gguf.GGUFValueType.ARRAY else None
 63        old_val = MetadataDetails(val_type, field.contents(), sub_type=sub_type)
 64        val = new_metadata.get(field.name, old_val)
 65
 66        if field.name in new_metadata:
 67            logger.debug(f'Modifying {field.name}: "{old_val.value}" -> "{val.value}" {val.description}')
 68            del new_metadata[field.name]
 69        elif val.value is not None:
 70            logger.debug(f'Copying {field.name}')
 71
 72        if val.value is not None:
 73            writer.add_key_value(field.name, val.value, val.type, sub_type=sub_type if val.sub_type is None else val.sub_type)
 74
 75    if gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata:
 76        logger.debug('Adding chat template(s)')
 77        writer.add_chat_template(new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE].value)
 78        del new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE]
 79
 80    for key, val in new_metadata.items():
 81        logger.debug(f'Adding {key}: "{val.value}" {val.description}')
 82        writer.add_key_value(key, val.value, val.type)
 83
 84    total_bytes = 0
 85
 86    for tensor in reader.tensors:
 87        total_bytes += tensor.n_bytes
 88        writer.add_tensor_info(tensor.name, tensor.data.shape, tensor.data.dtype, tensor.data.nbytes, tensor.tensor_type)
 89
 90    bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
 91
 92    writer.write_header_to_file()
 93    writer.write_kv_data_to_file()
 94    writer.write_ti_data_to_file()
 95
 96    for tensor in reader.tensors:
 97        writer.write_tensor_data(tensor.data, tensor_endianess=reader.endianess)
 98        bar.update(tensor.n_bytes)
 99
100    writer.close()
101
102
103def main() -> None:
104    tokenizer_metadata = (getattr(gguf.Keys.Tokenizer, n) for n in gguf.Keys.Tokenizer.__dict__.keys() if not n.startswith('_'))
105    token_names = dict((n.split('.')[-1][:-len('_token_id')], n) for n in tokenizer_metadata if n.endswith('_token_id'))
106
107    parser = argparse.ArgumentParser(description="Make a copy of a GGUF file with new metadata")
108    parser.add_argument("input",                                       type=Path, help="GGUF format model input filename")
109    parser.add_argument("output",                                      type=Path, help="GGUF format model output filename")
110    parser.add_argument("--general-name",                              type=str,  help="The models general.name", metavar='"name"')
111    parser.add_argument("--general-description",                       type=str,  help="The models general.description", metavar='"Description ..."')
112    parser.add_argument("--chat-template",                             type=str,  help="Chat template string (or JSON string containing templates)", metavar='"{% ... %} ..."')
113    parser.add_argument("--chat-template-config",                      type=Path, help="Config file containing chat template(s)", metavar='tokenizer_config.json')
114    parser.add_argument("--chat-template-file",                        type=Path, help="Jinja file containing chat template", metavar='chat_template.jinja')
115    parser.add_argument("--pre-tokenizer",                             type=str,  help="The models tokenizer.ggml.pre", metavar='"pre tokenizer"')
116    parser.add_argument("--remove-metadata",      action="append",     type=str,  help="Remove metadata (by key name) from output model", metavar='general.url')
117    parser.add_argument("--special-token",        action="append",     type=str,  help="Special token by value", nargs=2, metavar=(' | '.join(token_names.keys()), '"<token>"'))
118    parser.add_argument("--special-token-by-id",  action="append",     type=str,  help="Special token by id", nargs=2, metavar=(' | '.join(token_names.keys()), '0'))
119    parser.add_argument("--force",                action="store_true",            help="Bypass warnings without confirmation")
120    parser.add_argument("--verbose",              action="store_true",            help="Increase output verbosity")
121    args = parser.parse_args(None if len(sys.argv) > 2 else ["--help"])
122
123    logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
124
125    new_metadata = {}
126    remove_metadata = args.remove_metadata or []
127
128    if args.general_name:
129        new_metadata[gguf.Keys.General.NAME] = MetadataDetails(gguf.GGUFValueType.STRING, args.general_name)
130
131    if args.general_description:
132        new_metadata[gguf.Keys.General.DESCRIPTION] = MetadataDetails(gguf.GGUFValueType.STRING, args.general_description)
133
134    if args.chat_template:
135        new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = MetadataDetails(gguf.GGUFValueType.STRING, json.loads(args.chat_template) if args.chat_template.startswith('[') else args.chat_template)
136
137    if args.chat_template_config:
138        with open(args.chat_template_config, 'r', encoding='utf-8') as fp:
139            config = json.load(fp)
140            template = config.get('chat_template')
141            if template:
142                new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = MetadataDetails(gguf.GGUFValueType.STRING, template)
143
144    if args.chat_template_file:
145        with open(args.chat_template_file, 'r', encoding='utf-8') as fp:
146            template = fp.read()
147            new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = MetadataDetails(gguf.GGUFValueType.STRING, template)
148
149    if args.pre_tokenizer:
150        new_metadata[gguf.Keys.Tokenizer.PRE] = MetadataDetails(gguf.GGUFValueType.STRING, args.pre_tokenizer)
151
152    if remove_metadata:
153        logger.warning('*** Warning *** Warning *** Warning **')
154        logger.warning('* Most metadata is required for a fully functional GGUF file,')
155        logger.warning('* removing crucial metadata may result in a corrupt output file!')
156
157        if not args.force:
158            logger.warning('* Enter exactly YES if you are positive you want to proceed:')
159            response = input('YES, I am sure> ')
160            if response != 'YES':
161                logger.info("You didn't enter YES. Okay then, see ya!")
162                sys.exit(0)
163
164    logger.info(f'* Loading: {args.input}')
165    reader = gguf.GGUFReader(args.input, 'r')
166
167    arch = get_field_data(reader, gguf.Keys.General.ARCHITECTURE)
168
169    token_list = get_field_data(reader, gguf.Keys.Tokenizer.LIST) or []
170
171    for name, token in args.special_token or []:
172        if name not in token_names:
173            logger.warning(f'Unknown special token "{name}", ignoring...')
174        else:
175            ids = find_token(token_list, token)
176            new_metadata[token_names[name]] = MetadataDetails(gguf.GGUFValueType.UINT32, ids[0], f'= {token}')
177
178            if len(ids) > 1:
179                logger.warning(f'Multiple "{token}" tokens found, choosing ID {ids[0]}, use --special-token-by-id if you want another:')
180                logger.warning(', '.join(str(i) for i in ids))
181
182    for name, id_string in args.special_token_by_id or []:
183        if name not in token_names:
184            logger.warning(f'Unknown special token "{name}", ignoring...')
185        elif not id_string.isdecimal():
186            raise LookupError(f'Token ID "{id_string}" is not a valid ID!')
187        else:
188            id_int = int(id_string)
189
190            if id_int >= 0 and id_int < len(token_list):
191                new_metadata[token_names[name]] = MetadataDetails(gguf.GGUFValueType.UINT32, id_int, f'= {token_list[id_int]}')
192            else:
193                raise LookupError(f'Token ID {id_int} is not within token list!')
194
195    if os.path.isfile(args.output) and not args.force:
196        logger.warning('*** Warning *** Warning *** Warning **')
197        logger.warning(f'* The "{args.output}" GGUF file already exists, it will be overwritten!')
198        logger.warning('* Enter exactly YES if you are positive you want to proceed:')
199        response = input('YES, I am sure> ')
200        if response != 'YES':
201            logger.info("You didn't enter YES. Okay then, see ya!")
202            sys.exit(0)
203
204    logger.info(f'* Writing: {args.output}')
205    writer = gguf.GGUFWriter(args.output, arch=arch, endianess=reader.endianess)
206
207    alignment = get_field_data(reader, gguf.Keys.General.ALIGNMENT)
208    if alignment is not None:
209        logger.debug(f'Setting custom alignment: {alignment}')
210        writer.data_alignment = alignment
211
212    copy_with_new_metadata(reader, writer, new_metadata, remove_metadata)
213
214
215if __name__ == '__main__':
216    main()