diff --git a/packages.go b/packages.go index 0cd11fe77..274fc3081 100644 --- a/packages.go +++ b/packages.go @@ -192,6 +192,7 @@ func (pkgDefs *PackagesDefinitions) parseFunctionScopedTypesFromFile(astFile *as for _, astDeclaration := range astFile.Decls { funcDeclaration, ok := astDeclaration.(*ast.FuncDecl) if ok && funcDeclaration.Body != nil { + functionScopedTypes := make(map[string]*TypeSpecDef) for _, stmt := range funcDeclaration.Body.List { if declStmt, ok := (stmt).(*ast.DeclStmt); ok { if genDecl, ok := (declStmt.Decl).(*ast.GenDecl); ok && genDecl.Tok == token.TYPE { @@ -212,12 +213,28 @@ func (pkgDefs *PackagesDefinitions) parseFunctionScopedTypesFromFile(astFile *as } } + fullName := typeSpecDef.TypeName() + if structType, ok := typeSpecDef.TypeSpec.Type.(*ast.StructType); ok { + for _, field := range structType.Fields.List { + if idt, ok := field.Type.(*ast.Ident); ok && !IsGolangPrimitiveType(idt.Name) { + if functype, ok := functionScopedTypes[idt.Name]; ok { + idt.Name = functype.TypeName() + } + } + if art, ok := field.Type.(*ast.ArrayType); ok { + if idt, ok := art.Elt.(*ast.Ident); ok && !IsGolangPrimitiveType(idt.Name) { + if functype, ok := functionScopedTypes[idt.Name]; ok { + idt.Name = functype.TypeName() + } + } + } + } + } + if pkgDefs.uniqueDefinitions == nil { pkgDefs.uniqueDefinitions = make(map[string]*TypeSpecDef) } - fullName := typeSpecDef.TypeName() - anotherTypeDef, ok := pkgDefs.uniqueDefinitions[fullName] if ok { if anotherTypeDef == nil { @@ -234,6 +251,7 @@ func (pkgDefs *PackagesDefinitions) parseFunctionScopedTypesFromFile(astFile *as } } else { pkgDefs.uniqueDefinitions[fullName] = typeSpecDef + functionScopedTypes[typeSpec.Name.Name] = typeSpecDef } if pkgDefs.packages[typeSpecDef.PkgPath] == nil { diff --git a/parser_test.go b/parser_test.go index c92f3dbda..c2bc7546d 100644 --- a/parser_test.go +++ b/parser_test.go @@ -3389,6 +3389,49 @@ func Fun() { assert.True(t, ok) } +func TestParseFunctionScopedComplexStructDefinition(t *testing.T) { + t.Parallel() + + src := ` +package main + +// @Param request body main.Fun.request true "query params" +// @Success 200 {object} main.Fun.response +// @Router /test [post] +func Fun() { + type request struct { + Name string + } + + type grandChild struct { + Name string + } + + type child struct { + GrandChild grandChild + } + + type response struct { + Children []child + } +} +` + p := New() + _ = p.packages.ParseFile("api", "api/api.go", src, ParseAll) + _, err := p.packages.ParseTypes() + assert.NoError(t, err) + + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) + assert.NoError(t, err) + + _, ok := p.swagger.Definitions["main.Fun.response"] + assert.True(t, ok) + _, ok = p.swagger.Definitions["main.Fun.child"] + assert.True(t, ok) + _, ok = p.swagger.Definitions["main.Fun.grandChild"] + assert.True(t, ok) +} + func TestParseFunctionScopedStructRequestResponseJSON(t *testing.T) { t.Parallel() @@ -3474,6 +3517,130 @@ func Fun() { assert.Equal(t, expected, string(b)) } +func TestParseFunctionScopedComplexStructRequestResponseJSON(t *testing.T) { + t.Parallel() + + src := ` +package main + +type PublicChild struct { + Name string +} + +// @Param request body main.Fun.request true "query params" +// @Success 200 {object} main.Fun.response +// @Router /test [post] +func Fun() { + type request struct { + Name string + } + + type grandChild struct { + Name string + } + + type child struct { + GrandChild grandChild + } + + type response struct { + Children []child + PublicChild PublicChild + } +} +` + expected := `{ + "info": { + "contact": {} + }, + "paths": { + "/test": { + "post": { + "parameters": [ + { + "description": "query params", + "name": "request", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/main.Fun.request" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/main.Fun.response" + } + } + } + } + } + }, + "definitions": { + "main.Fun.child": { + "type": "object", + "properties": { + "grandChild": { + "$ref": "#/definitions/main.Fun.grandChild" + } + } + }, + "main.Fun.grandChild": { + "type": "object", + "properties": { + "name": { + "type": "string" + } + } + }, + "main.Fun.request": { + "type": "object", + "properties": { + "name": { + "type": "string" + } + } + }, + "main.Fun.response": { + "type": "object", + "properties": { + "children": { + "type": "array", + "items": { + "$ref": "#/definitions/main.Fun.child" + } + }, + "publicChild": { + "$ref": "#/definitions/main.PublicChild" + } + } + }, + "main.PublicChild": { + "type": "object", + "properties": { + "name": { + "type": "string" + } + } + } + } +}` + + p := New() + _ = p.packages.ParseFile("api", "api/api.go", src, ParseAll) + + _, err := p.packages.ParseTypes() + assert.NoError(t, err) + + err = p.packages.RangeFiles(p.ParseRouterAPIInfo) + assert.NoError(t, err) + + b, _ := json.MarshalIndent(p.swagger, "", " ") + assert.Equal(t, expected, string(b)) +} + func TestPackagesDefinitions_CollectAstFileInit(t *testing.T) { t.Parallel() diff --git a/testdata/simple/api/api.go b/testdata/simple/api/api.go index 85a7fa48f..3bc6ec391 100644 --- a/testdata/simple/api/api.go +++ b/testdata/simple/api/api.go @@ -138,3 +138,15 @@ func GetPet6FunctionScopedResponse() { Name string } } + +// @Success 200 {object} api.GetPet6FunctionScopedComplexResponse.response "ok" +// @Router /GetPet6FunctionScopedComplexResponse [get] +func GetPet6FunctionScopedComplexResponse() { + type child struct { + Name string + } + + type response struct { + Child child + } +} diff --git a/testdata/simple/expected.json b/testdata/simple/expected.json index b3f642629..e2a0dd71a 100644 --- a/testdata/simple/expected.json +++ b/testdata/simple/expected.json @@ -113,6 +113,18 @@ } } }, + "/GetPet6FunctionScopedComplexResponse": { + "get": { + "responses": { + "200": { + "description": "ok", + "schema": { + "$ref": "#/definitions/api.GetPet6FunctionScopedComplexResponse.response" + } + } + } + } + }, "/GetPet6MapString": { "get": { "responses": { @@ -401,6 +413,25 @@ } }, "definitions": { + "api.GetPet6FunctionScopedComplexResponse.pet": { + "type": "object", + "properties": { + "name": { + "type": "string" + } + } + }, + "api.GetPet6FunctionScopedComplexResponse.response": { + "type": "object", + "properties": { + "Pets": { + "type": "array", + "items": { + "$ref": "#/definitions/api.GetPet6FunctionScopedComplexResponse.pet" + } + } + } + }, "api.GetPet6FunctionScopedResponse.response": { "type": "object", "properties": {