Skip to content

Commit

Permalink
parse AST instead of doing extra steps
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewpeterkort committed Dec 20, 2024
1 parent e87bca4 commit c4c5621
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 265 deletions.
193 changes: 120 additions & 73 deletions gql-gen/graph/collectFields.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,99 +2,146 @@ package graph

import (
"context"
"sort"
"strings"
"fmt"

"github.com/99designs/gqlgen/graphql"
"github.com/vektah/gqlparser/v2/ast"

"github.com/bmeg/grip/gripql"
)

func RemoveIndex(s []string, index int) []string {
return append(s[:index], s[index+1:]...)
type Resolver struct {
GripDb gripql.Client
Schema *ast.Schema
}

func refinePaths(fields []string) []string {
/*Remove fields that aren't neccessary in traversing the graph*/
var refined []string
for _, field := range fields {
segments := strings.Split(field, ".")
if len(segments) > 1 && !strings.HasSuffix(segments[0], "Type") {
for i := 1; i < len(segments); i++ {
if strings.HasSuffix(segments[i], "Type") {
segments = RemoveIndex(segments, i-1)
}
}
}
refinedField := strings.Join(segments, ".")
refined = append(refined, refinedField)
type renderTree struct {
prevName string
moved bool
fields []string
parent map[string]string
fieldName map[string]string
}

}
type objectMap struct {
edgeLabel map[string]map[string]struct{} // Maps vertex labels to edge names
edgeDstType map[string]map[string]string // Maps vertex labels and edge names to destination labels
}

return refined
func (rt *renderTree) NewElement(cur string, fieldName string) string {
rName := fmt.Sprintf("f%d", len(rt.fields))
rt.fields = append(rt.fields, rName)
rt.parent[rName] = cur
rt.fieldName[rName] = fieldName
return rName
}

func SortFieldPaths(fields []string, rootType string) []string {
//refined := refinePaths(fields)
sort.Slice(fields, func(i, j int) bool {
countTypeI, nestingLevelI := countTypeOccurrencesAndLevels(fields[i])
countTypeJ, nestingLevelJ := countTypeOccurrencesAndLevels(fields[j])
func traversalBuild(reqCtx *graphql.OperationContext, query **gripql.Query, selSet ast.SelectionSet, curElement string, rt *renderTree, visited map[string]bool) []graphql.CollectedField {
fmt.Printf("\n\n")
groupedFields := make([]graphql.CollectedField, 0, len(selSet))
for _, s := range selSet {
switch sel := s.(type) {
case *ast.Field:
if _, ok := visited[sel.Name]; ok {
continue
}
visited[sel.Name] = true

// Prioritize fewer "Type" segments
if countTypeI != countTypeJ {
return countTypeI < countTypeJ
}
fmt.Printf("FIELD NAME: %s\n", sel.Name)
fmt.Printf("FIELD DEF NAME %s\n", sel.ObjectDefinition.Name)
if rt.moved {
*query = (*query).Select(curElement)
rt.moved = false
}
rt.prevName = sel.Name
elem := rt.NewElement(sel.ObjectDefinition.Name, sel.Name)
for _, childField := range traversalBuild(reqCtx, query, sel.SelectionSet, elem, rt, visited) {
_ = rt.NewElement(sel.Name, childField.Name)
fmt.Println("CUR NAME: ", sel.Name, "CHILD NAME: ", childField.Name)
}
case *ast.InlineFragment:
elem := rt.NewElement(rt.prevName, sel.TypeCondition)

// Within same Type count, prioritize fewer nesting levels
if nestingLevelI != nestingLevelJ {
return nestingLevelI < nestingLevelJ
}
return fields[i] < fields[j]
})
return fields
}
fmt.Printf("INLINE FRAG Type CONDITION: %s\n", sel.TypeCondition)
fmt.Printf("InlineFragment DEF NAME %s\n", sel.ObjectDefinition.Name)
typeConditionLen := len(sel.TypeCondition)
*query = (*query).OutNull(rt.prevName + "_" + sel.TypeCondition[:typeConditionLen-4]).As(elem)
for _, childField := range traversalBuild(reqCtx, query, sel.SelectionSet, elem, rt, visited) {
_ = rt.NewElement(rt.prevName, childField.Name)
fmt.Println("InlineFragment CUR NAME: ", sel.ObjectDefinition.Name, "CHILD NAME: ", childField.Name)
}
rt.moved = true

// Helper function to count "Type" segments and nesting levels
func countTypeOccurrencesAndLevels(field string) (typeCount int, nestingLevel int) {
typeCount = 0
segments := strings.Split(field, ".")
for _, segment := range segments {
if strings.HasSuffix(segment, "Type") {
typeCount++
case *ast.FragmentSpread:
fmt.Println("FRAG SPREAD: ", sel.Definition.Name)
default:
panic(fmt.Errorf("unsupported %T", sel))
}

}
nestingLevel = len(segments) - 1
return typeCount, nestingLevel
return groupedFields
}

func GetQueryFields(ctx context.Context, rootType string) []string {
fields := GetNestedPreloads(
graphql.GetOperationContext(ctx),
graphql.CollectFieldsCtx(ctx, []string{}),
"",
rootType,
)
return SortFieldPaths(fields, rootType)
}
func (r *queryResolver) GetSelectedFieldsAst(ctx context.Context, sourceType string) {
resctx := graphql.GetFieldContext(ctx)
opCtx := graphql.GetOperationContext(ctx)
rt := &renderTree{
fields: []string{"f0"},
parent: map[string]string{},
fieldName: map[string]string{},
}
q := gripql.V().HasLabel(sourceType[:len(sourceType)-4]).As("f0")
//for _, field := range resctx.Field.Selections {
_ = traversalBuild(opCtx, &q, resctx.Field.Selections, "f0", rt, map[string]bool{})
fmt.Println("QUERY AFTER: ", q)
fmt.Printf("RENDER TREE FIELDS: %#v\n", rt.fields)
fmt.Printf("RENDER TREE PARENT: %#v\n", rt.parent)
fmt.Printf("RENDER TREE FieldName: %#v\n", rt.fieldName)

func GetNestedPreloads(ctx *graphql.OperationContext, fields []graphql.CollectedField, prefix string, rootType string) (preloads []string) {
for _, column := range fields {
prefixColumn := GetPreloadString(prefix, column, rootType)
nestedFields := graphql.CollectFields(ctx, column.Selections, []string{})
if len(nestedFields) == 0 {
preloads = append(preloads, prefixColumn)
} else {
preloads = append(preloads, GetNestedPreloads(ctx, nestedFields, prefixColumn, rootType)...)
}
render := map[string]any{}
for _, i := range rt.fields {
render[i+"_gid"] = "$" + i + "._gid"
render[i+"_data"] = "$" + i + "._data"
}
return preloads
}

func GetPreloadString(prefix string, name graphql.CollectedField, rootType string) string {
// If edge out to another type, traverse to that type
if strings.HasSuffix(name.ObjectDefinition.Name, "Type") && name.ObjectDefinition.Name != rootType {
return prefix + "." + name.ObjectDefinition.Name + "." + name.Name
fmt.Printf("RENDER: %#v\n", render)
q = q.Render(render)

result, err := r.GripDb.Traversal(context.Background(), &gripql.GraphQuery{Graph: "CALIPER", Query: q.Statements})
if err != nil {
fmt.Printf("ERR: %s\n", err)
}
if len(prefix) > 0 {
return prefix + "." + name.Name

out := []any{}
for r := range result {
values := r.GetRender().GetStructValue().AsMap()
fmt.Println("VALUES: ", values)

data := map[string]map[string]any{}
for _, r := range rt.fields {
v := values[r+"_data"]
fmt.Println("V:", v)
if d, ok := v.(map[string]any); ok {
fmt.Println("HELLO IN HERE")
d["id"] = values[r+"_gid"]
fmt.Println("D: ", d)
if d["id"] != "" {
data[r] = d
}
}
}
for _, r := range rt.fields {
fmt.Println("RT PARENT: ", rt.parent, "R: ", r)
if parent, ok := rt.parent[r]; ok {
fieldName := rt.fieldName[r]
fmt.Println("DATA: ", data)
if data[r] != nil {
data[parent][fieldName] = []any{data[r]}
}
}
}
fmt.Println("DATA: ", data)
out = append(out, data["f0"])
}
return name.Name

}
57 changes: 0 additions & 57 deletions gql-gen/graph/gripFetch.go

This file was deleted.

10 changes: 5 additions & 5 deletions gql-gen/graph/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ package graph

import (
"context"
"fmt"

"github.com/bmeg/grip-graphql/gql-gen/generated"

Expand Down Expand Up @@ -63,10 +62,11 @@ func (r *queryResolver) Specimen(ctx context.Context, offset *int, first *int, f

// Observation is the resolver for the observation field.
func (r *queryResolver) Observation(ctx context.Context, offset *int, first *int, filter *string, sort *string, accessibility *model.Accessibility, format *model.Format) ([]*model.ObservationType, error) {
sourceType := "ObservationType"
fields := GetQueryFields(ctx, sourceType)
res := r.gripQuery(fields, sourceType)
fmt.Println("RES: ", res)
//sourceType := "ObservationType"
//fields := GetQueryFields(ctx, sourceType)
//res := r.gripQuery(fields, sourceType)
//fmt.Println("RES: ", res)
r.GetSelectedFieldsAst(ctx, "ObservationType")

/*for _, field := range fields {
fmt.Println("PATH: ", field)
Expand Down
Loading

0 comments on commit c4c5621

Please sign in to comment.