Skip to content

Commit

Permalink
BED-5305 fix: inability to swap SSO providers directly (#1076)
Browse files Browse the repository at this point in the history
  • Loading branch information
mistahj67 authored Jan 15, 2025
1 parent 7f7b17f commit be76b91
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 16 deletions.
7 changes: 6 additions & 1 deletion cmd/api/src/api/v2/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -418,17 +418,22 @@ func (s ManagementResource) UpdateUser(response http.ResponseWriter, request *ht
} else if provider, err := s.db.GetSAMLProvider(request.Context(), samlProviderID); err != nil {
api.HandleDatabaseError(request, response, err)
return
} else if ssoProvider, err := s.db.GetSSOProviderById(request.Context(), provider.SSOProviderID.Int32); err != nil {
api.HandleDatabaseError(request, response, err)
return
} else {
// Ensure that the AuthSecret reference is nil and the SSO provider is set
user.AuthSecret = nil // Required or the below updateUser will re-add the authSecret
user.SSOProvider = &ssoProvider
user.SSOProviderID = provider.SSOProviderID
}
} else if updateUserRequest.SSOProviderID.Valid {
if _, err := s.db.GetSSOProviderById(request.Context(), updateUserRequest.SSOProviderID.Int32); err != nil {
if ssoProvider, err := s.db.GetSSOProviderById(request.Context(), updateUserRequest.SSOProviderID.Int32); err != nil {
api.HandleDatabaseError(request, response, err)
return
} else {
user.AuthSecret = nil // Required or the below updateUser will re-add the authSecret
user.SSOProvider = &ssoProvider
user.SSOProviderID = updateUserRequest.SSOProviderID
}
} else {
Expand Down
41 changes: 26 additions & 15 deletions cmd/api/src/api/v2/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,8 @@ import (
)

const (
samlProviderPathFmt = "/api/v2/saml/providers/%d"
updateUserSecretPathFmt = "/api/v2/auth/users/%s/secret"
ssoProviderID int32 = 123
samlProviderID int32 = 1234
samlProviderIDStr = "1234"
samlProviderPathFmt = "/api/v2/saml/providers/%d"
updateUserSecretPathFmt = "/api/v2/auth/users/%s/secret"
)

func TestManagementResource_PutUserAuthSecret(t *testing.T) {
Expand Down Expand Up @@ -165,12 +162,24 @@ func TestManagementResource_PutUserAuthSecret(t *testing.T) {

func TestManagementResource_EnableUserSAML(t *testing.T) {
var (
adminUser = model.User{Unique: model.Unique{ID: must.NewUUIDv4()}}
goodRoles = []int32{0}
goodUserID = must.NewUUIDv4()
badUserID = must.NewUUIDv4()
mockCtrl = gomock.NewController(t)
resources, mockDB = apitest.NewAuthManagementResource(mockCtrl)

adminUser = model.User{Unique: model.Unique{ID: must.NewUUIDv4()}}
goodRoles = []int32{0}
goodUserID = must.NewUUIDv4()
badUserID = must.NewUUIDv4()

ssoProviderID int32 = 123
samlProviderIDStr = "1234"

ssoProvider = model.SSOProvider{
Serial: model.Serial{ID: ssoProviderID},
SAMLProvider: &model.SAMLProvider{
Serial: model.Serial{ID: 1234},
SSOProviderID: null.Int32From(ssoProviderID),
},
}
)

bhCtx := ctx.Get(context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}))
Expand All @@ -181,7 +190,8 @@ func TestManagementResource_EnableUserSAML(t *testing.T) {
t.Run("Successfully update user with deprecated saml provider", func(t *testing.T) {
mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Eq(goodRoles)).Return(model.Roles{}, nil)
mockDB.EXPECT().GetUser(gomock.Any(), goodUserID).Return(model.User{}, nil)
mockDB.EXPECT().GetSAMLProvider(gomock.Any(), samlProviderID).Return(model.SAMLProvider{}, nil)
mockDB.EXPECT().GetSAMLProvider(gomock.Any(), ssoProvider.SAMLProvider.ID).Return(*ssoProvider.SAMLProvider, nil)
mockDB.EXPECT().GetSSOProviderById(gomock.Any(), ssoProvider.ID).Return(ssoProvider, nil)
mockDB.EXPECT().UpdateUser(gomock.Any(), gomock.Any()).Return(nil)

test.Request(t).
Expand All @@ -200,7 +210,8 @@ func TestManagementResource_EnableUserSAML(t *testing.T) {
t.Run("Fails if auth secret set", func(t *testing.T) {
mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Eq(goodRoles)).Return(model.Roles{}, nil)
mockDB.EXPECT().GetUser(gomock.Any(), badUserID).Return(model.User{AuthSecret: &model.AuthSecret{}}, nil)
mockDB.EXPECT().GetSAMLProvider(gomock.Any(), samlProviderID).Return(model.SAMLProvider{}, nil)
mockDB.EXPECT().GetSAMLProvider(gomock.Any(), ssoProvider.SAMLProvider.ID).Return(*ssoProvider.SAMLProvider, nil)
mockDB.EXPECT().GetSSOProviderById(gomock.Any(), ssoProvider.ID).Return(ssoProvider, nil)
mockDB.EXPECT().UpdateUser(gomock.Any(), gomock.Any()).Return(nil)

test.Request(t).
Expand All @@ -219,7 +230,7 @@ func TestManagementResource_EnableUserSAML(t *testing.T) {
t.Run("Successful user update with sso provider-saml", func(t *testing.T) {
mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Eq(goodRoles)).Return(model.Roles{}, nil)
mockDB.EXPECT().GetUser(gomock.Any(), goodUserID).Return(model.User{}, nil)
mockDB.EXPECT().GetSSOProviderById(gomock.Any(), ssoProviderID).Return(model.SSOProvider{}, nil)
mockDB.EXPECT().GetSSOProviderById(gomock.Any(), ssoProvider.ID).Return(ssoProvider, nil)
mockDB.EXPECT().UpdateUser(gomock.Any(), gomock.Any()).Return(nil)

test.Request(t).
Expand All @@ -228,7 +239,7 @@ func TestManagementResource_EnableUserSAML(t *testing.T) {
WithBody(v2.UpdateUserRequest{
Principal: "tester",
Roles: goodRoles,
SSOProviderID: null.Int32From(123),
SSOProviderID: null.Int32From(ssoProvider.ID),
}).
OnHandlerFunc(resources.UpdateUser).
Require().
Expand Down Expand Up @@ -1544,14 +1555,14 @@ func TestManagementResource_UpdateUser_UserSelfModify(t *testing.T) {
t.Run("Prevent users from changing their own SSO provider", func(t *testing.T) {
mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{adminRole}, nil)
mockDB.EXPECT().GetUser(gomock.Any(), adminUser.ID).Return(adminUser, nil)
mockDB.EXPECT().GetSSOProviderById(gomock.Any(), ssoProviderID).Return(model.SSOProvider{}, nil)
mockDB.EXPECT().GetSSOProviderById(gomock.Any(), int32(1)).Return(model.SSOProvider{}, nil)
test.Request(t).
WithContext(bhCtx).
WithURLPathVars(map[string]string{"user_id": adminUser.ID.String()}).
WithBody(v2.UpdateUserRequest{
Principal: "tester",
Roles: goodRoles,
SSOProviderID: null.Int32From(123),
SSOProviderID: null.Int32From(1),
}).
OnHandlerFunc(resources.UpdateUser).
Require().
Expand Down

0 comments on commit be76b91

Please sign in to comment.