diff --git a/command.go b/command.go index d909f5a..5ac8754 100644 --- a/command.go +++ b/command.go @@ -26,6 +26,7 @@ type Command struct { helpFlag bool hidden bool positionalArgsMap map[string]reflect.Value + sliceSeparator map[string]string } // NewCommand creates a new Command @@ -37,6 +38,7 @@ func NewCommand(name string, description string) *Command { subCommandsMap: make(map[string]*Command), hidden: false, positionalArgsMap: make(map[string]reflect.Value), + sliceSeparator: make(map[string]string), } return result @@ -284,7 +286,11 @@ func (c *Command) AddFlags(optionStruct interface{}) *Command { description := tag.Get("description") defaultValue := tag.Get("default") pos := tag.Get("pos") + sep := tag.Get("sep") c.positionalArgsMap[pos] = field + if sep != "" { + c.sliceSeparator[pos] = sep + } if name == "" { name = strings.ToLower(t.Elem().Field(i).Name) } @@ -427,7 +433,7 @@ func (c *Command) AddFlags(optionStruct interface{}) *Command { } c.Float64Flag(name, description, field.Addr().Interface().(*float64)) case reflect.Slice: - c.addSliceField(field, defaultValue, tag.Get("sep")) + c.addSliceField(field, defaultValue, sep) c.addSliceFlags(name, description, field) default: if pos != "" { @@ -489,9 +495,6 @@ func (c *Command) addSliceField(field reflect.Value, defaultValue, separator str if defaultValue == "" { return c } - if separator == "" { - separator = "," - } if field.Kind() != reflect.Slice { panic("addSliceField() requires a pointer to a slice") } @@ -502,11 +505,14 @@ func (c *Command) addSliceField(field reflect.Value, defaultValue, separator str if t.Elem().Kind() != reflect.Slice { panic("addSliceField() requires a pointer to a slice") } + defaultSlice := []string{defaultValue} + if separator != "" { + defaultSlice = strings.Split(defaultValue, separator) + } switch t.Elem().Elem().Kind() { case reflect.Bool: - defaultSplit := strings.Split(defaultValue, separator) - defaultValues := make([]bool, 0, len(defaultSplit)) - for _, value := range defaultSplit { + defaultValues := make([]bool, 0, len(defaultSlice)) + for _, value := range defaultSlice { val, err := strconv.ParseBool(value) if err != nil { panic("Invalid default value for bool flag") @@ -515,12 +521,10 @@ func (c *Command) addSliceField(field reflect.Value, defaultValue, separator str } field.Set(reflect.ValueOf(defaultValues)) case reflect.String: - defaultValues := strings.Split(defaultValue, separator) - field.Set(reflect.ValueOf(defaultValues)) + field.Set(reflect.ValueOf(defaultSlice)) case reflect.Int: - defaultSplit := strings.Split(defaultValue, separator) - defaultValues := make([]int, 0, len(defaultSplit)) - for _, value := range defaultSplit { + defaultValues := make([]int, 0, len(defaultSlice)) + for _, value := range defaultSlice { val, err := strconv.Atoi(value) if err != nil { panic("Invalid default value for int flag") @@ -529,9 +533,8 @@ func (c *Command) addSliceField(field reflect.Value, defaultValue, separator str } field.Set(reflect.ValueOf(defaultValues)) case reflect.Int8: - defaultSplit := strings.Split(defaultValue, separator) - defaultValues := make([]int8, 0, len(defaultSplit)) - for _, value := range defaultSplit { + defaultValues := make([]int8, 0, len(defaultSlice)) + for _, value := range defaultSlice { val, err := strconv.Atoi(value) if err != nil { panic("Invalid default value for int8 flag") @@ -540,9 +543,8 @@ func (c *Command) addSliceField(field reflect.Value, defaultValue, separator str } field.Set(reflect.ValueOf(defaultValues)) case reflect.Int16: - defaultSplit := strings.Split(defaultValue, separator) - defaultValues := make([]int16, 0, len(defaultSplit)) - for _, value := range defaultSplit { + defaultValues := make([]int16, 0, len(defaultSlice)) + for _, value := range defaultSlice { val, err := strconv.Atoi(value) if err != nil { panic("Invalid default value for int16 flag") @@ -551,10 +553,9 @@ func (c *Command) addSliceField(field reflect.Value, defaultValue, separator str } field.Set(reflect.ValueOf(defaultValues)) case reflect.Int32: - defaultSplit := strings.Split(defaultValue, separator) - defaultValues := make([]int32, 0, len(defaultSplit)) - for _, value := range defaultSplit { - val, err := strconv.ParseInt(value, 10, 64) + defaultValues := make([]int32, 0, len(defaultSlice)) + for _, value := range defaultSlice { + val, err := strconv.ParseInt(value, 10, 32) if err != nil { panic("Invalid default value for int32 flag") } @@ -562,9 +563,8 @@ func (c *Command) addSliceField(field reflect.Value, defaultValue, separator str } field.Set(reflect.ValueOf(defaultValues)) case reflect.Int64: - defaultSplit := strings.Split(defaultValue, separator) - defaultValues := make([]int64, 0, len(defaultSplit)) - for _, value := range defaultSplit { + defaultValues := make([]int64, 0, len(defaultSlice)) + for _, value := range defaultSlice { val, err := strconv.ParseInt(value, 10, 64) if err != nil { panic("Invalid default value for int64 flag") @@ -573,9 +573,8 @@ func (c *Command) addSliceField(field reflect.Value, defaultValue, separator str } field.Set(reflect.ValueOf(defaultValues)) case reflect.Uint: - defaultSplit := strings.Split(defaultValue, separator) - defaultValues := make([]uint, 0, len(defaultSplit)) - for _, value := range defaultSplit { + defaultValues := make([]uint, 0, len(defaultSlice)) + for _, value := range defaultSlice { val, err := strconv.Atoi(value) if err != nil { panic("Invalid default value for uint flag") @@ -584,9 +583,8 @@ func (c *Command) addSliceField(field reflect.Value, defaultValue, separator str } field.Set(reflect.ValueOf(defaultValues)) case reflect.Uint8: - defaultSplit := strings.Split(defaultValue, separator) - defaultValues := make([]uint8, 0, len(defaultSplit)) - for _, value := range defaultSplit { + defaultValues := make([]uint8, 0, len(defaultSlice)) + for _, value := range defaultSlice { val, err := strconv.Atoi(value) if err != nil { panic("Invalid default value for uint8 flag") @@ -595,9 +593,8 @@ func (c *Command) addSliceField(field reflect.Value, defaultValue, separator str } field.Set(reflect.ValueOf(defaultValues)) case reflect.Uint16: - defaultSplit := strings.Split(defaultValue, separator) - defaultValues := make([]uint16, 0, len(defaultSplit)) - for _, value := range defaultSplit { + defaultValues := make([]uint16, 0, len(defaultSlice)) + for _, value := range defaultSlice { val, err := strconv.Atoi(value) if err != nil { panic("Invalid default value for uint16 flag") @@ -606,9 +603,8 @@ func (c *Command) addSliceField(field reflect.Value, defaultValue, separator str } field.Set(reflect.ValueOf(defaultValues)) case reflect.Uint32: - defaultSplit := strings.Split(defaultValue, separator) - defaultValues := make([]uint32, 0, len(defaultSplit)) - for _, value := range defaultSplit { + defaultValues := make([]uint32, 0, len(defaultSlice)) + for _, value := range defaultSlice { val, err := strconv.Atoi(value) if err != nil { panic("Invalid default value for uint32 flag") @@ -617,9 +613,8 @@ func (c *Command) addSliceField(field reflect.Value, defaultValue, separator str } field.Set(reflect.ValueOf(defaultValues)) case reflect.Uint64: - defaultSplit := strings.Split(defaultValue, separator) - defaultValues := make([]uint64, 0, len(defaultSplit)) - for _, value := range defaultSplit { + defaultValues := make([]uint64, 0, len(defaultSlice)) + for _, value := range defaultSlice { val, err := strconv.Atoi(value) if err != nil { panic("Invalid default value for uint64 flag") @@ -628,9 +623,8 @@ func (c *Command) addSliceField(field reflect.Value, defaultValue, separator str } field.Set(reflect.ValueOf(defaultValues)) case reflect.Float32: - defaultSplit := strings.Split(defaultValue, separator) - defaultValues := make([]float32, 0, len(defaultSplit)) - for _, value := range defaultSplit { + defaultValues := make([]float32, 0, len(defaultSlice)) + for _, value := range defaultSlice { val, err := strconv.Atoi(value) if err != nil { panic("Invalid default value for float32 flag") @@ -639,9 +633,8 @@ func (c *Command) addSliceField(field reflect.Value, defaultValue, separator str } field.Set(reflect.ValueOf(defaultValues)) case reflect.Float64: - defaultSplit := strings.Split(defaultValue, separator) - defaultValues := make([]float64, 0, len(defaultSplit)) - for _, value := range defaultSplit { + defaultValues := make([]float64, 0, len(defaultSlice)) + for _, value := range defaultSlice { val, err := strconv.Atoi(value) if err != nil { panic("Invalid default value for float64 flag") @@ -1332,7 +1325,7 @@ func (c *Command) parsePositionalArgs(args []string) error { } field.SetFloat(value) case reflect.Slice: - c.addSliceField(field, posArg, "") + c.addSliceField(field, posArg, c.sliceSeparator[key]) default: return errors.New("Unsupported type for positional argument: " + fieldType.Name()) } diff --git a/examples/flags-slice/main.go b/examples/flags-slice/main.go index 69f8d66..79542ef 100644 --- a/examples/flags-slice/main.go +++ b/examples/flags-slice/main.go @@ -9,59 +9,59 @@ import ( type Flags struct { String string `name:"string" description:"The string" pos:"1"` Strings []string `name:"strings" description:"The strings" pos:"2"` - StringsDefault []string `name:"strings_default" description:"The strings default" default:"one,two,three" pos:"3"` + StringsDefault []string `name:"strings_default" description:"The strings default" default:"one|two|three" sep:"|" pos:"3"` Int int `name:"int" description:"The int" pos:"4"` Ints []int `name:"ints" description:"The ints" pos:"5"` - IntsDefault []int `name:"ints_default" description:"The ints default" default:"3|4|5" sep:"|" pos:"6"` + IntsDefault []int `name:"ints_default" description:"The ints default" default:"3|4|5" sep:"|" pos:"6"` Int8 int8 `name:"int8" description:"The int8" pos:"7"` Int8s []int8 `name:"int8s" description:"The int8s" pos:"8"` - Int8sDefault []int8 `name:"int8s_default" description:"The int8s default" default:"3,4,5" pos:"9"` + Int8sDefault []int8 `name:"int8s_default" description:"The int8s default" default:"3,4,5" sep:"," pos:"9"` Int16 int16 `name:"int16" description:"The int16" pos:"10"` Int16s []int16 `name:"int16s" description:"The int16s" pos:"11"` - Int16sDefault []int16 `name:"int16s_default" description:"The int16s default" default:"3,4,5" pos:"12"` + Int16sDefault []int16 `name:"int16s_default" description:"The int16s default" default:"3,4,5" sep:"," pos:"12"` Int32 int32 `name:"int32" description:"The int32" pos:"13"` Int32s []int32 `name:"int32s" description:"The int32s" pos:"14"` - Int32sDefault []int32 `name:"int32s_default" description:"The int32 default" default:"3,4,5" pos:"15"` + Int32sDefault []int32 `name:"int32s_default" description:"The int32 default" default:"3,4,5" sep:"," pos:"15"` Int64 int64 `name:"int64" description:"The int64" pos:"16"` Int64s []int64 `name:"int64s" description:"The int64s" pos:"17"` - Int64sDefault []int64 `name:"int64s_default" description:"The int64s default" default:"3,4,5" pos:"18"` + Int64sDefault []int64 `name:"int64s_default" description:"The int64s default" default:"3,4,5" sep:"," pos:"18"` Uint uint `name:"uint" description:"The uint" pos:"19"` Uints []uint `name:"uints" description:"The uints" pos:"20"` - UintsDefault []uint `name:"uints_default" description:"The uints default" default:"3,4,5" pos:"21"` + UintsDefault []uint `name:"uints_default" description:"The uints default" default:"3,4,5" sep:"," pos:"21"` Uint8 uint8 `name:"uint8" description:"The uint8" pos:"22"` Uint8s []uint8 `name:"uint8s" description:"The uint8s" pos:"23"` - Uint8sDefault []uint8 `name:"uint8s_default" description:"The uint8s default" default:"3,4,5" pos:"24"` + Uint8sDefault []uint8 `name:"uint8s_default" description:"The uint8s default" default:"3,4,5" sep:"," pos:"24"` Uint16 uint16 `name:"uint16" description:"The uint16" pos:"25"` Uint16s []uint16 `name:"uint16s" description:"The uint16s" pos:"26"` - Uint16sDefault []uint16 `name:"uint16s_default" description:"The uint16 default" default:"3,4,5" pos:"27"` + Uint16sDefault []uint16 `name:"uint16s_default" description:"The uint16 default" default:"3,4,5" sep:"," pos:"27"` Uint32 uint32 `name:"uint32" description:"The uint32" pos:"28"` Uint32s []uint32 `name:"uint32s" description:"The uint32s" pos:"29"` - Uint32sDefault []uint32 `name:"uint32s_default" description:"The uint32s default" default:"3,4,5" pos:"30"` + Uint32sDefault []uint32 `name:"uint32s_default" description:"The uint32s default" default:"3,4,5" sep:"," pos:"30"` Uint64 uint64 `name:"uint64" description:"The uint64" pos:"31"` Uint64s []uint64 `name:"uint64s" description:"The uint64s" pos:"32"` - Uint64sDefault []uint64 `name:"uint64s_default" description:"The uint64s default" default:"3,4,5" pos:"33"` + Uint64sDefault []uint64 `name:"uint64s_default" description:"The uint64s default" default:"3,4,5" sep:"," pos:"33"` Float32 float32 `name:"float32" description:"The float32" pos:"34"` Float32s []float32 `name:"float32s" description:"The float32s" pos:"35"` - Float32sDefault []float32 `name:"float32s_default" description:"The float32s default" default:"3,4,5" pos:"36"` + Float32sDefault []float32 `name:"float32s_default" description:"The float32s default" default:"3|4|5" sep:"|" pos:"36"` Float64 float64 `name:"float64" description:"The float64" pos:"37"` Float64s []float64 `name:"float64s" description:"The float64s" pos:"38"` - Float64sDefault []float64 `name:"float64s_default" description:"The float64s default" default:"3,4,5" pos:"39"` + Float64sDefault []float64 `name:"float64s_default" description:"The float64s default" default:"3|4|5" sep:"|" pos:"39"` Bool bool `name:"bool" description:"The bool" pos:"40"` Bools []bool `name:"bools" description:"The bools" pos:"41"` - BoolsDefault []bool `name:"bools_default" description:"The bools default" default:"false,true,false,true" pos:"42"` + BoolsDefault []bool `name:"bools_default" description:"The bools default" default:"false|true|false|true" sep:"|" pos:"42"` } func main() { @@ -190,12 +190,12 @@ func main() { panic(fmt.Sprintf("expected 'hello', got '%v'", f.String)) } - if !reflect.DeepEqual(f.Strings, []string{"zkep", "hello", "clir"}) { - panic(fmt.Sprintf("expected '[zkep hello clir]', got '%v'", f.Strings)) + if !reflect.DeepEqual(f.Strings, []string{"zkep,hello,clir"}) { + panic(fmt.Sprintf("expected 'zkep,hello,clir', got '%v'", f.Strings)) } if !reflect.DeepEqual(f.StringsDefault, []string{"zkep", "clir", "hello"}) { - panic(fmt.Sprintf("expected '[zkep clir hello]', got '%v'", f.StringsDefault)) + panic(fmt.Sprintf("expected '[zkep,clir,hello]', got '%v'", f.StringsDefault)) } println("string:", fmt.Sprintf("%#v", f.String)) @@ -207,7 +207,8 @@ func main() { }) // Run! - if err := cli.Run("positional", "hello", "zkep,hello,clir", "zkep,clir,hello"); err != nil { + // The pos 3 slice separator is '|' in struct tag + if err := cli.Run("positional", "hello", "zkep,hello,clir", "zkep|clir|hello"); err != nil { panic(err) }