1from __future__ import annotations
2
3from enum import Enum
4import re
5import logging
6import json
7import os
8from pathlib import Path
9from typing import Any, Callable, Sequence, Mapping, Iterable, Protocol, ClassVar, runtime_checkable
10
11try:
12 from sentencepiece import SentencePieceProcessor
13except ImportError:
14 SentencePieceProcessor = None
15
16try:
17 from mistral_common.tokens.tokenizers.mistral import MistralTokenizer # pyright: ignore[reportMissingImports]
18 from mistral_common.tokens.tokenizers.tekken import Tekkenizer # pyright: ignore[reportMissingImports]
19 from mistral_common.tokens.tokenizers.utils import ( # pyright: ignore[reportMissingImports]
20 _filter_valid_tokenizer_files,
21 )
22 from mistral_common.tokens.tokenizers.sentencepiece import ( # pyright: ignore[reportMissingImports]
23 SentencePieceTokenizer,
24 )
25except ImportError:
26 _mistral_common_installed = False
27 MistralTokenizer = None
28 Tekkenizer = None
29 SentencePieceTokenizer = None
30 _filter_valid_tokenizer_files = None
31else:
32 _mistral_common_installed = True
33
34try:
35 from mistral_common.tokens.tokenizers.utils import ( # pyright: ignore[reportMissingImports]
36 get_one_valid_tokenizer_file,
37 )
38except ImportError:
39 # We still want the conversion to work with older mistral-common versions.
40 get_one_valid_tokenizer_file = None
41
42
43import gguf
44
45from .gguf_writer import GGUFWriter
46
47logger = logging.getLogger(__name__)
48
49
50class SpecialVocab:
51 merges: list[str]
52 add_special_token: dict[str, bool]
53 special_token_ids: dict[str, int]
54 chat_template: str | Sequence[Mapping[str, str]] | None
55
56 def __init__(
57 self, path: str | os.PathLike[str], load_merges: bool = False,
58 special_token_types: Iterable[str] | None = None,
59 n_vocab: int | None = None,
60 ):
61 self.special_token_ids = {}
62 self.add_special_token = {}
63 self.n_vocab = n_vocab
64 self.load_merges = load_merges
65 self.merges = []
66 self.chat_template = None
67 if special_token_types is not None:
68 self.special_token_types = special_token_types
69 else:
70 self.special_token_types = ('bos', 'eos', 'unk', 'sep', 'pad', 'cls', 'mask')
71 self._load(Path(path))
72
73 def __repr__(self) -> str:
74 return '<SpecialVocab with {} merges, special tokens {}, add special tokens {}>'.format(
75 len(self.merges), self.special_token_ids or "unset", self.add_special_token or "unset",
76 )
77
78 def add_to_gguf(self, gw: GGUFWriter, quiet: bool = False) -> None:
79 if self.merges:
80 if not quiet:
81 logger.info(f'Adding {len(self.merges)} merge(s).')
82 gw.add_token_merges(self.merges)
83 elif self.load_merges:
84 logger.warning('Adding merges requested but no merges found, output may be non-functional.')
85 for typ, tokid in self.special_token_ids.items():
86 id_handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None)
87 if id_handler is None:
88 logger.warning(f'No handler for special token type {typ} with id {tokid} - skipping')
89 continue
90 if not quiet:
91 logger.info(f'Setting special token type {typ} to {tokid}')
92 id_handler(tokid)
93 for typ, value in self.add_special_token.items():
94 add_handler: Callable[[bool], None] | None = getattr(gw, f'add_add_{typ}_token', None)
95 if add_handler is None:
96 logger.warning(f'No handler for add_{typ}_token with value {value} - skipping')
97 continue
98 if not quiet:
99 logger.info(f'Setting add_{typ}_token to {value}')
100 add_handler(value)
101 if self.chat_template is not None:
102 if not quiet:
103 logger.info(f'Setting chat_template to {self.chat_template}')
104 gw.add_chat_template(self.chat_template)
105
106 def _load(self, path: Path) -> None:
107 self._try_load_from_tokenizer_json(path)
108 self._try_load_from_config_json(path)
109 if self.load_merges and not self.merges:
110 self._try_load_merges_txt(path)
111
112 def _try_load_merges_txt(self, path: Path) -> bool:
113 merges_file = path / 'merges.txt'
114 if not merges_file.is_file():
115 return False
116 with open(merges_file, 'r', encoding = 'utf-8') as fp:
117 first_line = next(fp, '').strip()
118 if not first_line.startswith('#'):
119 fp.seek(0)
120 line_num = 0
121 else:
122 line_num = 1
123 merges = []
124 for line in fp:
125 line_num += 1
126 line = line.strip()
127 if not line:
128 continue
129 parts = line.split(None, 3)
130 if len(parts) != 2:
131 logger.warning(f'{merges_file.name}: Line {line_num}: Entry malformed, ignoring')
132 continue
133 merges.append(f'{parts[0]} {parts[1]}')
134 self.merges = merges
135 return True
136
137 def _set_special_token(self, typ: str, tid: Any) -> None:
138 if not isinstance(tid, int):
139 return
140 if tid < 0:
141 raise ValueError(f'invalid value for special token type {typ}: {tid}')
142 if self.n_vocab is None or tid < self.n_vocab:
143 if typ in self.special_token_ids:
144 return
145 self.special_token_ids[typ] = tid
146 return
147 logger.warning(f'Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping')
148
149 def _try_load_from_tokenizer_json(self, path: Path) -> bool:
150 tokenizer = None
151 tokenizer_file = path / 'tokenizer.json'
152 if tokenizer_file.is_file():
153 with open(tokenizer_file, encoding = 'utf-8') as f:
154 tokenizer = json.load(f)
155 if self.load_merges:
156 merges = tokenizer.get('model', {}).get('merges')
157 if isinstance(merges, list) and merges:
158 if isinstance(merges[0], str):
159 self.merges = merges
160 elif isinstance(merges[0], list) and len(merges[0]) == 2 and isinstance(merges[0][0], str):
161 # New format since transformers 4.45 to support spaces in merges
162 # ref: https://github.com/ggml-org/llama.cpp/issues/9692
163 # TODO: internally store as the new format instead of converting to old
164 if any(' ' in s for pair in merges for s in pair):
165 logger.warning(f'Spaces in merges detected, encoding as {chr(ord(" ") + 256)!r}')
166 self.merges = [
167 ' '.join(
168 [
169 # ensure the spaces are properly encoded
170 ''.join(
171 chr(ord(c) + 256) if c == ' ' else c
172 for c in part
173 )
174 for part in pair
175 ]
176 )
177 for pair in merges
178 ]
179 else:
180 raise ValueError("Unknown tokenizer merges format")
181 added_tokens = tokenizer.get('added_tokens', {})
182 else:
183 added_tokens = {}
184 tokenizer_config = None
185 tokenizer_config_file = path / 'tokenizer_config.json'
186 if tokenizer_config_file.is_file():
187 with open(tokenizer_config_file, encoding = 'utf-8') as f:
188 tokenizer_config = json.load(f)
189 if tokenizer:
190 special_bos = (tokenizer_config or {}).get('bos_token')
191 special_cls = (tokenizer_config or {}).get('cls_token')
192 special_eos = (tokenizer_config or {}).get('eos_token')
193 special_sep = (tokenizer_config or {}).get('sep_token')
194 if not special_bos and special_cls and tokenizer_config:
195 tokenizer_config['bos_token'] = special_bos = special_cls
196 if not special_eos and special_sep and tokenizer_config:
197 tokenizer_config['eos_token'] = special_eos = special_sep
198 if post_processor := tokenizer.get('post_processor'):
199 for processor in post_processor.get('processors', [post_processor]):
200 if processor.get('type') == 'RobertaProcessing':
201 self.add_special_token['bos'] = True
202 self.add_special_token['eos'] = True
203 self.add_special_token['sep'] = True
204 if not special_cls and tokenizer_config:
205 special_cls = processor.get('cls', [special_bos])[0]
206 tokenizer_config['cls_token'] = special_cls
207 if not special_sep and tokenizer_config:
208 special_sep = processor.get('sep', [special_eos])[0]
209 tokenizer_config['sep_token'] = special_sep
210 continue
211 # Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added
212 # Only works with simple templates, **will** get it wrong on unusual sequences
213 if processor.get('type') == 'TemplateProcessing':
214 tmpl_single = processor.get('single', [])
215 tmpl_pair = processor.get('pair', [])
216 special_first = None
217 special_last = None
218 if len(tmpl_single) > 1:
219 if special_first := tmpl_single[0].get('SpecialToken', {}).get('id'):
220 if not tokenizer_config:
221 special_bos = special_first
222 self.add_special_token['bos'] = True if special_first in (special_bos, special_cls) else False
223 if special_first not in (special_bos, special_cls):
224 logger.warning(f'Unknown leading special token {special_first!r} in TemplateProcessing<single>')
225 if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'):
226 if not tokenizer_config:
227 special_eos = special_last
228 elif special_last != special_eos:
229 if 'eot' not in self.special_token_types:
230 self.special_token_types = tuple(self.special_token_types) + ('eot', )
231 tokenizer_config['eot_token'] = special_eos
232 elif 'eom' not in self.special_token_types:
233 self.special_token_types = tuple(self.special_token_types) + ('eom', )
234 tokenizer_config['eom_token'] = special_eos
235 else:
236 logger.warning(f'Overriding EOS token {special_eos!r} with {special_last!r} without EOT/EOM fallback!')
237 tokenizer_config['eos_token'] = special_eos = special_last
238 self.add_special_token['eos'] = True if special_last == special_eos else False
239 if special_last != special_eos:
240 logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing<single>')
241 if tmpl_pair:
242 seq_start = 1 if special_first and tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0
243 seq_stop = -1 if special_last and tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None
244 if (special_first and seq_start == 0) or (special_last and seq_stop is None):
245 logger.warning('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>')
246 if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]:
247 tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id')
248 tmpl_b = tmpl_pair[-1].get('Sequence', {}).get('id')
249 if tmpl_a != 'A' or tmpl_b != 'B':
250 logger.warning(f'Unknown sequence {tmpl_a}...{tmpl_b} in TemplateProcessing<pair>')
251 # A [sep] [eos] B
252 if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair[1:-1]):
253 add_sep = False
254 if special_entry := tmpl_pair[0].get('SpecialToken', {}).get('id'):
255 if special_entry in (special_sep, special_eos) and not special_last:
256 add_sep = True
257 if special_entry not in (special_sep, special_eos):
258 logger.warning(f'Unknown separator token {special_entry!r} in TemplateProcessing<pair>')
259 else:
260 logger.warning(f'Unknown middle sequence {tmpl_pair[0]!r} in TemplateProcessing<pair>')
261 if len(tmpl_pair) == 2:
262 if special_entry := tmpl_pair[1].get('SpecialToken', {}).get('id'):
263 if special_entry in (special_sep, special_eos):
264 add_sep = True
265 if special_entry not in (special_sep, special_eos):
266 logger.warning(f'Unknown second separator token {special_entry!r} in TemplateProcessing<pair>')
267 else:
268 logger.warning(f'Unknown second middle sequence {tmpl_pair[1]!r} in TemplateProcessing<pair>')
269 self.add_special_token['sep'] = add_sep
270 if add_sep and not special_sep and tokenizer_config:
271 tokenizer_config['sep_token'] = special_eos
272 continue
273 if not tokenizer_config:
274 return True
275 chat_template_alt = None
276 chat_template_json = path / 'chat_template.json'
277 chat_template_jinja = path / 'chat_template.jinja'
278 if chat_template_jinja.is_file():
279 with open(chat_template_jinja, encoding = 'utf-8') as f:
280 chat_template_alt = f.read()
281 if additional_templates := list((path / 'additional_chat_templates').glob('*.jinja')):
282 chat_template_alt = [{'name': 'default', 'template': chat_template_alt}]
283 for template_path in additional_templates:
284 with open(template_path, encoding = 'utf-8') as fp:
285 chat_template_alt.append({'name': template_path.stem, 'template': fp.read()})
286 elif chat_template_json.is_file():
287 with open(chat_template_json, encoding = 'utf-8') as f:
288 chat_template_alt = json.load(f).get('chat_template')
289 chat_template = tokenizer_config.get('chat_template', chat_template_alt)
290 if chat_template is None or isinstance(chat_template, (str, list)):
291 self.chat_template = chat_template
292 else:
293 logger.warning(f'Bad type for chat_template field in {tokenizer_config_file!r} - ignoring')
294 for typ in self.special_token_types:
295 add_entry = tokenizer_config.get(f'add_{typ}_token')
296 if isinstance(add_entry, bool):
297 self.add_special_token[typ] = add_entry
298 entry = tokenizer_config.get(f'{typ}_token')
299 if isinstance(entry, str):
300 tc_content = entry
301 elif isinstance(entry, dict):
302 entry_content = entry.get('content')
303 if not isinstance(entry_content, str):
304 continue
305 tc_content = entry_content
306 else:
307 continue
308 # We only need the first match here.
309 maybe_token_id = next(
310 (atok.get('id') for atok in added_tokens if atok.get('content') == tc_content),
311 None,
312 )
313 self._set_special_token(typ, maybe_token_id)
314 return True
315
316 def _try_load_from_config_json(self, path: Path) -> bool:
317 config_file = path / 'config.json'
318 if not config_file.is_file():
319 return False
320 with open(config_file, encoding = 'utf-8') as f:
321 config = json.load(f)
322 for typ in self.special_token_types:
323 token_id = config.get(f'{typ}_token_id')
324 # If not found at root, check in text_config (for multimodal models like Kimi-VL)
325 if token_id is None and 'text_config' in config:
326 token_id = config['text_config'].get(f'{typ}_token_id')
327 self._set_special_token(typ, token_id)
328 return True
329
330
331@runtime_checkable
332class BaseVocab(Protocol):
333 tokenizer_model: ClassVar[str]
334 name: ClassVar[str]
335
336
337@runtime_checkable
338class Vocab(BaseVocab, Protocol):
339 vocab_size: int
340 added_tokens_dict: dict[str, int]
341 added_tokens_list: list[str]
342 fname_tokenizer: Path
343
344 def __init__(self, base_path: Path): ...
345 def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: ...
346
347
348class NoVocab(BaseVocab):
349 tokenizer_model = "no_vocab"
350 name = "no_vocab"
351
352 def __repr__(self) -> str:
353 return "<NoVocab for a model without integrated vocabulary>"
354
355
356class BpeVocab(Vocab):
357 tokenizer_model = "gpt2"
358 name = "bpe"
359
360 def __init__(self, base_path: Path):
361 added_tokens: dict[str, int] = {}
362
363 if (fname_tokenizer := base_path / 'vocab.json').exists():
364 # "slow" tokenizer
365 with open(fname_tokenizer, encoding="utf-8") as f:
366 self.vocab = json.load(f)
367
368 try:
369 # FIXME: Verify that added tokens here _cannot_ overlap with the main vocab.
370 with open(base_path / 'added_tokens.json', encoding="utf-8") as f:
371 added_tokens = json.load(f)
372 except FileNotFoundError:
373 pass
374 else:
375 # "fast" tokenizer
376 fname_tokenizer = base_path / 'tokenizer.json'
377
378 # if this fails, FileNotFoundError propagates to caller
379 with open(fname_tokenizer, encoding="utf-8") as f:
380 tokenizer_json = json.load(f)
381
382 tokenizer_model: dict[str, Any] = tokenizer_json['model']
383 if (
384 tokenizer_model['type'] != 'BPE' or tokenizer_model.get('byte_fallback', False)
385 or tokenizer_json['decoder']['type'] != 'ByteLevel'
386 ):
387 raise FileNotFoundError('Cannot find GPT-2 BPE tokenizer')
388
389 self.vocab = tokenizer_model["vocab"]
390
391 if (added := tokenizer_json.get('added_tokens')) is not None:
392 # Added tokens here can be duplicates of the main vocabulary.
393 added_tokens = {item['content']: item['id']
394 for item in added
395 if item['content'] not in self.vocab}
396
397 vocab_size = len(self.vocab)
398 expected_ids = list(range(vocab_size, vocab_size + len(added_tokens)))
399 actual_ids = sorted(added_tokens.values())
400 if expected_ids != actual_ids:
401 expected_end_id = vocab_size + len(actual_ids) - 1
402 raise ValueError(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range "
403 f"{vocab_size} - {expected_end_id}; got {actual_ids}")
404
405 items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1])
406 self.added_tokens_dict = added_tokens
407 self.added_tokens_list = [text for (text, idx) in items]
408 self.vocab_size_base = vocab_size
409 self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
410 self.fname_tokenizer = fname_tokenizer
411
412 def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
413 reverse_vocab = {id: encoded_tok for encoded_tok, id in self.vocab.items()}
414
415 for i, _ in enumerate(self.vocab):
416 yield reverse_vocab[i], 0.0, gguf.TokenType.NORMAL
417
418 def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
419 for text in self.added_tokens_list:
420 score = -1000.0
421 yield text.encode("utf-8"), score, gguf.TokenType.CONTROL
422
423 def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
424 yield from self.bpe_tokens()
425 yield from self.added_tokens()
426
427 def __repr__(self) -> str:
428 return f"<BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
429
430
431class SentencePieceVocab(Vocab):
432 tokenizer_model = "llama"
433 name = "spm"
434
435 def __init__(self, base_path: Path):
436 if SentencePieceProcessor is None:
437 raise RuntimeError("sentencepiece is not installed")
438
439 added_tokens: dict[str, int] = {}
440 if (fname_tokenizer := base_path / 'tokenizer.model').exists():
441 # normal location
442 try:
443 with open(base_path / 'added_tokens.json', encoding="utf-8") as f:
444 added_tokens = json.load(f)
445 except FileNotFoundError:
446 pass
447 elif not (fname_tokenizer := base_path.parent / 'tokenizer.model').exists():
448 # not found in alternate location either
449 raise FileNotFoundError('Cannot find tokenizer.model')
450
451 self.sentencepiece_tokenizer = SentencePieceProcessor()
452 self.sentencepiece_tokenizer.LoadFromFile(str(fname_tokenizer))
453 vocab_size = self.sentencepiece_tokenizer.vocab_size()
454
455 new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size}
456 expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens)))
457 actual_new_ids = sorted(new_tokens.keys())
458
459 if expected_new_ids != actual_new_ids:
460 raise ValueError(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}")
461
462 # Token pieces that were added to the base vocabulary.
463 self.added_tokens_dict = added_tokens
464 self.added_tokens_list = [new_tokens[id] for id in actual_new_ids]
465 self.vocab_size_base = vocab_size
466 self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
467 self.fname_tokenizer = fname_tokenizer
468
469 def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
470 tokenizer = self.sentencepiece_tokenizer
471 for i in range(tokenizer.vocab_size()):
472 piece = tokenizer.IdToPiece(i)
473 text = piece.encode("utf-8")
474 score: float = tokenizer.GetScore(i)
475
476 toktype = gguf.TokenType.NORMAL
477 if tokenizer.IsUnknown(i):
478 toktype = gguf.TokenType.UNKNOWN
479 if tokenizer.IsControl(i):
480 toktype = gguf.TokenType.CONTROL
481
482 # NOTE: I think added_tokens are user defined.
483 # ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto
484 # if tokenizer.is_user_defined(i): toktype = gguf.TokenType.USER_DEFINED
485
486 if tokenizer.IsUnused(i):
487 toktype = gguf.TokenType.UNUSED
488 if tokenizer.IsByte(i):
489 toktype = gguf.TokenType.BYTE
490
491 yield text, score, toktype
492
493 def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
494 for text in self.added_tokens_list:
495 score = -1000.0
496 yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED
497
498 def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
499 yield from self.sentencepiece_tokens()
500 yield from self.added_tokens()
501
502 def __repr__(self) -> str:
503 return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
504
505
506class LlamaHfVocab(Vocab):
507 tokenizer_model = "llama"
508 name = "hfft"
509
510 def __init__(self, base_path: Path):
511 fname_tokenizer = base_path / 'tokenizer.json'
512 # if this fails, FileNotFoundError propagates to caller
513 with open(fname_tokenizer, encoding='utf-8') as f:
514 tokenizer_json = json.load(f)
515
516 # pre-check so we know if we need transformers
517 tokenizer_model: dict[str, Any] = tokenizer_json['model']
518 is_llama3 = (
519 tokenizer_model['type'] == 'BPE' and tokenizer_model.get('ignore_merges', False)
520 and not tokenizer_model.get('byte_fallback', True)
521 )
522 if is_llama3:
523 raise TypeError('Llama 3 must be converted with BpeVocab')
524
525 if not is_llama3 and (
526 tokenizer_model['type'] != 'BPE' or not tokenizer_model.get('byte_fallback', False)
527 or tokenizer_json['decoder']['type'] != 'Sequence'
528 ):
529 raise FileNotFoundError('Cannot find Llama BPE tokenizer')
530
531 try:
532 from transformers import AutoTokenizer
533 except ImportError as e:
534 raise ImportError(
535 "To use LlamaHfVocab, please install the `transformers` package. "
536 "You can install it with `pip install transformers`."
537 ) from e
538
539 # Allow the tokenizer to default to slow or fast versions.
540 # Explicitly set tokenizer to use local paths.
541 self.tokenizer = AutoTokenizer.from_pretrained(
542 base_path,
543 cache_dir=base_path,
544 local_files_only=True,
545 )
546 assert self.tokenizer.is_fast # assume tokenizer.json is used
547
548 # Initialize lists and dictionaries for added tokens
549 self.added_tokens_list = []
550 self.added_tokens_dict = dict()
551 self.added_tokens_ids = set()
552
553 # Process added tokens
554 for tok, tokidx in sorted(
555 self.tokenizer.get_added_vocab().items(), key=lambda x: x[1]
556 ):
557 # Only consider added tokens that are not in the base vocabulary
558 if tokidx >= self.tokenizer.vocab_size:
559 self.added_tokens_list.append(tok)
560 self.added_tokens_dict[tok] = tokidx
561 self.added_tokens_ids.add(tokidx)
562
563 # Store special tokens and their IDs
564 self.specials = {
565 tok: self.tokenizer.get_vocab()[tok]
566 for tok in self.tokenizer.all_special_tokens
567 }
568 self.special_ids = set(self.tokenizer.all_special_ids)
569
570 # Set vocabulary sizes
571 self.vocab_size_base = self.tokenizer.vocab_size
572 self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
573
574 self.fname_tokenizer = fname_tokenizer
575
576 def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
577 reverse_vocab = {
578 id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items()
579 }
580
581 for token_id in range(self.vocab_size_base):
582 # Skip processing added tokens here
583 if token_id in self.added_tokens_ids:
584 continue
585
586 # Convert token text to bytes
587 token_text = reverse_vocab[token_id].encode("utf-8")
588
589 # Yield token text, score, and type
590 yield token_text, self.get_token_score(token_id), self.get_token_type(
591 token_id, token_text, self.special_ids # Reuse already stored special IDs
592 )
593
594 def get_token_type(self, token_id: int, token_text: bytes, special_ids: set[int]) -> gguf.TokenType:
595 # Special case for byte tokens
596 if re.fullmatch(br"<0x[0-9A-Fa-f]{2}>", token_text):
597 return gguf.TokenType.BYTE
598
599 # Determine token type based on whether it's a special token
600 return gguf.TokenType.CONTROL if token_id in special_ids else gguf.TokenType.NORMAL
601
602 def get_token_score(self, token_id: int) -> float:
603 # Placeholder for actual logic to determine the token's score
604 # This needs to be implemented based on specific requirements
605 return -1000.0 # Default score
606
607 def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
608 for text in self.added_tokens_list:
609 if text in self.specials:
610 toktype = self.get_token_type(self.specials[text], b'', self.special_ids)
611 score = self.get_token_score(self.specials[text])
612 else:
613 toktype = gguf.TokenType.USER_DEFINED
614 score = -1000.0
615
616 yield text.encode("utf-8"), score, toktype
617
618 def has_newline_token(self):
619 return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab
620
621 def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
622 yield from self.hf_tokens()
623 yield from self.added_tokens()
624
625 def __repr__(self) -> str:
626 return f"<LlamaHfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
627
628
629class MistralTokenizerType(str, Enum):
630 spm = "spm"
631 tekken = "tekken"
632
633
634# Copied from Transformers (Apache 2.0)
635# https://github.com/huggingface/transformers/blob/main/src/transformers/convert_slow_tokenizer.py#L1544
636
637def bytes_to_unicode() -> dict[int, str]:
638 """
639 Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
640 characters the bpe code barfs on.
641
642 The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
643 if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
644 decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
645 tables between utf-8 bytes and unicode strings.
646 """
647 bs = (
648 list(range(ord("!"), ord("~") + 1))
649 + list(range(ord("ยก"), ord("ยฌ") + 1))
650 + list(range(ord("ยฎ"), ord("รฟ") + 1))
651 )
652 cs = bs[:]
653 n = 0
654 for b in range(2**8):
655 if b not in bs:
656 bs.append(b)
657 cs.append(2**8 + n)
658 n += 1
659 cs_str = [chr(n) for n in cs]
660 return dict(zip(bs, cs_str))
661
662
663class MistralVocab(Vocab):
664 tokenizer_model = "mistral"
665 name = "mistral"
666
667 added_tokens_dict: dict[str, int] = {}
668 added_tokens_list: list[str] = []
669
670 def __init__(self, base_path: Path):
671 if not _mistral_common_installed:
672 raise ImportError(
673 "To use MistralVocab, please install the `mistral-common` package. "
674 "You can install it with `pip install mistral-common`."
675 )
676 assert _filter_valid_tokenizer_files is not None, "mistral_common is not installed"
677 assert MistralTokenizer is not None, "mistral_common is not installed"
678 assert Tekkenizer is not None, "mistral_common is not installed"
679
680 logger.info(f"Loading Mistral tokenizer from {base_path}")
681
682 # Find the tokenizer files
683 all_files = [f.as_posix() for f in base_path.glob("**/*") if f.is_file()]
684
685 if get_one_valid_tokenizer_file is not None:
686 tokenizer_file_path = get_one_valid_tokenizer_file(all_files)
687 else:
688 valid_tokenizer_files = _filter_valid_tokenizer_files(all_files)
689
690 if len(valid_tokenizer_files) == 0:
691 raise ValueError(f"No tokenizer file found in the directory: {base_path}")
692 # If there are multiple tokenizer files, we use tekken.json if it exists, otherwise the versioned one.
693 if len(valid_tokenizer_files) > 1:
694 if "tekken.json" in valid_tokenizer_files:
695 tokenizer_file = "tekken.json"
696 else:
697 tokenizer_file = sorted(valid_tokenizer_files)[-1]
698 logger.warning(
699 f"Multiple tokenizer files found in {base_path}. Using {tokenizer_file}"
700 )
701 else:
702 tokenizer_file = valid_tokenizer_files[0]
703
704 tokenizer_file_path = base_path / tokenizer_file
705
706 self.tokenizer = MistralTokenizer.from_file(
707 tokenizer_file_path
708 ).instruct_tokenizer.tokenizer
709 self.tokenizer_type = (
710 MistralTokenizerType.tekken
711 if isinstance(self.tokenizer, Tekkenizer)
712 else MistralTokenizerType.spm
713 )
714 self.vocab_size = self.tokenizer.n_words
715 self.fname_tokenizer = tokenizer_file_path
716 self._name = (
717 "mistral-" + self.tokenizer_type.value + "-" + self.tokenizer.version
718 )
719
720 @property
721 def tokenizer_name(self) -> str:
722 return self._name
723
724 @property
725 def gguf_tokenizer_model(self) -> str:
726 return "llama" if self.tokenizer_type == MistralTokenizerType.spm else "gpt2"
727
728 def _sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
729 assert SentencePieceTokenizer is not None, "mistral_common is not installed"
730 assert isinstance(self.tokenizer, SentencePieceTokenizer), (
731 f"Expected SentencePieceTokenizer, got {type(self.tokenizer)}"
732 )
733
734 for i in range(self.tokenizer._model.vocab_size()):
735 piece = self.tokenizer._model.IdToPiece(i)
736 text = piece.encode("utf-8")
737 score: float = self.tokenizer._model.GetScore(i)
738
739 toktype = gguf.TokenType.NORMAL
740 if self.tokenizer._model.IsUnknown(i):
741 toktype = gguf.TokenType.UNKNOWN
742 if self.tokenizer._model.IsControl(i):
743 toktype = gguf.TokenType.CONTROL
744
745 if self.tokenizer._model.IsUnused(i):
746 toktype = gguf.TokenType.UNUSED
747 if self.tokenizer._model.IsByte(i):
748 toktype = gguf.TokenType.BYTE
749
750 yield text, score, toktype
751
752 def _tekken_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
753 assert Tekkenizer is not None, "mistral_common is not installed"
754 assert isinstance(self.tokenizer, Tekkenizer), (
755 f"Expected Tekkenizer, got {type(self.tokenizer)}"
756 )
757
758 byte_encoder = bytes_to_unicode()
759 for token_id in range(self.tokenizer.num_special_tokens):
760 yield (
761 self.tokenizer.id_to_piece(token_id).encode("utf-8"),
762 0,
763 gguf.TokenType.CONTROL
764 )
765 for token in self.tokenizer._tekken_token2id_nospecial:
766 yield (
767 self.token_bytes_to_string(token, byte_encoder).encode("utf-8"),
768 0,
769 gguf.TokenType.NORMAL,
770 )
771
772 def get_token_id(self, token: str) -> int:
773 assert SentencePieceTokenizer is not None and Tekkenizer is not None, "mistral_common is not installed"
774 if self.tokenizer_type == MistralTokenizerType.spm:
775 assert isinstance(self.tokenizer, SentencePieceTokenizer)
776 return self.tokenizer._vocab.index(token)
777 elif self.tokenizer_type == MistralTokenizerType.tekken:
778 assert isinstance(self.tokenizer, Tekkenizer)
779 return (
780 self.tokenizer._vocab.index(token) + self.tokenizer.num_special_tokens
781 )
782 else:
783 raise ValueError(f"Unknown tokenizer type: {self.tokenizer_type}")
784
785 @property
786 def bos_id(self) -> int:
787 return self.tokenizer.bos_id
788
789 @property
790 def eos_id(self) -> int:
791 return self.tokenizer.eos_id
792
793 @property
794 def pad_id(self) -> int:
795 if self.tokenizer.pad_id == -1:
796 return self.eos_id
797 return self.tokenizer.pad_id
798
799 @property
800 def unk_id(self) -> int:
801 return self.tokenizer.unk_id
802
803 @property
804 def bos_token(self) -> str:
805 return self.tokenizer.id_to_piece(self.tokenizer.bos_id)
806
807 @property
808 def eos_token(self) -> str:
809 return self.tokenizer.id_to_piece(self.tokenizer.eos_id)
810
811 @property
812 def pad_token(self) -> str:
813 return self.tokenizer.id_to_piece(self.tokenizer.pad_id)
814
815 @property
816 def unk_token(self) -> str:
817 return self.tokenizer.id_to_piece(self.tokenizer.unk_id)
818
819 def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
820 if self.tokenizer_type == MistralTokenizerType.spm:
821 yield from self._sentencepiece_tokens()
822
823 elif self.tokenizer_type == MistralTokenizerType.tekken:
824 yield from self._tekken_tokens()
825
826 else:
827 raise ValueError(f"Unknown tokenizer type: {self.tokenizer_type}")
828
829 @staticmethod
830 def token_bytes_to_string(b, byte_encoder):
831 return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
832
833 def extract_vocab_merges_from_model(self):
834 # Adapted from Transformers (Apache 2.0)
835 # https://github.com/huggingface/transformers/blob/main/src/transformers/convert_slow_tokenizer.py
836 assert Tekkenizer is not None and isinstance(self.tokenizer, Tekkenizer), (
837 f"Expected Tekkenizer, got {type(self.tokenizer)}"
838 )
839 mergeable_ranks = self.tokenizer._model._mergeable_ranks
840 token_bytes_map = {
841 rank: token_bytes for token_bytes, rank in mergeable_ranks.items()
842 }
843 merge_pairs = []
844
845 # Sort vocab by rank to ensure correct merge order
846 for i in range(256, self.vocab_size - self.tokenizer.num_special_tokens):
847 merged_token = token_bytes_map[i]
848 local = []
849 for j in range(1, len(merged_token)):
850 left = merged_token[:j]
851 right = merged_token[j:]
852 if (
853 left in mergeable_ranks
854 and right in mergeable_ranks
855 and (left + right) in mergeable_ranks
856 ):
857 local.append((left, right, i))
858 if not local:
859 raise ValueError(
860 f"Could not find valid merge for token at rank {i}: {merged_token.decode('latin-1')}"
861 )
862 local = sorted(
863 local,
864 key=lambda x: (mergeable_ranks[x[0]], mergeable_ranks[x[1]]),
865 reverse=False,
866 )
867 merge_pairs.extend(local)
868 merge_pairs = sorted(merge_pairs, key=lambda val: val[2], reverse=False)
869
870 byte_encoder = bytes_to_unicode()
871
872 decoded_merge_pairs = [
873 [
874 self.token_bytes_to_string(val[0], byte_encoder),
875 self.token_bytes_to_string(val[1], byte_encoder),
876 ]
877 for val in merge_pairs
878 ]
879
880 merges = [
881 " ".join(
882 [
883 # ensure the spaces are properly encoded
884 "".join(chr(ord(c) + 256) if c == " " else c for c in part)
885 for part in pair
886 ]
887 )
888 for pair in decoded_merge_pairs
889 ]
890
891 return merges