Skip to content

Commit

Permalink
Refactor server in autoconfig
Browse files Browse the repository at this point in the history
  • Loading branch information
adambabik committed Jan 12, 2025
1 parent 11c7d9c commit 0f14169
Show file tree
Hide file tree
Showing 15 changed files with 336 additions and 301 deletions.
19 changes: 2 additions & 17 deletions internal/cmd/beta/server/server_start_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,14 @@ func serverStartCmd() *cobra.Command {
return autoconfig.InvokeForCommand(
func(
cfg *config.Config,
server *server.Server,
cmdFactory command.Factory,
logger *zap.Logger,
) error {
defer logger.Sync()

serverCfg := &server.Config{
Address: cfg.Server.Address,
CertFile: *cfg.Server.Tls.CertFile, // guaranteed by autoconfig
KeyFile: *cfg.Server.Tls.KeyFile, // guaranteed by autoconfig
TLSEnabled: cfg.Server.Tls.Enabled,
}

_ = telemetry.ReportUnlessNoTracking(logger)

logger.Debug("server config", zap.Any("config", serverCfg))

s, err := server.New(serverCfg, cmdFactory, logger)
if err != nil {
return err
}

// When using a unix socket, we want to create a file with server's PID.
if path := pidFileNameFromAddr(cfg.Server.Address); path != "" {
logger.Debug("creating PID file", zap.String("path", path))
Expand All @@ -52,9 +39,7 @@ func serverStartCmd() *cobra.Command {
defer os.Remove(cfg.Server.Address)
}

logger.Debug("starting the server")

return errors.WithStack(s.Serve())
return errors.WithStack(server.Serve())
},
)
},
Expand Down
4 changes: 2 additions & 2 deletions internal/command/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ func redactConfig(cfg *ProgramConfig) *ProgramConfig {
}

func isShell(cfg *ProgramConfig) bool {
return IsShellProgram(filepath.Base(cfg.ProgramName)) || IsShellLanguage(cfg.LanguageId)
return isShellProgram(filepath.Base(cfg.ProgramName)) || IsShellLanguage(cfg.LanguageId)
}

func IsShellProgram(programName string) bool {
func isShellProgram(programName string) bool {
switch strings.ToLower(programName) {
case "sh", "bash", "zsh", "ksh", "shell":
return true
Expand Down
89 changes: 63 additions & 26 deletions internal/config/autoconfig/autoconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,12 @@ import (
"github.com/stateful/runme/v3/internal/command"
"github.com/stateful/runme/v3/internal/config"
"github.com/stateful/runme/v3/internal/dockerexec"
"github.com/stateful/runme/v3/internal/project/projectservice"
"github.com/stateful/runme/v3/internal/runnerv2client"
"github.com/stateful/runme/v3/internal/runnerv2service"
"github.com/stateful/runme/v3/internal/server"
runmetls "github.com/stateful/runme/v3/internal/tls"
"github.com/stateful/runme/v3/pkg/document/editor/editorservice"
"github.com/stateful/runme/v3/pkg/project"
)

Expand Down Expand Up @@ -70,44 +74,27 @@ func init() {
mustProvide(container.Provide(getCommandFactory))
mustProvide(container.Provide(getConfigLoader))
mustProvide(container.Provide(getDocker))
mustProvide(container.Provide(getGRPCClient))
mustProvide(container.Provide(getLogger))
mustProvide(container.Provide(getProject))
mustProvide(container.Provide(getProjectFilters))
mustProvide(container.Provide(getRootConfig))
mustProvide(container.Provide(getServer))
mustProvide(container.Provide(getUserConfigDir))
}

func getClient(cfg *config.Config, logger *zap.Logger) (*runnerv2client.Client, error) {
if cfg.Server == nil {
return nil, nil
}

var opts []grpc.DialOption

if cfg.Server.Tls != nil && cfg.Server.Tls.Enabled {
// It's ok to dereference TLS fields because they are checked in [getRootConfig].
tlsConfig, err := runmetls.LoadClientConfig(*cfg.Server.Tls.CertFile, *cfg.Server.Tls.KeyFile)
if err != nil {
return nil, errors.WithStack(err)
}
creds := credentials.NewTLS(tlsConfig)
opts = append(opts, grpc.WithTransportCredentials(creds))
} else {
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
func getClient(cfg *config.Config, clientConn *grpc.ClientConn, logger *zap.Logger) (*runnerv2client.Client, error) {
if clientConn == nil {
return nil, errors.New("client connection is not configured")
}

return runnerv2client.New(
cfg.Server.Address,
logger,
opts...,
)
return runnerv2client.New(clientConn, logger), nil
}

type ClientFactory func() (*runnerv2client.Client, error)

func getClientFactory(cfg *config.Config, logger *zap.Logger) ClientFactory {
func getClientFactory(cfg *config.Config, clientConn *grpc.ClientConn, logger *zap.Logger) ClientFactory {
return func() (*runnerv2client.Client, error) {
return getClient(cfg, logger)
return getClient(cfg, clientConn, logger)
}
}

Expand Down Expand Up @@ -147,6 +134,35 @@ func getDocker(c *config.Config, logger *zap.Logger) (*dockerexec.Docker, error)
return dockerexec.New(options)
}

func getGRPCClient(
cfg *config.Config,
server *server.Server,
logger *zap.Logger,
) (*grpc.ClientConn, error) {
if cfg.Server == nil {
return nil, nil
}

opts := []grpc.DialOption{
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(cfg.Server.MaxMessageSize)),
}

if tls := cfg.Server.Tls; tls != nil && tls.Enabled {
// It's ok to dereference TLS fields because they are checked in [getRootConfig].
tlsConfig, err := runmetls.LoadClientConfig(*cfg.Server.Tls.CertFile, *cfg.Server.Tls.KeyFile)
if err != nil {
return nil, errors.WithStack(err)
}
creds := credentials.NewTLS(tlsConfig)
opts = append(opts, grpc.WithTransportCredentials(creds))
} else {
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
}

conn, err := grpc.NewClient(server.Addr(), opts...)
return conn, errors.WithStack(err)
}

func getLogger(c *config.Config) (*zap.Logger, error) {
if c == nil || c.Log == nil || !c.Log.Enabled {
return zap.NewNop(), nil
Expand All @@ -166,7 +182,7 @@ func getLogger(c *config.Config) (*zap.Logger, error) {
}

if c.Log.Verbose {
zapConfig.Level = zap.NewAtomicLevelAt(zap.DebugLevel)
zapConfig.Level = zap.NewAtomicLevelAt(zap.InfoLevel)
zapConfig.Development = true
zapConfig.Encoding = "console"
zapConfig.EncoderConfig = zap.NewDevelopmentEncoderConfig()
Expand Down Expand Up @@ -297,6 +313,27 @@ func getRootConfig(cfgLoader *config.Loader, userCfgDir UserConfigDir) (*config.
return cfg, nil
}

func getServer(cfg *config.Config, cmdFactory command.Factory, logger *zap.Logger) (*server.Server, error) {
if cfg.Server == nil {
return nil, nil
}

parserService := editorservice.NewParserServiceServer(logger)
projectService := projectservice.NewProjectServiceServer(logger)
runnerService, err := runnerv2service.NewRunnerService(cmdFactory, logger)
if err != nil {
return nil, err
}

return server.New(
cfg,
parserService,
projectService,
runnerService,
logger,
)
}

type UserConfigDir string

func getUserConfigDir() (UserConfigDir, error) {
Expand Down
155 changes: 152 additions & 3 deletions internal/config/autoconfig/autoconfig_test.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@
package autoconfig

import (
"context"
"fmt"
"os"
"path/filepath"
"testing"
"testing/fstest"
"time"

"github.com/pkg/errors"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
healthv1 "google.golang.org/grpc/health/grpc_health_v1"

"github.com/stateful/runme/v3/internal/config"
"github.com/stateful/runme/v3/internal/server"
)

func TestInvokeForCommand_Config(t *testing.T) {
// Create a fake filesystem and set it in [config.Loader].
err := InvokeForCommand(func(loader *config.Loader) error {
fsys := fstest.MapFS{
"README.md": {
Data: []byte("Hello, World!"),
},
"runme.yaml": {
Data: []byte(fmt.Sprintf("version: v1alpha1\nproject:\n filename: %s\n", "README.md")),
},
Expand All @@ -33,3 +39,146 @@ func TestInvokeForCommand_Config(t *testing.T) {
})
require.NoError(t, err)
}

func TestInvokeForCommand_ServerClient(t *testing.T) {
tmp := t.TempDir()
readme := filepath.Join(tmp, "README.md")
err := os.WriteFile(readme, []byte("Hello, World!"), 0o644)
require.NoError(t, err)

t.Run("NoServerInConfig", func(t *testing.T) {
err := InvokeForCommand(func(loader *config.Loader) error {
fsys := fstest.MapFS{
"runme.yaml": {
Data: []byte(fmt.Sprintf("version: v1alpha1\nproject:\n filename: %s\n", readme)),
},
}
loader.SetConfigRootPath(fsys)
return nil
})
require.NoError(t, err)

err = InvokeForCommand(func(
server *server.Server,
client *grpc.ClientConn,
) error {
require.Nil(t, server)
require.Nil(t, client)
return nil
})
require.NoError(t, err)
})

t.Run("ServerInConfigWithoutTLS", func(t *testing.T) {
err := InvokeForCommand(func(loader *config.Loader) error {
fsys := fstest.MapFS{
"runme.yaml": {
Data: []byte(`version: v1alpha1
project:
filename: ` + readme + `
server:
address: localhost:0
tls:
enabled: false
`),
},
}
loader.SetConfigRootPath(fsys)
return nil
})
require.NoError(t, err)

err = InvokeForCommand(func(
server *server.Server,
client *grpc.ClientConn,
) error {
require.NotNil(t, server)
require.NotNil(t, client)

var g errgroup.Group

g.Go(func() error {
return server.Serve()
})

g.Go(func() error {
defer server.Shutdown()
return checkHealth(client)
})

return g.Wait()
})
require.NoError(t, err)
})

t.Run("ServerInConfigWithTLS", func(t *testing.T) {
// Use a temp dir to store the TLS files.
err = DecorateRoot(func() (UserConfigDir, error) {
return UserConfigDir(tmp), nil
})
require.NoError(t, err)

err := InvokeForCommand(func(loader *config.Loader) error {
fsys := fstest.MapFS{
"runme.yaml": {
Data: []byte(`version: v1alpha1
project:
filename: ` + readme + `
server:
address: 127.0.0.1:0
tls:
enabled: true
`),
},
}
loader.SetConfigRootPath(fsys)
return nil
})
require.NoError(t, err)

err = InvokeForCommand(func(
server *server.Server,
client *grpc.ClientConn,
) error {
require.NotNil(t, server)
require.NotNil(t, client)

var g errgroup.Group

g.Go(func() error {
return server.Serve()
})

g.Go(func() error {
defer server.Shutdown()
return errors.WithMessage(checkHealth(client), "failed to check health")
})

return g.Wait()
})
require.NoError(t, err)
})
}

func checkHealth(conn *grpc.ClientConn) error {
client := healthv1.NewHealthClient(conn)

var (
resp *healthv1.HealthCheckResponse
err error
)

for i := 0; i < 5; i++ {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
resp, err = client.Check(ctx, &healthv1.HealthCheckRequest{})
if err != nil || resp.Status != healthv1.HealthCheckResponse_SERVING {
cancel()
time.Sleep(time.Second)
continue
}
cancel()
break
}

return err
}
4 changes: 1 addition & 3 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,8 @@ func Default() *Config {
}

// ParseYAML parses the given YAML items and returns a configuration object.
// Multiple items are merged into a single configuration. It uses a default
// configuration as a base.
// Multiple items are merged into a single configuration.
func ParseYAML(items ...[]byte) (*Config, error) {
items = append([][]byte{defaultRunmeYAML}, items...)
return parseYAML(items...)
}

Expand Down
4 changes: 4 additions & 0 deletions internal/config/config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@
"address": {
"type": "string"
},
"max_message_size": {
"type": "integer",
"default": 33554432
},
"tls": {
"type": "object",
"properties": {
Expand Down
Loading

0 comments on commit 0f14169

Please sign in to comment.