1from __future__ import annotations
2
3import re
4import json
5import yaml
6import logging
7from pathlib import Path
8from typing import Any, Literal, Optional
9from dataclasses import dataclass
10
11from .constants import Keys
12
13import gguf
14
15logger = logging.getLogger("metadata")
16
17
18@dataclass
19class Metadata:
20 # Recommended Sampler Parameters to be written to GGUF KV Store
21 sampling_sequence: Optional[str] = None
22 sampling_top_k: Optional[int] = None
23 sampling_top_p: Optional[float] = None
24 sampling_min_p: Optional[float] = None
25 sampling_xtc_probability: Optional[float] = None
26 sampling_xtc_threshold: Optional[float] = None
27 sampling_temp: Optional[float] = None
28 sampling_penalty_last_n: Optional[int] = None
29 sampling_penalty_repeat: Optional[float] = None
30 sampling_mirostat: Optional[int] = None
31 sampling_mirostat_tau: Optional[float] = None
32 sampling_mirostat_eta: Optional[float] = None
33
34 # Authorship Metadata to be written to GGUF KV Store
35 name: Optional[str] = None
36 author: Optional[str] = None
37 version: Optional[str] = None
38 organization: Optional[str] = None
39 finetune: Optional[str] = None
40 basename: Optional[str] = None
41 description: Optional[str] = None
42 quantized_by: Optional[str] = None
43 size_label: Optional[str] = None
44 url: Optional[str] = None
45 doi: Optional[str] = None
46 uuid: Optional[str] = None
47 repo_url: Optional[str] = None
48 source_url: Optional[str] = None
49 source_doi: Optional[str] = None
50 source_uuid: Optional[str] = None
51 source_repo_url: Optional[str] = None
52 license: Optional[str] = None
53 license_name: Optional[str] = None
54 license_link: Optional[str] = None
55 base_models: Optional[list[dict]] = None
56 tags: Optional[list[str]] = None
57 languages: Optional[list[str]] = None
58 datasets: Optional[list[dict]] = None
59
60 @staticmethod
61 def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Path] = None, model_name: Optional[str] = None, total_params: int = 0) -> Metadata:
62 # This grabs as many contextual authorship metadata as possible from the model repository
63 # making any conversion as required to match the gguf kv store metadata format
64 # as well as giving users the ability to override any authorship metadata that may be incorrect
65
66 # Create a new Metadata instance
67 metadata = Metadata()
68
69 model_card = Metadata.load_model_card(model_path)
70 hf_params = Metadata.load_hf_parameters(model_path)
71 gen_config = Metadata.load_generation_config(model_path)
72 # TODO: load adapter_config.json when possible, it usually contains the base model of the LoRA adapter
73
74 # heuristics
75 metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params)
76
77 if gen_config:
78 metadata.sampling_sequence = gen_config.get("sequence", metadata.sampling_sequence)
79 metadata.sampling_top_k = gen_config.get("top_k", metadata.sampling_top_k)
80 metadata.sampling_top_p = gen_config.get("top_p", metadata.sampling_top_p)
81 metadata.sampling_min_p = gen_config.get("min_p", metadata.sampling_min_p)
82 metadata.sampling_xtc_probability = gen_config.get("xtc_probability", metadata.sampling_xtc_probability)
83 metadata.sampling_xtc_threshold = gen_config.get("xtc_threshold", metadata.sampling_xtc_threshold)
84 metadata.sampling_temp = gen_config.get("temperature", metadata.sampling_temp)
85 metadata.sampling_penalty_last_n = gen_config.get("penalty_last_n", metadata.sampling_penalty_last_n)
86 metadata.sampling_penalty_repeat = gen_config.get("penalty_repeat", metadata.sampling_penalty_repeat)
87 metadata.sampling_mirostat = gen_config.get("mirostat", metadata.sampling_mirostat)
88 metadata.sampling_mirostat_tau = gen_config.get("mirostat_tau", metadata.sampling_mirostat_tau)
89 metadata.sampling_mirostat_eta = gen_config.get("mirostat_eta", metadata.sampling_mirostat_eta)
90
91 # Metadata Override File Provided
92 # This is based on LLM_KV_NAMES mapping in llama.cpp
93 metadata_override = Metadata.load_metadata_override(metadata_override_path)
94
95 metadata.sampling_sequence = metadata_override.get(Keys.General.SAMPLING_SEQUENCE, metadata.sampling_sequence)
96 metadata.sampling_top_k = metadata_override.get(Keys.General.SAMPLING_TOP_K, metadata.sampling_top_k)
97 metadata.sampling_top_p = metadata_override.get(Keys.General.SAMPLING_TOP_P, metadata.sampling_top_p)
98 metadata.sampling_min_p = metadata_override.get(Keys.General.SAMPLING_MIN_P, metadata.sampling_min_p)
99 metadata.sampling_xtc_probability = metadata_override.get(Keys.General.SAMPLING_XTC_PROBABILITY, metadata.sampling_xtc_probability)
100 metadata.sampling_xtc_threshold = metadata_override.get(Keys.General.SAMPLING_XTC_THRESHOLD, metadata.sampling_xtc_threshold)
101 metadata.sampling_temp = metadata_override.get(Keys.General.SAMPLING_TEMP, metadata.sampling_temp)
102 metadata.sampling_penalty_last_n = metadata_override.get(Keys.General.SAMPLING_PENALTY_LAST_N, metadata.sampling_penalty_last_n)
103 metadata.sampling_penalty_repeat = metadata_override.get(Keys.General.SAMPLING_PENALTY_REPEAT, metadata.sampling_penalty_repeat)
104 metadata.sampling_mirostat = metadata_override.get(Keys.General.SAMPLING_MIROSTAT, metadata.sampling_mirostat)
105 metadata.sampling_mirostat_tau = metadata_override.get(Keys.General.SAMPLING_MIROSTAT_TAU, metadata.sampling_mirostat_tau)
106 metadata.sampling_mirostat_eta = metadata_override.get(Keys.General.SAMPLING_MIROSTAT_ETA, metadata.sampling_mirostat_eta)
107
108 metadata.name = metadata_override.get(Keys.General.NAME, metadata.name)
109 metadata.author = metadata_override.get(Keys.General.AUTHOR, metadata.author)
110 metadata.version = metadata_override.get(Keys.General.VERSION, metadata.version)
111 metadata.organization = metadata_override.get(Keys.General.ORGANIZATION, metadata.organization)
112
113 metadata.finetune = metadata_override.get(Keys.General.FINETUNE, metadata.finetune)
114 metadata.basename = metadata_override.get(Keys.General.BASENAME, metadata.basename)
115
116 metadata.description = metadata_override.get(Keys.General.DESCRIPTION, metadata.description)
117 metadata.quantized_by = metadata_override.get(Keys.General.QUANTIZED_BY, metadata.quantized_by)
118
119 metadata.size_label = metadata_override.get(Keys.General.SIZE_LABEL, metadata.size_label)
120 metadata.license_name = metadata_override.get(Keys.General.LICENSE_NAME, metadata.license_name)
121 metadata.license_link = metadata_override.get(Keys.General.LICENSE_LINK, metadata.license_link)
122
123 metadata.url = metadata_override.get(Keys.General.URL, metadata.url)
124 metadata.doi = metadata_override.get(Keys.General.DOI, metadata.doi)
125 metadata.uuid = metadata_override.get(Keys.General.UUID, metadata.uuid)
126 metadata.repo_url = metadata_override.get(Keys.General.REPO_URL, metadata.repo_url)
127
128 metadata.source_url = metadata_override.get(Keys.General.SOURCE_URL, metadata.source_url)
129 metadata.source_doi = metadata_override.get(Keys.General.SOURCE_DOI, metadata.source_doi)
130 metadata.source_uuid = metadata_override.get(Keys.General.SOURCE_UUID, metadata.source_uuid)
131 metadata.source_repo_url = metadata_override.get(Keys.General.SOURCE_REPO_URL, metadata.source_repo_url)
132
133 # Base Models is received here as an array of models
134 metadata.base_models = metadata_override.get("general.base_models", metadata.base_models)
135
136 # Datasets is received here as an array of datasets
137 metadata.datasets = metadata_override.get("general.datasets", metadata.datasets)
138
139 metadata.tags = metadata_override.get(Keys.General.TAGS, metadata.tags)
140 metadata.languages = metadata_override.get(Keys.General.LANGUAGES, metadata.languages)
141
142 # Direct Metadata Override (via direct cli argument)
143 if model_name is not None:
144 metadata.name = model_name
145
146 return metadata
147
148 @staticmethod
149 def load_metadata_override(metadata_override_path: Optional[Path] = None) -> dict[str, Any]:
150 if metadata_override_path is None or not metadata_override_path.is_file():
151 return {}
152
153 with open(metadata_override_path, "r", encoding="utf-8") as f:
154 return json.load(f)
155
156 @staticmethod
157 def load_model_card(model_path: Optional[Path] = None) -> dict[str, Any]:
158 if model_path is None or not model_path.is_dir():
159 return {}
160
161 model_card_path = model_path / "README.md"
162
163 if not model_card_path.is_file():
164 return {}
165
166 # The model card metadata is assumed to always be in YAML (frontmatter)
167 # ref: https://github.com/huggingface/transformers/blob/a5c642fe7a1f25d3bdcd76991443ba6ff7ee34b2/src/transformers/modelcard.py#L468-L473
168 yaml_content: str = ""
169 with open(model_card_path, "r", encoding="utf-8") as f:
170 content = f.read()
171 lines = content.splitlines()
172 lines_yaml = []
173 if len(lines) == 0:
174 # Empty file
175 return {}
176 if len(lines) > 0 and lines[0] != "---":
177 # No frontmatter
178 return {}
179 for line in lines[1:]:
180 if line == "---":
181 break # End of frontmatter
182 else:
183 lines_yaml.append(line)
184 yaml_content = "\n".join(lines_yaml) + "\n"
185
186 # Quick hack to fix the Norway problem
187 # https://hitchdev.com/strictyaml/why/implicit-typing-removed/
188 yaml_content = yaml_content.replace("- no\n", "- \"no\"\n")
189 # yaml should use 2 spaces insted of tab
190 # this issue has came up with the Qwen/Qwen3-235B-A22B-Instruct-2507 model card
191 # (I've also sent a pr tp fix the modelcard too)
192 yaml_content = yaml_content.replace("\t", " ")
193
194 if yaml_content:
195 data = yaml.safe_load(yaml_content)
196 if isinstance(data, dict):
197 return data
198 else:
199 logger.error(f"while reading YAML model card frontmatter, data is {type(data)} instead of dict")
200 return {}
201 else:
202 return {}
203
204 @staticmethod
205 def load_hf_parameters(model_path: Optional[Path] = None) -> dict[str, Any]:
206 if model_path is None or not model_path.is_dir():
207 return {}
208
209 config_path = model_path / "config.json"
210
211 if not config_path.is_file():
212 return {}
213
214 with open(config_path, "r", encoding="utf-8") as f:
215 return json.load(f)
216
217 @staticmethod
218 def load_generation_config(model_path: Optional[Path] = None) -> dict[str, Any]:
219 if model_path is None or not model_path.is_dir():
220 return {}
221
222 generation_config_path = model_path / "generation_config.json"
223
224 if not generation_config_path.is_file():
225 return {}
226
227 try:
228 with open(generation_config_path, "r", encoding="utf-8") as f:
229 return json.load(f)
230 except (json.JSONDecodeError, IOError):
231 # not all models have valid generation_config.json
232 return {}
233
234 @staticmethod
235 def id_to_title(string):
236 # Convert capitalization into title form unless acronym or version number
237 return ' '.join([w.title() if w.islower() and not re.match(r'^(v\d+(?:\.\d+)*|\d.*)$', w) else w for w in string.strip().replace('-', ' ').split()])
238
239 @staticmethod
240 def get_model_id_components(model_id: Optional[str] = None, total_params: int = 0) -> tuple[str | None, str | None, str | None, str | None, str | None, str | None]:
241 # Huggingface often store model id as '<org>/<model name>'
242 # so let's parse it and apply some heuristics if possible for model name components
243
244 if model_id is None:
245 # model ID missing
246 return None, None, None, None, None, None
247
248 if ' ' in model_id:
249 # model ID is actually a normal human sentence
250 # which means its most likely a normal model name only
251 # not part of the hugging face naming standard, but whatever
252 return model_id, None, None, None, None, None
253
254 if '/' in model_id:
255 # model ID (huggingface style)
256 org_component, model_full_name_component = model_id.split('/', 1)
257 else:
258 # model ID but missing org components
259 org_component, model_full_name_component = None, model_id
260
261 # Check if we erroneously matched against './' or '../' etc...
262 if org_component is not None and len(org_component) > 0 and org_component[0] == '.':
263 org_component = None
264
265 name_parts: list[str] = model_full_name_component.split('-')
266
267 # Remove empty parts
268 for i in reversed(range(len(name_parts))):
269 if len(name_parts[i]) == 0:
270 del name_parts[i]
271
272 name_types: list[
273 set[Literal["basename", "size_label", "finetune", "version", "type"]]
274 ] = [set() for _ in name_parts]
275
276 # Annotate the name
277 for i, part in enumerate(name_parts):
278 # Version
279 if re.fullmatch(r'(v|iter)?\d+([.]\d+)*', part, re.IGNORECASE):
280 name_types[i].add("version")
281 # Quant type (should not be there for base models, but still annotated)
282 elif re.fullmatch(r'i?q\d(_\w)*|b?fp?(16|32)', part, re.IGNORECASE):
283 name_types[i].add("type")
284 name_parts[i] = part.upper()
285 # Model size
286 elif i > 0 and re.fullmatch(r'(([A]|\d+[x])?\d+([._]\d+)?[KMBT][\d]?|small|mini|medium|large|x?xl)', part, re.IGNORECASE):
287 part = part.replace("_", ".")
288 # Handle weird bloom-7b1 notation
289 if part[-1].isdecimal():
290 part = part[:-2] + "." + part[-1] + part[-2]
291 # Normalize the size suffixes
292 if len(part) > 1 and part[-2].isdecimal():
293 if part[-1] in "kmbt":
294 part = part[:-1] + part[-1].upper()
295 if total_params != 0:
296 try:
297 label_params = float(part[:-1]) * pow(1000, " KMBT".find(part[-1]))
298 # Only use it as a size label if it's close or bigger than the model size
299 # Note that LoRA adapters don't necessarily include all layers,
300 # so this is why bigger label sizes are accepted.
301 # Do not use the size label when it's smaller than 1/8 of the model size
302 if (total_params < 0 and label_params < abs(total_params) // 8) or (
303 # Check both directions when the current model isn't a LoRA adapter
304 total_params > 0 and abs(label_params - total_params) > 7 * total_params // 8
305 ):
306 # Likely a context length
307 name_types[i].add("finetune")
308 # Lowercase the size when it's a context length
309 part = part[:-1] + part[-1].lower()
310 except ValueError:
311 # Failed to convert the size label to float, use it anyway
312 pass
313 if len(name_types[i]) == 0:
314 name_types[i].add("size_label")
315 name_parts[i] = part
316 # Some easy to recognize finetune names
317 elif i > 0 and re.fullmatch(r'chat|instruct|vision|lora', part, re.IGNORECASE):
318 if total_params < 0 and part.lower() == "lora":
319 # ignore redundant "lora" in the finetune part when the output is a lora adapter
320 name_types[i].add("type")
321 else:
322 name_types[i].add("finetune")
323
324 # Ignore word-based size labels when there is at least a number-based one present
325 # TODO: should word-based size labels always be removed instead?
326 if any(c.isdecimal() for n, t in zip(name_parts, name_types) if "size_label" in t for c in n):
327 for n, t in zip(name_parts, name_types):
328 if "size_label" in t:
329 if all(c.isalpha() for c in n):
330 t.remove("size_label")
331
332 at_start = True
333 # Find the basename through the annotated name
334 for part, t in zip(name_parts, name_types):
335 if at_start and ((len(t) == 0 and part[0].isalpha()) or "version" in t):
336 t.add("basename")
337 else:
338 if at_start:
339 at_start = False
340 if len(t) == 0:
341 t.add("finetune")
342
343 # Remove the basename annotation from trailing version
344 for part, t in zip(reversed(name_parts), reversed(name_types)):
345 if "basename" in t and len(t) > 1:
346 t.remove("basename")
347 else:
348 break
349
350 basename = "-".join(n for n, t in zip(name_parts, name_types) if "basename" in t) or None
351 # Deduplicate size labels using order-preserving 'dict' ('set' seems to sort the keys)
352 size_label = "-".join(dict.fromkeys(s for s, t in zip(name_parts, name_types) if "size_label" in t).keys()) or None
353 finetune = "-".join(f for f, t in zip(name_parts, name_types) if "finetune" in t) or None
354 # TODO: should the basename version always be excluded?
355 # NOTE: multiple finetune versions are joined together
356 version = "-".join(v for v, t, in zip(name_parts, name_types) if "version" in t and "basename" not in t) or None
357
358 if size_label is None and finetune is None and version is None:
359 # Too ambiguous, output nothing
360 basename = None
361
362 return model_full_name_component, org_component, basename, finetune, version, size_label
363
364 @staticmethod
365 def apply_metadata_heuristic(metadata: Metadata, model_card: Optional[dict] = None, hf_params: Optional[dict] = None, model_path: Optional[Path] = None, total_params: int = 0) -> Metadata:
366 # Reference Model Card Metadata: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
367
368 # Model Card Heuristics
369 ########################
370 if model_card is not None:
371
372 def use_model_card_metadata(metadata_key: str, model_card_key: str):
373 if model_card_key in model_card and getattr(metadata, metadata_key, None) is None:
374 setattr(metadata, metadata_key, model_card.get(model_card_key))
375
376 def use_array_model_card_metadata(metadata_key: str, model_card_key: str):
377 # Note: Will append rather than replace if already exist
378 tags_value = model_card.get(model_card_key, None)
379 if tags_value is None:
380 return
381
382 current_value = getattr(metadata, metadata_key, None)
383 if current_value is None:
384 current_value = []
385
386 if isinstance(tags_value, str):
387 current_value.append(tags_value)
388 elif isinstance(tags_value, list):
389 current_value.extend(tags_value)
390
391 setattr(metadata, metadata_key, current_value)
392
393 # LLAMA.cpp's direct internal convention
394 # (Definitely not part of hugging face formal/informal standard)
395 #########################################
396 use_model_card_metadata("name", "name")
397 use_model_card_metadata("author", "author")
398 use_model_card_metadata("version", "version")
399 use_model_card_metadata("organization", "organization")
400 use_model_card_metadata("description", "description")
401 use_model_card_metadata("finetune", "finetune")
402 use_model_card_metadata("basename", "basename")
403 use_model_card_metadata("size_label", "size_label")
404 use_model_card_metadata("source_url", "url")
405 use_model_card_metadata("source_doi", "doi")
406 use_model_card_metadata("source_uuid", "uuid")
407 use_model_card_metadata("source_repo_url", "repo_url")
408
409 # LLAMA.cpp's huggingface style convention
410 # (Definitely not part of hugging face formal/informal standard... but with model_ appended to match their style)
411 ###########################################
412 use_model_card_metadata("name", "model_name")
413 use_model_card_metadata("author", "model_author")
414 use_model_card_metadata("version", "model_version")
415 use_model_card_metadata("organization", "model_organization")
416 use_model_card_metadata("description", "model_description")
417 use_model_card_metadata("finetune", "model_finetune")
418 use_model_card_metadata("basename", "model_basename")
419 use_model_card_metadata("size_label", "model_size_label")
420 use_model_card_metadata("source_url", "model_url")
421 use_model_card_metadata("source_doi", "model_doi")
422 use_model_card_metadata("source_uuid", "model_uuid")
423 use_model_card_metadata("source_repo_url", "model_repo_url")
424
425 # Hugging Face Direct Convention
426 #################################
427
428 # Not part of huggingface model card standard but notice some model creator using it
429 # such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
430 use_model_card_metadata("name", "model_name")
431 use_model_card_metadata("author", "model_creator")
432 use_model_card_metadata("basename", "model_type")
433
434 if "base_model" in model_card or "base_models" in model_card or "base_model_sources" in model_card:
435 # This represents the parent models that this is based on
436 # Example: stabilityai/stable-diffusion-xl-base-1.0. Can also be a list (for merges)
437 # Example of merges: https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0.1/blob/main/README.md
438 metadata_base_models = []
439 base_model_value = model_card.get("base_model", model_card.get("base_models", model_card.get("base_model_sources", None)))
440
441 if base_model_value is not None:
442 if isinstance(base_model_value, str):
443 metadata_base_models.append(base_model_value)
444 elif isinstance(base_model_value, list):
445 metadata_base_models.extend(base_model_value)
446
447 if metadata.base_models is None:
448 metadata.base_models = []
449
450 for model_id in metadata_base_models:
451 # NOTE: model size of base model is assumed to be similar to the size of the current model
452 base_model = {}
453 if isinstance(model_id, str):
454 if model_id.startswith("http://") or model_id.startswith("https://") or model_id.startswith("ssh://"):
455 base_model["repo_url"] = model_id
456
457 # Check if Hugging Face ID is present in URL
458 if "huggingface.co" in model_id:
459 match = re.match(r"https?://huggingface.co/([^/]+/[^/]+)$", model_id)
460 if match:
461 model_id_component = match.group(1)
462 model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id_component, total_params)
463
464 # Populate model dictionary with extracted components
465 if model_full_name_component is not None:
466 base_model["name"] = Metadata.id_to_title(model_full_name_component)
467 if org_component is not None:
468 base_model["organization"] = Metadata.id_to_title(org_component)
469 if version is not None:
470 base_model["version"] = version
471
472 else:
473 # Likely a Hugging Face ID
474 model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
475
476 # Populate model dictionary with extracted components
477 if model_full_name_component is not None:
478 base_model["name"] = Metadata.id_to_title(model_full_name_component)
479 if org_component is not None:
480 base_model["organization"] = Metadata.id_to_title(org_component)
481 if version is not None:
482 base_model["version"] = version
483 if org_component is not None and model_full_name_component is not None:
484 base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}"
485
486 elif isinstance(model_id, dict):
487 base_model = model_id
488
489 else:
490 logger.error(f"base model entry '{str(model_id)}' not in a known format")
491
492 metadata.base_models.append(base_model)
493
494 if "datasets" in model_card or "dataset" in model_card or "dataset_sources" in model_card:
495 # This represents the datasets that this was trained from
496 metadata_datasets = []
497 dataset_value = model_card.get("datasets", model_card.get("dataset", model_card.get("dataset_sources", None)))
498
499 if dataset_value is not None:
500 if isinstance(dataset_value, str):
501 metadata_datasets.append(dataset_value)
502 elif isinstance(dataset_value, list):
503 metadata_datasets.extend(dataset_value)
504
505 if metadata.datasets is None:
506 metadata.datasets = []
507
508 for dataset_id in metadata_datasets:
509 # NOTE: model size of base model is assumed to be similar to the size of the current model
510 dataset = {}
511 if isinstance(dataset_id, str):
512 if dataset_id.startswith(("http://", "https://", "ssh://")):
513 dataset["repo_url"] = dataset_id
514
515 # Check if Hugging Face ID is present in URL
516 if "huggingface.co" in dataset_id:
517 match = re.match(r"https?://huggingface.co/([^/]+/[^/]+)$", dataset_id)
518 if match:
519 dataset_id_component = match.group(1)
520 dataset_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(dataset_id_component, total_params)
521
522 # Populate dataset dictionary with extracted components
523 if dataset_name_component is not None:
524 dataset["name"] = Metadata.id_to_title(dataset_name_component)
525 if org_component is not None:
526 dataset["organization"] = Metadata.id_to_title(org_component)
527 if version is not None:
528 dataset["version"] = version
529
530 else:
531 # Likely a Hugging Face ID
532 dataset_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(dataset_id, total_params)
533
534 # Populate dataset dictionary with extracted components
535 if dataset_name_component is not None:
536 dataset["name"] = Metadata.id_to_title(dataset_name_component)
537 if org_component is not None:
538 dataset["organization"] = Metadata.id_to_title(org_component)
539 if version is not None:
540 dataset["version"] = version
541 if org_component is not None and dataset_name_component is not None:
542 dataset["repo_url"] = f"https://huggingface.co/{org_component}/{dataset_name_component}"
543
544 elif isinstance(dataset_id, dict):
545 dataset = dataset_id
546
547 else:
548 logger.error(f"dataset entry '{str(dataset_id)}' not in a known format")
549
550 metadata.datasets.append(dataset)
551
552 use_model_card_metadata("license", "license")
553 use_model_card_metadata("license_name", "license_name")
554 use_model_card_metadata("license_link", "license_link")
555
556 use_array_model_card_metadata("tags", "tags")
557 use_array_model_card_metadata("tags", "pipeline_tag")
558
559 use_array_model_card_metadata("languages", "languages")
560 use_array_model_card_metadata("languages", "language")
561
562 # Hugging Face Parameter Heuristics
563 ####################################
564
565 if hf_params is not None:
566
567 hf_name_or_path = hf_params.get("_name_or_path")
568 if hf_name_or_path is not None and hf_name_or_path.count('/') <= 1:
569 # Use _name_or_path only if its actually a model name and not some computer path
570 # e.g. 'meta-llama/Llama-2-7b-hf'
571 model_id = hf_name_or_path
572 model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
573 if metadata.name is None and model_full_name_component is not None:
574 metadata.name = Metadata.id_to_title(model_full_name_component)
575 if metadata.organization is None and org_component is not None:
576 metadata.organization = Metadata.id_to_title(org_component)
577 if metadata.basename is None and basename is not None:
578 metadata.basename = basename
579 if metadata.finetune is None and finetune is not None:
580 metadata.finetune = finetune
581 if metadata.version is None and version is not None:
582 metadata.version = version
583 if metadata.size_label is None and size_label is not None:
584 metadata.size_label = size_label
585
586 # Directory Folder Name Fallback Heuristics
587 ############################################
588 if model_path is not None:
589 model_id = model_path.name
590 model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
591 if metadata.name is None and model_full_name_component is not None:
592 metadata.name = Metadata.id_to_title(model_full_name_component)
593 if metadata.organization is None and org_component is not None:
594 metadata.organization = Metadata.id_to_title(org_component)
595 if metadata.basename is None and basename is not None:
596 metadata.basename = basename
597 if metadata.finetune is None and finetune is not None:
598 metadata.finetune = finetune
599 if metadata.version is None and version is not None:
600 metadata.version = version
601 if metadata.size_label is None and size_label is not None:
602 metadata.size_label = size_label
603
604 return metadata
605
606 def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter):
607 assert self.name is not None
608
609 if self.sampling_sequence is not None:
610 gguf_writer.add_sampling_sequence(self.sampling_sequence)
611 if self.sampling_top_k is not None:
612 gguf_writer.add_sampling_top_k(self.sampling_top_k)
613 if self.sampling_top_p is not None:
614 gguf_writer.add_sampling_top_p(self.sampling_top_p)
615 if self.sampling_min_p is not None:
616 gguf_writer.add_sampling_min_p(self.sampling_min_p)
617 if self.sampling_xtc_probability is not None:
618 gguf_writer.add_sampling_xtc_probability(self.sampling_xtc_probability)
619 if self.sampling_xtc_threshold is not None:
620 gguf_writer.add_sampling_xtc_threshold(self.sampling_xtc_threshold)
621 if self.sampling_temp is not None:
622 gguf_writer.add_sampling_temp(self.sampling_temp)
623 if self.sampling_penalty_last_n is not None:
624 gguf_writer.add_sampling_penalty_last_n(self.sampling_penalty_last_n)
625 if self.sampling_penalty_repeat is not None:
626 gguf_writer.add_sampling_penalty_repeat(self.sampling_penalty_repeat)
627 if self.sampling_mirostat is not None:
628 gguf_writer.add_sampling_mirostat(self.sampling_mirostat)
629 if self.sampling_mirostat_tau is not None:
630 gguf_writer.add_sampling_mirostat_tau(self.sampling_mirostat_tau)
631 if self.sampling_mirostat_eta is not None:
632 gguf_writer.add_sampling_mirostat_eta(self.sampling_mirostat_eta)
633
634 gguf_writer.add_name(self.name)
635
636 if self.author is not None:
637 gguf_writer.add_author(self.author)
638 if self.version is not None:
639 gguf_writer.add_version(self.version)
640 if self.organization is not None:
641 gguf_writer.add_organization(self.organization)
642
643 if self.finetune is not None:
644 gguf_writer.add_finetune(self.finetune)
645 if self.basename is not None:
646 gguf_writer.add_basename(self.basename)
647
648 if self.description is not None:
649 gguf_writer.add_description(self.description)
650 if self.quantized_by is not None:
651 gguf_writer.add_quantized_by(self.quantized_by)
652
653 if self.size_label is not None:
654 gguf_writer.add_size_label(self.size_label)
655
656 if self.license is not None:
657 if isinstance(self.license, list):
658 gguf_writer.add_license(",".join(self.license))
659 else:
660 gguf_writer.add_license(self.license)
661 if self.license_name is not None:
662 gguf_writer.add_license_name(self.license_name)
663 if self.license_link is not None:
664 gguf_writer.add_license_link(self.license_link)
665
666 if self.url is not None:
667 gguf_writer.add_url(self.url)
668 if self.doi is not None:
669 gguf_writer.add_doi(self.doi)
670 if self.uuid is not None:
671 gguf_writer.add_uuid(self.uuid)
672 if self.repo_url is not None:
673 gguf_writer.add_repo_url(self.repo_url)
674
675 if self.source_url is not None:
676 gguf_writer.add_source_url(self.source_url)
677 if self.source_doi is not None:
678 gguf_writer.add_source_doi(self.source_doi)
679 if self.source_uuid is not None:
680 gguf_writer.add_source_uuid(self.source_uuid)
681 if self.source_repo_url is not None:
682 gguf_writer.add_source_repo_url(self.source_repo_url)
683
684 if self.base_models is not None:
685 gguf_writer.add_base_model_count(len(self.base_models))
686 for key, base_model_entry in enumerate(self.base_models):
687 if "name" in base_model_entry:
688 gguf_writer.add_base_model_name(key, base_model_entry["name"])
689 if "author" in base_model_entry:
690 gguf_writer.add_base_model_author(key, base_model_entry["author"])
691 if "version" in base_model_entry:
692 gguf_writer.add_base_model_version(key, base_model_entry["version"])
693 if "organization" in base_model_entry:
694 gguf_writer.add_base_model_organization(key, base_model_entry["organization"])
695 if "description" in base_model_entry:
696 gguf_writer.add_base_model_description(key, base_model_entry["description"])
697 if "url" in base_model_entry:
698 gguf_writer.add_base_model_url(key, base_model_entry["url"])
699 if "doi" in base_model_entry:
700 gguf_writer.add_base_model_doi(key, base_model_entry["doi"])
701 if "uuid" in base_model_entry:
702 gguf_writer.add_base_model_uuid(key, base_model_entry["uuid"])
703 if "repo_url" in base_model_entry:
704 gguf_writer.add_base_model_repo_url(key, base_model_entry["repo_url"])
705
706 if self.datasets is not None:
707 gguf_writer.add_dataset_count(len(self.datasets))
708 for key, dataset_entry in enumerate(self.datasets):
709 if "name" in dataset_entry:
710 gguf_writer.add_dataset_name(key, dataset_entry["name"])
711 if "author" in dataset_entry:
712 gguf_writer.add_dataset_author(key, dataset_entry["author"])
713 if "version" in dataset_entry:
714 gguf_writer.add_dataset_version(key, dataset_entry["version"])
715 if "organization" in dataset_entry:
716 gguf_writer.add_dataset_organization(key, dataset_entry["organization"])
717 if "description" in dataset_entry:
718 gguf_writer.add_dataset_description(key, dataset_entry["description"])
719 if "url" in dataset_entry:
720 gguf_writer.add_dataset_url(key, dataset_entry["url"])
721 if "doi" in dataset_entry:
722 gguf_writer.add_dataset_doi(key, dataset_entry["doi"])
723 if "uuid" in dataset_entry:
724 gguf_writer.add_dataset_uuid(key, dataset_entry["uuid"])
725 if "repo_url" in dataset_entry:
726 gguf_writer.add_dataset_repo_url(key, dataset_entry["repo_url"])
727
728 if self.tags is not None:
729 gguf_writer.add_tags(self.tags)
730 if self.languages is not None:
731 gguf_writer.add_languages(self.languages)