From 0f14169345a1b4d764af343c74c600d64e60c451 Mon Sep 17 00:00:00 2001 From: Adam Babik Date: Sun, 12 Jan 2025 11:03:00 +0100 Subject: [PATCH] Refactor server in autoconfig --- internal/cmd/beta/server/server_start_cmd.go | 19 +-- internal/command/config.go | 4 +- internal/config/autoconfig/autoconfig.go | 89 +++++++--- internal/config/autoconfig/autoconfig_test.go | 155 +++++++++++++++++- internal/config/config.go | 4 +- internal/config/config.schema.json | 4 + internal/config/config_schema.go | 6 + internal/config/runme.default.yaml | 1 + internal/runnerv2client/client.go | 16 +- internal/runnerv2client/client_test.go | 11 +- internal/runnerv2service/service_execute.go | 11 ++ internal/runnerv2service/service_test.go | 29 ---- internal/server/server.go | 146 +++++++++-------- internal/server/server_test.go | 106 ------------ internal/server/server_unix_test.go | 36 ---- 15 files changed, 336 insertions(+), 301 deletions(-) delete mode 100644 internal/runnerv2service/service_test.go delete mode 100644 internal/server/server_test.go delete mode 100644 internal/server/server_unix_test.go diff --git a/internal/cmd/beta/server/server_start_cmd.go b/internal/cmd/beta/server/server_start_cmd.go index d1831a2c1..ca328ab4c 100644 --- a/internal/cmd/beta/server/server_start_cmd.go +++ b/internal/cmd/beta/server/server_start_cmd.go @@ -22,27 +22,14 @@ func serverStartCmd() *cobra.Command { return autoconfig.InvokeForCommand( func( cfg *config.Config, + server *server.Server, cmdFactory command.Factory, logger *zap.Logger, ) error { defer logger.Sync() - serverCfg := &server.Config{ - Address: cfg.Server.Address, - CertFile: *cfg.Server.Tls.CertFile, // guaranteed by autoconfig - KeyFile: *cfg.Server.Tls.KeyFile, // guaranteed by autoconfig - TLSEnabled: cfg.Server.Tls.Enabled, - } - _ = telemetry.ReportUnlessNoTracking(logger) - logger.Debug("server config", zap.Any("config", serverCfg)) - - s, err := server.New(serverCfg, cmdFactory, logger) - if err != nil { - return err - } - // When using a unix socket, we want to create a file with server's PID. if path := pidFileNameFromAddr(cfg.Server.Address); path != "" { logger.Debug("creating PID file", zap.String("path", path)) @@ -52,9 +39,7 @@ func serverStartCmd() *cobra.Command { defer os.Remove(cfg.Server.Address) } - logger.Debug("starting the server") - - return errors.WithStack(s.Serve()) + return errors.WithStack(server.Serve()) }, ) }, diff --git a/internal/command/config.go b/internal/command/config.go index 96ff2ecb7..a2efc8db0 100644 --- a/internal/command/config.go +++ b/internal/command/config.go @@ -25,10 +25,10 @@ func redactConfig(cfg *ProgramConfig) *ProgramConfig { } func isShell(cfg *ProgramConfig) bool { - return IsShellProgram(filepath.Base(cfg.ProgramName)) || IsShellLanguage(cfg.LanguageId) + return isShellProgram(filepath.Base(cfg.ProgramName)) || IsShellLanguage(cfg.LanguageId) } -func IsShellProgram(programName string) bool { +func isShellProgram(programName string) bool { switch strings.ToLower(programName) { case "sh", "bash", "zsh", "ksh", "shell": return true diff --git a/internal/config/autoconfig/autoconfig.go b/internal/config/autoconfig/autoconfig.go index 7d9353e1e..2b4400852 100644 --- a/internal/config/autoconfig/autoconfig.go +++ b/internal/config/autoconfig/autoconfig.go @@ -24,8 +24,12 @@ import ( "github.com/stateful/runme/v3/internal/command" "github.com/stateful/runme/v3/internal/config" "github.com/stateful/runme/v3/internal/dockerexec" + "github.com/stateful/runme/v3/internal/project/projectservice" "github.com/stateful/runme/v3/internal/runnerv2client" + "github.com/stateful/runme/v3/internal/runnerv2service" + "github.com/stateful/runme/v3/internal/server" runmetls "github.com/stateful/runme/v3/internal/tls" + "github.com/stateful/runme/v3/pkg/document/editor/editorservice" "github.com/stateful/runme/v3/pkg/project" ) @@ -70,44 +74,27 @@ func init() { mustProvide(container.Provide(getCommandFactory)) mustProvide(container.Provide(getConfigLoader)) mustProvide(container.Provide(getDocker)) + mustProvide(container.Provide(getGRPCClient)) mustProvide(container.Provide(getLogger)) mustProvide(container.Provide(getProject)) mustProvide(container.Provide(getProjectFilters)) mustProvide(container.Provide(getRootConfig)) + mustProvide(container.Provide(getServer)) mustProvide(container.Provide(getUserConfigDir)) } -func getClient(cfg *config.Config, logger *zap.Logger) (*runnerv2client.Client, error) { - if cfg.Server == nil { - return nil, nil - } - - var opts []grpc.DialOption - - if cfg.Server.Tls != nil && cfg.Server.Tls.Enabled { - // It's ok to dereference TLS fields because they are checked in [getRootConfig]. - tlsConfig, err := runmetls.LoadClientConfig(*cfg.Server.Tls.CertFile, *cfg.Server.Tls.KeyFile) - if err != nil { - return nil, errors.WithStack(err) - } - creds := credentials.NewTLS(tlsConfig) - opts = append(opts, grpc.WithTransportCredentials(creds)) - } else { - opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) +func getClient(cfg *config.Config, clientConn *grpc.ClientConn, logger *zap.Logger) (*runnerv2client.Client, error) { + if clientConn == nil { + return nil, errors.New("client connection is not configured") } - - return runnerv2client.New( - cfg.Server.Address, - logger, - opts..., - ) + return runnerv2client.New(clientConn, logger), nil } type ClientFactory func() (*runnerv2client.Client, error) -func getClientFactory(cfg *config.Config, logger *zap.Logger) ClientFactory { +func getClientFactory(cfg *config.Config, clientConn *grpc.ClientConn, logger *zap.Logger) ClientFactory { return func() (*runnerv2client.Client, error) { - return getClient(cfg, logger) + return getClient(cfg, clientConn, logger) } } @@ -147,6 +134,35 @@ func getDocker(c *config.Config, logger *zap.Logger) (*dockerexec.Docker, error) return dockerexec.New(options) } +func getGRPCClient( + cfg *config.Config, + server *server.Server, + logger *zap.Logger, +) (*grpc.ClientConn, error) { + if cfg.Server == nil { + return nil, nil + } + + opts := []grpc.DialOption{ + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(cfg.Server.MaxMessageSize)), + } + + if tls := cfg.Server.Tls; tls != nil && tls.Enabled { + // It's ok to dereference TLS fields because they are checked in [getRootConfig]. + tlsConfig, err := runmetls.LoadClientConfig(*cfg.Server.Tls.CertFile, *cfg.Server.Tls.KeyFile) + if err != nil { + return nil, errors.WithStack(err) + } + creds := credentials.NewTLS(tlsConfig) + opts = append(opts, grpc.WithTransportCredentials(creds)) + } else { + opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) + } + + conn, err := grpc.NewClient(server.Addr(), opts...) + return conn, errors.WithStack(err) +} + func getLogger(c *config.Config) (*zap.Logger, error) { if c == nil || c.Log == nil || !c.Log.Enabled { return zap.NewNop(), nil @@ -166,7 +182,7 @@ func getLogger(c *config.Config) (*zap.Logger, error) { } if c.Log.Verbose { - zapConfig.Level = zap.NewAtomicLevelAt(zap.DebugLevel) + zapConfig.Level = zap.NewAtomicLevelAt(zap.InfoLevel) zapConfig.Development = true zapConfig.Encoding = "console" zapConfig.EncoderConfig = zap.NewDevelopmentEncoderConfig() @@ -297,6 +313,27 @@ func getRootConfig(cfgLoader *config.Loader, userCfgDir UserConfigDir) (*config. return cfg, nil } +func getServer(cfg *config.Config, cmdFactory command.Factory, logger *zap.Logger) (*server.Server, error) { + if cfg.Server == nil { + return nil, nil + } + + parserService := editorservice.NewParserServiceServer(logger) + projectService := projectservice.NewProjectServiceServer(logger) + runnerService, err := runnerv2service.NewRunnerService(cmdFactory, logger) + if err != nil { + return nil, err + } + + return server.New( + cfg, + parserService, + projectService, + runnerService, + logger, + ) +} + type UserConfigDir string func getUserConfigDir() (UserConfigDir, error) { diff --git a/internal/config/autoconfig/autoconfig_test.go b/internal/config/autoconfig/autoconfig_test.go index 534d57576..878416125 100644 --- a/internal/config/autoconfig/autoconfig_test.go +++ b/internal/config/autoconfig/autoconfig_test.go @@ -1,22 +1,28 @@ package autoconfig import ( + "context" "fmt" + "os" + "path/filepath" "testing" "testing/fstest" + "time" + "github.com/pkg/errors" "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc" + healthv1 "google.golang.org/grpc/health/grpc_health_v1" "github.com/stateful/runme/v3/internal/config" + "github.com/stateful/runme/v3/internal/server" ) func TestInvokeForCommand_Config(t *testing.T) { // Create a fake filesystem and set it in [config.Loader]. err := InvokeForCommand(func(loader *config.Loader) error { fsys := fstest.MapFS{ - "README.md": { - Data: []byte("Hello, World!"), - }, "runme.yaml": { Data: []byte(fmt.Sprintf("version: v1alpha1\nproject:\n filename: %s\n", "README.md")), }, @@ -33,3 +39,146 @@ func TestInvokeForCommand_Config(t *testing.T) { }) require.NoError(t, err) } + +func TestInvokeForCommand_ServerClient(t *testing.T) { + tmp := t.TempDir() + readme := filepath.Join(tmp, "README.md") + err := os.WriteFile(readme, []byte("Hello, World!"), 0o644) + require.NoError(t, err) + + t.Run("NoServerInConfig", func(t *testing.T) { + err := InvokeForCommand(func(loader *config.Loader) error { + fsys := fstest.MapFS{ + "runme.yaml": { + Data: []byte(fmt.Sprintf("version: v1alpha1\nproject:\n filename: %s\n", readme)), + }, + } + loader.SetConfigRootPath(fsys) + return nil + }) + require.NoError(t, err) + + err = InvokeForCommand(func( + server *server.Server, + client *grpc.ClientConn, + ) error { + require.Nil(t, server) + require.Nil(t, client) + return nil + }) + require.NoError(t, err) + }) + + t.Run("ServerInConfigWithoutTLS", func(t *testing.T) { + err := InvokeForCommand(func(loader *config.Loader) error { + fsys := fstest.MapFS{ + "runme.yaml": { + Data: []byte(`version: v1alpha1 +project: + filename: ` + readme + ` +server: + address: localhost:0 + tls: + enabled: false +`), + }, + } + loader.SetConfigRootPath(fsys) + return nil + }) + require.NoError(t, err) + + err = InvokeForCommand(func( + server *server.Server, + client *grpc.ClientConn, + ) error { + require.NotNil(t, server) + require.NotNil(t, client) + + var g errgroup.Group + + g.Go(func() error { + return server.Serve() + }) + + g.Go(func() error { + defer server.Shutdown() + return checkHealth(client) + }) + + return g.Wait() + }) + require.NoError(t, err) + }) + + t.Run("ServerInConfigWithTLS", func(t *testing.T) { + // Use a temp dir to store the TLS files. + err = DecorateRoot(func() (UserConfigDir, error) { + return UserConfigDir(tmp), nil + }) + require.NoError(t, err) + + err := InvokeForCommand(func(loader *config.Loader) error { + fsys := fstest.MapFS{ + "runme.yaml": { + Data: []byte(`version: v1alpha1 +project: + filename: ` + readme + ` +server: + address: 127.0.0.1:0 + tls: + enabled: true +`), + }, + } + loader.SetConfigRootPath(fsys) + return nil + }) + require.NoError(t, err) + + err = InvokeForCommand(func( + server *server.Server, + client *grpc.ClientConn, + ) error { + require.NotNil(t, server) + require.NotNil(t, client) + + var g errgroup.Group + + g.Go(func() error { + return server.Serve() + }) + + g.Go(func() error { + defer server.Shutdown() + return errors.WithMessage(checkHealth(client), "failed to check health") + }) + + return g.Wait() + }) + require.NoError(t, err) + }) +} + +func checkHealth(conn *grpc.ClientConn) error { + client := healthv1.NewHealthClient(conn) + + var ( + resp *healthv1.HealthCheckResponse + err error + ) + + for i := 0; i < 5; i++ { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + resp, err = client.Check(ctx, &healthv1.HealthCheckRequest{}) + if err != nil || resp.Status != healthv1.HealthCheckResponse_SERVING { + cancel() + time.Sleep(time.Second) + continue + } + cancel() + break + } + + return err +} diff --git a/internal/config/config.go b/internal/config/config.go index 676d67bfa..f4f43914f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -32,10 +32,8 @@ func Default() *Config { } // ParseYAML parses the given YAML items and returns a configuration object. -// Multiple items are merged into a single configuration. It uses a default -// configuration as a base. +// Multiple items are merged into a single configuration. func ParseYAML(items ...[]byte) (*Config, error) { - items = append([][]byte{defaultRunmeYAML}, items...) return parseYAML(items...) } diff --git a/internal/config/config.schema.json b/internal/config/config.schema.json index d8a21f4c0..75da9334f 100644 --- a/internal/config/config.schema.json +++ b/internal/config/config.schema.json @@ -116,6 +116,10 @@ "address": { "type": "string" }, + "max_message_size": { + "type": "integer", + "default": 33554432 + }, "tls": { "type": "object", "properties": { diff --git a/internal/config/config_schema.go b/internal/config/config_schema.go index 3e81b777f..07a6ab94c 100644 --- a/internal/config/config_schema.go +++ b/internal/config/config_schema.go @@ -295,6 +295,9 @@ type ConfigServer struct { // Address corresponds to the JSON schema field "address". Address string `json:"address" yaml:"address"` + // MaxMessageSize corresponds to the JSON schema field "max_message_size". + MaxMessageSize int `json:"max_message_size,omitempty" yaml:"max_message_size,omitempty"` + // Tls corresponds to the JSON schema field "tls". Tls *ConfigServerTls `json:"tls,omitempty" yaml:"tls,omitempty"` } @@ -342,6 +345,9 @@ func (j *ConfigServer) UnmarshalJSON(b []byte) error { if err := json.Unmarshal(b, &plain); err != nil { return err } + if v, ok := raw["max_message_size"]; !ok || v == nil { + plain.MaxMessageSize = 33554432.0 + } *j = ConfigServer(plain) return nil } diff --git a/internal/config/runme.default.yaml b/internal/config/runme.default.yaml index 877375263..59997e3f2 100644 --- a/internal/config/runme.default.yaml +++ b/internal/config/runme.default.yaml @@ -28,6 +28,7 @@ server: # If not specified, default paths will be used. # cert_file: "/path/to/cert.pem" # key_file: "/path/to/key.pem" + max_message_size: 33554432 # 32 MiB log: enabled: false diff --git a/internal/runnerv2client/client.go b/internal/runnerv2client/client.go index d0e6305e3..8ac51516e 100644 --- a/internal/runnerv2client/client.go +++ b/internal/runnerv2client/client.go @@ -12,23 +12,19 @@ import ( runnerv2 "github.com/stateful/runme/v3/pkg/api/gen/proto/go/runme/runner/v2" ) +const MaxMsgSize = 32 * 1024 * 1024 // 32 MiB + type Client struct { runnerv2.RunnerServiceClient conn *grpc.ClientConn logger *zap.Logger } -func New(target string, logger *zap.Logger, opts ...grpc.DialOption) (*Client, error) { - client, err := grpc.NewClient(target, opts...) - if err != nil { - return nil, errors.WithStack(err) - } - serviceClient := &Client{ - RunnerServiceClient: runnerv2.NewRunnerServiceClient(client), - conn: client, - logger: logger, +func New(clientConn *grpc.ClientConn, logger *zap.Logger) *Client { + return &Client{ + RunnerServiceClient: runnerv2.NewRunnerServiceClient(clientConn), + logger: logger.Named("runnerv2client.Client"), } - return serviceClient, nil } func (c *Client) Close() error { diff --git a/internal/runnerv2client/client_test.go b/internal/runnerv2client/client_test.go index d2d8cb960..795963462 100644 --- a/internal/runnerv2client/client_test.go +++ b/internal/runnerv2client/client_test.go @@ -108,15 +108,18 @@ func TestClient_ExecuteProgram(t *testing.T) { func createClient(t *testing.T, lis *bufconn.Listener) *Client { t.Helper() - logger := zaptest.NewLogger(t) - client, err := New( + + clientConn, err := grpc.NewClient( "passthrough://bufconn", - logger, grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { return lis.Dial() }), grpc.WithTransportCredentials(insecure.NewCredentials()), ) require.NoError(t, err) - return client + + return New( + clientConn, + zaptest.NewLogger(t), + ) } diff --git a/internal/runnerv2service/service_execute.go b/internal/runnerv2service/service_execute.go index 8f4a17d30..401432ae0 100644 --- a/internal/runnerv2service/service_execute.go +++ b/internal/runnerv2service/service_execute.go @@ -56,6 +56,17 @@ func (r *runnerService) Execute(srv runnerv2.RunnerService_ExecuteServer) error return err } + // exec, err := newExecution( + // req.Config, + // proj, + // session, + // logger, + // req.StoreStdoutInEnv, + // ) + // if err != nil { + // return err + // } + exec, err := newExecution( req.Config, proj, diff --git a/internal/runnerv2service/service_test.go b/internal/runnerv2service/service_test.go deleted file mode 100644 index c661dd5ec..000000000 --- a/internal/runnerv2service/service_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package runnerv2service - -import ( - "testing/fstest" - - "github.com/stateful/runme/v3/internal/command" - "github.com/stateful/runme/v3/internal/config" - "github.com/stateful/runme/v3/internal/config/autoconfig" -) - -func init() { - command.SetEnvDumpCommandForTesting() - - // Server uses autoconfig to get necessary dependencies. - // One of them, implicit, is [config.Config]. With the default - // [config.Loader] it won't be found during testing, so - // we need to provide an override. - if err := autoconfig.DecorateRoot(func(loader *config.Loader) *config.Loader { - fsys := fstest.MapFS{ - "runme.yaml": { - Data: []byte("version: v1alpha1\n"), - }, - } - loader.SetConfigRootPath(fsys) - return loader - }); err != nil { - panic(err) - } -} diff --git a/internal/server/server.go b/internal/server/server.go index a22cf6565..355a14d18 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -6,92 +6,124 @@ import ( "os" "strings" - "github.com/pkg/errors" "go.uber.org/zap" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/health" healthv1 "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/reflection" - "github.com/stateful/runme/v3/internal/command" - "github.com/stateful/runme/v3/internal/project/projectservice" - "github.com/stateful/runme/v3/internal/runnerv2service" + "github.com/pkg/errors" + "github.com/stateful/runme/v3/internal/config" runmetls "github.com/stateful/runme/v3/internal/tls" parserv1 "github.com/stateful/runme/v3/pkg/api/gen/proto/go/runme/parser/v1" projectv1 "github.com/stateful/runme/v3/pkg/api/gen/proto/go/runme/project/v1" runnerv2 "github.com/stateful/runme/v3/pkg/api/gen/proto/go/runme/runner/v2" - "github.com/stateful/runme/v3/pkg/document/editor/editorservice" -) - -const ( - maxMsgSize = 4 * 1024 * 1024 // 4 MiB ) -type Config struct { - Address string - CertFile string - KeyFile string - TLSEnabled bool -} - type Server struct { - grpcServer *grpc.Server - lis net.Listener - logger *zap.Logger + gs *grpc.Server + lis net.Listener + logger *zap.Logger } func New( - c *Config, - cmdFactory command.Factory, + cfg *config.Config, + parserService parserv1.ParserServiceServer, + projectService projectv1.ProjectServiceServer, + runnerService runnerv2.RunnerServiceServer, logger *zap.Logger, -) (_ *Server, err error) { - var tlsConfig *tls.Config +) (*Server, error) { + tlsCfg, err := createTLSConfig(cfg, logger) + if err != nil { + return nil, err + } + + lis, err := createListener(cfg, tlsCfg) + if err != nil { + return nil, err + } + + grpcServer := createGRPCServer( + cfg, + tlsCfg, + parserService, + projectService, + runnerService, + ) + + s := Server{ + gs: grpcServer, + lis: lis, + logger: logger.Named("Server"), + } + + return &s, nil +} + +func (s *Server) Addr() string { + return s.lis.Addr().String() +} + +func (s *Server) Serve() error { + s.logger.Info("starting gRPC server", zap.String("address", s.Addr())) + return s.gs.Serve(s.lis) +} + +func (s *Server) Shutdown() { + s.logger.Info("stopping gRPC server") + s.gs.GracefulStop() +} - if c.TLSEnabled { +func createTLSConfig(cfg *config.Config, logger *zap.Logger) (*tls.Config, error) { + if tls := cfg.Server.Tls; tls != nil && tls.Enabled { // TODO(adamb): redesign runmetls API. - tlsConfig, err = runmetls.LoadOrGenerateConfig(c.CertFile, c.KeyFile, logger) - if err != nil { - return nil, err - } + return runmetls.LoadOrGenerateConfig( + *tls.CertFile, // guaranteed in [getRootConfig] + *tls.KeyFile, // guaranteed in [getRootConfig] + logger, + ) } + return nil, nil +} - addr := c.Address +func createListener(cfg *config.Config, tlsCfg *tls.Config) (net.Listener, error) { + addr := cfg.Server.Address protocol := "tcp" - var lis net.Listener - if strings.HasPrefix(addr, "unix://") { protocol = "unix" addr = strings.TrimPrefix(addr, "unix://") - if _, err := os.Stat(addr); !os.IsNotExist(err) { return nil, err } } - if tlsConfig == nil { - lis, err = net.Listen(protocol, addr) - } else { - lis, err = tls.Listen(protocol, addr, tlsConfig) - } - if err != nil { - return nil, errors.WithStack(err) + if tlsCfg != nil { + lis, err := tls.Listen(protocol, addr, tlsCfg) + return lis, errors.WithStack(err) } - logger.Info("server listening", zap.String("address", addr)) + lis, err := net.Listen(protocol, addr) + return lis, errors.WithStack(err) +} +func createGRPCServer( + cfg *config.Config, + tlsCfg *tls.Config, + parserService parserv1.ParserServiceServer, + projectService projectv1.ProjectServiceServer, + runnerService runnerv2.RunnerServiceServer, +) *grpc.Server { grpcServer := grpc.NewServer( - grpc.MaxRecvMsgSize(maxMsgSize), - grpc.MaxSendMsgSize(maxMsgSize), + grpc.MaxRecvMsgSize(cfg.Server.MaxMessageSize), + grpc.MaxSendMsgSize(cfg.Server.MaxMessageSize), + grpc.Creds(credentials.NewTLS(tlsCfg)), ) // Register runme services. - parserv1.RegisterParserServiceServer(grpcServer, editorservice.NewParserServiceServer(logger)) - projectv1.RegisterProjectServiceServer(grpcServer, projectservice.NewProjectServiceServer(logger)) - runnerService, err := runnerv2service.NewRunnerService(cmdFactory, logger) - if err != nil { - return nil, err - } + parserv1.RegisterParserServiceServer(grpcServer, parserService) + projectv1.RegisterProjectServiceServer(grpcServer, projectService) runnerv2.RegisterRunnerServiceServer(grpcServer, runnerService) // Register health service. @@ -103,21 +135,5 @@ func New( // Register reflection service. reflection.Register(grpcServer) - return &Server{ - lis: lis, - grpcServer: grpcServer, - logger: logger, - }, nil -} - -func (s *Server) Addr() string { - return s.lis.Addr().String() -} - -func (s *Server) Serve() error { - return s.grpcServer.Serve(s.lis) -} - -func (s *Server) Shutdown() { - s.grpcServer.GracefulStop() + return grpcServer } diff --git a/internal/server/server_test.go b/internal/server/server_test.go deleted file mode 100644 index 22f79b0b6..000000000 --- a/internal/server/server_test.go +++ /dev/null @@ -1,106 +0,0 @@ -package server - -import ( - "context" - "path/filepath" - "runtime" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/require" - "go.uber.org/zap/zaptest" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" - healthv1 "google.golang.org/grpc/health/grpc_health_v1" - - "github.com/stateful/runme/v3/internal/command" - runmetls "github.com/stateful/runme/v3/internal/tls" -) - -func TestServer(t *testing.T) { - logger := zaptest.NewLogger(t) - factory := command.NewFactory(command.WithLogger(logger)) - - t.Run("tcp", func(t *testing.T) { - cfg := &Config{ - Address: "localhost:0", - } - s, err := New(cfg, factory, logger) - require.NoError(t, err) - errc := make(chan error, 1) - go func() { - errc <- s.Serve() - }() - - testConnectivity(t, s.Addr(), insecure.NewCredentials()) - - s.Shutdown() - require.NoError(t, <-errc) - }) - - t.Run("tcp with tls", func(t *testing.T) { - dir := t.TempDir() - cfg := &Config{ - Address: "localhost:0", - CertFile: filepath.Join(dir, "cert.pem"), - KeyFile: filepath.Join(dir, "key.pem"), - TLSEnabled: true, - } - s, err := New(cfg, factory, logger) - require.NoError(t, err) - errc := make(chan error, 1) - go func() { - errc <- s.Serve() - }() - - tlsConfig, err := runmetls.LoadClientConfig(cfg.CertFile, cfg.KeyFile) - require.NoError(t, err) - - addr := s.Addr() - if runtime.GOOS == "windows" { - addr = strings.TrimPrefix(addr, "unix://") - } - testConnectivity(t, addr, credentials.NewTLS(tlsConfig)) - - s.Shutdown() - require.NoError(t, <-errc) - }) -} - -func testConnectivity(t *testing.T, addr string, creds credentials.TransportCredentials) { - t.Helper() - - var err error - - for i := 0; i < 5; i++ { - var ( - conn *grpc.ClientConn - resp *healthv1.HealthCheckResponse - ) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - conn, err := grpc.NewClient( - addr, - grpc.WithTransportCredentials(creds), - ) - if err != nil { - goto wait - } - - resp, err = healthv1.NewHealthClient(conn).Check(ctx, &healthv1.HealthCheckRequest{}) - if err != nil || resp.Status != healthv1.HealthCheckResponse_SERVING { - goto wait - } - - cancel() - break - - wait: - cancel() - <-time.After(time.Millisecond * 100) - } - - require.NoError(t, err) -} diff --git a/internal/server/server_unix_test.go b/internal/server/server_unix_test.go deleted file mode 100644 index eee7f0865..000000000 --- a/internal/server/server_unix_test.go +++ /dev/null @@ -1,36 +0,0 @@ -//go:build !windows - -package server - -import ( - "path/filepath" - "testing" - - "github.com/stretchr/testify/require" - "go.uber.org/zap/zaptest" - "google.golang.org/grpc/credentials/insecure" - - "github.com/stateful/runme/v3/internal/command" -) - -func TestServerUnixSocket(t *testing.T) { - dir := t.TempDir() - sock := filepath.Join(dir, "runme.sock") - cfg := &Config{ - Address: "unix://" + sock, - } - logger := zaptest.NewLogger(t) - factory := command.NewFactory(command.WithLogger(logger)) - s, err := New(cfg, factory, logger) - require.NoError(t, err) - errc := make(chan error, 1) - go func() { - err := s.Serve() - errc <- err - }() - - testConnectivity(t, cfg.Address, insecure.NewCredentials()) - - s.Shutdown() - require.NoError(t, <-errc) -}