1from __future__ import annotations
  2
  3from enum import Enum
  4import re
  5import logging
  6import json
  7import os
  8from pathlib import Path
  9from typing import Any, Callable, Sequence, Mapping, Iterable, Protocol, ClassVar, runtime_checkable
 10
 11try:
 12    from sentencepiece import SentencePieceProcessor
 13except ImportError:
 14    SentencePieceProcessor = None
 15
 16try:
 17    from mistral_common.tokens.tokenizers.mistral import MistralTokenizer # pyright: ignore[reportMissingImports]
 18    from mistral_common.tokens.tokenizers.tekken import Tekkenizer # pyright: ignore[reportMissingImports]
 19    from mistral_common.tokens.tokenizers.utils import ( # pyright: ignore[reportMissingImports]
 20        _filter_valid_tokenizer_files,
 21    )
 22    from mistral_common.tokens.tokenizers.sentencepiece import ( # pyright: ignore[reportMissingImports]
 23        SentencePieceTokenizer,
 24    )
 25except ImportError:
 26    _mistral_common_installed = False
 27    MistralTokenizer = None
 28    Tekkenizer = None
 29    SentencePieceTokenizer = None
 30    _filter_valid_tokenizer_files = None
 31else:
 32    _mistral_common_installed = True
 33
 34try:
 35    from mistral_common.tokens.tokenizers.utils import ( # pyright: ignore[reportMissingImports]
 36        get_one_valid_tokenizer_file,
 37    )
 38except ImportError:
 39    # We still want the conversion to work with older mistral-common versions.
 40    get_one_valid_tokenizer_file = None
 41
 42
 43import gguf
 44
 45from .gguf_writer import GGUFWriter
 46
 47logger = logging.getLogger(__name__)
 48
 49
 50class SpecialVocab:
 51    merges: list[str]
 52    add_special_token: dict[str, bool]
 53    special_token_ids: dict[str, int]
 54    chat_template: str | Sequence[Mapping[str, str]] | None
 55
 56    def __init__(
 57        self, path: str | os.PathLike[str], load_merges: bool = False,
 58        special_token_types: Iterable[str] | None = None,
 59        n_vocab: int | None = None,
 60    ):
 61        self.special_token_ids = {}
 62        self.add_special_token = {}
 63        self.n_vocab = n_vocab
 64        self.load_merges = load_merges
 65        self.merges = []
 66        self.chat_template = None
 67        if special_token_types is not None:
 68            self.special_token_types = special_token_types
 69        else:
 70            self.special_token_types = ('bos', 'eos', 'unk', 'sep', 'pad', 'cls', 'mask')
 71        self._load(Path(path))
 72
 73    def __repr__(self) -> str:
 74        return '<SpecialVocab with {} merges, special tokens {}, add special tokens {}>'.format(
 75            len(self.merges), self.special_token_ids or "unset", self.add_special_token or "unset",
 76        )
 77
 78    def add_to_gguf(self, gw: GGUFWriter, quiet: bool = False) -> None:
 79        if self.merges:
 80            if not quiet:
 81                logger.info(f'Adding {len(self.merges)} merge(s).')
 82            gw.add_token_merges(self.merges)
 83        elif self.load_merges:
 84            logger.warning('Adding merges requested but no merges found, output may be non-functional.')
 85        for typ, tokid in self.special_token_ids.items():
 86            id_handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None)
 87            if id_handler is None:
 88                logger.warning(f'No handler for special token type {typ} with id {tokid} - skipping')
 89                continue
 90            if not quiet:
 91                logger.info(f'Setting special token type {typ} to {tokid}')
 92            id_handler(tokid)
 93        for typ, value in self.add_special_token.items():
 94            add_handler: Callable[[bool], None] | None = getattr(gw, f'add_add_{typ}_token', None)
 95            if add_handler is None:
 96                logger.warning(f'No handler for add_{typ}_token with value {value} - skipping')
 97                continue
 98            if not quiet:
 99                logger.info(f'Setting add_{typ}_token to {value}')
