1package arg
  2
  3import (
  4	"encoding"
  5	"encoding/csv"
  6	"errors"
  7	"fmt"
  8	"os"
  9	"path/filepath"
 10	"reflect"
 11	"strings"
 12
 13	scalar "github.com/alexflint/go-scalar"
 14)
 15
 16// path represents a sequence of steps to find the output location for an
 17// argument or subcommand in the final destination struct
 18type path struct {
 19	root   int                   // index of the destination struct
 20	fields []reflect.StructField // sequence of struct fields to traverse
 21}
 22
 23// String gets a string representation of the given path
 24func (p path) String() string {
 25	s := "args"
 26	for _, f := range p.fields {
 27		s += "." + f.Name
 28	}
 29	return s
 30}
 31
 32// Child gets a new path representing a child of this path.
 33func (p path) Child(f reflect.StructField) path {
 34	// copy the entire slice of fields to avoid possible slice overwrite
 35	subfields := make([]reflect.StructField, len(p.fields)+1)
 36	copy(subfields, p.fields)
 37	subfields[len(subfields)-1] = f
 38	return path{
 39		root:   p.root,
 40		fields: subfields,
 41	}
 42}
 43
 44// spec represents a command line option
 45type spec struct {
 46	dest        path
 47	field       reflect.StructField // the struct field from which this option was created
 48	long        string              // the --long form for this option, or empty if none
 49	short       string              // the -s short form for this option, or empty if none
 50	cardinality cardinality         // determines how many tokens will be present (possible values: zero, one, multiple)
 51	required    bool                // if true, this option must be present on the command line
 52	positional  bool                // if true, this option will be looked for in the positional flags
 53	separate    bool                // if true, each slice and map entry will have its own --flag
 54	help        string              // the help text for this option
 55	env         string              // the name of the environment variable for this option, or empty for none
 56	defaultVal  string              // default value for this option
 57	placeholder string              // name of the data in help
 58}
 59
 60// command represents a named subcommand, or the top-level command
 61type command struct {
 62	name        string
 63	help        string
 64	dest        path
 65	specs       []*spec
 66	subcommands []*command
 67	parent      *command
 68}
 69
 70// ErrHelp indicates that -h or --help were provided
 71var ErrHelp = errors.New("help requested by user")
 72
 73// ErrVersion indicates that --version was provided
 74var ErrVersion = errors.New("version requested by user")
 75
 76// MustParse processes command line arguments and exits upon failure
 77func MustParse(dest ...interface{}) *Parser {
 78	p, err := NewParser(Config{}, dest...)
 79	if err != nil {
 80		fmt.Fprintln(stdout, err)
 81		osExit(-1)
 82		return nil // just in case osExit was monkey-patched
 83	}
 84
 85	err = p.Parse(flags())
 86	switch {
 87	case err == ErrHelp:
 88		p.writeHelpForSubcommand(stdout, p.lastCmd)
 89		osExit(0)
 90	case err == ErrVersion:
 91		fmt.Fprintln(stdout, p.version)
 92		osExit(0)
 93	case err != nil:
 94		p.failWithSubcommand(err.Error(), p.lastCmd)
 95	}
 96
 97	return p
 98}
 99
