diff --git a/api.go b/api.go index 88a42b2..4966e5b 100644 --- a/api.go +++ b/api.go @@ -5,6 +5,8 @@ import ( "fmt" "reflect" + "github.com/dgraph-io/dgraph/v24/dql" + "github.com/dgraph-io/dgraph/v24/schema" "github.com/dgraph-io/dgraph/v24/x" ) @@ -55,7 +57,9 @@ func Create[T any](db *DB, object *T, ns ...uint64) (uint64, *T, error) { return 0, object, err } - dms, sch, err := generateCreateDqlMutationsAndSchema(n, object, gid) + dms := make([]*dql.Mutation, 0) + sch := &schema.ParsedSchema{} + err = generateCreateDqlMutationsAndSchema(ctx, n, object, gid, &dms, sch) if err != nil { return 0, object, err } @@ -91,11 +95,11 @@ func Get[T any, R UniqueField](db *DB, uniqueField R, ns ...uint64) (uint64, *T, return 0, nil, err } if uid, ok := any(uniqueField).(uint64); ok { - return getByGid[T](ctx, n, uid) + return getByGid[T](ctx, n, uid, true) } if cf, ok := any(uniqueField).(ConstrainedField); ok { - return getByConstrainedField[T](ctx, n, cf) + return getByConstrainedField[T](ctx, n, cf, true) } return 0, nil, fmt.Errorf("invalid unique field type") @@ -109,7 +113,7 @@ func Delete[T any, R UniqueField](db *DB, uniqueField R, ns ...uint64) (uint64, return 0, nil, err } if uid, ok := any(uniqueField).(uint64); ok { - uid, obj, err := getByGid[T](ctx, n, uid) + uid, obj, err := getByGid[T](ctx, n, uid, true) if err != nil { return 0, nil, err } @@ -125,8 +129,34 @@ func Delete[T any, R UniqueField](db *DB, uniqueField R, ns ...uint64) (uint64, } if cf, ok := any(uniqueField).(ConstrainedField); ok { - return getByConstrainedField[T](ctx, n, cf) + uid, obj, err := getByConstrainedField[T](ctx, n, cf, true) + if err != nil { + return 0, nil, err + } + + dms := generateDeleteDqlMutations(n, uid) + + err = applyDqlMutations(ctx, db, dms) + if err != nil { + return 0, nil, err + } + + return uid, obj, nil } return 0, nil, fmt.Errorf("invalid unique field type") } + +func Upsert[T any](db *DB, object *T, ns ...uint64) (uint64, *T, bool, error) { + db.mutex.Lock() + defer db.mutex.Unlock() + if len(ns) > 1 { + return 0, object, false, fmt.Errorf("only one namespace is allowed") + } + ctx, n, err := getDefaultNamespace(db, ns...) + if err != nil { + return 0, object, false, err + } + + return upsertHelper[T](ctx, db, n, object, true) +} diff --git a/api_helper.go b/api_helper.go index 0cda75b..6baad79 100644 --- a/api_helper.go +++ b/api_helper.go @@ -50,7 +50,7 @@ func valueToPosting_ValType(v any) (pb.Posting_ValType, error) { } } -func valueToValType(v any) (*api.Value, error) { +func valueToApiVal(v any) (*api.Value, error) { switch val := v.(type) { case string: return &api.Value{Val: &api.Value_StrVal{StrVal: val}}, nil @@ -97,42 +97,66 @@ func valueToValType(v any) (*api.Value, error) { } } -func generateCreateDqlMutationsAndSchema[T any](n *Namespace, object *T, - gid uint64) ([]*dql.Mutation, *schema.ParsedSchema, error) { +func generateCreateDqlMutationsAndSchema[T any](ctx context.Context, n *Namespace, object *T, + gid uint64, dms *[]*dql.Mutation, sch *schema.ParsedSchema) error { t := reflect.TypeOf(*object) if t.Kind() != reflect.Struct { - return nil, nil, fmt.Errorf("expected struct, got %s", t.Kind()) + return fmt.Errorf("expected struct, got %s", t.Kind()) } - jsonFields, dbFields, _, err := getFieldTags(t) + fieldToJsonTags, jsonToDbTags, _, err := getFieldTags(t) if err != nil { - return nil, nil, err + return err } - values := getFieldValues(object, jsonFields) - sch := &schema.ParsedSchema{} + values := getJsonTagToValues(object, fieldToJsonTags) nquads := make([]*api.NQuad, 0) for jsonName, value := range values { if jsonName == "gid" { continue } - valType, err := valueToPosting_ValType(value) - if err != nil { - return nil, nil, err + var val *api.Value + var valType pb.Posting_ValType + if reflect.TypeOf(value).Kind() == reflect.Struct { + gid, _, _, err := upsertHelper(ctx, n.db, n, &value, false) + if err != nil { + return err + } + valType, err = valueToPosting_ValType(fmt.Sprint(gid)) + if err != nil { + return err + } + val, err = valueToApiVal(fmt.Sprint(gid)) + if err != nil { + return err + } + } else { + valType, err = valueToPosting_ValType(value) + if err != nil { + return err + } + val, err = valueToApiVal(value) + if err != nil { + return err + } } u := &pb.SchemaUpdate{ Predicate: addNamespace(n.id, getPredicateName(t.Name(), jsonName)), ValueType: valType, } - if dbFields[jsonName] != nil && dbFields[jsonName].constraint == "unique" { - u.Directive = pb.SchemaUpdate_INDEX - u.Tokenizer = []string{"exact"} + if jsonToDbTags[jsonName] != nil { + constraint := jsonToDbTags[jsonName].constraint + if constraint == "unique" || constraint == "term" { + u.Directive = pb.SchemaUpdate_INDEX + if constraint == "unique" { + u.Tokenizer = []string{"exact"} + } else { + u.Tokenizer = []string{"term"} + } + } } + sch.Preds = append(sch.Preds, u) - val, err := valueToValType(value) - if err != nil { - return nil, nil, err - } nquad := &api.NQuad{ Namespace: n.ID(), Subject: fmt.Sprint(gid), @@ -146,9 +170,9 @@ func generateCreateDqlMutationsAndSchema[T any](n *Namespace, object *T, Fields: sch.Preds, }) - val, err := valueToValType(t.Name()) + val, err := valueToApiVal(t.Name()) if err != nil { - return nil, nil, err + return err } nquad := &api.NQuad{ Namespace: n.ID(), @@ -158,12 +182,11 @@ func generateCreateDqlMutationsAndSchema[T any](n *Namespace, object *T, } nquads = append(nquads, nquad) - dms := make([]*dql.Mutation, 0) - dms = append(dms, &dql.Mutation{ + *dms = append(*dms, &dql.Mutation{ Set: nquads, }) - return dms, sch, nil + return nil } func generateDeleteDqlMutations(n *Namespace, gid uint64) []*dql.Mutation { @@ -181,48 +204,71 @@ func generateDeleteDqlMutations(n *Namespace, gid uint64) []*dql.Mutation { }} } -func getByGid[T any](ctx context.Context, n *Namespace, gid uint64) (uint64, *T, error) { - query := fmt.Sprintf(` +func getByGid[T any](ctx context.Context, n *Namespace, gid uint64, readFrom bool) (uint64, *T, error) { + query := ` { obj(func: uid(%d)) { uid expand(_all_) dgraph.type + %s } } - `, gid) + ` - return executeGet[T](ctx, n, query, nil) + return executeGet[T](ctx, n, query, readFrom, gid) } -func getByConstrainedField[T any](ctx context.Context, n *Namespace, cf ConstrainedField) (uint64, *T, error) { - var obj T - - t := reflect.TypeOf(obj) - query := fmt.Sprintf(` +func getByConstrainedField[T any](ctx context.Context, n *Namespace, cf ConstrainedField, readFrom bool) (uint64, *T, error) { + query := ` { obj(func: eq(%s, %s)) { uid expand(_all_) dgraph.type + %s } } - `, getPredicateName(t.Name(), cf.Key), cf.Value) + ` - return executeGet[T](ctx, n, query, &cf) + return executeGet[T](ctx, n, query, readFrom, cf) } -func executeGet[T any](ctx context.Context, n *Namespace, query string, cf *ConstrainedField) (uint64, *T, error) { +func executeGet[T any, R UniqueField](ctx context.Context, n *Namespace, query string, readFrom bool, args ...R) (uint64, *T, error) { + if len(args) != 1 { + return 0, nil, fmt.Errorf("expected 1 argument, got %d", len(args)) + } + var obj T t := reflect.TypeOf(obj) - jsonFields, dbTags, _, err := getFieldTags(t) + fieldToJsonTags, jsonToDbTag, reverseEdgeTags, err := getFieldTags(t) if err != nil { return 0, nil, err } + readFromQuery := "" + for fieldName, reverseEdgeTag := range reverseEdgeTags { + readFromQuery += fmt.Sprintf(` + %s: ~%s { + uid + expand(_all_) + dgraph.type + } + `, getPredicateName(t.Name(), fieldToJsonTags[fieldName]), reverseEdgeTag) + } - if cf != nil && dbTags[cf.Key].constraint == "" { + var cf ConstrainedField + gid, ok := any(args[0]).(uint64) + if ok { + query = fmt.Sprintf(query, gid, readFromQuery) + } else if cf, ok = any(args[0]).(ConstrainedField); ok { + query = fmt.Sprintf(query, getPredicateName(t.Name(), cf.Key), cf.Value, readFromQuery) + } else { + return 0, nil, fmt.Errorf("invalid unique field type") + } + + if jsonToDbTag[cf.Key] != nil && jsonToDbTag[cf.Key].constraint == "" { return 0, nil, fmt.Errorf("constraint not defined for field %s", cf.Key) } @@ -231,7 +277,7 @@ func executeGet[T any](ctx context.Context, n *Namespace, query string, cf *Cons return 0, nil, err } - dynamicType := createDynamicStruct(t, jsonFields) + dynamicType := createDynamicStruct(t, fieldToJsonTags) dynamicInstance := reflect.New(dynamicType).Interface() @@ -253,7 +299,7 @@ func executeGet[T any](ctx context.Context, n *Namespace, query string, cf *Cons // Map the dynamic struct to the final type T finalObject := reflect.New(t).Interface() - gid, err := mapDynamicToFinal(result.Obj[0], finalObject) + gid, err = mapDynamicToFinal(result.Obj[0], finalObject) if err != nil { return 0, nil, err } @@ -299,3 +345,85 @@ func applyDqlMutations(ctx context.Context, db *DB, dms []*dql.Mutation) error { Txns: []*pb.TxnStatus{{StartTs: startTs, CommitTs: commitTs}}, }) } + +func getUniqueConstraint[T any](object *T) (uint64, *ConstrainedField, error) { + t := reflect.TypeOf(*object) + fieldToJsonTags, jsonToDbTags, _, err := getFieldTags(t) + if err != nil { + return 0, nil, err + } + values := getJsonTagToValues(object, fieldToJsonTags) + + for jsonName, value := range values { + if jsonName == "gid" { + gid, ok := value.(uint64) + if !ok { + return 0, nil, fmt.Errorf("expected uint64 type for gid, got %T", value) + } + if gid != 0 { + return gid, nil, nil + } + } + if jsonToDbTags[jsonName] != nil && jsonToDbTags[jsonName].constraint == "unique" { + return 0, &ConstrainedField{ + Key: jsonName, + Value: value, + }, nil + } + } + + return 0, nil, fmt.Errorf("unique constraint not defined for any field on type %s", t.Name()) +} + +func upsertHelper[T any](ctx context.Context, db *DB, n *Namespace, object *T, readFrom bool) (uint64, *T, bool, error) { + gid, cf, err := getUniqueConstraint(object) + if err != nil { + return 0, object, false, err + } + if gid != 0 { + gid, object, err := getByGid[T](ctx, n, gid, readFrom) + if err != nil { + return 0, object, false, err + } + return gid, object, true, nil + } else if cf != nil { + gid, object, err := getByConstrainedField[T](ctx, n, *cf, readFrom) + if err == nil { + return gid, object, true, nil + } + } + + gid, err = db.z.nextUID() + if err != nil { + return 0, object, false, err + } + + dms := make([]*dql.Mutation, 0) + sch := &schema.ParsedSchema{} + err = generateCreateDqlMutationsAndSchema(ctx, n, object, gid, &dms, sch) + if err != nil { + return 0, object, false, err + } + + ctx = x.AttachNamespace(ctx, n.ID()) + + err = n.alterSchemaWithParsed(ctx, sch) + if err != nil { + return 0, object, false, err + } + + err = applyDqlMutations(ctx, db, dms) + if err != nil { + return 0, object, false, err + } + + v := reflect.ValueOf(object).Elem() + + gidField := v.FieldByName("Gid") + + if gidField.IsValid() && gidField.CanSet() && gidField.Kind() == reflect.Uint64 { + gidField.SetUint(gid) + } + + return gid, object, false, nil +} diff --git a/api_reflect.go b/api_reflect.go index 4c84813..9e1ac2e 100644 --- a/api_reflect.go +++ b/api_reflect.go @@ -11,10 +11,10 @@ type dbTag struct { constraint string } -func getFieldTags(t reflect.Type) (jsonTags map[string]string, jsonToDbTags map[string]*dbTag, - reverseEdgeTags map[string]string, err error) { +func getFieldTags(t reflect.Type) (fieldToJsonTags map[string]string, + jsonToDbTags map[string]*dbTag, reverseEdgeTags map[string]string, err error) { - jsonTags = make(map[string]string) + fieldToJsonTags = make(map[string]string) jsonToDbTags = make(map[string]*dbTag) reverseEdgeTags = make(map[string]string) for i := 0; i < t.NumField(); i++ { @@ -24,7 +24,7 @@ func getFieldTags(t reflect.Type) (jsonTags map[string]string, jsonToDbTags map[ return nil, nil, nil, fmt.Errorf("field %s has no json tag", field.Name) } jsonName := strings.Split(jsonTag, ",")[0] - jsonTags[field.Name] = jsonName + fieldToJsonTags[field.Name] = jsonName reverseEdgeTag := field.Tag.Get("readFrom") if reverseEdgeTag != "" { @@ -50,13 +50,13 @@ func getFieldTags(t reflect.Type) (jsonTags map[string]string, jsonToDbTags map[ } } } - return jsonTags, jsonToDbTags, reverseEdgeTags, nil + return fieldToJsonTags, jsonToDbTags, reverseEdgeTags, nil } -func getFieldValues(object any, jsonFields map[string]string) map[string]any { +func getJsonTagToValues(object any, fieldToJsonTags map[string]string) map[string]any { values := make(map[string]any) v := reflect.ValueOf(object).Elem() - for fieldName, jsonName := range jsonFields { + for fieldName, jsonName := range fieldToJsonTags { fieldValue := v.FieldByName(fieldName) values[jsonName] = fieldValue.Interface() @@ -64,9 +64,9 @@ func getFieldValues(object any, jsonFields map[string]string) map[string]any { return values } -func createDynamicStruct(t reflect.Type, jsonFields map[string]string) reflect.Type { - fields := make([]reflect.StructField, 0, len(jsonFields)) - for fieldName, jsonName := range jsonFields { +func createDynamicStruct(t reflect.Type, fieldToJsonTags map[string]string) reflect.Type { + fields := make([]reflect.StructField, 0, len(fieldToJsonTags)) + for fieldName, jsonName := range fieldToJsonTags { field, _ := t.FieldByName(fieldName) if fieldName != "Gid" { fields = append(fields, reflect.StructField{