1import argparse
2import requests
3import json
4from pathlib import Path
5import logging
6
7logger = logging.getLogger("compare-logprobs")
8logging.basicConfig(level=logging.INFO)
9
10
11DESCRIPTION = """
12Compare logits between llama.cpp and another inference engine using OpenAI-compatible server endpoints.
13
14Unlike compare-logits.py, it allows dumping logits from a hosted API endpoint. Useful when it's not possible to run both models locally.
15
16Example usage:
17 Step 1: Dump logits from two different servers
18 python scripts/compare-logprobs.py dump logits_llama.log http://localhost:8080/v1/completions
19 python scripts/compare-logprobs.py dump logits_other.log http://other-engine:8000/v1/completions
20
21 (optionally, you can add --api-key <key> if the endpoint requires authentication)
22
23 Step 2: Compare the dumped logits
24 python scripts/compare-logprobs.py compare logits_llama.log logits_other.log report.md
25"""
26
27
28def generate_input_prompt(length: int) -> list[str]:
29 CORPUS = """
30 You are an advanced AI assistant capable of using tools to gather information, perform calculations, or execute tasks. Always think step by step before responding. If a user's query requires external data, computation, or actions beyond your internal knowledge, use the appropriate tools via function calls.
31
32 ### Tool Call Format:
33 When you need to use a tool, output the call in this exact XML format. Include the opening and closing tags. Do not escape arguments; they will be parsed as plain text.
34
35 You can make multiple calls in one go by placing them one after another.
36 """
37 words = [w.strip() for w in CORPUS.strip().split(" ")]
38 words = [w for w in words if len(w) > 0] # filter out empty strings
39 while len(words) < length:
40 words += words
41 return words[:length]
42
43
44def dump_logits(
45 endpoint: str,
46 output_path: Path,
47 input_words: list[str],
48 pattern: list[tuple[bool, int]],
49 api_key=None,
50):
51 logger.info(f"Dumping logits to {output_path} from endpoint {endpoint}...")
52 words = input_words
53 curr_text = ""
54 n_total = sum(n for get, n in pattern if get)
55 n_done = 0
56 i_cur = 0
57 i_total = len(words)
58 with output_path.open("w") as f:
59 for get, n in pattern:
60 if not get:
61 # skip n words
62 for i in range(n):
63 curr_text += words.pop(0) + " "
64 i_cur += 1
65 continue
66 # get n words
67 for i in range(n):
68 curr_text += words.pop(0) + " "
69 payload = {
70 "prompt": curr_text.strip(),
71 "temperature": 0.0,
72 "top_k": 1,
73 "max_tokens": 1,
74 "logprobs": 1,
75 "stream": False,
76 }
77 response = requests.post(
78 endpoint,
79 json=payload,
80 headers={"Authorization": f"Bearer {api_key}"} if api_key else {},
81 )
82 response.raise_for_status()
83 data = response.json()
84 data["__index"] = i_cur # add index for easier debugging later
85 data = json.dumps(data)
86 f.write(f"{data}\n")
87 n_done += 1
88 i_cur += 1
89 logger.info(
90 f"\n\n{data}\n\n[Step: {n_done}/{n_total} | Word: {i_cur}/{i_total}]"
91 )
92 logger.info(f"Logits dumped to {output_path}")
93
94
95def get_token_logprobs(data: dict):
96 logprobs = data["choices"][0]["logprobs"]
97 if "content" in logprobs:
98 # llama.cpp case
99 top = logprobs["content"][0]["top_logprobs"][0]
100 return top["token"], top["logprob"]
101 else:
102 # vllm case
103 tokens = logprobs["tokens"]
104 token_logprobs = logprobs["token_logprobs"]
105 return tokens[0], token_logprobs[0]
106
107
108def clean_text(text: str) -> str:
109 return (
110 "'"
111 + text.replace("\n", "\\n")
112 .replace("\t", "\\t")
113 .replace("\r", "\\r")
114 .replace("|", "\\|")
115 + "'"
116 )
117
118
119def compare_logits(input1: Path, input2: Path, output_path: Path):
120 with input1.open("r") as f1, input2.open("r") as f2, output_path.open("w") as fout:
121 lines1 = f1.readlines()
122 lines2 = f2.readlines()
123
124 tab_header = [
125 "idx",
126 input1.name,
127 "logprob_1",
128 input2.name,
129 "logprob_2",
130 "diff (abs)",
131 ]
132 tab_entries = []
133 tab_max_widths = [len(h) for h in tab_header]
134
135 assert len(lines1) == len(
136 lines2
137 ), "Input files must have the same number of lines."
138
139 fout.write("# Logits Comparison Report\n\n")
140 for i, (line1, line2) in enumerate(zip(lines1, lines2)):
141 if not line1.strip() or not line2.strip():
142 continue # skip empty lines
143
144 data1 = json.loads(line1)
145 data2 = json.loads(line2)
146
147 idx1 = data1.get("__index", -1)
148 idx2 = data2.get("__index", -1)
149 if idx1 != idx2:
150 logger.warning(
151 f"Warning: Mismatched indices at line {i}: {idx1} vs {idx2}"
152 )
153
154 token1, logprob1 = get_token_logprobs(data1)
155 token2, logprob2 = get_token_logprobs(data2)
156
157 token1 = clean_text(token1)
158 token2 = clean_text(token2)
159 abs_diff = abs(logprob1 - logprob2)
160
161 tab_entries.append(
162 (
163 str(idx1 + 1),
164 token1,
165 f"{logprob1:.4f}",
166 token2,
167 f"{logprob2:.4f}",
168 f"{(abs_diff):.4f}",
169 )
170 )
171
172 for i in range(len(tab_entries)):
173 for j in range(len(tab_header)):
174 tab_max_widths[j] = max(tab_max_widths[j], len(tab_entries[i][j]))
175
176 output = ""
177 for j in range(len(tab_header)):
178 output += f"| {tab_header[j]:<{tab_max_widths[j]}} "
179 output += "|\n"
180 for j in range(len(tab_header)):
181 output += f"|{'-' * (tab_max_widths[j] + 2)}"
182 output += "|\n"
183 for entry in tab_entries:
184 for j in range(len(tab_header)):
185 output += f"| {entry[j]:<{tab_max_widths[j]}} "
186 output += "|\n"
187
188 logger.info("\n" + output)
189 fout.write(output)
190 logger.info(f"Report written to {output_path}")
191
192
193def parse_pattern(pattern: str) -> list[tuple[bool, int]]:
194 parts = pattern.split(",")
195 result = []
196 for i, part in enumerate(parts):
197 n = int(part)
198 if i % 2 == 0:
199 result.append((True, n)) # get n words
200 else:
201 result.append((False, n)) # skip n words
202 return result
203
204
205def parse_args() -> argparse.Namespace:
206 parser = argparse.ArgumentParser(
207 description=DESCRIPTION, formatter_class=argparse.RawTextHelpFormatter
208 )
209 subparsers = parser.add_subparsers(
210 dest="verb", required=True, help="action to perform"
211 )
212
213 # dump subcommand
214 parser_dump = subparsers.add_parser("dump", help="dump logits from an endpoint")
215 parser_dump.add_argument(
216 "output", type=Path, help="output path for dumped logits (.log)"
217 )
218 parser_dump.add_argument(
219 "endpoint", type=str, help="OAI-compat /completions endpoint"
220 )
221 parser_dump.add_argument(
222 "--api-key",
223 type=str,
224 default=None,
225 help="API key for authentication (if required)",
226 )
227 parser_dump.add_argument(
228 "--file",
229 type=Path,
230 default=None,
231 help="File containing prompt to use instead of the default",
232 )
233 parser_dump.add_argument(
234 "--pattern",
235 type=str,
236 default="10,1000,10,4000,10",
237 help="Pattern n_get,n_skip,... where n_get is number of words to get and n_skip is number of words to skip (num of words, NOT num of tokens)",
238 )
239
240 # compare subcommand
241 parser_compare = subparsers.add_parser(
242 "compare", help="compare two dumped logits files"
243 )
244 parser_compare.add_argument("input1", type=Path, help="first input file (.log)")
245 parser_compare.add_argument("input2", type=Path, help="second input file (.log)")
246 parser_compare.add_argument(
247 "output", type=Path, help="output path for comparison report (.md)"
248 )
249
250 try:
251 return parser.parse_args()
252 except Exception as e:
253 parser.print_help()
254 raise e
255
256
257def main():
258 args = parse_args()
259
260 if args.verb == "dump":
261 pattern = parse_pattern(args.pattern)
262 input_length = sum(n for _, n in pattern)
263 input_words = generate_input_prompt(input_length)
264 if args.file is not None:
265 with args.file.open("r") as f:
266 input_words = f.read().strip().split(" ")
267 if input_length < sum(n for _, n in pattern):
268 raise ValueError(
269 f"Input file has only {input_length} words, but pattern requires at least {input_length} words."
270 )
271 input_length = len(input_words)
272 logger.info(f"Using {input_length} words")
273 dump_logits(args.endpoint, args.output, input_words, pattern, args.api_key)
274 elif args.verb == "compare":
275 compare_logits(args.input1, args.input2, args.output)
276 else:
277 raise ValueError(f"Unknown verb: {args.verb}")
278
279
280if __name__ == "__main__":
281 main()