1#!/usr/bin/env python3
 2
 3import sys
 4import numpy as np
 5from pathlib import Path
 6import os
 7
 8# Add utils directory to path for direct script execution
 9sys.path.insert(0, str(Path(__file__).parent.parent / "utils"))
10from common import get_model_name_from_env_path, compare_tokens, exit_with_warning  # type: ignore[import-not-found]
11
12def quick_logits_check(pytorch_file, llamacpp_file):
13    """Lightweight sanity check before NMSE"""
14
15    try:
16        pytorch_logits = np.fromfile(pytorch_file, dtype=np.float32)
17        llamacpp_logits = np.fromfile(llamacpp_file, dtype=np.float32)
18    except Exception as e:
19        print(f"โŒ NOK: Failed to load files - {e}")
20        return False
21
22    # Check shapes match
23    if pytorch_logits.shape != llamacpp_logits.shape:
24        print(f"โŒ NOK: Shape mismatch - PyTorch: {pytorch_logits.shape}, llama.cpp: {llamacpp_logits.shape}")
25        return False
26
27    # Calculate key metrics
28    diff = pytorch_logits - llamacpp_logits
29    abs_diff = np.abs(diff)
30    max_diff = np.max(abs_diff)
31
32    # Get top 10 predictions from both models
33    pytorch_top10 = np.argsort(pytorch_logits)[-10:][::-1]
34    llamacpp_top10 = np.argsort(llamacpp_logits)[-10:][::-1]
35    print(f"Top 10 PyTorch logits: {pytorch_logits[pytorch_top10]}")
36    print(f"Top 10 llama.cpp logits: {llamacpp_logits[llamacpp_top10]}")
37    print(f"Max absolute difference: {max_diff:.4f}")
38
39    return True
40
41def main():
42    model_path = os.environ.get('MODEL_PATH')
43    model_name = get_model_name_from_env_path('MODEL_PATH')
44    data_dir = Path("data")
45    pytorch_file = data_dir / f"pytorch-{model_name}.bin"
46
47    llamacpp_model_name = get_model_name_from_env_path('CONVERTED_MODEL')
48    print(f"Using converted model: {llamacpp_model_name}")
49    llamacpp_file = data_dir / f"llamacpp-{llamacpp_model_name}.bin"
50
51    if not pytorch_file.exists():
52        print(f"Error: PyTorch logits file not found: {pytorch_file}")
53        print("Please run scripts/run-org-model.sh first to generate this file.")
54        sys.exit(1)
55
56    if not llamacpp_file.exists():
57        print(f"Error: llama.cpp logits file not found: {llamacpp_file}")
58        print("Please run scripts/run-converted-model.sh first to generate this file.")
59        sys.exit(1)
60
61    print("Checked all required files were found. Proceeding...\n")
62
63    # Verify tokens as they are a prerequisite for logits comparison.
64    print("๐Ÿ” Token Comparison Check")
65    print("=" * 40)
66    if not compare_tokens(f"pytorch-{model_name}", f"llamacpp-{llamacpp_model_name}"):
67        exit_with_warning("\nโŒ Token mismatch detected", model_path)
68    print()
69
70    print("๐Ÿ” GGML Model Validation for model ", model_name)
71    print("=" * 40)
72    print(f"PyTorch logits  : {pytorch_file}")
73    print(f"llama.cpp logits: {llamacpp_file}")
74    print()
75
76    success = quick_logits_check(pytorch_file, llamacpp_file)
77
78    # Exit with appropriate code
79    if success:
80        print("โœ… OK: Lightweight model check successful!")
81        print("       Ok to proceed with NMSE check...")
82        sys.exit(0)
83    else:
84        exit_with_warning(f"โŒ NOK: Top 10 predictions don't match - generation will differ", model_path)
85
86if __name__ == "__main__":
87    main()