-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
256 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
package modusdb | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"reflect" | ||
"strings" | ||
|
||
"github.com/dgraph-io/dgo/v240/protos/api" | ||
"github.com/dgraph-io/dgraph/v24/dql" | ||
"github.com/dgraph-io/dgraph/v24/protos/pb" | ||
"github.com/dgraph-io/dgraph/v24/query" | ||
"github.com/dgraph-io/dgraph/v24/worker" | ||
"github.com/dgraph-io/dgraph/v24/x" | ||
) | ||
|
||
type UniqueField interface{ | ||
uint64 | ConstrainedField | ||
} | ||
type ConstrainedField struct { | ||
key string | ||
value any | ||
} | ||
|
||
func getFieldTags(t reflect.Type) (jsonTags map[string]string, reverseEdgeTags map[string]string, err error) { | ||
jsonTags = make(map[string]string) | ||
reverseEdgeTags = make(map[string]string) | ||
for i := 0; i < t.NumField(); i++ { | ||
field := t.Field(i) | ||
jsonTag := field.Tag.Get("json") | ||
if jsonTag == "" { | ||
return nil, nil, fmt.Errorf("field %s has no json tag", field.Name) | ||
} | ||
jsonName := strings.Split(jsonTag, ",")[0] | ||
jsonTags[field.Name] = jsonName | ||
reverseEdgeTag := field.Tag.Get("readFrom") | ||
if reverseEdgeTag != "" { | ||
typeAndField := strings.Split(reverseEdgeTag, ",") | ||
if len(typeAndField) != 2 { | ||
return nil, nil, fmt.Errorf("field %s has invalid readFrom tag, expected format is type=<type>,field=<field>", field.Name) | ||
} | ||
t := strings.Split(typeAndField[0], "=")[1] | ||
f := strings.Split(typeAndField[1], "=")[1] | ||
reverseEdgeTags[field.Name] = getPredicateName(t, f) | ||
} | ||
} | ||
return jsonTags, reverseEdgeTags, nil | ||
} | ||
|
||
func getFieldValues(object any, jsonFields map[string]string) map[string]any { | ||
values := make(map[string]any) | ||
v := reflect.ValueOf(object).Elem() | ||
for fieldName, jsonName := range jsonFields { | ||
fieldValue := v.FieldByName(fieldName) | ||
values[jsonName] = fieldValue.Interface() | ||
|
||
} | ||
return values | ||
} | ||
|
||
func getPredicateName(typeName, fieldName string) string { | ||
return fmt.Sprint(typeName, ".", fieldName) | ||
} | ||
|
||
func valueToValType(v any) *api.Value { | ||
switch val := v.(type) { | ||
case string: | ||
return &api.Value{Val: &api.Value_StrVal{StrVal: val}} | ||
case int: | ||
return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}} | ||
case int64: | ||
return &api.Value{Val: &api.Value_IntVal{IntVal: val}} | ||
case uint64: | ||
return &api.Value{Val: &api.Value_IntVal{IntVal: int64(val)}} | ||
case bool: | ||
return &api.Value{Val: &api.Value_BoolVal{BoolVal: val}} | ||
case float64: | ||
return &api.Value{Val: &api.Value_DoubleVal{DoubleVal: val}} | ||
default: | ||
return &api.Value{Val: &api.Value_DefaultVal{DefaultVal: fmt.Sprint(v)}} | ||
} | ||
} | ||
|
||
func Create[T any](ctx context.Context, n *Namespace, object *T) (uint64, *T, error){ | ||
uids, err := n.db.z.nextUIDs(&pb.Num{Val: uint64(1), Type: pb.Num_UID}) | ||
if err != nil { | ||
return 0, object, err | ||
} | ||
|
||
t := reflect.TypeOf(*object) | ||
if t.Kind() != reflect.Struct { | ||
return 0, object, fmt.Errorf("expected struct, got %s", t.Kind()) | ||
} | ||
|
||
jsonFields, _, err := getFieldTags(t) | ||
if err != nil { | ||
return 0, object, err | ||
} | ||
values := getFieldValues(object, jsonFields) | ||
|
||
|
||
|
||
nquads := make([]*api.NQuad, 0) | ||
for jsonName, value := range values { | ||
if jsonName == "uid" { | ||
continue | ||
} | ||
nquad := &api.NQuad{ | ||
Namespace: n.ID(), | ||
Subject: fmt.Sprint(uids.StartId), | ||
Predicate: getPredicateName(t.Name(), jsonName), | ||
ObjectValue: valueToValType(value), | ||
} | ||
nquads = append(nquads, nquad) | ||
} | ||
|
||
dms := make([]*dql.Mutation, 0) | ||
dms = append(dms, &dql.Mutation{ | ||
Set: nquads, | ||
}) | ||
edges, err := query.ToDirectedEdges(dms, nil) | ||
if err != nil { | ||
return 0, object, err | ||
} | ||
ctx = x.AttachNamespace(ctx, n.ID()) | ||
|
||
n.db.mutex.Lock() | ||
defer n.db.mutex.Unlock() | ||
|
||
if !n.db.isOpen { | ||
return 0, object, ErrClosedDB | ||
} | ||
|
||
startTs, err := n.db.z.nextTs() | ||
if err != nil { | ||
return 0, object, err | ||
} | ||
commitTs, err := n.db.z.nextTs() | ||
if err != nil { | ||
return 0, object, err | ||
} | ||
|
||
m := &pb.Mutations{ | ||
GroupId: 1, | ||
StartTs: startTs, | ||
Edges: edges, | ||
} | ||
m.Edges, err = query.ExpandEdges(ctx, m) | ||
if err != nil { | ||
return 0, object, fmt.Errorf("error expanding edges: %w", err) | ||
} | ||
|
||
for _, edge := range m.Edges { | ||
worker.InitTablet(edge.Attr) | ||
} | ||
|
||
p := &pb.Proposal{Mutations: m, StartTs: startTs} | ||
if err := worker.ApplyMutations(ctx, p); err != nil { | ||
return 0, object, err | ||
} | ||
|
||
err = worker.ApplyCommited(ctx, &pb.OracleDelta{ | ||
Txns: []*pb.TxnStatus{{StartTs: startTs, CommitTs: commitTs}}, | ||
}) | ||
if err != nil { | ||
return 0, object, err | ||
} | ||
|
||
v := reflect.ValueOf(object).Elem() | ||
|
||
uidField := v.FieldByName("Uid") | ||
|
||
if uidField.IsValid() && uidField.CanSet() && uidField.Kind() == reflect.Uint64 { | ||
uidField.SetUint(uids.StartId) | ||
} | ||
|
||
return uids.StartId, object, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
package modusdb_test | ||
|
||
import ( | ||
"context" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/require" | ||
|
||
"github.com/hypermodeinc/modusdb" | ||
) | ||
|
||
type User struct{ | ||
Uid uint64 `json:"uid"` | ||
Name string `json:"name"` | ||
Age int `json:"age"` | ||
} | ||
|
||
func TestCreateApi(t *testing.T) { | ||
db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) | ||
require.NoError(t, err) | ||
defer db.Close() | ||
|
||
db1, err := db.CreateNamespace() | ||
require.NoError(t, err) | ||
|
||
require.NoError(t, db1.DropData(context.Background())) | ||
|
||
user := &User{ | ||
Name: "B", | ||
Age: 20, | ||
} | ||
|
||
uid, _, err := modusdb.Create(context.Background(), db1, user) | ||
require.NoError(t, err) | ||
|
||
require.Equal(t, "B", user.Name) | ||
require.Equal(t, uint64(2), uid) | ||
require.Equal(t, uint64(2), user.Uid) | ||
|
||
query := `{ | ||
me(func: has(User.name)) { | ||
uid | ||
User.name | ||
User.age | ||
} | ||
}` | ||
resp, err := db1.Query(context.Background(), query) | ||
require.NoError(t, err) | ||
require.JSONEq(t, `{"me":[{"uid":"0x2","User.name":"B","User.age":20}]}`, string(resp.GetJson())) | ||
} | ||
|
||
func TestCreateApiWithNonStruct(t *testing.T) { | ||
db, err := modusdb.New(modusdb.NewDefaultConfig(t.TempDir())) | ||
require.NoError(t, err) | ||
defer db.Close() | ||
|
||
db1, err := db.CreateNamespace() | ||
require.NoError(t, err) | ||
|
||
require.NoError(t, db1.DropData(context.Background())) | ||
|
||
user := &User{ | ||
Name: "B", | ||
Age: 20, | ||
} | ||
|
||
_, _, err = modusdb.Create[*User](context.Background(), db1, &user) | ||
require.Error(t, err) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters