1import argparse
  2import requests
  3import json
  4from pathlib import Path
  5import logging
  6
  7logger = logging.getLogger("compare-logprobs")
  8logging.basicConfig(level=logging.INFO)
  9
 10
 11DESCRIPTION = """
 12Compare logits between llama.cpp and another inference engine using OpenAI-compatible server endpoints.
 13
 14Unlike compare-logits.py, it allows dumping logits from a hosted API endpoint. Useful when it's not possible to run both models locally.
 15
 16Example usage:
 17    Step 1: Dump logits from two different servers
 18        python scripts/compare-logprobs.py dump logits_llama.log http://localhost:8080/v1/completions
 19        python scripts/compare-logprobs.py dump logits_other.log http://other-engine:8000/v1/completions
 20
 21        (optionally, you can add --api-key <key> if the endpoint requires authentication)
 22
 23    Step 2: Compare the dumped logits
 24        python scripts/compare-logprobs.py compare logits_llama.log logits_other.log report.md
 25"""
 26
 27
 28def generate_input_prompt(length: int) -> list[str]:
 29    CORPUS = """
 30    You are an advanced AI assistant capable of using tools to gather information, perform calculations, or execute tasks. Always think step by step before responding. If a user's query requires external data, computation, or actions beyond your internal knowledge, use the appropriate tools via function calls.
 31
 32    ### Tool Call Format:
 33    When you need to use a tool, output the call in this exact XML format. Include the opening and closing tags. Do not escape arguments; they will be parsed as plain text.
 34
 35    You can make multiple calls in one go by placing them one after another.
 36    """
 37    words = [w.strip() for w in CORPUS.strip().split(" ")]
 38    words = [w for w in words if len(w) > 0]  # filter out empty strings
 39    while len(words) < length:
 40        words += words
 41    return words[:length]
 42
 43
 44def dump_logits(
 45    endpoint: str,
 46    output_path: Path,
 47    input_words: list[str],
 48    pattern: list[tuple[bool, int]],
 49    api_key=None,
 50):
 51    logger.info(f"Dumping logits to {output_path} from endpoint {endpoint}...")
 52    words = input_words
 53    curr_text = ""
 54    n_total = sum(n for get, n in pattern if get)
 55    n_done = 0
 56    i_cur = 0
 57    i_total = len(words)
 58    with output_path.open("w") as f:
 59        for get, n in pattern:
 60            if not get:
 61                # skip n words
 62                for i in range(n):
 63                    curr_text += words.pop(0) + " "
 64                    i_cur += 1
 65                continue
 66            # get n words
 67            for i in range(n):
 68                curr_text += words.pop(0) + " "
 69                payload = {
 70                    "prompt": curr_text.strip(),
 71                    "temperature": 0.0,
 72                    "top_k": 1,
 73                    "max_tokens": 1,
 74                    "logprobs": 1,
 75                    "stream": False,
 76                }
 77                response = requests.post(
 78                    endpoint,
 79                    json=payload,
 80                    headers={"Authorization": f"Bearer {api_key}"} if api_key else {},
 81                )
 82                response.raise_for_status()
 83                data = response.json()
 84                data["__index"] = i_cur  # add index for easier debugging later
 85                data = json.dumps(data)
 86                f.write(f"{data}\n")
 87                n_done += 1
 88                i_cur += 1
 89                logger.info(
 90                    f"\n\n{data}\n\n[Step: {n_done}/{n_total} | Word: {i_cur}/{i_total}]"
 91                )
 92    logger.info(f"Logits dumped to {output_path}")
 93
 94
 95def get_token_logprobs(data: dict):
 96    logprobs = data["choices"][0]["logprobs"]
 97    if "content" in logprobs:
 98        # llama.cpp case
 99        top = logprobs["content"][0]["top_logprobs"][0]
