diff options
| author | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-01-21 20:22:09 +0100 |
|---|---|---|
| committer | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-01-21 20:22:09 +0100 |
| commit | 5a8dbc6347b3541e84fe669b22c17ad3b715e258 (patch) | |
| tree | b148c450939688caaaeb4adac6f2faa1eaffe649 /vendor/github.com/mitjafelicijan/go-tree-sitter/bindings.go | |
| download | qwe-editor-5a8dbc6347b3541e84fe669b22c17ad3b715e258.tar.gz | |
Engage!
Diffstat (limited to 'vendor/github.com/mitjafelicijan/go-tree-sitter/bindings.go')
| -rw-r--r-- | vendor/github.com/mitjafelicijan/go-tree-sitter/bindings.go | 1235 |
1 files changed, 1235 insertions, 0 deletions
diff --git a/vendor/github.com/mitjafelicijan/go-tree-sitter/bindings.go b/vendor/github.com/mitjafelicijan/go-tree-sitter/bindings.go new file mode 100644 index 0000000..5c3cca3 --- /dev/null +++ b/vendor/github.com/mitjafelicijan/go-tree-sitter/bindings.go @@ -0,0 +1,1235 @@ +package sitter + +// #include "bindings.h" +import "C" + +import ( + "context" + "errors" + "fmt" + "math" + "reflect" + "regexp" + "runtime" + "strings" + "sync" + "sync/atomic" + "unsafe" +) + +// maintain a map of read functions that can be called from C +var readFuncs = &readFuncsMap{funcs: make(map[int]ReadFunc)} + +// Parse is a shortcut for parsing bytes of source code, +// returns root node +// +// Deprecated: use ParseCtx instead +func Parse(content []byte, lang *Language) *Node { + n, _ := ParseCtx(context.Background(), content, lang) + return n +} + +// ParseCtx is a shortcut for parsing bytes of source code, +// returns root node +func ParseCtx(ctx context.Context, content []byte, lang *Language) (*Node, error) { + p := NewParser() + p.SetLanguage(lang) + tree, err := p.ParseCtx(ctx, nil, content) + if err != nil { + return nil, err + } + + return tree.RootNode(), nil +} + +// Parser produces concrete syntax tree based on source code using Language +type Parser struct { + isClosed bool + c *C.TSParser + cancel *uintptr +} + +// NewParser creates new Parser +func NewParser() *Parser { + cancel := uintptr(0) + p := &Parser{c: C.ts_parser_new(), cancel: &cancel} + C.ts_parser_set_cancellation_flag(p.c, (*C.size_t)(unsafe.Pointer(p.cancel))) + runtime.SetFinalizer(p, (*Parser).Close) + return p +} + +// SetLanguage assignes Language to a parser +func (p *Parser) SetLanguage(lang *Language) { + cLang := (*C.struct_TSLanguage)(lang.ptr) + C.ts_parser_set_language(p.c, cLang) +} + +// ReadFunc is a function to retrieve a chunk of text at a given byte offset and (row, column) position +// it should return nil to indicate the end of the document +type ReadFunc func(offset uint32, position Point) []byte + +// InputEncoding is a encoding of the text to parse +type InputEncoding int + +const ( + InputEncodingUTF8 InputEncoding = iota + InputEncodingUTF16 +) + +// Input defines parameters for parse method +type Input struct { + Read ReadFunc + Encoding InputEncoding +} + +var ( + ErrOperationLimit = errors.New("operation limit was hit") + ErrNoLanguage = errors.New("cannot parse without language") +) + +// Parse produces new Tree from content using old tree +// +// Deprecated: use ParseCtx instead +func (p *Parser) Parse(oldTree *Tree, content []byte) *Tree { + t, _ := p.ParseCtx(context.Background(), oldTree, content) + return t +} + +// ParseCtx produces new Tree from content using old tree +func (p *Parser) ParseCtx(ctx context.Context, oldTree *Tree, content []byte) (*Tree, error) { + var BaseTree *C.TSTree + if oldTree != nil { + BaseTree = oldTree.c + } + + parseComplete := make(chan struct{}) + + // run goroutine only if context is cancelable to avoid performance impact + if ctx.Done() != nil { + go func() { + select { + case <-ctx.Done(): + atomic.StoreUintptr(p.cancel, 1) + case <-parseComplete: + return + } + }() + } + + input := C.CBytes(content) + BaseTree = C.ts_parser_parse_string(p.c, BaseTree, (*C.char)(input), C.uint32_t(len(content))) + close(parseComplete) + C.free(input) + + return p.convertTSTree(ctx, BaseTree) +} + +// ParseInput produces new Tree by reading from a callback defined in input +// it is useful if your data is stored in specialized data structure +// as it will avoid copying the data into []bytes +// and faster access to edited part of the data +func (p *Parser) ParseInput(oldTree *Tree, input Input) *Tree { + t, _ := p.ParseInputCtx(context.Background(), oldTree, input) + return t +} + +// ParseInputCtx produces new Tree by reading from a callback defined in input +// it is useful if your data is stored in specialized data structure +// as it will avoid copying the data into []bytes +// and faster access to edited part of the data +func (p *Parser) ParseInputCtx(ctx context.Context, oldTree *Tree, input Input) (*Tree, error) { + var BaseTree *C.TSTree + if oldTree != nil { + BaseTree = oldTree.c + } + + funcID := readFuncs.register(input.Read) + BaseTree = C.call_ts_parser_parse(p.c, BaseTree, C.int(funcID), C.TSInputEncoding(input.Encoding)) + readFuncs.unregister(funcID) + + return p.convertTSTree(ctx, BaseTree) +} + +// convertTSTree converts the tree-sitter response into a *Tree or an error. +// +// tree-sitter can fail for 3 reasons: +// - cancelation +// - operation limit hit +// - no language set +// +// We check for all those conditions if ther return value is nil. +// see: https://github.com/tree-sitter/tree-sitter/blob/7890a29db0b186b7b21a0a95d99fa6c562b8316b/lib/include/tree_sitter/api.h#L209-L246 +func (p *Parser) convertTSTree(ctx context.Context, tsTree *C.TSTree) (*Tree, error) { + if tsTree == nil { + if ctx.Err() != nil { + // reset cancellation flag so the parse can be re-used + atomic.StoreUintptr(p.cancel, 0) + // context cancellation caused a timeout, return that error + return nil, ctx.Err() + } + + if C.ts_parser_language(p.c) == nil { + return nil, ErrNoLanguage + } + + return nil, ErrOperationLimit + } + + return p.newTree(tsTree), nil +} + +// OperationLimit returns the duration in microseconds that parsing is allowed to take +func (p *Parser) OperationLimit() int { + return int(C.ts_parser_timeout_micros(p.c)) +} + +// SetOperationLimit limits the maximum duration in microseconds that parsing should be allowed to take before halting +func (p *Parser) SetOperationLimit(limit int) { + C.ts_parser_set_timeout_micros(p.c, C.uint64_t(limit)) +} + +// Reset causes the parser to parse from scratch on the next call to parse, instead of resuming +// so that it sees the changes to the beginning of the source code. +func (p *Parser) Reset() { + C.ts_parser_reset(p.c) +} + +// SetIncludedRanges sets text ranges of a file +func (p *Parser) SetIncludedRanges(ranges []Range) { + cRanges := make([]C.TSRange, len(ranges)) + for i, r := range ranges { + cRanges[i] = C.TSRange{ + start_point: C.TSPoint{ + row: C.uint32_t(r.StartPoint.Row), + column: C.uint32_t(r.StartPoint.Column), + }, + end_point: C.TSPoint{ + row: C.uint32_t(r.EndPoint.Row), + column: C.uint32_t(r.EndPoint.Column), + }, + start_byte: C.uint32_t(r.StartByte), + end_byte: C.uint32_t(r.EndByte), + } + } + C.ts_parser_set_included_ranges(p.c, (*C.TSRange)(unsafe.Pointer(&cRanges[0])), C.uint(len(ranges))) +} + +// Debug enables debug output to stderr +func (p *Parser) Debug() { + logger := C.stderr_logger_new(true) + C.ts_parser_set_logger(p.c, logger) +} + +// Close should be called to ensure that all the memory used by the parse is freed. +// +// As the constructor in go-tree-sitter would set this func call through runtime.SetFinalizer, +// parser.Close() will be called by Go's garbage collector and users would not have to call this manually. +func (p *Parser) Close() { + if !p.isClosed { + C.ts_parser_delete(p.c) + } + + p.isClosed = true +} + +type Point struct { + Row uint32 + Column uint32 +} + +type Range struct { + StartPoint Point + EndPoint Point + StartByte uint32 + EndByte uint32 +} + +// we use cache for nodes on normal tree object +// it prevent run of SetFinalizer as it introduces cycle +// we can workaround it using separate object +// for details see: https://github.com/golang/go/issues/7358#issuecomment-66091558 +type BaseTree struct { + c *C.TSTree + isClosed bool +} + +// newTree creates a new tree object from a C pointer. The function will set a finalizer for the object, +// thus no free is needed for it. +func (p *Parser) newTree(c *C.TSTree) *Tree { + base := &BaseTree{c: c} + runtime.SetFinalizer(base, (*BaseTree).Close) + + newTree := &Tree{p: p, BaseTree: base, cache: make(map[C.TSNode]*Node)} + return newTree +} + +// Tree represents the syntax tree of an entire source code file +// Note: Tree instances are not thread safe; +// you must copy a tree if you want to use it on multiple threads simultaneously. +type Tree struct { + *BaseTree + + // p is a pointer to a Parser that produced the Tree. Only used to keep Parser alive. + // Otherwise Parser may be GC'ed (and deleted by the finalizer) while some Tree objects are still in use. + p *Parser + + // most probably better save node.id + cache map[C.TSNode]*Node +} + +// Copy returns a new copy of a tree +func (t *Tree) Copy() *Tree { + return t.p.newTree(C.ts_tree_copy(t.c)) +} + +// RootNode returns root node of a tree +func (t *Tree) RootNode() *Node { + ptr := C.ts_tree_root_node(t.c) + return t.cachedNode(ptr) +} + +func (t *Tree) cachedNode(ptr C.TSNode) *Node { + if ptr.id == nil { + return nil + } + + if n, ok := t.cache[ptr]; ok { + return n + } + + n := &Node{ptr, t} + t.cache[ptr] = n + return n +} + +// Close should be called to ensure that all the memory used by the tree is freed. +// +// As the constructor in go-tree-sitter would set this func call through runtime.SetFinalizer, +// parser.Close() will be called by Go's garbage collector and users would not have to call this manually. +func (t *BaseTree) Close() { + if !t.isClosed { + C.ts_tree_delete(t.c) + } + + t.isClosed = true +} + +type EditInput struct { + StartIndex uint32 + OldEndIndex uint32 + NewEndIndex uint32 + StartPoint Point + OldEndPoint Point + NewEndPoint Point +} + +func (i EditInput) c() *C.TSInputEdit { + return &C.TSInputEdit{ + start_byte: C.uint32_t(i.StartIndex), + old_end_byte: C.uint32_t(i.OldEndIndex), + new_end_byte: C.uint32_t(i.NewEndIndex), + start_point: C.TSPoint{ + row: C.uint32_t(i.StartPoint.Row), + column: C.uint32_t(i.StartPoint.Column), + }, + old_end_point: C.TSPoint{ + row: C.uint32_t(i.OldEndPoint.Row), + column: C.uint32_t(i.OldEndPoint.Column), + }, + new_end_point: C.TSPoint{ + row: C.uint32_t(i.OldEndPoint.Row), + column: C.uint32_t(i.OldEndPoint.Column), + }, + } +} + +// Edit the syntax tree to keep it in sync with source code that has been edited. +func (t *Tree) Edit(i EditInput) { + C.ts_tree_edit(t.c, i.c()) +} + +// Language defines how to parse a particular programming language +type Language struct { + ptr unsafe.Pointer +} + +// NewLanguage creates new Language from c pointer +func NewLanguage(ptr unsafe.Pointer) *Language { + return &Language{ptr} +} + +// SymbolName returns a node type string for the given Symbol. +func (l *Language) SymbolName(s Symbol) string { + return C.GoString(C.ts_language_symbol_name((*C.TSLanguage)(l.ptr), s)) +} + +// SymbolType returns named, anonymous, or a hidden type for a Symbol. +func (l *Language) SymbolType(s Symbol) SymbolType { + return SymbolType(C.ts_language_symbol_type((*C.TSLanguage)(l.ptr), s)) +} + +// SymbolCount returns the number of distinct field names in the language. +func (l *Language) SymbolCount() uint32 { + return uint32(C.ts_language_symbol_count((*C.TSLanguage)(l.ptr))) +} + +func (l *Language) FieldName(idx int) string { + return C.GoString(C.ts_language_field_name_for_id((*C.TSLanguage)(l.ptr), C.ushort(idx))) +} + +// Node represents a single node in the syntax tree +// It tracks its start and end positions in the source code, +// as well as its relation to other nodes like its parent, siblings and children. +type Node struct { + c C.TSNode + t *Tree // keep pointer on tree because node is valid only as long as tree is +} + +type Symbol = C.TSSymbol + +type SymbolType int + +const ( + SymbolTypeRegular SymbolType = iota + SymbolTypeAnonymous + SymbolTypeAuxiliary +) + +var symbolTypeNames = []string{ + "Regular", + "Anonymous", + "Auxiliary", +} + +func (t SymbolType) String() string { + return symbolTypeNames[t] +} + +func (n Node) ID() uintptr { + return uintptr(n.c.id) +} + +// StartByte returns the node's start byte. +func (n Node) StartByte() uint32 { + return uint32(C.ts_node_start_byte(n.c)) +} + +// EndByte returns the node's end byte. +func (n Node) EndByte() uint32 { + return uint32(C.ts_node_end_byte(n.c)) +} + +// StartPoint returns the node's start position in terms of rows and columns. +func (n Node) StartPoint() Point { + p := C.ts_node_start_point(n.c) + return Point{ + Row: uint32(p.row), + Column: uint32(p.column), + } +} + +// EndPoint returns the node's end position in terms of rows and columns. +func (n Node) EndPoint() Point { + p := C.ts_node_end_point(n.c) + return Point{ + Row: uint32(p.row), + Column: uint32(p.column), + } +} + +func (n Node) Range() Range { + return Range{ + StartByte: n.StartByte(), + EndByte: n.EndByte(), + StartPoint: n.StartPoint(), + EndPoint: n.EndPoint(), + } +} + +// Symbol returns the node's type as a Symbol. +func (n Node) Symbol() Symbol { + return C.ts_node_symbol(n.c) +} + +// Type returns the node's type as a string. +func (n Node) Type() string { + return C.GoString(C.ts_node_type(n.c)) +} + +// String returns an S-expression representing the node as a string. +func (n Node) String() string { + ptr := C.ts_node_string(n.c) + defer C.free(unsafe.Pointer(ptr)) + return C.GoString(ptr) +} + +// Equal checks if two nodes are identical. +func (n Node) Equal(other *Node) bool { + return bool(C.ts_node_eq(n.c, other.c)) +} + +// IsNull checks if the node is null. +func (n Node) IsNull() bool { + return bool(C.ts_node_is_null(n.c)) +} + +// IsNamed checks if the node is *named*. +// Named nodes correspond to named rules in the grammar, +// whereas *anonymous* nodes correspond to string literals in the grammar. +func (n Node) IsNamed() bool { + return bool(C.ts_node_is_named(n.c)) +} + +// IsMissing checks if the node is *missing*. +// Missing nodes are inserted by the parser in order to recover from certain kinds of syntax errors. +func (n Node) IsMissing() bool { + return bool(C.ts_node_is_missing(n.c)) +} + +// IsExtra checks if the node is *extra*. +// Extra nodes represent things like comments, which are not required the grammar, but can appear anywhere. +func (n Node) IsExtra() bool { + return bool(C.ts_node_is_extra(n.c)) +} + +// IsError checks if the node is a syntax error. +// Syntax errors represent parts of the code that could not be incorporated into a valid syntax tree. +func (n Node) IsError() bool { + return n.Symbol() == math.MaxUint16 +} + +// HasChanges checks if a syntax node has been edited. +func (n Node) HasChanges() bool { + return bool(C.ts_node_has_changes(n.c)) +} + +// HasError check if the node is a syntax error or contains any syntax errors. +func (n Node) HasError() bool { + return bool(C.ts_node_has_error(n.c)) +} + +// Parent returns the node's immediate parent. +func (n Node) Parent() *Node { + nn := C.ts_node_parent(n.c) + return n.t.cachedNode(nn) +} + +// Child returns the node's child at the given index, where zero represents the first child. +func (n Node) Child(idx int) *Node { + nn := C.ts_node_child(n.c, C.uint32_t(idx)) + return n.t.cachedNode(nn) +} + +// NamedChild returns the node's *named* child at the given index. +func (n Node) NamedChild(idx int) *Node { + nn := C.ts_node_named_child(n.c, C.uint32_t(idx)) + return n.t.cachedNode(nn) +} + +// ChildCount returns the node's number of children. +func (n Node) ChildCount() uint32 { + return uint32(C.ts_node_child_count(n.c)) +} + +// NamedChildCount returns the node's number of *named* children. +func (n Node) NamedChildCount() uint32 { + return uint32(C.ts_node_named_child_count(n.c)) +} + +// ChildByFieldName returns the node's child with the given field name. +func (n Node) ChildByFieldName(name string) *Node { + str := C.CString(name) + defer C.free(unsafe.Pointer(str)) + nn := C.ts_node_child_by_field_name(n.c, str, C.uint32_t(len(name))) + return n.t.cachedNode(nn) +} + +// FieldNameForChild returns the field name of the child at the given index, or "" if not named. +func (n Node) FieldNameForChild(idx int) string { + return C.GoString(C.ts_node_field_name_for_child(n.c, C.uint32_t(idx))) +} + +// NextSibling returns the node's next sibling. +func (n Node) NextSibling() *Node { + nn := C.ts_node_next_sibling(n.c) + return n.t.cachedNode(nn) +} + +// NextNamedSibling returns the node's next *named* sibling. +func (n Node) NextNamedSibling() *Node { + nn := C.ts_node_next_named_sibling(n.c) + return n.t.cachedNode(nn) +} + +// PrevSibling returns the node's previous sibling. +func (n Node) PrevSibling() *Node { + nn := C.ts_node_prev_sibling(n.c) + return n.t.cachedNode(nn) +} + +// PrevNamedSibling returns the node's previous *named* sibling. +func (n Node) PrevNamedSibling() *Node { + nn := C.ts_node_prev_named_sibling(n.c) + return n.t.cachedNode(nn) +} + +// Edit the node to keep it in-sync with source code that has been edited. +func (n Node) Edit(i EditInput) { + C.ts_node_edit(&n.c, i.c()) +} + +// Content returns node's source code from input as a string +func (n Node) Content(input []byte) string { + return string(input[n.StartByte():n.EndByte()]) +} + +func (n Node) NamedDescendantForPointRange(start Point, end Point) *Node { + cStartPoint := C.TSPoint{ + row: C.uint32_t(start.Row), + column: C.uint32_t(start.Column), + } + cEndPoint := C.TSPoint{ + row: C.uint32_t(end.Row), + column: C.uint32_t(end.Column), + } + nn := C.ts_node_named_descendant_for_point_range(n.c, cStartPoint, cEndPoint) + return n.t.cachedNode(nn) +} + +// TreeCursor allows you to walk a syntax tree more efficiently than is +// possible using the `Node` functions. It is a mutable object that is always +// on a certain syntax node, and can be moved imperatively to different nodes. +type TreeCursor struct { + c *C.TSTreeCursor + t *Tree + + isClosed bool +} + +// NewTreeCursor creates a new tree cursor starting from the given node. +func NewTreeCursor(n *Node) *TreeCursor { + cc := C.ts_tree_cursor_new(n.c) + c := &TreeCursor{ + c: &cc, + t: n.t, + } + + runtime.SetFinalizer(c, (*TreeCursor).Close) + return c +} + +// Close should be called to ensure that all the memory used by the tree cursor +// is freed. +// +// As the constructor in go-tree-sitter would set this func call through runtime.SetFinalizer, +// parser.Close() will be called by Go's garbage collector and users would not have to call this manually. +func (c *TreeCursor) Close() { + if !c.isClosed { + C.ts_tree_cursor_delete(c.c) + } + + c.isClosed = true +} + +// Reset re-initializes a tree cursor to start at a different node. +func (c *TreeCursor) Reset(n *Node) { + c.t = n.t + C.ts_tree_cursor_reset(c.c, n.c) +} + +// CurrentNode of the tree cursor. +func (c *TreeCursor) CurrentNode() *Node { + n := C.ts_tree_cursor_current_node(c.c) + return c.t.cachedNode(n) +} + +// CurrentFieldName gets the field name of the tree cursor's current node. +// +// This returns empty string if the current node doesn't have a field. +func (c *TreeCursor) CurrentFieldName() string { + return C.GoString(C.ts_tree_cursor_current_field_name(c.c)) +} + +// GoToParent moves the cursor to the parent of its current node. +// +// This returns `true` if the cursor successfully moved, and returns `false` +// if there was no parent node (the cursor was already on the root node). +func (c *TreeCursor) GoToParent() bool { + return bool(C.ts_tree_cursor_goto_parent(c.c)) +} + +// GoToNextSibling moves the cursor to the next sibling of its current node. +// +// This returns `true` if the cursor successfully moved, and returns `false` +// if there was no next sibling node. +func (c *TreeCursor) GoToNextSibling() bool { + return bool(C.ts_tree_cursor_goto_next_sibling(c.c)) +} + +// GoToFirstChild moves the cursor to the first child of its current node. +// +// This returns `true` if the cursor successfully moved, and returns `false` +// if there were no children. +func (c *TreeCursor) GoToFirstChild() bool { + return bool(C.ts_tree_cursor_goto_first_child(c.c)) +} + +// GoToFirstChildForByte moves the cursor to the first child of its current node +// that extends beyond the given byte offset. +// +// This returns the index of the child node if one was found, and returns -1 +// if no such child was found. +func (c *TreeCursor) GoToFirstChildForByte(b uint32) int64 { + return int64(C.ts_tree_cursor_goto_first_child_for_byte(c.c, C.uint32_t(b))) +} + +// QueryErrorType - value that indicates the type of QueryError. +type QueryErrorType int + +const ( + QueryErrorNone QueryErrorType = iota + QueryErrorSyntax + QueryErrorNodeType + QueryErrorField + QueryErrorCapture + QueryErrorStructure + QueryErrorLanguage +) + +func QueryErrorTypeToString(errorType QueryErrorType) string { + switch errorType { + case QueryErrorNone: + return "none" + case QueryErrorNodeType: + return "node type" + case QueryErrorField: + return "field" + case QueryErrorCapture: + return "capture" + case QueryErrorSyntax: + return "syntax" + default: + return "unknown" + } + +} + +// QueryError - if there is an error in the query, +// then the Offset argument will be set to the byte offset of the error, +// and the Type argument will be set to a value that indicates the type of error. +type QueryError struct { + Offset uint32 + Type QueryErrorType + Message string +} + +func (qe *QueryError) Error() string { + return qe.Message +} + +// Query API +type Query struct { + c *C.TSQuery + isClosed bool +} + +// NewQuery creates a query by specifying a string containing one or more patterns. +// In case of error returns QueryError. +func NewQuery(pattern []byte, lang *Language) (*Query, error) { + var ( + erroff C.uint32_t + errtype C.TSQueryError + ) + + input := C.CBytes(pattern) + c := C.ts_query_new( + (*C.struct_TSLanguage)(lang.ptr), + (*C.char)(input), + C.uint32_t(len(pattern)), + &erroff, + &errtype, + ) + C.free(input) + if errtype != C.TSQueryError(QueryErrorNone) { + errorOffset := uint32(erroff) + // search for the line containing the offset + line := 1 + line_start := 0 + for i, c := range pattern { + line_start = i + if uint32(i) >= errorOffset { + break + } + if c == '\n' { + line++ + } + } + column := int(errorOffset) - line_start + errorType := QueryErrorType(errtype) + errorTypeToString := QueryErrorTypeToString(errorType) + + var message string + switch errorType { + // errors that apply to a single identifier + case QueryErrorNodeType: + fallthrough + case QueryErrorField: + fallthrough + case QueryErrorCapture: + // find identifier at input[errorOffset] + // and report it in the error message + s := string(pattern[errorOffset:]) + identifierRegexp := regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_-]*`) + m := identifierRegexp.FindStringSubmatch(s) + if len(m) > 0 { + message = fmt.Sprintf("invalid %s '%s' at line %d column %d", + errorTypeToString, m[0], line, column) + } else { + message = fmt.Sprintf("invalid %s at line %d column %d", + errorTypeToString, line, column) + } + + // errors the report position + case QueryErrorSyntax: + fallthrough + case QueryErrorStructure: + fallthrough + case QueryErrorLanguage: + fallthrough + default: + s := string(pattern[errorOffset:]) + lines := strings.Split(s, "\n") + whitespace := strings.Repeat(" ", column) + message = fmt.Sprintf("invalid %s at line %d column %d\n%s\n%s^", + errorTypeToString, line, column, + lines[0], whitespace) + } + + return nil, &QueryError{ + Offset: errorOffset, + Type: errorType, + Message: message, + } + } + + q := &Query{c: c} + + // Copied from: https://github.com/klothoplatform/go-tree-sitter/commit/e351b20167b26d515627a4a1a884528ede5fef79 + // this is just used for syntax validation - it does not actually filter anything + for i := uint32(0); i < q.PatternCount(); i++ { + predicates := q.PredicatesForPattern(i) + for _, steps := range predicates { + if len(steps) == 0 { + continue + } + + if steps[0].Type != QueryPredicateStepTypeString { + return nil, errors.New("predicate must begin with a literal value") + } + + operator := q.StringValueForId(steps[0].ValueId) + switch operator { + case "eq?", "not-eq?": + if len(steps) != 4 { + return nil, fmt.Errorf("wrong number of arguments to `#%s` predicate. Expected 2, got %d", operator, len(steps)-2) + } + if steps[1].Type != QueryPredicateStepTypeCapture { + return nil, fmt.Errorf("first argument of `#%s` predicate must be a capture. Got %s", operator, q.StringValueForId(steps[1].ValueId)) + } + case "match?", "not-match?": + if len(steps) != 4 { + return nil, fmt.Errorf("wrong number of arguments to `#%s` predicate. Expected 2, got %d", operator, len(steps)-2) + } + if steps[1].Type != QueryPredicateStepTypeCapture { + return nil, fmt.Errorf("first argument of `#%s` predicate must be a capture. Got %s", operator, q.StringValueForId(steps[1].ValueId)) + } + if steps[2].Type != QueryPredicateStepTypeString { + return nil, fmt.Errorf("second argument of `#%s` predicate must be a string. Got %s", operator, q.StringValueForId(steps[2].ValueId)) + } + case "set!", "is?", "is-not?": + if len(steps) < 3 || len(steps) > 4 { + return nil, fmt.Errorf("wrong number of arguments to `#%s` predicate. Expected 1 or 2, got %d", operator, len(steps)-2) + } + if steps[1].Type != QueryPredicateStepTypeString { + return nil, fmt.Errorf("first argument of `#%s` predicate must be a string. Got %s", operator, q.StringValueForId(steps[1].ValueId)) + } + if len(steps) > 2 && steps[2].Type != QueryPredicateStepTypeString { + return nil, fmt.Errorf("second argument of `#%s` predicate must be a string. Got %s", operator, q.StringValueForId(steps[2].ValueId)) + } + } + } + } + + runtime.SetFinalizer(q, (*Query).Close) + + return q, nil +} + +// Close should be called to ensure that all the memory used by the query is freed. +// +// As the constructor in go-tree-sitter would set this func call through runtime.SetFinalizer, +// parser.Close() will be called by Go's garbage collector and users would not have to call this manually. +func (q *Query) Close() { + if !q.isClosed { + C.ts_query_delete(q.c) + } + + q.isClosed = true +} + +func (q *Query) PatternCount() uint32 { + return uint32(C.ts_query_pattern_count(q.c)) +} + +func (q *Query) CaptureCount() uint32 { + return uint32(C.ts_query_capture_count(q.c)) +} + +func (q *Query) StringCount() uint32 { + return uint32(C.ts_query_string_count(q.c)) +} + +type QueryPredicateStepType int + +const ( + QueryPredicateStepTypeDone QueryPredicateStepType = iota + QueryPredicateStepTypeCapture + QueryPredicateStepTypeString +) + +type QueryPredicateStep struct { + Type QueryPredicateStepType + ValueId uint32 +} + +func (q *Query) PredicatesForPattern(patternIndex uint32) [][]QueryPredicateStep { + var ( + length C.uint32_t + cPredicateSteps []C.TSQueryPredicateStep + predicateSteps []QueryPredicateStep + ) + + cPredicateStep := C.ts_query_predicates_for_pattern(q.c, C.uint32_t(patternIndex), &length) + + count := int(length) + slice := (*reflect.SliceHeader)((unsafe.Pointer(&cPredicateSteps))) + slice.Cap = count + slice.Len = count + slice.Data = uintptr(unsafe.Pointer(cPredicateStep)) + for _, s := range cPredicateSteps { + stepType := QueryPredicateStepType(s._type) + valueId := uint32(s.value_id) + predicateSteps = append(predicateSteps, QueryPredicateStep{stepType, valueId}) + } + + return splitPredicates(predicateSteps) +} + +func (q *Query) CaptureNameForId(id uint32) string { + var length C.uint32_t + name := C.ts_query_capture_name_for_id(q.c, C.uint32_t(id), &length) + return C.GoStringN(name, C.int(length)) +} + +func (q *Query) StringValueForId(id uint32) string { + var length C.uint32_t + value := C.ts_query_string_value_for_id(q.c, C.uint32_t(id), &length) + return C.GoStringN(value, C.int(length)) +} + +type Quantifier int + +const ( + QuantifierZero = iota + QuantifierZeroOrOne + QuantifierZeroOrMore + QuantifierOne + QuantifierOneOrMore +) + +func (q *Query) CaptureQuantifierForId(id uint32, captureId uint32) Quantifier { + return Quantifier(C.ts_query_capture_quantifier_for_id(q.c, C.uint32_t(id), C.uint32_t(captureId))) +} + +// QueryCursor carries the state needed for processing the queries. +type QueryCursor struct { + c *C.TSQueryCursor + t *Tree + // keep a pointer to the query to avoid garbage collection + q *Query + + isClosed bool +} + +// NewQueryCursor creates a query cursor. +func NewQueryCursor() *QueryCursor { + qc := &QueryCursor{c: C.ts_query_cursor_new(), t: nil} + runtime.SetFinalizer(qc, (*QueryCursor).Close) + + return qc +} + +// Exec executes the query on a given syntax node. +func (qc *QueryCursor) Exec(q *Query, n *Node) { + qc.q = q + qc.t = n.t + C.ts_query_cursor_exec(qc.c, q.c, n.c) +} + +func (qc *QueryCursor) SetPointRange(startPoint Point, endPoint Point) { + cStartPoint := C.TSPoint{ + row: C.uint32_t(startPoint.Row), + column: C.uint32_t(startPoint.Column), + } + cEndPoint := C.TSPoint{ + row: C.uint32_t(endPoint.Row), + column: C.uint32_t(endPoint.Column), + } + C.ts_query_cursor_set_point_range(qc.c, cStartPoint, cEndPoint) +} + +// Close should be called to ensure that all the memory used by the query cursor is freed. +// +// As the constructor in go-tree-sitter would set this func call through runtime.SetFinalizer, +// parser.Close() will be called by Go's garbage collector and users would not have to call this manually. +func (qc *QueryCursor) Close() { + if !qc.isClosed { + C.ts_query_cursor_delete(qc.c) + } + + qc.isClosed = true +} + +// QueryCapture is a captured node by a query with an index +type QueryCapture struct { + Index uint32 + Node *Node +} + +// QueryMatch - you can then iterate over the matches. +type QueryMatch struct { + ID uint32 + PatternIndex uint16 + Captures []QueryCapture +} + +// NextMatch iterates over matches. +// This function will return (nil, false) when there are no more matches. +// Otherwise, it will populate the QueryMatch with data +// about which pattern matched and which nodes were captured. +func (qc *QueryCursor) NextMatch() (*QueryMatch, bool) { + var ( + cqm C.TSQueryMatch + cqc []C.TSQueryCapture + ) + + if ok := C.ts_query_cursor_next_match(qc.c, &cqm); !bool(ok) { + return nil, false + } + + qm := &QueryMatch{ + ID: uint32(cqm.id), + PatternIndex: uint16(cqm.pattern_index), + } + + count := int(cqm.capture_count) + slice := (*reflect.SliceHeader)((unsafe.Pointer(&cqc))) + slice.Cap = count + slice.Len = count + slice.Data = uintptr(unsafe.Pointer(cqm.captures)) + for _, c := range cqc { + idx := uint32(c.index) + node := qc.t.cachedNode(c.node) + qm.Captures = append(qm.Captures, QueryCapture{idx, node}) + } + + return qm, true +} + +func (qc *QueryCursor) NextCapture() (*QueryMatch, uint32, bool) { + var ( + cqm C.TSQueryMatch + cqc []C.TSQueryCapture + captureIndex C.uint32_t + ) + + if ok := C.ts_query_cursor_next_capture(qc.c, &cqm, &captureIndex); !bool(ok) { + return nil, 0, false + } + + qm := &QueryMatch{ + ID: uint32(cqm.id), + PatternIndex: uint16(cqm.pattern_index), + } + + count := int(cqm.capture_count) + slice := (*reflect.SliceHeader)((unsafe.Pointer(&cqc))) + slice.Cap = count + slice.Len = count + slice.Data = uintptr(unsafe.Pointer(cqm.captures)) + for _, c := range cqc { + idx := uint32(c.index) + node := qc.t.cachedNode(c.node) + qm.Captures = append(qm.Captures, QueryCapture{idx, node}) + } + + return qm, uint32(captureIndex), true +} + +// Copied From: https://github.com/klothoplatform/go-tree-sitter/commit/e351b20167b26d515627a4a1a884528ede5fef79 + +func splitPredicates(steps []QueryPredicateStep) [][]QueryPredicateStep { + var predicateSteps [][]QueryPredicateStep + var currentSteps []QueryPredicateStep + for _, step := range steps { + currentSteps = append(currentSteps, step) + if step.Type == QueryPredicateStepTypeDone { + predicateSteps = append(predicateSteps, currentSteps) + currentSteps = []QueryPredicateStep{} + } + } + return predicateSteps +} + +func (qc *QueryCursor) FilterPredicates(m *QueryMatch, input []byte) *QueryMatch { + qm := &QueryMatch{ + ID: m.ID, + PatternIndex: m.PatternIndex, + } + + q := qc.q + + predicates := q.PredicatesForPattern(uint32(qm.PatternIndex)) + if len(predicates) == 0 { + qm.Captures = m.Captures + return qm + } + + // track if we matched all predicates globally + matchedAll := true + + // check each predicate against the match + for _, steps := range predicates { + operator := q.StringValueForId(steps[0].ValueId) + + switch operator { + case "eq?", "not-eq?": + isPositive := operator == "eq?" + + expectedCaptureNameLeft := q.CaptureNameForId(steps[1].ValueId) + + if steps[2].Type == QueryPredicateStepTypeCapture { + expectedCaptureNameRight := q.CaptureNameForId(steps[2].ValueId) + + var nodeLeft, nodeRight *Node + + for _, c := range m.Captures { + captureName := q.CaptureNameForId(c.Index) + + if captureName == expectedCaptureNameLeft { + nodeLeft = c.Node + } + if captureName == expectedCaptureNameRight { + nodeRight = c.Node + } + + if nodeLeft != nil && nodeRight != nil { + if (nodeLeft.Content(input) == nodeRight.Content(input)) != isPositive { + matchedAll = false + } + break + } + } + } else { + expectedValueRight := q.StringValueForId(steps[2].ValueId) + + for _, c := range m.Captures { + captureName := q.CaptureNameForId(c.Index) + + if expectedCaptureNameLeft != captureName { + continue + } + + if (c.Node.Content(input) == expectedValueRight) != isPositive { + matchedAll = false + break + } + } + } + + if matchedAll == false { + break + } + + case "match?", "not-match?": + isPositive := operator == "match?" + + expectedCaptureName := q.CaptureNameForId(steps[1].ValueId) + regex := regexp.MustCompile(q.StringValueForId(steps[2].ValueId)) + + for _, c := range m.Captures { + captureName := q.CaptureNameForId(c.Index) + if expectedCaptureName != captureName { + continue + } + + if regex.Match([]byte(c.Node.Content(input))) != isPositive { + matchedAll = false + break + } + } + } + } + + if matchedAll { + qm.Captures = append(qm.Captures, m.Captures...) + } + + return qm + +} + +// keeps callbacks for parser.parse method +type readFuncsMap struct { + sync.Mutex + + funcs map[int]ReadFunc + count int +} + +func (m *readFuncsMap) register(f ReadFunc) int { + m.Lock() + defer m.Unlock() + + m.count++ + m.funcs[m.count] = f + return m.count +} + +func (m *readFuncsMap) unregister(id int) { + m.Lock() + defer m.Unlock() + + delete(m.funcs, id) +} + +func (m *readFuncsMap) get(id int) ReadFunc { + m.Lock() + defer m.Unlock() + + return m.funcs[id] +} + +//export callReadFunc +func callReadFunc(id C.int, byteIndex C.uint32_t, position C.TSPoint, bytesRead *C.uint32_t) *C.char { + readFunc := readFuncs.get(int(id)) + content := readFunc(uint32(byteIndex), Point{ + Row: uint32(position.row), + Column: uint32(position.column), + }) + *bytesRead = C.uint32_t(len(content)) + + // Note: This memory is freed inside the C code; see bindings.c + input := C.CBytes(content) + return (*C.char)(input) +} |
