Skip to content

Commit

Permalink
Parse short form flags (#54)
Browse files Browse the repository at this point in the history
  • Loading branch information
FollowTheProcess authored Aug 1, 2024
1 parent 0cd6c9d commit 91b8a12
Show file tree
Hide file tree
Showing 4 changed files with 545 additions and 38 deletions.
2 changes: 1 addition & 1 deletion internal/flag/flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ func validateFlagShort(short rune) error {
// value parsing errors.
func errParse[T any](name, str string, typ *T, err error) error {
return fmt.Errorf(
"flag %s received invalid value %q (expected %T), detail: %w",
"flag %q received invalid value %q (expected %T), detail: %w",
name,
str,
*typ,
Expand Down
36 changes: 18 additions & 18 deletions internal/flag/flag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestFlagValue(t *testing.T) {
test.Equal(
t,
err.Error(),
`flag int received invalid value "word" (expected int), detail: strconv.ParseInt: parsing "word": invalid syntax`,
`flag "int" received invalid value "word" (expected int), detail: strconv.ParseInt: parsing "word": invalid syntax`,
)
})

Expand All @@ -62,7 +62,7 @@ func TestFlagValue(t *testing.T) {
test.Equal(
t,
err.Error(),
`flag int received invalid value "word" (expected int8), detail: strconv.ParseInt: parsing "word": invalid syntax`,
`flag "int" received invalid value "word" (expected int8), detail: strconv.ParseInt: parsing "word": invalid syntax`,
)
})

Expand All @@ -88,7 +88,7 @@ func TestFlagValue(t *testing.T) {
test.Equal(
t,
err.Error(),
`flag int received invalid value "word" (expected int16), detail: strconv.ParseInt: parsing "word": invalid syntax`,
`flag "int" received invalid value "word" (expected int16), detail: strconv.ParseInt: parsing "word": invalid syntax`,
)
})

Expand All @@ -114,7 +114,7 @@ func TestFlagValue(t *testing.T) {
test.Equal(
t,
err.Error(),
`flag int received invalid value "word" (expected int32), detail: strconv.ParseInt: parsing "word": invalid syntax`,
`flag "int" received invalid value "word" (expected int32), detail: strconv.ParseInt: parsing "word": invalid syntax`,
)
})

Expand All @@ -140,7 +140,7 @@ func TestFlagValue(t *testing.T) {
test.Equal(
t,
err.Error(),
`flag int received invalid value "word" (expected int64), detail: strconv.ParseInt: parsing "word": invalid syntax`,
`flag "int" received invalid value "word" (expected int64), detail: strconv.ParseInt: parsing "word": invalid syntax`,
)
})

Expand All @@ -166,7 +166,7 @@ func TestFlagValue(t *testing.T) {
test.Equal(
t,
err.Error(),
`flag uint received invalid value "word" (expected uint), detail: strconv.ParseUint: parsing "word": invalid syntax`,
`flag "uint" received invalid value "word" (expected uint), detail: strconv.ParseUint: parsing "word": invalid syntax`,
)
})

Expand All @@ -192,7 +192,7 @@ func TestFlagValue(t *testing.T) {
test.Equal(
t,
err.Error(),
`flag uint received invalid value "word" (expected uint8), detail: strconv.ParseUint: parsing "word": invalid syntax`,
`flag "uint" received invalid value "word" (expected uint8), detail: strconv.ParseUint: parsing "word": invalid syntax`,
)
})

Expand All @@ -218,7 +218,7 @@ func TestFlagValue(t *testing.T) {
test.Equal(
t,
err.Error(),
`flag uint received invalid value "word" (expected uint16), detail: strconv.ParseUint: parsing "word": invalid syntax`,
`flag "uint" received invalid value "word" (expected uint16), detail: strconv.ParseUint: parsing "word": invalid syntax`,
)
})

Expand All @@ -244,7 +244,7 @@ func TestFlagValue(t *testing.T) {
test.Equal(
t,
err.Error(),
`flag uint received invalid value "word" (expected uint32), detail: strconv.ParseUint: parsing "word": invalid syntax`,
`flag "uint" received invalid value "word" (expected uint32), detail: strconv.ParseUint: parsing "word": invalid syntax`,
)
})

Expand All @@ -270,7 +270,7 @@ func TestFlagValue(t *testing.T) {
test.Equal(
t,
err.Error(),
`flag uint received invalid value "word" (expected uint64), detail: strconv.ParseUint: parsing "word": invalid syntax`,
`flag "uint" received invalid value "word" (expected uint64), detail: strconv.ParseUint: parsing "word": invalid syntax`,
)
})

Expand All @@ -296,7 +296,7 @@ func TestFlagValue(t *testing.T) {
test.Equal(
t,
err.Error(),
`flag uintptr received invalid value "word" (expected uintptr), detail: strconv.ParseUint: parsing "word": invalid syntax`,
`flag "uintptr" received invalid value "word" (expected uintptr), detail: strconv.ParseUint: parsing "word": invalid syntax`,
)
})

Expand All @@ -322,7 +322,7 @@ func TestFlagValue(t *testing.T) {
test.Equal(
t,
err.Error(),
`flag float received invalid value "word" (expected float32), detail: strconv.ParseFloat: parsing "word": invalid syntax`,
`flag "float" received invalid value "word" (expected float32), detail: strconv.ParseFloat: parsing "word": invalid syntax`,
)
})

Expand All @@ -348,7 +348,7 @@ func TestFlagValue(t *testing.T) {
test.Equal(
t,
err.Error(),
`flag float received invalid value "word" (expected float64), detail: strconv.ParseFloat: parsing "word": invalid syntax`,
`flag "float" received invalid value "word" (expected float64), detail: strconv.ParseFloat: parsing "word": invalid syntax`,
)
})

Expand All @@ -374,7 +374,7 @@ func TestFlagValue(t *testing.T) {
test.Equal(
t,
err.Error(),
`flag bool received invalid value "word" (expected bool), detail: strconv.ParseBool: parsing "word": invalid syntax`,
`flag "bool" received invalid value "word" (expected bool), detail: strconv.ParseBool: parsing "word": invalid syntax`,
)
})

Expand Down Expand Up @@ -414,7 +414,7 @@ func TestFlagValue(t *testing.T) {
test.Equal(
t,
err.Error(),
`flag byte received invalid value "0xF" (expected []uint8), detail: encoding/hex: invalid byte: U+0078 'x'`,
`flag "byte" received invalid value "0xF" (expected []uint8), detail: encoding/hex: invalid byte: U+0078 'x'`,
)
})

Expand Down Expand Up @@ -443,7 +443,7 @@ func TestFlagValue(t *testing.T) {
test.Equal(
t,
err.Error(),
`flag time received invalid value "not a time" (expected time.Time), detail: parsing time "not a time" as "2006-01-02T15:04:05Z07:00": cannot parse "not a time" as "2006"`,
`flag "time" received invalid value "not a time" (expected time.Time), detail: parsing time "not a time" as "2006-01-02T15:04:05Z07:00": cannot parse "not a time" as "2006"`,
)
})

Expand Down Expand Up @@ -484,7 +484,7 @@ func TestFlagValue(t *testing.T) {
test.Equal(
t,
err.Error(),
`flag duration received invalid value "not a duration" (expected time.Duration), detail: time: invalid duration "not a duration"`,
`flag "duration" received invalid value "not a duration" (expected time.Duration), detail: time: invalid duration "not a duration"`,
)
})

Expand All @@ -510,7 +510,7 @@ func TestFlagValue(t *testing.T) {
test.Equal(
t,
err.Error(),
`flag ip received invalid value "not an ip" (expected net.IP), detail: invalid IP address`,
`flag "ip" received invalid value "not an ip" (expected net.IP), detail: invalid IP address`,
)
})
}
Expand Down
144 changes: 130 additions & 14 deletions internal/flag/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"strings"
"unicode/utf8"
)