100// Parse processes command line arguments and stores them in dest
101func Parse(dest ...interface{}) error {
102	p, err := NewParser(Config{}, dest...)
103	if err != nil {
104		return err
105	}
106	return p.Parse(flags())
107}
108
109// flags gets all command line arguments other than the first (program name)
110func flags() []string {
111	if len(os.Args) == 0 { // os.Args could be empty
112		return nil
113	}
114	return os.Args[1:]
115}
116
117// Config represents configuration options for an argument parser
118type Config struct {
119	// Program is the name of the program used in the help text
120	Program string
121
122	// IgnoreEnv instructs the library not to read environment variables
123	IgnoreEnv bool
124}
125
126// Parser represents a set of command line options with destination values
127type Parser struct {
128	cmd         *command
129	roots       []reflect.Value
130	config      Config
131	version     string
132	description string
133
134	// the following field changes during processing of command line arguments
135	lastCmd *command
136}
137
138// Versioned is the interface that the destination struct should implement to
139// make a version string appear at the top of the help message.
140type Versioned interface {
141	// Version returns the version string that will be printed on a line by itself
142	// at the top of the help message.
143	Version() string
144}
145
146// Described is the interface that the destination struct should implement to
147// make a description string appear at the top of the help message.
148type Described interface {
149	// Description returns the string that will be printed on a line by itself
150	// at the top of the help message.
151	Description() string
152}
153
154// walkFields calls a function for each field of a struct, recursively expanding struct fields.
155func walkFields(t reflect.Type, visit func(field reflect.StructField, owner reflect.Type) bool) {
156	walkFieldsImpl(t, visit, nil)
157}
158
159func walkFieldsImpl(t reflect.Type, visit func(field reflect.StructField, owner reflect.Type) bool, path []int) {
160	for i := 0; i < t.NumField(); i++ {
161		field := t.Field(i)
162		field.Index = make([]int, len(path)+1)
163		copy(field.Index, append(path, i))
164		expand := visit(field, t)
165		if expand && field.Type.Kind() == reflect.Struct {
166			var subpath []int
167			if field.Anonymous {
168				subpath = append(path, i)
169			}
170			walkFieldsImpl(field.Type, visit, subpath)
171		}
172	}
173}
174
175// NewParser constructs a parser from a list of destination structs
176func NewParser(config Config, dests ...interface{}) (*Parser, error) {
177	// first pick a name for the command for use in the usage text
178	var name string
179	switch {
180	case config.Program != "":
181		name = config.Program
182	case len(os.Args) > 0:
183		name = filepath.Base(os.Args[0])
184	default:
185		name = "program"
186	}
187
188	// construct a parser
189	p := Parser{
190		cmd:    &command{name: name},
191		config: config,
192	}
193
194	// make a list of roots
195	for _, dest := range dests {
196		p.roots = append(p.roots, reflect.ValueOf(dest))
197	}
198
199	// process each of the destination values
200	for i, dest := range dests {
201		t := reflect.TypeOf(dest)
202		if t.Kind() != reflect.Ptr {
203			panic(fmt.Sprintf("%s is not a pointer (did you forget an ampersand?)", t))
204		}
205
206		cmd, err := cmdFromStruct(name, path{root: i}, t)
207		if err != nil {
208			return nil, err
209		}
210
211		// add nonzero field values as defaults
212		for _, spec := range cmd.specs {
213			if v := p.val(spec.dest); v.IsValid() && !isZero(v) {
214				if defaultVal, ok := v.Interface().(encoding.TextMarshaler); ok {
215					str, err := defaultVal.MarshalText()
216					if err != nil {
217						return nil, fmt.Errorf("%v: error marshaling default value to string: %v", spec.dest, err)
218					}
219					spec.defaultVal = string(str)
220				} else {
221					spec.defaultVal = fmt.Sprintf("%v", v)
222				}
223			}
224		}
225
226		p.cmd.specs = append(p.cmd.specs, cmd.specs...)
227		p.cmd.subcommands = append(p.cmd.subcommands, cmd.subcommands...)
228
229		if dest, ok := dest.(Versioned); ok {
230			p.version = dest.Version()
231		}
232		if dest, ok := dest.(Described); ok {
233			p.description = dest.Description()
234		}
235	}
236
237	return &p, nil
238}
239
240func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) {
241	// commands can only be created from pointers to structs
242	if t.Kind() != reflect.Ptr {
243		return nil, fmt.Errorf("subcommands must be pointers to structs but %s is a %s",
244			dest, t.Kind())
245	}
246
247	t = t.Elem()
248	if t.Kind() != reflect.Struct {
249		return nil, fmt.Errorf("subcommands must be pointers to structs but %s is a pointer to %s",
250			dest, t.Kind())
251	}
252
253	cmd := command{
254		name: name,
255		dest: dest,
256	}
257
258	var errs []string
259	walkFields(t, func(field reflect.StructField, t reflect.Type) bool {
260		// check for the ignore switch in the tag
261		tag := field.Tag.Get("arg")
262		if tag == "-" {
263			return false
264		}
265
266		// if this is an embedded struct then recurse into its fields, even if
267		// it is unexported, because exported fields on unexported embedded
268		// structs are still writable
269		if field.Anonymous && field.Type.Kind() == reflect.Struct {
270			return true
271		}
272
273		// ignore any other unexported field
274		if !isExported(field.Name) {
275			return false
276		}
277
278		// duplicate the entire path to avoid slice overwrites
279		subdest := dest.Child(field)
280		spec := spec{
281			dest:  subdest,
282			field: field,
283			long:  strings.ToLower(field.Name),
284		}
285
286		help, exists := field.Tag.Lookup("help")
287		if exists {
288			spec.help = help
289		}
290
291		defaultVal, hasDefault := field.Tag.Lookup("default")
292		if hasDefault {
293			spec.defaultVal = defaultVal
294		}
295
296		// Look at the tag
297		var isSubcommand bool // tracks whether this field is a subcommand
298		for _, key := range strings.Split(tag, ",") {
299			if key == "" {
300				continue
301			}
302			key = strings.TrimLeft(key, " ")
303			var value string
304			if pos := strings.Index(key, ":"); pos != -1 {
305				value = key[pos+1:]
306				key = key[:pos]
307			}
308
309			switch {
310			case strings.HasPrefix(key, "---"):
311				errs = append(errs, fmt.Sprintf("%s.%s: too many hyphens", t.Name(), field.Name))
312			case strings.HasPrefix(key, "--"):
313				spec.long = key[2:]
314			case strings.HasPrefix(key, "-"):
315				if len(key) != 2 {
316					errs = append(errs, fmt.Sprintf("%s.%s: short arguments must be one character only",
317						t.Name(), field.Name))
318					return false
319				}
320				spec.short = key[1:]
321			case key == "required":
322				if hasDefault {
323					errs = append(errs, fmt.Sprintf("%s.%s: 'required' cannot be used when a default value is specified",
324						t.Name(), field.Name))
325					return false
326				}
327				spec.required = true
328			case key == "positional":
329				spec.positional = true
330			case key == "separate":
331				spec.separate = true
332			case key == "help": // deprecated
333				spec.help = value
334			case key == "env":
335				// Use override name if provided
336				if value != "" {
337					spec.env = value
338				} else {
339					spec.env = strings.ToUpper(field.Name)
340				}
341			case key == "subcommand":
342				// decide on a name for the subcommand
343				cmdname := value
344				if cmdname == "" {
345					cmdname = strings.ToLower(field.Name)
346				}
347
348				// parse the subcommand recursively
349				subcmd, err := cmdFromStruct(cmdname, subdest, field.Type)
350				if err != nil {
351					errs = append(errs, err.Error())
352					return false
353				}
354
355				subcmd.parent = &cmd
356				subcmd.help = field.Tag.Get("help")
357
358				cmd.subcommands = append(cmd.subcommands, subcmd)
359				isSubcommand = true
360			default:
361				errs = append(errs, fmt.Sprintf("unrecognized tag '%s' on field %s", key, tag))
362				return false
363			}
364		}
365
366		placeholder, hasPlaceholder := field.Tag.Lookup("placeholder")
367		if hasPlaceholder {
368			spec.placeholder = placeholder
369		} else if spec.long != "" {
370			spec.placeholder = strings.ToUpper(spec.long)
371		} else {
372			spec.placeholder = strings.ToUpper(spec.field.Name)
373		}
374
375		// Check whether this field is supported. It's good to do this here rather than
376		// wait until ParseValue because it means that a program with invalid argument
377		// fields will always fail regardless of whether the arguments it received
378		// exercised those fields.
379		if !isSubcommand {
380			cmd.specs = append(cmd.specs, &spec)
381
382			var err error
383			spec.cardinality, err = cardinalityOf(field.Type)
384			if err != nil {
385				errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported",
386					t.Name(), field.Name, field.Type.String()))
387				return false
388			}
389			if spec.cardinality == multiple && hasDefault {
390				errs = append(errs, fmt.Sprintf("%s.%s: default values are not supported for slice or map fields",
391					t.Name(), field.Name))
392				return false
393			}
394		}
395
396		// if this was an embedded field then we already returned true up above
397		return false
398	})
399
400	if len(errs) > 0 {
401		return nil, errors.New(strings.Join(errs, "\n"))
402	}
403
404	// check that we don't have both positionals and subcommands
405	var hasPositional bool
406	for _, spec := range cmd.specs {
407		if spec.positional {
408			hasPositional = true
409		}
410	}
411	if hasPositional && len(cmd.subcommands) > 0 {
412		return nil, fmt.Errorf("%s cannot have both subcommands and positional arguments", dest)
413	}
414
415	return &cmd, nil
416}
417
418// Parse processes the given command line option, storing the results in the field
419// of the structs from which NewParser was constructed
420func (p *Parser) Parse(args []string) error {
421	err := p.process(args)
422	if err != nil {
423		// If -h or --help were specified then make sure help text supercedes other errors
424		for _, arg := range args {
425			if arg == "-h" || arg == "--help" {
426				return ErrHelp
427			}
428			if arg == "--" {
429				break
430			}
431		}
432	}
433	return err
434}
435
436// process environment vars for the given arguments
437func (p *Parser) captureEnvVars(specs []*spec, wasPresent map[*spec]bool) error {
438	for _, spec := range specs {
439		if spec.env == "" {
440			continue
441		}
442
443		value, found := os.LookupEnv(spec.env)
444		if !found {
445			continue
446		}
447
448		if spec.cardinality == multiple {
449			// expect a CSV string in an environment
450			// variable in the case of multiple values
451			var values []string
452			var err error
453			if len(strings.TrimSpace(value)) > 0 {
454				values, err = csv.NewReader(strings.NewReader(value)).Read()
455				if err != nil {
456					return fmt.Errorf(
457						"error reading a CSV string from environment variable %s with multiple values: %v",
458						spec.env,
459						err,
460					)
461				}
462			}
463			if err = setSliceOrMap(p.val(spec.dest), values, !spec.separate); err != nil {
464				return fmt.Errorf(
465					"error processing environment variable %s with multiple values: %v",
466					spec.env,
467					err,
468				)
469			}
470		} else {
471			if err := scalar.ParseValue(p.val(spec.dest), value); err != nil {
472				return fmt.Errorf("error processing environment variable %s: %v", spec.env, err)
473			}
474		}
475		wasPresent[spec] = true
476	}
477
478	return nil
479}
480
481// process goes through arguments one-by-one, parses them, and assigns the result to
482// the underlying struct field
483func (p *Parser) process(args []string) error {
484	// track the options we have seen
485	wasPresent := make(map[*spec]bool)
486
487	// union of specs for the chain of subcommands encountered so far
488	curCmd := p.cmd
489	p.lastCmd = curCmd
490
491	// make a copy of the specs because we will add to this list each time we expand a subcommand
492	specs := make([]*spec, len(curCmd.specs))
493	copy(specs, curCmd.specs)
494
495	// deal with environment vars
496	if !p.config.IgnoreEnv {
497		err := p.captureEnvVars(specs, wasPresent)
498		if err != nil {
499			return err
500		}
501	}
502
503	// process each string from the command line
504	var allpositional bool
505	var positionals []string
506
507	// must use explicit for loop, not range, because we manipulate i inside the loop
508	for i := 0; i < len(args); i++ {
509		arg := args[i]
510		if arg == "--" {
511			allpositional = true
512			continue
513		}
514
515		if !isFlag(arg) || allpositional {
516			// each subcommand can have either subcommands or positionals, but not both
517			if len(curCmd.subcommands) == 0 {
518				positionals = append(positionals, arg)
519				continue
520			}
521
522			// if we have a subcommand then make sure it is valid for the current context
523			subcmd := findSubcommand(curCmd.subcommands, arg)
524			if subcmd == nil {
525				return fmt.Errorf("invalid subcommand: %s", arg)
526			}
527
528			// instantiate the field to point to a new struct
529			v := p.val(subcmd.dest)
530			v.Set(reflect.New(v.Type().Elem())) // we already checked that all subcommands are struct pointers
531
532			// add the new options to the set of allowed options
533			specs = append(specs, subcmd.specs...)
534
535			// capture environment vars for these new options
536			if !p.config.IgnoreEnv {
537				err := p.captureEnvVars(subcmd.specs, wasPresent)
538				if err != nil {
539					return err
540				}
541			}
542
543			curCmd = subcmd
544			p.lastCmd = curCmd
545			continue
546		}
547
548		// check for special --help and --version flags
549		switch arg {
550		case "-h", "--help":
551			return ErrHelp
552		case "--version":
553			return ErrVersion
554		}
555
556		// check for an equals sign, as in "--foo=bar"
557		var value string
558		opt := strings.TrimLeft(arg, "-")
559		if pos := strings.Index(opt, "="); pos != -1 {
560			value = opt[pos+1:]
561			opt = opt[:pos]
562		}
563
564		// lookup the spec for this option (note that the "specs" slice changes as
565		// we expand subcommands so it is better not to use a map)
566		spec := findOption(specs, opt)
567		if spec == nil {
568			return fmt.Errorf("unknown argument %s", arg)
569		}
570		wasPresent[spec] = true
571
572		// deal with the case of multiple values
573		if spec.cardinality == multiple {
574			var values []string
575			if value == "" {
576				for i+1 < len(args) && !isFlag(args[i+1]) && args[i+1] != "--" {
577					values = append(values, args[i+1])
578					i++
579					if spec.separate {
580						break
581					}
582				}
583			} else {
584				values = append(values, value)
585			}
586			err := setSliceOrMap(p.val(spec.dest), values, !spec.separate)
587			if err != nil {
588				return fmt.Errorf("error processing %s: %v", arg, err)
589			}
590			continue
591		}
592
593		// if it's a flag and it has no value then set the value to true
594		// use boolean because this takes account of TextUnmarshaler
595		if spec.cardinality == zero && value == "" {
596			value = "true"
597		}
598
599		// if we have something like "--foo" then the value is the next argument
600		if value == "" {
601			if i+1 == len(args) {
602				return fmt.Errorf("missing value for %s", arg)
603			}
604			if !nextIsNumeric(spec.field.Type, args[i+1]) && isFlag(args[i+1]) {
605				return fmt.Errorf("missing value for %s", arg)
606			}
607			value = args[i+1]
608			i++
609		}
610
611		err := scalar.ParseValue(p.val(spec.dest), value)
612		if err != nil {
613			return fmt.Errorf("error processing %s: %v", arg, err)
614		}
615	}
616
617	// process positionals
618	for _, spec := range specs {
619		if !spec.positional {
620			continue
621		}
622		if len(positionals) == 0 {
623			break
624		}
625		wasPresent[spec] = true
626		if spec.cardinality == multiple {
627			err := setSliceOrMap(p.val(spec.dest), positionals, true)
628			if err != nil {
629				return fmt.Errorf("error processing %s: %v", spec.field.Name, err)
630			}
631			positionals = nil
632		} else {
633			err := scalar.ParseValue(p.val(spec.dest), positionals[0])
634			if err != nil {
635				return fmt.Errorf("error processing %s: %v", spec.field.Name, err)
636			}
637			positionals = positionals[1:]
638		}
639	}
640	if len(positionals) > 0 {
641		return fmt.Errorf("too many positional arguments at '%s'", positionals[0])
642	}
643
644	// fill in defaults and check that all the required args were provided
645	for _, spec := range specs {
646		if wasPresent[spec] {
647			continue
648		}
649
650		name := strings.ToLower(spec.field.Name)
651		if spec.long != "" && !spec.positional {
652			name = "--" + spec.long
653		}
654
655		if spec.required {
656			msg := fmt.Sprintf("%s is required", name)
657			if spec.env != "" {
658				msg += " (or environment variable " + spec.env + ")"
659			}
660			return errors.New(msg)
661		}
662		if spec.defaultVal != "" {
663			err := scalar.ParseValue(p.val(spec.dest), spec.defaultVal)
664			if err != nil {
665				return fmt.Errorf("error processing default value for %s: %v", name, err)
666			}
667		}
668	}
669
670	return nil
671}
672
673func nextIsNumeric(t reflect.Type, s string) bool {
674	switch t.Kind() {
675	case reflect.Ptr:
676		return nextIsNumeric(t.Elem(), s)
677	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
678		v := reflect.New(t)
679		err := scalar.ParseValue(v, s)
680		return err == nil
681	default:
682		return false
683	}
684}
685
686// isFlag returns true if a token is a flag such as "-v" or "--user" but not "-" or "--"
687func isFlag(s string) bool {
688	return strings.HasPrefix(s, "-") && strings.TrimLeft(s, "-") != ""
689}
690
691// val returns a reflect.Value corresponding to the current value for the
692// given path
693func (p *Parser) val(dest path) reflect.Value {
694	v := p.roots[dest.root]
695	for _, field := range dest.fields {
696		if v.Kind() == reflect.Ptr {
697			if v.IsNil() {
698				return reflect.Value{}
699			}
700			v = v.Elem()
701		}
702
703		v = v.FieldByIndex(field.Index)
704	}
705	return v
706}
707
708// findOption finds an option from its name, or returns null if no spec is found
709func findOption(specs []*spec, name string) *spec {
710	for _, spec := range specs {
711		if spec.positional {
712			continue
713		}
714		if spec.long == name || spec.short == name {
715			return spec
716		}
717	}
718	return nil
719}
720
721// findSubcommand finds a subcommand using its name, or returns null if no subcommand is found
722func findSubcommand(cmds []*command, name string) *command {
723	for _, cmd := range cmds {
724		if cmd.name == name {
725			return cmd
726		}
727	}
728	return nil
729}