1#!/usr/bin/env python3
  2
  3import numpy as np
  4import sys
  5import os
  6import argparse
  7from pathlib import Path
  8from common import get_model_name_from_env_path  # type: ignore[import-not-found]
  9
 10def calculate_nmse(reference, test):
 11    mse = np.mean((test - reference) ** 2)
 12    ref_var = np.var(reference)
 13    if ref_var == 0:
 14        nmse = float('inf') if mse > 0 else 0.0
 15        return mse, mse, ref_var
 16
 17    nmse = mse / ref_var
 18
 19    return nmse, mse, ref_var
 20
 21def load_logits(file_path):
 22    if not os.path.exists(file_path):
 23        raise FileNotFoundError(f"File not found: {file_path}")
 24
 25    if file_path.suffix == '.npy':
 26        return np.load(file_path)
 27    elif file_path.suffix == '.bin':
 28        return np.fromfile(file_path, dtype=np.float32)
 29    else:
 30        # Try to load as text file
 31        try:
 32            # If it has index format "0: value", extract just values
 33            data = []
 34            with open(file_path, 'r') as f:
 35                for line in f:
 36                    if ':' in line:
 37                        # Format: "index: value"
 38                        value = float(line.split(':')[1].strip())
 39                    else:
 40                        # Just the value
 41                        value = float(line.strip())
 42                    data.append(value)
 43            return np.array(data, dtype=np.float32)
 44        except:
 45            return np.loadtxt(file_path, dtype=np.float32)
 46
 47def interpret_nmse(nmse):
 48    """Provide interpretation of NMSE value"""
 49    if nmse == 0:
 50        return "Perfect match", "๐ŸŽ‰"
 51    elif nmse < 1e-6:
 52        return "Essentially identical", "โœ…"
 53    elif nmse < 1e-4:
 54        return "Excellent match", "โœ…"
 55    elif nmse < 1e-3:
 56        return "Very good match", "๐Ÿ‘"
 57    elif nmse < 1e-2:
 58        return "Good match", "๐Ÿ‘"
 59    elif nmse < 0.1:
 60        return "Acceptable match", "โš ๏ธ"
 61    elif nmse < 1.0:
 62        return "Poor match", "โŒ"
 63    else:
 64        return "Very poor match (worse than noise)", "โŒ"
 65
 66def main():
 67    parser = argparse.ArgumentParser(description='Validate model logits')
 68    parser.add_argument('-m', '--model-path', required=True,  help='Path to the model directory')
 69    args = parser.parse_args()
 70
 71    model_name = get_model_name_from_env_path('MODEL_PATH')
 72    data_dir = Path("data")
 73
 74    pytorch_file = data_dir / f"pytorch-{model_name}.bin"
 75
 76    llamacpp_model_name = get_model_name_from_env_path('CONVERTED_MODEL')
 77    llamacpp_file = data_dir / f"llamacpp-{llamacpp_model_name}.bin"
 78
 79    print(f"Model name: {model_name}")
 80    print(f"PyTorch logits file: {pytorch_file}")
 81    print(f"llama.cpp logits file: {llamacpp_file}")
 82
 83    reference_file = pytorch_file
 84    test_file = llamacpp_file
 85
 86    print("๐Ÿ“Š NMSE Check for Model Comparison")
 87    print("=" * 50)
 88    print(f"Reference (ground truth): {reference_file}")
 89    print(f"Test (to evaluate):       {test_file}")
 90    print()
 91
 92    try:
 93        print("Loading reference logits...")
 94        reference = load_logits(reference_file)
 95        print(f"  Shape: {reference.shape}, Type: {reference.dtype}")
 96
 97        print("Loading test logits...")
 98        test = load_logits(test_file)
 99        print(f"  Shape: {test.shape}, Type: {test.dtype}")
100
101        # Check shapes match
102        if reference.shape != test.shape:
103            print(f"\nโŒ Error: Shape mismatch!")
104            print(f"  Reference: {reference.shape}")
105            print(f"  Test: {test.shape}")
106            sys.exit(1)
107
108        print(f"\nโœ… Shapes match: {reference.shape}")
109
110        nmse, mse, ref_var = calculate_nmse(reference, test)
111
112        # Additional metrics
113        max_abs_error = np.max(np.abs(test - reference))
114        mean_abs_error = np.mean(np.abs(test - reference))
115
116        # Results
117        print(f"\n๐Ÿ“ˆ METRICS")
118        print("=" * 30)
119        print(f"MSE (Mean Squared Error):     {mse:.6e}")
120        print(f"Reference Variance:           {ref_var:.6e}")
121        print(f"NMSE:                         {nmse:.6e}")
122        print(f"Max Absolute Error:           {max_abs_error:.6f}")
123        print(f"Mean Absolute Error:          {mean_abs_error:.6f}")
124
125        # NMSE in dB (common in signal processing)
126        if nmse > 0:
127            nmse_db = 10 * np.log10(nmse)
128            print(f"NMSE (dB):                    {nmse_db:.2f} dB")
129
130        # Interpretation
131        interpretation, emoji = interpret_nmse(nmse)
132        print(f"\n๐ŸŽฏ INTERPRETATION")
133        print("=" * 30)
134        print(f"{emoji} {interpretation}")
135
136        # Detailed guidance
137        print(f"\n๐Ÿ“‹ GUIDANCE")
138        print("=" * 30)
139        if nmse < 1e-3:
140            print("โœ… EXCELLENT: Your GGML conversion is working very well!")
141            print("   The differences are negligible for practical use.")
142        elif nmse < 1e-2:
143            print("๐Ÿ‘ GOOD: Your GGML conversion is working well.")
144            print("   Small differences are likely due to precision/quantization.")
145        elif nmse < 0.1:
146            print("โš ๏ธ  ACCEPTABLE: Conversion is working but with some differences.")
147            print("   Check if you're using quantization (Q4, Q8, etc.)")
148            print("   Test generation quality to see if it's acceptable.")
149        else:
150            print("โŒ PROBLEMATIC: Large differences detected.")
151            print("   Check your conversion process for potential issues.")
152            print("   Verify you're using the same model weights.")
153
154        # NMSE benchmarks
155        print(f"\n๐Ÿ“š NMSE BENCHMARKS")
156        print("=" * 30)
157        print("< 1e-6:  Essentially identical")
158        print("< 1e-4:  Excellent (typical for good conversions)")
159        print("< 1e-3:  Very good")
160        print("< 1e-2:  Good (acceptable for most use cases)")
161        print("< 0.1:   Acceptable (may need verification)")
162        print("> 1.0:   Poor (worse than random)")
163
164        # Exit code based on NMSE
165        if nmse < 1e-2:
166            print(f"\nโœ… RESULT: PASS (NMSE = {nmse:.2e})")
167            sys.exit(0)
168        else:
169            print(f"\nโŒ RESULT: NEEDS REVIEW (NMSE = {nmse:.2e})")
170            sys.exit(1)
171
172    except Exception as e:
173        print(f"โŒ Error: {e}")
174        sys.exit(1)
175
176if __name__ == "__main__":
177    main()