1#!/usr/bin/env python3
  2from __future__ import annotations
  3
  4import argparse
  5import itertools
  6import json
  7import re
  8import sys
  9from typing import Any, List, Optional, Set, Tuple, Union
 10
 11def _build_repetition(item_rule, min_items, max_items, separator_rule=None):
 12
 13    if max_items == 0:
 14        return ""
 15
 16    if min_items == 0 and max_items == 1:
 17        return f'{item_rule}?'
 18
 19    if not separator_rule:
 20        if min_items == 1 and max_items is None:
 21            return f'{item_rule}+'
 22        elif min_items == 0 and max_items is None:
 23            return f'{item_rule}*'
 24        else:
 25            return f'{item_rule}{{{min_items},{max_items if max_items is not None else ""}}}'
 26
 27    result = item_rule + ' ' + _build_repetition(f'({separator_rule} {item_rule})', min_items - 1 if min_items > 0 else 0, max_items - 1 if max_items is not None else None)
 28    return f'({result})?' if min_items == 0 else result
 29
 30def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], out: list, decimals_left: int = 16, top_level: bool = True):
 31    has_min = min_value != None
 32    has_max = max_value != None
 33
 34    def digit_range(from_char: str, to_char: str):
 35        out.append("[")
 36        if from_char == to_char:
 37            out.append(from_char)
 38        else:
 39            out.append(from_char)
 40            out.append("-")
 41            out.append(to_char)
 42        out.append("]")
 43
 44    def more_digits(min_digits: int, max_digits: int):
 45        out.append("[0-9]")
 46        if min_digits == max_digits and min_digits == 1:
 47            return
 48        out.append("{")
 49        out.append(str(min_digits))
 50        if max_digits != min_digits:
 51            out.append(",")
 52            if max_digits != sys.maxsize:
 53                out.append(str(max_digits))
 54        out.append("}")
 55
 56    def uniform_range(from_str: str, to_str: str):
 57        i = 0
 58        while i < len(from_str) and from_str[i] == to_str[i]:
 59            i += 1
 60        if i > 0:
 61            out.append("\"")
 62            out.append(from_str[:i])
 63            out.append("\"")
 64        if i < len(from_str):
 65            if i > 0:
 66                out.append(" ")
 67            sub_len = len(from_str) - i - 1
 68            if sub_len > 0:
 69                from_sub = from_str[i+1:]
 70                to_sub = to_str[i+1:]
 71                sub_zeros = "0" * sub_len
 72                sub_nines = "9" * sub_len
 73
 74                to_reached = False
 75                out.append("(")
 76                if from_sub == sub_zeros:
 77                    digit_range(from_str[i], chr(ord(to_str[i]) - 1))
 78                    out.append(" ")
 79                    more_digits(sub_len, sub_len)
 80                else:
 81                    out.append("[")
 82                    out.append(from_str[i])
 83                    out.append("] ")
 84                    out.append("(")
 85                    uniform_range(from_sub, sub_nines)
 86                    out.append(")")
 87                    if ord(from_str[i]) < ord(to_str[i]) - 1:
 88                        out.append(" | ")
 89                        if to_sub == sub_nines:
 90                            digit_range(chr(ord(from_str[i]) + 1), to_str[i])
 91                            to_reached = True
 92                        else:
 93                            digit_range(chr(ord(from_str[i]) + 1), chr(ord(to_str[i]) - 1))
 94                        out.append(" ")
 95                        more_digits(sub_len, sub_len)
 96                if not to_reached:
 97                    out.append(" | ")
 98                    digit_range(to_str[i], to_str[i])
 99                    out.append(" ")
