1#!/usr/bin/env python3
2
3import numpy as np
4import sys
5import os
6import argparse
7from pathlib import Path
8from common import get_model_name_from_env_path # type: ignore[import-not-found]
9
10def calculate_nmse(reference, test):
11 mse = np.mean((test - reference) ** 2)
12 ref_var = np.var(reference)
13 if ref_var == 0:
14 nmse = float('inf') if mse > 0 else 0.0
15 return mse, mse, ref_var
16
17 nmse = mse / ref_var
18
19 return nmse, mse, ref_var
20
21def load_logits(file_path):
22 if not os.path.exists(file_path):
23 raise FileNotFoundError(f"File not found: {file_path}")
24
25 if file_path.suffix == '.npy':
26 return np.load(file_path)
27 elif file_path.suffix == '.bin':
28 return np.fromfile(file_path, dtype=np.float32)
29 else:
30 # Try to load as text file
31 try:
32 # If it has index format "0: value", extract just values
33 data = []
34 with open(file_path, 'r') as f:
35 for line in f:
36 if ':' in line:
37 # Format: "index: value"
38 value = float(line.split(':')[1].strip())
39 else:
40 # Just the value
41 value = float(line.strip())
42 data.append(value)
43 return np.array(data, dtype=np.float32)
44 except:
45 return np.loadtxt(file_path, dtype=np.float32)
46
47def interpret_nmse(nmse):
48 """Provide interpretation of NMSE value"""
49 if nmse == 0:
50 return "Perfect match", "๐"
51 elif nmse < 1e-6:
52 return "Essentially identical", "โ
"
53 elif nmse < 1e-4:
54 return "Excellent match", "โ
"
55 elif nmse < 1e-3:
56 return "Very good match", "๐"
57 elif nmse < 1e-2:
58 return "Good match", "๐"
59 elif nmse < 0.1:
60 return "Acceptable match", "โ ๏ธ"
61 elif nmse < 1.0:
62 return "Poor match", "โ"
63 else:
64 return "Very poor match (worse than noise)", "โ"
65
66def main():
67 parser = argparse.ArgumentParser(description='Validate model logits')
68 parser.add_argument('-m', '--model-path', required=True, help='Path to the model directory')
69 args = parser.parse_args()
70
71 model_name = get_model_name_from_env_path('MODEL_PATH')
72 data_dir = Path("data")
73
74 pytorch_file = data_dir / f"pytorch-{model_name}.bin"
75
76 llamacpp_model_name = get_model_name_from_env_path('CONVERTED_MODEL')
77 llamacpp_file = data_dir / f"llamacpp-{llamacpp_model_name}.bin"
78
79 print(f"Model name: {model_name}")
80 print(f"PyTorch logits file: {pytorch_file}")
81 print(f"llama.cpp logits file: {llamacpp_file}")
82
83 reference_file = pytorch_file
84 test_file = llamacpp_file
85
86 print("๐ NMSE Check for Model Comparison")
87 print("=" * 50)
88 print(f"Reference (ground truth): {reference_file}")
89 print(f"Test (to evaluate): {test_file}")
90 print()
91
92 try:
93 print("Loading reference logits...")
94 reference = load_logits(reference_file)
95 print(f" Shape: {reference.shape}, Type: {reference.dtype}")
96
97 print("Loading test logits...")
98 test = load_logits(test_file)
99 print(f" Shape: {test.shape}, Type: {test.dtype}")
100
101 # Check shapes match
102 if reference.shape != test.shape:
103 print(f"\nโ Error: Shape mismatch!")
104 print(f" Reference: {reference.shape}")
105 print(f" Test: {test.shape}")
106 sys.exit(1)
107
108 print(f"\nโ
Shapes match: {reference.shape}")
109
110 nmse, mse, ref_var = calculate_nmse(reference, test)
111
112 # Additional metrics
113 max_abs_error = np.max(np.abs(test - reference))
114 mean_abs_error = np.mean(np.abs(test - reference))
115
116 # Results
117 print(f"\n๐ METRICS")
118 print("=" * 30)
119 print(f"MSE (Mean Squared Error): {mse:.6e}")
120 print(f"Reference Variance: {ref_var:.6e}")
121 print(f"NMSE: {nmse:.6e}")
122 print(f"Max Absolute Error: {max_abs_error:.6f}")
123 print(f"Mean Absolute Error: {mean_abs_error:.6f}")
124
125 # NMSE in dB (common in signal processing)
126 if nmse > 0:
127 nmse_db = 10 * np.log10(nmse)
128 print(f"NMSE (dB): {nmse_db:.2f} dB")
129
130 # Interpretation
131 interpretation, emoji = interpret_nmse(nmse)
132 print(f"\n๐ฏ INTERPRETATION")
133 print("=" * 30)
134 print(f"{emoji} {interpretation}")
135
136 # Detailed guidance
137 print(f"\n๐ GUIDANCE")
138 print("=" * 30)
139 if nmse < 1e-3:
140 print("โ
EXCELLENT: Your GGML conversion is working very well!")
141 print(" The differences are negligible for practical use.")
142 elif nmse < 1e-2:
143 print("๐ GOOD: Your GGML conversion is working well.")
144 print(" Small differences are likely due to precision/quantization.")
145 elif nmse < 0.1:
146 print("โ ๏ธ ACCEPTABLE: Conversion is working but with some differences.")
147 print(" Check if you're using quantization (Q4, Q8, etc.)")
148 print(" Test generation quality to see if it's acceptable.")
149 else:
150 print("โ PROBLEMATIC: Large differences detected.")
151 print(" Check your conversion process for potential issues.")
152 print(" Verify you're using the same model weights.")
153
154 # NMSE benchmarks
155 print(f"\n๐ NMSE BENCHMARKS")
156 print("=" * 30)
157 print("< 1e-6: Essentially identical")
158 print("< 1e-4: Excellent (typical for good conversions)")
159 print("< 1e-3: Very good")
160 print("< 1e-2: Good (acceptable for most use cases)")
161 print("< 0.1: Acceptable (may need verification)")
162 print("> 1.0: Poor (worse than random)")
163
164 # Exit code based on NMSE
165 if nmse < 1e-2:
166 print(f"\nโ
RESULT: PASS (NMSE = {nmse:.2e})")
167 sys.exit(0)
168 else:
169 print(f"\nโ RESULT: NEEDS REVIEW (NMSE = {nmse:.2e})")
170 sys.exit(1)
171
172 except Exception as e:
173 print(f"โ Error: {e}")
174 sys.exit(1)
175
176if __name__ == "__main__":
177 main()