1package chroma
  2
  3import (
  4	"compress/gzip"
  5	"encoding/xml"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"io/fs"
 10	"math"
 11	"path/filepath"
 12	"reflect"
 13	"regexp"
 14	"strings"
 15
 16	"github.com/dlclark/regexp2"
 17)
 18
 19// Serialisation of Chroma rules to XML. The format is:
 20//
 21//	<rules>
 22//	  <state name="$STATE">
 23//	    <rule [pattern="$PATTERN"]>
 24//	      [<$EMITTER ...>]
 25//	      [<$MUTATOR ...>]
 26//	    </rule>
 27//	  </state>
 28//	</rules>
 29//
 30// eg. Include("String") would become:
 31//
 32//	<rule>
 33//	  <include state="String" />
 34//	</rule>
 35//
 36//	[null, null, {"kind": "include", "state": "String"}]
 37//
 38// eg. Rule{`\d+`, Text, nil} would become:
 39//
 40//	<rule pattern="\\d+">
 41//	  <token type="Text"/>
 42//	</rule>
 43//
 44// eg. Rule{`"`, String, Push("String")}
 45//
 46//	<rule pattern="\"">
 47//	  <token type="String" />
 48//	  <push state="String" />
 49//	</rule>
 50//
 51// eg. Rule{`(\w+)(\n)`, ByGroups(Keyword, Whitespace), nil},
 52//
 53//	<rule pattern="(\\w+)(\\n)">
 54//	  <bygroups token="Keyword" token="Whitespace" />
 55//	  <push state="String" />
 56//	</rule>
 57var (
 58	// ErrNotSerialisable is returned if a lexer contains Rules that cannot be serialised.
 59	ErrNotSerialisable = fmt.Errorf("not serialisable")
 60	emitterTemplates   = func() map[string]SerialisableEmitter {
 61		out := map[string]SerialisableEmitter{}
 62		for _, emitter := range []SerialisableEmitter{
 63			&byGroupsEmitter{},
 64			&usingSelfEmitter{},
 65			TokenType(0),
 66			&usingEmitter{},
 67			&usingByGroup{},
 68		} {
 69			out[emitter.EmitterKind()] = emitter
 70		}
 71		return out
 72	}()
 73	mutatorTemplates = func() map[string]SerialisableMutator {
 74		out := map[string]SerialisableMutator{}
 75		for _, mutator := range []SerialisableMutator{
 76			&includeMutator{},
 77			&combinedMutator{},
 78			&multiMutator{},
 79			&pushMutator{},
 80			&popMutator{},
 81		} {
 82			out[mutator.MutatorKind()] = mutator
 83		}
 84		return out
 85	}()
 86)
 87
 88// fastUnmarshalConfig unmarshals only the Config from a serialised lexer.
 89func fastUnmarshalConfig(from fs.FS, path string) (*Config, error) {
 90	r, err := from.Open(path)
 91	if err != nil {
 92		return nil, err
 93	}
 94	defer r.Close()
 95	dec := xml.NewDecoder(r)
 96	for {
 97		token, err := dec.Token()
 98		if err != nil {
 99			if errors.Is(err, io.EOF) {
100				return nil, fmt.Errorf("could not find <config> element")
101			}
102			return nil, err
103		}
104		switch se := token.(type) {
105		case xml.StartElement:
106			if se.Name.Local != "config" {
107				break
108			}
109
110			var config Config
111			err = dec.DecodeElement(&config, &se)
112			if err != nil {
113				return nil, fmt.Errorf("%s: %w", path, err)
114			}
115			return &config, nil
116		}
117	}
118}
119
120// MustNewXMLLexer constructs a new RegexLexer from an XML file or panics.
121func MustNewXMLLexer(from fs.FS, path string) *RegexLexer {
122	lex, err := NewXMLLexer(from, path)
123	if err != nil {
124		panic(err)
125	}
126	return lex
127}
128
129// NewXMLLexer creates a new RegexLexer from a serialised RegexLexer.
130func NewXMLLexer(from fs.FS, path string) (*RegexLexer, error) {
131	config, err := fastUnmarshalConfig(from, path)
132	if err != nil {
133		return nil, err
134	}
135
136	for _, glob := range append(config.Filenames, config.AliasFilenames...) {
137		_, err := filepath.Match(glob, "")
138		if err != nil {
139			return nil, fmt.Errorf("%s: %q is not a valid glob: %w", config.Name, glob, err)
140		}
141	}
142
143	var analyserFn func(string) float32
144
145	if config.Analyse != nil {
146		type regexAnalyse struct {
147			re    *regexp2.Regexp
148			score float32
149		}
150
151		regexAnalysers := make([]regexAnalyse, 0, len(config.Analyse.Regexes))
152
153		regexFlags := regexp2.None
154		if config.CaseInsensitive {
155			regexFlags = regexp2.IgnoreCase
156		}
157		for _, ra := range config.Analyse.Regexes {
158			re, err := regexp2.Compile(ra.Pattern, regexFlags)
159			if err != nil {
160				return nil, fmt.Errorf("%s: %q is not a valid analyser regex: %w", config.Name, ra.Pattern, err)
161			}
162
163			regexAnalysers = append(regexAnalysers, regexAnalyse{re, ra.Score})
164		}
165
166		analyserFn = func(text string) float32 {
167			var score float32
168
169			for _, ra := range regexAnalysers {
170				ok, err := ra.re.MatchString(text)
171				if err != nil {
172					return 0
173				}
174
175				if ok && config.Analyse.First {
176					return float32(math.Min(float64(ra.score), 1.0))
177				}
178
179				if ok {
180					score += ra.score
181				}
182			}
183
184			return float32(math.Min(float64(score), 1.0))
185		}
186	}
187
188	return &RegexLexer{
189		config:   config,
190		analyser: analyserFn,
191		fetchRulesFunc: func() (Rules, error) {
192			var lexer struct {
193				Config
194				Rules Rules `xml:"rules"`
195			}
196			// Try to open .xml fallback to .xml.gz
197			fr, err := from.Open(path)
198			if err != nil {
199				if errors.Is(err, fs.ErrNotExist) {
200					path += ".gz"
201					fr, err = from.Open(path)
202					if err != nil {
203						return nil, err
204					}
205				} else {
206					return nil, err
207				}
208			}
209			defer fr.Close()
210			var r io.Reader = fr
211			if strings.HasSuffix(path, ".gz") {
212				r, err = gzip.NewReader(r)
213				if err != nil {
214					return nil, fmt.Errorf("%s: %w", path, err)
215				}
216			}
217			err = xml.NewDecoder(r).Decode(&lexer)
218			if err != nil {
219				return nil, fmt.Errorf("%s: %w", path, err)
220			}
221			return lexer.Rules, nil
222		},
223	}, nil
224}
225
226// Marshal a RegexLexer to XML.
227func Marshal(l *RegexLexer) ([]byte, error) {
228	type lexer struct {
229		Config Config `xml:"config"`
230		Rules  Rules  `xml:"rules"`
231	}
232
233	rules, err := l.Rules()
234	if err != nil {
235		return nil, err
236	}
237	root := &lexer{
238		Config: *l.Config(),
239		Rules:  rules,
240	}
241	data, err := xml.MarshalIndent(root, "", "  ")
242	if err != nil {
243		return nil, err
244	}
245	re := regexp.MustCompile(`></[a-zA-Z]+>`)
246	data = re.ReplaceAll(data, []byte(`/>`))
247	return data, nil
248}
249
250// Unmarshal a RegexLexer from XML.
251func Unmarshal(data []byte) (*RegexLexer, error) {
252	type lexer struct {
253		Config Config `xml:"config"`
254		Rules  Rules  `xml:"rules"`
255	}
256	root := &lexer{}
257	err := xml.Unmarshal(data, root)
258	if err != nil {
259		return nil, fmt.Errorf("invalid Lexer XML: %w", err)
260	}
261	lex, err := NewLexer(&root.Config, func() Rules { return root.Rules })
262	if err != nil {
263		return nil, err
264	}
265	return lex, nil
266}
267
268func marshalMutator(e *xml.Encoder, mutator Mutator) error {
269	if mutator == nil {
270		return nil
271	}
272	smutator, ok := mutator.(SerialisableMutator)
273	if !ok {
274		return fmt.Errorf("unsupported mutator: %w", ErrNotSerialisable)
275	}
276	return e.EncodeElement(mutator, xml.StartElement{Name: xml.Name{Local: smutator.MutatorKind()}})
277}
278
279func unmarshalMutator(d *xml.Decoder, start xml.StartElement) (Mutator, error) {
280	kind := start.Name.Local
281	mutator, ok := mutatorTemplates[kind]
282	if !ok {
283		return nil, fmt.Errorf("unknown mutator %q: %w", kind, ErrNotSerialisable)
284	}
285	value, target := newFromTemplate(mutator)
286	if err := d.DecodeElement(target, &start); err != nil {
287		return nil, err
288	}
289	return value().(SerialisableMutator), nil
290}
291
292func marshalEmitter(e *xml.Encoder, emitter Emitter) error {
293	if emitter == nil {
294		return nil
295	}
296	semitter, ok := emitter.(SerialisableEmitter)
297	if !ok {
298		return fmt.Errorf("unsupported emitter %T: %w", emitter, ErrNotSerialisable)
299	}
300	return e.EncodeElement(emitter, xml.StartElement{
301		Name: xml.Name{Local: semitter.EmitterKind()},
302	})
303}
304
305func unmarshalEmitter(d *xml.Decoder, start xml.StartElement) (Emitter, error) {
306	kind := start.Name.Local
307	mutator, ok := emitterTemplates[kind]
308	if !ok {
309		return nil, fmt.Errorf("unknown emitter %q: %w", kind, ErrNotSerialisable)
310	}
311	value, target := newFromTemplate(mutator)
312	if err := d.DecodeElement(target, &start); err != nil {
313		return nil, err
314	}
315	return value().(SerialisableEmitter), nil
316}
317
318func (r Rule) MarshalXML(e *xml.Encoder, _ xml.StartElement) error {
319	start := xml.StartElement{
320		Name: xml.Name{Local: "rule"},
321	}
322	if r.Pattern != "" {
323		start.Attr = append(start.Attr, xml.Attr{
324			Name:  xml.Name{Local: "pattern"},
325			Value: r.Pattern,
326		})
327	}
328	if err := e.EncodeToken(start); err != nil {
329		return err
330	}
331	if err := marshalEmitter(e, r.Type); err != nil {
332		return err
333	}
334	if err := marshalMutator(e, r.Mutator); err != nil {
335		return err
336	}
337	return e.EncodeToken(xml.EndElement{Name: start.Name})
338}
339
340func (r *Rule) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
341	for _, attr := range start.Attr {
342		if attr.Name.Local == "pattern" {
343			r.Pattern = attr.Value
344			break
345		}
346	}
347	for {
348		token, err := d.Token()
349		if err != nil {
350			return err
351		}
352		switch token := token.(type) {
353		case xml.StartElement:
354			mutator, err := unmarshalMutator(d, token)
355			if err != nil && !errors.Is(err, ErrNotSerialisable) {
356				return err
357			} else if err == nil {
358				if r.Mutator != nil {
359					return fmt.Errorf("duplicate mutator")
360				}
361				r.Mutator = mutator
362				continue
363			}
364			emitter, err := unmarshalEmitter(d, token)
365			if err != nil && !errors.Is(err, ErrNotSerialisable) { // nolint: gocritic
366				return err
367			} else if err == nil {
368				if r.Type != nil {
369					return fmt.Errorf("duplicate emitter")
370				}
371				r.Type = emitter
372				continue
373			} else {
374				return err
375			}
376
377		case xml.EndElement:
378			return nil
379		}
380	}
381}
382
383type xmlRuleState struct {
384	Name  string `xml:"name,attr"`
385	Rules []Rule `xml:"rule"`
386}
387
388type xmlRules struct {
389	States []xmlRuleState `xml:"state"`
390}
391
392func (r Rules) MarshalXML(e *xml.Encoder, _ xml.StartElement) error {
393	xr := xmlRules{}
394	for state, rules := range r {
395		xr.States = append(xr.States, xmlRuleState{
396			Name:  state,
397			Rules: rules,
398		})
399	}
400	return e.EncodeElement(xr, xml.StartElement{Name: xml.Name{Local: "rules"}})
401}
402
403func (r *Rules) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
404	xr := xmlRules{}
405	if err := d.DecodeElement(&xr, &start); err != nil {
406		return err
407	}
408	if *r == nil {
409		*r = Rules{}
410	}
411	for _, state := range xr.States {
412		(*r)[state.Name] = state.Rules
413	}
414	return nil
415}
416
417type xmlTokenType struct {
418	Type string `xml:"type,attr"`
419}
420
421func (t *TokenType) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
422	el := xmlTokenType{}
423	if err := d.DecodeElement(&el, &start); err != nil {
424		return err
425	}
426	tt, err := TokenTypeString(el.Type)
427	if err != nil {
428		return err
429	}
430	*t = tt
431	return nil
432}
433
434func (t TokenType) MarshalXML(e *xml.Encoder, start xml.StartElement) error {
435	start.Attr = append(start.Attr, xml.Attr{Name: xml.Name{Local: "type"}, Value: t.String()})
436	if err := e.EncodeToken(start); err != nil {
437		return err
438	}
439	return e.EncodeToken(xml.EndElement{Name: start.Name})
440}
441
442// This hijinks is a bit unfortunate but without it we can't deserialise into TokenType.
443func newFromTemplate(template interface{}) (value func() interface{}, target interface{}) {
444	t := reflect.TypeOf(template)
445	if t.Kind() == reflect.Ptr {
446		v := reflect.New(t.Elem())
447		return v.Interface, v.Interface()
448	}
449	v := reflect.New(t)
450	return func() interface{} { return v.Elem().Interface() }, v.Interface()
451}
452
453func (b *Emitters) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
454	for {
455		token, err := d.Token()
456		if err != nil {
457			return err
458		}
459		switch token := token.(type) {
460		case xml.StartElement:
461			emitter, err := unmarshalEmitter(d, token)
462			if err != nil {
463				return err
464			}
465			*b = append(*b, emitter)
466
467		case xml.EndElement:
468			return nil
469		}
470	}
471}
472
473func (b Emitters) MarshalXML(e *xml.Encoder, start xml.StartElement) error {
474	if err := e.EncodeToken(start); err != nil {
475		return err
476	}
477	for _, m := range b {
478		if err := marshalEmitter(e, m); err != nil {
479			return err
480		}
481	}
482	return e.EncodeToken(xml.EndElement{Name: start.Name})
483}