Skip to content

Commit

Permalink
add nested mutation, upsert, nested types, reverse edge support
Browse files Browse the repository at this point in the history
  • Loading branch information
jairad26 committed Dec 23, 2024
1 parent 3e824e2 commit f9db383
Show file tree
Hide file tree
Showing 3 changed files with 212 additions and 54 deletions.
40 changes: 35 additions & 5 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"fmt"
"reflect"

"github.com/dgraph-io/dgraph/v24/dql"
"github.com/dgraph-io/dgraph/v24/schema"
"github.com/dgraph-io/dgraph/v24/x"
)

Expand Down Expand Up @@ -55,7 +57,9 @@ func Create[T any](db *DB, object *T, ns ...uint64) (uint64, *T, error) {
return 0, object, err
}

dms, sch, err := generateCreateDqlMutationsAndSchema(n, object, gid)
dms := make([]*dql.Mutation, 0)
sch := &schema.ParsedSchema{}
err = generateCreateDqlMutationsAndSchema(ctx, n, object, gid, &dms, sch)
if err != nil {
return 0, object, err
}
Expand Down Expand Up @@ -91,11 +95,11 @@ func Get[T any, R UniqueField](db *DB, uniqueField R, ns ...uint64) (uint64, *T,
return 0, nil, err
}
if uid, ok := any(uniqueField).(uint64); ok {
return getByGid[T](ctx, n, uid)
return getByGid[T](ctx, n, uid, true)
}

if cf, ok := any(uniqueField).(ConstrainedField); ok {
return getByConstrainedField[T](ctx, n, cf)
return getByConstrainedField[T](ctx, n, cf, true)
}

return 0, nil, fmt.Errorf("invalid unique field type")
Expand All @@ -109,7 +113,7 @@ func Delete[T any, R UniqueField](db *DB, uniqueField R, ns ...uint64) (uint64,
return 0, nil, err
}
if uid, ok := any(uniqueField).(uint64); ok {
uid, obj, err := getByGid[T](ctx, n, uid)
uid, obj, err := getByGid[T](ctx, n, uid, true)
if err != nil {
return 0, nil, err
}
Expand All @@ -125,8 +129,34 @@ func Delete[T any, R UniqueField](db *DB, uniqueField R, ns ...uint64) (uint64,
}

if cf, ok := any(uniqueField).(ConstrainedField); ok {
return getByConstrainedField[T](ctx, n, cf)
uid, obj, err := getByConstrainedField[T](ctx, n, cf, true)
if err != nil {
return 0, nil, err
}

dms := generateDeleteDqlMutations(n, uid)

err = applyDqlMutations(ctx, db, dms)
if err != nil {
return 0, nil, err
}

return uid, obj, nil
}

return 0, nil, fmt.Errorf("invalid unique field type")
}

func Upsert[T any](db *DB, object *T, ns ...uint64) (uint64, *T, bool, error) {
db.mutex.Lock()
defer db.mutex.Unlock()
if len(ns) > 1 {
return 0, object, false, fmt.Errorf("only one namespace is allowed")
}
ctx, n, err := getDefaultNamespace(db, ns...)
if err != nil {
return 0, object, false, err
}

return upsertHelper[T](ctx, db, n, object, true)
}
206 changes: 167 additions & 39 deletions api_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func valueToPosting_ValType(v any) (pb.Posting_ValType, error) {
}
}

