diff --git a/controllers/cohortdata.go b/controllers/cohortdata.go index b0b7ae1..7e0ba8c 100644 --- a/controllers/cohortdata.go +++ b/controllers/cohortdata.go @@ -52,7 +52,7 @@ func (u CohortDataController) RetrieveHistogramForCohortIdAndConceptId(c *gin.Co validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{cohortId}, cohortPairs) if !validAccessRequest { log.Printf("Error: invalid request") - c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"}) + c.JSON(http.StatusForbidden, gin.H{"message": "access denied"}) c.Abort() return } @@ -101,7 +101,7 @@ func (u CohortDataController) RetrieveDataBySourceIdAndCohortIdAndVariables(c *g validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{cohortId}, cohortPairs) if !validAccessRequest { log.Printf("Error: invalid request") - c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"}) + c.JSON(http.StatusForbidden, gin.H{"message": "access denied"}) c.Abort() return } @@ -254,7 +254,7 @@ func (u CohortDataController) RetrieveCohortOverlapStatsWithoutFilteringOnConcep validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{caseCohortId, controlCohortId}, cohortPairs) if !validAccessRequest { log.Printf("Error: invalid request") - c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"}) + c.JSON(http.StatusForbidden, gin.H{"message": "access denied"}) c.Abort() return } diff --git a/controllers/cohortdefinition.go b/controllers/cohortdefinition.go index d75cf25..ab4104d 100644 --- a/controllers/cohortdefinition.go +++ b/controllers/cohortdefinition.go @@ -1,28 +1,41 @@ package controllers import ( + "log" "net/http" "strconv" "github.com/gin-gonic/gin" + "github.com/uc-cdis/cohort-middleware/middlewares" "github.com/uc-cdis/cohort-middleware/models" "github.com/uc-cdis/cohort-middleware/utils" ) type CohortDefinitionController struct { cohortDefinitionModel models.CohortDefinitionI + teamProjectAuthz middlewares.TeamProjectAuthzI } -func NewCohortDefinitionController(cohortDefinitionModel models.CohortDefinitionI) CohortDefinitionController { - return CohortDefinitionController{cohortDefinitionModel: cohortDefinitionModel} +func NewCohortDefinitionController(cohortDefinitionModel models.CohortDefinitionI, teamProjectAuthz middlewares.TeamProjectAuthzI) CohortDefinitionController { + return CohortDefinitionController{ + cohortDefinitionModel: cohortDefinitionModel, + teamProjectAuthz: teamProjectAuthz, + } } func (u CohortDefinitionController) RetriveById(c *gin.Context) { - // TODO - add teamproject validation - check if user has the necessary atlas and arborist permissions cohortDefinitionId := c.Param("id") if cohortDefinitionId != "" { cohortDefinitionId, _ := strconv.Atoi(cohortDefinitionId) + // validate teamproject access permission for cohort: + validAccessRequest := u.teamProjectAuthz.TeamProjectValidationForCohort(c, cohortDefinitionId) + if !validAccessRequest { + log.Printf("Error: invalid request") + c.JSON(http.StatusForbidden, gin.H{"message": "access denied"}) + c.Abort() + return + } cohortDefinition, err := u.cohortDefinitionModel.GetCohortDefinitionById(cohortDefinitionId) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"message": "Error retrieving cohortDefinition", "error": err.Error()}) @@ -45,7 +58,15 @@ func (u CohortDefinitionController) RetriveStatsBySourceIdAndTeamProject(c *gin. c.Abort() return } - // TODO - validate teamproject against arborist + // validate teamproject access permission: + validAccessRequest := u.teamProjectAuthz.HasAccessToTeamProject(c, teamProject) + if !validAccessRequest { + log.Printf("Error: invalid request") + c.JSON(http.StatusForbidden, gin.H{"message": "access denied"}) + c.Abort() + return + } + if err1 == nil { cohortDefinitionsAndStats, err := u.cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId, teamProject) if err != nil { diff --git a/controllers/concept.go b/controllers/concept.go index 78fd7d8..c32e603 100644 --- a/controllers/concept.go +++ b/controllers/concept.go @@ -102,7 +102,7 @@ func (u ConceptController) RetrieveBreakdownStatsBySourceIdAndCohortId(c *gin.Co validAccessRequest := u.teamProjectAuthz.TeamProjectValidationForCohort(c, cohortId) if !validAccessRequest { log.Printf("Error: invalid request") - c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"}) + c.JSON(http.StatusForbidden, gin.H{"message": "access denied"}) c.Abort() return } @@ -135,7 +135,7 @@ func (u ConceptController) RetrieveBreakdownStatsBySourceIdAndCohortIdAndVariabl validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{cohortId}, cohortPairs) if !validAccessRequest { log.Printf("Error: invalid request") - c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"}) + c.JSON(http.StatusForbidden, gin.H{"message": "access denied"}) c.Abort() return } @@ -201,7 +201,7 @@ func (u ConceptController) RetrieveAttritionTable(c *gin.Context) { validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{cohortId}, cohortPairs) if !validAccessRequest { log.Printf("Error: invalid request") - c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"}) + c.JSON(http.StatusForbidden, gin.H{"message": "access denied"}) c.Abort() return } diff --git a/middlewares/teamprojectauthz.go b/middlewares/teamprojectauthz.go index 3b044ff..20493b0 100644 --- a/middlewares/teamprojectauthz.go +++ b/middlewares/teamprojectauthz.go @@ -13,6 +13,7 @@ type TeamProjectAuthzI interface { TeamProjectValidationForCohort(ctx *gin.Context, cohortDefinitionId int) bool TeamProjectValidation(ctx *gin.Context, cohortDefinitionIds []int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool TeamProjectValidationForCohortIdsList(ctx *gin.Context, uniqueCohortDefinitionIdsList []int) bool + HasAccessToTeamProject(ctx *gin.Context, teamProject string) bool } type HttpClientI interface { @@ -30,30 +31,40 @@ func NewTeamProjectAuthz(cohortDefinitionModel models.CohortDefinitionI, httpCli httpClient: httpClient, } } -func (u TeamProjectAuthz) hasAccessToAtLeastOne(ctx *gin.Context, teamProjects []string) bool { - // query Arborist and return as soon as one of the teamProjects access check returns 200: - for _, teamProject := range teamProjects { - teamProjectAsResourcePath := teamProject - teamProjectAccessService := "atlas-argo-wrapper-and-cohort-middleware" +func (u TeamProjectAuthz) HasAccessToTeamProject(ctx *gin.Context, teamProject string) bool { + teamProjectAsResourcePath := teamProject + teamProjectAccessService := "atlas-argo-wrapper-and-cohort-middleware" - req, err := PrepareNewArboristRequestForResourceAndService(ctx, teamProjectAsResourcePath, teamProjectAccessService) - if err != nil { - ctx.AbortWithStatus(500) - panic("Error while preparing Arborist request") - } - // send the request to Arborist: - resp, _ := u.httpClient.Do(req) - log.Printf("Got response status %d from Arborist...", resp.StatusCode) + req, err := PrepareNewArboristRequestForResourceAndService(ctx, teamProjectAsResourcePath, teamProjectAccessService) + if err != nil { + ctx.AbortWithStatus(500) + panic("Error while preparing Arborist request") + } + // send the request to Arborist: + resp, _ := u.httpClient.Do(req) + log.Printf("Got response status %d from Arborist...", resp.StatusCode) + + // arborist will return with 200 if the user has been granted access to the cohort-middleware URL in ctx: + if resp.StatusCode == 200 { + return true + } else { + // unauthorized or otherwise: + log.Printf("Authorization check for team project failed with status %d ...", resp.StatusCode) + return false + } +} - // arborist will return with 200 if the user has been granted access to the cohort-middleware URL in ctx: - if resp.StatusCode == 200 { +func (u TeamProjectAuthz) hasAccessToAtLeastOne(ctx *gin.Context, teamProjects []string) bool { + for _, teamProject := range teamProjects { + if u.HasAccessToTeamProject(ctx, teamProject) { return true } else { - // unauthorized or otherwise: - log.Printf("Status %d does NOT give access to team project...", resp.StatusCode) + // unauthorized: + log.Printf("NO access to team project...checking next one (if any)...") } } + log.Printf("NO access to any of the team projects queried...") return false } diff --git a/server/router.go b/server/router.go index 5a11ca2..910cf57 100644 --- a/server/router.go +++ b/server/router.go @@ -28,7 +28,8 @@ func NewRouter() *gin.Engine { authorized.GET("/source/by-name/:name", source.RetriveByName) authorized.GET("/sources", source.RetriveAll) - cohortdefinitions := controllers.NewCohortDefinitionController(*new(models.CohortDefinition)) + cohortdefinitions := controllers.NewCohortDefinitionController(*new(models.CohortDefinition), + middlewares.NewTeamProjectAuthz(*new(models.CohortDefinition), &http.Client{})) authorized.GET("/cohortdefinition/by-id/:id", cohortdefinitions.RetriveById) authorized.GET("/cohortdefinition-stats/by-source-id/:sourceid/by-team-project", cohortdefinitions.RetriveStatsBySourceIdAndTeamProject) diff --git a/tests/controllers_tests/controllers_test.go b/tests/controllers_tests/controllers_test.go index 045022e..2fa3c42 100644 --- a/tests/controllers_tests/controllers_test.go +++ b/tests/controllers_tests/controllers_test.go @@ -54,10 +54,11 @@ var cohortDataController = controllers.NewCohortDataController(*new(dummyCohortD var cohortDataControllerWithFailingTeamProjectAuthz = controllers.NewCohortDataController(*new(dummyCohortDataModel), *new(dummyFailingTeamProjectAuthz)) // instance of the controller that talks to the regular model implementation (that needs a real DB): -var cohortDefinitionControllerNeedsDb = controllers.NewCohortDefinitionController(*new(models.CohortDefinition)) +var cohortDefinitionControllerNeedsDb = controllers.NewCohortDefinitionController(*new(models.CohortDefinition), *new(dummyTeamProjectAuthz)) // instance of the controller that talks to a mock implementation of the model: -var cohortDefinitionController = controllers.NewCohortDefinitionController(*new(dummyCohortDefinitionDataModel)) +var cohortDefinitionController = controllers.NewCohortDefinitionController(*new(dummyCohortDefinitionDataModel), *new(dummyTeamProjectAuthz)) +var cohortDefinitionControllerWithFailingTeamProjectAuthz = controllers.NewCohortDefinitionController(*new(dummyCohortDefinitionDataModel), *new(dummyFailingTeamProjectAuthz)) type dummyCohortDataModel struct{} @@ -151,6 +152,10 @@ func (h dummyTeamProjectAuthz) TeamProjectValidationForCohortIdsList(ctx *gin.Co return true } +func (h dummyTeamProjectAuthz) HasAccessToTeamProject(ctx *gin.Context, teamProject string) bool { + return true +} + type dummyFailingTeamProjectAuthz struct{} func (h dummyFailingTeamProjectAuthz) TeamProjectValidationForCohort(ctx *gin.Context, cohortDefinitionId int) bool { @@ -165,6 +170,10 @@ func (h dummyFailingTeamProjectAuthz) TeamProjectValidationForCohortIdsList(ctx return false } +func (h dummyFailingTeamProjectAuthz) HasAccessToTeamProject(ctx *gin.Context, teamProject string) bool { + return false +} + var conceptController = controllers.NewConceptController(*new(dummyConceptDataModel), *new(dummyCohortDefinitionDataModel), *new(dummyTeamProjectAuthz)) var conceptControllerWithFailingTeamProjectAuthz = controllers.NewConceptController(*new(dummyConceptDataModel), *new(dummyCohortDefinitionDataModel), *new(dummyFailingTeamProjectAuthz)) @@ -463,18 +472,37 @@ func TestRetriveStatsBySourceIdAndTeamProjectCheckMandatoryTeamProject(t *testin } } +func TestRetriveStatsBySourceIdAndTeamProjectAuthorizationError(t *testing.T) { + setUp(t) + requestContext := new(gin.Context) + requestContext.Params = append(requestContext.Params, gin.Param{Key: "sourceid", Value: strconv.Itoa(tests.GetTestSourceId())}) + requestContext.Request = &http.Request{URL: &url.URL{}} + teamProject := "/test/dummyname/dummy-team-project" + requestContext.Request.URL.RawQuery = "team-project=" + teamProject + requestContext.Writer = new(tests.CustomResponseWriter) + cohortDefinitionControllerWithFailingTeamProjectAuthz.RetriveStatsBySourceIdAndTeamProject(requestContext) + result := requestContext.Writer.(*tests.CustomResponseWriter) + if !requestContext.IsAborted() { + t.Errorf("Expected aborted request") + } + if result.Status() != http.StatusForbidden { + t.Errorf("Expected StatusForbidden, got %d", result.Status()) + } + if !strings.Contains(result.CustomResponseWriterOut, "access denied") { + t.Errorf("Expected 'access denied' in response") + } +} + func TestRetriveStatsBySourceIdAndTeamProject(t *testing.T) { setUp(t) requestContext := new(gin.Context) requestContext.Params = append(requestContext.Params, gin.Param{Key: "sourceid", Value: strconv.Itoa(tests.GetTestSourceId())}) - //requestContext.Params = append(requestContext.Params, gin.Param{Key: "teamproject", Value: "dummy-team-project"}) requestContext.Request = &http.Request{URL: &url.URL{}} teamProject := "/test/dummyname/dummy-team-project" requestContext.Request.URL.RawQuery = "team-project=" + teamProject requestContext.Writer = new(tests.CustomResponseWriter) cohortDefinitionController.RetriveStatsBySourceIdAndTeamProject(requestContext) result := requestContext.Writer.(*tests.CustomResponseWriter) - log.Printf("result: %s", result) // expect result with all of the dummy data: if !strings.Contains(result.CustomResponseWriterOut, "name1_"+teamProject) || !strings.Contains(result.CustomResponseWriterOut, "name2_"+teamProject) || @@ -502,7 +530,6 @@ func TestRetriveById(t *testing.T) { requestContext.Writer = new(tests.CustomResponseWriter) cohortDefinitionController.RetriveById(requestContext) result := requestContext.Writer.(*tests.CustomResponseWriter) - log.Printf("result: %s", result) // expect result with dummy data: if !strings.Contains(result.CustomResponseWriterOut, "test 1") { t.Errorf("Expected data in result") @@ -522,6 +549,26 @@ func TestRetriveByIdModelError(t *testing.T) { } } +func TestRetriveByIdAuthorizationError(t *testing.T) { + setUp(t) + requestContext := new(gin.Context) + requestContext.Params = append(requestContext.Params, gin.Param{Key: "id", Value: "1"}) + requestContext.Writer = new(tests.CustomResponseWriter) + cohortDefinitionControllerWithFailingTeamProjectAuthz.RetriveById(requestContext) + result := requestContext.Writer.(*tests.CustomResponseWriter) + if !requestContext.IsAborted() { + t.Errorf("Expected aborted request") + } + if result.Status() != http.StatusForbidden { + t.Errorf("Expected StatusForbidden, got %d", result.Status()) + } + // expect result with dummy data: + if !strings.Contains(result.CustomResponseWriterOut, "access denied") { + t.Errorf("Expected 'access denied' in response") + } + +} + func TestRetrieveBreakdownStatsBySourceIdAndCohortId(t *testing.T) { setUp(t) requestContext := new(gin.Context) @@ -532,7 +579,6 @@ func TestRetrieveBreakdownStatsBySourceIdAndCohortId(t *testing.T) { requestContext.Writer = new(tests.CustomResponseWriter) conceptController.RetrieveBreakdownStatsBySourceIdAndCohortId(requestContext) result := requestContext.Writer.(*tests.CustomResponseWriter) - log.Printf("result: %s", result) // expect result with dummy data: if !strings.Contains(result.CustomResponseWriterOut, "persons_in_cohort_with_value") { t.Errorf("Expected data in result") @@ -563,7 +609,6 @@ func TestRetrieveBreakdownStatsBySourceIdAndCohortIdAndVariables(t *testing.T) { requestContext.Writer = new(tests.CustomResponseWriter) conceptController.RetrieveBreakdownStatsBySourceIdAndCohortIdAndVariables(requestContext) result := requestContext.Writer.(*tests.CustomResponseWriter) - log.Printf("result: %s", result) // expect result with dummy data: if !strings.Contains(result.CustomResponseWriterOut, "persons_in_cohort_with_value") { t.Errorf("Expected data in result") @@ -608,7 +653,6 @@ func TestRetrieveInfoBySourceIdAndConceptIds(t *testing.T) { requestContext.Writer = new(tests.CustomResponseWriter) conceptController.RetrieveInfoBySourceIdAndConceptIds(requestContext) result := requestContext.Writer.(*tests.CustomResponseWriter) - log.Printf("result: %s", result) // expect result with dummy data: if !strings.Contains(result.CustomResponseWriterOut, "Concept A") || !strings.Contains(result.CustomResponseWriterOut, "Concept B") { @@ -625,7 +669,6 @@ func TestRetrieveInfoBySourceIdAndConceptTypes(t *testing.T) { requestContext.Writer = new(tests.CustomResponseWriter) conceptController.RetrieveInfoBySourceIdAndConceptTypes(requestContext) result := requestContext.Writer.(*tests.CustomResponseWriter) - log.Printf("result: %s", result) // expect result with dummy data: if !strings.Contains(result.CustomResponseWriterOut, "Concept A") || !strings.Contains(result.CustomResponseWriterOut, "Concept B") { @@ -644,7 +687,6 @@ func TestRetrieveInfoBySourceIdAndConceptTypesModelError(t *testing.T) { dummyModelReturnError = true conceptController.RetrieveInfoBySourceIdAndConceptTypes(requestContext) result := requestContext.Writer.(*tests.CustomResponseWriter) - log.Printf("result: %s", result) if !requestContext.IsAborted() { t.Errorf("Expected aborted request") } @@ -662,7 +704,6 @@ func TestRetrieveInfoBySourceIdAndConceptTypesArgsError(t *testing.T) { dummyModelReturnError = true conceptController.RetrieveInfoBySourceIdAndConceptTypes(requestContext) result := requestContext.Writer.(*tests.CustomResponseWriter) - log.Printf("result: %s", result) if !requestContext.IsAborted() { t.Errorf("Expected aborted request") } @@ -680,7 +721,6 @@ func TestRetrieveInfoBySourceIdAndConceptTypesMissingBody(t *testing.T) { dummyModelReturnError = true conceptController.RetrieveInfoBySourceIdAndConceptTypes(requestContext) result := requestContext.Writer.(*tests.CustomResponseWriter) - log.Printf("result: %s", result) if !requestContext.IsAborted() { t.Errorf("Expected aborted request") } @@ -982,7 +1022,6 @@ func TestRetrieveAttritionTable(t *testing.T) { requestContext.Writer = new(tests.CustomResponseWriter) conceptController.RetrieveAttritionTable(requestContext) result := requestContext.Writer.(*tests.CustomResponseWriter) - log.Printf("result: %s", result.CustomResponseWriterOut) // check result vs expect result: csvLines := strings.Split(strings.TrimRight(result.CustomResponseWriterOut, "\n"), "\n") expectedLines := []string{ diff --git a/tests/middlewares_tests/middlewares_test.go b/tests/middlewares_tests/middlewares_test.go index 883d596..c31c2b8 100644 --- a/tests/middlewares_tests/middlewares_test.go +++ b/tests/middlewares_tests/middlewares_test.go @@ -97,9 +97,11 @@ func (h dummyCohortDefinitionDataModel) GetCohortDefinitionIdsForTeamProject(tea } func (h dummyCohortDefinitionDataModel) GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList []int) ([]string, error) { - // dummy switch just to support two test scenarios: + // dummy switch just to support three test scenarios: if uniqueCohortDefinitionIdsList[0] == 0 { return nil, nil + } else if len(uniqueCohortDefinitionIdsList) == 1 { + return []string{"teamProject1"}, nil } else { return []string{"teamProject1", "teamProject2"}, nil } @@ -122,6 +124,48 @@ func (h dummyCohortDefinitionDataModel) GetAllCohortDefinitions() ([]*models.Coh return nil, nil } +func TestTeamProjectValidationForCohort(t *testing.T) { + setUp(t) + config.Init("mocktest") + arboristAuthzResponseCode := 200 + dummyHttpClient := &dummyHttpClient{statusCode: arboristAuthzResponseCode} + teamProjectAuthz := middlewares.NewTeamProjectAuthz(*new(dummyCohortDefinitionDataModel), + dummyHttpClient) + requestContext := new(gin.Context) + requestContext.Request = new(http.Request) + requestContext.Request.Header = map[string][]string{ + "Authorization": {"dummy_token_value"}, + } + result := teamProjectAuthz.TeamProjectValidationForCohort(requestContext, 1) + if result == false { + t.Errorf("Expected TeamProjectValidationForCohort result to be 'true'") + } + if dummyHttpClient.nrCalls != 1 { + t.Errorf("Expected dummyHttpClient to have been only once") + } +} + +func TestTeamProjectValidationForCohortArborist401(t *testing.T) { + setUp(t) + config.Init("mocktest") + arboristAuthzResponseCode := 401 + dummyHttpClient := &dummyHttpClient{statusCode: arboristAuthzResponseCode} + teamProjectAuthz := middlewares.NewTeamProjectAuthz(*new(dummyCohortDefinitionDataModel), + dummyHttpClient) + requestContext := new(gin.Context) + requestContext.Request = new(http.Request) + requestContext.Request.Header = map[string][]string{ + "Authorization": {"dummy_token_value"}, + } + result := teamProjectAuthz.TeamProjectValidationForCohort(requestContext, 1) + if result == true { + t.Errorf("Expected TeamProjectValidationForCohort result to be 'false'") + } + if dummyHttpClient.nrCalls != 1 { + t.Errorf("Expected dummyHttpClient to have been only once") + } +} + func TestTeamProjectValidation(t *testing.T) { setUp(t) config.Init("mocktest") @@ -134,7 +178,7 @@ func TestTeamProjectValidation(t *testing.T) { requestContext.Request.Header = map[string][]string{ "Authorization": {"dummy_token_value"}, } - result := teamProjectAuthz.TeamProjectValidation(requestContext, []int{1}, nil) + result := teamProjectAuthz.TeamProjectValidation(requestContext, []int{1, 2}, nil) if result == false { t.Errorf("Expected TeamProjectValidation result to be 'true'") } @@ -155,7 +199,7 @@ func TestTeamProjectValidationArborist401(t *testing.T) { requestContext.Request.Header = map[string][]string{ "Authorization": {"dummy_token_value"}, } - result := teamProjectAuthz.TeamProjectValidation(requestContext, []int{1}, nil) + result := teamProjectAuthz.TeamProjectValidation(requestContext, []int{1, 2}, nil) if result == true { t.Errorf("Expected TeamProjectValidation result to be 'false'") } @@ -184,3 +228,30 @@ func TestTeamProjectValidationNoTeamProjectMatchingAllCohortDefinitions(t *testi t.Errorf("Expected dummyHttpClient to NOT have been called") } } + +func TestHasAccessToTeamProjectAbortOnArboristPrepError(t *testing.T) { + setUp(t) + config.Init("mocktest") + arboristAuthzResponseCode := 200 + dummyHttpClient := &dummyHttpClient{statusCode: arboristAuthzResponseCode} + requestContext := new(gin.Context) + requestContext.Request = new(http.Request) + requestContext.Writer = new(tests.CustomResponseWriter) + // add empty header to force an error during PrepareNewArboristRequestForResourceAndService: + requestContext.Request.Header = map[string][]string{ + "Authorization": {""}, + } + teamProjectAuthz := middlewares.NewTeamProjectAuthz(*new(dummyCohortDefinitionDataModel), + dummyHttpClient) + + defer func() { + if err := recover(); err != nil { + log.Println("panic occurred:", err) + if err != "Error while preparing Arborist request" { + t.Errorf("Expected error: 'Error while preparing Arborist request'") + } + } + }() + teamProjectAuthz.HasAccessToTeamProject(requestContext, "dummyTeam") + t.Errorf("Expected error") +} diff --git a/tests/testutils.go b/tests/testutils.go index 3ebc7dc..b82d9f6 100644 --- a/tests/testutils.go +++ b/tests/testutils.go @@ -219,6 +219,7 @@ func Map[T, U any](items []T, f func(T) U) []U { // to use when mocking request context (gin.Context) in controller tests: type CustomResponseWriter struct { CustomResponseWriterOut string + StatusCode int } func (w *CustomResponseWriter) Header() http.Header { @@ -234,7 +235,8 @@ func (w *CustomResponseWriter) Write(b []byte) (int, error) { } func (w *CustomResponseWriter) WriteHeader(statusCode int) { - // do nothing + // Store the status code + w.StatusCode = statusCode } func (w *CustomResponseWriter) CloseNotify() <-chan bool { @@ -254,7 +256,7 @@ func (w *CustomResponseWriter) Pusher() (pusher http.Pusher) { } func (w *CustomResponseWriter) Status() int { - return 0 + return w.StatusCode } func (w *CustomResponseWriter) Size() int {