1#!/usr/bin/env python3
2
3import sys
4import numpy as np
5from pathlib import Path
6import os
7
8# Add utils directory to path for direct script execution
9sys.path.insert(0, str(Path(__file__).parent.parent / "utils"))
10from common import get_model_name_from_env_path, compare_tokens, exit_with_warning # type: ignore[import-not-found]
11
12def quick_logits_check(pytorch_file, llamacpp_file):
13 """Lightweight sanity check before NMSE"""
14
15 try:
16 pytorch_logits = np.fromfile(pytorch_file, dtype=np.float32)
17 llamacpp_logits = np.fromfile(llamacpp_file, dtype=np.float32)
18 except Exception as e:
19 print(f"โ NOK: Failed to load files - {e}")
20 return False
21
22 # Check shapes match
23 if pytorch_logits.shape != llamacpp_logits.shape:
24 print(f"โ NOK: Shape mismatch - PyTorch: {pytorch_logits.shape}, llama.cpp: {llamacpp_logits.shape}")
25 return False
26
27 # Calculate key metrics
28 diff = pytorch_logits - llamacpp_logits
29 abs_diff = np.abs(diff)
30 max_diff = np.max(abs_diff)
31
32 # Get top 10 predictions from both models
33 pytorch_top10 = np.argsort(pytorch_logits)[-10:][::-1]
34 llamacpp_top10 = np.argsort(llamacpp_logits)[-10:][::-1]
35 print(f"Top 10 PyTorch logits: {pytorch_logits[pytorch_top10]}")
36 print(f"Top 10 llama.cpp logits: {llamacpp_logits[llamacpp_top10]}")
37 print(f"Max absolute difference: {max_diff:.4f}")
38
39 return True
40
41def main():
42 model_path = os.environ.get('MODEL_PATH')
43 model_name = get_model_name_from_env_path('MODEL_PATH')
44 data_dir = Path("data")
45 pytorch_file = data_dir / f"pytorch-{model_name}.bin"
46
47 llamacpp_model_name = get_model_name_from_env_path('CONVERTED_MODEL')
48 print(f"Using converted model: {llamacpp_model_name}")
49 llamacpp_file = data_dir / f"llamacpp-{llamacpp_model_name}.bin"
50
51 if not pytorch_file.exists():
52 print(f"Error: PyTorch logits file not found: {pytorch_file}")
53 print("Please run scripts/run-org-model.sh first to generate this file.")
54 sys.exit(1)
55
56 if not llamacpp_file.exists():
57 print(f"Error: llama.cpp logits file not found: {llamacpp_file}")
58 print("Please run scripts/run-converted-model.sh first to generate this file.")
59 sys.exit(1)
60
61 print("Checked all required files were found. Proceeding...\n")
62
63 # Verify tokens as they are a prerequisite for logits comparison.
64 print("๐ Token Comparison Check")
65 print("=" * 40)
66 if not compare_tokens(f"pytorch-{model_name}", f"llamacpp-{llamacpp_model_name}"):
67 exit_with_warning("\nโ Token mismatch detected", model_path)
68 print()
69
70 print("๐ GGML Model Validation for model ", model_name)
71 print("=" * 40)
72 print(f"PyTorch logits : {pytorch_file}")
73 print(f"llama.cpp logits: {llamacpp_file}")
74 print()
75
76 success = quick_logits_check(pytorch_file, llamacpp_file)
77
78 # Exit with appropriate code
79 if success:
80 print("โ
OK: Lightweight model check successful!")
81 print(" Ok to proceed with NMSE check...")
82 sys.exit(0)
83 else:
84 exit_with_warning(f"โ NOK: Top 10 predictions don't match - generation will differ", model_path)
85
86if __name__ == "__main__":
87 main()