diff --git a/examples/no_cycles_test.go b/examples/no_cycles_test.go index 5d63595b352..4ad3f120aa3 100644 --- a/examples/no_cycles_test.go +++ b/examples/no_cycles_test.go @@ -42,63 +42,12 @@ func TestNoCycles(t *testing.T) { } // detect cycles + visited := make(map[string]bool) for _, p := range pkgs { - require.NoError(t, detectCycles(p, pkgs)) + require.NoError(t, detectCycles(p, pkgs, visited)) } } -type testPkg struct { - Dir string - PkgPath string - Imports packages.ImportsMap -} - -// listPkgs lists all packages in rootMod -func listPkgs(rootMod gnomod.Pkg) ([]testPkg, error) { - res := []testPkg{} - rootDir := rootMod.Dir - visited := map[string]struct{}{} - if err := fs.WalkDir(os.DirFS(rootDir), ".", func(path string, d fs.DirEntry, err error) error { - if err != nil { - return err - } - if d.IsDir() { - return nil - } - if !strings.HasSuffix(d.Name(), ".gno") { - return nil - } - subPath := filepath.Dir(path) - dir := filepath.Join(rootDir, subPath) - if _, ok := visited[dir]; ok { - return nil - } - visited[dir] = struct{}{} - - subPkgPath := pathlib.Join(rootMod.Name, subPath) - - pkg := testPkg{ - Dir: dir, - PkgPath: subPkgPath, - } - - memPkg, err := readPkg(pkg.Dir, pkg.PkgPath) - if err != nil { - return fmt.Errorf("read pkg %q: %w", pkg.Dir, err) - } - pkg.Imports, err = packages.Imports(memPkg, nil) - if err != nil { - return fmt.Errorf("list imports of %q: %w", memPkg.Path, err) - } - - res = append(res, pkg) - return nil - }); err != nil { - return nil, fmt.Errorf("walk dirs at %q: %w", rootDir, err) - } - return res, nil -} - // detectCycles detects import cycles // // We need to check @@ -113,7 +62,7 @@ func listPkgs(rootMod gnomod.Pkg) ([]testPkg, error) { // // The tricky thing is that we need to split test sources and normal source // while not considering them as distincitive packages. -// Because if we don't we will have false positive for example if we have these edges: +// Otherwise we will have false positive for example if we have these edges: // // - foo_pkg/foo_test.go imports bar_pkg // @@ -125,9 +74,8 @@ func listPkgs(rootMod gnomod.Pkg) ([]testPkg, error) { // - foo_pkg/foo.go imports bar_pkg // // - bar_pkg/bar_test.go imports foo_pkg -func detectCycles(root testPkg, pkgs []testPkg) error { +func detectCycles(root testPkg, pkgs []testPkg, visited map[string]bool) error { // check cycles in package's sources - visited := make(map[string]bool) stack := []string{} if err := visitPackage(root, pkgs, visited, stack); err != nil { return fmt.Errorf("pkgsrc import: %w", err) @@ -138,7 +86,7 @@ func detectCycles(root testPkg, pkgs []testPkg) error { } // check cycles in tests' imports by marking the current package as visited while visiting the tests' imports - // we also coniders PackageSource imports here because tests can call package code + // we also consider PackageSource imports here because tests can call package code visited = map[string]bool{root.PkgPath: true} stack = []string{root.PkgPath} if err := visitImports([]packages.FileKind{packages.FileKindPackageSource, packages.FileKindTest}, root, pkgs, visited, stack); err != nil { @@ -185,6 +133,58 @@ func visitPackage(pkg testPkg, pkgs []testPkg, visited map[string]bool, stack [] return nil } +type testPkg struct { + Dir string + PkgPath string + Imports packages.ImportsMap +} + +// listPkgs lists all packages in rootMod +func listPkgs(rootMod gnomod.Pkg) ([]testPkg, error) { + res := []testPkg{} + rootDir := rootMod.Dir + visited := map[string]struct{}{} + if err := fs.WalkDir(os.DirFS(rootDir), ".", func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() { + return nil + } + if !strings.HasSuffix(d.Name(), ".gno") { + return nil + } + subPath := filepath.Dir(path) + dir := filepath.Join(rootDir, subPath) + if _, ok := visited[dir]; ok { + return nil + } + visited[dir] = struct{}{} + + subPkgPath := pathlib.Join(rootMod.Name, subPath) + + pkg := testPkg{ + Dir: dir, + PkgPath: subPkgPath, + } + + memPkg, err := readPkg(pkg.Dir, pkg.PkgPath) + if err != nil { + return fmt.Errorf("read pkg %q: %w", pkg.Dir, err) + } + pkg.Imports, err = packages.Imports(memPkg, nil) + if err != nil { + return fmt.Errorf("list imports of %q: %w", memPkg.Path, err) + } + + res = append(res, pkg) + return nil + }); err != nil { + return nil, fmt.Errorf("walk dirs at %q: %w", rootDir, err) + } + return res, nil +} + // readPkg reads the sources of a package. It includes all .gno files but ignores the package name func readPkg(dir string, pkgPath string) (*gnovm.MemPackage, error) { list, err := os.ReadDir(dir)