1package sitter
   2
   3// #include "bindings.h"
   4import "C"
   5
   6import (
   7	"context"
   8	"errors"
   9	"fmt"
  10	"math"
  11	"reflect"
  12	"regexp"
  13	"runtime"
  14	"strings"
  15	"sync"
  16	"sync/atomic"
  17	"unsafe"
  18)
  19
  20// maintain a map of read functions that can be called from C
  21var readFuncs = &readFuncsMap{funcs: make(map[int]ReadFunc)}
  22
  23// Parse is a shortcut for parsing bytes of source code,
  24// returns root node
  25//
  26// Deprecated: use ParseCtx instead
  27func Parse(content []byte, lang *Language) *Node {
  28	n, _ := ParseCtx(context.Background(), content, lang)
  29	return n
  30}
  31
  32// ParseCtx is a shortcut for parsing bytes of source code,
  33// returns root node
  34func ParseCtx(ctx context.Context, content []byte, lang *Language) (*Node, error) {
  35	p := NewParser()
  36	p.SetLanguage(lang)
  37	tree, err := p.ParseCtx(ctx, nil, content)
  38	if err != nil {
  39		return nil, err
  40	}
  41
  42	return tree.RootNode(), nil
  43}
  44
  45// Parser produces concrete syntax tree based on source code using Language
  46type Parser struct {
  47	isClosed bool
  48	c        *C.TSParser
  49	cancel   *uintptr
  50}
  51
  52// NewParser creates new Parser
  53func NewParser() *Parser {
  54	cancel := uintptr(0)
  55	p := &Parser{c: C.ts_parser_new(), cancel: &cancel}
  56	C.ts_parser_set_cancellation_flag(p.c, (*C.size_t)(unsafe.Pointer(p.cancel)))
  57	runtime.SetFinalizer(p, (*Parser).Close)
  58	return p
  59}
  60
  61// SetLanguage assignes Language to a parser
  62func (p *Parser) SetLanguage(lang *Language) {
  63	cLang := (*C.struct_TSLanguage)(lang.ptr)
  64	C.ts_parser_set_language(p.c, cLang)
  65}
  66
  67// ReadFunc is a function to retrieve a chunk of text at a given byte offset and (row, column) position
  68// it should return nil to indicate the end of the document
  69type ReadFunc func(offset uint32, position Point) []byte
  70
  71// InputEncoding is a encoding of the text to parse
  72type InputEncoding int
  73
  74const (
  75	InputEncodingUTF8 InputEncoding = iota
  76	InputEncodingUTF16
  77)
  78
  79// Input defines parameters for parse method
  80type Input struct {
  81	Read     ReadFunc
  82	Encoding InputEncoding
  83}
  84
  85var (
  86	ErrOperationLimit = errors.New("operation limit was hit")
  87	ErrNoLanguage     = errors.New("cannot parse without language")
  88)
  89
  90// Parse produces new Tree from content using old tree
  91//
  92// Deprecated: use ParseCtx instead
  93func (p *Parser) Parse(oldTree *Tree, content []byte) *Tree {
  94	t, _ := p.ParseCtx(context.Background(), oldTree, content)
  95	return t
  96}
  97
  98// ParseCtx produces new Tree from content using old tree
  99func (p *Parser) ParseCtx(ctx context.Context, oldTree *Tree, content []byte) (*Tree, error) {
 100	var BaseTree *C.TSTree
 101	if oldTree != nil {
 102		BaseTree = oldTree.c
 103	}
 104
 105	parseComplete := make(chan struct{})
 106
 107	// run goroutine only if context is cancelable to avoid performance impact
 108	if ctx.Done() != nil {
 109		go func() {
 110			select {
 111			case <-ctx.Done():
 112				atomic.StoreUintptr(p.cancel, 1)
 113			case <-parseComplete:
 114				return
 115			}
 116		}()
 117	}
 118
 119	input := C.CBytes(content)
 120	BaseTree = C.ts_parser_parse_string(p.c, BaseTree, (*C.char)(input), C.uint32_t(len(content)))
 121	close(parseComplete)
 122	C.free(input)
 123
 124	return p.convertTSTree(ctx, BaseTree)
 125}
 126
 127// ParseInput produces new Tree by reading from a callback defined in input
 128// it is useful if your data is stored in specialized data structure
 129// as it will avoid copying the data into []bytes
 130// and faster access to edited part of the data
 131func (p *Parser) ParseInput(oldTree *Tree, input Input) *Tree {
 132	t, _ := p.ParseInputCtx(context.Background(), oldTree, input)
 133	return t
 134}
 135
 136// ParseInputCtx produces new Tree by reading from a callback defined in input
 137// it is useful if your data is stored in specialized data structure
 138// as it will avoid copying the data into []bytes
 139// and faster access to edited part of the data
 140func (p *Parser) ParseInputCtx(ctx context.Context, oldTree *Tree, input Input) (*Tree, error) {
 141	var BaseTree *C.TSTree
 142	if oldTree != nil {
 143		BaseTree = oldTree.c
 144	}
 145
 146	funcID := readFuncs.register(input.Read)
 147	BaseTree = C.call_ts_parser_parse(p.c, BaseTree, C.int(funcID), C.TSInputEncoding(input.Encoding))
 148	readFuncs.unregister(funcID)
 149
 150	return p.convertTSTree(ctx, BaseTree)
 151}
 152
 153// convertTSTree converts the tree-sitter response into a *Tree or an error.
 154//
 155// tree-sitter can fail for 3 reasons:
 156// - cancelation
 157// - operation limit hit
 158// - no language set
 159//
 160// We check for all those conditions if ther return value is nil.
 161// see: https://github.com/tree-sitter/tree-sitter/blob/7890a29db0b186b7b21a0a95d99fa6c562b8316b/lib/include/tree_sitter/api.h#L209-L246
 162func (p *Parser) convertTSTree(ctx context.Context, tsTree *C.TSTree) (*Tree, error) {
 163	if tsTree == nil {
 164		if ctx.Err() != nil {
 165			// reset cancellation flag so the parse can be re-used
 166			atomic.StoreUintptr(p.cancel, 0)
 167			// context cancellation caused a timeout, return that error
 168			return nil, ctx.Err()
 169		}
 170
 171		if C.ts_parser_language(p.c) == nil {
 172			return nil, ErrNoLanguage
 173		}
 174
 175		return nil, ErrOperationLimit
 176	}
 177
 178	return p.newTree(tsTree), nil
 179}
 180
 181// OperationLimit returns the duration in microseconds that parsing is allowed to take
 182func (p *Parser) OperationLimit() int {
 183	return int(C.ts_parser_timeout_micros(p.c))
 184}
 185
 186// SetOperationLimit limits the maximum duration in microseconds that parsing should be allowed to take before halting
 187func (p *Parser) SetOperationLimit(limit int) {
 188	C.ts_parser_set_timeout_micros(p.c, C.uint64_t(limit))
 189}
 190
 191// Reset causes the parser to parse from scratch on the next call to parse, instead of resuming
 192// so that it sees the changes to the beginning of the source code.
 193func (p *Parser) Reset() {
 194	C.ts_parser_reset(p.c)
 195}
 196
 197// SetIncludedRanges sets text ranges of a file
 198func (p *Parser) SetIncludedRanges(ranges []Range) {
 199	cRanges := make([]C.TSRange, len(ranges))
 200	for i, r := range ranges {
 201		cRanges[i] = C.TSRange{
 202			start_point: C.TSPoint{
 203				row:    C.uint32_t(r.StartPoint.Row),
 204				column: C.uint32_t(r.StartPoint.Column),
 205			},
 206			end_point: C.TSPoint{
 207				row:    C.uint32_t(r.EndPoint.Row),
 208				column: C.uint32_t(r.EndPoint.Column),
 209			},
 210			start_byte: C.uint32_t(r.StartByte),
 211			end_byte:   C.uint32_t(r.EndByte),
 212		}
 213	}
 214	C.ts_parser_set_included_ranges(p.c, (*C.TSRange)(unsafe.Pointer(&cRanges[0])), C.uint(len(ranges)))
 215}
 216
 217// Debug enables debug output to stderr
 218func (p *Parser) Debug() {
 219	logger := C.stderr_logger_new(true)
 220	C.ts_parser_set_logger(p.c, logger)
 221}
 222
 223// Close should be called to ensure that all the memory used by the parse is freed.
 224//
 225// As the constructor in go-tree-sitter would set this func call through runtime.SetFinalizer,
 226// parser.Close() will be called by Go's garbage collector and users would not have to call this manually.
 227func (p *Parser) Close() {
 228	if !p.isClosed {
 229		C.ts_parser_delete(p.c)
 230	}
 231
 232	p.isClosed = true
 233}
 234
 235type Point struct {
 236	Row    uint32
 237	Column uint32
 238}
 239
 240type Range struct {
 241	StartPoint Point
 242	EndPoint   Point
 243	StartByte  uint32
 244	EndByte    uint32
 245}
 246
 247// we use cache for nodes on normal tree object
 248// it prevent run of SetFinalizer as it introduces cycle
 249// we can workaround it using separate object
 250// for details see: https://github.com/golang/go/issues/7358#issuecomment-66091558
 251type BaseTree struct {
 252	c        *C.TSTree
 253	isClosed bool
 254}
 255
 256// newTree creates a new tree object from a C pointer. The function will set a finalizer for the object,
 257// thus no free is needed for it.
 258func (p *Parser) newTree(c *C.TSTree) *Tree {
 259	base := &BaseTree{c: c}
 260	runtime.SetFinalizer(base, (*BaseTree).Close)
 261
 262	newTree := &Tree{p: p, BaseTree: base, cache: make(map[C.TSNode]*Node)}
 263	return newTree
 264}
 265
 266// Tree represents the syntax tree of an entire source code file
 267// Note: Tree instances are not thread safe;
 268// you must copy a tree if you want to use it on multiple threads simultaneously.
 269type Tree struct {
 270	*BaseTree
 271
 272	// p is a pointer to a Parser that produced the Tree. Only used to keep Parser alive.
 273	// Otherwise Parser may be GC'ed (and deleted by the finalizer) while some Tree objects are still in use.
 274	p *Parser
 275
 276	// most probably better save node.id
 277	cache map[C.TSNode]*Node
 278}
 279
 280// Copy returns a new copy of a tree
 281func (t *Tree) Copy() *Tree {
 282	return t.p.newTree(C.ts_tree_copy(t.c))
 283}
 284
 285// RootNode returns root node of a tree
 286func (t *Tree) RootNode() *Node {
 287	ptr := C.ts_tree_root_node(t.c)
 288	return t.cachedNode(ptr)
 289}
 290
 291func (t *Tree) cachedNode(ptr C.TSNode) *Node {
 292	if ptr.id == nil {
 293		return nil
 294	}
 295
 296	if n, ok := t.cache[ptr]; ok {
 297		return n
 298	}
 299
 300	n := &Node{ptr, t}
 301	t.cache[ptr] = n
 302	return n
 303}
 304
 305// Close should be called to ensure that all the memory used by the tree is freed.
 306//
 307// As the constructor in go-tree-sitter would set this func call through runtime.SetFinalizer,
 308// parser.Close() will be called by Go's garbage collector and users would not have to call this manually.
 309func (t *BaseTree) Close() {
 310	if !t.isClosed {
 311		C.ts_tree_delete(t.c)
 312	}
 313
 314	t.isClosed = true
 315}
 316
 317type EditInput struct {
 318	StartIndex  uint32
 319	OldEndIndex uint32
 320	NewEndIndex uint32
 321	StartPoint  Point
 322	OldEndPoint Point
 323	NewEndPoint Point
 324}
 325
 326func (i EditInput) c() *C.TSInputEdit {
 327	return &C.TSInputEdit{
 328		start_byte:   C.uint32_t(i.StartIndex),
 329		old_end_byte: C.uint32_t(i.OldEndIndex),
 330		new_end_byte: C.uint32_t(i.NewEndIndex),
 331		start_point: C.TSPoint{
 332			row:    C.uint32_t(i.StartPoint.Row),
 333			column: C.uint32_t(i.StartPoint.Column),
 334		},
 335		old_end_point: C.TSPoint{
 336			row:    C.uint32_t(i.OldEndPoint.Row),
 337			column: C.uint32_t(i.OldEndPoint.Column),
 338		},
 339		new_end_point: C.TSPoint{
 340			row:    C.uint32_t(i.OldEndPoint.Row),
 341			column: C.uint32_t(i.OldEndPoint.Column),
 342		},
 343	}
 344}
 345
 346// Edit the syntax tree to keep it in sync with source code that has been edited.
 347func (t *Tree) Edit(i EditInput) {
 348	C.ts_tree_edit(t.c, i.c())
 349}
 350
 351// Language defines how to parse a particular programming language
 352type Language struct {
 353	ptr unsafe.Pointer
 354}
 355
 356// NewLanguage creates new Language from c pointer
 357func NewLanguage(ptr unsafe.Pointer) *Language {
 358	return &Language{ptr}
 359}
 360
 361// SymbolName returns a node type string for the given Symbol.
 362func (l *Language) SymbolName(s Symbol) string {
 363	return C.GoString(C.ts_language_symbol_name((*C.TSLanguage)(l.ptr), s))
 364}
 365
 366// SymbolType returns named, anonymous, or a hidden type for a Symbol.
 367func (l *Language) SymbolType(s Symbol) SymbolType {
 368	return SymbolType(C.ts_language_symbol_type((*C.TSLanguage)(l.ptr), s))
 369}
 370
 371// SymbolCount returns the number of distinct field names in the language.
 372func (l *Language) SymbolCount() uint32 {
 373	return uint32(C.ts_language_symbol_count((*C.TSLanguage)(l.ptr)))
 374}
 375
 376func (l *Language) FieldName(idx int) string {
 377	return C.GoString(C.ts_language_field_name_for_id((*C.TSLanguage)(l.ptr), C.ushort(idx)))
 378}
 379
 380// Node represents a single node in the syntax tree
 381// It tracks its start and end positions in the source code,
 382// as well as its relation to other nodes like its parent, siblings and children.
 383type Node struct {
 384	c C.TSNode
 385	t *Tree // keep pointer on tree because node is valid only as long as tree is
 386}
 387
 388type Symbol = C.TSSymbol
 389
 390type SymbolType int
 391
 392const (
 393	SymbolTypeRegular SymbolType = iota
 394	SymbolTypeAnonymous
 395	SymbolTypeAuxiliary
 396)
 397
 398var symbolTypeNames = []string{
 399	"Regular",
 400	"Anonymous",
 401	"Auxiliary",
 402}
 403
 404func (t SymbolType) String() string {
 405	return symbolTypeNames[t]
 406}
 407
 408func (n Node) ID() uintptr {
 409	return uintptr(n.c.id)
 410}
 411
 412// StartByte returns the node's start byte.
 413func (n Node) StartByte() uint32 {
 414	return uint32(C.ts_node_start_byte(n.c))
 415}
 416
 417// EndByte returns the node's end byte.
 418func (n Node) EndByte() uint32 {
 419	return uint32(C.ts_node_end_byte(n.c))
 420}
 421
 422// StartPoint returns the node's start position in terms of rows and columns.
 423func (n Node) StartPoint() Point {
 424	p := C.ts_node_start_point(n.c)
 425	return Point{
 426		Row:    uint32(p.row),
 427		Column: uint32(p.column),
 428	}
 429}
 430
 431// EndPoint returns the node's end position in terms of rows and columns.
 432func (n Node) EndPoint() Point {
 433	p := C.ts_node_end_point(n.c)
 434	return Point{
 435		Row:    uint32(p.row),
 436		Column: uint32(p.column),
 437	}
 438}
 439
 440func (n Node) Range() Range {
 441	return Range{
 442		StartByte:  n.StartByte(),
 443		EndByte:    n.EndByte(),
 444		StartPoint: n.StartPoint(),
 445		EndPoint:   n.EndPoint(),
 446	}
 447}
 448
 449// Symbol returns the node's type as a Symbol.
 450func (n Node) Symbol() Symbol {
 451	return C.ts_node_symbol(n.c)
 452}
 453
 454// Type returns the node's type as a string.
 455func (n Node) Type() string {
 456	return C.GoString(C.ts_node_type(n.c))
 457}
 458
 459// String returns an S-expression representing the node as a string.
 460func (n Node) String() string {
 461	ptr := C.ts_node_string(n.c)
 462	defer C.free(unsafe.Pointer(ptr))
 463	return C.GoString(ptr)
 464}
 465
 466// Equal checks if two nodes are identical.
 467func (n Node) Equal(other *Node) bool {
 468	return bool(C.ts_node_eq(n.c, other.c))
 469}
 470
 471// IsNull checks if the node is null.
 472func (n Node) IsNull() bool {
 473	return bool(C.ts_node_is_null(n.c))
 474}
 475
 476// IsNamed checks if the node is *named*.
 477// Named nodes correspond to named rules in the grammar,
 478// whereas *anonymous* nodes correspond to string literals in the grammar.
 479func (n Node) IsNamed() bool {
 480	return bool(C.ts_node_is_named(n.c))
 481}
 482
 483// IsMissing checks if the node is *missing*.
 484// Missing nodes are inserted by the parser in order to recover from certain kinds of syntax errors.
 485func (n Node) IsMissing() bool {
 486	return bool(C.ts_node_is_missing(n.c))
 487}
 488
 489// IsExtra checks if the node is *extra*.
 490// Extra nodes represent things like comments, which are not required the grammar, but can appear anywhere.
 491func (n Node) IsExtra() bool {
 492	return bool(C.ts_node_is_extra(n.c))
 493}
 494
 495// IsError checks if the node is a syntax error.
 496// Syntax errors represent parts of the code that could not be incorporated into a valid syntax tree.
 497func (n Node) IsError() bool {
 498	return n.Symbol() == math.MaxUint16
 499}
 500
 501// HasChanges checks if a syntax node has been edited.
 502func (n Node) HasChanges() bool {
 503	return bool(C.ts_node_has_changes(n.c))
 504}
 505
 506// HasError check if the node is a syntax error or contains any syntax errors.
 507func (n Node) HasError() bool {
 508	return bool(C.ts_node_has_error(n.c))
 509}
 510
 511// Parent returns the node's immediate parent.
 512func (n Node) Parent() *Node {
 513	nn := C.ts_node_parent(n.c)
 514	return n.t.cachedNode(nn)
 515}
 516
 517// Child returns the node's child at the given index, where zero represents the first child.
 518func (n Node) Child(idx int) *Node {
 519	nn := C.ts_node_child(n.c, C.uint32_t(idx))
 520	return n.t.cachedNode(nn)
 521}
 522
 523// NamedChild returns the node's *named* child at the given index.
 524func (n Node) NamedChild(idx int) *Node {
 525	nn := C.ts_node_named_child(n.c, C.uint32_t(idx))
 526	return n.t.cachedNode(nn)
 527}
 528
 529// ChildCount returns the node's number of children.
 530func (n Node) ChildCount() uint32 {
 531	return uint32(C.ts_node_child_count(n.c))
 532}
 533
 534// NamedChildCount returns the node's number of *named* children.
 535func (n Node) NamedChildCount() uint32 {
 536	return uint32(C.ts_node_named_child_count(n.c))
 537}
 538
 539// ChildByFieldName returns the node's child with the given field name.
 540func (n Node) ChildByFieldName(name string) *Node {
 541	str := C.CString(name)
 542	defer C.free(unsafe.Pointer(str))
 543	nn := C.ts_node_child_by_field_name(n.c, str, C.uint32_t(len(name)))
 544	return n.t.cachedNode(nn)
 545}
 546
 547// FieldNameForChild returns the field name of the child at the given index, or "" if not named.
 548func (n Node) FieldNameForChild(idx int) string {
 549	return C.GoString(C.ts_node_field_name_for_child(n.c, C.uint32_t(idx)))
 550}
 551
 552// NextSibling returns the node's next sibling.
 553func (n Node) NextSibling() *Node {
 554	nn := C.ts_node_next_sibling(n.c)
 555	return n.t.cachedNode(nn)
 556}
 557
 558// NextNamedSibling returns the node's next *named* sibling.
 559func (n Node) NextNamedSibling() *Node {
 560	nn := C.ts_node_next_named_sibling(n.c)
 561	return n.t.cachedNode(nn)
 562}
 563
 564// PrevSibling returns the node's previous sibling.
 565func (n Node) PrevSibling() *Node {
 566	nn := C.ts_node_prev_sibling(n.c)
 567	return n.t.cachedNode(nn)
 568}
 569
 570// PrevNamedSibling returns the node's previous *named* sibling.
 571func (n Node) PrevNamedSibling() *Node {
 572	nn := C.ts_node_prev_named_sibling(n.c)
 573	return n.t.cachedNode(nn)
 574}
 575
 576// Edit the node to keep it in-sync with source code that has been edited.
 577func (n Node) Edit(i EditInput) {
 578	C.ts_node_edit(&n.c, i.c())
 579}
 580
 581// Content returns node's source code from input as a string
 582func (n Node) Content(input []byte) string {
 583	return string(input[n.StartByte():n.EndByte()])
 584}
 585
 586func (n Node) NamedDescendantForPointRange(start Point, end Point) *Node {
 587	cStartPoint := C.TSPoint{
 588		row:    C.uint32_t(start.Row),
 589		column: C.uint32_t(start.Column),
 590	}
 591	cEndPoint := C.TSPoint{
 592		row:    C.uint32_t(end.Row),
 593		column: C.uint32_t(end.Column),
 594	}
 595	nn := C.ts_node_named_descendant_for_point_range(n.c, cStartPoint, cEndPoint)
 596	return n.t.cachedNode(nn)
 597}
 598
 599// TreeCursor allows you to walk a syntax tree more efficiently than is
 600// possible using the `Node` functions. It is a mutable object that is always
 601// on a certain syntax node, and can be moved imperatively to different nodes.
 602type TreeCursor struct {
 603	c *C.TSTreeCursor
 604	t *Tree
 605
 606	isClosed bool
 607}
 608
 609// NewTreeCursor creates a new tree cursor starting from the given node.
 610func NewTreeCursor(n *Node) *TreeCursor {
 611	cc := C.ts_tree_cursor_new(n.c)
 612	c := &TreeCursor{
 613		c: &cc,
 614		t: n.t,
 615	}
 616
 617	runtime.SetFinalizer(c, (*TreeCursor).Close)
 618	return c
 619}
 620
 621// Close should be called to ensure that all the memory used by the tree cursor
 622// is freed.
 623//
 624// As the constructor in go-tree-sitter would set this func call through runtime.SetFinalizer,
 625// parser.Close() will be called by Go's garbage collector and users would not have to call this manually.
 626func (c *TreeCursor) Close() {
 627	if !c.isClosed {
 628		C.ts_tree_cursor_delete(c.c)
 629	}
 630
 631	c.isClosed = true
 632}
 633
 634// Reset re-initializes a tree cursor to start at a different node.
 635func (c *TreeCursor) Reset(n *Node) {
 636	c.t = n.t
 637	C.ts_tree_cursor_reset(c.c, n.c)
 638}
 639
 640// CurrentNode of the tree cursor.
 641func (c *TreeCursor) CurrentNode() *Node {
 642	n := C.ts_tree_cursor_current_node(c.c)
 643	return c.t.cachedNode(n)
 644}
 645
 646// CurrentFieldName gets the field name of the tree cursor's current node.
 647//
 648// This returns empty string if the current node doesn't have a field.
 649func (c *TreeCursor) CurrentFieldName() string {
 650	return C.GoString(C.ts_tree_cursor_current_field_name(c.c))
 651}
 652
 653// GoToParent moves the cursor to the parent of its current node.
 654//
 655// This returns `true` if the cursor successfully moved, and returns `false`
 656// if there was no parent node (the cursor was already on the root node).
 657func (c *TreeCursor) GoToParent() bool {
 658	return bool(C.ts_tree_cursor_goto_parent(c.c))
 659}
 660
 661// GoToNextSibling moves the cursor to the next sibling of its current node.
 662//
 663// This returns `true` if the cursor successfully moved, and returns `false`
 664// if there was no next sibling node.
 665func (c *TreeCursor) GoToNextSibling() bool {
 666	return bool(C.ts_tree_cursor_goto_next_sibling(c.c))
 667}
 668
 669// GoToFirstChild moves the cursor to the first child of its current node.
 670//
 671// This returns `true` if the cursor successfully moved, and returns `false`
 672// if there were no children.
 673func (c *TreeCursor) GoToFirstChild() bool {
 674	return bool(C.ts_tree_cursor_goto_first_child(c.c))
 675}
 676
 677// GoToFirstChildForByte moves the cursor to the first child of its current node
 678// that extends beyond the given byte offset.
 679//
 680// This returns the index of the child node if one was found, and returns -1
 681// if no such child was found.
 682func (c *TreeCursor) GoToFirstChildForByte(b uint32) int64 {
 683	return int64(C.ts_tree_cursor_goto_first_child_for_byte(c.c, C.uint32_t(b)))
 684}
 685
 686// QueryErrorType - value that indicates the type of QueryError.
 687type QueryErrorType int
 688
 689const (
 690	QueryErrorNone QueryErrorType = iota
 691	QueryErrorSyntax
 692	QueryErrorNodeType
 693	QueryErrorField
 694	QueryErrorCapture
 695	QueryErrorStructure
 696	QueryErrorLanguage
 697)
 698
 699func QueryErrorTypeToString(errorType QueryErrorType) string {
 700	switch errorType {
 701	case QueryErrorNone:
 702		return "none"
 703	case QueryErrorNodeType:
 704		return "node type"
 705	case QueryErrorField:
 706		return "field"
 707	case QueryErrorCapture:
 708		return "capture"
 709	case QueryErrorSyntax:
 710		return "syntax"
 711	default:
 712		return "unknown"
 713	}
 714
 715}
 716
 717// QueryError - if there is an error in the query,
 718// then the Offset argument will be set to the byte offset of the error,
 719// and the Type argument will be set to a value that indicates the type of error.
 720type QueryError struct {
 721	Offset  uint32
 722	Type    QueryErrorType
 723	Message string
 724}
 725
 726func (qe *QueryError) Error() string {
 727	return qe.Message
 728}
 729
 730// Query API
 731type Query struct {
 732	c        *C.TSQuery
 733	isClosed bool
 734}
 735
 736// NewQuery creates a query by specifying a string containing one or more patterns.
 737// In case of error returns QueryError.
 738func NewQuery(pattern []byte, lang *Language) (*Query, error) {
 739	var (
 740		erroff  C.uint32_t
 741		errtype C.TSQueryError
 742	)
 743
 744	input := C.CBytes(pattern)
 745	c := C.ts_query_new(
 746		(*C.struct_TSLanguage)(lang.ptr),
 747		(*C.char)(input),
 748		C.uint32_t(len(pattern)),
 749		&erroff,
 750		&errtype,
 751	)
 752	C.free(input)
 753	if errtype != C.TSQueryError(QueryErrorNone) {
 754		errorOffset := uint32(erroff)
 755		// search for the line containing the offset
 756		line := 1
 757		line_start := 0
 758		for i, c := range pattern {
 759			line_start = i
 760			if uint32(i) >= errorOffset {
 761				break
 762			}
 763			if c == '\n' {
 764				line++
 765			}
 766		}
 767		column := int(errorOffset) - line_start
 768		errorType := QueryErrorType(errtype)
 769		errorTypeToString := QueryErrorTypeToString(errorType)
 770
 771		var message string
 772		switch errorType {
 773		// errors that apply to a single identifier
 774		case QueryErrorNodeType:
 775			fallthrough
 776		case QueryErrorField:
 777			fallthrough
 778		case QueryErrorCapture:
 779			// find identifier at input[errorOffset]
 780			// and report it in the error message
 781			s := string(pattern[errorOffset:])
 782			identifierRegexp := regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_-]*`)
 783			m := identifierRegexp.FindStringSubmatch(s)
 784			if len(m) > 0 {
 785				message = fmt.Sprintf("invalid %s '%s' at line %d column %d",
 786					errorTypeToString, m[0], line, column)
 787			} else {
 788				message = fmt.Sprintf("invalid %s at line %d column %d",
 789					errorTypeToString, line, column)
 790			}
 791
 792		// errors the report position
 793		case QueryErrorSyntax:
 794			fallthrough
 795		case QueryErrorStructure:
 796			fallthrough
 797		case QueryErrorLanguage:
 798			fallthrough
 799		default:
 800			s := string(pattern[errorOffset:])
 801			lines := strings.Split(s, "\n")
 802			whitespace := strings.Repeat(" ", column)
 803			message = fmt.Sprintf("invalid %s at line %d column %d\n%s\n%s^",
 804				errorTypeToString, line, column,
 805				lines[0], whitespace)
 806		}
 807
 808		return nil, &QueryError{
 809			Offset:  errorOffset,
 810			Type:    errorType,
 811			Message: message,
 812		}
 813	}
 814
 815	q := &Query{c: c}
 816
 817	// Copied from: https://github.com/klothoplatform/go-tree-sitter/commit/e351b20167b26d515627a4a1a884528ede5fef79
 818	// this is just used for syntax validation - it does not actually filter anything
 819	for i := uint32(0); i < q.PatternCount(); i++ {
 820		predicates := q.PredicatesForPattern(i)
 821		for _, steps := range predicates {
 822			if len(steps) == 0 {
 823				continue
 824			}
 825
 826			if steps[0].Type != QueryPredicateStepTypeString {
 827				return nil, errors.New("predicate must begin with a literal value")
 828			}
 829
 830			operator := q.StringValueForId(steps[0].ValueId)
 831			switch operator {
 832			case "eq?", "not-eq?":
 833				if len(steps) != 4 {
 834					return nil, fmt.Errorf("wrong number of arguments to `#%s` predicate. Expected 2, got %d", operator, len(steps)-2)
 835				}
 836				if steps[1].Type != QueryPredicateStepTypeCapture {
 837					return nil, fmt.Errorf("first argument of `#%s` predicate must be a capture. Got %s", operator, q.StringValueForId(steps[1].ValueId))
 838				}
 839			case "match?", "not-match?":
 840				if len(steps) != 4 {
 841					return nil, fmt.Errorf("wrong number of arguments to `#%s` predicate. Expected 2, got %d", operator, len(steps)-2)
 842				}
 843				if steps[1].Type != QueryPredicateStepTypeCapture {
 844					return nil, fmt.Errorf("first argument of `#%s` predicate must be a capture. Got %s", operator, q.StringValueForId(steps[1].ValueId))
 845				}
 846				if steps[2].Type != QueryPredicateStepTypeString {
 847					return nil, fmt.Errorf("second argument of `#%s` predicate must be a string. Got %s", operator, q.StringValueForId(steps[2].ValueId))
 848				}
 849			case "set!", "is?", "is-not?":
 850				if len(steps) < 3 || len(steps) > 4 {
 851					return nil, fmt.Errorf("wrong number of arguments to `#%s` predicate. Expected 1 or 2, got %d", operator, len(steps)-2)
 852				}
 853				if steps[1].Type != QueryPredicateStepTypeString {
 854					return nil, fmt.Errorf("first argument of `#%s` predicate must be a string. Got %s", operator, q.StringValueForId(steps[1].ValueId))
 855				}
 856				if len(steps) > 2 && steps[2].Type != QueryPredicateStepTypeString {
 857					return nil, fmt.Errorf("second argument of `#%s` predicate must be a string. Got %s", operator, q.StringValueForId(steps[2].ValueId))
 858				}
 859			}
 860		}
 861	}
 862
 863	runtime.SetFinalizer(q, (*Query).Close)
 864
 865	return q, nil
 866}
 867
 868// Close should be called to ensure that all the memory used by the query is freed.
 869//
 870// As the constructor in go-tree-sitter would set this func call through runtime.SetFinalizer,
 871// parser.Close() will be called by Go's garbage collector and users would not have to call this manually.
 872func (q *Query) Close() {
 873	if !q.isClosed {
 874		C.ts_query_delete(q.c)
 875	}
 876
 877	q.isClosed = true
 878}
 879
 880func (q *Query) PatternCount() uint32 {
 881	return uint32(C.ts_query_pattern_count(q.c))
 882}
 883
 884func (q *Query) CaptureCount() uint32 {
 885	return uint32(C.ts_query_capture_count(q.c))
 886}
 887
 888func (q *Query) StringCount() uint32 {
 889	return uint32(C.ts_query_string_count(q.c))
 890}
 891
 892type QueryPredicateStepType int
 893
 894const (
 895	QueryPredicateStepTypeDone QueryPredicateStepType = iota
 896	QueryPredicateStepTypeCapture
 897	QueryPredicateStepTypeString
 898)
 899
 900type QueryPredicateStep struct {
 901	Type    QueryPredicateStepType
 902	ValueId uint32
 903}
 904
 905func (q *Query) PredicatesForPattern(patternIndex uint32) [][]QueryPredicateStep {
 906	var (
 907		length          C.uint32_t
 908		cPredicateSteps []C.TSQueryPredicateStep
 909		predicateSteps  []QueryPredicateStep
 910	)
 911
 912	cPredicateStep := C.ts_query_predicates_for_pattern(q.c, C.uint32_t(patternIndex), &length)
 913
 914	count := int(length)
 915	slice := (*reflect.SliceHeader)((unsafe.Pointer(&cPredicateSteps)))
 916	slice.Cap = count
 917	slice.Len = count
 918	slice.Data = uintptr(unsafe.Pointer(cPredicateStep))
 919	for _, s := range cPredicateSteps {
 920		stepType := QueryPredicateStepType(s._type)
 921		valueId := uint32(s.value_id)
 922		predicateSteps = append(predicateSteps, QueryPredicateStep{stepType, valueId})
 923	}
 924
 925	return splitPredicates(predicateSteps)
 926}
 927
 928func (q *Query) CaptureNameForId(id uint32) string {
 929	var length C.uint32_t
 930	name := C.ts_query_capture_name_for_id(q.c, C.uint32_t(id), &length)
 931	return C.GoStringN(name, C.int(length))
 932}
 933
 934func (q *Query) StringValueForId(id uint32) string {
 935	var length C.uint32_t
 936	value := C.ts_query_string_value_for_id(q.c, C.uint32_t(id), &length)
 937	return C.GoStringN(value, C.int(length))
 938}
 939
 940type Quantifier int
 941
 942const (
 943	QuantifierZero = iota
 944	QuantifierZeroOrOne
 945	QuantifierZeroOrMore
 946	QuantifierOne
 947	QuantifierOneOrMore
 948)
 949
 950func (q *Query) CaptureQuantifierForId(id uint32, captureId uint32) Quantifier {
 951	return Quantifier(C.ts_query_capture_quantifier_for_id(q.c, C.uint32_t(id), C.uint32_t(captureId)))
 952}
 953
 954// QueryCursor carries the state needed for processing the queries.
 955type QueryCursor struct {
 956	c *C.TSQueryCursor
 957	t *Tree
 958	// keep a pointer to the query to avoid garbage collection
 959	q *Query
 960
 961	isClosed bool
 962}
 963
 964// NewQueryCursor creates a query cursor.
 965func NewQueryCursor() *QueryCursor {
 966	qc := &QueryCursor{c: C.ts_query_cursor_new(), t: nil}
 967	runtime.SetFinalizer(qc, (*QueryCursor).Close)
 968
 969	return qc
 970}
 971
 972// Exec executes the query on a given syntax node.
 973func (qc *QueryCursor) Exec(q *Query, n *Node) {
 974	qc.q = q
 975	qc.t = n.t
 976	C.ts_query_cursor_exec(qc.c, q.c, n.c)
 977}
 978
 979func (qc *QueryCursor) SetPointRange(startPoint Point, endPoint Point) {
 980	cStartPoint := C.TSPoint{
 981		row:    C.uint32_t(startPoint.Row),
 982		column: C.uint32_t(startPoint.Column),
 983	}
 984	cEndPoint := C.TSPoint{
 985		row:    C.uint32_t(endPoint.Row),
 986		column: C.uint32_t(endPoint.Column),
 987	}
 988	C.ts_query_cursor_set_point_range(qc.c, cStartPoint, cEndPoint)
 989}
 990
 991// Close should be called to ensure that all the memory used by the query cursor is freed.
 992//
 993// As the constructor in go-tree-sitter would set this func call through runtime.SetFinalizer,
 994// parser.Close() will be called by Go's garbage collector and users would not have to call this manually.
 995func (qc *QueryCursor) Close() {
 996	if !qc.isClosed {
 997		C.ts_query_cursor_delete(qc.c)
 998	}
 999
1000	qc.isClosed = true
1001}
1002
1003// QueryCapture is a captured node by a query with an index
1004type QueryCapture struct {
1005	Index uint32
1006	Node  *Node
1007}
1008
1009// QueryMatch - you can then iterate over the matches.
1010type QueryMatch struct {
1011	ID           uint32
1012	PatternIndex uint16
1013	Captures     []QueryCapture
1014}
1015
1016// NextMatch iterates over matches.
1017// This function will return (nil, false) when there are no more matches.
1018// Otherwise, it will populate the QueryMatch with data
1019// about which pattern matched and which nodes were captured.
1020func (qc *QueryCursor) NextMatch() (*QueryMatch, bool) {
1021	var (
1022		cqm C.TSQueryMatch
1023		cqc []C.TSQueryCapture
1024	)
1025
1026	if ok := C.ts_query_cursor_next_match(qc.c, &cqm); !bool(ok) {
1027		return nil, false
1028	}
1029
1030	qm := &QueryMatch{
1031		ID:           uint32(cqm.id),
1032		PatternIndex: uint16(cqm.pattern_index),
1033	}
1034
1035	count := int(cqm.capture_count)
1036	slice := (*reflect.SliceHeader)((unsafe.Pointer(&cqc)))
1037	slice.Cap = count
1038	slice.Len = count
1039	slice.Data = uintptr(unsafe.Pointer(cqm.captures))
1040	for _, c := range cqc {
1041		idx := uint32(c.index)
1042		node := qc.t.cachedNode(c.node)
1043		qm.Captures = append(qm.Captures, QueryCapture{idx, node})
1044	}
1045
1046	return qm, true
1047}
1048
1049func (qc *QueryCursor) NextCapture() (*QueryMatch, uint32, bool) {
1050	var (
1051		cqm          C.TSQueryMatch
1052		cqc          []C.TSQueryCapture
1053		captureIndex C.uint32_t
1054	)
1055
1056	if ok := C.ts_query_cursor_next_capture(qc.c, &cqm, &captureIndex); !bool(ok) {
1057		return nil, 0, false
1058	}
1059
1060	qm := &QueryMatch{
1061		ID:           uint32(cqm.id),
1062		PatternIndex: uint16(cqm.pattern_index),
1063	}
1064
1065	count := int(cqm.capture_count)
1066	slice := (*reflect.SliceHeader)((unsafe.Pointer(&cqc)))
1067	slice.Cap = count
1068	slice.Len = count
1069	slice.Data = uintptr(unsafe.Pointer(cqm.captures))
1070	for _, c := range cqc {
1071		idx := uint32(c.index)
1072		node := qc.t.cachedNode(c.node)
1073		qm.Captures = append(qm.Captures, QueryCapture{idx, node})
1074	}
1075
1076	return qm, uint32(captureIndex), true
1077}
1078
1079// Copied From: https://github.com/klothoplatform/go-tree-sitter/commit/e351b20167b26d515627a4a1a884528ede5fef79
1080
1081func splitPredicates(steps []QueryPredicateStep) [][]QueryPredicateStep {
1082	var predicateSteps [][]QueryPredicateStep
1083	var currentSteps []QueryPredicateStep
1084	for _, step := range steps {
1085		currentSteps = append(currentSteps, step)
1086		if step.Type == QueryPredicateStepTypeDone {
1087			predicateSteps = append(predicateSteps, currentSteps)
1088			currentSteps = []QueryPredicateStep{}
1089		}
1090	}
1091	return predicateSteps
1092}
1093
1094func (qc *QueryCursor) FilterPredicates(m *QueryMatch, input []byte) *QueryMatch {
1095	qm := &QueryMatch{
1096		ID:           m.ID,
1097		PatternIndex: m.PatternIndex,
1098	}
1099
1100	q := qc.q
1101
1102	predicates := q.PredicatesForPattern(uint32(qm.PatternIndex))
1103	if len(predicates) == 0 {
1104		qm.Captures = m.Captures
1105		return qm
1106	}
1107
1108	// track if we matched all predicates globally
1109	matchedAll := true
1110
1111	// check each predicate against the match
1112	for _, steps := range predicates {
1113		operator := q.StringValueForId(steps[0].ValueId)
1114
1115		switch operator {
1116		case "eq?", "not-eq?":
1117			isPositive := operator == "eq?"
1118
1119			expectedCaptureNameLeft := q.CaptureNameForId(steps[1].ValueId)
1120
1121			if steps[2].Type == QueryPredicateStepTypeCapture {
1122				expectedCaptureNameRight := q.CaptureNameForId(steps[2].ValueId)
1123
1124				var nodeLeft, nodeRight *Node
1125
1126				for _, c := range m.Captures {
1127					captureName := q.CaptureNameForId(c.Index)
1128
1129					if captureName == expectedCaptureNameLeft {
1130						nodeLeft = c.Node
1131					}
1132					if captureName == expectedCaptureNameRight {
1133						nodeRight = c.Node
1134					}
1135
1136					if nodeLeft != nil && nodeRight != nil {
1137						if (nodeLeft.Content(input) == nodeRight.Content(input)) != isPositive {
1138							matchedAll = false
1139						}
1140						break
1141					}
1142				}
1143			} else {
1144				expectedValueRight := q.StringValueForId(steps[2].ValueId)
1145
1146				for _, c := range m.Captures {
1147					captureName := q.CaptureNameForId(c.Index)
1148
1149					if expectedCaptureNameLeft != captureName {
1150						continue
1151					}
1152
1153					if (c.Node.Content(input) == expectedValueRight) != isPositive {
1154						matchedAll = false
1155						break
1156					}
1157				}
1158			}
1159
1160			if matchedAll == false {
1161				break
1162			}
1163
1164		case "match?", "not-match?":
1165			isPositive := operator == "match?"
1166
1167			expectedCaptureName := q.CaptureNameForId(steps[1].ValueId)
1168			regex := regexp.MustCompile(q.StringValueForId(steps[2].ValueId))
1169
1170			for _, c := range m.Captures {
1171				captureName := q.CaptureNameForId(c.Index)
1172				if expectedCaptureName != captureName {
1173					continue
1174				}
1175
1176				if regex.Match([]byte(c.Node.Content(input))) != isPositive {
1177					matchedAll = false
1178					break
1179				}
1180			}
1181		}
1182	}
1183
1184	if matchedAll {
1185		qm.Captures = append(qm.Captures, m.Captures...)
1186	}
1187
1188	return qm
1189
1190}
1191
1192// keeps callbacks for parser.parse method
1193type readFuncsMap struct {
1194	sync.Mutex
1195
1196	funcs map[int]ReadFunc
1197	count int
1198}
1199
1200func (m *readFuncsMap) register(f ReadFunc) int {
1201	m.Lock()
1202	defer m.Unlock()
1203
1204	m.count++
1205	m.funcs[m.count] = f
1206	return m.count
1207}
1208
1209func (m *readFuncsMap) unregister(id int) {
1210	m.Lock()
1211	defer m.Unlock()
1212
1213	delete(m.funcs, id)
1214}
1215
1216func (m *readFuncsMap) get(id int) ReadFunc {
1217	m.Lock()
1218	defer m.Unlock()
1219
1220	return m.funcs[id]
1221}
1222
1223//export callReadFunc
1224func callReadFunc(id C.int, byteIndex C.uint32_t, position C.TSPoint, bytesRead *C.uint32_t) *C.char {
1225	readFunc := readFuncs.get(int(id))
1226	content := readFunc(uint32(byteIndex), Point{
1227		Row:    uint32(position.row),
1228		Column: uint32(position.column),
1229	})
1230	*bytesRead = C.uint32_t(len(content))
1231
1232	// Note: This memory is freed inside the C code; see bindings.c
1233	input := C.CBytes(content)
1234	return (*C.char)(input)
1235}