100        return top["token"], top["logprob"]
101    else:
102        # vllm case
103        tokens = logprobs["tokens"]
104        token_logprobs = logprobs["token_logprobs"]
105        return tokens[0], token_logprobs[0]
106
107
108def clean_text(text: str) -> str:
109    return (
110        "'"
111        + text.replace("\n", "\\n")
112        .replace("\t", "\\t")
113        .replace("\r", "\\r")
114        .replace("|", "\\|")
115        + "'"
116    )
117
118
119def compare_logits(input1: Path, input2: Path, output_path: Path):
120    with input1.open("r") as f1, input2.open("r") as f2, output_path.open("w") as fout:
121        lines1 = f1.readlines()
122        lines2 = f2.readlines()
123
124        tab_header = [
125            "idx",
126            input1.name,
127            "logprob_1",
128            input2.name,
129            "logprob_2",
130            "diff (abs)",
131        ]
132        tab_entries = []
133        tab_max_widths = [len(h) for h in tab_header]
134
135        assert len(lines1) == len(
136            lines2
137        ), "Input files must have the same number of lines."
138
139        fout.write("# Logits Comparison Report\n\n")
140        for i, (line1, line2) in enumerate(zip(lines1, lines2)):
141            if not line1.strip() or not line2.strip():
142                continue  # skip empty lines
143
144            data1 = json.loads(line1)
145            data2 = json.loads(line2)
146
147            idx1 = data1.get("__index", -1)
148            idx2 = data2.get("__index", -1)
149            if idx1 != idx2:
150                logger.warning(
151                    f"Warning: Mismatched indices at line {i}: {idx1} vs {idx2}"
152                )
153
154            token1, logprob1 = get_token_logprobs(data1)
155            token2, logprob2 = get_token_logprobs(data2)
156
157            token1 = clean_text(token1)
158            token2 = clean_text(token2)
159            abs_diff = abs(logprob1 - logprob2)
160
161            tab_entries.append(
162                (
163                    str(idx1 + 1),
164                    token1,
165                    f"{logprob1:.4f}",
166                    token2,
167                    f"{logprob2:.4f}",
168                    f"{(abs_diff):.4f}",
169                )
170            )
171
172        for i in range(len(tab_entries)):
173            for j in range(len(tab_header)):
174                tab_max_widths[j] = max(tab_max_widths[j], len(tab_entries[i][j]))
175
176        output = ""
177        for j in range(len(tab_header)):
178            output += f"| {tab_header[j]:<{tab_max_widths[j]}} "
179        output += "|\n"
180        for j in range(len(tab_header)):
181            output += f"|{'-' * (tab_max_widths[j] + 2)}"
182        output += "|\n"
183        for entry in tab_entries:
184            for j in range(len(tab_header)):
185                output += f"| {entry[j]:<{tab_max_widths[j]}} "
186            output += "|\n"
187
188        logger.info("\n" + output)
189        fout.write(output)
190        logger.info(f"Report written to {output_path}")
191
192
193def parse_pattern(pattern: str) -> list[tuple[bool, int]]:
194    parts = pattern.split(",")
195    result = []
196    for i, part in enumerate(parts):
197        n = int(part)
198        if i % 2 == 0:
199            result.append((True, n))  # get n words
200        else:
201            result.append((False, n))  # skip n words
202    return result
203
204
205def parse_args() -> argparse.Namespace:
206    parser = argparse.ArgumentParser(
207        description=DESCRIPTION, formatter_class=argparse.RawTextHelpFormatter
208    )
209    subparsers = parser.add_subparsers(
210        dest="verb", required=True, help="action to perform"
211    )
212
213    # dump subcommand
214    parser_dump = subparsers.add_parser("dump", help="dump logits from an endpoint")
215    parser_dump.add_argument(
216        "output", type=Path, help="output path for dumped logits (.log)"
217    )
218    parser_dump.add_argument(
219        "endpoint", type=str, help="OAI-compat /completions endpoint"
220    )
221    parser_dump.add_argument(
222        "--api-key",
223        type=str,
224        default=None,
225        help="API key for authentication (if required)",
226    )
227    parser_dump.add_argument(
228        "--file",
229        type=Path,
230        default=None,
231        help="File containing prompt to use instead of the default",
232    )
233    parser_dump.add_argument(
234        "--pattern",
235        type=str,
236        default="10,1000,10,4000,10",
237        help="Pattern n_get,n_skip,... where n_get is number of words to get and n_skip is number of words to skip (num of words, NOT num of tokens)",
238    )
239
240    # compare subcommand
241    parser_compare = subparsers.add_parser(
242        "compare", help="compare two dumped logits files"
243    )
244    parser_compare.add_argument("input1", type=Path, help="first input file (.log)")
245    parser_compare.add_argument("input2", type=Path, help="second input file (.log)")
246    parser_compare.add_argument(
247        "output", type=Path, help="output path for comparison report (.md)"
248    )
249
250    try:
251        return parser.parse_args()
252    except Exception as e:
253        parser.print_help()
254        raise e
255
256
257def main():
258    args = parse_args()
259
260    if args.verb == "dump":
261        pattern = parse_pattern(args.pattern)
262        input_length = sum(n for _, n in pattern)
263        input_words = generate_input_prompt(input_length)
264        if args.file is not None:
265            with args.file.open("r") as f:
266                input_words = f.read().strip().split(" ")
267                if input_length < sum(n for _, n in pattern):
268                    raise ValueError(
269                        f"Input file has only {input_length} words, but pattern requires at least {input_length} words."
270                    )
271                input_length = len(input_words)
272        logger.info(f"Using {input_length} words")
273        dump_logits(args.endpoint, args.output, input_words, pattern, args.api_key)
274    elif args.verb == "compare":
275        compare_logits(args.input1, args.input2, args.output)
276    else:
277        raise ValueError(f"Unknown verb: {args.verb}")
278
279
280if __name__ == "__main__":
281    main()