100                    uniform_range(sub_zeros, to_sub)
101                out.append(")")
102            else:
103                out.append("[")
104                out.append(from_str[i])
105                out.append("-")
106                out.append(to_str[i])
107                out.append("]")
108
109    if has_min and has_max:
110        if min_value < 0 and max_value < 0:
111            out.append("\"-\" (")
112            _generate_min_max_int(-max_value, -min_value, out, decimals_left, top_level=True)
113            out.append(")")
114            return
115
116        if min_value < 0:
117            out.append("\"-\" (")
118            _generate_min_max_int(0, -min_value, out, decimals_left, top_level=True)
119            out.append(") | ")
120            min_value = 0
121
122        min_s = str(min_value)
123        max_s = str(max_value)
124        min_digits = len(min_s)
125        max_digits = len(max_s)
126
127        for digits in range(min_digits, max_digits):
128            uniform_range(min_s, "9" * digits)
129            min_s = "1" + "0" * digits
130            out.append(" | ")
131        uniform_range(min_s, max_s)
132        return
133
134    less_decimals = max(decimals_left - 1, 1)
135
136    if has_min:
137        if min_value < 0:
138            out.append("\"-\" (")
139            _generate_min_max_int(None, -min_value, out, decimals_left, top_level=False)
140            out.append(") | [0] | [1-9] ")
141            more_digits(0, decimals_left - 1)
142        elif min_value == 0:
143            if top_level:
144                out.append("[0] | [1-9] ")
145                more_digits(0, less_decimals)
146            else:
147                more_digits(1, decimals_left)
148        elif min_value <= 9:
149            c = str(min_value)
150            range_start = '1' if top_level else '0'
151            if c > range_start:
152                digit_range(range_start, chr(ord(c) - 1))
153                out.append(" ")
154                more_digits(1, less_decimals)
155                out.append(" | ")
156            digit_range(c, "9")
157            out.append(" ")
158            more_digits(0, less_decimals)
159        else:
160            min_s = str(min_value)
161            length = len(min_s)
162            c = min_s[0]
163
164            if c > "1":
165                digit_range("1" if top_level else "0", chr(ord(c) - 1))
166                out.append(" ")
167                more_digits(length, less_decimals)
168                out.append(" | ")
169            digit_range(c, c)
170            out.append(" (")
171            _generate_min_max_int(int(min_s[1:]), None, out, less_decimals, top_level=False)
172            out.append(")")
173            if c < "9":
174                out.append(" | ")
175                digit_range(chr(ord(c) + 1), "9")
176                out.append(" ")
177                more_digits(length - 1, less_decimals)
178        return
179
180    if has_max:
181        if max_value >= 0:
182            if top_level:
183                out.append("\"-\" [1-9] ")
184                more_digits(0, less_decimals)
185                out.append(" | ")
186            _generate_min_max_int(0, max_value, out, decimals_left, top_level=True)
187        else:
188            out.append("\"-\" (")
189            _generate_min_max_int(-max_value, None, out, decimals_left, top_level=False)
190            out.append(")")
191        return
192
193    raise RuntimeError("At least one of min_value or max_value must be set")
194
195class BuiltinRule:
196    def __init__(self, content: str, deps: list | None = None):
197        self.content = content
198        self.deps = deps or []
199
200# Constraining spaces to prevent model "running away".
201SPACE_RULE = '| " " | "\\n"{1,2} [ \\t]{0,20}'
202
203PRIMITIVE_RULES = {
204    'boolean'      : BuiltinRule('("true" | "false") space', []),
205    'decimal-part' : BuiltinRule('[0-9]{1,16}', []),
206    'integral-part': BuiltinRule('[0] | [1-9] [0-9]{0,15}', []),
207    'number'       : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']),
208    'integer'      : BuiltinRule('("-"? integral-part) space', ['integral-part']),
209    'value'        : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']),
210    'object'       : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']),
211    'array'        : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']),
212    'uuid'         : BuiltinRule(r'"\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\"" space', []),
213    'char'         : BuiltinRule(r'[^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})', []),
214    'string'       : BuiltinRule(r'"\"" char* "\"" space', ['char']),
215    'null'         : BuiltinRule('"null" space', []),
216}
217
218# TODO: support "uri", "email" string formats
219STRING_FORMAT_RULES = {
220    'date'            : BuiltinRule('[0-9]{4} "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []),
221    'time'            : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []),
222    'date-time'       : BuiltinRule('date "T" time', ['date', 'time']),
223    'date-string'     : BuiltinRule('"\\"" date "\\"" space', ['date']),
224    'time-string'     : BuiltinRule('"\\"" time "\\"" space', ['time']),
225    'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']),
226}
227
228DOTALL = '[\\U00000000-\\U0010FFFF]'
229DOT = '[^\\x0A\\x0D]'
230
231RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()])
232
233INVALID_RULE_CHARS_RE = re.compile(r'[^a-zA-Z0-9-]+')
234GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\\]')
235GRAMMAR_RANGE_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\]\-\\]')
236GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]', '\\': '\\\\'}
237
238NON_LITERAL_SET = set('|.()[]{}*+?')
239ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('^$.[]()|{}*+?')
240
241
242class SchemaConverter:
243    def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern):
244        self._prop_order = prop_order
245        self._allow_fetch = allow_fetch
246        self._dotall = dotall
247        self._raw_pattern = raw_pattern
248        self._rules = {
249            'space': SPACE_RULE,
250        }
251        self._refs = {}
252        self._refs_being_resolved = set()
253
254    def _format_literal(self, literal):
255        escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
256            lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)) or m.group(0), literal
257        )
258        return f'"{escaped}"'
259
260    def not_literal(self, literal: str, dotall: bool = True, maybe_escaped_underscores = False) -> str:
261        '''
262            not_literal('a') -> '[^a]'
263            not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?'
264        '''
265        assert len(literal) > 0, 'Empty literal not supported'
266        def recurse(i: int):
267            c = literal[i]
268            if maybe_escaped_underscores and c == '_':
269                yield f'[^{c}\\\\]'
270                yield ' | '
271                yield f'"\\\\"? "{c}"'
272            else:
273                yield f'[^{c}]'
274            if i < len(literal) - 1:
275                yield ' | '
276                yield self._format_literal(c)
277                yield ' ('
278                yield from recurse(i + 1)
279                yield ')?'
280
281        return ''.join(('(', *recurse(0), ')'))
282
283    def _not_strings(self, strings):
284        class TrieNode:
285            def __init__(self):
286                self.children = {}
287                self.is_end_of_string = False
288
289            def insert(self, string):
290                node = self
291                for c in string:
292                    node = node.children.setdefault(c, TrieNode())
293                node.is_end_of_string = True
294
295        trie = TrieNode()
296        for s in strings:
297            trie.insert(s)
298
299        char_rule = self._add_primitive('char', PRIMITIVE_RULES['char'])
300        out = ['["] ( ']
301
302        def visit(node):
303            rejects = []
304            first = True
305            for c in sorted(node.children.keys()):
306                child = node.children[c]
307                rejects.append(c)
308                if first:
309                    first = False
310                else:
311                    out.append(' | ')
312                out.append(f'[{c}]')
313                if child.children:
314                    out.append(f' (')
315                    visit(child)
316                    out.append(')')
317                elif child.is_end_of_string:
318                    out.append(f' {char_rule}+')
319            if node.children:
320                if not first:
321                    out.append(' | ')
322                out.append(f'[^"{"".join(rejects)}] {char_rule}*')
323        visit(trie)
324
325        out.append(f' ){"" if trie.is_end_of_string else "?"} ["] space')
326        return ''.join(out)
327
328    def _add_rule(self, name, rule):
329        esc_name = INVALID_RULE_CHARS_RE.sub('-', name)
330        if esc_name not in self._rules or self._rules[esc_name] == rule:
331            key = esc_name
332        else:
333            i = 0
334            while f'{esc_name}{i}' in self._rules and self._rules[f'{esc_name}{i}'] != rule:
335                i += 1
336            key = f'{esc_name}{i}'
337        self._rules[key] = rule
338        return key
339
340    def resolve_refs(self, schema: dict, url: str):
341        '''
342            Resolves all $ref fields in the given schema, fetching any remote schemas,
343            replacing $ref with absolute reference URL and populating self._refs with the
344            respective referenced (sub)schema dictionaries.
345        '''
346        def visit(n: dict):
347            if isinstance(n, list):
348                return [visit(x) for x in n]
349            elif isinstance(n, dict):
350                ref = n.get('$ref')
351                if ref is not None and ref not in self._refs:
352                    if ref.startswith('https://'):
353                        assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)'
354                        import requests
355
356                        frag_split = ref.split('#')
357                        base_url = frag_split[0]
358
359                        target = self._refs.get(base_url)
360                        if target is None:
361                            target = self.resolve_refs(requests.get(ref).json(), base_url)
362                            self._refs[base_url] = target
363
364                        if len(frag_split) == 1 or frag_split[-1] == '':
365                            return target
366                    elif ref.startswith('#/'):
367                        target = schema
368                        ref = f'{url}{ref}'
369                        n['$ref'] = ref
370                    else:
371                        raise ValueError(f'Unsupported ref {ref}')
372
373                    for sel in ref.split('#')[-1].split('/')[1:]:
374                        assert target is not None, f'Error resolving ref {ref}: {sel} not in {target}'
375                        if isinstance(target, list):
376                            try:
377                                sel_index = int(sel)
378                            except ValueError:
379                                raise ValueError(f'Error resolving ref {ref}: {sel} not in {target}')
380                            assert 0 <= sel_index < len(target), f'Error resolving ref {ref}: {sel} not in {target}'
381                            target = target[sel_index]
382                        else:
383                            assert sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
384                            target = target[sel]
385
386                    self._refs[ref] = target
387                else:
388                    for v in n.values():
389                        visit(v)
390
391            return n
392        return visit(schema)
393
394    def _generate_union_rule(self, name, alt_schemas):
395        return ' | '.join((
396            self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}')
397            for i, alt_schema in enumerate(alt_schemas)
398        ))
399
400    def _visit_pattern(self, pattern, name):
401        '''
402            Transforms a regular expression pattern into a GBNF rule.
403
404            Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions
405            Output: https://github.com/ggml-org/llama.cpp/blob/master/grammars/README.md
406
407            Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers.
408
409            Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which
410            we define sub-rules to keep the output lean.
411        '''
412
413        assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"'
414        pattern = pattern[1:-1]
415        sub_rule_ids = {}
416
417        i = 0
418        length = len(pattern)
419
420        def to_rule(s: tuple[str, bool]) -> str:
421            (txt, is_literal) = s
422            return "\"" + txt + "\"" if is_literal else txt
423
424        def transform() -> tuple[str, bool]:
425            '''
426                Parse a unit at index i (advancing it), and return its string representation + whether it's a literal.
427            '''
428            nonlocal i
429            nonlocal pattern
430            nonlocal sub_rule_ids
431
432            start = i
433            # For each component of this sequence, store its string representation and whether it's a literal.
434            # We only need a flat structure here to apply repetition operators to the last item, and
435            # to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially
436            # (GBNF's syntax is luckily very close to regular expressions!)
437            seq: list[tuple[str, bool]] = []
438
439            def get_dot():
440                if self._dotall:
441                    rule = DOTALL
442                else:
443                    # Accept any character... except \n and \r line break chars (\x0A and \xOD)
444                    rule = DOT
445                return self._add_rule(f'dot', rule)
446
447            def join_seq():
448                nonlocal seq
449                ret = []
450                for is_literal, g in itertools.groupby(seq, lambda x: x[1]):
451                    if is_literal:
452                        ret.append((''.join(x[0] for x in g), True))
453                    else:
454                        ret.extend(g)
455                if len(ret) == 1:
456                    return ret[0]
457                return (' '.join(to_rule(x) for x in seq), False)
458
459            while i < length:
460                c = pattern[i]
461                if c == '.':
462                    seq.append((get_dot(), False))
463                    i += 1
464                elif c == '(':
465                    i += 1
466                    if i < length:
467                        assert pattern[i] != '?', f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/'
468                    seq.append((f'({to_rule(transform())})', False))
469                elif c == ')':
470                    i += 1
471                    assert start > 0 and pattern[start-1] == '(', f'Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}'
472                    return join_seq()
473                elif c == '[':
474                    square_brackets = c
475                    i += 1
476                    while i < length and pattern[i] != ']':
477                        if pattern[i] == '\\':
478                            square_brackets += pattern[i:i+2]
479                            i += 2
480                        else:
481                            square_brackets += pattern[i]
482                            i += 1
483                    assert i < length, f'Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}'
484                    square_brackets += ']'
485                    i += 1
486                    seq.append((square_brackets, False))
487                elif c == '|':
488                    seq.append(('|', False))
489                    i += 1
490                elif c in ('*', '+', '?'):
491                    seq[-1] = (to_rule(seq[-1]) + c, False)
492                    i += 1
493                elif c == '{':
494                    curly_brackets = c
495                    i += 1
496                    while i < length and pattern[i] != '}':
497                        curly_brackets += pattern[i]
498                        i += 1
499                    assert i < length, f'Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}'
500                    curly_brackets += '}'
501                    i += 1
502                    nums = [s.strip() for s in curly_brackets[1:-1].split(',')]
503                    min_times = 0
504                    max_times = None
505                    try:
506                        if len(nums) == 1:
507                            min_times = int(nums[0])
508                            max_times = min_times
509                        else:
510                            assert len(nums) == 2
511                            min_times = int(nums[0]) if nums[0] else 0
512                            max_times = int(nums[1]) if nums[1] else None
513                    except ValueError:
514                        raise ValueError(f'Invalid quantifier {curly_brackets} in /{pattern}/')
515
516                    (sub, sub_is_literal) = seq[-1]
517
518                    if not sub_is_literal:
519                        id = sub_rule_ids.get(sub)
520                        if id is None:
521                            id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub)
522                            sub_rule_ids[sub] = id
523                        sub = id
524
525                    seq[-1] = (_build_repetition(f'"{sub}"' if sub_is_literal else sub, min_times, max_times), False)
526                else:
527                    literal = ''
528                    while i < length:
529                        if pattern[i] == '\\' and i < length - 1:
530                            next = pattern[i + 1]
531                            if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS:
532                                i += 1
533                                literal += pattern[i]
534                                i += 1
535                            else:
536                                literal += pattern[i:i+2]
537                                i += 2
538                        elif pattern[i] == '"' and not self._raw_pattern:
539                            literal += '\\"'
540                            i += 1
541                        elif pattern[i] not in NON_LITERAL_SET and \
542                                (i == length - 1 or literal == '' or pattern[i+1] == '.' or pattern[i+1] not in NON_LITERAL_SET):
543                            literal += pattern[i]
544                            i += 1
545                        else:
546                            break
547                    if literal:
548                        seq.append((literal, True))
549
550            return join_seq()
551
552        return self._add_rule(
553            name,
554            to_rule(transform()) if self._raw_pattern \
555                else "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space")
556
557
558    def _resolve_ref(self, ref):
559        ref_fragment = ref.split('#')[-1]
560        ref_name = 'ref' + re.sub(r'[^a-zA-Z0-9-]+', '-', ref_fragment)
561        if ref_name not in self._rules and ref not in self._refs_being_resolved:
562            self._refs_being_resolved.add(ref)
563            resolved = self._refs[ref]
564            ref_name = self.visit(resolved, ref_name)
565            self._refs_being_resolved.remove(ref)
566        return ref_name
567
568    def _generate_constant_rule(self, value):
569        return self._format_literal(json.dumps(value))
570
571    def visit(self, schema, name):
572        schema_type = schema.get('type')
573        schema_format = schema.get('format')
574        rule_name = name + '-' if name in RESERVED_NAMES else name or 'root'
575
576        if (ref := schema.get('$ref')) is not None:
577            return self._add_rule(rule_name, self._resolve_ref(ref))
578
579        elif 'oneOf' in schema or 'anyOf' in schema:
580            return self._add_rule(rule_name, self._generate_union_rule(name, schema.get('oneOf') or schema['anyOf']))
581
582        elif isinstance(schema_type, list):
583            return self._add_rule(rule_name, self._generate_union_rule(name, [{**schema, 'type': t} for t in schema_type]))
584
585        elif 'const' in schema:
586            return self._add_rule(rule_name, self._generate_constant_rule(schema['const']) + ' space')
587
588        elif 'enum' in schema:
589            rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + ') space'
590            return self._add_rule(rule_name, rule)
591
592        elif schema_type in (None, 'object') and \
593             ('properties' in schema or \
594              ('additionalProperties' in schema and schema['additionalProperties'] is not True)):
595            required = set(schema.get('required', []))
596            properties = list(schema.get('properties', {}).items())
597            return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties')))
598
599        elif schema_type in (None, 'object', 'string') and 'allOf' in schema:
600            required = set()
601            properties = []
602            enum_sets = []
603            hybrid_name = name
604            def add_component(comp_schema, is_required):
605                if (ref := comp_schema.get('$ref')) is not None:
606                    comp_schema = self._refs[ref]
607
608                if 'properties' in comp_schema:
609                    for prop_name, prop_schema in comp_schema['properties'].items():
610                        properties.append((prop_name, prop_schema))
611                        if is_required:
612                            required.add(prop_name)
613
614                if 'enum' in comp_schema:
615                    enum_sets.append(set(comp_schema['enum']))
616
617            for t in schema['allOf']:
618                if 'anyOf' in t:
619                    for tt in t['anyOf']:
620                        add_component(tt, is_required=False)
621                else:
622                    add_component(t, is_required=True)
623
624            if enum_sets:
625                enum_intersection = enum_sets[0]
626                for s in enum_sets[1:]:
627                    enum_intersection &= s
628
629                if enum_intersection:
630                    rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in sorted(enum_intersection))) + ') space'
631                    return self._add_rule(rule_name, rule)
632
633            return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None))
634
635        elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
636            items = schema.get('items') or schema['prefixItems']
637            if isinstance(items, list):
638                return self._add_rule(
639                    rule_name,
640                    '"[" space ' +
641                    ' "," space '.join(
642                        self.visit(item, f'{name}{"-" if name else ""}tuple-{i}')
643                        for i, item in enumerate(items)) +
644                    ' "]" space')
645            else:
646                item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item')
647                min_items = schema.get("minItems", 0)
648                max_items = schema.get("maxItems")
649                return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space')
650
651        elif schema_type in (None, 'string') and 'pattern' in schema:
652            return self._visit_pattern(schema['pattern'], rule_name)
653
654        elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''):
655            return self._add_primitive(
656                'root' if rule_name == 'root' else schema_format,
657                PRIMITIVE_RULES['uuid']
658            )
659
660        elif schema_type in (None, 'string') and f'{schema_format}-string' in STRING_FORMAT_RULES:
661            prim_name = f'{schema_format}-string'
662            return self._add_rule(rule_name, self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name]))
663
664        elif schema_type == 'string' and ('minLength' in schema or 'maxLength' in schema):
665            char_rule = self._add_primitive('char', PRIMITIVE_RULES['char'])
666            min_len = schema.get('minLength', 0)
667            max_len = schema.get('maxLength')
668
669            return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space')
670
671        elif schema_type in (None, 'integer') and \
672                ('minimum' in schema or 'exclusiveMinimum' in schema or 'maximum' in schema or 'exclusiveMaximum' in schema):
673            min_value = None
674            max_value = None
675            if 'minimum' in schema:
676                min_value = schema['minimum']
677            elif 'exclusiveMinimum' in schema:
678                min_value = schema['exclusiveMinimum'] + 1
679            if 'maximum' in schema:
680                max_value = schema['maximum']
681            elif 'exclusiveMaximum' in schema:
682                max_value = schema['exclusiveMaximum'] - 1
683
684            out = ["("]
685            _generate_min_max_int(min_value, max_value, out)
686            out.append(") space")
687            return self._add_rule(rule_name, ''.join(out))
688
689        elif (schema_type == 'object') or (len(schema) == 0):
690            return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object']))
691
692        else:
693            assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}'
694            # TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
695            return self._add_primitive('root' if rule_name == 'root' else schema_type, PRIMITIVE_RULES[schema_type])
696
697    def _add_primitive(self, name: str, rule: BuiltinRule):
698        n = self._add_rule(name, rule.content)
699
700        for dep in rule.deps:
701            dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep)
702            assert dep_rule, f'Rule {dep} not known'
703            if dep not in self._rules:
704                self._add_primitive(dep, dep_rule)
705        return n
706
707    def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Optional[Union[bool, Any]]):
708        prop_order = self._prop_order
709        # sort by position in prop_order (if specified) then by original order
710        sorted_props = [kv[0] for _, kv in sorted(enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]))]
711
712        prop_kv_rule_names = {}
713        for prop_name, prop_schema in properties:
714            prop_rule_name = self.visit(prop_schema, f'{name}{"-" if name else ""}{prop_name}')
715            prop_kv_rule_names[prop_name] = self._add_rule(
716                f'{name}{"-" if name else ""}{prop_name}-kv',
717                fr'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}'
718            )
719        required_props = [k for k in sorted_props if k in required]
720        optional_props = [k for k in sorted_props if k not in required]
721
722        if additional_properties is not None and additional_properties != False:
723            sub_name = f'{name}{"-" if name else ""}additional'
724            value_rule = self.visit(additional_properties, f'{sub_name}-value') if isinstance(additional_properties, dict) else \
725                self._add_primitive('value', PRIMITIVE_RULES['value'])
726            key_rule = self._add_primitive('string', PRIMITIVE_RULES['string']) if not sorted_props \
727                else self._add_rule(f'{sub_name}-k', self._not_strings(sorted_props))
728
729            prop_kv_rule_names["*"] = self._add_rule(
730                f'{sub_name}-kv',
731                f'{key_rule} ":" space {value_rule}'
732            )
733            optional_props.append("*")
734
735        rule = '"{" space '
736        rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props)
737
738        if optional_props:
739            rule += ' ('
740            if required_props:
741                rule += ' "," space ( '
742
743            def get_recursive_refs(ks, first_is_optional):
744                [k, *rest] = ks
745                kv_rule_name = prop_kv_rule_names[k]
746                comma_ref = f'( "," space {kv_rule_name} )'
747                if first_is_optional:
748                    res = comma_ref + ('*' if k == '*' else '?')
749                else:
750                    res = kv_rule_name + (' ' + comma_ref + "*" if k == '*' else '')
751                if len(rest) > 0:
752                    res += ' ' + self._add_rule(
753                        f'{name}{"-" if name else ""}{k}-rest',
754                        get_recursive_refs(rest, first_is_optional=True)
755                    )
756                return res
757
758            rule += ' | '.join(
759                get_recursive_refs(optional_props[i:], first_is_optional=False)
760                for i in range(len(optional_props))
761            )
762            if required_props:
763                rule += ' )'
764            rule += ' )?'
765
766        rule += ' "}" space'
767
768        return rule
769
770    def format_grammar(self):
771        return '\n'.join(
772            f'{name} ::= {rule}'
773            for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0])
774        )
775
776
777def main(args_in = None):
778    parser = argparse.ArgumentParser(
779        description='''
780            Generates a grammar (suitable for use in ./llama-cli) that produces JSON conforming to a
781            given JSON schema. Only a subset of JSON schema features are supported; more may be
782            added in the future.
783        ''',
784    )
785    parser.add_argument(
786        '--prop-order',
787        default=[],
788        type=lambda s: s.split(','),
789        help='''
790            comma-separated property names defining the order of precedence for object properties;
791            properties not specified here are given lower precedence than those that are, and
792            are kept in their original order from the schema. Required properties are always
793            given precedence over optional properties.
794        '''
795    )
796    parser.add_argument(
797        '--allow-fetch',
798        action='store_true',
799        default=False,
800        help='Whether to allow fetching referenced schemas over HTTPS')
801    parser.add_argument(
802        '--dotall',
803        action='store_true',
804        default=False,
805        help='Whether to treat dot (".") as matching all chars including line breaks in regular expression patterns')
806    parser.add_argument(
807        '--raw-pattern',
808        action='store_true',
809        default=False,
810        help='Treats string patterns as raw patterns w/o quotes (or quote escapes)')
811
812    parser.add_argument('schema', help='file containing JSON schema ("-" for stdin)')
813    args = parser.parse_args(args_in)
814
815    if args.schema.startswith('https://'):
816        url = args.schema
817        import requests
818        schema = requests.get(url).json()
819    elif args.schema == '-':
820        url = 'stdin'
821        schema = json.load(sys.stdin)
822    else:
823        url = f'file://{args.schema}'
824        with open(args.schema) as f:
825            schema = json.load(f)
826    converter = SchemaConverter(
827        prop_order={name: idx for idx, name in enumerate(args.prop_order)},
828        allow_fetch=args.allow_fetch,
829        dotall=args.dotall,
830        raw_pattern=args.raw_pattern)
831    schema = converter.resolve_refs(schema, url)
832    converter.visit(schema, '')
833    print(converter.format_grammar())
834
835
836if __name__ == '__main__':
837    main()