1package chroma
  2
  3import (
  4	"compress/gzip"
  5	"encoding/xml"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"io/fs"
 10	"path/filepath"
 11	"reflect"
 12	"regexp"
 13	"strings"
 14)
 15
 16// Serialisation of Chroma rules to XML. The format is:
 17//
 18//     <rules>
 19//       <state name="$STATE">
 20//         <rule [pattern="$PATTERN"]>
 21//           [<$EMITTER ...>]
 22//           [<$MUTATOR ...>]
 23//         </rule>
 24//       </state>
 25//     </rules>
 26//
 27// eg. Include("String") would become:
 28//
 29//     <rule>
 30//       <include state="String" />
 31//     </rule>
 32//
 33//     [null, null, {"kind": "include", "state": "String"}]
 34//
 35// eg. Rule{`\d+`, Text, nil} would become:
 36//
 37//     <rule pattern="\\d+">
 38//       <token type="Text"/>
 39//     </rule>
 40//
 41// eg. Rule{`"`, String, Push("String")}
 42//
 43//     <rule pattern="\"">
 44//       <token type="String" />
 45//       <push state="String" />
 46//     </rule>
 47//
 48// eg. Rule{`(\w+)(\n)`, ByGroups(Keyword, Whitespace), nil},
 49//
 50//     <rule pattern="(\\w+)(\\n)">
 51//       <bygroups token="Keyword" token="Whitespace" />
 52//       <push state="String" />
 53//     </rule>
 54var (
 55	// ErrNotSerialisable is returned if a lexer contains Rules that cannot be serialised.
 56	ErrNotSerialisable = fmt.Errorf("not serialisable")
 57	emitterTemplates   = func() map[string]SerialisableEmitter {
 58		out := map[string]SerialisableEmitter{}
 59		for _, emitter := range []SerialisableEmitter{
 60			&byGroupsEmitter{},
 61			&usingSelfEmitter{},
 62			TokenType(0),
 63			&usingEmitter{},
 64			&usingByGroup{},
 65		} {
 66			out[emitter.EmitterKind()] = emitter
 67		}
 68		return out
 69	}()
 70	mutatorTemplates = func() map[string]SerialisableMutator {
 71		out := map[string]SerialisableMutator{}
 72		for _, mutator := range []SerialisableMutator{
 73			&includeMutator{},
 74			&combinedMutator{},
 75			&multiMutator{},
 76			&pushMutator{},
 77			&popMutator{},
 78		} {
 79			out[mutator.MutatorKind()] = mutator
 80		}
 81		return out
 82	}()
 83)
 84
 85// fastUnmarshalConfig unmarshals only the Config from a serialised lexer.
 86func fastUnmarshalConfig(from fs.FS, path string) (*Config, error) {
 87	r, err := from.Open(path)
 88	if err != nil {
 89		return nil, err
 90	}
 91	defer r.Close()
 92	dec := xml.NewDecoder(r)
 93	for {
 94		token, err := dec.Token()
 95		if err != nil {
 96			if errors.Is(err, io.EOF) {
 97				return nil, fmt.Errorf("could not find <config> element")
 98			}
 99			return nil, err
100		}
101		switch se := token.(type) {
102		case xml.StartElement:
103			if se.Name.Local != "config" {
104				break
105			}
106
107			var config Config
108			err = dec.DecodeElement(&config, &se)
109			if err != nil {
110				panic(err)
111			}
112			return &config, nil
113		}
114	}
115}
116
117// MustNewXMLLexer constructs a new RegexLexer from an XML file or panics.
118func MustNewXMLLexer(from fs.FS, path string) *RegexLexer {
119	lex, err := NewXMLLexer(from, path)
120	if err != nil {
121		panic(err)
122	}
123	return lex
124}
125
126// NewXMLLexer creates a new RegexLexer from a serialised RegexLexer.
127func NewXMLLexer(from fs.FS, path string) (*RegexLexer, error) {
128	config, err := fastUnmarshalConfig(from, path)
129	if err != nil {
130		return nil, err
131	}
132	for _, glob := range append(config.Filenames, config.AliasFilenames...) {
133		_, err := filepath.Match(glob, "")
134		if err != nil {
135			return nil, fmt.Errorf("%s: %q is not a valid glob: %w", config.Name, glob, err)
136		}
137	}
138	return &RegexLexer{
139		config: config,
140		fetchRulesFunc: func() (Rules, error) {
141			var lexer struct {
142				Config
143				Rules Rules `xml:"rules"`
144			}
145			// Try to open .xml fallback to .xml.gz
146			fr, err := from.Open(path)
147			if err != nil {
148				if errors.Is(err, fs.ErrNotExist) {
149					path += ".gz"
150					fr, err = from.Open(path)
151					if err != nil {
152						return nil, err
153					}
154				} else {
155					return nil, err
156				}
157			}
158			defer fr.Close()
159			var r io.Reader = fr
160			if strings.HasSuffix(path, ".gz") {
161				r, err = gzip.NewReader(r)
162				if err != nil {
163					return nil, fmt.Errorf("%s: %w", path, err)
164				}
165			}
166			err = xml.NewDecoder(r).Decode(&lexer)
167			if err != nil {
168				return nil, fmt.Errorf("%s: %w", path, err)
169			}
170			return lexer.Rules, nil
171		},
172	}, nil
173}
174
175// Marshal a RegexLexer to XML.
176func Marshal(l *RegexLexer) ([]byte, error) {
177	type lexer struct {
178		Config Config `xml:"config"`
179		Rules  Rules  `xml:"rules"`
180	}
181
182	rules, err := l.Rules()
183	if err != nil {
184		return nil, err
185	}
186	root := &lexer{
187		Config: *l.Config(),
188		Rules:  rules,
189	}
190	data, err := xml.MarshalIndent(root, "", "  ")
191	if err != nil {
192		return nil, err
193	}
194	re := regexp.MustCompile(`></[a-zA-Z]+>`)
195	data = re.ReplaceAll(data, []byte(`/>`))
196	return data, nil
197}
198
199// Unmarshal a RegexLexer from XML.
200func Unmarshal(data []byte) (*RegexLexer, error) {
201	type lexer struct {
202		Config Config `xml:"config"`
203		Rules  Rules  `xml:"rules"`
204	}
205	root := &lexer{}
206	err := xml.Unmarshal(data, root)
207	if err != nil {
208		return nil, fmt.Errorf("invalid Lexer XML: %w", err)
209	}
210	lex, err := NewLexer(&root.Config, func() Rules { return root.Rules })
211	if err != nil {
212		return nil, err
213	}
214	return lex, nil
215}
216
217func marshalMutator(e *xml.Encoder, mutator Mutator) error {
218	if mutator == nil {
219		return nil
220	}
221	smutator, ok := mutator.(SerialisableMutator)
222	if !ok {
223		return fmt.Errorf("unsupported mutator: %w", ErrNotSerialisable)
224	}
225	return e.EncodeElement(mutator, xml.StartElement{Name: xml.Name{Local: smutator.MutatorKind()}})
226}
227
228func unmarshalMutator(d *xml.Decoder, start xml.StartElement) (Mutator, error) {
229	kind := start.Name.Local
230	mutator, ok := mutatorTemplates[kind]
231	if !ok {
232		return nil, fmt.Errorf("unknown mutator %q: %w", kind, ErrNotSerialisable)
233	}
234	value, target := newFromTemplate(mutator)
235	if err := d.DecodeElement(target, &start); err != nil {
236		return nil, err
237	}
238	return value().(SerialisableMutator), nil
239}
240
241func marshalEmitter(e *xml.Encoder, emitter Emitter) error {
242	if emitter == nil {
243		return nil
244	}
245	semitter, ok := emitter.(SerialisableEmitter)
246	if !ok {
247		return fmt.Errorf("unsupported emitter %T: %w", emitter, ErrNotSerialisable)
248	}
249	return e.EncodeElement(emitter, xml.StartElement{
250		Name: xml.Name{Local: semitter.EmitterKind()},
251	})
252}
253
254func unmarshalEmitter(d *xml.Decoder, start xml.StartElement) (Emitter, error) {
255	kind := start.Name.Local
256	mutator, ok := emitterTemplates[kind]
257	if !ok {
258		return nil, fmt.Errorf("unknown emitter %q: %w", kind, ErrNotSerialisable)
259	}
260	value, target := newFromTemplate(mutator)
261	if err := d.DecodeElement(target, &start); err != nil {
262		return nil, err
263	}
264	return value().(SerialisableEmitter), nil
265}
266
267func (r Rule) MarshalXML(e *xml.Encoder, _ xml.StartElement) error {
268	start := xml.StartElement{
269		Name: xml.Name{Local: "rule"},
270	}
271	if r.Pattern != "" {
272		start.Attr = append(start.Attr, xml.Attr{
273			Name:  xml.Name{Local: "pattern"},
274			Value: r.Pattern,
275		})
276	}
277	if err := e.EncodeToken(start); err != nil {
278		return err
279	}
280	if err := marshalEmitter(e, r.Type); err != nil {
281		return err
282	}
283	if err := marshalMutator(e, r.Mutator); err != nil {
284		return err
285	}
286	return e.EncodeToken(xml.EndElement{Name: start.Name})
287}
288
289func (r *Rule) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
290	for _, attr := range start.Attr {
291		if attr.Name.Local == "pattern" {
292			r.Pattern = attr.Value
293			break
294		}
295	}
296	for {
297		token, err := d.Token()
298		if err != nil {
299			return err
300		}
301		switch token := token.(type) {
302		case xml.StartElement:
303			mutator, err := unmarshalMutator(d, token)
304			if err != nil && !errors.Is(err, ErrNotSerialisable) {
305				return err
306			} else if err == nil {
307				if r.Mutator != nil {
308					return fmt.Errorf("duplicate mutator")
309				}
310				r.Mutator = mutator
311				continue
312			}
313			emitter, err := unmarshalEmitter(d, token)
314			if err != nil && !errors.Is(err, ErrNotSerialisable) { // nolint: gocritic
315				return err
316			} else if err == nil {
317				if r.Type != nil {
318					return fmt.Errorf("duplicate emitter")
319				}
320				r.Type = emitter
321				continue
322			} else {
323				return err
324			}
325
326		case xml.EndElement:
327			return nil
328		}
329	}
330}
331
332type xmlRuleState struct {
333	Name  string `xml:"name,attr"`
334	Rules []Rule `xml:"rule"`
335}
336
337type xmlRules struct {
338	States []xmlRuleState `xml:"state"`
339}
340
341func (r Rules) MarshalXML(e *xml.Encoder, _ xml.StartElement) error {
342	xr := xmlRules{}
343	for state, rules := range r {
344		xr.States = append(xr.States, xmlRuleState{
345			Name:  state,
346			Rules: rules,
347		})
348	}
349	return e.EncodeElement(xr, xml.StartElement{Name: xml.Name{Local: "rules"}})
350}
351
352func (r *Rules) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
353	xr := xmlRules{}
354	if err := d.DecodeElement(&xr, &start); err != nil {
355		return err
356	}
357	if *r == nil {
358		*r = Rules{}
359	}
360	for _, state := range xr.States {
361		(*r)[state.Name] = state.Rules
362	}
363	return nil
364}
365
366type xmlTokenType struct {
367	Type string `xml:"type,attr"`
368}
369
370func (t *TokenType) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
371	el := xmlTokenType{}
372	if err := d.DecodeElement(&el, &start); err != nil {
373		return err
374	}
375	for tt, text := range _TokenType_map {
376		if text == el.Type {
377			*t = tt
378			return nil
379		}
380	}
381	return fmt.Errorf("unknown TokenType %q", el.Type)
382}
383
384func (t TokenType) MarshalXML(e *xml.Encoder, start xml.StartElement) error {
385	start.Attr = append(start.Attr, xml.Attr{Name: xml.Name{Local: "type"}, Value: t.String()})
386	if err := e.EncodeToken(start); err != nil {
387		return err
388	}
389	return e.EncodeToken(xml.EndElement{Name: start.Name})
390}
391
392// This hijinks is a bit unfortunate but without it we can't deserialise into TokenType.
393func newFromTemplate(template interface{}) (value func() interface{}, target interface{}) {
394	t := reflect.TypeOf(template)
395	if t.Kind() == reflect.Ptr {
396		v := reflect.New(t.Elem())
397		return v.Interface, v.Interface()
398	}
399	v := reflect.New(t)
400	return func() interface{} { return v.Elem().Interface() }, v.Interface()
401}
402
403func (b *Emitters) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
404	for {
405		token, err := d.Token()
406		if err != nil {
407			return err
408		}
409		switch token := token.(type) {
410		case xml.StartElement:
411			emitter, err := unmarshalEmitter(d, token)
412			if err != nil {
413				return err
414			}
415			*b = append(*b, emitter)
416
417		case xml.EndElement:
418			return nil
419		}
420	}
421}
422
423func (b Emitters) MarshalXML(e *xml.Encoder, start xml.StartElement) error {
424	if err := e.EncodeToken(start); err != nil {
425		return err
426	}
427	for _, m := range b {
428		if err := marshalEmitter(e, m); err != nil {
429			return err
430		}
431	}
432	return e.EncodeToken(xml.EndElement{Name: start.Name})
433}