1from __future__ import annotations
  2
  3import re
  4import json
  5import yaml
  6import logging
  7from pathlib import Path
  8from typing import Any, Literal, Optional
  9from dataclasses import dataclass
 10
 11from .constants import Keys
 12
 13import gguf
 14
 15logger = logging.getLogger("metadata")
 16
 17
 18@dataclass
 19class Metadata:
 20    # Recommended Sampler Parameters to be written to GGUF KV Store
 21    sampling_sequence: Optional[str] = None
 22    sampling_top_k: Optional[int] = None
 23    sampling_top_p: Optional[float] = None
 24    sampling_min_p: Optional[float] = None
 25    sampling_xtc_probability: Optional[float] = None
 26    sampling_xtc_threshold: Optional[float] = None
 27    sampling_temp: Optional[float] = None
 28    sampling_penalty_last_n: Optional[int] = None
 29    sampling_penalty_repeat: Optional[float] = None
 30    sampling_mirostat: Optional[int] = None
 31    sampling_mirostat_tau: Optional[float] = None
 32    sampling_mirostat_eta: Optional[float] = None
 33
 34    # Authorship Metadata to be written to GGUF KV Store
 35    name: Optional[str] = None
 36    author: Optional[str] = None
 37    version: Optional[str] = None
 38    organization: Optional[str] = None
 39    finetune: Optional[str] = None
 40    basename: Optional[str] = None
 41    description: Optional[str] = None
 42    quantized_by: Optional[str] = None
 43    size_label: Optional[str] = None
 44    url: Optional[str] = None
 45    doi: Optional[str] = None
 46    uuid: Optional[str] = None
 47    repo_url: Optional[str] = None
 48    source_url: Optional[str] = None
 49    source_doi: Optional[str] = None
 50    source_uuid: Optional[str] = None
 51    source_repo_url: Optional[str] = None
 52    license: Optional[str] = None
 53    license_name: Optional[str] = None
 54    license_link: Optional[str] = None
 55    base_models: Optional[list[dict]] = None
 56    tags: Optional[list[str]] = None
 57    languages: Optional[list[str]] = None
 58    datasets: Optional[list[dict]] = None
 59
 60    @staticmethod
 61    def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Path] = None, model_name: Optional[str] = None, total_params: int = 0) -> Metadata:
 62        # This grabs as many contextual authorship metadata as possible from the model repository
 63        # making any conversion as required to match the gguf kv store metadata format
 64        # as well as giving users the ability to override any authorship metadata that may be incorrect
 65
 66        # Create a new Metadata instance
 67        metadata = Metadata()
 68
 69        model_card = Metadata.load_model_card(model_path)
 70        hf_params = Metadata.load_hf_parameters(model_path)
 71        gen_config = Metadata.load_generation_config(model_path)
 72        # TODO: load adapter_config.json when possible, it usually contains the base model of the LoRA adapter
 73
 74        # heuristics
 75        metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params)
 76
 77        if gen_config:
 78            metadata.sampling_sequence        = gen_config.get("sequence",        metadata.sampling_sequence)
 79            metadata.sampling_top_k           = gen_config.get("top_k",           metadata.sampling_top_k)
 80            metadata.sampling_top_p           = gen_config.get("top_p",           metadata.sampling_top_p)
 81            metadata.sampling_min_p           = gen_config.get("min_p",           metadata.sampling_min_p)
 82            metadata.sampling_xtc_probability = gen_config.get("xtc_probability", metadata.sampling_xtc_probability)
 83            metadata.sampling_xtc_threshold   = gen_config.get("xtc_threshold",   metadata.sampling_xtc_threshold)
 84            metadata.sampling_temp            = gen_config.get("temperature",     metadata.sampling_temp)
 85            metadata.sampling_penalty_last_n  = gen_config.get("penalty_last_n",  metadata.sampling_penalty_last_n)
 86            metadata.sampling_penalty_repeat  = gen_config.get("penalty_repeat",  metadata.sampling_penalty_repeat)
 87            metadata.sampling_mirostat        = gen_config.get("mirostat",        metadata.sampling_mirostat)
 88            metadata.sampling_mirostat_tau    = gen_config.get("mirostat_tau",    metadata.sampling_mirostat_tau)
 89            metadata.sampling_mirostat_eta    = gen_config.get("mirostat_eta",    metadata.sampling_mirostat_eta)
 90
 91        # Metadata Override File Provided
 92        # This is based on LLM_KV_NAMES mapping in llama.cpp
 93        metadata_override = Metadata.load_metadata_override(metadata_override_path)
 94
 95        metadata.sampling_sequence        = metadata_override.get(Keys.General.SAMPLING_SEQUENCE,        metadata.sampling_sequence)
 96        metadata.sampling_top_k           = metadata_override.get(Keys.General.SAMPLING_TOP_K,           metadata.sampling_top_k)
 97        metadata.sampling_top_p           = metadata_override.get(Keys.General.SAMPLING_TOP_P,           metadata.sampling_top_p)
 98        metadata.sampling_min_p           = metadata_override.get(Keys.General.SAMPLING_MIN_P,           metadata.sampling_min_p)
 99        metadata.sampling_xtc_probability = metadata_override.get(Keys.General.SAMPLING_XTC_PROBABILITY, metadata.sampling_xtc_probability)
