1#!/usr/bin/env python3
2from __future__ import annotations
3
4import argparse
5import itertools
6import json
7import re
8import sys
9from typing import Any, List, Optional, Set, Tuple, Union
10
11def _build_repetition(item_rule, min_items, max_items, separator_rule=None):
12
13 if max_items == 0:
14 return ""
15
16 if min_items == 0 and max_items == 1:
17 return f'{item_rule}?'
18
19 if not separator_rule:
20 if min_items == 1 and max_items is None:
21 return f'{item_rule}+'
22 elif min_items == 0 and max_items is None:
23 return f'{item_rule}*'
24 else:
25 return f'{item_rule}{{{min_items},{max_items if max_items is not None else ""}}}'
26
27 result = item_rule + ' ' + _build_repetition(f'({separator_rule} {item_rule})', min_items - 1 if min_items > 0 else 0, max_items - 1 if max_items is not None else None)
28 return f'({result})?' if min_items == 0 else result
29
30def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], out: list, decimals_left: int = 16, top_level: bool = True):
31 has_min = min_value != None
32 has_max = max_value != None
33
34 def digit_range(from_char: str, to_char: str):
35 out.append("[")
36 if from_char == to_char:
37 out.append(from_char)
38 else:
39 out.append(from_char)
40 out.append("-")
41 out.append(to_char)
42 out.append("]")
43
44 def more_digits(min_digits: int, max_digits: int):
45 out.append("[0-9]")
46 if min_digits == max_digits and min_digits == 1:
47 return
48 out.append("{")
49 out.append(str(min_digits))
50 if max_digits != min_digits:
51 out.append(",")
52 if max_digits != sys.maxsize:
53 out.append(str(max_digits))
54 out.append("}")
55
56 def uniform_range(from_str: str, to_str: str):
57 i = 0
58 while i < len(from_str) and from_str[i] == to_str[i]:
59 i += 1
60 if i > 0:
61 out.append("\"")
62 out.append(from_str[:i])
63 out.append("\"")
64 if i < len(from_str):
65 if i > 0:
66 out.append(" ")
67 sub_len = len(from_str) - i - 1
68 if sub_len > 0:
69 from_sub = from_str[i+1:]
70 to_sub = to_str[i+1:]
71 sub_zeros = "0" * sub_len
72 sub_nines = "9" * sub_len
73
74 to_reached = False
75 out.append("(")
76 if from_sub == sub_zeros:
77 digit_range(from_str[i], chr(ord(to_str[i]) - 1))
78 out.append(" ")
79 more_digits(sub_len, sub_len)
80 else:
81 out.append("[")
82 out.append(from_str[i])
83 out.append("] ")
84 out.append("(")
85 uniform_range(from_sub, sub_nines)
86 out.append(")")
87 if ord(from_str[i]) < ord(to_str[i]) - 1:
88 out.append(" | ")
89 if to_sub == sub_nines:
90 digit_range(chr(ord(from_str[i]) + 1), to_str[i])
91 to_reached = True
92 else:
93 digit_range(chr(ord(from_str[i]) + 1), chr(ord(to_str[i]) - 1))
94 out.append(" ")
95 more_digits(sub_len, sub_len)
96 if not to_reached:
97 out.append(" | ")
98 digit_range(to_str[i], to_str[i])
99 out.append(" ")
100 uniform_range(sub_zeros, to_sub)
101 out.append(")")
102 else:
103 out.append("[")
104 out.append(from_str[i])
105 out.append("-")
106 out.append(to_str[i])
107 out.append("]")
108
109 if has_min and has_max:
110 if min_value < 0 and max_value < 0:
111 out.append("\"-\" (")
112 _generate_min_max_int(-max_value, -min_value, out, decimals_left, top_level=True)
113 out.append(")")
114 return
115
116 if min_value < 0:
117 out.append("\"-\" (")
118 _generate_min_max_int(0, -min_value, out, decimals_left, top_level=True)
119 out.append(") | ")
120 min_value = 0
121
122 min_s = str(min_value)
123 max_s = str(max_value)
124 min_digits = len(min_s)
125 max_digits = len(max_s)
126
127 for digits in range(min_digits, max_digits):
128 uniform_range(min_s, "9" * digits)
129 min_s = "1" + "0" * digits
130 out.append(" | ")
131 uniform_range(min_s, max_s)
132 return
133
134 less_decimals = max(decimals_left - 1, 1)
135
136 if has_min:
137 if min_value < 0:
138 out.append("\"-\" (")
139 _generate_min_max_int(None, -min_value, out, decimals_left, top_level=False)
140 out.append(") | [0] | [1-9] ")
141 more_digits(0, decimals_left - 1)
142 elif min_value == 0:
143 if top_level:
144 out.append("[0] | [1-9] ")
145 more_digits(0, less_decimals)
146 else:
147 more_digits(1, decimals_left)
148 elif min_value <= 9:
149 c = str(min_value)
150 range_start = '1' if top_level else '0'
151 if c > range_start:
152 digit_range(range_start, chr(ord(c) - 1))
153 out.append(" ")
154 more_digits(1, less_decimals)
155 out.append(" | ")
156 digit_range(c, "9")
157 out.append(" ")
158 more_digits(0, less_decimals)
159 else:
160 min_s = str(min_value)
161 length = len(min_s)
162 c = min_s[0]
163
164 if c > "1":
165 digit_range("1" if top_level else "0", chr(ord(c) - 1))
166 out.append(" ")
167 more_digits(length, less_decimals)
168 out.append(" | ")
169 digit_range(c, c)
170 out.append(" (")
171 _generate_min_max_int(int(min_s[1:]), None, out, less_decimals, top_level=False)
172 out.append(")")
173 if c < "9":
174 out.append(" | ")
175 digit_range(chr(ord(c) + 1), "9")
176 out.append(" ")
177 more_digits(length - 1, less_decimals)
178 return
179
180 if has_max:
181 if max_value >= 0:
182 if top_level:
183 out.append("\"-\" [1-9] ")
184 more_digits(0, less_decimals)
185 out.append(" | ")
186 _generate_min_max_int(0, max_value, out, decimals_left, top_level=True)
187 else:
188 out.append("\"-\" (")
189 _generate_min_max_int(-max_value, None, out, decimals_left, top_level=False)
190 out.append(")")
191 return
192
193 raise RuntimeError("At least one of min_value or max_value must be set")
194
195class BuiltinRule:
196 def __init__(self, content: str, deps: list | None = None):
197 self.content = content
198 self.deps = deps or []
199
200# Constraining spaces to prevent model "running away".
201SPACE_RULE = '| " " | "\\n"{1,2} [ \\t]{0,20}'
202
203PRIMITIVE_RULES = {
204 'boolean' : BuiltinRule('("true" | "false") space', []),
205 'decimal-part' : BuiltinRule('[0-9]{1,16}', []),
206 'integral-part': BuiltinRule('[0] | [1-9] [0-9]{0,15}', []),
207 'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']),
208 'integer' : BuiltinRule('("-"? integral-part) space', ['integral-part']),
209 'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']),
210 'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']),
211 'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']),
212 'uuid' : BuiltinRule(r'"\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\"" space', []),
213 'char' : BuiltinRule(r'[^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})', []),
214 'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']),
215 'null' : BuiltinRule('"null" space', []),
216}
217
218# TODO: support "uri", "email" string formats
219STRING_FORMAT_RULES = {
220 'date' : BuiltinRule('[0-9]{4} "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []),
221 'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []),
222 'date-time' : BuiltinRule('date "T" time', ['date', 'time']),
223 'date-string' : BuiltinRule('"\\"" date "\\"" space', ['date']),
224 'time-string' : BuiltinRule('"\\"" time "\\"" space', ['time']),
225 'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']),
226}
227
228DOTALL = '[\\U00000000-\\U0010FFFF]'
229DOT = '[^\\x0A\\x0D]'
230
231RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()])
232
233INVALID_RULE_CHARS_RE = re.compile(r'[^a-zA-Z0-9-]+')
234GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\\]')
235GRAMMAR_RANGE_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\]\-\\]')
236GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]', '\\': '\\\\'}
237
238NON_LITERAL_SET = set('|.()[]{}*+?')
239ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('^$.[]()|{}*+?')
240
241
242class SchemaConverter:
243 def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern):
244 self._prop_order = prop_order
245 self._allow_fetch = allow_fetch
246 self._dotall = dotall
247 self._raw_pattern = raw_pattern
248 self._rules = {
249 'space': SPACE_RULE,
250 }
251 self._refs = {}
252 self._refs_being_resolved = set()
253
254 def _format_literal(self, literal):
255 escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
256 lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)) or m.group(0), literal
257 )
258 return f'"{escaped}"'
259
260 def not_literal(self, literal: str, dotall: bool = True, maybe_escaped_underscores = False) -> str:
261 '''
262 not_literal('a') -> '[^a]'
263 not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?'
264 '''
265 assert len(literal) > 0, 'Empty literal not supported'
266 def recurse(i: int):
267 c = literal[i]
268 if maybe_escaped_underscores and c == '_':
269 yield f'[^{c}\\\\]'
270 yield ' | '
271 yield f'"\\\\"? "{c}"'
272 else:
273 yield f'[^{c}]'
274 if i < len(literal) - 1:
275 yield ' | '
276 yield self._format_literal(c)
277 yield ' ('
278 yield from recurse(i + 1)
279 yield ')?'
280
281 return ''.join(('(', *recurse(0), ')'))
282
283 def _not_strings(self, strings):
284 class TrieNode:
285 def __init__(self):
286 self.children = {}
287 self.is_end_of_string = False
288
289 def insert(self, string):
290 node = self
291 for c in string:
292 node = node.children.setdefault(c, TrieNode())
293 node.is_end_of_string = True
294
295 trie = TrieNode()
296 for s in strings:
297 trie.insert(s)
298
299 char_rule = self._add_primitive('char', PRIMITIVE_RULES['char'])
300 out = ['["] ( ']
301
302 def visit(node):
303 rejects = []
304 first = True
305 for c in sorted(node.children.keys()):
306 child = node.children[c]
307 rejects.append(c)
308 if first:
309 first = False
310 else:
311 out.append(' | ')
312 out.append(f'[{c}]')
313 if child.children:
314 out.append(f' (')
315 visit(child)
316 out.append(')')
317 elif child.is_end_of_string:
318 out.append(f' {char_rule}+')
319 if node.children:
320 if not first:
321 out.append(' | ')
322 out.append(f'[^"{"".join(rejects)}] {char_rule}*')
323 visit(trie)
324
325 out.append(f' ){"" if trie.is_end_of_string else "?"} ["] space')
326 return ''.join(out)
327
328 def _add_rule(self, name, rule):
329 esc_name = INVALID_RULE_CHARS_RE.sub('-', name)
330 if esc_name not in self._rules or self._rules[esc_name] == rule:
331 key = esc_name
332 else:
333 i = 0
334 while f'{esc_name}{i}' in self._rules and self._rules[f'{esc_name}{i}'] != rule:
335 i += 1
336 key = f'{esc_name}{i}'
337 self._rules[key] = rule
338 return key
339
340 def resolve_refs(self, schema: dict, url: str):
341 '''
342 Resolves all $ref fields in the given schema, fetching any remote schemas,
343 replacing $ref with absolute reference URL and populating self._refs with the
344 respective referenced (sub)schema dictionaries.
345 '''
346 def visit(n: dict):
347 if isinstance(n, list):
348 return [visit(x) for x in n]
349 elif isinstance(n, dict):
350 ref = n.get('$ref')
351 if ref is not None and ref not in self._refs:
352 if ref.startswith('https://'):
353 assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)'
354 import requests
355
356 frag_split = ref.split('#')
357 base_url = frag_split[0]
358
359 target = self._refs.get(base_url)
360 if target is None:
361 target = self.resolve_refs(requests.get(ref).json(), base_url)
362 self._refs[base_url] = target
363
364 if len(frag_split) == 1 or frag_split[-1] == '':
365 return target
366 elif ref.startswith('#/'):
367 target = schema
368 ref = f'{url}{ref}'
369 n['$ref'] = ref
370 else:
371 raise ValueError(f'Unsupported ref {ref}')
372
373 for sel in ref.split('#')[-1].split('/')[1:]:
374 assert target is not None, f'Error resolving ref {ref}: {sel} not in {target}'
375 if isinstance(target, list):
376 try:
377 sel_index = int(sel)
378 except ValueError:
379 raise ValueError(f'Error resolving ref {ref}: {sel} not in {target}')
380 assert 0 <= sel_index < len(target), f'Error resolving ref {ref}: {sel} not in {target}'
381 target = target[sel_index]
382 else:
383 assert sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
384 target = target[sel]
385
386 self._refs[ref] = target
387 else:
388 for v in n.values():
389 visit(v)
390
391 return n
392 return visit(schema)
393
394 def _generate_union_rule(self, name, alt_schemas):
395 return ' | '.join((
396 self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}')
397 for i, alt_schema in enumerate(alt_schemas)
398 ))
399
400 def _visit_pattern(self, pattern, name):
401 '''
402 Transforms a regular expression pattern into a GBNF rule.
403
404 Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions
405 Output: https://github.com/ggml-org/llama.cpp/blob/master/grammars/README.md
406
407 Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers.
408
409 Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which
410 we define sub-rules to keep the output lean.
411 '''
412
413 assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"'
414 pattern = pattern[1:-1]
415 sub_rule_ids = {}
416
417 i = 0
418 length = len(pattern)
419
420 def to_rule(s: tuple[str, bool]) -> str:
421 (txt, is_literal) = s
422 return "\"" + txt + "\"" if is_literal else txt
423
424 def transform() -> tuple[str, bool]:
425 '''
426 Parse a unit at index i (advancing it), and return its string representation + whether it's a literal.
427 '''
428 nonlocal i
429 nonlocal pattern
430 nonlocal sub_rule_ids
431
432 start = i
433 # For each component of this sequence, store its string representation and whether it's a literal.
434 # We only need a flat structure here to apply repetition operators to the last item, and
435 # to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially
436 # (GBNF's syntax is luckily very close to regular expressions!)
437 seq: list[tuple[str, bool]] = []
438
439 def get_dot():
440 if self._dotall:
441 rule = DOTALL
442 else:
443 # Accept any character... except \n and \r line break chars (\x0A and \xOD)
444 rule = DOT
445 return self._add_rule(f'dot', rule)
446
447 def join_seq():
448 nonlocal seq
449 ret = []
450 for is_literal, g in itertools.groupby(seq, lambda x: x[1]):
451 if is_literal:
452 ret.append((''.join(x[0] for x in g), True))
453 else:
454 ret.extend(g)
455 if len(ret) == 1:
456 return ret[0]
457 return (' '.join(to_rule(x) for x in seq), False)
458
459 while i < length:
460 c = pattern[i]
461 if c == '.':
462 seq.append((get_dot(), False))
463 i += 1
464 elif c == '(':
465 i += 1
466 if i < length:
467 assert pattern[i] != '?', f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/'
468 seq.append((f'({to_rule(transform())})', False))
469 elif c == ')':
470 i += 1
471 assert start > 0 and pattern[start-1] == '(', f'Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}'
472 return join_seq()
473 elif c == '[':
474 square_brackets = c
475 i += 1
476 while i < length and pattern[i] != ']':
477 if pattern[i] == '\\':
478 square_brackets += pattern[i:i+2]
479 i += 2
480 else:
481 square_brackets += pattern[i]
482 i += 1
483 assert i < length, f'Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}'
484 square_brackets += ']'
485 i += 1
486 seq.append((square_brackets, False))
487 elif c == '|':
488 seq.append(('|', False))
489 i += 1
490 elif c in ('*', '+', '?'):
491 seq[-1] = (to_rule(seq[-1]) + c, False)
492 i += 1
493 elif c == '{':
494 curly_brackets = c
495 i += 1
496 while i < length and pattern[i] != '}':
497 curly_brackets += pattern[i]
498 i += 1
499 assert i < length, f'Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}'
500 curly_brackets += '}'
501 i += 1
502 nums = [s.strip() for s in curly_brackets[1:-1].split(',')]
503 min_times = 0
504 max_times = None
505 try:
506 if len(nums) == 1:
507 min_times = int(nums[0])
508 max_times = min_times
509 else:
510 assert len(nums) == 2
511 min_times = int(nums[0]) if nums[0] else 0
512 max_times = int(nums[1]) if nums[1] else None
513 except ValueError:
514 raise ValueError(f'Invalid quantifier {curly_brackets} in /{pattern}/')
515
516 (sub, sub_is_literal) = seq[-1]
517
518 if not sub_is_literal:
519 id = sub_rule_ids.get(sub)
520 if id is None:
521 id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub)
522 sub_rule_ids[sub] = id
523 sub = id
524
525 seq[-1] = (_build_repetition(f'"{sub}"' if sub_is_literal else sub, min_times, max_times), False)
526 else:
527 literal = ''
528 while i < length:
529 if pattern[i] == '\\' and i < length - 1:
530 next = pattern[i + 1]
531 if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS:
532 i += 1
533 literal += pattern[i]
534 i += 1
535 else:
536 literal += pattern[i:i+2]
537 i += 2
538 elif pattern[i] == '"' and not self._raw_pattern:
539 literal += '\\"'
540 i += 1
541 elif pattern[i] not in NON_LITERAL_SET and \
542 (i == length - 1 or literal == '' or pattern[i+1] == '.' or pattern[i+1] not in NON_LITERAL_SET):
543 literal += pattern[i]
544 i += 1
545 else:
546 break
547 if literal:
548 seq.append((literal, True))
549
550 return join_seq()
551
552 return self._add_rule(
553 name,
554 to_rule(transform()) if self._raw_pattern \
555 else "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space")
556
557
558 def _resolve_ref(self, ref):
559 ref_fragment = ref.split('#')[-1]
560 ref_name = 'ref' + re.sub(r'[^a-zA-Z0-9-]+', '-', ref_fragment)
561 if ref_name not in self._rules and ref not in self._refs_being_resolved:
562 self._refs_being_resolved.add(ref)
563 resolved = self._refs[ref]
564 ref_name = self.visit(resolved, ref_name)
565 self._refs_being_resolved.remove(ref)
566 return ref_name
567
568 def _generate_constant_rule(self, value):
569 return self._format_literal(json.dumps(value))
570
571 def visit(self, schema, name):
572 schema_type = schema.get('type')
573 schema_format = schema.get('format')
574 rule_name = name + '-' if name in RESERVED_NAMES else name or 'root'
575
576 if (ref := schema.get('$ref')) is not None:
577 return self._add_rule(rule_name, self._resolve_ref(ref))
578
579 elif 'oneOf' in schema or 'anyOf' in schema:
580 return self._add_rule(rule_name, self._generate_union_rule(name, schema.get('oneOf') or schema['anyOf']))
581
582 elif isinstance(schema_type, list):
583 return self._add_rule(rule_name, self._generate_union_rule(name, [{**schema, 'type': t} for t in schema_type]))
584
585 elif 'const' in schema:
586 return self._add_rule(rule_name, self._generate_constant_rule(schema['const']) + ' space')
587
588 elif 'enum' in schema:
589 rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + ') space'
590 return self._add_rule(rule_name, rule)
591
592 elif schema_type in (None, 'object') and \
593 ('properties' in schema or \
594 ('additionalProperties' in schema and schema['additionalProperties'] is not True)):
595 required = set(schema.get('required', []))
596 properties = list(schema.get('properties', {}).items())
597 return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties')))
598
599 elif schema_type in (None, 'object', 'string') and 'allOf' in schema:
600 required = set()
601 properties = []
602 enum_sets = []
603 hybrid_name = name
604 def add_component(comp_schema, is_required):
605 if (ref := comp_schema.get('$ref')) is not None:
606 comp_schema = self._refs[ref]
607
608 if 'properties' in comp_schema:
609 for prop_name, prop_schema in comp_schema['properties'].items():
610 properties.append((prop_name, prop_schema))
611 if is_required:
612 required.add(prop_name)
613
614 if 'enum' in comp_schema:
615 enum_sets.append(set(comp_schema['enum']))
616
617 for t in schema['allOf']:
618 if 'anyOf' in t:
619 for tt in t['anyOf']:
620 add_component(tt, is_required=False)
621 else:
622 add_component(t, is_required=True)
623
624 if enum_sets:
625 enum_intersection = enum_sets[0]
626 for s in enum_sets[1:]:
627 enum_intersection &= s
628
629 if enum_intersection:
630 rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in sorted(enum_intersection))) + ') space'
631 return self._add_rule(rule_name, rule)
632
633 return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None))
634
635 elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
636 items = schema.get('items') or schema['prefixItems']
637 if isinstance(items, list):
638 return self._add_rule(
639 rule_name,
640 '"[" space ' +
641 ' "," space '.join(
642 self.visit(item, f'{name}{"-" if name else ""}tuple-{i}')
643 for i, item in enumerate(items)) +
644 ' "]" space')
645 else:
646 item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item')
647 min_items = schema.get("minItems", 0)
648 max_items = schema.get("maxItems")
649 return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space')
650
651 elif schema_type in (None, 'string') and 'pattern' in schema:
652 return self._visit_pattern(schema['pattern'], rule_name)
653
654 elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''):
655 return self._add_primitive(
656 'root' if rule_name == 'root' else schema_format,
657 PRIMITIVE_RULES['uuid']
658 )
659
660 elif schema_type in (None, 'string') and f'{schema_format}-string' in STRING_FORMAT_RULES:
661 prim_name = f'{schema_format}-string'
662 return self._add_rule(rule_name, self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name]))
663
664 elif schema_type == 'string' and ('minLength' in schema or 'maxLength' in schema):
665 char_rule = self._add_primitive('char', PRIMITIVE_RULES['char'])
666 min_len = schema.get('minLength', 0)
667 max_len = schema.get('maxLength')
668
669 return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space')
670
671 elif schema_type in (None, 'integer') and \
672 ('minimum' in schema or 'exclusiveMinimum' in schema or 'maximum' in schema or 'exclusiveMaximum' in schema):
673 min_value = None
674 max_value = None
675 if 'minimum' in schema:
676 min_value = schema['minimum']
677 elif 'exclusiveMinimum' in schema:
678 min_value = schema['exclusiveMinimum'] + 1
679 if 'maximum' in schema:
680 max_value = schema['maximum']
681 elif 'exclusiveMaximum' in schema:
682 max_value = schema['exclusiveMaximum'] - 1
683
684 out = ["("]
685 _generate_min_max_int(min_value, max_value, out)
686 out.append(") space")
687 return self._add_rule(rule_name, ''.join(out))
688
689 elif (schema_type == 'object') or (len(schema) == 0):
690 return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object']))
691
692 else:
693 assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}'
694 # TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
695 return self._add_primitive('root' if rule_name == 'root' else schema_type, PRIMITIVE_RULES[schema_type])
696
697 def _add_primitive(self, name: str, rule: BuiltinRule):
698 n = self._add_rule(name, rule.content)
699
700 for dep in rule.deps:
701 dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep)
702 assert dep_rule, f'Rule {dep} not known'
703 if dep not in self._rules:
704 self._add_primitive(dep, dep_rule)
705 return n
706
707 def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Optional[Union[bool, Any]]):
708 prop_order = self._prop_order
709 # sort by position in prop_order (if specified) then by original order
710 sorted_props = [kv[0] for _, kv in sorted(enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]))]
711
712 prop_kv_rule_names = {}
713 for prop_name, prop_schema in properties:
714 prop_rule_name = self.visit(prop_schema, f'{name}{"-" if name else ""}{prop_name}')
715 prop_kv_rule_names[prop_name] = self._add_rule(
716 f'{name}{"-" if name else ""}{prop_name}-kv',
717 fr'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}'
718 )
719 required_props = [k for k in sorted_props if k in required]
720 optional_props = [k for k in sorted_props if k not in required]
721
722 if additional_properties is not None and additional_properties != False:
723 sub_name = f'{name}{"-" if name else ""}additional'
724 value_rule = self.visit(additional_properties, f'{sub_name}-value') if isinstance(additional_properties, dict) else \
725 self._add_primitive('value', PRIMITIVE_RULES['value'])
726 key_rule = self._add_primitive('string', PRIMITIVE_RULES['string']) if not sorted_props \
727 else self._add_rule(f'{sub_name}-k', self._not_strings(sorted_props))
728
729 prop_kv_rule_names["*"] = self._add_rule(
730 f'{sub_name}-kv',
731 f'{key_rule} ":" space {value_rule}'
732 )
733 optional_props.append("*")
734
735 rule = '"{" space '
736 rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props)
737
738 if optional_props:
739 rule += ' ('
740 if required_props:
741 rule += ' "," space ( '
742
743 def get_recursive_refs(ks, first_is_optional):
744 [k, *rest] = ks
745 kv_rule_name = prop_kv_rule_names[k]
746 comma_ref = f'( "," space {kv_rule_name} )'
747 if first_is_optional:
748 res = comma_ref + ('*' if k == '*' else '?')
749 else:
750 res = kv_rule_name + (' ' + comma_ref + "*" if k == '*' else '')
751 if len(rest) > 0:
752 res += ' ' + self._add_rule(
753 f'{name}{"-" if name else ""}{k}-rest',
754 get_recursive_refs(rest, first_is_optional=True)
755 )
756 return res
757
758 rule += ' | '.join(
759 get_recursive_refs(optional_props[i:], first_is_optional=False)
760 for i in range(len(optional_props))
761 )
762 if required_props:
763 rule += ' )'
764 rule += ' )?'
765
766 rule += ' "}" space'
767
768 return rule
769
770 def format_grammar(self):
771 return '\n'.join(
772 f'{name} ::= {rule}'
773 for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0])
774 )
775
776
777def main(args_in = None):
778 parser = argparse.ArgumentParser(
779 description='''
780 Generates a grammar (suitable for use in ./llama-cli) that produces JSON conforming to a
781 given JSON schema. Only a subset of JSON schema features are supported; more may be
782 added in the future.
783 ''',
784 )
785 parser.add_argument(
786 '--prop-order',
787 default=[],
788 type=lambda s: s.split(','),
789 help='''
790 comma-separated property names defining the order of precedence for object properties;
791 properties not specified here are given lower precedence than those that are, and
792 are kept in their original order from the schema. Required properties are always
793 given precedence over optional properties.
794 '''
795 )
796 parser.add_argument(
797 '--allow-fetch',
798 action='store_true',
799 default=False,
800 help='Whether to allow fetching referenced schemas over HTTPS')
801 parser.add_argument(
802 '--dotall',
803 action='store_true',
804 default=False,
805 help='Whether to treat dot (".") as matching all chars including line breaks in regular expression patterns')
806 parser.add_argument(
807 '--raw-pattern',
808 action='store_true',
809 default=False,
810 help='Treats string patterns as raw patterns w/o quotes (or quote escapes)')
811
812 parser.add_argument('schema', help='file containing JSON schema ("-" for stdin)')
813 args = parser.parse_args(args_in)
814
815 if args.schema.startswith('https://'):
816 url = args.schema
817 import requests
818 schema = requests.get(url).json()
819 elif args.schema == '-':
820 url = 'stdin'
821 schema = json.load(sys.stdin)
822 else:
823 url = f'file://{args.schema}'
824 with open(args.schema) as f:
825 schema = json.load(f)
826 converter = SchemaConverter(
827 prop_order={name: idx for idx, name in enumerate(args.prop_order)},
828 allow_fetch=args.allow_fetch,
829 dotall=args.dotall,
830 raw_pattern=args.raw_pattern)
831 schema = converter.resolve_refs(schema, url)
832 converter.visit(schema, '')
833 print(converter.format_grammar())
834
835
836if __name__ == '__main__':
837 main()