// Set is a set of command line flags.
Expand All @@ -23,6 +24,10 @@ func NewSet() *Set {

// AddToSet adds a flag to the given Set.
func AddToSet[T Flaggable](set *Set, flag Flag[T]) error {
if set == nil {
return errors.New("cannot add flag to a nil set")
}
// TODO: Would this be better as a method on Flag[T]?
_, exists := set.flags[flag.name]
if exists {
return fmt.Errorf("flag %q already defined", flag.name)
Expand Down Expand Up @@ -53,6 +58,9 @@ func AddToSet[T Flaggable](set *Set, flag Flag[T]) error {
// Get gets a flag Value from the Set by name and a boolean to indicate
// whether it was present.
func (s *Set) Get(name string) (Value, bool) {
if s == nil {
return nil, false
}
entry, ok := s.flags[name]
if !ok {
return nil, false
Expand All @@ -62,6 +70,9 @@ func (s *Set) Get(name string) (Value, bool) {

// Parse parses flags and their values from the command line.
func (s *Set) Parse(args []string) (err error) {
if s == nil {
return errors.New("Parse called on a nil set")
}
for len(args) > 0 {
arg := args[0] // The argument we're currently inspecting
args = args[1:] // Remainder
Expand All @@ -75,7 +86,10 @@ func (s *Set) Parse(args []string) (err error) {
}
case strings.HasPrefix(arg, "-"):
// Short flag e.g. -d
return errors.New("TODO")
args, err = s.parseShortFlag(arg, args)
if err != nil {
return err
}
default:
// Regular positional argument
s.args = append(s.args, arg)
Expand All @@ -85,6 +99,16 @@ func (s *Set) Parse(args []string) (err error) {
return nil
}

// flagEntry represents a single flag in the set.
type flagEntry struct {
value Value // The actual Flag[T]
name string // The full name of the flag e.g. "delete"
usage string // The flag's usage message
defaultValue string // String representation of the default flag value
defaultValueNoArg string // String representation of the default flag value if used without an arg, e.g. boolean flags "--force" implies "--force true"
shorthand rune // The optional shorthand e.g. 'd' or [NoShortHand]
}

// parseLongFlag parses a single long flag e.g. --delete. It is passed
// the possible long flag and the rest of the argument list and returns
// the remaining arguments after it's done parsing to the caller.
Expand All @@ -95,7 +119,7 @@ func (s *Set) parseLongFlag(long string, rest []string) (remaining []string, err
name := strings.TrimPrefix(long, "--")

// name will either be the entire string or the name before the "="
name, equalsValue, containsEquals := strings.Cut(name, "=")
name, value, containsEquals := strings.Cut(name, "=")
if err := validateFlagName(name); err != nil {
return nil, fmt.Errorf("invalid flag name %q: %w", name, err)
}
Expand All @@ -106,9 +130,9 @@ func (s *Set) parseLongFlag(long string, rest []string) (remaining []string, err

if containsEquals {
// Must be "flag=value"
err := flag.value.Set(equalsValue)
err := flag.value.Set(value)
if err != nil {
return nil, fmt.Errorf("failed to set value %s for flag --%s: %w", equalsValue, name, err)
return nil, err
}

// We're done, no need to cut anything from rest as this was a single arg
Expand All @@ -121,7 +145,7 @@ func (s *Set) parseLongFlag(long string, rest []string) (remaining []string, err
// --flag (boolean)
err := flag.value.Set(flag.defaultValueNoArg)
if err != nil {
return nil, fmt.Errorf("failed to set value %s for flag --%s: %w", flag.defaultValueNoArg, name, err)
return nil, err
}
// Done, as above no need to cut anything
return rest, nil
Expand All @@ -130,7 +154,7 @@ func (s *Set) parseLongFlag(long string, rest []string) (remaining []string, err
value := rest[0]
err := flag.value.Set(value)
if err != nil {
return nil, fmt.Errorf("failed to set value %s for flag --%s: %w", value, name, err)
return nil, err
}
// Done, cut value from args and return
return rest[1:], nil
Expand All @@ -140,12 +164,104 @@ func (s *Set) parseLongFlag(long string, rest []string) (remaining []string, err
}
}

// flagEntry represents a single flag in the set.
type flagEntry struct {
value Value // The actual Flag[T]
name string // The full name of the flag e.g. "delete"
usage string // The flag's usage message
defaultValue string // String representation of the default flag value
defaultValueNoArg string // String representation of the default flag value if used without an arg, e.g. boolean flags "--force" implies "--force true"
shorthand rune // The optional shorthand e.g. 'd' or [NoShortHand]
// parseShortFlag parses short flags from the command line. It is passed the possible
// short flag and the rest of the argument list and returns the remaining arguments
// after it's done parsing to the caller.
//
// The forms it expects are "-f", "-vfg", "-f value" and "-f=value".
func (s *Set) parseShortFlag(short string, rest []string) (remaining []string, err error) {
// TODO: Refactor this to clean it up and reduce duplication

// Could either be "f", "vfg" or "f=value"
shorthands := strings.TrimPrefix(short, "-")

// Is it e.g. f=value
shorthand, value, containsEquals := strings.Cut(shorthands, "=")
if err := validateFlagName(shorthand); err != nil {
return nil, fmt.Errorf("invalid flag name %q: %w", shorthand, err)
}
if containsEquals {
// Yes, it is. If the thing on the left of the equals is > 1 char it's an error
if len(shorthand) != 1 {
return nil, fmt.Errorf("invalid shorthand syntax: expected e.g. -f=<value> got %s", short)
}

char, _ := utf8.DecodeRuneInString(shorthand)
if err := validateFlagShort(char); err != nil {
return nil, fmt.Errorf("invalid flag shorthand %q: %w", string(char), err)
}

flag, exists := s.shorthands[char]
if !exists {
return nil, fmt.Errorf("unrecognised shorthand flag: -%s", string(char))
}

if err := flag.value.Set(value); err != nil {
return nil, err
}

// We're done, nothing to trim off
return rest, nil
}

// It's not "f=value" so must be one of "fvalue", "f value", or "vvv"
// len("fvalue") is > 1 but len("f") isn't (value in that last case is the first arg in 'rest')
if len(shorthands) > 1 {
// It must be "fvalue", so extract "value"
char, _ := utf8.DecodeRuneInString(shorthands)
if err := validateFlagShort(char); err != nil {
return nil, fmt.Errorf("invalid flag shorthand %q: %w", string(char), err)
}
value = shorthands[1:]
flag, exists := s.shorthands[char]
if !exists {
return nil, fmt.Errorf("unrecognised shorthand flag: -%s", string(char))
}
if err := flag.value.Set(value); err != nil {
return nil, err
}

// We're done, nothing to trim off
return rest, nil
}

// Any arguments after the short flag?
if len(rest) > 0 {
// It must be "f value" and value is the next argument in rest
char, _ := utf8.DecodeRuneInString(shorthands)
if err := validateFlagShort(char); err != nil {
return nil, fmt.Errorf("invalid flag shorthand %q: %w", string(char), err)
}
value = rest[0]
flag, exists := s.shorthands[char]
if !exists {
return nil, fmt.Errorf("unrecognised shorthand flag: -%s", string(char))
}
if err := flag.value.Set(value); err != nil {
return nil, err
}

// We've consumed "value" from rest so trim it off
return rest[1:], nil
}

// If we get here, it must be the "-vvv" form
for _, char := range shorthands {
flag, exists := s.shorthands[char]
if !exists {
return nil, fmt.Errorf("unrecognised shorthand flag: %q in -%s", string(char), shorthands)
}

// -f (boolean flag)
if flag.defaultValueNoArg != "" {
err := flag.value.Set(flag.defaultValueNoArg)
if err != nil {
return nil, err
}
// Done, as above no need to cut anything
return rest, nil
}
}

return rest, nil
}
Loading

0 comments on commit 91b8a12

Please sign in to comment.