1#!/usr/bin/env python3
  2# -*- coding: utf-8 -*-
  3
  4from __future__ import annotations
  5
  6from dataclasses import dataclass
  7import logging
  8import argparse
  9import os
 10import sys
 11import json
 12from math import prod
 13from pathlib import Path
 14from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast
 15from transformers import AutoConfig, AutoTokenizer
 16
 17import torch
 18
 19if TYPE_CHECKING:
 20    from torch import Tensor
 21
 22if 'NO_LOCAL_GGUF' not in os.environ:
 23    sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
 24import gguf
 25
 26# reuse model definitions from convert_hf_to_gguf.py
 27from convert_hf_to_gguf import LazyTorchTensor, ModelBase
 28
 29from gguf.constants import GGUFValueType
 30
 31logger = logging.getLogger("lora-to-gguf")
 32
 33
 34@dataclass
 35class PartialLoraTensor:
 36    A: Tensor | None = None
 37    B: Tensor | None = None
 38
 39
 40# magic to support tensor shape modifications and splitting
 41class LoraTorchTensor:
 42    _lora_A: Tensor  # (n_rank, row_size)
 43    _lora_B: Tensor  # (col_size, n_rank)
 44    _rank: int
 45
 46    def __init__(self, A: Tensor, B: Tensor):
 47        assert len(A.shape) == len(B.shape)
 48        assert A.shape[-2] == B.shape[-1]
 49        if A.dtype != B.dtype:
 50            A = A.to(torch.float32)
 51            B = B.to(torch.float32)
 52        self._lora_A = A
 53        self._lora_B = B
 54        self._rank = B.shape[-1]
 55
 56    def get_lora_A_B(self) -> tuple[Tensor, Tensor]:
 57        return (self._lora_A, self._lora_B)
 58
 59    def __getitem__(
 60        self,
 61        indices: (
 62            SupportsIndex
 63            | slice
 64            | tuple[SupportsIndex | slice | Tensor, ...]  # TODO: add ellipsis in the type signature
 65        ),
 66    ) -> LoraTorchTensor:
 67        shape = self.shape
 68        if isinstance(indices, SupportsIndex):
 69            if len(shape) > 2:
 70                return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices])
 71            else:
 72                raise NotImplementedError  # can't return a vector
 73        elif isinstance(indices, slice):
 74            if len(shape) > 2:
 75                return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices])
 76            else:
 77                return LoraTorchTensor(self._lora_A, self._lora_B[indices])
 78        elif isinstance(indices, tuple):
 79            assert len(indices) > 0
 80            if indices[-1] is Ellipsis:
 81                return self[indices[:-1]]
 82            # expand ellipsis
 83            indices = tuple(
 84                u
 85                for v in (
 86                    (
 87                        (slice(None, None) for _ in range(len(indices) - 1))
 88                        if i is Ellipsis
 89                        else (i,)
 90                    )
 91                    for i in indices
 92                )
 93                for u in v
 94            )
 95
 96            if len(indices) < len(shape):
 97                indices = (*indices, *(slice(None, None) for _ in range(len(indices), len(shape))))
 98
 99            # TODO: make sure this is correct