100            add_handler(value)
101        if self.chat_template is not None:
102            if not quiet:
103                logger.info(f'Setting chat_template to {self.chat_template}')
104            gw.add_chat_template(self.chat_template)
105
106    def _load(self, path: Path) -> None:
107        self._try_load_from_tokenizer_json(path)
108        self._try_load_from_config_json(path)
109        if self.load_merges and not self.merges:
110            self._try_load_merges_txt(path)
111
112    def _try_load_merges_txt(self, path: Path) -> bool:
113        merges_file = path / 'merges.txt'
114        if not merges_file.is_file():
115            return False
116        with open(merges_file, 'r', encoding = 'utf-8') as fp:
117            first_line = next(fp, '').strip()
118            if not first_line.startswith('#'):
119                fp.seek(0)
120                line_num = 0
121            else:
122                line_num = 1
123            merges = []
124            for line in fp:
125                line_num += 1
126                line = line.strip()
127                if not line:
128                    continue
129                parts = line.split(None, 3)
130                if len(parts) != 2:
131                    logger.warning(f'{merges_file.name}: Line {line_num}: Entry malformed, ignoring')
132                    continue
133                merges.append(f'{parts[0]} {parts[1]}')
134        self.merges = merges
135        return True
136
137    def _set_special_token(self, typ: str, tid: Any) -> None:
138        if not isinstance(tid, int):
139            return
140        if tid < 0:
141            raise ValueError(f'invalid value for special token type {typ}: {tid}')
142        if self.n_vocab is None or tid < self.n_vocab:
143            if typ in self.special_token_ids:
144                return
145            self.special_token_ids[typ] = tid
146            return
147        logger.warning(f'Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping')
148
149    def _try_load_from_tokenizer_json(self, path: Path) -> bool:
150        tokenizer = None
151        tokenizer_file = path / 'tokenizer.json'
152        if tokenizer_file.is_file():
153            with open(tokenizer_file, encoding = 'utf-8') as f:
154                tokenizer = json.load(f)
155            if self.load_merges:
156                merges = tokenizer.get('model', {}).get('merges')
157                if isinstance(merges, list) and merges:
158                    if isinstance(merges[0], str):
159                        self.merges = merges
160                    elif isinstance(merges[0], list) and len(merges[0]) == 2 and isinstance(merges[0][0], str):
161                        # New format since transformers 4.45 to support spaces in merges
162                        # ref: https://github.com/ggml-org/llama.cpp/issues/9692
163                        # TODO: internally store as the new format instead of converting to old
164                        if any(' ' in s for pair in merges for s in pair):
165                            logger.warning(f'Spaces in merges detected, encoding as {chr(ord(" ") + 256)!r}')
166                        self.merges = [
167                            ' '.join(
168                                [
169                                    # ensure the spaces are properly encoded
170                                    ''.join(
171                                        chr(ord(c) + 256) if c == ' ' else c
172                                        for c in part
173                                    )
174                                    for part in pair
175                                ]
176                            )
177                            for pair in merges
178                        ]
179                    else:
180                        raise ValueError("Unknown tokenizer merges format")
181            added_tokens = tokenizer.get('added_tokens', {})
182        else:
183            added_tokens = {}
184        tokenizer_config = None
185        tokenizer_config_file = path / 'tokenizer_config.json'
186        if tokenizer_config_file.is_file():
187            with open(tokenizer_config_file, encoding = 'utf-8') as f:
188                tokenizer_config = json.load(f)
189        if tokenizer:
190            special_bos = (tokenizer_config or {}).get('bos_token')
191            special_cls = (tokenizer_config or {}).get('cls_token')
192            special_eos = (tokenizer_config or {}).get('eos_token')
193            special_sep = (tokenizer_config or {}).get('sep_token')
194            if not special_bos and special_cls and tokenizer_config:
195                tokenizer_config['bos_token'] = special_bos = special_cls
196            if not special_eos and special_sep and tokenizer_config:
197                tokenizer_config['eos_token'] = special_eos = special_sep
198            if post_processor := tokenizer.get('post_processor'):
199                for processor in post_processor.get('processors', [post_processor]):
200                    if processor.get('type') == 'RobertaProcessing':
201                        self.add_special_token['bos'] = True
202                        self.add_special_token['eos'] = True
203                        self.add_special_token['sep'] = True
204                        if not special_cls and tokenizer_config:
205                            special_cls = processor.get('cls', [special_bos])[0]
206                            tokenizer_config['cls_token'] = special_cls
207                        if not special_sep and tokenizer_config:
208                            special_sep = processor.get('sep', [special_eos])[0]
209                            tokenizer_config['sep_token'] = special_sep
210                        continue
211                    # Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added
212                    # Only works with simple templates, **will** get it wrong on unusual sequences
213                    if processor.get('type') == 'TemplateProcessing':
214                        tmpl_single = processor.get('single', [])
215                        tmpl_pair = processor.get('pair', [])
216                        special_first = None
217                        special_last = None
218                        if len(tmpl_single) > 1:
219                            if special_first := tmpl_single[0].get('SpecialToken', {}).get('id'):
220                                if not tokenizer_config:
221                                    special_bos = special_first
222                                self.add_special_token['bos'] = True if special_first in (special_bos, special_cls) else False
223                                if special_first not in (special_bos, special_cls):
224                                    logger.warning(f'Unknown leading special token {special_first!r} in TemplateProcessing<single>')
225                            if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'):
226                                if not tokenizer_config:
227                                    special_eos = special_last
228                                elif special_last != special_eos:
229                                    if 'eot' not in self.special_token_types:
230                                        self.special_token_types = tuple(self.special_token_types) + ('eot', )
231                                        tokenizer_config['eot_token'] = special_eos
232                                    elif 'eom' not in self.special_token_types:
233                                        self.special_token_types = tuple(self.special_token_types) + ('eom', )
234                                        tokenizer_config['eom_token'] = special_eos
235                                    else:
236                                        logger.warning(f'Overriding EOS token {special_eos!r} with {special_last!r} without EOT/EOM fallback!')
237                                    tokenizer_config['eos_token'] = special_eos = special_last
238                                self.add_special_token['eos'] = True if special_last == special_eos else False
239                                if special_last != special_eos:
240                                    logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing<single>')
241                        if tmpl_pair:
242                            seq_start = 1 if special_first and tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0
243                            seq_stop = -1 if special_last and tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None
244                            if (special_first and seq_start == 0) or (special_last and seq_stop is None):
245                                logger.warning('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>')
246                            if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]:
247                                tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id')
248                                tmpl_b = tmpl_pair[-1].get('Sequence', {}).get('id')
249                                if tmpl_a != 'A' or tmpl_b != 'B':
250                                    logger.warning(f'Unknown sequence {tmpl_a}...{tmpl_b} in TemplateProcessing<pair>')
251                                # A [sep] [eos] B
252                                if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair[1:-1]):
253                                    add_sep = False
254                                    if special_entry := tmpl_pair[0].get('SpecialToken', {}).get('id'):
255                                        if special_entry in (special_sep, special_eos) and not special_last:
256                                            add_sep = True
257                                        if special_entry not in (special_sep, special_eos):
258                                            logger.warning(f'Unknown separator token {special_entry!r} in TemplateProcessing<pair>')
259                                    else:
260                                        logger.warning(f'Unknown middle sequence {tmpl_pair[0]!r} in TemplateProcessing<pair>')
261                                    if len(tmpl_pair) == 2:
262                                        if special_entry := tmpl_pair[1].get('SpecialToken', {}).get('id'):
263                                            if special_entry in (special_sep, special_eos):
264                                                add_sep = True
265                                            if special_entry not in (special_sep, special_eos):
266                                                logger.warning(f'Unknown second separator token {special_entry!r} in TemplateProcessing<pair>')
267                                        else:
268                                            logger.warning(f'Unknown second middle sequence {tmpl_pair[1]!r} in TemplateProcessing<pair>')
269                                    self.add_special_token['sep'] = add_sep
270                                    if add_sep and not special_sep and tokenizer_config:
271                                        tokenizer_config['sep_token'] = special_eos
272                        continue
273        if not tokenizer_config:
274            return True
275        chat_template_alt = None
276        chat_template_json = path / 'chat_template.json'
277        chat_template_jinja = path / 'chat_template.jinja'
278        if chat_template_jinja.is_file():
279            with open(chat_template_jinja, encoding = 'utf-8') as f:
280                chat_template_alt = f.read()
281            if additional_templates := list((path / 'additional_chat_templates').glob('*.jinja')):
282                chat_template_alt = [{'name': 'default', 'template': chat_template_alt}]
283                for template_path in additional_templates:
284                    with open(template_path, encoding = 'utf-8') as fp:
285                        chat_template_alt.append({'name': template_path.stem, 'template': fp.read()})
286        elif chat_template_json.is_file():
287            with open(chat_template_json, encoding = 'utf-8') as f:
288                chat_template_alt = json.load(f).get('chat_template')
289        chat_template = tokenizer_config.get('chat_template', chat_template_alt)
290        if chat_template is None or isinstance(chat_template, (str, list)):
291            self.chat_template = chat_template
292        else:
293            logger.warning(f'Bad type for chat_template field in {tokenizer_config_file!r} - ignoring')
294        for typ in self.special_token_types:
295            add_entry = tokenizer_config.get(f'add_{typ}_token')
296            if isinstance(add_entry, bool):
297                self.add_special_token[typ] = add_entry
298            entry = tokenizer_config.get(f'{typ}_token')
299            if isinstance(entry, str):
300                tc_content = entry
301            elif isinstance(entry, dict):
302                entry_content = entry.get('content')
303                if not isinstance(entry_content, str):
304                    continue
305                tc_content = entry_content
306            else:
307                continue
308            # We only need the first match here.
309            maybe_token_id = next(
310                (atok.get('id') for atok in added_tokens if atok.get('content') == tc_content),
311                None,
312            )
313            self._set_special_token(typ, maybe_token_id)
314        return True
315
316    def _try_load_from_config_json(self, path: Path) -> bool:
317        config_file = path / 'config.json'
318        if not config_file.is_file():
319            return False
320        with open(config_file, encoding = 'utf-8') as f:
321            config = json.load(f)
322        for typ in self.special_token_types:
323            token_id = config.get(f'{typ}_token_id')
324            # If not found at root, check in text_config (for multimodal models like Kimi-VL)
325            if token_id is None and 'text_config' in config:
326                token_id = config['text_config'].get(f'{typ}_token_id')
327            self._set_special_token(typ, token_id)
328        return True
329
330
331@runtime_checkable
332class BaseVocab(Protocol):
333    tokenizer_model: ClassVar[str]
334    name: ClassVar[str]
335
336
337@runtime_checkable
338class Vocab(BaseVocab, Protocol):
339    vocab_size: int
340    added_tokens_dict: dict[str, int]
341    added_tokens_list: list[str]
342    fname_tokenizer: Path
343
344    def __init__(self, base_path: Path): ...
345    def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: ...
346
347
348class NoVocab(BaseVocab):
349    tokenizer_model = "no_vocab"
350    name = "no_vocab"
351
352    def __repr__(self) -> str:
353        return "<NoVocab for a model without integrated vocabulary>"
354
355
356class BpeVocab(Vocab):
357    tokenizer_model = "gpt2"
358    name = "bpe"
359
360    def __init__(self, base_path: Path):
361        added_tokens: dict[str, int] = {}
362
363        if (fname_tokenizer := base_path / 'vocab.json').exists():
364            # "slow" tokenizer
365            with open(fname_tokenizer, encoding="utf-8") as f:
366                self.vocab = json.load(f)
367
368            try:
369                # FIXME: Verify that added tokens here _cannot_ overlap with the main vocab.
370                with open(base_path / 'added_tokens.json', encoding="utf-8") as f:
371                    added_tokens = json.load(f)
372            except FileNotFoundError:
373                pass
374        else:
375            # "fast" tokenizer
376            fname_tokenizer = base_path / 'tokenizer.json'
377
378            # if this fails, FileNotFoundError propagates to caller
379            with open(fname_tokenizer, encoding="utf-8") as f:
380                tokenizer_json = json.load(f)
381
382            tokenizer_model: dict[str, Any] = tokenizer_json['model']
383            if (
384                tokenizer_model['type'] != 'BPE' or tokenizer_model.get('byte_fallback', False)
385                or tokenizer_json['decoder']['type'] != 'ByteLevel'
386            ):
387                raise FileNotFoundError('Cannot find GPT-2 BPE tokenizer')
388
389            self.vocab = tokenizer_model["vocab"]
390
391            if (added := tokenizer_json.get('added_tokens')) is not None:
392                # Added tokens here can be duplicates of the main vocabulary.
393                added_tokens = {item['content']: item['id']
394                                for item in added
395                                if item['content'] not in self.vocab}
396
397        vocab_size   = len(self.vocab)
398        expected_ids = list(range(vocab_size, vocab_size + len(added_tokens)))
399        actual_ids   = sorted(added_tokens.values())
400        if expected_ids != actual_ids:
401            expected_end_id = vocab_size + len(actual_ids) - 1
402            raise ValueError(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range "
403                             f"{vocab_size} - {expected_end_id}; got {actual_ids}")
404
405        items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1])
406        self.added_tokens_dict    = added_tokens
407        self.added_tokens_list    = [text for (text, idx) in items]
408        self.vocab_size_base      = vocab_size
409        self.vocab_size           = self.vocab_size_base + len(self.added_tokens_list)
410        self.fname_tokenizer      = fname_tokenizer
411
412    def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
413        reverse_vocab = {id: encoded_tok for encoded_tok, id in self.vocab.items()}
414
415        for i, _ in enumerate(self.vocab):
416            yield reverse_vocab[i], 0.0, gguf.TokenType.NORMAL
417
418    def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
419        for text in self.added_tokens_list:
420            score = -1000.0
421            yield text.encode("utf-8"), score, gguf.TokenType.CONTROL
422
423    def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
424        yield from self.bpe_tokens()
425        yield from self.added_tokens()
426
427    def __repr__(self) -> str:
428        return f"<BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
429
430
431class SentencePieceVocab(Vocab):
432    tokenizer_model = "llama"
433    name = "spm"
434
435    def __init__(self, base_path: Path):
436        if SentencePieceProcessor is None:
437            raise RuntimeError("sentencepiece is not installed")
438
439        added_tokens: dict[str, int] = {}
440        if (fname_tokenizer := base_path / 'tokenizer.model').exists():
441            # normal location
442            try:
443                with open(base_path / 'added_tokens.json', encoding="utf-8") as f:
444                    added_tokens = json.load(f)
445            except FileNotFoundError:
446                pass
447        elif not (fname_tokenizer := base_path.parent / 'tokenizer.model').exists():
448            # not found in alternate location either
449            raise FileNotFoundError('Cannot find tokenizer.model')
450
451        self.sentencepiece_tokenizer = SentencePieceProcessor()
452        self.sentencepiece_tokenizer.LoadFromFile(str(fname_tokenizer))
453        vocab_size = self.sentencepiece_tokenizer.vocab_size()
454
455        new_tokens       = {id: piece for piece, id in added_tokens.items() if id >= vocab_size}
456        expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens)))
457        actual_new_ids   = sorted(new_tokens.keys())
458
459        if expected_new_ids != actual_new_ids:
460            raise ValueError(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}")
461
462        # Token pieces that were added to the base vocabulary.
463        self.added_tokens_dict  = added_tokens
464        self.added_tokens_list  = [new_tokens[id] for id in actual_new_ids]
465        self.vocab_size_base    = vocab_size
466        self.vocab_size         = self.vocab_size_base + len(self.added_tokens_list)
467        self.fname_tokenizer    = fname_tokenizer
468
469    def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
470        tokenizer = self.sentencepiece_tokenizer
471        for i in range(tokenizer.vocab_size()):
472            piece = tokenizer.IdToPiece(i)
473            text         = piece.encode("utf-8")
474            score: float = tokenizer.GetScore(i)
475
476            toktype = gguf.TokenType.NORMAL
477            if tokenizer.IsUnknown(i):
478                toktype = gguf.TokenType.UNKNOWN
479            if tokenizer.IsControl(i):
480                toktype = gguf.TokenType.CONTROL
481
482            # NOTE: I think added_tokens are user defined.
483            # ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto
484            # if tokenizer.is_user_defined(i): toktype = gguf.TokenType.USER_DEFINED
485
486            if tokenizer.IsUnused(i):
487                toktype = gguf.TokenType.UNUSED
488            if tokenizer.IsByte(i):
489                toktype = gguf.TokenType.BYTE
490
491            yield text, score, toktype
492
493    def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
494        for text in self.added_tokens_list:
495            score = -1000.0
496            yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED
497
498    def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
499        yield from self.sentencepiece_tokens()
500        yield from self.added_tokens()
501
502    def __repr__(self) -> str:
503        return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
504
505
506class LlamaHfVocab(Vocab):
507    tokenizer_model = "llama"
508    name = "hfft"
509
510    def __init__(self, base_path: Path):
511        fname_tokenizer = base_path / 'tokenizer.json'
512        # if this fails, FileNotFoundError propagates to caller
513        with open(fname_tokenizer, encoding='utf-8') as f:
514            tokenizer_json = json.load(f)
515
516        # pre-check so we know if we need transformers
517        tokenizer_model: dict[str, Any] = tokenizer_json['model']
518        is_llama3 = (
519            tokenizer_model['type'] == 'BPE' and tokenizer_model.get('ignore_merges', False)
520            and not tokenizer_model.get('byte_fallback', True)
521        )
522        if is_llama3:
523            raise TypeError('Llama 3 must be converted with BpeVocab')
524
525        if not is_llama3 and (
526            tokenizer_model['type'] != 'BPE' or not tokenizer_model.get('byte_fallback', False)
527            or tokenizer_json['decoder']['type'] != 'Sequence'
528        ):
529            raise FileNotFoundError('Cannot find Llama BPE tokenizer')
530
531        try:
532            from transformers import AutoTokenizer
533        except ImportError as e:
534            raise ImportError(
535                "To use LlamaHfVocab, please install the `transformers` package. "
536                "You can install it with `pip install transformers`."
537            ) from e
538
539        # Allow the tokenizer to default to slow or fast versions.
540        # Explicitly set tokenizer to use local paths.
541        self.tokenizer = AutoTokenizer.from_pretrained(
542            base_path,
543            cache_dir=base_path,
544            local_files_only=True,
545        )
546        assert self.tokenizer.is_fast  # assume tokenizer.json is used
547
548        # Initialize lists and dictionaries for added tokens
549        self.added_tokens_list = []
550        self.added_tokens_dict = dict()
551        self.added_tokens_ids  = set()
552
553        # Process added tokens
554        for tok, tokidx in sorted(
555            self.tokenizer.get_added_vocab().items(), key=lambda x: x[1]
556        ):
557            # Only consider added tokens that are not in the base vocabulary
558            if tokidx >= self.tokenizer.vocab_size:
559                self.added_tokens_list.append(tok)
560                self.added_tokens_dict[tok] = tokidx
561                self.added_tokens_ids.add(tokidx)
562
563        # Store special tokens and their IDs
564        self.specials = {
565            tok: self.tokenizer.get_vocab()[tok]
566            for tok in self.tokenizer.all_special_tokens
567        }
568        self.special_ids = set(self.tokenizer.all_special_ids)
569
570        # Set vocabulary sizes
571        self.vocab_size_base = self.tokenizer.vocab_size
572        self.vocab_size      = self.vocab_size_base + len(self.added_tokens_list)
573
574        self.fname_tokenizer = fname_tokenizer
575
576    def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
577        reverse_vocab = {
578            id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items()
579        }
580
581        for token_id in range(self.vocab_size_base):
582            # Skip processing added tokens here
583            if token_id in self.added_tokens_ids:
584                continue
585
586            # Convert token text to bytes
587            token_text = reverse_vocab[token_id].encode("utf-8")
588
589            # Yield token text, score, and type
590            yield token_text, self.get_token_score(token_id), self.get_token_type(
591                token_id, token_text, self.special_ids  # Reuse already stored special IDs
592            )
593
594    def get_token_type(self, token_id: int, token_text: bytes, special_ids: set[int]) -> gguf.TokenType:
595        # Special case for byte tokens
596        if re.fullmatch(br"<0x[0-9A-Fa-f]{2}>", token_text):
597            return gguf.TokenType.BYTE
598
599        # Determine token type based on whether it's a special token
600        return gguf.TokenType.CONTROL if token_id in special_ids else gguf.TokenType.NORMAL
601
602    def get_token_score(self, token_id: int) -> float:
603        # Placeholder for actual logic to determine the token's score
604        # This needs to be implemented based on specific requirements
605        return -1000.0  # Default score
606
607    def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
608        for text in self.added_tokens_list:
609            if text in self.specials:
610                toktype = self.get_token_type(self.specials[text], b'', self.special_ids)
611                score = self.get_token_score(self.specials[text])
612            else:
613                toktype = gguf.TokenType.USER_DEFINED
614                score = -1000.0
615
616            yield text.encode("utf-8"), score, toktype
617
618    def has_newline_token(self):
619        return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab
620
621    def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
622        yield from self.hf_tokens()
623        yield from self.added_tokens()
624
625    def __repr__(self) -> str:
626        return f"<LlamaHfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
627
628
629class MistralTokenizerType(str, Enum):
630    spm = "spm"
631    tekken = "tekken"
632
633
634# Copied from Transformers (Apache 2.0)
635# https://github.com/huggingface/transformers/blob/main/src/transformers/convert_slow_tokenizer.py#L1544
636
637def bytes_to_unicode() -> dict[int, str]:
638    """
639    Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
640    characters the bpe code barfs on.
641
642    The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
643    if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
644    decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
645    tables between utf-8 bytes and unicode strings.
646    """
647    bs = (
648        list(range(ord("!"), ord("~") + 1))
649        + list(range(ord("ยก"), ord("ยฌ") + 1))
650        + list(range(ord("ยฎ"), ord("รฟ") + 1))
651    )
652    cs = bs[:]
653    n = 0
654    for b in range(2**8):
655        if b not in bs:
656            bs.append(b)
657            cs.append(2**8 + n)
658            n += 1
659    cs_str = [chr(n) for n in cs]
660    return dict(zip(bs, cs_str))
661
662
663class MistralVocab(Vocab):
664    tokenizer_model = "mistral"
665    name = "mistral"
666
667    added_tokens_dict: dict[str, int] = {}
668    added_tokens_list: list[str] = []
669
670    def __init__(self, base_path: Path):
671        if not _mistral_common_installed:
672            raise ImportError(
673                "To use MistralVocab, please install the `mistral-common` package. "
674                "You can install it with `pip install mistral-common`."
675            )
676        assert _filter_valid_tokenizer_files is not None, "mistral_common is not installed"
677        assert MistralTokenizer is not None, "mistral_common is not installed"
678        assert Tekkenizer is not None, "mistral_common is not installed"
679
680        logger.info(f"Loading Mistral tokenizer from {base_path}")
681
682        # Find the tokenizer files
683        all_files = [f.as_posix() for f in base_path.glob("**/*") if f.is_file()]
684
685        if get_one_valid_tokenizer_file is not None:
686            tokenizer_file_path = get_one_valid_tokenizer_file(all_files)
687        else:
688            valid_tokenizer_files = _filter_valid_tokenizer_files(all_files)
689
690            if len(valid_tokenizer_files) == 0:
691                raise ValueError(f"No tokenizer file found in the directory: {base_path}")
692            # If there are multiple tokenizer files, we use tekken.json if it exists, otherwise the versioned one.
693            if len(valid_tokenizer_files) > 1:
694                if "tekken.json" in valid_tokenizer_files:
695                    tokenizer_file = "tekken.json"
696                else:
697                    tokenizer_file = sorted(valid_tokenizer_files)[-1]
698                logger.warning(
699                    f"Multiple tokenizer files found in {base_path}. Using {tokenizer_file}"
700                )
701            else:
702                tokenizer_file = valid_tokenizer_files[0]
703
704            tokenizer_file_path = base_path / tokenizer_file
705
706        self.tokenizer = MistralTokenizer.from_file(
707            tokenizer_file_path
708        ).instruct_tokenizer.tokenizer
709        self.tokenizer_type = (
710            MistralTokenizerType.tekken
711            if isinstance(self.tokenizer, Tekkenizer)
712            else MistralTokenizerType.spm
713        )
714        self.vocab_size = self.tokenizer.n_words
715        self.fname_tokenizer = tokenizer_file_path
716        self._name = (
717            "mistral-" + self.tokenizer_type.value + "-" + self.tokenizer.version
718        )
719
720    @property
721    def tokenizer_name(self) -> str:
722        return self._name
723
724    @property
725    def gguf_tokenizer_model(self) -> str:
726        return "llama" if self.tokenizer_type == MistralTokenizerType.spm else "gpt2"
727
728    def _sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
729        assert SentencePieceTokenizer is not None, "mistral_common is not installed"
730        assert isinstance(self.tokenizer, SentencePieceTokenizer), (
731            f"Expected SentencePieceTokenizer, got {type(self.tokenizer)}"
732        )
733
734        for i in range(self.tokenizer._model.vocab_size()):
735            piece = self.tokenizer._model.IdToPiece(i)
736            text = piece.encode("utf-8")
737            score: float = self.tokenizer._model.GetScore(i)
738
739            toktype = gguf.TokenType.NORMAL
740            if self.tokenizer._model.IsUnknown(i):
741                toktype = gguf.TokenType.UNKNOWN
742            if self.tokenizer._model.IsControl(i):
743                toktype = gguf.TokenType.CONTROL
744
745            if self.tokenizer._model.IsUnused(i):
746                toktype = gguf.TokenType.UNUSED
747            if self.tokenizer._model.IsByte(i):
748                toktype = gguf.TokenType.BYTE
749
750            yield text, score, toktype
751
752    def _tekken_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
753        assert Tekkenizer is not None, "mistral_common is not installed"
754        assert isinstance(self.tokenizer, Tekkenizer), (
755            f"Expected Tekkenizer, got {type(self.tokenizer)}"
756        )
757
758        byte_encoder = bytes_to_unicode()
759        for token_id in range(self.tokenizer.num_special_tokens):
760            yield (
761                self.tokenizer.id_to_piece(token_id).encode("utf-8"),
762                0,
763                gguf.TokenType.CONTROL
764            )
765        for token in self.tokenizer._tekken_token2id_nospecial:
766            yield (
767                self.token_bytes_to_string(token, byte_encoder).encode("utf-8"),
768                0,
769                gguf.TokenType.NORMAL,
770            )
771
772    def get_token_id(self, token: str) -> int:
773        assert SentencePieceTokenizer is not None and Tekkenizer is not None, "mistral_common is not installed"
774        if self.tokenizer_type == MistralTokenizerType.spm:
775            assert isinstance(self.tokenizer, SentencePieceTokenizer)
776            return self.tokenizer._vocab.index(token)
777        elif self.tokenizer_type == MistralTokenizerType.tekken:
778            assert isinstance(self.tokenizer, Tekkenizer)
779            return (
780                self.tokenizer._vocab.index(token) + self.tokenizer.num_special_tokens
781            )
782        else:
783            raise ValueError(f"Unknown tokenizer type: {self.tokenizer_type}")
784
785    @property
786    def bos_id(self) -> int:
787        return self.tokenizer.bos_id
788
789    @property
790    def eos_id(self) -> int:
791        return self.tokenizer.eos_id
792
793    @property
794    def pad_id(self) -> int:
795        if self.tokenizer.pad_id == -1:
796            return self.eos_id
797        return self.tokenizer.pad_id
798
799    @property
800    def unk_id(self) -> int:
801        return self.tokenizer.unk_id
802
803    @property
804    def bos_token(self) -> str:
805        return self.tokenizer.id_to_piece(self.tokenizer.bos_id)
806
807    @property
808    def eos_token(self) -> str:
809        return self.tokenizer.id_to_piece(self.tokenizer.eos_id)
810
811    @property
812    def pad_token(self) -> str:
813        return self.tokenizer.id_to_piece(self.tokenizer.pad_id)
814
815    @property
816    def unk_token(self) -> str:
817        return self.tokenizer.id_to_piece(self.tokenizer.unk_id)
818
819    def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
820        if self.tokenizer_type == MistralTokenizerType.spm:
821            yield from self._sentencepiece_tokens()
822
823        elif self.tokenizer_type == MistralTokenizerType.tekken:
824            yield from self._tekken_tokens()
825
826        else:
827            raise ValueError(f"Unknown tokenizer type: {self.tokenizer_type}")
828
829    @staticmethod
830    def token_bytes_to_string(b, byte_encoder):
831        return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
832
833    def extract_vocab_merges_from_model(self):
834        # Adapted from Transformers (Apache 2.0)
835        # https://github.com/huggingface/transformers/blob/main/src/transformers/convert_slow_tokenizer.py
836        assert Tekkenizer is not None and isinstance(self.tokenizer, Tekkenizer), (
837            f"Expected Tekkenizer, got {type(self.tokenizer)}"
838        )
839        mergeable_ranks = self.tokenizer._model._mergeable_ranks
840        token_bytes_map = {
841            rank: token_bytes for token_bytes, rank in mergeable_ranks.items()
842        }
843        merge_pairs = []
844
845        # Sort vocab by rank to ensure correct merge order
846        for i in range(256, self.vocab_size - self.tokenizer.num_special_tokens):
847            merged_token = token_bytes_map[i]
848            local = []
849            for j in range(1, len(merged_token)):
850                left = merged_token[:j]
851                right = merged_token[j:]
852                if (
853                    left in mergeable_ranks
854                    and right in mergeable_ranks
855                    and (left + right) in mergeable_ranks
856                ):
857                    local.append((left, right, i))
858            if not local:
859                raise ValueError(
860                    f"Could not find valid merge for token at rank {i}: {merged_token.decode('latin-1')}"
861                )
862            local = sorted(
863                local,
864                key=lambda x: (mergeable_ranks[x[0]], mergeable_ranks[x[1]]),
865                reverse=False,
866            )
867            merge_pairs.extend(local)
868        merge_pairs = sorted(merge_pairs, key=lambda val: val[2], reverse=False)
869
870        byte_encoder = bytes_to_unicode()
871
872        decoded_merge_pairs = [
873            [
874                self.token_bytes_to_string(val[0], byte_encoder),
875                self.token_bytes_to_string(val[1], byte_encoder),
876            ]
877            for val in merge_pairs
878        ]
879
880        merges = [
881            " ".join(
882                [
883                    # ensure the spaces are properly encoded
884                    "".join(chr(ord(c) + 256) if c == " " else c for c in part)
885                    for part in pair
886                ]
887            )
888            for pair in decoded_merge_pairs
889        ]
890
891        return merges