100        metadata.sampling_xtc_threshold   = metadata_override.get(Keys.General.SAMPLING_XTC_THRESHOLD,   metadata.sampling_xtc_threshold)
101        metadata.sampling_temp            = metadata_override.get(Keys.General.SAMPLING_TEMP,            metadata.sampling_temp)
102        metadata.sampling_penalty_last_n  = metadata_override.get(Keys.General.SAMPLING_PENALTY_LAST_N,  metadata.sampling_penalty_last_n)
103        metadata.sampling_penalty_repeat  = metadata_override.get(Keys.General.SAMPLING_PENALTY_REPEAT,  metadata.sampling_penalty_repeat)
104        metadata.sampling_mirostat        = metadata_override.get(Keys.General.SAMPLING_MIROSTAT,        metadata.sampling_mirostat)
105        metadata.sampling_mirostat_tau    = metadata_override.get(Keys.General.SAMPLING_MIROSTAT_TAU,    metadata.sampling_mirostat_tau)
106        metadata.sampling_mirostat_eta    = metadata_override.get(Keys.General.SAMPLING_MIROSTAT_ETA,    metadata.sampling_mirostat_eta)
107
108        metadata.name            = metadata_override.get(Keys.General.NAME,            metadata.name)
109        metadata.author          = metadata_override.get(Keys.General.AUTHOR,          metadata.author)
110        metadata.version         = metadata_override.get(Keys.General.VERSION,         metadata.version)
111        metadata.organization    = metadata_override.get(Keys.General.ORGANIZATION,    metadata.organization)
112
113        metadata.finetune        = metadata_override.get(Keys.General.FINETUNE,        metadata.finetune)
114        metadata.basename        = metadata_override.get(Keys.General.BASENAME,        metadata.basename)
115
116        metadata.description     = metadata_override.get(Keys.General.DESCRIPTION,     metadata.description)
117        metadata.quantized_by    = metadata_override.get(Keys.General.QUANTIZED_BY,    metadata.quantized_by)
118
119        metadata.size_label      = metadata_override.get(Keys.General.SIZE_LABEL,      metadata.size_label)
120        metadata.license_name    = metadata_override.get(Keys.General.LICENSE_NAME,    metadata.license_name)
121        metadata.license_link    = metadata_override.get(Keys.General.LICENSE_LINK,    metadata.license_link)
122
123        metadata.url             = metadata_override.get(Keys.General.URL,             metadata.url)
124        metadata.doi             = metadata_override.get(Keys.General.DOI,             metadata.doi)
125        metadata.uuid            = metadata_override.get(Keys.General.UUID,            metadata.uuid)
126        metadata.repo_url        = metadata_override.get(Keys.General.REPO_URL,        metadata.repo_url)
127
128        metadata.source_url      = metadata_override.get(Keys.General.SOURCE_URL,      metadata.source_url)
129        metadata.source_doi      = metadata_override.get(Keys.General.SOURCE_DOI,      metadata.source_doi)
130        metadata.source_uuid     = metadata_override.get(Keys.General.SOURCE_UUID,     metadata.source_uuid)
131        metadata.source_repo_url = metadata_override.get(Keys.General.SOURCE_REPO_URL, metadata.source_repo_url)
132
133        # Base Models is received here as an array of models
134        metadata.base_models     = metadata_override.get("general.base_models",        metadata.base_models)
135
136        # Datasets is received here as an array of datasets
137        metadata.datasets        = metadata_override.get("general.datasets",           metadata.datasets)
138
139        metadata.tags            = metadata_override.get(Keys.General.TAGS,            metadata.tags)
140        metadata.languages       = metadata_override.get(Keys.General.LANGUAGES,       metadata.languages)
141
142        # Direct Metadata Override (via direct cli argument)
143        if model_name is not None:
144            metadata.name = model_name
145
146        return metadata
147
148    @staticmethod
149    def load_metadata_override(metadata_override_path: Optional[Path] = None) -> dict[str, Any]:
150        if metadata_override_path is None or not metadata_override_path.is_file():
151            return {}
152
153        with open(metadata_override_path, "r", encoding="utf-8") as f:
154            return json.load(f)
155
156    @staticmethod
157    def load_model_card(model_path: Optional[Path] = None) -> dict[str, Any]:
158        if model_path is None or not model_path.is_dir():
159            return {}
160
161        model_card_path = model_path / "README.md"
162
163        if not model_card_path.is_file():
164            return {}
165
166        # The model card metadata is assumed to always be in YAML (frontmatter)
167        # ref: https://github.com/huggingface/transformers/blob/a5c642fe7a1f25d3bdcd76991443ba6ff7ee34b2/src/transformers/modelcard.py#L468-L473
168        yaml_content: str = ""
169        with open(model_card_path, "r", encoding="utf-8") as f:
170            content = f.read()
171            lines = content.splitlines()
172            lines_yaml = []
173            if len(lines) == 0:
174                # Empty file
175                return {}
176            if len(lines) > 0 and lines[0] != "---":
177                # No frontmatter
178                return {}
179            for line in lines[1:]:
180                if line == "---":
181                    break # End of frontmatter
182                else:
183                    lines_yaml.append(line)
184            yaml_content = "\n".join(lines_yaml) + "\n"
185
186        # Quick hack to fix the Norway problem
187        # https://hitchdev.com/strictyaml/why/implicit-typing-removed/
188        yaml_content = yaml_content.replace("- no\n", "- \"no\"\n")
189        # yaml should use 2 spaces insted of tab
190        # this issue has came up with the Qwen/Qwen3-235B-A22B-Instruct-2507 model card
191        #    (I've also sent a pr tp fix the modelcard too)
192        yaml_content = yaml_content.replace("\t", "  ")
193
194        if yaml_content:
195            data = yaml.safe_load(yaml_content)
196            if isinstance(data, dict):
197                return data
198            else:
199                logger.error(f"while reading YAML model card frontmatter, data is {type(data)} instead of dict")
200                return {}
201        else:
202            return {}
203
204    @staticmethod
205    def load_hf_parameters(model_path: Optional[Path] = None) -> dict[str, Any]:
206        if model_path is None or not model_path.is_dir():
207            return {}
208
209        config_path = model_path / "config.json"
210
211        if not config_path.is_file():
212            return {}
213
214        with open(config_path, "r", encoding="utf-8") as f:
215            return json.load(f)
216
217    @staticmethod
218    def load_generation_config(model_path: Optional[Path] = None) -> dict[str, Any]:
219        if model_path is None or not model_path.is_dir():
220            return {}
221
222        generation_config_path = model_path / "generation_config.json"
223
224        if not generation_config_path.is_file():
225            return {}
226
227        try:
228            with open(generation_config_path, "r", encoding="utf-8") as f:
229                return json.load(f)
230        except (json.JSONDecodeError, IOError):
231            # not all models have valid generation_config.json
232            return {}
233
234    @staticmethod
235    def id_to_title(string):
236        # Convert capitalization into title form unless acronym or version number
237        return ' '.join([w.title() if w.islower() and not re.match(r'^(v\d+(?:\.\d+)*|\d.*)$', w) else w for w in string.strip().replace('-', ' ').split()])
238
239    @staticmethod
240    def get_model_id_components(model_id: Optional[str] = None, total_params: int = 0) -> tuple[str | None, str | None, str | None, str | None, str | None, str | None]:
241        # Huggingface often store model id as '<org>/<model name>'
242        # so let's parse it and apply some heuristics if possible for model name components
243
244        if model_id is None:
245            # model ID missing
246            return None, None, None, None, None, None
247
248        if ' ' in model_id:
249            # model ID is actually a normal human sentence
250            # which means its most likely a normal model name only
251            # not part of the hugging face naming standard, but whatever
252            return model_id, None, None, None, None, None
253
254        if '/' in model_id:
255            # model ID (huggingface style)
256            org_component, model_full_name_component = model_id.split('/', 1)
257        else:
258            # model ID but missing org components
259            org_component, model_full_name_component = None, model_id
260
261        # Check if we erroneously matched against './' or '../' etc...
262        if org_component is not None and len(org_component) > 0 and org_component[0] == '.':
263            org_component = None
264
265        name_parts: list[str] = model_full_name_component.split('-')
266
267        # Remove empty parts
268        for i in reversed(range(len(name_parts))):
269            if len(name_parts[i]) == 0:
270                del name_parts[i]
271
272        name_types: list[
273            set[Literal["basename", "size_label", "finetune", "version", "type"]]
274        ] = [set() for _ in name_parts]
275
276        # Annotate the name
277        for i, part in enumerate(name_parts):
278            # Version
279            if re.fullmatch(r'(v|iter)?\d+([.]\d+)*', part, re.IGNORECASE):
280                name_types[i].add("version")
281            # Quant type (should not be there for base models, but still annotated)
282            elif re.fullmatch(r'i?q\d(_\w)*|b?fp?(16|32)', part, re.IGNORECASE):
283                name_types[i].add("type")
284                name_parts[i] = part.upper()
285            # Model size
286            elif i > 0 and re.fullmatch(r'(([A]|\d+[x])?\d+([._]\d+)?[KMBT][\d]?|small|mini|medium|large|x?xl)', part, re.IGNORECASE):
287                part = part.replace("_", ".")
288                # Handle weird bloom-7b1 notation
289                if part[-1].isdecimal():
290                    part = part[:-2] + "." + part[-1] + part[-2]
291                # Normalize the size suffixes
292                if len(part) > 1 and part[-2].isdecimal():
293                    if part[-1] in "kmbt":
294                        part = part[:-1] + part[-1].upper()
295                if total_params != 0:
296                    try:
297                        label_params = float(part[:-1]) * pow(1000, " KMBT".find(part[-1]))
298                        # Only use it as a size label if it's close or bigger than the model size
299                        # Note that LoRA adapters don't necessarily include all layers,
300                        # so this is why bigger label sizes are accepted.
301                        # Do not use the size label when it's smaller than 1/8 of the model size
302                        if (total_params < 0 and label_params < abs(total_params) // 8) or (
303                            # Check both directions when the current model isn't a LoRA adapter
304                            total_params > 0 and abs(label_params - total_params) > 7 * total_params // 8
305                        ):
306                            # Likely a context length
307                            name_types[i].add("finetune")
308                            # Lowercase the size when it's a context length
309                            part = part[:-1] + part[-1].lower()
310                    except ValueError:
311                        # Failed to convert the size label to float, use it anyway
312                        pass
313                if len(name_types[i]) == 0:
314                    name_types[i].add("size_label")
315                name_parts[i] = part
316            # Some easy to recognize finetune names
317            elif i > 0 and re.fullmatch(r'chat|instruct|vision|lora', part, re.IGNORECASE):
318                if total_params < 0 and part.lower() == "lora":
319                    # ignore redundant "lora" in the finetune part when the output is a lora adapter
320                    name_types[i].add("type")
321                else:
322                    name_types[i].add("finetune")
323
324        # Ignore word-based size labels when there is at least a number-based one present
325        # TODO: should word-based size labels always be removed instead?
326        if any(c.isdecimal() for n, t in zip(name_parts, name_types) if "size_label" in t for c in n):
327            for n, t in zip(name_parts, name_types):
328                if "size_label" in t:
329                    if all(c.isalpha() for c in n):
330                        t.remove("size_label")
331
332        at_start = True
333        # Find the basename through the annotated name
334        for part, t in zip(name_parts, name_types):
335            if at_start and ((len(t) == 0 and part[0].isalpha()) or "version" in t):
336                t.add("basename")
337            else:
338                if at_start:
339                    at_start = False
340                if len(t) == 0:
341                    t.add("finetune")
342
343        # Remove the basename annotation from trailing version
344        for part, t in zip(reversed(name_parts), reversed(name_types)):
345            if "basename" in t and len(t) > 1:
346                t.remove("basename")
347            else:
348                break
349
350        basename = "-".join(n for n, t in zip(name_parts, name_types) if "basename" in t) or None
351        # Deduplicate size labels using order-preserving 'dict' ('set' seems to sort the keys)
352        size_label = "-".join(dict.fromkeys(s for s, t in zip(name_parts, name_types) if "size_label" in t).keys()) or None
353        finetune = "-".join(f for f, t in zip(name_parts, name_types) if "finetune" in t) or None
354        # TODO: should the basename version always be excluded?
355        # NOTE: multiple finetune versions are joined together
356        version = "-".join(v for v, t, in zip(name_parts, name_types) if "version" in t and "basename" not in t) or None
357
358        if size_label is None and finetune is None and version is None:
359            # Too ambiguous, output nothing
360            basename = None
361
362        return model_full_name_component, org_component, basename, finetune, version, size_label
363
364    @staticmethod
365    def apply_metadata_heuristic(metadata: Metadata, model_card: Optional[dict] = None, hf_params: Optional[dict] = None, model_path: Optional[Path] = None, total_params: int = 0) -> Metadata:
366        # Reference Model Card Metadata: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
367
368        # Model Card Heuristics
369        ########################
370        if model_card is not None:
371
372            def use_model_card_metadata(metadata_key: str, model_card_key: str):
373                if model_card_key in model_card and getattr(metadata, metadata_key, None) is None:
374                    setattr(metadata, metadata_key, model_card.get(model_card_key))
375
376            def use_array_model_card_metadata(metadata_key: str, model_card_key: str):
377                # Note: Will append rather than replace if already exist
378                tags_value = model_card.get(model_card_key, None)
379                if tags_value is None:
380                    return
381
382                current_value = getattr(metadata, metadata_key, None)
383                if current_value is None:
384                    current_value = []
385
386                if isinstance(tags_value, str):
387                    current_value.append(tags_value)
388                elif isinstance(tags_value, list):
389                    current_value.extend(tags_value)
390
391                setattr(metadata, metadata_key, current_value)
392
393            # LLAMA.cpp's direct internal convention
394            # (Definitely not part of hugging face formal/informal standard)
395            #########################################
396            use_model_card_metadata("name", "name")
397            use_model_card_metadata("author", "author")
398            use_model_card_metadata("version", "version")
399            use_model_card_metadata("organization", "organization")
400            use_model_card_metadata("description", "description")
401            use_model_card_metadata("finetune", "finetune")
402            use_model_card_metadata("basename", "basename")
403            use_model_card_metadata("size_label", "size_label")
404            use_model_card_metadata("source_url", "url")
405            use_model_card_metadata("source_doi", "doi")
406            use_model_card_metadata("source_uuid", "uuid")
407            use_model_card_metadata("source_repo_url", "repo_url")
408
409            # LLAMA.cpp's huggingface style convention
410            # (Definitely not part of hugging face formal/informal standard... but with model_ appended to match their style)
411            ###########################################
412            use_model_card_metadata("name", "model_name")
413            use_model_card_metadata("author", "model_author")
414            use_model_card_metadata("version", "model_version")
415            use_model_card_metadata("organization", "model_organization")
416            use_model_card_metadata("description", "model_description")
417            use_model_card_metadata("finetune", "model_finetune")
418            use_model_card_metadata("basename", "model_basename")
419            use_model_card_metadata("size_label", "model_size_label")
420            use_model_card_metadata("source_url", "model_url")
421            use_model_card_metadata("source_doi", "model_doi")
422            use_model_card_metadata("source_uuid", "model_uuid")
423            use_model_card_metadata("source_repo_url", "model_repo_url")
424
425            # Hugging Face Direct Convention
426            #################################
427
428            # Not part of huggingface model card standard but notice some model creator using it
429            # such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
430            use_model_card_metadata("name", "model_name")
431            use_model_card_metadata("author", "model_creator")
432            use_model_card_metadata("basename", "model_type")
433
434            if "base_model" in model_card or "base_models" in model_card or "base_model_sources" in model_card:
435                # This represents the parent models that this is based on
436                # Example: stabilityai/stable-diffusion-xl-base-1.0. Can also be a list (for merges)
437                # Example of merges: https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0.1/blob/main/README.md
438                metadata_base_models = []
439                base_model_value = model_card.get("base_model", model_card.get("base_models", model_card.get("base_model_sources", None)))
440
441                if base_model_value is not None:
442                    if isinstance(base_model_value, str):
443                        metadata_base_models.append(base_model_value)
444                    elif isinstance(base_model_value, list):
445                        metadata_base_models.extend(base_model_value)
446
447                if metadata.base_models is None:
448                    metadata.base_models = []
449
450                for model_id in metadata_base_models:
451                    # NOTE: model size of base model is assumed to be similar to the size of the current model
452                    base_model = {}
453                    if isinstance(model_id, str):
454                        if model_id.startswith("http://") or model_id.startswith("https://") or model_id.startswith("ssh://"):
455                            base_model["repo_url"] = model_id
456
457                            # Check if Hugging Face ID is present in URL
458                            if "huggingface.co" in model_id:
459                                match = re.match(r"https?://huggingface.co/([^/]+/[^/]+)$", model_id)
460                                if match:
461                                    model_id_component = match.group(1)
462                                    model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id_component, total_params)
463
464                                    # Populate model dictionary with extracted components
465                                    if model_full_name_component is not None:
466                                        base_model["name"] = Metadata.id_to_title(model_full_name_component)
467                                    if org_component is not None:
468                                        base_model["organization"] = Metadata.id_to_title(org_component)
469                                    if version is not None:
470                                        base_model["version"] = version
471
472                        else:
473                            # Likely a Hugging Face ID
474                            model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
475
476                            # Populate model dictionary with extracted components
477                            if model_full_name_component is not None:
478                                base_model["name"] = Metadata.id_to_title(model_full_name_component)
479                            if org_component is not None:
480                                base_model["organization"] = Metadata.id_to_title(org_component)
481                            if version is not None:
482                                base_model["version"] = version
483                            if org_component is not None and model_full_name_component is not None:
484                                base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}"
485
486                    elif isinstance(model_id, dict):
487                        base_model = model_id
488
489                    else:
490                        logger.error(f"base model entry '{str(model_id)}' not in a known format")
491
492                    metadata.base_models.append(base_model)
493
494            if "datasets" in model_card or "dataset" in model_card or "dataset_sources" in model_card:
495                # This represents the datasets that this was trained from
496                metadata_datasets = []
497                dataset_value = model_card.get("datasets", model_card.get("dataset", model_card.get("dataset_sources", None)))
498
499                if dataset_value is not None:
500                    if isinstance(dataset_value, str):
501                        metadata_datasets.append(dataset_value)
502                    elif isinstance(dataset_value, list):
503                        metadata_datasets.extend(dataset_value)
504
505                if metadata.datasets is None:
506                    metadata.datasets = []
507
508                for dataset_id in metadata_datasets:
509                    # NOTE: model size of base model is assumed to be similar to the size of the current model
510                    dataset = {}
511                    if isinstance(dataset_id, str):
512                        if dataset_id.startswith(("http://", "https://", "ssh://")):
513                            dataset["repo_url"] = dataset_id
514
515                            # Check if Hugging Face ID is present in URL
516                            if "huggingface.co" in dataset_id:
517                                match = re.match(r"https?://huggingface.co/([^/]+/[^/]+)$", dataset_id)
518                                if match:
519                                    dataset_id_component = match.group(1)
520                                    dataset_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(dataset_id_component, total_params)
521
522                                    # Populate dataset dictionary with extracted components
523                                    if dataset_name_component is not None:
524                                        dataset["name"] = Metadata.id_to_title(dataset_name_component)
525                                    if org_component is not None:
526                                        dataset["organization"] = Metadata.id_to_title(org_component)
527                                    if version is not None:
528                                        dataset["version"] = version
529
530                        else:
531                            # Likely a Hugging Face ID
532                            dataset_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(dataset_id, total_params)
533
534                            # Populate dataset dictionary with extracted components
535                            if dataset_name_component is not None:
536                                dataset["name"] = Metadata.id_to_title(dataset_name_component)
537                            if org_component is not None:
538                                dataset["organization"] = Metadata.id_to_title(org_component)
539                            if version is not None:
540                                dataset["version"] = version
541                            if org_component is not None and dataset_name_component is not None:
542                                dataset["repo_url"] = f"https://huggingface.co/{org_component}/{dataset_name_component}"
543
544                    elif isinstance(dataset_id, dict):
545                        dataset = dataset_id
546
547                    else:
548                        logger.error(f"dataset entry '{str(dataset_id)}' not in a known format")
549
550                    metadata.datasets.append(dataset)
551
552            use_model_card_metadata("license", "license")
553            use_model_card_metadata("license_name", "license_name")
554            use_model_card_metadata("license_link", "license_link")
555
556            use_array_model_card_metadata("tags", "tags")
557            use_array_model_card_metadata("tags", "pipeline_tag")
558
559            use_array_model_card_metadata("languages", "languages")
560            use_array_model_card_metadata("languages", "language")
561
562        # Hugging Face Parameter Heuristics
563        ####################################
564
565        if hf_params is not None:
566
567            hf_name_or_path = hf_params.get("_name_or_path")
568            if hf_name_or_path is not None and hf_name_or_path.count('/') <= 1:
569                # Use _name_or_path only if its actually a model name and not some computer path
570                # e.g. 'meta-llama/Llama-2-7b-hf'
571                model_id = hf_name_or_path
572                model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
573                if metadata.name is None and model_full_name_component is not None:
574                    metadata.name = Metadata.id_to_title(model_full_name_component)
575                if metadata.organization is None and org_component is not None:
576                    metadata.organization = Metadata.id_to_title(org_component)
577                if metadata.basename is None and basename is not None:
578                    metadata.basename = basename
579                if metadata.finetune is None and finetune is not None:
580                    metadata.finetune = finetune
581                if metadata.version is None and version is not None:
582                    metadata.version = version
583                if metadata.size_label is None and size_label is not None:
584                    metadata.size_label = size_label
585
586        # Directory Folder Name Fallback Heuristics
587        ############################################
588        if model_path is not None:
589            model_id = model_path.name
590            model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
591            if metadata.name is None and model_full_name_component is not None:
592                metadata.name = Metadata.id_to_title(model_full_name_component)
593            if metadata.organization is None and org_component is not None:
594                metadata.organization = Metadata.id_to_title(org_component)
595            if metadata.basename is None and basename is not None:
596                metadata.basename = basename
597            if metadata.finetune is None and finetune is not None:
598                metadata.finetune = finetune
599            if metadata.version is None and version is not None:
600                metadata.version = version
601            if metadata.size_label is None and size_label is not None:
602                metadata.size_label = size_label
603
604        return metadata
605
606    def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter):
607        assert self.name is not None
608
609        if self.sampling_sequence is not None:
610            gguf_writer.add_sampling_sequence(self.sampling_sequence)
611        if self.sampling_top_k is not None:
612            gguf_writer.add_sampling_top_k(self.sampling_top_k)
613        if self.sampling_top_p is not None:
614            gguf_writer.add_sampling_top_p(self.sampling_top_p)
615        if self.sampling_min_p is not None:
616            gguf_writer.add_sampling_min_p(self.sampling_min_p)
617        if self.sampling_xtc_probability is not None:
618            gguf_writer.add_sampling_xtc_probability(self.sampling_xtc_probability)
619        if self.sampling_xtc_threshold is not None:
620            gguf_writer.add_sampling_xtc_threshold(self.sampling_xtc_threshold)
621        if self.sampling_temp is not None:
622            gguf_writer.add_sampling_temp(self.sampling_temp)
623        if self.sampling_penalty_last_n is not None:
624            gguf_writer.add_sampling_penalty_last_n(self.sampling_penalty_last_n)
625        if self.sampling_penalty_repeat is not None:
626            gguf_writer.add_sampling_penalty_repeat(self.sampling_penalty_repeat)
627        if self.sampling_mirostat is not None:
628            gguf_writer.add_sampling_mirostat(self.sampling_mirostat)
629        if self.sampling_mirostat_tau is not None:
630            gguf_writer.add_sampling_mirostat_tau(self.sampling_mirostat_tau)
631        if self.sampling_mirostat_eta is not None:
632            gguf_writer.add_sampling_mirostat_eta(self.sampling_mirostat_eta)
633
634        gguf_writer.add_name(self.name)
635
636        if self.author is not None:
637            gguf_writer.add_author(self.author)
638        if self.version is not None:
639            gguf_writer.add_version(self.version)
640        if self.organization is not None:
641            gguf_writer.add_organization(self.organization)
642
643        if self.finetune is not None:
644            gguf_writer.add_finetune(self.finetune)
645        if self.basename is not None:
646            gguf_writer.add_basename(self.basename)
647
648        if self.description is not None:
649            gguf_writer.add_description(self.description)
650        if self.quantized_by is not None:
651            gguf_writer.add_quantized_by(self.quantized_by)
652
653        if self.size_label is not None:
654            gguf_writer.add_size_label(self.size_label)
655
656        if self.license is not None:
657            if isinstance(self.license, list):
658                gguf_writer.add_license(",".join(self.license))
659            else:
660                gguf_writer.add_license(self.license)
661        if self.license_name is not None:
662            gguf_writer.add_license_name(self.license_name)
663        if self.license_link is not None:
664            gguf_writer.add_license_link(self.license_link)
665
666        if self.url is not None:
667            gguf_writer.add_url(self.url)
668        if self.doi is not None:
669            gguf_writer.add_doi(self.doi)
670        if self.uuid is not None:
671            gguf_writer.add_uuid(self.uuid)
672        if self.repo_url is not None:
673            gguf_writer.add_repo_url(self.repo_url)
674
675        if self.source_url is not None:
676            gguf_writer.add_source_url(self.source_url)
677        if self.source_doi is not None:
678            gguf_writer.add_source_doi(self.source_doi)
679        if self.source_uuid is not None:
680            gguf_writer.add_source_uuid(self.source_uuid)
681        if self.source_repo_url is not None:
682            gguf_writer.add_source_repo_url(self.source_repo_url)
683
684        if self.base_models is not None:
685            gguf_writer.add_base_model_count(len(self.base_models))
686            for key, base_model_entry in enumerate(self.base_models):
687                if "name" in base_model_entry:
688                    gguf_writer.add_base_model_name(key, base_model_entry["name"])
689                if "author" in base_model_entry:
690                    gguf_writer.add_base_model_author(key, base_model_entry["author"])
691                if "version" in base_model_entry:
692                    gguf_writer.add_base_model_version(key, base_model_entry["version"])
693                if "organization" in base_model_entry:
694                    gguf_writer.add_base_model_organization(key, base_model_entry["organization"])
695                if "description" in base_model_entry:
696                    gguf_writer.add_base_model_description(key, base_model_entry["description"])
697                if "url" in base_model_entry:
698                    gguf_writer.add_base_model_url(key, base_model_entry["url"])
699                if "doi" in base_model_entry:
700                    gguf_writer.add_base_model_doi(key, base_model_entry["doi"])
701                if "uuid" in base_model_entry:
702                    gguf_writer.add_base_model_uuid(key, base_model_entry["uuid"])
703                if "repo_url" in base_model_entry:
704                    gguf_writer.add_base_model_repo_url(key, base_model_entry["repo_url"])
705
706        if self.datasets is not None:
707            gguf_writer.add_dataset_count(len(self.datasets))
708            for key, dataset_entry in enumerate(self.datasets):
709                if "name" in dataset_entry:
710                    gguf_writer.add_dataset_name(key, dataset_entry["name"])
711                if "author" in dataset_entry:
712                    gguf_writer.add_dataset_author(key, dataset_entry["author"])
713                if "version" in dataset_entry:
714                    gguf_writer.add_dataset_version(key, dataset_entry["version"])
715                if "organization" in dataset_entry:
716                    gguf_writer.add_dataset_organization(key, dataset_entry["organization"])
717                if "description" in dataset_entry:
718                    gguf_writer.add_dataset_description(key, dataset_entry["description"])
719                if "url" in dataset_entry:
720                    gguf_writer.add_dataset_url(key, dataset_entry["url"])
721                if "doi" in dataset_entry:
722                    gguf_writer.add_dataset_doi(key, dataset_entry["doi"])
723                if "uuid" in dataset_entry:
724                    gguf_writer.add_dataset_uuid(key, dataset_entry["uuid"])
725                if "repo_url" in dataset_entry:
726                    gguf_writer.add_dataset_repo_url(key, dataset_entry["repo_url"])
727
728        if self.tags is not None:
729            gguf_writer.add_tags(self.tags)
730        if self.languages is not None:
731            gguf_writer.add_languages(self.languages)