100            indices_A = (
101                *(
102                    (
103                        j.__index__() % self._lora_A.shape[i]
104                        if isinstance(j, SupportsIndex)
105                        else slice(None, None)
106                    )
107                    for i, j in enumerate(indices[:-2])
108                ),
109                slice(None, None),
110                indices[-1],
111            )
112            indices_B = indices[:-1]
113            return LoraTorchTensor(self._lora_A[indices_A], self._lora_B[indices_B])
114        else:
115            raise NotImplementedError  # unknown indice type
116
117    @property
118    def dtype(self) -> torch.dtype:
119        assert self._lora_A.dtype == self._lora_B.dtype
120        return self._lora_A.dtype
121
122    @property
123    def shape(self) -> tuple[int, ...]:
124        assert len(self._lora_A.shape) == len(self._lora_B.shape)
125        return (*self._lora_B.shape[:-1], self._lora_A.shape[-1])
126
127    def size(self, dim=None):
128        assert dim is None
129        return self.shape
130
131    def reshape(self, *shape: int | tuple[int, ...]) -> LoraTorchTensor:
132        if isinstance(shape[0], tuple):
133            new_shape: tuple[int, ...] = shape[0]
134        else:
135            new_shape = cast(tuple[int, ...], shape)
136        orig_shape = self.shape
137        if len(new_shape) < 2:
138            raise NotImplementedError  # can't become a vector
139
140        # expand -1 in the shape
141        if any(dim == -1 for dim in new_shape):
142            n_elems = prod(orig_shape)
143            n_new_elems = prod(dim if dim != -1 else 1 for dim in new_shape)
144            assert n_elems % n_new_elems == 0
145            new_shape = (*(dim if dim != -1 else n_elems // n_new_elems for dim in new_shape),)
146
147        if new_shape[-1] != orig_shape[-1]:
148            raise NotImplementedError  # can't reshape the row size trivially
149
150        shape_A = (*(1 for _ in new_shape[:-2]), self._rank, orig_shape[-1])
151        shape_B = (*new_shape[:-1], self._rank)
152        return LoraTorchTensor(
153            self._lora_A.reshape(shape_A),
154            self._lora_B.reshape(shape_B),
155        )
156
157    def reshape_as(self, other: Tensor) -> LoraTorchTensor:
158        return self.reshape(*other.shape)
159
160    def view(self, *size: int) -> LoraTorchTensor:
161        return self.reshape(*size)
162
163    def permute(self, *dims: int) -> LoraTorchTensor:
164        shape = self.shape
165        dims = tuple(dim - len(shape) if dim >= 0 else dim for dim in dims)
166        if dims[-1] == -1:
167            # TODO: support higher dimensional A shapes bigger than 1
168            assert all(dim == 1 for dim in self._lora_A.shape[:-2])
169            return LoraTorchTensor(self._lora_A, self._lora_B.permute(*dims))
170        if len(shape) == 2 and dims[-1] == -2 and dims[-2] == -1:
171            return LoraTorchTensor(self._lora_B.permute(*dims), self._lora_A.permute(*dims))
172        else:
173            # TODO: compose the above two
174            raise NotImplementedError
175
176    def transpose(self, dim0: int, dim1: int) -> LoraTorchTensor:
177        shape = self.shape
178        dims = [i for i in range(len(shape))]
179        dims[dim0], dims[dim1] = dims[dim1], dims[dim0]
180        return self.permute(*dims)
181
182    def swapaxes(self, axis0: int, axis1: int) -> LoraTorchTensor:
183        return self.transpose(axis0, axis1)
184
185    def to(self, *args, **kwargs):
186        return LoraTorchTensor(self._lora_A.to(*args, **kwargs), self._lora_B.to(*args, **kwargs))
187
188    @classmethod
189    def __torch_function__(cls, func: Callable, types, args=(), kwargs=None):
190        del types  # unused
191
192        if kwargs is None:
193            kwargs = {}
194
195        if func is torch.permute:
196            return type(args[0]).permute(*args, **kwargs)
197        elif func is torch.reshape:
198            return type(args[0]).reshape(*args, **kwargs)
199        elif func is torch.stack:
200            assert isinstance(args[0], Sequence)
201            dim = kwargs.get("dim", 0)
202            assert dim == 0
203            return LoraTorchTensor(
204                torch.stack([a._lora_A for a in args[0]], dim),
205                torch.stack([b._lora_B for b in args[0]], dim),
206            )
207        elif func is torch.cat:
208            assert isinstance(args[0], Sequence)
209            dim = kwargs.get("dim", 0)
210            assert dim == 0
211            if len(args[0][0].shape) > 2:
212                return LoraTorchTensor(
213                    torch.cat([a._lora_A for a in args[0]], dim),
214                    torch.cat([b._lora_B for b in args[0]], dim),
215                )
216            elif all(torch.equal(args[0][0]._lora_A, t._lora_A) for t in args[0][1:]):
217                return LoraTorchTensor(
218                    args[0][0]._lora_A,
219                    torch.cat([b._lora_B for b in args[0]], dim),
220                )
221            else:
222                raise NotImplementedError
223        else:
224            raise NotImplementedError
225
226
227def get_base_tensor_name(lora_tensor_name: str) -> str:
228    base_name = lora_tensor_name.replace("base_model.model.", "")
229    base_name = base_name.replace(".lora_A.weight", ".weight")
230    base_name = base_name.replace(".lora_B.weight", ".weight")
231    # models produced by mergekit-extract-lora have token embeddings in the adapter
232    base_name = base_name.replace(".lora_embedding_A", ".weight")
233    base_name = base_name.replace(".lora_embedding_B", ".weight")
234    return base_name
235
236
237def parse_args() -> argparse.Namespace:
238    parser = argparse.ArgumentParser(
239        description="Convert a Hugging Face PEFT LoRA adapter to a GGUF file")
240    parser.add_argument(
241        "--outfile", type=Path,
242        help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
243    )
244    parser.add_argument(
245        "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f32",
246        help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
247    )
248    parser.add_argument(
249        "--bigendian", action="store_true",
250        help="model is executed on big endian machine",
251    )
252    parser.add_argument(
253        "--no-lazy", action="store_true",
254        help="use more RAM by computing all outputs before writing (use in case lazy evaluation is broken)",
255    )
256    parser.add_argument(
257        "--verbose", action="store_true",
258        help="increase output verbosity",
259    )
260    parser.add_argument(
261        "--dry-run", action="store_true",
262        help="only print out what will be done, without writing any new files",
263    )
264    parser.add_argument(
265        "--base", type=Path,
266        help="directory containing Hugging Face model config files (config.json, tokenizer.json) for the base model that the adapter is based on - only config is needed, actual model weights are not required. If base model is unspecified, it will be loaded from Hugging Face hub based on the adapter config",
267    )
268    parser.add_argument(
269        "--base-model-id", type=str,
270        help="the model ID of the base model, if it is not available locally or in the adapter config. If specified, it will ignore --base and load the base model config from the Hugging Face hub (Example: 'meta-llama/Llama-3.2-1B-Instruct')",
271    )
272    parser.add_argument(
273        "lora_path", type=Path,
274        help="directory containing Hugging Face PEFT LoRA config (adapter_model.json) and weights (adapter_model.safetensors or adapter_model.bin)",
275    )
276
277    return parser.parse_args()
278
279
280def load_hparams_from_hf(hf_model_id: str) -> tuple[dict[str, Any], Path | None]:
281    from huggingface_hub import try_to_load_from_cache
282
283    # normally, adapter does not come with base model config, we need to load it from AutoConfig
284    config = AutoConfig.from_pretrained(hf_model_id)
285    cache_dir = try_to_load_from_cache(hf_model_id, "config.json")
286    cache_dir = Path(cache_dir).parent if isinstance(cache_dir, str) else None
287
288    return config.to_dict(), cache_dir
289
290
291if __name__ == '__main__':
292    args = parse_args()
293    logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
294
295    ftype_map: dict[str, gguf.LlamaFileType] = {
296        "f32": gguf.LlamaFileType.ALL_F32,
297        "f16": gguf.LlamaFileType.MOSTLY_F16,
298        "bf16": gguf.LlamaFileType.MOSTLY_BF16,
299        "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
300        "auto": gguf.LlamaFileType.GUESSED,
301    }
302
303    ftype = ftype_map[args.outtype]
304
305    dir_base_model: Path | None = args.base
306    dir_lora: Path = args.lora_path
307    base_model_id: str | None = args.base_model_id
308    lora_config = dir_lora / "adapter_config.json"
309    input_model = dir_lora / "adapter_model.safetensors"
310
311    if args.outfile is not None:
312        fname_out = args.outfile
313    else:
314        # output in the same directory as the model by default
315        fname_out = dir_lora
316
317    if os.path.exists(input_model):
318        # lazy import load_file only if lora is in safetensors format.
319        from safetensors.torch import load_file
320
321        lora_model = load_file(input_model, device="cpu")
322    else:
323        input_model = os.path.join(dir_lora, "adapter_model.bin")
324        lora_model = torch.load(input_model, map_location="cpu", weights_only=True)
325
326    # load LoRA config
327    with open(lora_config, "r") as f:
328        lparams: dict[str, Any] = json.load(f)
329
330    # load base model
331    if base_model_id is not None:
332        logger.info(f"Loading base model from Hugging Face: {base_model_id}")
333        hparams, dir_base_model = load_hparams_from_hf(base_model_id)
334    elif dir_base_model is None:
335        if "base_model_name_or_path" in lparams:
336            model_id = lparams["base_model_name_or_path"]
337            logger.info(f"Loading base model from Hugging Face: {model_id}")
338            try:
339                hparams, dir_base_model = load_hparams_from_hf(model_id)
340            except OSError as e:
341                logger.error(f"Failed to load base model config: {e}")
342                logger.error("Please try downloading the base model and add its path to --base")
343                sys.exit(1)
344        else:
345            logger.error("'base_model_name_or_path' is not found in adapter_config.json")
346            logger.error("Base model config is required. Please download the base model and add its path to --base")
347            sys.exit(1)
348    else:
349        logger.info(f"Loading base model: {dir_base_model.name}")
350        hparams = ModelBase.load_hparams(dir_base_model, False)
351
352    with torch.inference_mode():
353        try:
354            model_class = ModelBase.from_model_architecture(hparams["architectures"][0])
355        except NotImplementedError:
356            logger.error(f"Model {hparams['architectures'][0]} is not supported")
357            sys.exit(1)
358
359        class LoraModel(model_class):
360            model_arch = model_class.model_arch
361
362            lora_alpha: float
363
364            def __init__(self, *args, dir_lora_model: Path, lora_alpha: float, **kwargs):
365
366                super().__init__(*args, **kwargs)
367
368                self.dir_model_card = dir_lora_model
369                self.lora_alpha = float(lora_alpha)
370
371            def set_vocab(self):
372                pass
373
374            def set_type(self):
375                self.gguf_writer.add_type(gguf.GGUFType.ADAPTER)
376                self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
377
378            def set_gguf_parameters(self):
379                logger.debug("GGUF KV: %s = %d", gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha)
380                self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha)
381                alora_invocation_tokens = lparams.get("alora_invocation_tokens")
382                invocation_string = lparams.get("invocation_string")
383                if invocation_string and not alora_invocation_tokens:
384                    logger.debug("Tokenizing invocation_string -> alora_invocation_tokens")
385                    base_model_path_or_id = hparams.get("_name_or_path")
386                    try:
387                        tokenizer = AutoTokenizer.from_pretrained(base_model_path_or_id)
388                    except ValueError:
389                        logger.error("Unable to load tokenizer from %s", base_model_path_or_id)
390                        raise
391                    # NOTE: There's an off-by-one with the older aLoRAs where
392                    # the invocation string includes the "<|start_of_turn|>"
393                    # token, but the adapters themselves were trained to
394                    # activate _after_ that first token, so we drop it here.
395                    alora_invocation_tokens = tokenizer(invocation_string)["input_ids"][1:]
396                if alora_invocation_tokens:
397                    logger.debug("GGUF KV: %s = %s", gguf.Keys.Adapter.ALORA_INVOCATION_TOKENS, alora_invocation_tokens)
398                    self.gguf_writer.add_key_value(
399                        gguf.Keys.Adapter.ALORA_INVOCATION_TOKENS,
400                        alora_invocation_tokens,
401                        GGUFValueType.ARRAY,
402                        GGUFValueType.UINT32,
403                    )
404
405            def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
406                # Never add extra tensors (e.g. rope_freqs) for LoRA adapters
407                return ()
408
409            def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
410                tensor_map: dict[str, PartialLoraTensor] = {}
411
412                for name, tensor in lora_model.items():
413                    if self.lazy:
414                        tensor = LazyTorchTensor.from_eager(tensor)
415                    base_name = get_base_tensor_name(name)
416                    # note: mergekit-extract-lora also adds token embeddings to the adapter
417                    is_lora_a = ".lora_A.weight" in name or ".lora_embedding_A" in name
418                    is_lora_b = ".lora_B.weight" in name or ".lora_embedding_B" in name
419                    if not is_lora_a and not is_lora_b:
420                        if ".base_layer.weight" in name:
421                            continue
422                        # mergekit-extract-lora add these layernorm to the adapter, we need to keep them
423                        if "_layernorm" in name or ".norm" in name:
424                            yield (base_name, tensor)
425                            continue
426                        logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor")
427                        if ".embed_tokens.weight" in name or ".lm_head.weight" in name:
428                            logger.error("Embeddings is present in the adapter. This can be due to new tokens added during fine tuning")
429                            logger.error("Please refer to https://github.com/ggml-org/llama.cpp/pull/9948")
430                        sys.exit(1)
431
432                    if base_name in tensor_map:
433                        if is_lora_a:
434                            tensor_map[base_name].A = tensor
435                        else:
436                            tensor_map[base_name].B = tensor
437                    else:
438                        if is_lora_a:
439                            tensor_map[base_name] = PartialLoraTensor(A=tensor)
440                        else:
441                            tensor_map[base_name] = PartialLoraTensor(B=tensor)
442
443                for name, tensor in tensor_map.items():
444                    assert tensor.A is not None
445                    assert tensor.B is not None
446                    yield (name, cast(torch.Tensor, LoraTorchTensor(tensor.A, tensor.B)))
447
448            def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
449                dest = list(super().modify_tensors(data_torch, name, bid))
450                # some archs may have the same tensor for lm_head and output (tie word embeddings)
451                # in this case, adapters targeting lm_head will fail when using llama-export-lora
452                # therefore, we ignore them for now
453                # see: https://github.com/ggml-org/llama.cpp/issues/9065
454                if name == "lm_head.weight" and len(dest) == 0:
455                    raise ValueError("lm_head is present in adapter, but is ignored in base model")
456                for dest_name, dest_data in dest:
457                    # mergekit-extract-lora add these layernorm to the adapter
458                    if "_norm" in dest_name:
459                        assert dest_data.dim() == 1
460                        yield (dest_name, dest_data)
461                        continue
462
463                    # otherwise, we must get the lora_A and lora_B tensors
464                    assert isinstance(dest_data, LoraTorchTensor)
465                    lora_a, lora_b = dest_data.get_lora_A_B()
466
467                    # note: mergekit-extract-lora flip and transpose A and B
468                    # here we only need to transpose token_embd.lora_a, see llm_build_inp_embd()
469                    if "token_embd.weight" in dest_name:
470                        lora_a = lora_a.T
471
472                    yield (dest_name + ".lora_a", lora_a)
473                    yield (dest_name + ".lora_b", lora_b)
474
475        alpha: float = lparams["lora_alpha"]
476
477        model_instance = LoraModel(
478            dir_base_model,
479            ftype,
480            fname_out,
481            is_big_endian=args.bigendian,
482            use_temp_file=False,
483            eager=args.no_lazy,
484            dry_run=args.dry_run,
485            dir_lora_model=dir_lora,
486            lora_alpha=alpha,
487            hparams=hparams,
488            remote_hf_model_id=base_model_id,
489        )
490
491        logger.info("Exporting model...")
492        model_instance.write()
493        logger.info(f"Model successfully exported to {model_instance.fname_out}")