From 7f5856520158ab75d67bd77b79992db6b0499c92 Mon Sep 17 00:00:00 2001 From: Bailin He <15058035+bailinhe@users.noreply.github.com> Date: Thu, 5 Oct 2023 12:03:34 -0400 Subject: [PATCH] [Extend Governor APIs] Extensions Management (#76) * add extensions management * Update pkg/api/v1alpha1/errors.go Co-authored-by: E Camden Fisher * address review issues --------- Co-authored-by: E Camden Fisher --- internal/dbtools/hooks.go | 94 ++++- internal/eventbus/client.go | 3 +- pkg/api/v1alpha1/errors.go | 7 + pkg/api/v1alpha1/extensions.go | 452 +++++++++++++++++++++ pkg/api/v1alpha1/extensions_test.go | 596 ++++++++++++++++++++++++++++ pkg/api/v1alpha1/router.go | 39 ++ pkg/api/v1alpha1/testing.go | 14 + pkg/client/errors.go | 3 + pkg/client/extensions.go | 241 +++++++++++ pkg/client/extensions_test.go | 532 +++++++++++++++++++++++++ pkg/events/v1alpha1/events.go | 5 + 11 files changed, 1966 insertions(+), 20 deletions(-) create mode 100644 pkg/api/v1alpha1/extensions.go create mode 100644 pkg/api/v1alpha1/extensions_test.go create mode 100644 pkg/api/v1alpha1/testing.go create mode 100644 pkg/client/extensions.go create mode 100644 pkg/client/extensions_test.go diff --git a/internal/dbtools/hooks.go b/internal/dbtools/hooks.go index d9396d4..1f9bab1 100644 --- a/internal/dbtools/hooks.go +++ b/internal/dbtools/hooks.go @@ -10,6 +10,7 @@ import ( "github.com/gosimple/slug" "github.com/volatiletech/null/v8" "github.com/volatiletech/sqlboiler/v4/boil" + "github.com/volatiletech/sqlboiler/v4/types" "github.com/metal-toolbox/governor-api/internal/models" ) @@ -39,29 +40,29 @@ func SetApplicationTypeSlug(a *models.ApplicationType) { } func changesetLine(set []string, key string, old, new interface{}) []string { - if old == new { + if reflect.DeepEqual(old, new) { return set } var str string - if old != new { - switch o := old.(type) { - case string: - str = fmt.Sprintf(`%s: "%s" => "%s"`, key, o, new.(string)) - case null.String: - str = fmt.Sprintf(`%s: "%s" => "%s"`, key, o.String, new.(null.String).String) - case int: - str = fmt.Sprintf(`%s: "%d" => "%d"`, key, o, new) - case int64: - str = fmt.Sprintf(`%s: "%d" => "%d"`, key, o, new) - case bool: - str = fmt.Sprintf(`%s: "%t" => "%t"`, key, o, new) - case time.Time: - str = fmt.Sprintf(`%s: "%s" => "%s"`, key, o.UTC().Format(time.RFC3339), new.(time.Time).UTC().Format(time.RFC3339)) - default: - str = fmt.Sprintf(`%s: "%s" => "%s"`, key, o, new) - } + switch o := old.(type) { + case string: + str = fmt.Sprintf(`%s: "%s" => "%s"`, key, o, new.(string)) + case null.String: + str = fmt.Sprintf(`%s: "%s" => "%s"`, key, o.String, new.(null.String).String) + case int: + str = fmt.Sprintf(`%s: "%d" => "%d"`, key, o, new) + case int64: + str = fmt.Sprintf(`%s: "%d" => "%d"`, key, o, new) + case bool: + str = fmt.Sprintf(`%s: "%t" => "%t"`, key, o, new) + case time.Time: + str = fmt.Sprintf(`%s: "%s" => "%s"`, key, o.UTC().Format(time.RFC3339), new.(time.Time).UTC().Format(time.RFC3339)) + case types.JSON: + str = fmt.Sprintf(`%s: "%s" => "%s"`, key, string(o), string(new.(types.JSON))) + default: + str = fmt.Sprintf(`%s: "%s" => "%s"`, key, o, new) } return append(set, str) @@ -1011,3 +1012,60 @@ func AuditNotificationPreferencesUpdated(ctx context.Context, exec boil.ContextE return &event, event.Insert(ctx, exec, boil.Infer()) } + +// AuditExtensionCreated inserts an event representing a extension being created +func AuditExtensionCreated(ctx context.Context, exec boil.ContextExecutor, pID string, actor *models.User, a *models.Extension) (*models.AuditEvent, error) { + // TODO non-user API actors don't exist in the governor database, + // we need to figure out how to handle that relationship in the audit table + var actorID null.String + if actor != nil { + actorID = null.StringFrom(actor.ID) + } + + event := models.AuditEvent{ + ParentID: null.StringFrom(pID), + ActorID: actorID, + Action: "extension.created", + Changeset: calculateChangeset(&models.Extension{}, a), + } + + return &event, event.Insert(ctx, exec, boil.Infer()) +} + +// AuditExtensionUpdated inserts an event representing a extension being created +func AuditExtensionUpdated(ctx context.Context, exec boil.ContextExecutor, pID string, actor *models.User, o, a *models.Extension) (*models.AuditEvent, error) { + // TODO non-user API actors don't exist in the governor database, + // we need to figure out how to handle that relationship in the audit table + var actorID null.String + if actor != nil { + actorID = null.StringFrom(actor.ID) + } + + event := models.AuditEvent{ + ParentID: null.StringFrom(pID), + ActorID: actorID, + Action: "extension.updated", + Changeset: calculateChangeset(o, a), + } + + return &event, event.Insert(ctx, exec, boil.Infer()) +} + +// AuditExtensionDeleted inserts an event representing an extension being deleted +func AuditExtensionDeleted(ctx context.Context, exec boil.ContextExecutor, pID string, actor *models.User, a *models.Extension) (*models.AuditEvent, error) { + // TODO non-user API actors don't exist in the governor database, + // we need to figure out how to handle that relationship in the audit table + var actorID null.String + if actor != nil { + actorID = null.StringFrom(actor.ID) + } + + event := models.AuditEvent{ + ParentID: null.StringFrom(pID), + ActorID: actorID, + Action: "extension.deleted", + Changeset: calculateChangeset(a, &models.Extension{}), + } + + return &event, event.Insert(ctx, exec, boil.Infer()) +} diff --git a/internal/eventbus/client.go b/internal/eventbus/client.go index 9adca1d..257d9d9 100644 --- a/internal/eventbus/client.go +++ b/internal/eventbus/client.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" - "github.com/nats-io/nats.go" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" @@ -52,7 +51,7 @@ func NewClient(opts ...Option) *Client { } // WithNATSConn sets the nats connection -func WithNATSConn(nc *nats.Conn) Option { +func WithNATSConn(nc conn) Option { return func(c *Client) { c.conn = nc } diff --git a/pkg/api/v1alpha1/errors.go b/pkg/api/v1alpha1/errors.go index acd756e..4d14379 100644 --- a/pkg/api/v1alpha1/errors.go +++ b/pkg/api/v1alpha1/errors.go @@ -13,6 +13,13 @@ var ( ErrEmptyInput = errors.New("name or description cannot be empty") // ErrUnknownRequestKind is returned a request kind is unknown ErrUnknownRequestKind = errors.New("request kind is unrecognized") + // ErrGetDeleteResourcedWithSlug is returned when user tries to query a deleted + // resource with slug + ErrGetDeleteResourcedWithSlug = errors.New("unable to get deleted resource by slug, use the id") + // ErrExtensionNotFound is returned when an extension is not found + ErrExtensionNotFound = errors.New("extension does not exist") + // ErrERDNotFound is returned when an extension resource definition is not found + ErrERDNotFound = errors.New("ERD does not exist") ) func sendError(c *gin.Context, code int, msg string) { diff --git a/pkg/api/v1alpha1/extensions.go b/pkg/api/v1alpha1/extensions.go new file mode 100644 index 0000000..f162b19 --- /dev/null +++ b/pkg/api/v1alpha1/extensions.go @@ -0,0 +1,452 @@ +package v1alpha1 + +import ( + "database/sql" + "errors" + "fmt" + "net/http" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "github.com/gosimple/slug" + "github.com/metal-toolbox/auditevent/ginaudit" + "github.com/metal-toolbox/governor-api/internal/dbtools" + "github.com/metal-toolbox/governor-api/internal/models" + events "github.com/metal-toolbox/governor-api/pkg/events/v1alpha1" + "github.com/volatiletech/sqlboiler/v4/boil" + "github.com/volatiletech/sqlboiler/v4/queries/qm" + "go.uber.org/zap" +) + +// Extension is the extension response +type Extension struct { + *models.Extension +} + +// ExtensionReq is a request to create an extension +type ExtensionReq struct { + Name string `json:"name"` + Description string `json:"description"` + Enabled *bool `json:"enabled,omitempty"` +} + +// listExtensions lists extensions as JSON +func (r *Router) listExtensions(c *gin.Context) { + queryMods := []qm.QueryMod{ + qm.OrderBy("name"), + } + + if _, ok := c.GetQuery("deleted"); ok { + queryMods = append(queryMods, qm.WithDeleted()) + } + + extensions, err := models.Extensions(queryMods...).All(c.Request.Context(), r.DB) + if err != nil { + r.Logger.Error("error fetching extensions", zap.Error(err)) + sendError(c, http.StatusBadRequest, "error listing extensions: "+err.Error()) + + return + } + + c.JSON(http.StatusOK, extensions) +} + +// createExtension creates an extension in DB +func (r *Router) createExtension(c *gin.Context) { + req := &ExtensionReq{} + if err := c.BindJSON(req); err != nil { + sendError(c, http.StatusBadRequest, "unable to bind request: "+err.Error()) + return + } + + if req.Name == "" { + sendError(c, http.StatusBadRequest, "extension name is required") + return + } + + if req.Description == "" { + sendError(c, http.StatusBadRequest, "extension description is required") + return + } + + if req.Enabled == nil { + sendError(c, http.StatusBadRequest, "extension enabled is required") + return + } + + extension := &models.Extension{ + Name: req.Name, + Description: req.Description, + Enabled: *req.Enabled, + } + + extension.Slug = slug.Make(extension.Name) + + tx, err := r.DB.BeginTx(c.Request.Context(), nil) + if err != nil { + sendError(c, http.StatusBadRequest, "error starting extension create transaction: "+err.Error()) + return + } + + if err := extension.Insert(c.Request.Context(), tx, boil.Infer()); err != nil { + msg := fmt.Sprintf("error creating extension: %s", err.Error()) + + if err := tx.Rollback(); err != nil { + msg += fmt.Sprintf("error rolling back transaction: %s", err.Error()) + } + + sendError(c, http.StatusBadRequest, msg) + + return + } + + event, err := dbtools.AuditExtensionCreated( + c.Request.Context(), + tx, + getCtxAuditID(c), + getCtxUser(c), + extension, + ) + if err != nil { + msg := fmt.Sprintf("error creating extension (audit): %s", err.Error()) + + if err := tx.Rollback(); err != nil { + msg += fmt.Sprintf("error rolling back transaction: %s", err.Error()) + } + + sendError(c, http.StatusBadRequest, msg) + + return + } + + if err := updateContextWithAuditEventData(c, event); err != nil { + msg := fmt.Sprintf("error creating extension: %s", err.Error()) + + if err := tx.Rollback(); err != nil { + msg += fmt.Sprintf("error rolling back transaction: %s", err.Error()) + } + + sendError(c, http.StatusBadRequest, msg) + + return + } + + if err := tx.Commit(); err != nil { + msg := fmt.Sprintf("error committing extension create: %s", err.Error()) + + if err := tx.Rollback(); err != nil { + msg += fmt.Sprintf("error rolling back transaction: %s", err.Error()) + } + + sendError(c, http.StatusBadRequest, msg) + + return + } + + err = r.EventBus.Publish( + c.Request.Context(), + events.GovernorExtensionsEventSubject, + &events.Event{ + Version: events.Version, + Action: events.GovernorEventCreate, + AuditID: c.GetString(ginaudit.AuditIDContextKey), + ActorID: getCtxActorID(c), + ExtensionID: extension.ID, + }, + ) + if err != nil { + sendError( + c, + http.StatusBadRequest, + fmt.Sprintf( + "failed to publish extension create event: %s\n%s", + err.Error(), + "downstream changes may be delayed", + ), + ) + + return + } + + c.JSON(http.StatusAccepted, extension) +} + +// getExtension fetch a extension from DB with given id +func (r *Router) getExtension(c *gin.Context) { + queryMods := []qm.QueryMod{} + id := c.Param("eid") + + deleted := false + if _, deleted = c.GetQuery("deleted"); deleted { + queryMods = append(queryMods, qm.WithDeleted()) + } + + q := qm.Where("id = ?", id) + + if _, err := uuid.Parse(id); err != nil { + if deleted { + sendError(c, http.StatusBadRequest, "unable to get deleted extension by slug, use the id") + return + } + + q = qm.Where("slug = ?", id) + } + + queryMods = append(queryMods, q) + + extension, err := models.Extensions(queryMods...).One(c.Request.Context(), r.DB) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + sendError(c, http.StatusNotFound, "extension not found: "+err.Error()) + return + } + + sendError(c, http.StatusInternalServerError, "error getting extension"+err.Error()) + + return + } + + c.JSON(http.StatusOK, Extension{extension}) +} + +// deleteExtension marks an extension deleted +func (r *Router) deleteExtension(c *gin.Context) { + id := c.Param("eid") + + q := qm.Where("id = ?", id) + if _, err := uuid.Parse(id); err != nil { + q = qm.Where("slug = ?", id) + } + + extension, err := models.Extensions(q).One(c.Request.Context(), r.DB) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + sendError(c, http.StatusNotFound, "extension not found: "+err.Error()) + return + } + + sendError(c, http.StatusInternalServerError, "error getting extension: "+err.Error()) + + return + } + + tx, err := r.DB.BeginTx(c.Request.Context(), nil) + if err != nil { + sendError(c, http.StatusBadRequest, "error starting delete transaction: "+err.Error()) + return + } + + if _, err := extension.Delete(c.Request.Context(), tx, false); err != nil { + msg := fmt.Sprintf("error deleting extension: %s. rolling back\n", err.Error()) + + if err := tx.Rollback(); err != nil { + msg += fmt.Sprintf("error rolling back transaction: %s", err.Error()) + } + + sendError(c, http.StatusBadRequest, msg) + + return + } + + event, err := dbtools.AuditExtensionDeleted( + c.Request.Context(), + tx, + getCtxAuditID(c), + getCtxUser(c), + extension, + ) + if err != nil { + msg := fmt.Sprintf("error deleting extension (audit): %s", err.Error()) + + if err := tx.Rollback(); err != nil { + msg += fmt.Sprintf("error rolling back transaction: %s", err.Error()) + } + + sendError(c, http.StatusBadRequest, msg) + + return + } + + if err := updateContextWithAuditEventData(c, event); err != nil { + msg := fmt.Sprintf("error deleting extension: %s", err.Error()) + + if err := tx.Rollback(); err != nil { + msg += fmt.Sprintf("error rolling back transaction: %s", err.Error()) + } + + sendError(c, http.StatusBadRequest, msg) + + return + } + + if err := tx.Commit(); err != nil { + msg := fmt.Sprintf("error committing extension delete: %s", err.Error()) + + if err := tx.Rollback(); err != nil { + msg += fmt.Sprintf("error rolling back transaction: %s", err.Error()) + } + + sendError(c, http.StatusBadRequest, msg) + + return + } + + err = r.EventBus.Publish( + c.Request.Context(), + events.GovernorExtensionsEventSubject, + &events.Event{ + Version: events.Version, + Action: events.GovernorEventDelete, + AuditID: c.GetString(ginaudit.AuditIDContextKey), + ActorID: getCtxActorID(c), + ExtensionID: extension.ID, + }, + ) + if err != nil { + sendError( + c, + http.StatusBadRequest, + fmt.Sprintf( + "failed to publish extension delete event: %s\n%s", + err.Error(), + "downstream changes may be delayed", + ), + ) + + return + } + + c.JSON(http.StatusAccepted, extension) +} + +// updateExtension updates an extension in DB +func (r *Router) updateExtension(c *gin.Context) { + id := c.Param("eid") + + q := qm.Where("id = ?", id) + if _, err := uuid.Parse(id); err != nil { + q = qm.Where("slug = ?", id) + } + + extension, err := models.Extensions(q).One(c.Request.Context(), r.DB) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + sendError(c, http.StatusNotFound, "extension not found: "+err.Error()) + return + } + + sendError(c, http.StatusInternalServerError, "error getting extension: "+err.Error()) + + return + } + + original := *extension + + req := &ExtensionReq{} + if err := c.BindJSON(req); err != nil { + sendError(c, http.StatusBadRequest, "unable to bind request: "+err.Error()) + return + } + + if req.Name != "" && req.Name != extension.Name { + sendError(c, http.StatusBadRequest, "modifying extension name is not allowed") + return + } + + if req.Description != "" { + extension.Description = req.Description + } + + if req.Enabled != nil { + extension.Enabled = *req.Enabled + } + + tx, err := r.DB.BeginTx(c.Request.Context(), nil) + if err != nil { + sendError(c, http.StatusBadRequest, "error starting update transaction: "+err.Error()) + return + } + + if _, err := extension.Update(c.Request.Context(), tx, boil.Infer()); err != nil { + msg := fmt.Sprintf("error updating extension: %s. rolling back\n", err.Error()) + + if err := tx.Rollback(); err != nil { + msg += fmt.Sprintf("error rolling back transaction: %s", err.Error()) + } + + sendError(c, http.StatusBadRequest, msg) + + return + } + + event, err := dbtools.AuditExtensionUpdated( + c.Request.Context(), + tx, + getCtxAuditID(c), + getCtxUser(c), + &original, + extension, + ) + if err != nil { + msg := fmt.Sprintf("error updating extension (audit): %s", err.Error()) + + if err := tx.Rollback(); err != nil { + msg += fmt.Sprintf("error rolling back transaction: %s", err.Error()) + } + + sendError(c, http.StatusBadRequest, msg) + + return + } + + if err := updateContextWithAuditEventData(c, event); err != nil { + msg := fmt.Sprintf("error updating extension: %s", err.Error()) + + if err := tx.Rollback(); err != nil { + msg += fmt.Sprintf("error rolling back transaction: %s", err.Error()) + } + + sendError(c, http.StatusBadRequest, msg) + + return + } + + if err := tx.Commit(); err != nil { + msg := fmt.Sprintf("error committing extension update: %s", err.Error()) + + if err := tx.Rollback(); err != nil { + msg += fmt.Sprintf("error rolling back transaction: %s", err.Error()) + } + + sendError(c, http.StatusBadRequest, msg) + + return + } + + err = r.EventBus.Publish( + c.Request.Context(), + events.GovernorExtensionsEventSubject, + &events.Event{ + Version: events.Version, + Action: events.GovernorEventUpdate, + AuditID: c.GetString(ginaudit.AuditIDContextKey), + ActorID: getCtxActorID(c), + ExtensionID: extension.ID, + }, + ) + if err != nil { + sendError( + c, + http.StatusBadRequest, + fmt.Sprintf( + "failed to publish extension update event: %s\n%s", + err.Error(), + "downstream changes may be delayed", + ), + ) + + return + } + + c.JSON(http.StatusAccepted, extension) +} diff --git a/pkg/api/v1alpha1/extensions_test.go b/pkg/api/v1alpha1/extensions_test.go new file mode 100644 index 0000000..bedcd34 --- /dev/null +++ b/pkg/api/v1alpha1/extensions_test.go @@ -0,0 +1,596 @@ +package v1alpha1 + +import ( + "bytes" + "context" + "database/sql" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/cockroachdb/cockroach-go/v2/testserver" + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "github.com/jmoiron/sqlx" + "github.com/metal-toolbox/auditevent/ginaudit" + dbm "github.com/metal-toolbox/governor-api/db" + "github.com/metal-toolbox/governor-api/internal/eventbus" + "github.com/metal-toolbox/governor-api/internal/models" + events "github.com/metal-toolbox/governor-api/pkg/events/v1alpha1" + "github.com/pressly/goose/v3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + "go.hollow.sh/toolbox/ginauth" + "go.uber.org/zap" +) + +type ExtensionsTestSuite struct { + suite.Suite + + db *sql.DB + conn *mockNATSConn +} + +func (s *ExtensionsTestSuite) seedTestDB() error { + testData := []string{ + `INSERT INTO extensions (id, name, description, enabled, slug, status) + VALUES ('00000001-0000-0000-0000-000000000001', 'Test Extension', 'some extension', true, 'test-extension', 'online');`, + `INSERT INTO extensions (id, name, description, enabled, slug, status, deleted_at) + VALUES ('00000001-0000-0000-0000-000000000002', 'Deleted Extension', 'some deleted extension', true, 'deleted-extension', 'offline', '2023-07-12 12:00:00.000000+00');`, + `INSERT INTO extensions (id, name, description, enabled, slug, status) + VALUES ('00000001-0000-0000-0000-000000000003', 'Test Extension 3', 'some extension', true, 'test-extension-3', 'online');`, + } + + for _, q := range testData { + _, err := s.db.Query(q) + if err != nil { + return err + } + } + + return nil +} + +func (s *ExtensionsTestSuite) v1alpha1() *Router { + return &Router{ + AdminGroups: []string{"governor-admin"}, + AuthMW: &ginauth.MultiTokenMiddleware{}, + AuditMW: ginaudit.NewJSONMiddleware("governor-api", io.Discard), + DB: sqlx.NewDb(s.db, "postgres"), + EventBus: eventbus.NewClient(eventbus.WithNATSConn(s.conn)), + Logger: &zap.Logger{}, + } +} + +func (s *ExtensionsTestSuite) SetupSuite() { + s.conn = &mockNATSConn{} + + gin.SetMode(gin.TestMode) + + ts, err := testserver.NewTestServer() + if err != nil { + panic(err) + } + + s.db, err = sql.Open("postgres", ts.PGURL().String()) + if err != nil { + panic(err) + } + + goose.SetBaseFS(dbm.Migrations) + + if err := goose.Up(s.db, "migrations"); err != nil { + panic("migration failed - could not set up test db") + } + + if err := s.seedTestDB(); err != nil { + panic("db setup failed - could not seed test db: " + err.Error()) + } +} + +func (s *ExtensionsTestSuite) TestCreateExtension() { + r := s.v1alpha1() + + tests := []struct { + name string + url string + payload string + expectedResp *Extension + expectedStatus int + expectedErrMsg string + expectedEventSubject string + expectedEventPayload *events.Event + }{ + { + name: "ok", + url: "/api/v1alpha1/extensions", + expectedStatus: http.StatusAccepted, + payload: `{ "name": "Test Extension 1", "description": "some test", "enabled": true }`, + expectedEventSubject: "events.extensions", + expectedEventPayload: &events.Event{ + Action: events.GovernorEventCreate, + }, + expectedResp: &Extension{&models.Extension{ + Name: "Test Extension 1", + Description: "some test", + Slug: "test-extension-1", + Enabled: true, + }}, + }, + { + name: "enabled false", + url: "/api/v1alpha1/extensions", + expectedStatus: http.StatusAccepted, + payload: `{ "name": "Test Extension 2", "description": "some test", "enabled": false }`, + expectedEventSubject: "events.extensions", + expectedEventPayload: &events.Event{ + Action: events.GovernorEventCreate, + }, + expectedResp: &Extension{&models.Extension{ + Name: "Test Extension 2", + Description: "some test", + Slug: "test-extension-2", + Enabled: false, + }}, + }, + { + name: "duplicate entry", + url: "/api/v1alpha1/extensions", + payload: `{ "name": "Test Extension 2", "description": "some test", "enabled": true }`, + expectedEventSubject: "events.extensions", + expectedErrMsg: "duplicate key value violates unique constraint", + expectedStatus: http.StatusBadRequest, + }, + } + + for _, tt := range tests { + s.T().Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + auditID := uuid.New().String() + + req, _ := http.NewRequest("POST", tt.url, nil) + req.Body = io.NopCloser(bytes.NewBufferString(tt.payload)) + c.Request = req + c.Set(ginaudit.AuditIDContextKey, auditID) + + r.createExtension(c) + + assert.Equal(t, tt.expectedStatus, w.Code, "Expected status %d, got %d", tt.expectedStatus, w.Code) + + if tt.expectedErrMsg != "" { + body := w.Body.String() + assert.Contains( + t, body, tt.expectedErrMsg, + "Expected error message to contain %q, got %s", tt.expectedErrMsg, body, + ) + + return + } + + event := &events.Event{} + err := json.Unmarshal(s.conn.Payload, event) + assert.Nil(t, err) + + ex := &Extension{} + body := w.Body.String() + err = json.Unmarshal([]byte(body), ex) + assert.Nil(t, err) + + assert.Equal( + t, tt.expectedResp.Name, ex.Name, + "Expected extension name %s, got %s", t, tt.expectedResp.Name, ex.Name, + ) + + assert.Equal( + t, tt.expectedResp.Slug, ex.Slug, + "Expected extension slug %s, got %s", t, tt.expectedResp.Slug, ex.Slug, + ) + + assert.Equal( + t, tt.expectedResp.Description, ex.Description, + "Expected extension description %s, got %s", t, tt.expectedResp.Description, ex.Description, + ) + + assert.Equal( + t, tt.expectedResp.Enabled, ex.Enabled, + ) + + assert.Equal( + t, tt.expectedEventSubject, s.conn.Subject, + "Expected event subject %s, got %s", tt.expectedEventSubject, s.conn.Subject, + ) + + assert.Equal( + t, tt.expectedEventPayload.Action, event.Action, + "Expected event action %s, got %s", tt.expectedEventPayload.Action, event.Action, + ) + + assert.Equal( + t, event.ExtensionID, ex.ID, + "Expected event extension ID to match response ID", + ) + }) + } +} + +func (s *ExtensionsTestSuite) TestListExtensions() { + r := s.v1alpha1() + + tests := []struct { + name string + url string + expectedStatus int + expectedErrMsg string + expectedCount int + }{ + { + name: "ok", + url: "/api/v1alpha1/extensions", + expectedStatus: http.StatusOK, + expectedCount: 4, + }, + { + name: "list deleted", + url: "/api/v1alpha1/extensions?deleted", + expectedStatus: http.StatusOK, + expectedCount: 5, + }, + } + + for _, tt := range tests { + s.T().Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + auditID := uuid.New().String() + + req, _ := http.NewRequest("GET", tt.url, nil) + req = req.WithContext(context.Background()) + c.Request = req + c.Set(ginaudit.AuditIDContextKey, auditID) + + r.listExtensions(c) + + assert.Equal(t, tt.expectedStatus, w.Code, "Expected status %d, got %d", tt.expectedStatus, w.Code) + + if tt.expectedErrMsg != "" { + body := w.Body.String() + assert.Contains( + t, body, tt.expectedErrMsg, + "Expected error message to contain %q, got %s", tt.expectedErrMsg, body, + ) + + return + } + + body := w.Body.String() + resp := []interface{}{} + err := json.Unmarshal([]byte(body), &resp) + + assert.Nil(t, err, "expecting unmarshal err to be nil") + assert.Equal(t, tt.expectedCount, len(resp)) + }) + } +} + +func (s *ExtensionsTestSuite) TestGetExtension() { + r := s.v1alpha1() + + tests := []struct { + name string + url string + params gin.Params + expectedStatus int + expectedErrMsg string + }{ + { + name: "get by ID ok", + url: "/api/v1alpha1/extensions/00000001-0000-0000-0000-000000000001", + expectedStatus: http.StatusOK, + params: gin.Params{ + gin.Param{Key: "eid", Value: "00000001-0000-0000-0000-000000000001"}, + }, + }, + { + name: "get by slug ok", + url: "/api/v1alpha1/extensions/test-extension", + expectedStatus: http.StatusOK, + params: gin.Params{ + gin.Param{Key: "eid", Value: "test-extension"}, + }, + }, + { + name: "get deleted ok", + url: "/api/v1alpha1/extensions/00000001-0000-0000-0000-000000000002?deleted", + expectedStatus: http.StatusOK, + params: gin.Params{ + gin.Param{Key: "eid", Value: "00000001-0000-0000-0000-000000000002"}, + }, + }, + { + name: "get deleted by slug", + url: "/api/v1alpha1/extensions/deleted-extension?deleted", + expectedStatus: http.StatusBadRequest, + params: gin.Params{ + gin.Param{Key: "eid", Value: "deleted-extension"}, + }, + }, + { + name: "extension not found by slug", + url: "/api/v1alpha1/extensions/nonexistent-extension", + expectedStatus: http.StatusNotFound, + expectedErrMsg: "extension not found", + params: gin.Params{ + gin.Param{Key: "eid", Value: "nonexistent-extension"}, + }, + }, + { + name: "extension not found by ID", + url: "/api/v1alpha1/extensions/00000001-0000-0000-0000-000000000002", + expectedStatus: http.StatusNotFound, + expectedErrMsg: "extension not found", + params: gin.Params{ + gin.Param{Key: "eid", Value: "00000001-0000-0000-0000-000000000002"}, + }, + }, + } + + for _, tt := range tests { + s.T().Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + auditID := uuid.New().String() + + req, _ := http.NewRequest("GET", tt.url, nil) + req = req.WithContext(context.Background()) + c.Request = req + c.Params = tt.params + c.Set(ginaudit.AuditIDContextKey, auditID) + + r.getExtension(c) + + assert.Equal(t, tt.expectedStatus, w.Code, "Expected status %d, got %d", tt.expectedStatus, w.Code) + + if tt.expectedErrMsg != "" { + body := w.Body.String() + assert.Contains( + t, body, tt.expectedErrMsg, + "Expected error message to contain %q, got %s", tt.expectedErrMsg, body, + ) + + return + } + }) + } +} + +func (s *ExtensionsTestSuite) TestUpdateExtension() { + r := s.v1alpha1() + + tests := []struct { + name string + url string + params gin.Params + payload string + expectedResp *Extension + expectedStatus int + expectedErrMsg string + expectedEventSubject string + expectedEventPayload *events.Event + }{ + { + name: "disable extension", + url: "/api/v1alpha1/extensions/00000001-0000-0000-0000-000000000001", + expectedStatus: http.StatusAccepted, + payload: `{ "name": "Test Extension", "description": "some test", "enabled": false }`, + expectedEventSubject: "events.extensions", + expectedEventPayload: &events.Event{ + Action: events.GovernorEventUpdate, + ExtensionID: "00000001-0000-0000-0000-000000000001", + }, + expectedResp: &Extension{&models.Extension{ + Name: "Test Extension", + Description: "some test", + Slug: "test-extension", + Enabled: false, + }}, + params: gin.Params{ + gin.Param{Key: "eid", Value: "00000001-0000-0000-0000-000000000001"}, + }, + }, + { + name: "update by slug", + url: "/api/v1alpha1/extensions/test-extension-1", + expectedStatus: http.StatusAccepted, + payload: `{ "name": "Test Extension", "description": "some test", "enabled": true }`, + expectedEventSubject: "events.extensions", + expectedEventPayload: &events.Event{ + Action: events.GovernorEventUpdate, + ExtensionID: "00000001-0000-0000-0000-000000000001", + }, + expectedResp: &Extension{&models.Extension{ + Name: "Test Extension", + Description: "some test", + Slug: "test-extension", + Enabled: true, + }}, + params: gin.Params{ + gin.Param{Key: "eid", Value: "00000001-0000-0000-0000-000000000001"}, + }, + }, + { + name: "change name", + url: "/api/v1alpha1/extensions/00000001-0000-0000-0000-000000000001", + expectedStatus: http.StatusBadRequest, + payload: `{ "name": "Test Extension 2", "description": "some test", "enabled": false }`, + expectedErrMsg: "modifying extension name is not allowed", + params: gin.Params{ + gin.Param{Key: "eid", Value: "00000001-0000-0000-0000-000000000001"}, + }, + }, + { + name: "extension not found", + url: "/api/v1alpha1/extensions/00000001-0000-0000-0000-000000000002", + expectedStatus: http.StatusNotFound, + payload: `{ "name": "Test Extension 2", "description": "some test", "enabled": false }`, + expectedErrMsg: "not found", + params: gin.Params{ + gin.Param{Key: "eid", Value: "00000001-0000-0000-0000-000000000002"}, + }, + }, + } + + for _, tt := range tests { + s.T().Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + auditID := uuid.New().String() + + req, _ := http.NewRequest("PATCH", tt.url, nil) + req = req.WithContext(context.Background()) + req.Body = io.NopCloser(bytes.NewBufferString(tt.payload)) + c.Request = req + c.Params = tt.params + c.Set(ginaudit.AuditIDContextKey, auditID) + + r.updateExtension(c) + + assert.Equal(t, tt.expectedStatus, w.Code, "Expected status %d, got %d", tt.expectedStatus, w.Code) + + if tt.expectedErrMsg != "" { + body := w.Body.String() + assert.Contains( + t, body, tt.expectedErrMsg, + "Expected error message to contain %q, got %s", tt.expectedErrMsg, body, + ) + + return + } + + event := &events.Event{} + err := json.Unmarshal(s.conn.Payload, event) + assert.Nil(t, err) + + ex := &Extension{} + body := w.Body.String() + err = json.Unmarshal([]byte(body), ex) + assert.Nil(t, err) + + assert.Equal( + t, tt.expectedResp.Name, ex.Name, + "Expected extension name %s, got %s", t, tt.expectedResp.Name, ex.Name, + ) + + assert.Equal( + t, tt.expectedResp.Slug, ex.Slug, + "Expected extension slug %s, got %s", t, tt.expectedResp.Slug, ex.Slug, + ) + + assert.Equal( + t, tt.expectedResp.Description, ex.Description, + "Expected extension description %s, got %s", t, tt.expectedResp.Description, ex.Description, + ) + + assert.Equal( + t, tt.expectedResp.Enabled, ex.Enabled, + ) + + assert.Equal( + t, tt.expectedEventPayload.Action, event.Action, + "Expected event action %s, got %s", tt.expectedEventPayload.Action, event.Action, + ) + + assert.Equal( + t, tt.expectedEventSubject, s.conn.Subject, + "Expected event subject %s, got %s", tt.expectedEventSubject, s.conn.Subject, + ) + + assert.Equal( + t, tt.expectedEventPayload.ExtensionID, event.ExtensionID, + "Expected event extension ID %s, got %s", tt.expectedEventPayload.ExtensionID, event.ExtensionID, + ) + }) + } +} + +func (s *ExtensionsTestSuite) TestDeleteExtension() { + r := s.v1alpha1() + + tests := []struct { + name string + url string + params gin.Params + expectedStatus int + expectedErrMsg string + expectedCount int + }{ + { + name: "delete by ID ok", + url: "/api/v1alpha1/extensions/00000001-0000-0000-0000-000000000001", + expectedStatus: http.StatusOK, + params: gin.Params{ + gin.Param{Key: "eid", Value: "00000001-0000-0000-0000-000000000001"}, + }, + }, + { + name: "delete by slug ok", + url: "/api/v1alpha1/extensions/test-extension-3", + expectedStatus: http.StatusOK, + params: gin.Params{ + gin.Param{Key: "eid", Value: "test-extension-3"}, + }, + }, + { + name: "extension not found by ID", + url: "/api/v1alpha1/extensions/nonexistent-extension", + expectedStatus: http.StatusNotFound, + expectedErrMsg: "extension not found", + params: gin.Params{ + gin.Param{Key: "eid", Value: "nonexistent-extension"}, + }, + }, + { + name: "extension not found by slug", + url: "/api/v1alpha1/extensions/00000001-0000-0000-0000-000000000002", + expectedStatus: http.StatusNotFound, + expectedErrMsg: "extension not found", + params: gin.Params{ + gin.Param{Key: "eid", Value: "00000001-0000-0000-0000-000000000002"}, + }, + }, + } + + for _, tt := range tests { + s.T().Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + auditID := uuid.New().String() + + req, _ := http.NewRequest("GET", tt.url, nil) + req = req.WithContext(context.Background()) + c.Request = req + c.Params = tt.params + c.Set(ginaudit.AuditIDContextKey, auditID) + + r.getExtension(c) + + assert.Equal(t, tt.expectedStatus, w.Code, "Expected status %d, got %d", tt.expectedStatus, w.Code) + + if tt.expectedErrMsg != "" { + body := w.Body.String() + assert.Contains( + t, body, tt.expectedErrMsg, + "Expected error message to contain %q, got %s", tt.expectedErrMsg, body, + ) + + return + } + }) + } +} + +func TestExtensionSuite(t *testing.T) { + suite.Run(t, new(ExtensionsTestSuite)) +} diff --git a/pkg/api/v1alpha1/router.go b/pkg/api/v1alpha1/router.go index fe64853..d475409 100644 --- a/pkg/api/v1alpha1/router.go +++ b/pkg/api/v1alpha1/router.go @@ -566,6 +566,45 @@ func (r *Router) Routes(rg *gin.RouterGroup) { r.mwUserAuthRequired(AuthRoleAdmin), r.deleteNotificationTarget, ) + + // extensions + rg.GET( + "/extensions", + r.AuditMW.AuditWithType("ListExtensions"), + r.AuthMW.AuthRequired(readScopesWithOpenID("governor:extensions")), + r.listExtensions, + ) + + rg.GET( + "/extensions/:eid", + r.AuditMW.AuditWithType("GetExtension"), + r.AuthMW.AuthRequired(readScopesWithOpenID("governor:extensions")), + r.getExtension, + ) + + rg.POST( + "/extensions", + r.AuditMW.AuditWithType("CreateExtension"), + r.AuthMW.AuthRequired(createScopesWithOpenID("governor:extensions")), + r.mwUserAuthRequired(AuthRoleAdmin), + r.createExtension, + ) + + rg.PATCH( + "/extensions/:eid", + r.AuditMW.AuditWithType("UpdateExtension"), + r.AuthMW.AuthRequired(updateScopesWithOpenID("governor:extensions")), + r.mwUserAuthRequired(AuthRoleAdmin), + r.updateExtension, + ) + + rg.DELETE( + "/extensions/:eid", + r.AuditMW.AuditWithType("DeleteExtension"), + r.AuthMW.AuthRequired(deleteScopesWithOpenID("governor:extensions")), + r.mwUserAuthRequired(AuthRoleAdmin), + r.deleteExtension, + ) } func contains(list []string, item string) bool { diff --git a/pkg/api/v1alpha1/testing.go b/pkg/api/v1alpha1/testing.go new file mode 100644 index 0000000..fe6e48d --- /dev/null +++ b/pkg/api/v1alpha1/testing.go @@ -0,0 +1,14 @@ +package v1alpha1 + +type mockNATSConn struct { + Subject string + Payload []byte +} + +func (m *mockNATSConn) Drain() error { return nil } +func (m *mockNATSConn) Publish(s string, p []byte) error { + m.Subject = s + m.Payload = p + + return nil +} diff --git a/pkg/client/errors.go b/pkg/client/errors.go index 7e92bac..3fd17cf 100644 --- a/pkg/client/errors.go +++ b/pkg/client/errors.go @@ -47,4 +47,7 @@ var ( // ErrMissingNotificationTargetID is returned when a a missing or bad notification target ID is passed to a request ErrMissingNotificationTargetID = errors.New("missing notification target id in request") + + // ErrMissingExtensionIDOrSlug is returned when a missing or bad extension ID is passed to a request + ErrMissingExtensionIDOrSlug = errors.New("missing extension id or slug in request") ) diff --git a/pkg/client/extensions.go b/pkg/client/extensions.go new file mode 100644 index 0000000..63933ae --- /dev/null +++ b/pkg/client/extensions.go @@ -0,0 +1,241 @@ +package client + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/metal-toolbox/governor-api/pkg/api/v1alpha1" +) + +// Extension fetch an extension +func (c *Client) Extension(ctx context.Context, idOrSlug string, deleted bool) (*v1alpha1.Extension, error) { + if idOrSlug == "" { + return nil, ErrMissingExtensionIDOrSlug + } + + u := fmt.Sprintf( + "%s/api/%s/extensions/%s", + c.url, + governorAPIVersionAlpha, + idOrSlug, + ) + if deleted { + u += "?deleted" + } + + req, err := c.newGovernorRequest(ctx, http.MethodGet, u) + if err != nil { + return nil, err + } + + resp, err := c.httpClient.Do(req.WithContext(ctx)) + if err != nil { + return nil, err + } + + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNotFound { + return nil, v1alpha1.ErrExtensionNotFound + } + + if resp.StatusCode != http.StatusOK && + resp.StatusCode != http.StatusAccepted && + resp.StatusCode != http.StatusNoContent { + return nil, ErrRequestNonSuccess + } + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + nt := &v1alpha1.Extension{} + if err := json.Unmarshal(respBody, nt); err != nil { + return nil, err + } + + return nt, nil +} + +// Extensions list all extensions +func (c *Client) Extensions(ctx context.Context, deleted bool) ([]*v1alpha1.Extension, error) { + u := fmt.Sprintf( + "%s/api/%s/extensions", + c.url, + governorAPIVersionAlpha, + ) + if deleted { + u += "?deleted" + } + + req, err := c.newGovernorRequest(ctx, http.MethodGet, u) + if err != nil { + return nil, err + } + + resp, err := c.httpClient.Do(req.WithContext(ctx)) + if err != nil { + return nil, err + } + + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && + resp.StatusCode != http.StatusAccepted && + resp.StatusCode != http.StatusNoContent { + return nil, ErrRequestNonSuccess + } + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + nt := []*v1alpha1.Extension{} + if err := json.Unmarshal(respBody, &nt); err != nil { + return nil, err + } + + return nt, nil +} + +// CreateExtension creates an extension +func (c *Client) CreateExtension(ctx context.Context, exReq *v1alpha1.ExtensionReq) (*v1alpha1.Extension, error) { + req, err := c.newGovernorRequest( + ctx, http.MethodPost, + fmt.Sprintf("%s/api/%s/extensions", c.url, governorAPIVersionAlpha), + ) + if err != nil { + return nil, err + } + + exReqJSON, err := json.Marshal(exReq) + if err != nil { + return nil, err + } + + req.Body = io.NopCloser(bytes.NewReader(exReqJSON)) + + resp, err := c.httpClient.Do(req.WithContext(ctx)) + if err != nil { + return nil, err + } + + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && + resp.StatusCode != http.StatusAccepted && + resp.StatusCode != http.StatusNoContent { + return nil, ErrRequestNonSuccess + } + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + nt := &v1alpha1.Extension{} + if err := json.Unmarshal(respBody, nt); err != nil { + return nil, err + } + + return nt, nil +} + +// UpdateExtension updates an extension +func (c *Client) UpdateExtension( + ctx context.Context, idOrSlug string, exReq *v1alpha1.ExtensionReq, +) (*v1alpha1.Extension, error) { + if idOrSlug == "" { + return nil, ErrMissingExtensionIDOrSlug + } + + req, err := c.newGovernorRequest( + ctx, http.MethodPatch, + fmt.Sprintf( + "%s/api/%s/extensions/%s", + c.url, + governorAPIVersionAlpha, + idOrSlug, + ), + ) + if err != nil { + return nil, err + } + + exReqJSON, err := json.Marshal(exReq) + if err != nil { + return nil, err + } + + req.Body = io.NopCloser(bytes.NewReader(exReqJSON)) + + resp, err := c.httpClient.Do(req.WithContext(ctx)) + if err != nil { + return nil, err + } + + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNotFound { + return nil, v1alpha1.ErrExtensionNotFound + } + + if resp.StatusCode != http.StatusOK && + resp.StatusCode != http.StatusAccepted && + resp.StatusCode != http.StatusNoContent { + return nil, ErrRequestNonSuccess + } + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + nt := &v1alpha1.Extension{} + if err := json.Unmarshal(respBody, nt); err != nil { + return nil, err + } + + return nt, nil +} + +// DeleteExtension deletes an extension +func (c *Client) DeleteExtension(ctx context.Context, idOrSlug string) error { + if idOrSlug == "" { + return ErrMissingExtensionIDOrSlug + } + + req, err := c.newGovernorRequest( + ctx, http.MethodDelete, + fmt.Sprintf( + "%s/api/%s/extensions/%s", + c.url, + governorAPIVersionAlpha, + idOrSlug, + ), + ) + if err != nil { + return err + } + + resp, err := c.httpClient.Do(req.WithContext(ctx)) + if err != nil { + return err + } + + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && + resp.StatusCode != http.StatusAccepted && + resp.StatusCode != http.StatusNoContent { + return ErrRequestNonSuccess + } + + return nil +} diff --git a/pkg/client/extensions_test.go b/pkg/client/extensions_test.go new file mode 100644 index 0000000..6f14ef4 --- /dev/null +++ b/pkg/client/extensions_test.go @@ -0,0 +1,532 @@ +package client + +import ( + "context" + "encoding/json" + "net/http" + "testing" + + "github.com/metal-toolbox/governor-api/pkg/api/v1alpha1" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" + "golang.org/x/oauth2" +) + +const ( + testExtensionsResponse = `[ + { + "id": "35b9861f-83b5-49df-95b0-321cfe5c1532", + "name": "Test Extension 1", + "slug": "test-extension-1", + "description": "some test", + "enabled": true, + "status": "online", + "created_at": "2023-09-26T20:04:19.190374Z", + "updated_at": "2023-09-26T20:04:19.190374Z", + "deleted_at": null + }, + { + "id": "e311a55e-d77f-4289-ba69-e2cbea09e3a3", + "name": "Test Extension 2", + "slug": "test-extension-2", + "description": "some test", + "enabled": true, + "status": "online", + "created_at": "2023-09-26T20:04:19.190374Z", + "updated_at": "2023-09-26T20:04:19.190374Z", + "deleted_at": null + } + ]` + + testExtensionResponse = `{ + "id": "35b9861f-83b5-49df-95b0-321cfe5c1532", + "name": "Test Extension 1", + "slug": "test-extension-1", + "description": "some test", + "enabled": true, + "status": "online", + "created_at": "2023-09-26T20:04:19.190374Z", + "updated_at": "2023-09-26T20:04:19.190374Z", + "deleted_at": null + }` +) + +func TestClient_Extensions(t *testing.T) { + testResp := func(r []byte) []*v1alpha1.Extension { + resp := []*v1alpha1.Extension{} + if err := json.Unmarshal(r, &resp); err != nil { + t.Error(err) + } + + return resp + } + + type fields struct { + httpClient HTTPDoer + } + + tests := []struct { + name string + fields fields + want []*v1alpha1.Extension + wantErr bool + }{ + { + name: "example request", + fields: fields{ + httpClient: &mockHTTPDoer{ + t: t, + resp: []byte(testExtensionsResponse), + statusCode: http.StatusOK, + }, + }, + want: testResp([]byte(testExtensionsResponse)), + }, + { + name: "non-success", + fields: fields{ + httpClient: &mockHTTPDoer{ + t: t, + statusCode: http.StatusInternalServerError, + }, + }, + wantErr: true, + }, + { + name: "bad json response", + fields: fields{ + httpClient: &mockHTTPDoer{ + t: t, + statusCode: http.StatusOK, + resp: []byte(`{`), + }, + }, + wantErr: true, + }, + { + name: "null response", + fields: fields{ + httpClient: &mockHTTPDoer{ + t: t, + statusCode: http.StatusOK, + resp: []byte(`null`), + }, + }, + want: []*v1alpha1.Extension(nil), + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Client{ + url: "https://the.gov/", + logger: zap.NewNop(), + httpClient: tt.fields.httpClient, + clientCredentialConfig: &mockTokener{t: t}, + token: &oauth2.Token{AccessToken: "topSekret"}, + } + got, err := c.Extensions(context.TODO(), false) + if tt.wantErr { + assert.Error(t, err) + return + } + + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestClient_Extension(t *testing.T) { + testResp := func(r []byte) *v1alpha1.Extension { + resp := &v1alpha1.Extension{} + if err := json.Unmarshal(r, resp); err != nil { + t.Error(err) + } + + return resp + } + + type fields struct { + httpClient HTTPDoer + } + + tests := []struct { + name string + fields fields + want *v1alpha1.Extension + wantErr bool + id string + }{ + { + name: "example request", + id: "test-extension-1", + fields: fields{ + httpClient: &mockHTTPDoer{ + t: t, + resp: []byte(testExtensionResponse), + statusCode: http.StatusOK, + }, + }, + want: testResp([]byte(testExtensionResponse)), + }, + { + name: "non-success", + id: "test-extension-1", + fields: fields{ + httpClient: &mockHTTPDoer{ + t: t, + statusCode: http.StatusInternalServerError, + }, + }, + wantErr: true, + }, + { + name: "bad json response", + id: "test-extension-1", + fields: fields{ + httpClient: &mockHTTPDoer{ + t: t, + statusCode: http.StatusOK, + resp: []byte(`{`), + }, + }, + wantErr: true, + }, + { + name: "missing id", + id: "test-extension-1", + fields: fields{ + httpClient: &mockHTTPDoer{ + t: t, + statusCode: http.StatusOK, + }, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Client{ + url: "https://the.gov/", + logger: zap.NewNop(), + httpClient: tt.fields.httpClient, + clientCredentialConfig: &mockTokener{t: t}, + token: &oauth2.Token{AccessToken: "topSekret"}, + } + got, err := c.Extension(context.TODO(), tt.id, false) + if tt.wantErr { + assert.Error(t, err) + return + } + + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestClient_CreateExtension(t *testing.T) { + testResp := func(r []byte) *v1alpha1.Extension { + resp := &v1alpha1.Extension{} + if err := json.Unmarshal(r, resp); err != nil { + t.Error(err) + } + + return resp + } + + enabled := true + + type fields struct { + httpClient HTTPDoer + } + + tests := []struct { + name string + fields fields + req *v1alpha1.ExtensionReq + want *v1alpha1.Extension + wantErr bool + }{ + { + name: "example request", + fields: fields{ + httpClient: &mockHTTPDoer{ + t: t, + resp: []byte(testExtensionResponse), + statusCode: http.StatusOK, + }, + }, + req: &v1alpha1.ExtensionReq{ + Name: "Test Extension 1", + Description: "some test", + Enabled: &enabled, + }, + want: testResp([]byte(testExtensionResponse)), + }, + { + name: "example request status accepted", + fields: fields{ + httpClient: &mockHTTPDoer{ + t: t, + resp: []byte(testExtensionResponse), + statusCode: http.StatusAccepted, + }, + }, + req: &v1alpha1.ExtensionReq{ + Name: "Test Extension 1", + Description: "some test", + Enabled: &enabled, + }, + want: testResp([]byte(testExtensionResponse)), + }, + { + name: "non-success", + fields: fields{ + httpClient: &mockHTTPDoer{ + t: t, + statusCode: http.StatusInternalServerError, + }, + }, + req: &v1alpha1.ExtensionReq{ + Name: "Test Extension 1", + Description: "some test", + Enabled: &enabled, + }, + wantErr: true, + }, + { + name: "bad json response", + fields: fields{ + httpClient: &mockHTTPDoer{ + t: t, + statusCode: http.StatusOK, + resp: []byte(`{`), + }, + }, + req: &v1alpha1.ExtensionReq{ + Name: "Test Extension 1", + Description: "some test", + Enabled: &enabled, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Client{ + url: "https://the.gov/", + logger: zap.NewNop(), + httpClient: tt.fields.httpClient, + clientCredentialConfig: &mockTokener{t: t}, + token: &oauth2.Token{AccessToken: "topSekret"}, + } + got, err := c.CreateExtension(context.TODO(), tt.req) + if tt.wantErr { + assert.Error(t, err) + return + } + + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestClient_UpdateExtension(t *testing.T) { + testResp := func(r []byte) *v1alpha1.Extension { + resp := &v1alpha1.Extension{} + if err := json.Unmarshal(r, resp); err != nil { + t.Error(err) + } + + return resp + } + + type fields struct { + httpClient HTTPDoer + } + + tests := []struct { + name string + fields fields + id string + req *v1alpha1.ExtensionReq + want *v1alpha1.Extension + wantErr bool + }{ + { + name: "example request", + fields: fields{ + httpClient: &mockHTTPDoer{ + t: t, + resp: []byte(testExtensionResponse), + statusCode: http.StatusOK, + }, + }, + id: "test-extension-1", + req: &v1alpha1.ExtensionReq{ + Description: "some test", + }, + want: testResp([]byte(testExtensionResponse)), + }, + { + name: "example request status accepted", + fields: fields{ + httpClient: &mockHTTPDoer{ + t: t, + resp: []byte(testExtensionResponse), + statusCode: http.StatusAccepted, + }, + }, + id: "test-extension-1", + req: &v1alpha1.ExtensionReq{ + Description: "some test", + }, + want: testResp([]byte(testExtensionResponse)), + }, + { + name: "non-success", + fields: fields{ + httpClient: &mockHTTPDoer{ + t: t, + statusCode: http.StatusInternalServerError, + }, + }, + id: "test-extension-1", + req: &v1alpha1.ExtensionReq{ + Description: "some test", + }, + wantErr: true, + }, + { + name: "bad json response", + fields: fields{ + httpClient: &mockHTTPDoer{ + t: t, + statusCode: http.StatusOK, + resp: []byte(`{`), + }, + }, + id: "test-extension-1", + req: &v1alpha1.ExtensionReq{ + Description: "some test", + }, + wantErr: true, + }, + { + name: "missing id", + fields: fields{ + httpClient: &mockHTTPDoer{ + t: t, + resp: []byte(testExtensionResponse), + statusCode: http.StatusOK, + }, + }, + req: &v1alpha1.ExtensionReq{ + Description: "some test", + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Client{ + url: "https://the.gov/", + logger: zap.NewNop(), + httpClient: tt.fields.httpClient, + clientCredentialConfig: &mockTokener{t: t}, + token: &oauth2.Token{AccessToken: "topSekret"}, + } + got, err := c.UpdateExtension(context.TODO(), tt.id, tt.req) + if tt.wantErr { + assert.Error(t, err) + return + } + + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestClient_DeleteExtension(t *testing.T) { + testResp := func(r []byte) *v1alpha1.Extension { + resp := &v1alpha1.Extension{} + if err := json.Unmarshal(r, resp); err != nil { + t.Error(err) + } + + return resp + } + + type fields struct { + httpClient HTTPDoer + } + + tests := []struct { + name string + fields fields + want *v1alpha1.Extension + wantErr bool + id string + }{ + { + name: "example request", + id: "test-extension-1", + fields: fields{ + httpClient: &mockHTTPDoer{ + t: t, + resp: []byte(testExtensionResponse), + statusCode: http.StatusOK, + }, + }, + want: testResp([]byte(testExtensionResponse)), + }, + { + name: "non-success", + id: "test-extension-1", + fields: fields{ + httpClient: &mockHTTPDoer{ + t: t, + statusCode: http.StatusInternalServerError, + }, + }, + wantErr: true, + }, + { + name: "missing id", + fields: fields{ + httpClient: &mockHTTPDoer{ + t: t, + statusCode: http.StatusOK, + }, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Client{ + url: "https://the.gov/", + logger: zap.NewNop(), + httpClient: tt.fields.httpClient, + clientCredentialConfig: &mockTokener{t: t}, + token: &oauth2.Token{AccessToken: "topSekret"}, + } + err := c.DeleteExtension(context.TODO(), tt.id) + if tt.wantErr { + assert.Error(t, err) + return + } + + assert.NoError(t, err) + }) + } +} diff --git a/pkg/events/v1alpha1/events.go b/pkg/events/v1alpha1/events.go index 15adc48..7772eb8 100644 --- a/pkg/events/v1alpha1/events.go +++ b/pkg/events/v1alpha1/events.go @@ -39,6 +39,8 @@ const ( GovernorNotificationTypesEventSubject = "notification.types" // GovernorNotificationTargetsEventSubject is the subject name for notification target events (minus the subject prefix) GovernorNotificationTargetsEventSubject = "notification.targets" + // GovernorExtensionsEventSubject is the subject name for extensions events (minus the subject prefix) + GovernorExtensionsEventSubject = "extensions" ) // Event is an event notification from Governor. @@ -54,6 +56,9 @@ type Event struct { NotificationTypeID string `json:"notification_type_id,omitempty"` NotificationTargetID string `json:"notification_target_id,omitempty"` + ExtensionID string `json:"extension_id,omitempty"` + ExtensionResourceDefinitionID string `json:"extension_resource_definition_id,omitempty"` + // TraceContext is a map of values used for OpenTelemetry context propagation. TraceContext map[string]string `json:"traceContext"` }