1#!/usr/bin/env python3
 2
 3import argparse
 4import sys
 5from common import compare_tokens  # type: ignore
 6
 7
 8def parse_arguments():
 9    parser = argparse.ArgumentParser(
10        description='Compare tokens between two models',
11        formatter_class=argparse.RawDescriptionHelpFormatter,
12        epilog="""
13Examples:
14  %(prog)s pytorch-gemma-3-270m-it llamacpp-gemma-3-270m-it-bf16
15        """
16    )
17    parser.add_argument(
18        'original',
19        help='Original model name'
20    )
21    parser.add_argument(
22        'converted',
23        help='Converted model name'
24    )
25    parser.add_argument(
26        '-s', '--suffix',
27        default='',
28        help='Type suffix (e.g., "-embeddings")'
29    )
30    parser.add_argument(
31        '-d', '--data-dir',
32        default='data',
33        help='Directory containing token files (default: data)'
34    )
35    parser.add_argument(
36        '-v', '--verbose',
37        action='store_true',
38        help='Print prompts from both models'
39    )
40    return parser.parse_args()
41
42
43def main():
44    args = parse_arguments()
45
46    if args.verbose:
47        from pathlib import Path
48        data_dir = Path(args.data_dir)
49
50        prompt1_file = data_dir / f"{args.original}{args.suffix}-prompt.txt"
51        prompt2_file = data_dir / f"{args.converted}{args.suffix}-prompt.txt"
52
53        if prompt1_file.exists():
54            print(f"\nOriginal model prompt ({args.original}):")
55            print(f"  {prompt1_file.read_text().strip()}")
56
57        if prompt2_file.exists():
58            print(f"\nConverted model prompt ({args.converted}):")
59            print(f"  {prompt2_file.read_text().strip()}")
60
61        print()
62
63    result = compare_tokens(
64        args.original,
65        args.converted,
66        type_suffix=args.suffix,
67        output_dir=args.data_dir
68    )
69
70    # Enable the script to be used in shell scripts so that they can check
71    # the exit code for success/failure.
72    sys.exit(0 if result else 1)
73
74
75if __name__ == "__main__":
76    main()