Skip to content

Commit

Permalink
chore: Refactoring schema generation (#605)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjohnsonpint authored Nov 22, 2024
1 parent 05abc06 commit 8073ca4
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 131 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- feat: Add API explorer to runtime [#578](https://github.com/hypermodeinc/modus/pull/578)
- feat: Add API explorer component to runtime [#584](https://github.com/hypermodeinc/modus/pull/584)
- fix: logic for jwks endpoint unmarshalling was incorrect [#594](https://github.com/hypermodeinc/modus/pull/594)
- chore: Refactoring schema generation [#605](https://github.com/hypermodeinc/modus/pull/605)

## 2024-11-20 - CLI 0.13.9

Expand Down
111 changes: 56 additions & 55 deletions runtime/graphql/schemagen/schemagen.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,23 +43,20 @@ func GetGraphQLSchema(ctx context.Context, md *metadata.Metadata) (*GraphQLSchem
inputTypeDefs, errors := transformTypes(md.Types, lti, true)
resultTypeDefs, errs := transformTypes(md.Types, lti, false)
errors = append(errors, errs...)
fieldsToFunctions, queryFields, mutationFields, errs := transformFunctions(md.FnExports, inputTypeDefs, resultTypeDefs, lti)
root, errs := transformFunctions(md.FnExports, inputTypeDefs, resultTypeDefs, lti)
errors = append(errors, errs...)

if len(errors) > 0 {
return nil, fmt.Errorf("failed to generate schema: %+v", errors)
}

queryFields = filterFields(queryFields)
mutationFields = filterFields(mutationFields)
allFields := append(queryFields, mutationFields...)

allFields := root.AllFields()
scalarTypes := extractCustomScalarTypes(inputTypeDefs, resultTypeDefs)
inputTypes := filterTypes(utils.MapValues(inputTypeDefs), allFields, true)
resultTypes := filterTypes(utils.MapValues(resultTypeDefs), allFields, false)

buf := bytes.Buffer{}
writeSchema(&buf, queryFields, mutationFields, scalarTypes, inputTypes, resultTypes)
writeSchema(&buf, root, scalarTypes, inputTypes, resultTypes)

mapTypes := make([]string, 0, len(resultTypeDefs))
for _, t := range resultTypeDefs {
Expand All @@ -68,6 +65,11 @@ func GetGraphQLSchema(ctx context.Context, md *metadata.Metadata) (*GraphQLSchem
}
}

fieldsToFunctions := make(map[string]string, len(allFields))
for _, f := range allFields {
fieldsToFunctions[f.Name] = f.Function
}

return &GraphQLSchema{
Schema: buf.String(),
FieldsToFunctions: fieldsToFunctions,
Expand Down Expand Up @@ -125,35 +127,38 @@ func transformTypes(types metadata.TypeMap, lti langsupport.LanguageTypeInfo, fo
}

type FieldDefinition struct {
Name string
Arguments []*ArgumentDefinition
ReturnType string
Name string
Type string
Arguments []*ArgumentDefinition
Function string
}

type TypeDefinition struct {
Name string
Fields []*NameTypePair
Fields []*FieldDefinition
IsMapType bool
}

type NameTypePair struct {
Name string
Type string
}

type ArgumentDefinition struct {
Name string
Type string
Default *any
}

// TODO: refactor for readability
type RootObjects struct {
QueryFields []*FieldDefinition
MutationFields []*FieldDefinition
}

func transformFunctions(functions metadata.FunctionMap, inputTypeDefs, resultTypeDefs map[string]*TypeDefinition, lti langsupport.LanguageTypeInfo) (map[string]string, []*FieldDefinition, []*FieldDefinition, []*TransformError) {
fieldsToFunctions := make(map[string]string, len(functions))
func (r *RootObjects) AllFields() []*FieldDefinition {
return append(r.QueryFields, r.MutationFields...)
}

func transformFunctions(functions metadata.FunctionMap, inputTypeDefs, resultTypeDefs map[string]*TypeDefinition, lti langsupport.LanguageTypeInfo) (*RootObjects, []*TransformError) {
queryFields := make([]*FieldDefinition, 0, len(functions))
mutationFields := make([]*FieldDefinition, 0, len(functions))
errors := make([]*TransformError, 0)
filter := getFieldFilter()

fnNames := utils.MapKeys(functions)
sort.Strings(fnNames)
Expand All @@ -173,34 +178,29 @@ func transformFunctions(functions metadata.FunctionMap, inputTypeDefs, resultTyp
}

fieldName := getFieldName(fn.Name)
fieldsToFunctions[fieldName] = fn.Name

field := &FieldDefinition{
Name: fieldName,
Arguments: args,
ReturnType: returnType,
Name: fieldName,
Arguments: args,
Type: returnType,
Function: fn.Name,
}

if isMutation(fn.Name) {
mutationFields = append(mutationFields, field)
} else {
queryFields = append(queryFields, field)
if filter(field) {
if isMutation(fn.Name) {
mutationFields = append(mutationFields, field)
} else {
queryFields = append(queryFields, field)
}
}
}

return fieldsToFunctions, queryFields, mutationFields, errors
}

func filterFields(fields []*FieldDefinition) []*FieldDefinition {
filter := getFieldFilter()
results := make([]*FieldDefinition, 0, len(fields))
for _, f := range fields {
if filter(f) {
results = append(results, f)
}
results := &RootObjects{
QueryFields: queryFields,
MutationFields: mutationFields,
}

return results
return results, errors
}

func filterTypes(types []*TypeDefinition, fields []*FieldDefinition, forInput bool) []*TypeDefinition {
Expand All @@ -222,7 +222,7 @@ func filterTypes(types []*TypeDefinition, fields []*FieldDefinition, forInput bo
addUsedTypes(p.Type, typeMap, usedTypes)
}
} else {
addUsedTypes(f.ReturnType, typeMap, usedTypes)
addUsedTypes(f.Type, typeMap, usedTypes)
}
}

Expand Down Expand Up @@ -278,16 +278,16 @@ func getBaseType(name string) string {
return name
}

func writeSchema(buf *bytes.Buffer, queryFields []*FieldDefinition, mutationFields []*FieldDefinition, scalarTypes []string, inputTypeDefs, resultTypeDefs []*TypeDefinition) {
func writeSchema(buf *bytes.Buffer, root *RootObjects, scalarTypes []string, inputTypeDefs, resultTypeDefs []*TypeDefinition) {

// write header
buf.WriteString("# Modus GraphQL Schema (auto-generated)\n")

// sort everything
slices.SortFunc(queryFields, func(a, b *FieldDefinition) int {
slices.SortFunc(root.QueryFields, func(a, b *FieldDefinition) int {
return cmp.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
})
slices.SortFunc(mutationFields, func(a, b *FieldDefinition) int {
slices.SortFunc(root.MutationFields, func(a, b *FieldDefinition) int {
return cmp.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
})
slices.SortFunc(scalarTypes, func(a, b string) int {
Expand All @@ -301,20 +301,20 @@ func writeSchema(buf *bytes.Buffer, queryFields []*FieldDefinition, mutationFiel
})

// write query object
if len(queryFields) > 0 {
if len(root.QueryFields) > 0 {
buf.WriteByte('\n')
buf.WriteString("type Query {\n")
for _, field := range queryFields {
for _, field := range root.QueryFields {
writeField(buf, field)
}
buf.WriteString("}\n")
}

// write mutation object
if len(mutationFields) > 0 {
if len(root.MutationFields) > 0 {
buf.WriteByte('\n')
buf.WriteString("type Mutation {\n")
for _, field := range mutationFields {
for _, field := range root.MutationFields {
writeField(buf, field)
}
buf.WriteString("}\n")
Expand Down Expand Up @@ -386,7 +386,7 @@ func writeField(buf *bytes.Buffer, field *FieldDefinition) {
buf.WriteByte(')')
}
buf.WriteString(": ")
buf.WriteString(field.ReturnType)
buf.WriteString(field.Type)
buf.WriteByte('\n')
}

Expand Down Expand Up @@ -421,7 +421,7 @@ func convertResults(results []*metadata.Result, lti langsupport.LanguageTypeInfo
return convertType(results[0].Type, lti, typeDefs, false, false)
}

fields := make([]*NameTypePair, len(results))
fields := make([]*FieldDefinition, len(results))
for i, r := range results {
name := r.Name
if name == "" {
Expand All @@ -433,7 +433,7 @@ func convertResults(results []*metadata.Result, lti langsupport.LanguageTypeInfo
return "", err
}

fields[i] = &NameTypePair{
fields[i] = &FieldDefinition{
Name: name,
Type: typ,
}
Expand All @@ -443,7 +443,7 @@ func convertResults(results []*metadata.Result, lti langsupport.LanguageTypeInfo
return t, nil
}

func getTypeForFields(fields []*NameTypePair, typeDefs map[string]*TypeDefinition) string {
func getTypeForFields(fields []*FieldDefinition, typeDefs map[string]*TypeDefinition) string {
// see if an existing type already matches
for _, t := range typeDefs {
if len(t.Fields) != len(fields) {
Expand Down Expand Up @@ -475,18 +475,18 @@ func getTypeForFields(fields []*NameTypePair, typeDefs map[string]*TypeDefinitio
return newType(name, fields, typeDefs)
}

func convertFields(fields []*metadata.Field, lti langsupport.LanguageTypeInfo, typeDefs map[string]*TypeDefinition, forInput bool) ([]*NameTypePair, error) {
func convertFields(fields []*metadata.Field, lti langsupport.LanguageTypeInfo, typeDefs map[string]*TypeDefinition, forInput bool) ([]*FieldDefinition, error) {
if len(fields) == 0 {
return nil, nil
}

results := make([]*NameTypePair, len(fields))
results := make([]*FieldDefinition, len(fields))
for i, f := range fields {
t, err := convertType(f.Type, lti, typeDefs, true, forInput)
if err != nil {
return nil, err
}
results[i] = &NameTypePair{
results[i] = &FieldDefinition{
Name: f.Name,
Type: t,
}
Expand Down Expand Up @@ -633,7 +633,8 @@ func convertType(typ string, lti langsupport.LanguageTypeInfo, typeDefs map[stri
typeName += "Input"
}

newMapType(typeName, []*NameTypePair{{"key", kt}, {"value", vt}}, typeDefs)
fields := []*FieldDefinition{{Name: "key", Type: kt}, {Name: "value", Type: vt}}
newMapType(typeName, fields, typeDefs)

// The map is represented as a list of the pair type.
// The list might be nullable, but the pair type within the list is always non-nullable.
Expand Down Expand Up @@ -675,7 +676,7 @@ func newScalar(name string, typeDefs map[string]*TypeDefinition) string {
return newType(name, nil, typeDefs)
}

func newType(name string, fields []*NameTypePair, typeDefs map[string]*TypeDefinition) string {
func newType(name string, fields []*FieldDefinition, typeDefs map[string]*TypeDefinition) string {
if _, ok := typeDefs[name]; !ok {
typeDefs[name] = &TypeDefinition{
Name: name,
Expand All @@ -685,7 +686,7 @@ func newType(name string, fields []*NameTypePair, typeDefs map[string]*TypeDefin
return name
}

func newMapType(name string, fields []*NameTypePair, typeDefs map[string]*TypeDefinition) string {
func newMapType(name string, fields []*FieldDefinition, typeDefs map[string]*TypeDefinition) string {
if _, ok := typeDefs[name]; !ok {
typeDefs[name] = &TypeDefinition{
Name: name,
Expand Down
Loading

0 comments on commit 8073ca4

Please sign in to comment.