diff --git a/cmd/api/src/api/v2/auth/auth.go b/cmd/api/src/api/v2/auth/auth.go index b5ef6b738..9f5a737a1 100644 --- a/cmd/api/src/api/v2/auth/auth.go +++ b/cmd/api/src/api/v2/auth/auth.go @@ -418,17 +418,22 @@ func (s ManagementResource) UpdateUser(response http.ResponseWriter, request *ht } else if provider, err := s.db.GetSAMLProvider(request.Context(), samlProviderID); err != nil { api.HandleDatabaseError(request, response, err) return + } else if ssoProvider, err := s.db.GetSSOProviderById(request.Context(), provider.SSOProviderID.Int32); err != nil { + api.HandleDatabaseError(request, response, err) + return } else { // Ensure that the AuthSecret reference is nil and the SSO provider is set user.AuthSecret = nil // Required or the below updateUser will re-add the authSecret + user.SSOProvider = &ssoProvider user.SSOProviderID = provider.SSOProviderID } } else if updateUserRequest.SSOProviderID.Valid { - if _, err := s.db.GetSSOProviderById(request.Context(), updateUserRequest.SSOProviderID.Int32); err != nil { + if ssoProvider, err := s.db.GetSSOProviderById(request.Context(), updateUserRequest.SSOProviderID.Int32); err != nil { api.HandleDatabaseError(request, response, err) return } else { user.AuthSecret = nil // Required or the below updateUser will re-add the authSecret + user.SSOProvider = &ssoProvider user.SSOProviderID = updateUserRequest.SSOProviderID } } else { diff --git a/cmd/api/src/api/v2/auth/auth_test.go b/cmd/api/src/api/v2/auth/auth_test.go index ab4e4d6e2..a55a4a0a6 100644 --- a/cmd/api/src/api/v2/auth/auth_test.go +++ b/cmd/api/src/api/v2/auth/auth_test.go @@ -55,11 +55,8 @@ import ( ) const ( - samlProviderPathFmt = "/api/v2/saml/providers/%d" - updateUserSecretPathFmt = "/api/v2/auth/users/%s/secret" - ssoProviderID int32 = 123 - samlProviderID int32 = 1234 - samlProviderIDStr = "1234" + samlProviderPathFmt = "/api/v2/saml/providers/%d" + updateUserSecretPathFmt = "/api/v2/auth/users/%s/secret" ) func TestManagementResource_PutUserAuthSecret(t *testing.T) { @@ -165,12 +162,24 @@ func TestManagementResource_PutUserAuthSecret(t *testing.T) { func TestManagementResource_EnableUserSAML(t *testing.T) { var ( - adminUser = model.User{Unique: model.Unique{ID: must.NewUUIDv4()}} - goodRoles = []int32{0} - goodUserID = must.NewUUIDv4() - badUserID = must.NewUUIDv4() mockCtrl = gomock.NewController(t) resources, mockDB = apitest.NewAuthManagementResource(mockCtrl) + + adminUser = model.User{Unique: model.Unique{ID: must.NewUUIDv4()}} + goodRoles = []int32{0} + goodUserID = must.NewUUIDv4() + badUserID = must.NewUUIDv4() + + ssoProviderID int32 = 123 + samlProviderIDStr = "1234" + + ssoProvider = model.SSOProvider{ + Serial: model.Serial{ID: ssoProviderID}, + SAMLProvider: &model.SAMLProvider{ + Serial: model.Serial{ID: 1234}, + SSOProviderID: null.Int32From(ssoProviderID), + }, + } ) bhCtx := ctx.Get(context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{})) @@ -181,7 +190,8 @@ func TestManagementResource_EnableUserSAML(t *testing.T) { t.Run("Successfully update user with deprecated saml provider", func(t *testing.T) { mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Eq(goodRoles)).Return(model.Roles{}, nil) mockDB.EXPECT().GetUser(gomock.Any(), goodUserID).Return(model.User{}, nil) - mockDB.EXPECT().GetSAMLProvider(gomock.Any(), samlProviderID).Return(model.SAMLProvider{}, nil) + mockDB.EXPECT().GetSAMLProvider(gomock.Any(), ssoProvider.SAMLProvider.ID).Return(*ssoProvider.SAMLProvider, nil) + mockDB.EXPECT().GetSSOProviderById(gomock.Any(), ssoProvider.ID).Return(ssoProvider, nil) mockDB.EXPECT().UpdateUser(gomock.Any(), gomock.Any()).Return(nil) test.Request(t). @@ -200,7 +210,8 @@ func TestManagementResource_EnableUserSAML(t *testing.T) { t.Run("Fails if auth secret set", func(t *testing.T) { mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Eq(goodRoles)).Return(model.Roles{}, nil) mockDB.EXPECT().GetUser(gomock.Any(), badUserID).Return(model.User{AuthSecret: &model.AuthSecret{}}, nil) - mockDB.EXPECT().GetSAMLProvider(gomock.Any(), samlProviderID).Return(model.SAMLProvider{}, nil) + mockDB.EXPECT().GetSAMLProvider(gomock.Any(), ssoProvider.SAMLProvider.ID).Return(*ssoProvider.SAMLProvider, nil) + mockDB.EXPECT().GetSSOProviderById(gomock.Any(), ssoProvider.ID).Return(ssoProvider, nil) mockDB.EXPECT().UpdateUser(gomock.Any(), gomock.Any()).Return(nil) test.Request(t). @@ -219,7 +230,7 @@ func TestManagementResource_EnableUserSAML(t *testing.T) { t.Run("Successful user update with sso provider-saml", func(t *testing.T) { mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Eq(goodRoles)).Return(model.Roles{}, nil) mockDB.EXPECT().GetUser(gomock.Any(), goodUserID).Return(model.User{}, nil) - mockDB.EXPECT().GetSSOProviderById(gomock.Any(), ssoProviderID).Return(model.SSOProvider{}, nil) + mockDB.EXPECT().GetSSOProviderById(gomock.Any(), ssoProvider.ID).Return(ssoProvider, nil) mockDB.EXPECT().UpdateUser(gomock.Any(), gomock.Any()).Return(nil) test.Request(t). @@ -228,7 +239,7 @@ func TestManagementResource_EnableUserSAML(t *testing.T) { WithBody(v2.UpdateUserRequest{ Principal: "tester", Roles: goodRoles, - SSOProviderID: null.Int32From(123), + SSOProviderID: null.Int32From(ssoProvider.ID), }). OnHandlerFunc(resources.UpdateUser). Require(). @@ -1544,14 +1555,14 @@ func TestManagementResource_UpdateUser_UserSelfModify(t *testing.T) { t.Run("Prevent users from changing their own SSO provider", func(t *testing.T) { mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{adminRole}, nil) mockDB.EXPECT().GetUser(gomock.Any(), adminUser.ID).Return(adminUser, nil) - mockDB.EXPECT().GetSSOProviderById(gomock.Any(), ssoProviderID).Return(model.SSOProvider{}, nil) + mockDB.EXPECT().GetSSOProviderById(gomock.Any(), int32(1)).Return(model.SSOProvider{}, nil) test.Request(t). WithContext(bhCtx). WithURLPathVars(map[string]string{"user_id": adminUser.ID.String()}). WithBody(v2.UpdateUserRequest{ Principal: "tester", Roles: goodRoles, - SSOProviderID: null.Int32From(123), + SSOProviderID: null.Int32From(1), }). OnHandlerFunc(resources.UpdateUser). Require().