1#!/usr/bin/env python3
  2
  3import unittest
  4from pathlib import Path
  5import os
  6import sys
  7
  8# Necessary to load the local gguf package
  9if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
 10    sys.path.insert(0, str(Path(__file__).parent.parent))
 11
 12import gguf
 13
 14
 15class TestMetadataMethod(unittest.TestCase):
 16
 17    def test_id_to_title(self):
 18        self.assertEqual(gguf.Metadata.id_to_title("Mixtral-8x7B-Instruct-v0.1"), "Mixtral 8x7B Instruct v0.1")
 19        self.assertEqual(gguf.Metadata.id_to_title("Meta-Llama-3-8B"), "Meta Llama 3 8B")
 20        self.assertEqual(gguf.Metadata.id_to_title("hermes-2-pro-llama-3-8b-DPO"), "Hermes 2 Pro Llama 3 8b DPO")
 21
 22    def test_get_model_id_components(self):
 23        # This is the basic standard form with organization marker
 24        self.assertEqual(gguf.Metadata.get_model_id_components("Mistral/Mixtral-8x7B-Instruct-v0.1"),
 25                         ('Mixtral-8x7B-Instruct-v0.1', "Mistral", 'Mixtral', 'Instruct', 'v0.1', '8x7B'))
 26
 27        # Similar to basic standard form but without organization marker
 28        self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-Instruct-v0.1"),
 29                         ('Mixtral-8x7B-Instruct-v0.1', None, 'Mixtral', 'Instruct', 'v0.1', '8x7B'))
 30
 31        # Missing version
 32        self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-Instruct"),
 33                         ('Mixtral-8x7B-Instruct', None, 'Mixtral', 'Instruct', None, '8x7B'))
 34
 35        # Missing finetune
 36        self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-v0.1"),
 37                         ('Mixtral-8x7B-v0.1', None, 'Mixtral', None, 'v0.1', '8x7B'))
 38
 39        # Base name and size label only
 40        self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B"),
 41                         ('Mixtral-8x7B', None, 'Mixtral', None, None, '8x7B'))
 42
 43        # Base name and version only
 44        self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-v0.1"),
 45                         ('Mixtral-v0.1', None, 'Mixtral', None, 'v0.1', None))
 46
 47        ## Edge Cases ##
 48
 49        # This is too ambiguous... best to err on caution and output nothing
 50        self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral"),
 51                         ('Mixtral', None, None, None, None, None))
 52
 53        # Basename has numbers mixed in and also size label provided. Must avoid capturing number in basename
 54        self.assertEqual(gguf.Metadata.get_model_id_components("NousResearch/Meta-Llama-3-8B"),
 55                         ('Meta-Llama-3-8B', "NousResearch", 'Meta-Llama-3', None, None, '8B'))
 56
 57        # Non standard naming
 58        self.assertEqual(gguf.Metadata.get_model_id_components("Qwen1.5-MoE-A2.7B-Chat"),
 59                         ('Qwen1.5-MoE-A2.7B-Chat', None, 'Qwen1.5-MoE', 'Chat', None, 'A2.7B'))
 60
 61        # Capture 'sub size labels' e.g. A14B in '57B-A14B' usually refers to activated params/weight count
 62        self.assertEqual(gguf.Metadata.get_model_id_components("Qwen2-57B-A14B-Instruct"),
 63                         ('Qwen2-57B-A14B-Instruct', None, 'Qwen2', 'Instruct', None, '57B-A14B'))
 64
 65        # Check that it can handle a real model id with no version code
 66        # Note that 4k in this string is non standard and microsoft were referring to context length rather than weight count
 67        self.assertEqual(gguf.Metadata.get_model_id_components("microsoft/Phi-3-mini-4k-instruct", 4 * 10**9),
 68                         ('Phi-3-mini-4k-instruct', 'microsoft', 'Phi-3', '4k-instruct', None, 'mini'))
 69
 70        # There is some legitimate models with only thousands of parameters
 71        self.assertEqual(gguf.Metadata.get_model_id_components("delphi-suite/stories-llama2-50k", 50 * 10**3),
 72                         ('stories-llama2-50k', 'delphi-suite', 'stories-llama2', None, None, '50K'))
 73
 74        # Non standard and not easy to disambiguate
 75        self.assertEqual(gguf.Metadata.get_model_id_components("DeepSeek-Coder-V2-Lite-Instruct"),
 76                         ('DeepSeek-Coder-V2-Lite-Instruct', None, 'DeepSeek-Coder-V2-Lite', 'Instruct', None, None))
 77
 78        # This is a real model_id where they append 2DPO to refer to Direct Preference Optimization
 79        self.assertEqual(gguf.Metadata.get_model_id_components("crestf411/daybreak-kunoichi-2dpo-7b"),
 80                         ('daybreak-kunoichi-2dpo-7b', 'crestf411', 'daybreak-kunoichi', '2dpo', None, '7B'))
 81
 82        # This is a real model id where the weight size has a decimal point
 83        self.assertEqual(gguf.Metadata.get_model_id_components("Qwen2-0.5B-Instruct"),
 84                         ('Qwen2-0.5B-Instruct', None, 'Qwen2', 'Instruct', None, '0.5B'))
 85
 86        # Uses an underscore in the size label
 87        self.assertEqual(gguf.Metadata.get_model_id_components("smallcloudai/Refact-1_6B-fim"),
 88                         ('Refact-1_6B-fim', 'smallcloudai', 'Refact', 'fim', None, '1.6B'))
 89
 90        # Uses Iter3 for the version
 91        self.assertEqual(gguf.Metadata.get_model_id_components("UCLA-AGI/Gemma-2-9B-It-SPPO-Iter3"),
 92                         ('Gemma-2-9B-It-SPPO-Iter3', 'UCLA-AGI', 'Gemma-2', 'It-SPPO', 'Iter3', '9B'))
 93
 94        # Has two potential versions in the basename
 95        self.assertEqual(gguf.Metadata.get_model_id_components("NousResearch/Hermes-2-Theta-Llama-3-8B"),
 96                         ('Hermes-2-Theta-Llama-3-8B', 'NousResearch', 'Hermes-2-Theta-Llama-3', None, None, '8B'))
 97
 98        # Potential version in the basename
 99        self.assertEqual(gguf.Metadata.get_model_id_components("SeaLLMs/SeaLLMs-v3-7B-Chat"),
100                         ('SeaLLMs-v3-7B-Chat', 'SeaLLMs', 'SeaLLMs-v3', 'Chat', None, '7B'))
101
102        # Underscore in the basename, and 1m for the context size
103        self.assertEqual(gguf.Metadata.get_model_id_components("internlm/internlm2_5-7b-chat-1m", 7 * 10**9),
104                         ('internlm2_5-7b-chat-1m', 'internlm', 'internlm2_5', 'chat-1m', None, '7B'))
105
106        # Version before the finetune name
107        self.assertEqual(gguf.Metadata.get_model_id_components("pszemraj/jamba-900M-v0.13-KIx2"),
108                         ('jamba-900M-v0.13-KIx2', 'pszemraj', 'jamba', 'KIx2', 'v0.13', '900M'))
109
110        # TODO: hf suffix which could be ignored but isn't
111        self.assertEqual(gguf.Metadata.get_model_id_components("state-spaces/mamba-2.8b-hf"),
112                         ('mamba-2.8b-hf', 'state-spaces', 'mamba', 'hf', None, '2.8B'))
113
114        # Two sizes, don't merge them, the other is the number of tokens on which it was trained
115        self.assertEqual(gguf.Metadata.get_model_id_components("abacaj/llama-161M-100B", 161 * 10**6),
116                         ('llama-161M-100B', 'abacaj', 'llama', '100b', None, '161M'))
117
118        # It's a trap, there is no size label
119        self.assertEqual(gguf.Metadata.get_model_id_components("SparseLLM/relu-100B", 1340 * 10**6),
120                         ('relu-100B', 'SparseLLM', 'relu', '100b', None, None))
121
122        # Weird size notation
123        self.assertEqual(gguf.Metadata.get_model_id_components("bigscience/bloom-7b1-petals"),
124                         ('bloom-7b1-petals', 'bigscience', 'bloom', 'petals', None, '7.1B'))
125
126        # Ignore full-text size labels when there are number-based ones, and deduplicate size labels
127        self.assertEqual(gguf.Metadata.get_model_id_components("MaziyarPanahi/GreenNode-mini-7B-multilingual-v1olet-Mistral-7B-Instruct-v0.1"),
128                         ('GreenNode-mini-7B-multilingual-v1olet-Mistral-7B-Instruct-v0.1', 'MaziyarPanahi', 'GreenNode-mini', 'multilingual-v1olet-Mistral-Instruct', 'v0.1', '7B'))
129
130        # Instruct in a name without a size label
131        self.assertEqual(gguf.Metadata.get_model_id_components("mistralai/Mistral-Nemo-Instruct-2407"),
132                         ('Mistral-Nemo-Instruct-2407', 'mistralai', 'Mistral-Nemo', 'Instruct', '2407', None))
133
134        # Non-obvious splitting relying on 'chat' keyword
135        self.assertEqual(gguf.Metadata.get_model_id_components("deepseek-ai/DeepSeek-V2-Chat-0628"),
136                         ('DeepSeek-V2-Chat-0628', 'deepseek-ai', 'DeepSeek-V2', 'Chat', '0628', None))
137
138        # Multiple versions
139        self.assertEqual(gguf.Metadata.get_model_id_components("OpenGVLab/Mini-InternVL-Chat-2B-V1-5"),
140                         ('Mini-InternVL-Chat-2B-V1-5', 'OpenGVLab', 'Mini-InternVL', 'Chat', 'V1-5', '2B'))
141
142        # TODO: DPO in the name
143        self.assertEqual(gguf.Metadata.get_model_id_components("jondurbin/bagel-dpo-2.8b-v0.2"),
144                         ('bagel-dpo-2.8b-v0.2', 'jondurbin', 'bagel-dpo', None, 'v0.2', '2.8B'))
145
146        # DPO in name, but can't be used for the finetune to keep 'LLaMA-3' in the basename
147        self.assertEqual(gguf.Metadata.get_model_id_components("voxmenthe/SFR-Iterative-DPO-LLaMA-3-8B-R-unquantized"),
148                         ('SFR-Iterative-DPO-LLaMA-3-8B-R-unquantized', 'voxmenthe', 'SFR-Iterative-DPO-LLaMA-3', 'R-unquantized', None, '8B'))
149
150        # Too ambiguous
151        # TODO: should "base" be a 'finetune' or 'size_label'?
152        # (in this case it should be a size label, but other models use it to signal that they are not finetuned)
153        self.assertEqual(gguf.Metadata.get_model_id_components("microsoft/Florence-2-base"),
154                         ('Florence-2-base', 'microsoft', None, None, None, None))
155
156        ## Invalid cases ##
157
158        # Start with a dash and has dashes in rows
159        self.assertEqual(gguf.Metadata.get_model_id_components("mistralai/-Mistral--Nemo-Base-2407-"),
160                         ('-Mistral--Nemo-Base-2407-', 'mistralai', 'Mistral-Nemo-Base', None, '2407', None))
161
162        ## LoRA ##
163
164        self.assertEqual(gguf.Metadata.get_model_id_components("Llama-3-Instruct-abliteration-LoRA-8B"),
165                         ('Llama-3-Instruct-abliteration-LoRA-8B', None, 'Llama-3', 'Instruct-abliteration-LoRA', None, '8B'))
166
167        # Negative size --> output is a LoRA adaper --> prune "LoRA" out of the name to avoid redundancy with the suffix
168        self.assertEqual(gguf.Metadata.get_model_id_components("Llama-3-Instruct-abliteration-LoRA-8B", -1234),
169                         ('Llama-3-Instruct-abliteration-LoRA-8B', None, 'Llama-3', 'Instruct-abliteration', None, '8B'))
170
171    def test_apply_metadata_heuristic_from_model_card(self):
172        model_card = {
173            'tags': ['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'],
174            'model-index': [{'name': 'Mixtral-8x7B-Instruct-v0.1', 'results': []}],
175            'language': ['en'],
176            'datasets': ['teknium/OpenHermes-2.5'],
177            'widget': [{'example_title': 'Hermes 2 Pro', 'messages': [{'role': 'system', 'content': 'You are a sentient, superintelligent artificial general intelligence, here to teach and assist me.'}, {'role': 'user', 'content': 'Write a short story about Goku discovering kirby has teamed up with Majin Buu to destroy the world.'}]}],
178            'base_model': ["EmbeddedLLM/Mistral-7B-Merge-14-v0", "janai-hq/trinity-v1"]
179        }
180        got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
181        expect = gguf.Metadata()
182        expect.base_models=[{'name': 'Mistral 7B Merge 14 v0', 'organization': 'EmbeddedLLM', 'version': '14-v0', 'repo_url': 'https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0'}, {'name': 'Trinity v1', 'organization': 'Janai Hq', 'version': 'v1', 'repo_url': 'https://huggingface.co/janai-hq/trinity-v1'}]
183        expect.tags=['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl']
184        expect.languages=['en']
185        expect.datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]
186        self.assertEqual(got, expect)
187
188        # Base Model spec is inferred from model id
189        model_card = {'base_models': 'teknium/OpenHermes-2.5'}
190        expect = gguf.Metadata(base_models=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
191        got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
192        self.assertEqual(got, expect)
193
194        # Base Model spec is only url
195        model_card = {'base_models': ['https://huggingface.co/teknium/OpenHermes-2.5']}
196        expect = gguf.Metadata(base_models=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
197        got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
198        self.assertEqual(got, expect)
199
200        # Base Model spec is given directly
201        model_card = {'base_models': [{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]}
202        expect = gguf.Metadata(base_models=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
203        got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
204        self.assertEqual(got, expect)
205
206        # Dataset spec is inferred from model id
207        model_card = {'datasets': 'teknium/OpenHermes-2.5'}
208        expect = gguf.Metadata(datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
209        got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
210        self.assertEqual(got, expect)
211
212        # Dataset spec is only url
213        model_card = {'datasets': ['https://huggingface.co/teknium/OpenHermes-2.5']}
214        expect = gguf.Metadata(datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
215        got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
216        self.assertEqual(got, expect)
217
218        # Dataset spec is given directly
219        model_card = {'datasets': [{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]}
220        expect = gguf.Metadata(datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
221        got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
222        self.assertEqual(got, expect)
223
224    def test_apply_metadata_heuristic_from_hf_parameters(self):
225        hf_params = {"_name_or_path": "./hermes-2-pro-llama-3-8b-DPO"}
226        got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card=None, hf_params=hf_params, model_path=None)
227        expect = gguf.Metadata(name='Hermes 2 Pro Llama 3 8b DPO', finetune='DPO', basename='hermes-2-pro-llama-3', size_label='8B')
228        self.assertEqual(got, expect)
229
230    def test_apply_metadata_heuristic_from_model_dir(self):
231        model_dir_path = Path("./hermes-2-pro-llama-3-8b-DPO")
232        got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card=None, hf_params=None, model_path=model_dir_path)
233        expect = gguf.Metadata(name='Hermes 2 Pro Llama 3 8b DPO', finetune='DPO', basename='hermes-2-pro-llama-3', size_label='8B')
234        self.assertEqual(got, expect)
235
236
237if __name__ == "__main__":
238    unittest.main()