diff --git a/decoder.go b/decoder.go index 3c46ba2..ab949df 100644 --- a/decoder.go +++ b/decoder.go @@ -18,6 +18,12 @@ type Decoder struct { // options (Default: 'csv'). Tag string + // If true, Decoder will return a MissingColumnsError if it discovers + // that any of the columns are missing. This means that a CSV input + // will be required to contain all columns that were defined in the + // provided struct. + DisallowMissingColumns bool + // If not nil, Map is a function that is called for each field in the csv // record before decoding the data. It allows mapping certain string values // for specific columns or types to a known format. Decoder calls Map with @@ -393,13 +399,18 @@ func (d *Decoder) fields(k typeKey) ([]decField, error) { return d.cache, nil } - fields := cachedFields(k) - decFields := make([]decField, 0, len(fields)) - used := make([]bool, len(d.header)) - + var ( + fields = cachedFields(k) + decFields = make([]decField, 0, len(fields)) + used = make([]bool, len(d.header)) + missingCols []string + ) for _, f := range fields { i, ok := d.hmap[f.name] if !ok { + if d.DisallowMissingColumns { + missingCols = append(missingCols, f.name) + } continue } @@ -427,6 +438,12 @@ func (d *Decoder) fields(k typeKey) ([]decField, error) { used[i] = true } + if len(missingCols) > 0 { + return nil, &MissingColumnsError{ + Columns: missingCols, + } + } + d.unused = d.unused[:0] for i, b := range used { if !b { diff --git a/decoder_test.go b/decoder_test.go index 7fd91da..285f936 100644 --- a/decoder_test.go +++ b/decoder_test.go @@ -1628,6 +1628,99 @@ string,"{""key"":""value""}" } }) + t.Run("decode with disallow missing columns", func(t *testing.T) { + type Type struct { + String string + Int int + Float float64 + } + + t.Run("all present", func(t *testing.T) { + dec, err := NewDecoder(NewReader( + []string{"String", "Int", "Float"}, + []string{"lol", "1", "2.0"}, + )) + if err != nil { + t.Fatal(err) + } + dec.DisallowMissingColumns = true + + var tt Type + if err := dec.Decode(&tt); err != nil { + t.Fatalf("expected err to be nil; got %v", err) + } + + if expected := (Type{"lol", 1, 2}); !reflect.DeepEqual(tt, expected) { + t.Errorf("want=%v; got %v", expected, tt) + } + }) + + fixtures := []struct { + desc string + recs [][]string + missingCols []string + msg string + }{ + { + desc: "one missing", + recs: [][]string{ + {"String", "Int"}, + {"lol", "1"}, + }, + missingCols: []string{"Float"}, + msg: `csvutil: missing columns: "Float"`, + }, + { + desc: "two missing", + recs: [][]string{ + {"String"}, + {"lol"}, + }, + missingCols: []string{"Int", "Float"}, + msg: `csvutil: missing columns: "Int", "Float"`, + }, + { + desc: "all missing", + recs: [][]string{ + {"w00t"}, + {"lol"}, + }, + missingCols: []string{"String", "Int", "Float"}, + msg: `csvutil: missing columns: "String", "Int", "Float"`, + }, + } + + for _, f := range fixtures { + t.Run(f.desc, func(t *testing.T) { + dec, err := NewDecoder(NewReader(f.recs...)) + if err != nil { + t.Fatal(err) + } + dec.DisallowMissingColumns = true + + var tt Type + err = dec.Decode(&tt) + + if err == nil { + t.Fatal("expected err != nil") + } + + mcerr, ok := err.(*MissingColumnsError) + if !ok { + t.Fatalf("expected err to be of *MissingColumnErr; got %[1]T (%[1]v)", err) + } + + if !reflect.DeepEqual(mcerr.Columns, f.missingCols) { + t.Errorf("expected missing columns to be %v; got %v", f.missingCols, mcerr.Columns) + } + + if err.Error() != f.msg { + t.Errorf("expected err message to be %q; got %q", f.msg, err.Error()) + } + }) + } + }) + t.Run("invalid unmarshal tests", func(t *testing.T) { var fixtures = []struct { v interface{} diff --git a/error.go b/error.go index 797e6c8..6f59887 100644 --- a/error.go +++ b/error.go @@ -1,6 +1,7 @@ package csvutil import ( + "bytes" "errors" "fmt" "reflect" @@ -130,3 +131,21 @@ func (e *MarshalerError) Error() string { func errPtrUnexportedStruct(typ reflect.Type) error { return fmt.Errorf("csvutil: cannot decode into a pointer to unexported struct: %s", typ) } + +// MissingColumnsError is returned by Decoder only when DisallowMissingColumns +// option was set to true. It contains a list of all missing columns. +type MissingColumnsError struct { + Columns []string +} + +func (e *MissingColumnsError) Error() string { + var b bytes.Buffer + b.WriteString("csvutil: missing columns: ") + for i, c := range e.Columns { + if i > 0 { + b.WriteString(", ") + } + fmt.Fprintf(&b, "%q", c) + } + return b.String() +}