func valueToValType(v any) (*api.Value, error) {
func valueToApiVal(v any) (*api.Value, error) {
switch val := v.(type) {
case string:
return &api.Value{Val: &api.Value_StrVal{StrVal: val}}, nil
Expand Down Expand Up @@ -97,42 +97,66 @@ func valueToValType(v any) (*api.Value, error) {
}
}

func generateCreateDqlMutationsAndSchema[T any](n *Namespace, object *T,
gid uint64) ([]*dql.Mutation, *schema.ParsedSchema, error) {
func generateCreateDqlMutationsAndSchema[T any](ctx context.Context, n *Namespace, object *T,
gid uint64, dms *[]*dql.Mutation, sch *schema.ParsedSchema) error {
t := reflect.TypeOf(*object)
if t.Kind() != reflect.Struct {
return nil, nil, fmt.Errorf("expected struct, got %s", t.Kind())
return fmt.Errorf("expected struct, got %s", t.Kind())
}

jsonFields, dbFields, _, err := getFieldTags(t)
fieldToJsonTags, jsonToDbTags, _, err := getFieldTags(t)
if err != nil {
return nil, nil, err
return err
}
values := getFieldValues(object, jsonFields)
sch := &schema.ParsedSchema{}
values := getJsonTagToValues(object, fieldToJsonTags)

nquads := make([]*api.NQuad, 0)
for jsonName, value := range values {
if jsonName == "gid" {
continue
}
valType, err := valueToPosting_ValType(value)
if err != nil {
return nil, nil, err
var val *api.Value
var valType pb.Posting_ValType
if reflect.TypeOf(value).Kind() == reflect.Struct {
gid, _, _, err := upsertHelper(ctx, n.db, n, &value, false)
if err != nil {
return err
}
valType, err = valueToPosting_ValType(fmt.Sprint(gid))
if err != nil {
return err
}
val, err = valueToApiVal(fmt.Sprint(gid))
if err != nil {
return err
}
} else {
valType, err = valueToPosting_ValType(value)
if err != nil {
return err
}
val, err = valueToApiVal(value)
if err != nil {
return err
}
}
u := &pb.SchemaUpdate{
Predicate: addNamespace(n.id, getPredicateName(t.Name(), jsonName)),
ValueType: valType,
}
if dbFields[jsonName] != nil && dbFields[jsonName].constraint == "unique" {
u.Directive = pb.SchemaUpdate_INDEX
u.Tokenizer = []string{"exact"}
if jsonToDbTags[jsonName] != nil {
constraint := jsonToDbTags[jsonName].constraint
if constraint == "unique" || constraint == "term" {
u.Directive = pb.SchemaUpdate_INDEX
if constraint == "unique" {
u.Tokenizer = []string{"exact"}
} else {
u.Tokenizer = []string{"term"}
}
}
}

sch.Preds = append(sch.Preds, u)
val, err := valueToValType(value)
if err != nil {
return nil, nil, err
}
nquad := &api.NQuad{
Namespace: n.ID(),
Subject: fmt.Sprint(gid),
Expand All @@ -146,9 +170,9 @@ func generateCreateDqlMutationsAndSchema[T any](n *Namespace, object *T,
Fields: sch.Preds,
})

val, err := valueToValType(t.Name())
val, err := valueToApiVal(t.Name())
if err != nil {
return nil, nil, err
return err
}
nquad := &api.NQuad{
Namespace: n.ID(),
Expand All @@ -158,12 +182,11 @@ func generateCreateDqlMutationsAndSchema[T any](n *Namespace, object *T,
}
nquads = append(nquads, nquad)

dms := make([]*dql.Mutation, 0)
dms = append(dms, &dql.Mutation{
*dms = append(*dms, &dql.Mutation{
Set: nquads,
})

return dms, sch, nil
return nil
}

func generateDeleteDqlMutations(n *Namespace, gid uint64) []*dql.Mutation {
Expand All @@ -181,48 +204,71 @@ func generateDeleteDqlMutations(n *Namespace, gid uint64) []*dql.Mutation {
}}
}

func getByGid[T any](ctx context.Context, n *Namespace, gid uint64) (uint64, *T, error) {
query := fmt.Sprintf(`
func getByGid[T any](ctx context.Context, n *Namespace, gid uint64, readFrom bool) (uint64, *T, error) {
query := `
{
obj(func: uid(%d)) {
uid
expand(_all_)
dgraph.type
%s
}
}
`, gid)
`

return executeGet[T](ctx, n, query, nil)
return executeGet[T](ctx, n, query, readFrom, gid)
}

func getByConstrainedField[T any](ctx context.Context, n *Namespace, cf ConstrainedField) (uint64, *T, error) {
var obj T

t := reflect.TypeOf(obj)
query := fmt.Sprintf(`
func getByConstrainedField[T any](ctx context.Context, n *Namespace, cf ConstrainedField, readFrom bool) (uint64, *T, error) {
query := `
{
obj(func: eq(%s, %s)) {
uid
expand(_all_)
dgraph.type
%s
}
}
`, getPredicateName(t.Name(), cf.Key), cf.Value)
`

return executeGet[T](ctx, n, query, &cf)
return executeGet[T](ctx, n, query, readFrom, cf)
}

func executeGet[T any](ctx context.Context, n *Namespace, query string, cf *ConstrainedField) (uint64, *T, error) {
func executeGet[T any, R UniqueField](ctx context.Context, n *Namespace, query string, readFrom bool, args ...R) (uint64, *T, error) {
if len(args) != 1 {
return 0, nil, fmt.Errorf("expected 1 argument, got %d", len(args))
}

var obj T

t := reflect.TypeOf(obj)

jsonFields, dbTags, _, err := getFieldTags(t)
fieldToJsonTags, jsonToDbTag, reverseEdgeTags, err := getFieldTags(t)
if err != nil {
return 0, nil, err
}
readFromQuery := ""
for fieldName, reverseEdgeTag := range reverseEdgeTags {
readFromQuery += fmt.Sprintf(`
%s: ~%s {
uid
expand(_all_)
dgraph.type
}
`, getPredicateName(t.Name(), fieldToJsonTags[fieldName]), reverseEdgeTag)
}

if cf != nil && dbTags[cf.Key].constraint == "" {
var cf ConstrainedField
gid, ok := any(args[0]).(uint64)
if ok {
query = fmt.Sprintf(query, gid, readFromQuery)
} else if cf, ok = any(args[0]).(ConstrainedField); ok {
query = fmt.Sprintf(query, getPredicateName(t.Name(), cf.Key), cf.Value, readFromQuery)
} else {
return 0, nil, fmt.Errorf("invalid unique field type")
}

if jsonToDbTag[cf.Key] != nil && jsonToDbTag[cf.Key].constraint == "" {
return 0, nil, fmt.Errorf("constraint not defined for field %s", cf.Key)
}

Expand All @@ -231,7 +277,7 @@ func executeGet[T any](ctx context.Context, n *Namespace, query string, cf *Cons
return 0, nil, err
}

dynamicType := createDynamicStruct(t, jsonFields)
dynamicType := createDynamicStruct(t, fieldToJsonTags)

dynamicInstance := reflect.New(dynamicType).Interface()

Expand All @@ -253,7 +299,7 @@ func executeGet[T any](ctx context.Context, n *Namespace, query string, cf *Cons

// Map the dynamic struct to the final type T
finalObject := reflect.New(t).Interface()
gid, err := mapDynamicToFinal(result.Obj[0], finalObject)
gid, err = mapDynamicToFinal(result.Obj[0], finalObject)
if err != nil {
return 0, nil, err
}
Expand Down Expand Up @@ -299,3 +345,85 @@ func applyDqlMutations(ctx context.Context, db *DB, dms []*dql.Mutation) error {
Txns: []*pb.TxnStatus{{StartTs: startTs, CommitTs: commitTs}},
})
}

func getUniqueConstraint[T any](object *T) (uint64, *ConstrainedField, error) {
t := reflect.TypeOf(*object)
fieldToJsonTags, jsonToDbTags, _, err := getFieldTags(t)
if err != nil {
return 0, nil, err
}
values := getJsonTagToValues(object, fieldToJsonTags)

for jsonName, value := range values {
if jsonName == "gid" {
gid, ok := value.(uint64)
if !ok {
return 0, nil, fmt.Errorf("expected uint64 type for gid, got %T", value)
}
if gid != 0 {
return gid, nil, nil
}
}
if jsonToDbTags[jsonName] != nil && jsonToDbTags[jsonName].constraint == "unique" {
return 0, &ConstrainedField{
Key: jsonName,
Value: value,
}, nil
}
}

return 0, nil, fmt.Errorf("unique constraint not defined for any field on type %s", t.Name())
}

func upsertHelper[T any](ctx context.Context, db *DB, n *Namespace, object *T, readFrom bool) (uint64, *T, bool, error) {
gid, cf, err := getUniqueConstraint(object)
if err != nil {
return 0, object, false, err
}
if gid != 0 {
gid, object, err := getByGid[T](ctx, n, gid, readFrom)
if err != nil {
return 0, object, false, err
}
return gid, object, true, nil
} else if cf != nil {
gid, object, err := getByConstrainedField[T](ctx, n, *cf, readFrom)
if err == nil {
return gid, object, true, nil
}
}

gid, err = db.z.nextUID()
if err != nil {
return 0, object, false, err
}

dms := make([]*dql.Mutation, 0)
sch := &schema.ParsedSchema{}
err = generateCreateDqlMutationsAndSchema(ctx, n, object, gid, &dms, sch)
if err != nil {
return 0, object, false, err
}

ctx = x.AttachNamespace(ctx, n.ID())

err = n.alterSchemaWithParsed(ctx, sch)
if err != nil {
return 0, object, false, err
}

err = applyDqlMutations(ctx, db, dms)
if err != nil {
return 0, object, false, err
}

v := reflect.ValueOf(object).Elem()

gidField := v.FieldByName("Gid")

if gidField.IsValid() && gidField.CanSet() && gidField.Kind() == reflect.Uint64 {
gidField.SetUint(gid)
}

return gid, object, false, nil
}
Loading

0 comments on commit f9db383

Please sign in to comment.