1#!/usr/bin/env python
2'''
3 Fetches the Jinja chat template of a HuggingFace model.
4 If a model has multiple chat templates, you can specify the variant name.
5
6 Syntax:
7 ./scripts/get_chat_template.py model_id [variant]
8
9 Examples:
10 ./scripts/get_chat_template.py CohereForAI/c4ai-command-r-plus tool_use
11 ./scripts/get_chat_template.py microsoft/Phi-3.5-mini-instruct
12'''
13
14import json
15import re
16import sys
17
18
19def get_chat_template(model_id, variant=None):
20 try:
21 # Use huggingface_hub library if available.
22 # Allows access to gated models if the user has access and ran `huggingface-cli login`.
23 from huggingface_hub import hf_hub_download
24 with open(hf_hub_download(repo_id=model_id, filename="tokenizer_config.json"), encoding="utf-8") as f:
25 config_str = f.read()
26 except ImportError:
27 import requests
28 assert re.match(r"^[\w.-]+/[\w.-]+$", model_id), f"Invalid model ID: {model_id}"
29 response = requests.get(f"https://huggingface.co/{model_id}/resolve/main/tokenizer_config.json")
30 if response.status_code == 401:
31 raise Exception('Access to this model is gated, please request access, authenticate with `huggingface-cli login` and make sure to run `pip install huggingface_hub`')
32 response.raise_for_status()
33 config_str = response.text
34
35 try:
36 config = json.loads(config_str)
37 except json.JSONDecodeError:
38 # Fix https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json
39 # (Remove extra '}' near the end of the file)
40 config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str))
41
42 chat_template = config['chat_template']
43 if isinstance(chat_template, str):
44 return chat_template
45 else:
46 variants = {
47 ct['name']: ct['template']
48 for ct in chat_template
49 }
50
51 def format_variants():
52 return ', '.join(f'"{v}"' for v in variants.keys())
53
54 if variant is None:
55 if 'default' not in variants:
56 raise Exception(f'Please specify a chat template variant (one of {format_variants()})')
57 variant = 'default'
58 sys.stderr.write(f'Note: picked "default" chat template variant (out of {format_variants()})\n')
59 elif variant not in variants:
60 raise Exception(f"Variant {variant} not found in chat template (found {format_variants()})")
61
62 return variants[variant]
63
64
65def main(args):
66 if len(args) < 1:
67 raise ValueError("Please provide a model ID and an optional variant name")
68 model_id = args[0]
69 variant = None if len(args) < 2 else args[1]
70
71 template = get_chat_template(model_id, variant)
72 sys.stdout.write(template)
73
74
75if __name__ == '__main__':
76 main(sys.argv[1:])