1package rubex
  2
  3/*
  4#cgo CFLAGS: -I/usr/local/include
  5#cgo LDFLAGS: -L/usr/local/lib -lonig
  6#include <stdlib.h>
  7#include <oniguruma.h>
  8#include "chelper.h"
  9*/
 10import "C"
 11
 12import (
 13	"bytes"
 14	"errors"
 15	"fmt"
 16	"io"
 17	"runtime"
 18	"strconv"
 19	"sync"
 20	"unicode/utf8"
 21	"unsafe"
 22)
 23
 24const numMatchStartSize = 4
 25const numReadBufferStartSize = 256
 26
 27var mutex sync.Mutex
 28
 29type NamedGroupInfo map[string]int
 30
 31type Regexp struct {
 32	pattern   string
 33	regex     C.OnigRegex
 34	encoding  C.OnigEncoding
 35	errorInfo *C.OnigErrorInfo
 36	errorBuf  *C.char
 37
 38	numCaptures    int32
 39	namedGroupInfo NamedGroupInfo
 40}
 41
 42// NewRegexp creates and initializes a new Regexp with the given pattern and option.
 43func NewRegexp(pattern string, option int) (*Regexp, error) {
 44	return initRegexp(&Regexp{pattern: pattern, encoding: C.ONIG_ENCODING_UTF8}, option)
 45}
 46
 47// NewRegexpASCII is equivalent to NewRegexp, but with the encoding restricted to ASCII.
 48func NewRegexpASCII(pattern string, option int) (*Regexp, error) {
 49	return initRegexp(&Regexp{pattern: pattern, encoding: C.ONIG_ENCODING_ASCII}, option)
 50}
 51
 52func initRegexp(re *Regexp, option int) (*Regexp, error) {
 53	patternCharPtr := C.CString(re.pattern)
 54	defer C.free(unsafe.Pointer(patternCharPtr))
 55
 56	mutex.Lock()
 57	defer mutex.Unlock()
 58
 59	errorCode := C.NewOnigRegex(patternCharPtr, C.int(len(re.pattern)), C.int(option), &re.regex, &re.encoding, &re.errorInfo, &re.errorBuf)
 60	if errorCode != C.ONIG_NORMAL {
 61		return re, errors.New(C.GoString(re.errorBuf))
 62	}
 63
 64	re.numCaptures = int32(C.onig_number_of_captures(re.regex)) + 1
 65	re.namedGroupInfo = re.getNamedGroupInfo()
 66
 67	runtime.SetFinalizer(re, (*Regexp).Free)
 68
 69	return re, nil
 70}
 71
 72func Compile(str string) (*Regexp, error) {
 73	return NewRegexp(str, ONIG_OPTION_DEFAULT)
 74}
 75
 76func MustCompile(str string) *Regexp {
 77	regexp, error := NewRegexp(str, ONIG_OPTION_DEFAULT)
 78	if error != nil {
 79		panic("regexp: compiling " + str + ": " + error.Error())
 80	}
 81
 82	return regexp
 83}
 84
 85func CompileWithOption(str string, option int) (*Regexp, error) {
 86	return NewRegexp(str, option)
 87}
 88
 89func MustCompileWithOption(str string, option int) *Regexp {
 90	regexp, error := NewRegexp(str, option)
 91	if error != nil {
 92		panic("regexp: compiling " + str + ": " + error.Error())
 93	}
 94
 95	return regexp
 96}
 97
 98// MustCompileASCII is equivalent to MustCompile, but with the encoding restricted to ASCII.
 99func MustCompileASCII(str string) *Regexp {
100	regexp, error := NewRegexpASCII(str, ONIG_OPTION_DEFAULT)
101	if error != nil {
102		panic("regexp: compiling " + str + ": " + error.Error())
103	}
104
105	return regexp
106}
107
108func (re *Regexp) Free() {
109	mutex.Lock()
110	if re.regex != nil {
111		C.onig_free(re.regex)
112		re.regex = nil
113	}
114	mutex.Unlock()
115	if re.errorInfo != nil {
116		C.free(unsafe.Pointer(re.errorInfo))
117		re.errorInfo = nil
118	}
119	if re.errorBuf != nil {
120		C.free(unsafe.Pointer(re.errorBuf))
121		re.errorBuf = nil
122	}
123}
124
125func (re *Regexp) getNamedGroupInfo() NamedGroupInfo {
126	numNamedGroups := int(C.onig_number_of_names(re.regex))
127	// when any named capture exists, there is no numbered capture even if
128	// there are unnamed captures.
129	if numNamedGroups == 0 {
130		return nil
131	}
132
133	namedGroupInfo := make(map[string]int)
134
135	//try to get the names
136	bufferSize := len(re.pattern) * 2
137	nameBuffer := make([]byte, bufferSize)
138	groupNumbers := make([]int32, numNamedGroups)
139	bufferPtr := unsafe.Pointer(&nameBuffer[0])
140	numbersPtr := unsafe.Pointer(&groupNumbers[0])
141
142	length := int(C.GetCaptureNames(re.regex, bufferPtr, (C.int)(bufferSize), (*C.int)(numbersPtr)))
143	if length == 0 {
144		panic(fmt.Errorf("could not get the capture group names from %q", re.String()))
145	}
146
147	namesAsBytes := bytes.Split(nameBuffer[:length], ([]byte)(";"))
148	if len(namesAsBytes) != numNamedGroups {
149		panic(fmt.Errorf(
150			"the number of named groups (%d) does not match the number names found (%d)",
151			numNamedGroups, len(namesAsBytes),
152		))
153	}
154
155	for i, nameAsBytes := range namesAsBytes {
156		name := string(nameAsBytes)
157		namedGroupInfo[name] = int(groupNumbers[i])
158	}
159
160	return namedGroupInfo
161}
162
163func (re *Regexp) find(b []byte, n int, offset int) []int {
164	match := make([]int, re.numCaptures*2)
165
166	if n == 0 {
167		b = []byte{0}
168	}
169
170	bytesPtr := unsafe.Pointer(&b[0])
171
172	// captures contains two pairs of ints, start and end, so we need list
173	// twice the size of the capture groups.
174	captures := make([]C.int, re.numCaptures*2)
175	capturesPtr := unsafe.Pointer(&captures[0])
176
177	var numCaptures int32
178	numCapturesPtr := unsafe.Pointer(&numCaptures)
179
180	pos := int(C.SearchOnigRegex(
181		bytesPtr, C.int(n), C.int(offset), C.int(ONIG_OPTION_DEFAULT),
182		re.regex, re.errorInfo, (*C.char)(nil), (*C.int)(capturesPtr), (*C.int)(numCapturesPtr),
183	))
184
185	if pos < 0 {
186		return nil
187	}
188
189	if numCaptures <= 0 {
190		panic("cannot have 0 captures when processing a match")
191	}
192
193	if re.numCaptures != numCaptures {
194		panic(fmt.Errorf("expected %d captures but got %d", re.numCaptures, numCaptures))
195	}
196
197	for i := range captures {
198		match[i] = int(captures[i])
199	}
200
201	return match
202}
203
204func getCapture(b []byte, beg int, end int) []byte {
205	if beg < 0 || end < 0 {
206		return nil
207	}
208
209	return b[beg:end]
210}
211
212func (re *Regexp) match(b []byte, n int, offset int) bool {
213	if n == 0 {
214		b = []byte{0}
215	}
216
217	bytesPtr := unsafe.Pointer(&b[0])
218	pos := int(C.SearchOnigRegex(
219		bytesPtr, C.int(n), C.int(offset), C.int(ONIG_OPTION_DEFAULT),
220		re.regex, re.errorInfo, nil, nil, nil,
221	))
222
223	return pos >= 0
224}
225
226func (re *Regexp) findAll(b []byte, n int) [][]int {
227	if n < 0 {
228		n = len(b)
229	}
230
231	capture := make([][]int, 0, numMatchStartSize)
232	var offset int
233	for offset <= n {
234		match := re.find(b, n, offset)
235		if match == nil {
236			break
237		}
238
239		capture = append(capture, match)
240
241		// move offset to the ending index of the current match and prepare to
242		// find the next non-overlapping match.
243		offset = match[1]
244
245		// if match[0] == match[1], it means the current match does not advance
246		// the search. we need to exit the loop to avoid getting stuck here.
247		if match[0] == match[1] {
248			if offset < n && offset >= 0 {
249				//there are more bytes, so move offset by a word
250				_, width := utf8.DecodeRune(b[offset:])
251				offset += width
252			} else {
253				//search is over, exit loop
254				break
255			}
256		}
257	}
258
259	return capture
260}
261
262func (re *Regexp) FindIndex(b []byte) []int {
263	match := re.find(b, len(b), 0)
264	if len(match) == 0 {
265		return nil
266	}
267
268	return match[:2]
269}
270
271func (re *Regexp) Find(b []byte) []byte {
272	loc := re.FindIndex(b)
273	if loc == nil {
274		return nil
275	}
276
277	return getCapture(b, loc[0], loc[1])
278}
279
280func (re *Regexp) FindString(s string) string {
281	mb := re.Find([]byte(s))
282	if mb == nil {
283		return ""
284	}
285
286	return string(mb)
287}
288
289func (re *Regexp) FindStringIndex(s string) []int {
290	return re.FindIndex([]byte(s))
291}
292
293func (re *Regexp) FindAllIndex(b []byte, n int) [][]int {
294	matches := re.findAll(b, n)
295	if len(matches) == 0 {
296		return nil
297	}
298
299	return matches
300}
301
302func (re *Regexp) FindAll(b []byte, n int) [][]byte {
303	matches := re.FindAllIndex(b, n)
304	if matches == nil {
305		return nil
306	}
307
308	matchBytes := make([][]byte, 0, len(matches))
309	for _, match := range matches {
310		matchBytes = append(matchBytes, getCapture(b, match[0], match[1]))
311	}
312
313	return matchBytes
314}
315
316func (re *Regexp) FindAllString(s string, n int) []string {
317	b := []byte(s)
318	matches := re.FindAllIndex(b, n)
319	if matches == nil {
320		return nil
321	}
322
323	matchStrings := make([]string, 0, len(matches))
324	for _, match := range matches {
325		m := getCapture(b, match[0], match[1])
326		if m == nil {
327			matchStrings = append(matchStrings, "")
328		} else {
329			matchStrings = append(matchStrings, string(m))
330		}
331	}
332
333	return matchStrings
334
335}
336
337func (re *Regexp) FindAllStringIndex(s string, n int) [][]int {
338	return re.FindAllIndex([]byte(s), n)
339}
340
341func (re *Regexp) FindSubmatchIndex(b []byte) []int {
342	match := re.find(b, len(b), 0)
343	if len(match) == 0 {
344		return nil
345	}
346
347	return match
348}
349
350func (re *Regexp) FindSubmatch(b []byte) [][]byte {
351	match := re.FindSubmatchIndex(b)
352	if match == nil {
353		return nil
354	}
355
356	length := len(match) / 2
357	if length == 0 {
358		return nil
359	}
360
361	results := make([][]byte, 0, length)
362	for i := 0; i < length; i++ {
363		results = append(results, getCapture(b, match[2*i], match[2*i+1]))
364	}
365
366	return results
367}
368
369func (re *Regexp) FindStringSubmatch(s string) []string {
370	b := []byte(s)
371	match := re.FindSubmatchIndex(b)
372	if match == nil {
373		return nil
374	}
375
376	length := len(match) / 2
377	if length == 0 {
378		return nil
379	}
380
381	results := make([]string, 0, length)
382	for i := 0; i < length; i++ {
383		cap := getCapture(b, match[2*i], match[2*i+1])
384		if cap == nil {
385			results = append(results, "")
386		} else {
387			results = append(results, string(cap))
388		}
389	}
390
391	return results
392}
393
394func (re *Regexp) FindStringSubmatchIndex(s string) []int {
395	return re.FindSubmatchIndex([]byte(s))
396}
397
398func (re *Regexp) FindAllSubmatchIndex(b []byte, n int) [][]int {
399	matches := re.findAll(b, n)
400	if len(matches) == 0 {
401		return nil
402	}
403
404	return matches
405}
406
407func (re *Regexp) FindAllSubmatch(b []byte, n int) [][][]byte {
408	matches := re.findAll(b, n)
409	if len(matches) == 0 {
410		return nil
411	}
412
413	allCapturedBytes := make([][][]byte, 0, len(matches))
414	for _, match := range matches {
415		length := len(match) / 2
416		capturedBytes := make([][]byte, 0, length)
417		for i := 0; i < length; i++ {
418			capturedBytes = append(capturedBytes, getCapture(b, match[2*i], match[2*i+1]))
419		}
420
421		allCapturedBytes = append(allCapturedBytes, capturedBytes)
422	}
423
424	return allCapturedBytes
425}
426
427func (re *Regexp) FindAllStringSubmatch(s string, n int) [][]string {
428	b := []byte(s)
429
430	matches := re.findAll(b, n)
431	if len(matches) == 0 {
432		return nil
433	}
434
435	allCapturedStrings := make([][]string, 0, len(matches))
436	for _, match := range matches {
437		length := len(match) / 2
438		capturedStrings := make([]string, 0, length)
439		for i := 0; i < length; i++ {
440			cap := getCapture(b, match[2*i], match[2*i+1])
441			if cap == nil {
442				capturedStrings = append(capturedStrings, "")
443			} else {
444				capturedStrings = append(capturedStrings, string(cap))
445			}
446		}
447
448		allCapturedStrings = append(allCapturedStrings, capturedStrings)
449	}
450
451	return allCapturedStrings
452}
453
454func (re *Regexp) FindAllStringSubmatchIndex(s string, n int) [][]int {
455	return re.FindAllSubmatchIndex([]byte(s), n)
456}
457
458func (re *Regexp) Match(b []byte) bool {
459	return re.match(b, len(b), 0)
460}
461
462func (re *Regexp) MatchString(s string) bool {
463	return re.Match([]byte(s))
464}
465
466func (re *Regexp) NumSubexp() int {
467	return (int)(C.onig_number_of_captures(re.regex))
468}
469
470func fillCapturedValues(repl []byte, _ []byte, capturedBytes map[string][]byte) []byte {
471	replLen := len(repl)
472	newRepl := make([]byte, 0, replLen*3)
473	groupName := make([]byte, 0, replLen)
474
475	var inGroupNameMode, inEscapeMode bool
476	for index := 0; index < replLen; index++ {
477		ch := repl[index]
478		if inGroupNameMode && ch == byte('<') {
479		} else if inGroupNameMode && ch == byte('>') {
480			inGroupNameMode = false
481			capBytes := capturedBytes[string(groupName)]
482			newRepl = append(newRepl, capBytes...)
483			groupName = groupName[:0] //reset the name
484		} else if inGroupNameMode {
485			groupName = append(groupName, ch)
486		} else if inEscapeMode && ch <= byte('9') && byte('1') <= ch {
487			capNumStr := string(ch)
488			capBytes := capturedBytes[capNumStr]
489			newRepl = append(newRepl, capBytes...)
490		} else if inEscapeMode && ch == byte('k') && (index+1) < replLen && repl[index+1] == byte('<') {
491			inGroupNameMode = true
492			inEscapeMode = false
493			index++ //bypass the next char '<'
494		} else if inEscapeMode {
495			newRepl = append(newRepl, '\\')
496			newRepl = append(newRepl, ch)
497		} else if ch != '\\' {
498			newRepl = append(newRepl, ch)
499		}
500		if ch == byte('\\') || inEscapeMode {
501			inEscapeMode = !inEscapeMode
502		}
503	}
504
505	return newRepl
506}
507
508func (re *Regexp) replaceAll(src, repl []byte, replFunc func([]byte, []byte, map[string][]byte) []byte) []byte {
509	srcLen := len(src)
510	matches := re.findAll(src, srcLen)
511	if len(matches) == 0 {
512		return src
513	}
514
515	dest := make([]byte, 0, srcLen)
516	for i, match := range matches {
517		length := len(match) / 2
518		capturedBytes := make(map[string][]byte)
519
520		if re.namedGroupInfo == nil {
521			for j := 0; j < length; j++ {
522				capturedBytes[strconv.Itoa(j)] = getCapture(src, match[2*j], match[2*j+1])
523			}
524		} else {
525			for name, j := range re.namedGroupInfo {
526				capturedBytes[name] = getCapture(src, match[2*j], match[2*j+1])
527			}
528		}
529
530		matchBytes := getCapture(src, match[0], match[1])
531		newRepl := replFunc(repl, matchBytes, capturedBytes)
532		prevEnd := 0
533		if i > 0 {
534			prevMatch := matches[i-1][:2]
535			prevEnd = prevMatch[1]
536		}
537
538		if match[0] > prevEnd && prevEnd >= 0 && match[0] <= srcLen {
539			dest = append(dest, src[prevEnd:match[0]]...)
540		}
541
542		dest = append(dest, newRepl...)
543	}
544
545	lastEnd := matches[len(matches)-1][1]
546	if lastEnd < srcLen && lastEnd >= 0 {
547		dest = append(dest, src[lastEnd:]...)
548	}
549
550	return dest
551}
552
553func (re *Regexp) ReplaceAll(src, repl []byte) []byte {
554	return re.replaceAll(src, repl, fillCapturedValues)
555}
556
557func (re *Regexp) ReplaceAllFunc(src []byte, repl func([]byte) []byte) []byte {
558	return re.replaceAll(src, nil, func(_ []byte, matchBytes []byte, _ map[string][]byte) []byte {
559		return repl(matchBytes)
560	})
561}
562
563func (re *Regexp) ReplaceAllString(src, repl string) string {
564	return string(re.ReplaceAll([]byte(src), []byte(repl)))
565}
566
567func (re *Regexp) ReplaceAllStringFunc(src string, repl func(string) string) string {
568	return string(re.replaceAll([]byte(src), nil, func(_ []byte, matchBytes []byte, _ map[string][]byte) []byte {
569		return []byte(repl(string(matchBytes)))
570	}))
571}
572
573func (re *Regexp) String() string {
574	return re.pattern
575}
576
577func growBuffer(b []byte, offset int, n int) []byte {
578	if offset+n > cap(b) {
579		buf := make([]byte, 2*cap(b)+n)
580		copy(buf, b[:offset])
581		return buf
582	}
583
584	return b
585}
586
587func fromReader(r io.RuneReader) []byte {
588	b := make([]byte, numReadBufferStartSize)
589
590	var offset int
591	for {
592		rune, runeWidth, err := r.ReadRune()
593		if err != nil {
594			break
595		}
596
597		b = growBuffer(b, offset, runeWidth)
598		writeWidth := utf8.EncodeRune(b[offset:], rune)
599		if runeWidth != writeWidth {
600			panic("reading rune width not equal to the written rune width")
601		}
602
603		offset += writeWidth
604	}
605
606	return b[:offset]
607}
608
609func (re *Regexp) FindReaderIndex(r io.RuneReader) []int {
610	b := fromReader(r)
611	return re.FindIndex(b)
612}
613
614func (re *Regexp) FindReaderSubmatchIndex(r io.RuneReader) []int {
615	b := fromReader(r)
616	return re.FindSubmatchIndex(b)
617}
618
619func (re *Regexp) MatchReader(r io.RuneReader) bool {
620	b := fromReader(r)
621	return re.Match(b)
622}
623
624func (re *Regexp) LiteralPrefix() (prefix string, complete bool) {
625	//no easy way to implement this
626	return "", false
627}
628
629func MatchString(pattern string, s string) (matched bool, error error) {
630	re, err := Compile(pattern)
631	if err != nil {
632		return false, err
633	}
634
635	return re.MatchString(s), nil
636}
637
638func (re *Regexp) Gsub(src, repl string) string {
639	return string(re.replaceAll([]byte(src), []byte(repl), fillCapturedValues))
640}
641
642func (re *Regexp) GsubFunc(src string, replFunc func(string, map[string]string) string) string {
643	replaced := re.replaceAll([]byte(src), nil,
644		func(_ []byte, matchBytes []byte, capturedBytes map[string][]byte) []byte {
645			capturedStrings := make(map[string]string)
646			for name, capBytes := range capturedBytes {
647				capturedStrings[name] = string(capBytes)
648			}
649			matchString := string(matchBytes)
650			return ([]byte)(replFunc(matchString, capturedStrings))
651		},
652	)
653
654	return string(replaced)
655}