diff --git a/models_test.go b/models_test.go index 1b6a5ac..df3b43b 100644 --- a/models_test.go +++ b/models_test.go @@ -184,11 +184,22 @@ type Company struct { ID string `jsonapi:"primary,companies"` Name string `jsonapi:"attr,name"` Boss Employee `jsonapi:"attr,boss"` + Manager *Employee `jsonapi:"attr,manager"` Teams []Team `jsonapi:"attr,teams"` People []*People `jsonapi:"attr,people"` FoundedAt time.Time `jsonapi:"attr,founded-at,iso8601"` } +type CompanyOmitEmpty struct { + ID string `jsonapi:"primary,companies"` + Name string `jsonapi:"attr,name,omitempty"` + Boss Employee `jsonapi:"attr,boss,omitempty"` + Manager *Employee `jsonapi:"attr,manager,omitempty"` + Teams []Team `jsonapi:"attr,teams,omitempty"` + People []*People `jsonapi:"attr,people,omitempty"` + FoundedAt time.Time `jsonapi:"attr,founded-at,iso8601,omitempty"` +} + type People struct { Name string `jsonapi:"attr,name"` Age int `jsonapi:"attr,age"` diff --git a/response.go b/response.go index dea77d8..cc989f2 100644 --- a/response.go +++ b/response.go @@ -221,18 +221,300 @@ func selectChoiceTypeStructField(structValue reflect.Value) (reflect.Value, erro return reflect.Value{}, errors.New("no non-nil choice field was found in the specified struct") } +func visitModelNodeAttribute(args []string, node *Node, fieldValue reflect.Value) error { + var omitEmpty, iso8601, rfc3339 bool + + if len(args) > 2 { + for _, arg := range args[2:] { + switch arg { + case annotationOmitEmpty: + omitEmpty = true + case annotationISO8601: + iso8601 = true + case annotationRFC3339: + rfc3339 = true + } + } + } + + if node.Attributes == nil { + node.Attributes = make(map[string]interface{}) + } + + // Handle Nullable[T] + if strings.HasPrefix(fieldValue.Type().Name(), "NullableAttr[") { + // handle unspecified + if fieldValue.IsNil() { + return nil + } + + // handle null + if fieldValue.MapIndex(reflect.ValueOf(false)).IsValid() { + node.Attributes[args[1]] = json.RawMessage("null") + return nil + } else { + + // handle value + fieldValue = fieldValue.MapIndex(reflect.ValueOf(true)) + } + } + + if fieldValue.Type() == reflect.TypeOf(time.Time{}) { + t := fieldValue.Interface().(time.Time) + + if t.IsZero() { + return nil + } + + if iso8601 { + node.Attributes[args[1]] = t.UTC().Format(iso8601TimeFormat) + } else if rfc3339 { + node.Attributes[args[1]] = t.UTC().Format(time.RFC3339) + } else { + node.Attributes[args[1]] = t.Unix() + } + } else if fieldValue.Type() == reflect.TypeOf(new(time.Time)) { + // A time pointer may be nil + if fieldValue.IsNil() { + if omitEmpty { + return nil + } + + node.Attributes[args[1]] = nil + } else { + tm := fieldValue.Interface().(*time.Time) + + if tm.IsZero() && omitEmpty { + return nil + } + + if iso8601 { + node.Attributes[args[1]] = tm.UTC().Format(iso8601TimeFormat) + } else if rfc3339 { + node.Attributes[args[1]] = tm.UTC().Format(time.RFC3339) + } else { + node.Attributes[args[1]] = tm.Unix() + } + } + } else { + // Dealing with a fieldValue that is not a time + emptyValue := reflect.Zero(fieldValue.Type()) + + // See if we need to omit this field + if omitEmpty && reflect.DeepEqual(fieldValue.Interface(), emptyValue.Interface()) { + return nil + } + + isStruct := fieldValue.Type().Kind() == reflect.Struct + isPointerToStruct := fieldValue.Type().Kind() == reflect.Pointer && fieldValue.Elem().Kind() == reflect.Struct + isSliceOfStruct := fieldValue.Type().Kind() == reflect.Slice && fieldValue.Type().Elem().Kind() == reflect.Struct + isSliceOfPointerToStruct := fieldValue.Type().Kind() == reflect.Slice && fieldValue.Type().Elem().Kind() == reflect.Pointer && fieldValue.Type().Elem().Elem().Kind() == reflect.Struct + + if isSliceOfStruct || isSliceOfPointerToStruct { + if fieldValue.Len() == 0 && omitEmpty { + return nil + } + // Nested slice of object attributes + manyNested, err := visitModelNodeRelationships(fieldValue, nil, false) + if err != nil { + return fmt.Errorf("failed to marshal slice of nested attribute %q: %w", args[1], err) + } + nestedNodes := make([]any, len(manyNested.Data)) + for i, n := range manyNested.Data { + nestedNodes[i] = n.Attributes + } + node.Attributes[args[1]] = nestedNodes + } else if isStruct || isPointerToStruct { + // Nested object attribute + nested, err := visitModelNode(fieldValue.Interface(), nil, false) + if err != nil { + return fmt.Errorf("failed to marshal nested attribute %q: %w", args[1], err) + } + node.Attributes[args[1]] = nested.Attributes + } else { + // Primitive attribute + strAttr, ok := fieldValue.Interface().(string) + if ok { + node.Attributes[args[1]] = strAttr + } else { + node.Attributes[args[1]] = fieldValue.Interface() + } + } + } + + return nil +} + +func visitModelNodeRelation(model any, annotation string, args []string, node *Node, fieldValue reflect.Value, included *map[string]*Node, sideload bool) error { + var omitEmpty bool + + //add support for 'omitempty' struct tag for marshaling as absent + if len(args) > 2 { + omitEmpty = args[2] == annotationOmitEmpty + } + + isSlice := fieldValue.Type().Kind() == reflect.Slice + if omitEmpty && + (isSlice && fieldValue.Len() < 1 || + (!isSlice && fieldValue.IsNil())) { + return nil + } + + if annotation == annotationPolyRelation { + // for polyrelation, we'll snoop out the actual relation model + // through the choice type value by choosing the first non-nil + // field that has a jsonapi type annotation and overwriting + // `fieldValue` so normal annotation-assisted marshaling + // can continue + if !isSlice { + choiceValue := fieldValue + + // must be a pointer type + if choiceValue.Type().Kind() != reflect.Ptr { + return ErrUnexpectedType + } + + if choiceValue.IsNil() { + fieldValue = reflect.ValueOf(nil) + } + structValue := choiceValue.Elem() + + // Short circuit if field is omitted from model + if !structValue.IsValid() { + return nil + } + + if found, err := selectChoiceTypeStructField(structValue); err == nil { + fieldValue = found + } + } else { + // A slice polyrelation field can be... polymorphic... meaning + // that we might snoop different types within each slice element. + // Each snooped value will added to this collection and then + // the recursion will take care of the rest. The only special case + // is nil. For that, we'll just choose the first + collection := make([]interface{}, 0) + + for i := 0; i < fieldValue.Len(); i++ { + itemValue := fieldValue.Index(i) + // Once again, must be a pointer type + if itemValue.Type().Kind() != reflect.Ptr { + return ErrUnexpectedType + } + + if itemValue.IsNil() { + return ErrUnexpectedNil + } + + structValue := itemValue.Elem() + + if found, err := selectChoiceTypeStructField(structValue); err == nil { + collection = append(collection, found.Interface()) + } + } + + fieldValue = reflect.ValueOf(collection) + } + } + + if node.Relationships == nil { + node.Relationships = make(map[string]interface{}) + } + + var relLinks *Links + if linkableModel, ok := model.(RelationshipLinkable); ok { + relLinks = linkableModel.JSONAPIRelationshipLinks(args[1]) + } + + var relMeta *Meta + if metableModel, ok := model.(RelationshipMetable); ok { + relMeta = metableModel.JSONAPIRelationshipMeta(args[1]) + } + + if isSlice { + // to-many relationship + relationship, err := visitModelNodeRelationships( + fieldValue, + included, + sideload, + ) + if err != nil { + return err + } + relationship.Links = relLinks + relationship.Meta = relMeta + + if sideload { + shallowNodes := []*Node{} + for _, n := range relationship.Data { + appendIncluded(included, n) + shallowNodes = append(shallowNodes, toShallowNode(n)) + } + + node.Relationships[args[1]] = &RelationshipManyNode{ + Data: shallowNodes, + Links: relationship.Links, + Meta: relationship.Meta, + } + } else { + node.Relationships[args[1]] = relationship + } + } else { + // to-one relationships + + // Handle null relationship case + if fieldValue.IsNil() { + node.Relationships[args[1]] = &RelationshipOneNode{Data: nil} + return nil + } + + relationship, err := visitModelNode( + fieldValue.Interface(), + included, + sideload, + ) + + if err != nil { + return err + } + + if sideload { + appendIncluded(included, relationship) + node.Relationships[args[1]] = &RelationshipOneNode{ + Data: toShallowNode(relationship), + Links: relLinks, + Meta: relMeta, + } + } else { + node.Relationships[args[1]] = &RelationshipOneNode{ + Data: relationship, + Links: relLinks, + Meta: relMeta, + } + } + } + return nil +} + func visitModelNode(model interface{}, included *map[string]*Node, sideload bool) (*Node, error) { node := new(Node) var er error + var modelValue reflect.Value + var modelType reflect.Type value := reflect.ValueOf(model) - if value.IsNil() { - return nil, nil - } - modelValue := value.Elem() - modelType := value.Type().Elem() + if value.Type().Kind() == reflect.Pointer { + if value.IsNil() { + return nil, nil + } + modelValue = value.Elem() + modelType = value.Type().Elem() + } else { + modelValue = value + modelType = value.Type() + } for i := 0; i < modelValue.NumField(); i++ { fieldValue := modelValue.Field(i) @@ -312,251 +594,14 @@ func visitModelNode(model interface{}, included *map[string]*Node, node.ClientID = clientID } } else if annotation == annotationAttribute { - var omitEmpty, iso8601, rfc3339 bool - - if len(args) > 2 { - for _, arg := range args[2:] { - switch arg { - case annotationOmitEmpty: - omitEmpty = true - case annotationISO8601: - iso8601 = true - case annotationRFC3339: - rfc3339 = true - } - } - } - - if node.Attributes == nil { - node.Attributes = make(map[string]interface{}) - } - - // Handle Nullable[T] - if strings.HasPrefix(fieldValue.Type().Name(), "NullableAttr[") { - // handle unspecified - if fieldValue.IsNil() { - continue - } - - // handle null - if fieldValue.MapIndex(reflect.ValueOf(false)).IsValid() { - node.Attributes[args[1]] = json.RawMessage("null") - continue - } else { - - // handle value - fieldValue = fieldValue.MapIndex(reflect.ValueOf(true)) - } - } - - if fieldValue.Type() == reflect.TypeOf(time.Time{}) { - t := fieldValue.Interface().(time.Time) - - if t.IsZero() { - continue - } - - if iso8601 { - node.Attributes[args[1]] = t.UTC().Format(iso8601TimeFormat) - } else if rfc3339 { - node.Attributes[args[1]] = t.UTC().Format(time.RFC3339) - } else { - node.Attributes[args[1]] = t.Unix() - } - } else if fieldValue.Type() == reflect.TypeOf(new(time.Time)) { - // A time pointer may be nil - if fieldValue.IsNil() { - if omitEmpty { - continue - } - - node.Attributes[args[1]] = nil - } else { - tm := fieldValue.Interface().(*time.Time) - - if tm.IsZero() && omitEmpty { - continue - } - - if iso8601 { - node.Attributes[args[1]] = tm.UTC().Format(iso8601TimeFormat) - } else if rfc3339 { - node.Attributes[args[1]] = tm.UTC().Format(time.RFC3339) - } else { - node.Attributes[args[1]] = tm.Unix() - } - } - } else { - // Dealing with a fieldValue that is not a time - emptyValue := reflect.Zero(fieldValue.Type()) - - // See if we need to omit this field - if omitEmpty && reflect.DeepEqual(fieldValue.Interface(), emptyValue.Interface()) { - continue - } - - strAttr, ok := fieldValue.Interface().(string) - if ok { - node.Attributes[args[1]] = strAttr - } else { - node.Attributes[args[1]] = fieldValue.Interface() - } + er = visitModelNodeAttribute(args, node, fieldValue) + if er != nil { + break } } else if annotation == annotationRelation || annotation == annotationPolyRelation { - var omitEmpty bool - - //add support for 'omitempty' struct tag for marshaling as absent - if len(args) > 2 { - omitEmpty = args[2] == annotationOmitEmpty - } - - isSlice := fieldValue.Type().Kind() == reflect.Slice - if omitEmpty && - (isSlice && fieldValue.Len() < 1 || - (!isSlice && fieldValue.IsNil())) { - continue - } - - if annotation == annotationPolyRelation { - // for polyrelation, we'll snoop out the actual relation model - // through the choice type value by choosing the first non-nil - // field that has a jsonapi type annotation and overwriting - // `fieldValue` so normal annotation-assisted marshaling - // can continue - if !isSlice { - choiceValue := fieldValue - - // must be a pointer type - if choiceValue.Type().Kind() != reflect.Ptr { - er = ErrUnexpectedType - break - } - - if choiceValue.IsNil() { - fieldValue = reflect.ValueOf(nil) - } - structValue := choiceValue.Elem() - - // Short circuit if field is omitted from model - if !structValue.IsValid() { - break - } - - if found, err := selectChoiceTypeStructField(structValue); err == nil { - fieldValue = found - } - } else { - // A slice polyrelation field can be... polymorphic... meaning - // that we might snoop different types within each slice element. - // Each snooped value will added to this collection and then - // the recursion will take care of the rest. The only special case - // is nil. For that, we'll just choose the first - collection := make([]interface{}, 0) - - for i := 0; i < fieldValue.Len(); i++ { - itemValue := fieldValue.Index(i) - // Once again, must be a pointer type - if itemValue.Type().Kind() != reflect.Ptr { - er = ErrUnexpectedType - break - } - - if itemValue.IsNil() { - er = ErrUnexpectedNil - break - } - - structValue := itemValue.Elem() - - if found, err := selectChoiceTypeStructField(structValue); err == nil { - collection = append(collection, found.Interface()) - } - } - - if er != nil { - break - } - - fieldValue = reflect.ValueOf(collection) - } - } - - if node.Relationships == nil { - node.Relationships = make(map[string]interface{}) - } - - var relLinks *Links - if linkableModel, ok := model.(RelationshipLinkable); ok { - relLinks = linkableModel.JSONAPIRelationshipLinks(args[1]) - } - - var relMeta *Meta - if metableModel, ok := model.(RelationshipMetable); ok { - relMeta = metableModel.JSONAPIRelationshipMeta(args[1]) - } - - if isSlice { - // to-many relationship - relationship, err := visitModelNodeRelationships( - fieldValue, - included, - sideload, - ) - if err != nil { - er = err - break - } - relationship.Links = relLinks - relationship.Meta = relMeta - - if sideload { - shallowNodes := []*Node{} - for _, n := range relationship.Data { - appendIncluded(included, n) - shallowNodes = append(shallowNodes, toShallowNode(n)) - } - - node.Relationships[args[1]] = &RelationshipManyNode{ - Data: shallowNodes, - Links: relationship.Links, - Meta: relationship.Meta, - } - } else { - node.Relationships[args[1]] = relationship - } - } else { - // to-one relationships - - // Handle null relationship case - if fieldValue.IsNil() { - node.Relationships[args[1]] = &RelationshipOneNode{Data: nil} - continue - } - - relationship, err := visitModelNode( - fieldValue.Interface(), - included, - sideload, - ) - if err != nil { - er = err - break - } - - if sideload { - appendIncluded(included, relationship) - node.Relationships[args[1]] = &RelationshipOneNode{ - Data: toShallowNode(relationship), - Links: relLinks, - Meta: relMeta, - } - } else { - node.Relationships[args[1]] = &RelationshipOneNode{ - Data: relationship, - Links: relLinks, - Meta: relMeta, - } - } + er = visitModelNodeRelation(model, annotation, args, node, fieldValue, included, sideload) + if er != nil { + break } } else if annotation == annotationLinks { // Nothing. Ignore this field, as Links fields are only for unmarshaling requests. @@ -611,7 +656,7 @@ func visitModelNodeRelationships(models reflect.Value, included *map[string]*Nod for i := 0; i < models.Len(); i++ { model := models.Index(i) - if !model.IsValid() || model.IsNil() { + if !model.IsValid() || (model.Kind() == reflect.Pointer && model.IsNil()) { return nil, ErrUnexpectedNil } diff --git a/response_test.go b/response_test.go index d79d64f..2691a1f 100644 --- a/response_test.go +++ b/response_test.go @@ -682,6 +682,131 @@ func TestSupportsAttributes(t *testing.T) { } } +func TestMarshalObjectAttribute(t *testing.T) { + now := time.Now() + testModel := &Company{ + ID: "5", + Name: "test", + Boss: Employee{ + HiredAt: &now, + }, + Manager: &Employee{ + Firstname: "Dave", + HiredAt: &now, + }, + Teams: []Team{ + {Name: "Team 1"}, + {Name: "Team-2"}, + }, + People: []*People{ + {Name: "Person-1"}, + {Name: "Person-2"}, + }, + } + + out := bytes.NewBuffer(nil) + if err := MarshalPayload(out, testModel); err != nil { + t.Fatal(err) + } + + resp := new(OnePayload) + if err := json.NewDecoder(out).Decode(resp); err != nil { + t.Fatal(err) + } + + data := resp.Data + + if data.Attributes == nil { + t.Fatalf("Expected attributes") + } + + boss, ok := data.Attributes["boss"].(map[string]interface{}) + if !ok { + t.Fatalf("Expected boss attribute, got %v", data.Attributes) + } + + hiredAt, ok := boss["hired-at"] + if !ok { + t.Fatalf("Expected boss attribute to contain a \"hired-at\" property, got %v", boss) + } + + if hiredAt != now.UTC().Format(iso8601TimeFormat) { + t.Fatalf("Expected hired-at to be %s, got %s", now.UTC().Format(iso8601TimeFormat), hiredAt) + } + + manager, ok := data.Attributes["manager"].(map[string]interface{}) + if !ok { + t.Fatalf("Expected manager attribute, got %v", data.Attributes) + } + + if manager["firstname"] != "Dave" { + t.Fatalf("Expected manager.firstname to be \"Dave\", got %v", manager) + } + + people, ok := data.Attributes["people"].([]interface{}) + if !ok { + t.Fatalf("Expected people attribute, got %v", data.Attributes) + } + if len(people) != 2 { + t.Fatalf("Expected 2 people, got %v", people) + } + + teams, ok := data.Attributes["teams"].([]interface{}) + if !ok { + t.Fatalf("Expected teams attribute, got %v", data.Attributes) + } + if len(teams) != 2 { + t.Fatalf("Expected 2 teams, got %v", teams) + } +} + +func TestMarshalObjectAttributeWithEmptyNested(t *testing.T) { + testModel := &CompanyOmitEmpty{ + ID: "5", + Name: "test", + Boss: Employee{}, + Manager: nil, + Teams: []Team{}, + People: nil, + } + + out := bytes.NewBuffer(nil) + if err := MarshalPayload(out, testModel); err != nil { + t.Fatal(err) + } + + resp := new(OnePayload) + if err := json.NewDecoder(out).Decode(resp); err != nil { + t.Fatal(err) + } + + data := resp.Data + + if data.Attributes == nil { + t.Fatalf("Expected attributes") + } + + _, ok := data.Attributes["boss"].(map[string]interface{}) + if ok { + t.Fatalf("Expected omitted boss attribute, got %v", data.Attributes) + } + + _, ok = data.Attributes["manager"].(map[string]interface{}) + if ok { + t.Fatalf("Expected omitted manager attribute, got %v", data.Attributes) + } + + _, ok = data.Attributes["people"].([]interface{}) + if ok { + t.Fatalf("Expected omitted people attribute, got %v", data.Attributes) + } + + _, ok = data.Attributes["teams"].([]interface{}) + if ok { + t.Fatalf("Expected omitted teams attribute, got %v", data.Attributes) + } +} + func TestOmitsZeroTimes(t *testing.T) { testModel := &Blog{ ID: 5,