summaryrefslogtreecommitdiff
path: root/llama.cpp/examples/json_schema_to_grammar.py
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/examples/json_schema_to_grammar.py')
-rwxr-xr-xllama.cpp/examples/json_schema_to_grammar.py837
1 files changed, 837 insertions, 0 deletions
diff --git a/llama.cpp/examples/json_schema_to_grammar.py b/llama.cpp/examples/json_schema_to_grammar.py
new file mode 100755
index 0000000..9fc90a3
--- /dev/null
+++ b/llama.cpp/examples/json_schema_to_grammar.py
@@ -0,0 +1,837 @@
+#!/usr/bin/env python3
+from __future__ import annotations
+
+import argparse
+import itertools
+import json
+import re
+import sys
+from typing import Any, List, Optional, Set, Tuple, Union
+
+def _build_repetition(item_rule, min_items, max_items, separator_rule=None):
+
+ if max_items == 0:
+ return ""
+
+ if min_items == 0 and max_items == 1:
+ return f'{item_rule}?'
+
+ if not separator_rule:
+ if min_items == 1 and max_items is None:
+ return f'{item_rule}+'
+ elif min_items == 0 and max_items is None:
+ return f'{item_rule}*'
+ else:
+ return f'{item_rule}{{{min_items},{max_items if max_items is not None else ""}}}'
+
+ result = item_rule + ' ' + _build_repetition(f'({separator_rule} {item_rule})', min_items - 1 if min_items > 0 else 0, max_items - 1 if max_items is not None else None)
+ return f'({result})?' if min_items == 0 else result
+
+def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], out: list, decimals_left: int = 16, top_level: bool = True):
+ has_min = min_value != None
+ has_max = max_value != None
+
+ def digit_range(from_char: str, to_char: str):
+ out.append("[")
+ if from_char == to_char:
+ out.append(from_char)
+ else:
+ out.append(from_char)
+ out.append("-")
+ out.append(to_char)
+ out.append("]")
+
+ def more_digits(min_digits: int, max_digits: int):
+ out.append("[0-9]")
+ if min_digits == max_digits and min_digits == 1:
+ return
+ out.append("{")
+ out.append(str(min_digits))
+ if max_digits != min_digits:
+ out.append(",")
+ if max_digits != sys.maxsize:
+ out.append(str(max_digits))
+ out.append("}")
+
+ def uniform_range(from_str: str, to_str: str):
+ i = 0
+ while i < len(from_str) and from_str[i] == to_str[i]:
+ i += 1
+ if i > 0:
+ out.append("\"")
+ out.append(from_str[:i])
+ out.append("\"")
+ if i < len(from_str):
+ if i > 0:
+ out.append(" ")
+ sub_len = len(from_str) - i - 1
+ if sub_len > 0:
+ from_sub = from_str[i+1:]
+ to_sub = to_str[i+1:]
+ sub_zeros = "0" * sub_len
+ sub_nines = "9" * sub_len
+
+ to_reached = False
+ out.append("(")
+ if from_sub == sub_zeros:
+ digit_range(from_str[i], chr(ord(to_str[i]) - 1))
+ out.append(" ")
+ more_digits(sub_len, sub_len)
+ else:
+ out.append("[")
+ out.append(from_str[i])
+ out.append("] ")
+ out.append("(")
+ uniform_range(from_sub, sub_nines)
+ out.append(")")
+ if ord(from_str[i]) < ord(to_str[i]) - 1:
+ out.append(" | ")
+ if to_sub == sub_nines:
+ digit_range(chr(ord(from_str[i]) + 1), to_str[i])
+ to_reached = True
+ else:
+ digit_range(chr(ord(from_str[i]) + 1), chr(ord(to_str[i]) - 1))
+ out.append(" ")
+ more_digits(sub_len, sub_len)
+ if not to_reached:
+ out.append(" | ")
+ digit_range(to_str[i], to_str[i])
+ out.append(" ")
+ uniform_range(sub_zeros, to_sub)
+ out.append(")")
+ else:
+ out.append("[")
+ out.append(from_str[i])
+ out.append("-")
+ out.append(to_str[i])
+ out.append("]")
+
+ if has_min and has_max:
+ if min_value < 0 and max_value < 0:
+ out.append("\"-\" (")
+ _generate_min_max_int(-max_value, -min_value, out, decimals_left, top_level=True)
+ out.append(")")
+ return
+
+ if min_value < 0:
+ out.append("\"-\" (")
+ _generate_min_max_int(0, -min_value, out, decimals_left, top_level=True)
+ out.append(") | ")
+ min_value = 0
+
+ min_s = str(min_value)
+ max_s = str(max_value)
+ min_digits = len(min_s)
+ max_digits = len(max_s)
+
+ for digits in range(min_digits, max_digits):
+ uniform_range(min_s, "9" * digits)
+ min_s = "1" + "0" * digits
+ out.append(" | ")
+ uniform_range(min_s, max_s)
+ return
+
+ less_decimals = max(decimals_left - 1, 1)
+
+ if has_min:
+ if min_value < 0:
+ out.append("\"-\" (")
+ _generate_min_max_int(None, -min_value, out, decimals_left, top_level=False)
+ out.append(") | [0] | [1-9] ")
+ more_digits(0, decimals_left - 1)
+ elif min_value == 0:
+ if top_level:
+ out.append("[0] | [1-9] ")
+ more_digits(0, less_decimals)
+ else:
+ more_digits(1, decimals_left)
+ elif min_value <= 9:
+ c = str(min_value)
+ range_start = '1' if top_level else '0'
+ if c > range_start:
+ digit_range(range_start, chr(ord(c) - 1))
+ out.append(" ")
+ more_digits(1, less_decimals)
+ out.append(" | ")
+ digit_range(c, "9")
+ out.append(" ")
+ more_digits(0, less_decimals)
+ else:
+ min_s = str(min_value)
+ length = len(min_s)
+ c = min_s[0]
+
+ if c > "1":
+ digit_range("1" if top_level else "0", chr(ord(c) - 1))
+ out.append(" ")
+ more_digits(length, less_decimals)
+ out.append(" | ")
+ digit_range(c, c)
+ out.append(" (")
+ _generate_min_max_int(int(min_s[1:]), None, out, less_decimals, top_level=False)
+ out.append(")")
+ if c < "9":
+ out.append(" | ")
+ digit_range(chr(ord(c) + 1), "9")
+ out.append(" ")
+ more_digits(length - 1, less_decimals)
+ return
+
+ if has_max:
+ if max_value >= 0:
+ if top_level:
+ out.append("\"-\" [1-9] ")
+ more_digits(0, less_decimals)
+ out.append(" | ")
+ _generate_min_max_int(0, max_value, out, decimals_left, top_level=True)
+ else:
+ out.append("\"-\" (")
+ _generate_min_max_int(-max_value, None, out, decimals_left, top_level=False)
+ out.append(")")
+ return
+
+ raise RuntimeError("At least one of min_value or max_value must be set")
+
+class BuiltinRule:
+ def __init__(self, content: str, deps: list | None = None):
+ self.content = content
+ self.deps = deps or []
+
+# Constraining spaces to prevent model "running away".
+SPACE_RULE = '| " " | "\\n"{1,2} [ \\t]{0,20}'
+
+PRIMITIVE_RULES = {
+ 'boolean' : BuiltinRule('("true" | "false") space', []),
+ 'decimal-part' : BuiltinRule('[0-9]{1,16}', []),
+ 'integral-part': BuiltinRule('[0] | [1-9] [0-9]{0,15}', []),
+ 'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']),
+ 'integer' : BuiltinRule('("-"? integral-part) space', ['integral-part']),
+ 'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']),
+ 'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']),
+ 'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']),
+ 'uuid' : BuiltinRule(r'"\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\"" space', []),
+ 'char' : BuiltinRule(r'[^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})', []),
+ 'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']),
+ 'null' : BuiltinRule('"null" space', []),
+}
+
+# TODO: support "uri", "email" string formats
+STRING_FORMAT_RULES = {
+ 'date' : BuiltinRule('[0-9]{4} "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []),
+ 'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []),
+ 'date-time' : BuiltinRule('date "T" time', ['date', 'time']),
+ 'date-string' : BuiltinRule('"\\"" date "\\"" space', ['date']),
+ 'time-string' : BuiltinRule('"\\"" time "\\"" space', ['time']),
+ 'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']),
+}
+
+DOTALL = '[\\U00000000-\\U0010FFFF]'
+DOT = '[^\\x0A\\x0D]'
+
+RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()])
+
+INVALID_RULE_CHARS_RE = re.compile(r'[^a-zA-Z0-9-]+')
+GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\\]')
+GRAMMAR_RANGE_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\]\-\\]')
+GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]', '\\': '\\\\'}
+
+NON_LITERAL_SET = set('|.()[]{}*+?')
+ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('^$.[]()|{}*+?')
+
+
+class SchemaConverter:
+ def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern):
+ self._prop_order = prop_order
+ self._allow_fetch = allow_fetch
+ self._dotall = dotall
+ self._raw_pattern = raw_pattern
+ self._rules = {
+ 'space': SPACE_RULE,
+ }
+ self._refs = {}
+ self._refs_being_resolved = set()
+
+ def _format_literal(self, literal):
+ escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
+ lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)) or m.group(0), literal
+ )
+ return f'"{escaped}"'
+
+ def not_literal(self, literal: str, dotall: bool = True, maybe_escaped_underscores = False) -> str:
+ '''
+ not_literal('a') -> '[^a]'
+ not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?'
+ '''
+ assert len(literal) > 0, 'Empty literal not supported'
+ def recurse(i: int):
+ c = literal[i]
+ if maybe_escaped_underscores and c == '_':
+ yield f'[^{c}\\\\]'
+ yield ' | '
+ yield f'"\\\\"? "{c}"'
+ else:
+ yield f'[^{c}]'
+ if i < len(literal) - 1:
+ yield ' | '
+ yield self._format_literal(c)
+ yield ' ('
+ yield from recurse(i + 1)
+ yield ')?'
+
+ return ''.join(('(', *recurse(0), ')'))
+
+ def _not_strings(self, strings):
+ class TrieNode:
+ def __init__(self):
+ self.children = {}
+ self.is_end_of_string = False
+
+ def insert(self, string):
+ node = self
+ for c in string:
+ node = node.children.setdefault(c, TrieNode())
+ node.is_end_of_string = True
+
+ trie = TrieNode()
+ for s in strings:
+ trie.insert(s)
+
+ char_rule = self._add_primitive('char', PRIMITIVE_RULES['char'])
+ out = ['["] ( ']
+
+ def visit(node):
+ rejects = []
+ first = True
+ for c in sorted(node.children.keys()):
+ child = node.children[c]
+ rejects.append(c)
+ if first:
+ first = False
+ else:
+ out.append(' | ')
+ out.append(f'[{c}]')
+ if child.children:
+ out.append(f' (')
+ visit(child)
+ out.append(')')
+ elif child.is_end_of_string:
+ out.append(f' {char_rule}+')
+ if node.children:
+ if not first:
+ out.append(' | ')
+ out.append(f'[^"{"".join(rejects)}] {char_rule}*')
+ visit(trie)
+
+ out.append(f' ){"" if trie.is_end_of_string else "?"} ["] space')
+ return ''.join(out)
+
+ def _add_rule(self, name, rule):
+ esc_name = INVALID_RULE_CHARS_RE.sub('-', name)
+ if esc_name not in self._rules or self._rules[esc_name] == rule:
+ key = esc_name
+ else:
+ i = 0
+ while f'{esc_name}{i}' in self._rules and self._rules[f'{esc_name}{i}'] != rule:
+ i += 1
+ key = f'{esc_name}{i}'
+ self._rules[key] = rule
+ return key
+
+ def resolve_refs(self, schema: dict, url: str):
+ '''
+ Resolves all $ref fields in the given schema, fetching any remote schemas,
+ replacing $ref with absolute reference URL and populating self._refs with the
+ respective referenced (sub)schema dictionaries.
+ '''
+ def visit(n: dict):
+ if isinstance(n, list):
+ return [visit(x) for x in n]
+ elif isinstance(n, dict):
+ ref = n.get('$ref')
+ if ref is not None and ref not in self._refs:
+ if ref.startswith('https://'):
+ assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)'
+ import requests
+
+ frag_split = ref.split('#')
+ base_url = frag_split[0]
+
+ target = self._refs.get(base_url)
+ if target is None:
+ target = self.resolve_refs(requests.get(ref).json(), base_url)
+ self._refs[base_url] = target
+
+ if len(frag_split) == 1 or frag_split[-1] == '':
+ return target
+ elif ref.startswith('#/'):
+ target = schema
+ ref = f'{url}{ref}'
+ n['$ref'] = ref
+ else:
+ raise ValueError(f'Unsupported ref {ref}')
+
+ for sel in ref.split('#')[-1].split('/')[1:]:
+ assert target is not None, f'Error resolving ref {ref}: {sel} not in {target}'
+ if isinstance(target, list):
+ try:
+ sel_index = int(sel)
+ except ValueError:
+ raise ValueError(f'Error resolving ref {ref}: {sel} not in {target}')
+ assert 0 <= sel_index < len(target), f'Error resolving ref {ref}: {sel} not in {target}'
+ target = target[sel_index]
+ else:
+ assert sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
+ target = target[sel]
+
+ self._refs[ref] = target
+ else:
+ for v in n.values():
+ visit(v)
+
+ return n
+ return visit(schema)
+
+ def _generate_union_rule(self, name, alt_schemas):
+ return ' | '.join((
+ self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}')
+ for i, alt_schema in enumerate(alt_schemas)
+ ))
+
+ def _visit_pattern(self, pattern, name):
+ '''
+ Transforms a regular expression pattern into a GBNF rule.
+
+ Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions
+ Output: https://github.com/ggml-org/llama.cpp/blob/master/grammars/README.md
+
+ Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers.
+
+ Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which
+ we define sub-rules to keep the output lean.
+ '''
+
+ assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"'
+ pattern = pattern[1:-1]
+ sub_rule_ids = {}
+
+ i = 0
+ length = len(pattern)
+
+ def to_rule(s: tuple[str, bool]) -> str:
+ (txt, is_literal) = s
+ return "\"" + txt + "\"" if is_literal else txt
+
+ def transform() -> tuple[str, bool]:
+ '''
+ Parse a unit at index i (advancing it), and return its string representation + whether it's a literal.
+ '''
+ nonlocal i
+ nonlocal pattern
+ nonlocal sub_rule_ids
+
+ start = i
+ # For each component of this sequence, store its string representation and whether it's a literal.
+ # We only need a flat structure here to apply repetition operators to the last item, and
+ # to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially
+ # (GBNF's syntax is luckily very close to regular expressions!)
+ seq: list[tuple[str, bool]] = []
+
+ def get_dot():
+ if self._dotall:
+ rule = DOTALL
+ else:
+ # Accept any character... except \n and \r line break chars (\x0A and \xOD)
+ rule = DOT
+ return self._add_rule(f'dot', rule)
+
+ def join_seq():
+ nonlocal seq
+ ret = []
+ for is_literal, g in itertools.groupby(seq, lambda x: x[1]):
+ if is_literal:
+ ret.append((''.join(x[0] for x in g), True))
+ else:
+ ret.extend(g)
+ if len(ret) == 1:
+ return ret[0]
+ return (' '.join(to_rule(x) for x in seq), False)
+
+ while i < length:
+ c = pattern[i]
+ if c == '.':
+ seq.append((get_dot(), False))
+ i += 1
+ elif c == '(':
+ i += 1
+ if i < length:
+ assert pattern[i] != '?', f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/'
+ seq.append((f'({to_rule(transform())})', False))
+ elif c == ')':
+ i += 1
+ assert start > 0 and pattern[start-1] == '(', f'Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}'
+ return join_seq()
+ elif c == '[':
+ square_brackets = c
+ i += 1
+ while i < length and pattern[i] != ']':
+ if pattern[i] == '\\':
+ square_brackets += pattern[i:i+2]
+ i += 2
+ else:
+ square_brackets += pattern[i]
+ i += 1
+ assert i < length, f'Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}'
+ square_brackets += ']'
+ i += 1
+ seq.append((square_brackets, False))
+ elif c == '|':
+ seq.append(('|', False))
+ i += 1
+ elif c in ('*', '+', '?'):
+ seq[-1] = (to_rule(seq[-1]) + c, False)
+ i += 1
+ elif c == '{':
+ curly_brackets = c
+ i += 1
+ while i < length and pattern[i] != '}':
+ curly_brackets += pattern[i]
+ i += 1
+ assert i < length, f'Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}'
+ curly_brackets += '}'
+ i += 1
+ nums = [s.strip() for s in curly_brackets[1:-1].split(',')]
+ min_times = 0
+ max_times = None
+ try:
+ if len(nums) == 1:
+ min_times = int(nums[0])
+ max_times = min_times
+ else:
+ assert len(nums) == 2
+ min_times = int(nums[0]) if nums[0] else 0
+ max_times = int(nums[1]) if nums[1] else None
+ except ValueError:
+ raise ValueError(f'Invalid quantifier {curly_brackets} in /{pattern}/')
+
+ (sub, sub_is_literal) = seq[-1]
+
+ if not sub_is_literal:
+ id = sub_rule_ids.get(sub)
+ if id is None:
+ id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub)
+ sub_rule_ids[sub] = id
+ sub = id
+
+ seq[-1] = (_build_repetition(f'"{sub}"' if sub_is_literal else sub, min_times, max_times), False)
+ else:
+ literal = ''
+ while i < length:
+ if pattern[i] == '\\' and i < length - 1:
+ next = pattern[i + 1]
+ if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS:
+ i += 1
+ literal += pattern[i]
+ i += 1
+ else:
+ literal += pattern[i:i+2]
+ i += 2
+ elif pattern[i] == '"' and not self._raw_pattern:
+ literal += '\\"'
+ i += 1
+ elif pattern[i] not in NON_LITERAL_SET and \
+ (i == length - 1 or literal == '' or pattern[i+1] == '.' or pattern[i+1] not in NON_LITERAL_SET):
+ literal += pattern[i]
+ i += 1
+ else:
+ break
+ if literal:
+ seq.append((literal, True))
+
+ return join_seq()
+
+ return self._add_rule(
+ name,
+ to_rule(transform()) if self._raw_pattern \
+ else "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space")
+
+
+ def _resolve_ref(self, ref):
+ ref_fragment = ref.split('#')[-1]
+ ref_name = 'ref' + re.sub(r'[^a-zA-Z0-9-]+', '-', ref_fragment)
+ if ref_name not in self._rules and ref not in self._refs_being_resolved:
+ self._refs_being_resolved.add(ref)
+ resolved = self._refs[ref]
+ ref_name = self.visit(resolved, ref_name)
+ self._refs_being_resolved.remove(ref)
+ return ref_name
+
+ def _generate_constant_rule(self, value):
+ return self._format_literal(json.dumps(value))
+
+ def visit(self, schema, name):
+ schema_type = schema.get('type')
+ schema_format = schema.get('format')
+ rule_name = name + '-' if name in RESERVED_NAMES else name or 'root'
+
+ if (ref := schema.get('$ref')) is not None:
+ return self._add_rule(rule_name, self._resolve_ref(ref))
+
+ elif 'oneOf' in schema or 'anyOf' in schema:
+ return self._add_rule(rule_name, self._generate_union_rule(name, schema.get('oneOf') or schema['anyOf']))
+
+ elif isinstance(schema_type, list):
+ return self._add_rule(rule_name, self._generate_union_rule(name, [{**schema, 'type': t} for t in schema_type]))
+
+ elif 'const' in schema:
+ return self._add_rule(rule_name, self._generate_constant_rule(schema['const']) + ' space')
+
+ elif 'enum' in schema:
+ rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + ') space'
+ return self._add_rule(rule_name, rule)
+
+ elif schema_type in (None, 'object') and \
+ ('properties' in schema or \
+ ('additionalProperties' in schema and schema['additionalProperties'] is not True)):
+ required = set(schema.get('required', []))
+ properties = list(schema.get('properties', {}).items())
+ return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties')))
+
+ elif schema_type in (None, 'object', 'string') and 'allOf' in schema:
+ required = set()
+ properties = []
+ enum_sets = []
+ hybrid_name = name
+ def add_component(comp_schema, is_required):
+ if (ref := comp_schema.get('$ref')) is not None:
+ comp_schema = self._refs[ref]
+
+ if 'properties' in comp_schema:
+ for prop_name, prop_schema in comp_schema['properties'].items():
+ properties.append((prop_name, prop_schema))
+ if is_required:
+ required.add(prop_name)
+
+ if 'enum' in comp_schema:
+ enum_sets.append(set(comp_schema['enum']))
+
+ for t in schema['allOf']:
+ if 'anyOf' in t:
+ for tt in t['anyOf']:
+ add_component(tt, is_required=False)
+ else:
+ add_component(t, is_required=True)
+
+ if enum_sets:
+ enum_intersection = enum_sets[0]
+ for s in enum_sets[1:]:
+ enum_intersection &= s
+
+ if enum_intersection:
+ rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in sorted(enum_intersection))) + ') space'
+ return self._add_rule(rule_name, rule)
+
+ return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None))
+
+ elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
+ items = schema.get('items') or schema['prefixItems']
+ if isinstance(items, list):
+ return self._add_rule(
+ rule_name,
+ '"[" space ' +
+ ' "," space '.join(
+ self.visit(item, f'{name}{"-" if name else ""}tuple-{i}')
+ for i, item in enumerate(items)) +
+ ' "]" space')
+ else:
+ item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item')
+ min_items = schema.get("minItems", 0)
+ max_items = schema.get("maxItems")
+ return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space')
+
+ elif schema_type in (None, 'string') and 'pattern' in schema:
+ return self._visit_pattern(schema['pattern'], rule_name)
+
+ elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''):
+ return self._add_primitive(
+ 'root' if rule_name == 'root' else schema_format,
+ PRIMITIVE_RULES['uuid']
+ )
+
+ elif schema_type in (None, 'string') and f'{schema_format}-string' in STRING_FORMAT_RULES:
+ prim_name = f'{schema_format}-string'
+ return self._add_rule(rule_name, self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name]))
+
+ elif schema_type == 'string' and ('minLength' in schema or 'maxLength' in schema):
+ char_rule = self._add_primitive('char', PRIMITIVE_RULES['char'])
+ min_len = schema.get('minLength', 0)
+ max_len = schema.get('maxLength')
+
+ return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space')
+
+ elif schema_type in (None, 'integer') and \
+ ('minimum' in schema or 'exclusiveMinimum' in schema or 'maximum' in schema or 'exclusiveMaximum' in schema):
+ min_value = None
+ max_value = None
+ if 'minimum' in schema:
+ min_value = schema['minimum']
+ elif 'exclusiveMinimum' in schema:
+ min_value = schema['exclusiveMinimum'] + 1
+ if 'maximum' in schema:
+ max_value = schema['maximum']
+ elif 'exclusiveMaximum' in schema:
+ max_value = schema['exclusiveMaximum'] - 1
+
+ out = ["("]
+ _generate_min_max_int(min_value, max_value, out)
+ out.append(") space")
+ return self._add_rule(rule_name, ''.join(out))
+
+ elif (schema_type == 'object') or (len(schema) == 0):
+ return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object']))
+
+ else:
+ assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}'
+ # TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
+ return self._add_primitive('root' if rule_name == 'root' else schema_type, PRIMITIVE_RULES[schema_type])
+
+ def _add_primitive(self, name: str, rule: BuiltinRule):
+ n = self._add_rule(name, rule.content)
+
+ for dep in rule.deps:
+ dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep)
+ assert dep_rule, f'Rule {dep} not known'
+ if dep not in self._rules:
+ self._add_primitive(dep, dep_rule)
+ return n
+
+ def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Optional[Union[bool, Any]]):
+ prop_order = self._prop_order
+ # sort by position in prop_order (if specified) then by original order
+ sorted_props = [kv[0] for _, kv in sorted(enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]))]
+
+ prop_kv_rule_names = {}
+ for prop_name, prop_schema in properties:
+ prop_rule_name = self.visit(prop_schema, f'{name}{"-" if name else ""}{prop_name}')
+ prop_kv_rule_names[prop_name] = self._add_rule(
+ f'{name}{"-" if name else ""}{prop_name}-kv',
+ fr'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}'
+ )
+ required_props = [k for k in sorted_props if k in required]
+ optional_props = [k for k in sorted_props if k not in required]
+
+ if additional_properties is not None and additional_properties != False:
+ sub_name = f'{name}{"-" if name else ""}additional'
+ value_rule = self.visit(additional_properties, f'{sub_name}-value') if isinstance(additional_properties, dict) else \
+ self._add_primitive('value', PRIMITIVE_RULES['value'])
+ key_rule = self._add_primitive('string', PRIMITIVE_RULES['string']) if not sorted_props \
+ else self._add_rule(f'{sub_name}-k', self._not_strings(sorted_props))
+
+ prop_kv_rule_names["*"] = self._add_rule(
+ f'{sub_name}-kv',
+ f'{key_rule} ":" space {value_rule}'
+ )
+ optional_props.append("*")
+
+ rule = '"{" space '
+ rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props)
+
+ if optional_props:
+ rule += ' ('
+ if required_props:
+ rule += ' "," space ( '
+
+ def get_recursive_refs(ks, first_is_optional):
+ [k, *rest] = ks
+ kv_rule_name = prop_kv_rule_names[k]
+ comma_ref = f'( "," space {kv_rule_name} )'
+ if first_is_optional:
+ res = comma_ref + ('*' if k == '*' else '?')
+ else:
+ res = kv_rule_name + (' ' + comma_ref + "*" if k == '*' else '')
+ if len(rest) > 0:
+ res += ' ' + self._add_rule(
+ f'{name}{"-" if name else ""}{k}-rest',
+ get_recursive_refs(rest, first_is_optional=True)
+ )
+ return res
+
+ rule += ' | '.join(
+ get_recursive_refs(optional_props[i:], first_is_optional=False)
+ for i in range(len(optional_props))
+ )
+ if required_props:
+ rule += ' )'
+ rule += ' )?'
+
+ rule += ' "}" space'
+
+ return rule
+
+ def format_grammar(self):
+ return '\n'.join(
+ f'{name} ::= {rule}'
+ for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0])
+ )
+
+
+def main(args_in = None):
+ parser = argparse.ArgumentParser(
+ description='''
+ Generates a grammar (suitable for use in ./llama-cli) that produces JSON conforming to a
+ given JSON schema. Only a subset of JSON schema features are supported; more may be
+ added in the future.
+ ''',
+ )
+ parser.add_argument(
+ '--prop-order',
+ default=[],
+ type=lambda s: s.split(','),
+ help='''
+ comma-separated property names defining the order of precedence for object properties;
+ properties not specified here are given lower precedence than those that are, and
+ are kept in their original order from the schema. Required properties are always
+ given precedence over optional properties.
+ '''
+ )
+ parser.add_argument(
+ '--allow-fetch',
+ action='store_true',
+ default=False,
+ help='Whether to allow fetching referenced schemas over HTTPS')
+ parser.add_argument(
+ '--dotall',
+ action='store_true',
+ default=False,
+ help='Whether to treat dot (".") as matching all chars including line breaks in regular expression patterns')
+ parser.add_argument(
+ '--raw-pattern',
+ action='store_true',
+ default=False,
+ help='Treats string patterns as raw patterns w/o quotes (or quote escapes)')
+
+ parser.add_argument('schema', help='file containing JSON schema ("-" for stdin)')
+ args = parser.parse_args(args_in)
+
+ if args.schema.startswith('https://'):
+ url = args.schema
+ import requests
+ schema = requests.get(url).json()
+ elif args.schema == '-':
+ url = 'stdin'
+ schema = json.load(sys.stdin)
+ else:
+ url = f'file://{args.schema}'
+ with open(args.schema) as f:
+ schema = json.load(f)
+ converter = SchemaConverter(
+ prop_order={name: idx for idx, name in enumerate(args.prop_order)},
+ allow_fetch=args.allow_fetch,
+ dotall=args.dotall,
+ raw_pattern=args.raw_pattern)
+ schema = converter.resolve_refs(schema, url)
+ converter.visit(schema, '')
+ print(converter.format_grammar())
+
+
+if __name__ == '__main__':
+ main()