From bcd42e1bbcb3b100e45386eb22c071c4b2833026 Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Fri, 3 Jan 2025 08:25:39 -0800 Subject: [PATCH] feat: add readfrom json tag to support reverse edges (#49) **Description** This PR adds support for reverse edges via the readFrom json tag **Checklist** - [x] Code compiles correctly and linting passes locally - [x] For all _code_ changes, an entry added to the `CHANGELOG.md` file describing and linking to this PR - [x] Tests added for new functionality, or regression tests for bug fixes added as applicable - [ ] For public APIs, new features, etc., PR on [docs repo](https://github.com/hypermodeinc/docs) staged and linked here --- CHANGELOG.md | 5 ++ api.go | 2 +- api_dql.go | 16 +++-- api_mutate_helper.go | 37 ++++++++--- api_query_helper.go | 20 ++---- api_reflect.go | 54 ++++++++++++---- api_test.go | 149 +++++++++++++++++++++++++++++++++++++++++-- 7 files changed, 236 insertions(+), 47 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2999085..4d92b85 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # Changelog +## UNRELEASED + +- feat: add readfrom json tag to support reverse edges + [#49](https://github.com/hypermodeinc/modusDB/pull/49) + ## 2025-01-02 - Version 0.1.0 Baseline for the changelog. diff --git a/api.go b/api.go index 379fccd..652ea19 100644 --- a/api.go +++ b/api.go @@ -155,7 +155,7 @@ func Query[T any](db *DB, queryParams QueryParams, ns ...uint64) ([]uint64, []T, return nil, nil, err } - return executeQuery[T](ctx, n, queryParams, false) + return executeQuery[T](ctx, n, queryParams, true) } func Delete[T any, R UniqueField](db *DB, uniqueField R, ns ...uint64) (uint64, T, error) { diff --git a/api_dql.go b/api_dql.go index 203440a..6c63171 100644 --- a/api_dql.go +++ b/api_dql.go @@ -21,9 +21,9 @@ const ( objQuery = ` { obj(func: %s) { - uid + gid: uid expand(_all_) { - uid + gid: uid expand(_all_) dgraph.type } @@ -36,9 +36,9 @@ const ( objsQuery = ` { objs(func: type("%s")%s) @filter(%s) { - uid + gid: uid expand(_all_) { - uid + gid: uid expand(_all_) dgraph.type } @@ -48,6 +48,14 @@ const ( } ` + reverseEdgeQuery = ` + %s: ~%s { + gid: uid + expand(_all_) + dgraph.type + } + ` + funcUid = `uid(%d)` funcEq = `eq(%s, %s)` funcSimilarTo = `similar_to(%s, %d, "[%s]")` diff --git a/api_mutate_helper.go b/api_mutate_helper.go index 2dd4188..f90c258 100644 --- a/api_mutate_helper.go +++ b/api_mutate_helper.go @@ -13,6 +13,7 @@ import ( "context" "fmt" "reflect" + "strings" "github.com/dgraph-io/dgo/v240/protos/api" "github.com/dgraph-io/dgraph/v24/dql" @@ -39,18 +40,34 @@ func generateCreateDqlMutationsAndSchema[T any](ctx context.Context, n *Namespac nquads := make([]*api.NQuad, 0) uniqueConstraintFound := false for jsonName, value := range jsonTagToValue { + var val *api.Value + var valType pb.Posting_ValType + + reflectValueType := reflect.TypeOf(value) + var nquad *api.NQuad + if jsonToReverseEdgeTags[jsonName] != "" { + if reflectValueType.Kind() != reflect.Slice || reflectValueType.Elem().Kind() != reflect.Struct { + return fmt.Errorf("reverse edge %s should be a slice of structs", jsonName) + } + reverseEdge := jsonToReverseEdgeTags[jsonName] + typeName := strings.Split(reverseEdge, ".")[0] + u := &pb.SchemaUpdate{ + Predicate: addNamespace(n.id, reverseEdge), + ValueType: pb.Posting_UID, + Directive: pb.SchemaUpdate_REVERSE, + } + sch.Preds = append(sch.Preds, u) + sch.Types = append(sch.Types, &pb.TypeUpdate{ + TypeName: addNamespace(n.id, typeName), + Fields: []*pb.SchemaUpdate{u}, + }) continue } if jsonName == "gid" { uniqueConstraintFound = true continue } - var val *api.Value - var valType pb.Posting_ValType - - reflectValueType := reflect.TypeOf(value) - var nquad *api.NQuad if reflectValueType.Kind() == reflect.Struct { value = reflect.ValueOf(value).Interface() @@ -87,16 +104,18 @@ func generateCreateDqlMutationsAndSchema[T any](ctx context.Context, n *Namespac Predicate: getPredicateName(t.Name(), jsonName), } + u := &pb.SchemaUpdate{ + Predicate: addNamespace(n.id, getPredicateName(t.Name(), jsonName)), + ValueType: valType, + } + if valType == pb.Posting_UID { nquad.ObjectId = fmt.Sprint(value) + u.Directive = pb.SchemaUpdate_REVERSE } else { nquad.ObjectValue = val } - u := &pb.SchemaUpdate{ - Predicate: addNamespace(n.id, getPredicateName(t.Name(), jsonName)), - ValueType: valType, - } if jsonToDbTags[jsonName] != nil { constraint := jsonToDbTags[jsonName].constraint if constraint == "vector" && valType != pb.Posting_VFLOAT { diff --git a/api_query_helper.go b/api_query_helper.go index 5c2552f..9ec211c 100644 --- a/api_query_helper.go +++ b/api_query_helper.go @@ -54,13 +54,7 @@ func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespac readFromQuery := "" if withReverse { for jsonTag, reverseEdgeTag := range jsonToReverseEdgeTags { - readFromQuery += fmt.Sprintf(` - %s: ~%s { - uid - expand(_all_) - dgraph.type - } - `, getPredicateName(t.Name(), jsonTag), reverseEdgeTag) + readFromQuery += fmt.Sprintf(reverseEdgeQuery, getPredicateName(t.Name(), jsonTag), reverseEdgeTag) } } @@ -106,7 +100,7 @@ func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *Namespac // Map the dynamic struct to the final type T finalObject := reflect.New(t).Interface() - gid, err = mapDynamicToFinal(result.Obj[0], finalObject) + gid, err = mapDynamicToFinal(result.Obj[0], finalObject, false) if err != nil { return 0, obj, err } @@ -152,13 +146,7 @@ func executeQuery[T any](ctx context.Context, n *Namespace, queryParams QueryPar readFromQuery := "" if withReverse { for jsonTag, reverseEdgeTag := range jsonToReverseEdgeTags { - readFromQuery += fmt.Sprintf(` - %s: ~%s { - uid - expand(_all_) - dgraph.type - } - `, getPredicateName(t.Name(), jsonTag), reverseEdgeTag) + readFromQuery += fmt.Sprintf(reverseEdgeQuery, getPredicateName(t.Name(), jsonTag), reverseEdgeTag) } } @@ -197,7 +185,7 @@ func executeQuery[T any](ctx context.Context, n *Namespace, queryParams QueryPar var objs []T for _, obj := range result.Objs { finalObject := reflect.New(t).Interface() - gid, err := mapDynamicToFinal(obj, finalObject) + gid, err := mapDynamicToFinal(obj, finalObject, false) if err != nil { return nil, nil, err } diff --git a/api_reflect.go b/api_reflect.go index 1779a81..5a32b01 100644 --- a/api_reflect.go +++ b/api_reflect.go @@ -82,7 +82,7 @@ func createDynamicStruct(t reflect.Type, fieldToJsonTags map[string]string, dept field, _ := t.FieldByName(fieldName) if fieldName != "Gid" { if field.Type.Kind() == reflect.Struct { - if depth <= 2 { + if depth <= 1 { nestedFieldToJsonTags, _, _, _ := getFieldTags(field.Type) nestedType := createDynamicStruct(field.Type, nestedFieldToJsonTags, depth+1) fields = append(fields, reflect.StructField{ @@ -100,6 +100,15 @@ func createDynamicStruct(t reflect.Type, fieldToJsonTags map[string]string, dept Type: reflect.PointerTo(nestedType), Tag: reflect.StructTag(fmt.Sprintf(`json:"%s.%s"`, t.Name(), jsonName)), }) + } else if field.Type.Kind() == reflect.Slice && + field.Type.Elem().Kind() == reflect.Struct { + nestedFieldToJsonTags, _, _, _ := getFieldTags(field.Type.Elem()) + nestedType := createDynamicStruct(field.Type.Elem(), nestedFieldToJsonTags, depth+1) + fields = append(fields, reflect.StructField{ + Name: field.Name, + Type: reflect.SliceOf(nestedType), + Tag: reflect.StructTag(fmt.Sprintf(`json:"%s.%s"`, t.Name(), jsonName)), + }) } else { fields = append(fields, reflect.StructField{ Name: field.Name, @@ -111,9 +120,9 @@ func createDynamicStruct(t reflect.Type, fieldToJsonTags map[string]string, dept } } fields = append(fields, reflect.StructField{ - Name: "Uid", + Name: "Gid", Type: reflect.TypeOf(""), - Tag: reflect.StructTag(`json:"uid"`), + Tag: reflect.StructTag(`json:"gid"`), }, reflect.StructField{ Name: "DgraphType", Type: reflect.TypeOf([]string{}), @@ -122,7 +131,7 @@ func createDynamicStruct(t reflect.Type, fieldToJsonTags map[string]string, dept return reflect.StructOf(fields) } -func mapDynamicToFinal(dynamic any, final any) (uint64, error) { +func mapDynamicToFinal(dynamic any, final any, isNested bool) (uint64, error) { vFinal := reflect.ValueOf(final).Elem() vDynamic := reflect.ValueOf(dynamic).Elem() @@ -135,35 +144,54 @@ func mapDynamicToFinal(dynamic any, final any) (uint64, error) { dynamicValue := vDynamic.Field(i) var finalField reflect.Value - if dynamicField.Name == "Uid" { + if dynamicField.Name == "Gid" { finalField = vFinal.FieldByName("Gid") gidStr := dynamicValue.String() gid, _ = strconv.ParseUint(gidStr, 0, 64) } else if dynamicField.Name == "DgraphType" { - fieldArr := dynamicValue.Interface().([]string) - if len(fieldArr) == 0 { - return 0, ErrNoObjFound + fieldArrInterface := dynamicValue.Interface() + fieldArr, ok := fieldArrInterface.([]string) + if ok { + if len(fieldArr) == 0 { + if !isNested { + return 0, ErrNoObjFound + } else { + continue + } + } + } else { + return 0, fmt.Errorf("DgraphType field should be an array of strings") } } else { finalField = vFinal.FieldByName(dynamicField.Name) } if dynamicFieldType.Kind() == reflect.Struct { - _, err := mapDynamicToFinal(dynamicValue.Addr().Interface(), finalField.Addr().Interface()) + _, err := mapDynamicToFinal(dynamicValue.Addr().Interface(), finalField.Addr().Interface(), true) if err != nil { return 0, err } } else if dynamicFieldType.Kind() == reflect.Ptr && dynamicFieldType.Elem().Kind() == reflect.Struct { // if field is a pointer, find if the underlying is a struct - _, err := mapDynamicToFinal(dynamicValue.Interface(), finalField.Interface()) + _, err := mapDynamicToFinal(dynamicValue.Interface(), finalField.Interface(), true) if err != nil { return 0, err } - + } else if dynamicFieldType.Kind() == reflect.Slice && + dynamicFieldType.Elem().Kind() == reflect.Struct { + for j := 0; j < dynamicValue.Len(); j++ { + sliceElem := dynamicValue.Index(j).Addr().Interface() + finalSliceElem := reflect.New(finalField.Type().Elem()).Elem() + _, err := mapDynamicToFinal(sliceElem, finalSliceElem.Addr().Interface(), true) + if err != nil { + return 0, err + } + finalField.Set(reflect.Append(finalField, finalSliceElem)) + } } else { if finalField.IsValid() && finalField.CanSet() { - // if field name is uid, convert it to uint64 - if dynamicField.Name == "Uid" { + // if field name is gid, convert it to uint64 + if dynamicField.Name == "Gid" { finalField.SetUint(gid) } else { finalField.Set(dynamicValue) diff --git a/api_test.go b/api_test.go index 0ae4bea..f7f15e0 100644 --- a/api_test.go +++ b/api_test.go @@ -400,10 +400,10 @@ func TestQueryApiWithPaginiationAndSorting(t *testing.T) { } type Project struct { - Gid uint64 `json:"gid,omitempty"` - Name string `json:"name,omitempty"` - ClerkId string `json:"clerk_id,omitempty" db:"constraint=unique"` - // Branches []Branch `json:"branches,omitempty" readFrom:"type=Branch,field=proj"` + Gid uint64 `json:"gid,omitempty"` + Name string `json:"name,omitempty"` + ClerkId string `json:"clerk_id,omitempty" db:"constraint=unique"` + Branches []Branch `json:"branches,omitempty" readFrom:"type=Branch,field=proj"` } type Branch struct { @@ -413,6 +413,147 @@ type Branch struct { Proj Project `json:"proj,omitempty"` } +func TestReverseEdgeGet(t *testing.T) { + ctx := context.Background() + db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) + require.NoError(t, err) + defer db.Close() + + db1, err := db.CreateNamespace() + require.NoError(t, err) + + require.NoError(t, db1.DropData(ctx)) + + projGid, project, err := modusdb.Create(db, Project{ + Name: "P", + ClerkId: "456", + Branches: []Branch{ + {Name: "B", ClerkId: "123"}, + {Name: "B2", ClerkId: "456"}, + }, + }, db1.ID()) + require.NoError(t, err) + + require.Equal(t, "P", project.Name) + require.Equal(t, project.Gid, projGid) + + // modifying a read-only field will be a no-op + require.Len(t, project.Branches, 0) + + branch1 := Branch{ + Name: "B", + ClerkId: "123", + Proj: Project{ + Gid: projGid, + }, + } + + branch1Gid, branch1, err := modusdb.Create(db, branch1, db1.ID()) + require.NoError(t, err) + + require.Equal(t, "B", branch1.Name) + require.Equal(t, branch1.Gid, branch1Gid) + require.Equal(t, projGid, branch1.Proj.Gid) + require.Equal(t, "P", branch1.Proj.Name) + + branch2 := Branch{ + Name: "B2", + ClerkId: "456", + Proj: Project{ + Gid: projGid, + }, + } + + branch2Gid, branch2, err := modusdb.Create(db, branch2, db1.ID()) + require.NoError(t, err) + require.Equal(t, "B2", branch2.Name) + require.Equal(t, branch2.Gid, branch2Gid) + require.Equal(t, projGid, branch2.Proj.Gid) + + getProjGid, queriedProject, err := modusdb.Get[Project](db, projGid, db1.ID()) + require.NoError(t, err) + require.Equal(t, projGid, getProjGid) + require.Equal(t, "P", queriedProject.Name) + require.Len(t, queriedProject.Branches, 2) + require.Equal(t, "B", queriedProject.Branches[0].Name) + require.Equal(t, "B2", queriedProject.Branches[1].Name) + + queryBranchesGids, queriedBranches, err := modusdb.Query[Branch](db, modusdb.QueryParams{}, db1.ID()) + require.NoError(t, err) + require.Len(t, queriedBranches, 2) + require.Len(t, queryBranchesGids, 2) + require.Equal(t, "B", queriedBranches[0].Name) + require.Equal(t, "B2", queriedBranches[1].Name) + + // max depth is 2, so we should not see the branches within project + require.Len(t, queriedBranches[0].Proj.Branches, 0) + + _, _, err = modusdb.Delete[Project](db, projGid, db1.ID()) + require.NoError(t, err) + + queryBranchesGids, queriedBranches, err = modusdb.Query[Branch](db, modusdb.QueryParams{}, db1.ID()) + require.NoError(t, err) + require.Len(t, queriedBranches, 2) + require.Len(t, queryBranchesGids, 2) + require.Equal(t, "B", queriedBranches[0].Name) + require.Equal(t, "B2", queriedBranches[1].Name) +} + +func TestReverseEdgeQuery(t *testing.T) { + ctx := context.Background() + db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) + require.NoError(t, err) + defer db.Close() + + db1, err := db.CreateNamespace() + require.NoError(t, err) + + require.NoError(t, db1.DropData(ctx)) + + projects := []Project{ + {Name: "P1", ClerkId: "456"}, + {Name: "P2", ClerkId: "789"}, + } + + branchCounter := 1 + clerkCounter := 100 + + for _, project := range projects { + projGid, project, err := modusdb.Create(db, project, db1.ID()) + require.NoError(t, err) + require.Equal(t, project.Name, project.Name) + require.Equal(t, project.Gid, projGid) + + branches := []Branch{ + {Name: fmt.Sprintf("B%d", branchCounter), ClerkId: fmt.Sprintf("%d", clerkCounter), Proj: Project{Gid: projGid}}, + {Name: fmt.Sprintf("B%d", branchCounter+1), ClerkId: fmt.Sprintf("%d", clerkCounter+1), Proj: Project{Gid: projGid}}, + } + branchCounter += 2 + clerkCounter += 2 + + for _, branch := range branches { + branchGid, branch, err := modusdb.Create(db, branch, db1.ID()) + require.NoError(t, err) + require.Equal(t, branch.Name, branch.Name) + require.Equal(t, branch.Gid, branchGid) + require.Equal(t, projGid, branch.Proj.Gid) + } + } + + queriedProjectsGids, queriedProjects, err := modusdb.Query[Project](db, modusdb.QueryParams{}, db1.ID()) + require.NoError(t, err) + require.Len(t, queriedProjects, 2) + require.Len(t, queriedProjectsGids, 2) + require.Equal(t, "P1", queriedProjects[0].Name) + require.Equal(t, "P2", queriedProjects[1].Name) + require.Len(t, queriedProjects[0].Branches, 2) + require.Len(t, queriedProjects[1].Branches, 2) + require.Equal(t, "B1", queriedProjects[0].Branches[0].Name) + require.Equal(t, "B2", queriedProjects[0].Branches[1].Name) + require.Equal(t, "B3", queriedProjects[1].Branches[0].Name) + require.Equal(t, "B4", queriedProjects[1].Branches[1].Name) +} + func TestNestedObjectMutation(t *testing.T) { ctx := context.Background() db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir()))