Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/add missing arborist checks #87

Merged
merged 6 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions controllers/cohortdata.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func (u CohortDataController) RetrieveHistogramForCohortIdAndConceptId(c *gin.Co
validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{cohortId}, cohortPairs)
if !validAccessRequest {
log.Printf("Error: invalid request")
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
c.JSON(http.StatusForbidden, gin.H{"message": "access denied"})
c.Abort()
return
}
Expand Down Expand Up @@ -101,7 +101,7 @@ func (u CohortDataController) RetrieveDataBySourceIdAndCohortIdAndVariables(c *g
validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{cohortId}, cohortPairs)
if !validAccessRequest {
log.Printf("Error: invalid request")
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
c.JSON(http.StatusForbidden, gin.H{"message": "access denied"})
c.Abort()
return
}
Expand Down Expand Up @@ -254,7 +254,7 @@ func (u CohortDataController) RetrieveCohortOverlapStatsWithoutFilteringOnConcep
validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{caseCohortId, controlCohortId}, cohortPairs)
if !validAccessRequest {
log.Printf("Error: invalid request")
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
c.JSON(http.StatusForbidden, gin.H{"message": "access denied"})
c.Abort()
return
}
Expand Down
29 changes: 25 additions & 4 deletions controllers/cohortdefinition.go
Original file line number Diff line number Diff line change
@@ -1,28 +1,41 @@
package controllers

import (
"log"
"net/http"
"strconv"

"github.com/gin-gonic/gin"
"github.com/uc-cdis/cohort-middleware/middlewares"
"github.com/uc-cdis/cohort-middleware/models"
"github.com/uc-cdis/cohort-middleware/utils"
)

type CohortDefinitionController struct {
cohortDefinitionModel models.CohortDefinitionI
teamProjectAuthz middlewares.TeamProjectAuthzI
}

func NewCohortDefinitionController(cohortDefinitionModel models.CohortDefinitionI) CohortDefinitionController {
return CohortDefinitionController{cohortDefinitionModel: cohortDefinitionModel}
func NewCohortDefinitionController(cohortDefinitionModel models.CohortDefinitionI, teamProjectAuthz middlewares.TeamProjectAuthzI) CohortDefinitionController {
return CohortDefinitionController{
cohortDefinitionModel: cohortDefinitionModel,
teamProjectAuthz: teamProjectAuthz,
}
}

func (u CohortDefinitionController) RetriveById(c *gin.Context) {
// TODO - add teamproject validation - check if user has the necessary atlas and arborist permissions
cohortDefinitionId := c.Param("id")

if cohortDefinitionId != "" {
cohortDefinitionId, _ := strconv.Atoi(cohortDefinitionId)
// validate teamproject access permission for cohort:
validAccessRequest := u.teamProjectAuthz.TeamProjectValidationForCohort(c, cohortDefinitionId)
if !validAccessRequest {
log.Printf("Error: invalid request")
c.JSON(http.StatusForbidden, gin.H{"message": "access denied"})
c.Abort()
return
}
cohortDefinition, err := u.cohortDefinitionModel.GetCohortDefinitionById(cohortDefinitionId)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"message": "Error retrieving cohortDefinition", "error": err.Error()})
Expand All @@ -45,7 +58,15 @@ func (u CohortDefinitionController) RetriveStatsBySourceIdAndTeamProject(c *gin.
c.Abort()
return
}
// TODO - validate teamproject against arborist
// validate teamproject access permission:
validAccessRequest := u.teamProjectAuthz.HasAccessToTeamProject(c, teamProject)
if !validAccessRequest {
log.Printf("Error: invalid request")
c.JSON(http.StatusForbidden, gin.H{"message": "access denied"})
c.Abort()
return
}

if err1 == nil {
cohortDefinitionsAndStats, err := u.cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId, teamProject)
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions controllers/concept.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func (u ConceptController) RetrieveBreakdownStatsBySourceIdAndCohortId(c *gin.Co
validAccessRequest := u.teamProjectAuthz.TeamProjectValidationForCohort(c, cohortId)
if !validAccessRequest {
log.Printf("Error: invalid request")
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
c.JSON(http.StatusForbidden, gin.H{"message": "access denied"})
c.Abort()
return
}
Expand Down Expand Up @@ -135,7 +135,7 @@ func (u ConceptController) RetrieveBreakdownStatsBySourceIdAndCohortIdAndVariabl
validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{cohortId}, cohortPairs)
if !validAccessRequest {
log.Printf("Error: invalid request")
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
c.JSON(http.StatusForbidden, gin.H{"message": "access denied"})
c.Abort()
return
}
Expand Down Expand Up @@ -201,7 +201,7 @@ func (u ConceptController) RetrieveAttritionTable(c *gin.Context) {
validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{cohortId}, cohortPairs)
if !validAccessRequest {
log.Printf("Error: invalid request")
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
c.JSON(http.StatusForbidden, gin.H{"message": "access denied"})
c.Abort()
return
}
Expand Down
45 changes: 28 additions & 17 deletions middlewares/teamprojectauthz.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type TeamProjectAuthzI interface {
TeamProjectValidationForCohort(ctx *gin.Context, cohortDefinitionId int) bool
TeamProjectValidation(ctx *gin.Context, cohortDefinitionIds []int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool
TeamProjectValidationForCohortIdsList(ctx *gin.Context, uniqueCohortDefinitionIdsList []int) bool
HasAccessToTeamProject(ctx *gin.Context, teamProject string) bool
}

type HttpClientI interface {
Expand All @@ -30,30 +31,40 @@ func NewTeamProjectAuthz(cohortDefinitionModel models.CohortDefinitionI, httpCli
httpClient: httpClient,
}
}
func (u TeamProjectAuthz) hasAccessToAtLeastOne(ctx *gin.Context, teamProjects []string) bool {

// query Arborist and return as soon as one of the teamProjects access check returns 200:
for _, teamProject := range teamProjects {
teamProjectAsResourcePath := teamProject
teamProjectAccessService := "atlas-argo-wrapper-and-cohort-middleware"
func (u TeamProjectAuthz) HasAccessToTeamProject(ctx *gin.Context, teamProject string) bool {
teamProjectAsResourcePath := teamProject
teamProjectAccessService := "atlas-argo-wrapper-and-cohort-middleware"

req, err := PrepareNewArboristRequestForResourceAndService(ctx, teamProjectAsResourcePath, teamProjectAccessService)
if err != nil {
ctx.AbortWithStatus(500)
panic("Error while preparing Arborist request")
}
// send the request to Arborist:
resp, _ := u.httpClient.Do(req)
log.Printf("Got response status %d from Arborist...", resp.StatusCode)
req, err := PrepareNewArboristRequestForResourceAndService(ctx, teamProjectAsResourcePath, teamProjectAccessService)
if err != nil {
ctx.AbortWithStatus(500)
panic("Error while preparing Arborist request")
}
// send the request to Arborist:
resp, _ := u.httpClient.Do(req)
log.Printf("Got response status %d from Arborist...", resp.StatusCode)

// arborist will return with 200 if the user has been granted access to the cohort-middleware URL in ctx:
if resp.StatusCode == 200 {
return true
} else {
// unauthorized or otherwise:
log.Printf("Authorization check for team project failed with status %d ...", resp.StatusCode)
return false
}
}

// arborist will return with 200 if the user has been granted access to the cohort-middleware URL in ctx:
if resp.StatusCode == 200 {
func (u TeamProjectAuthz) hasAccessToAtLeastOne(ctx *gin.Context, teamProjects []string) bool {
for _, teamProject := range teamProjects {
if u.HasAccessToTeamProject(ctx, teamProject) {
return true
} else {
// unauthorized or otherwise:
log.Printf("Status %d does NOT give access to team project...", resp.StatusCode)
// unauthorized:
log.Printf("NO access to team project...checking next one (if any)...")
}
}
log.Printf("NO access to any of the team projects queried...")
return false
}

Expand Down
3 changes: 2 additions & 1 deletion server/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ func NewRouter() *gin.Engine {
authorized.GET("/source/by-name/:name", source.RetriveByName)
authorized.GET("/sources", source.RetriveAll)

cohortdefinitions := controllers.NewCohortDefinitionController(*new(models.CohortDefinition))
cohortdefinitions := controllers.NewCohortDefinitionController(*new(models.CohortDefinition),
middlewares.NewTeamProjectAuthz(*new(models.CohortDefinition), &http.Client{}))
authorized.GET("/cohortdefinition/by-id/:id", cohortdefinitions.RetriveById)

authorized.GET("/cohortdefinition-stats/by-source-id/:sourceid/by-team-project", cohortdefinitions.RetriveStatsBySourceIdAndTeamProject)
Expand Down
65 changes: 52 additions & 13 deletions tests/controllers_tests/controllers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@ var cohortDataController = controllers.NewCohortDataController(*new(dummyCohortD
var cohortDataControllerWithFailingTeamProjectAuthz = controllers.NewCohortDataController(*new(dummyCohortDataModel), *new(dummyFailingTeamProjectAuthz))

// instance of the controller that talks to the regular model implementation (that needs a real DB):
var cohortDefinitionControllerNeedsDb = controllers.NewCohortDefinitionController(*new(models.CohortDefinition))
var cohortDefinitionControllerNeedsDb = controllers.NewCohortDefinitionController(*new(models.CohortDefinition), *new(dummyTeamProjectAuthz))

// instance of the controller that talks to a mock implementation of the model:
var cohortDefinitionController = controllers.NewCohortDefinitionController(*new(dummyCohortDefinitionDataModel))
var cohortDefinitionController = controllers.NewCohortDefinitionController(*new(dummyCohortDefinitionDataModel), *new(dummyTeamProjectAuthz))
var cohortDefinitionControllerWithFailingTeamProjectAuthz = controllers.NewCohortDefinitionController(*new(dummyCohortDefinitionDataModel), *new(dummyFailingTeamProjectAuthz))

type dummyCohortDataModel struct{}

Expand Down Expand Up @@ -151,6 +152,10 @@ func (h dummyTeamProjectAuthz) TeamProjectValidationForCohortIdsList(ctx *gin.Co
return true
}

func (h dummyTeamProjectAuthz) HasAccessToTeamProject(ctx *gin.Context, teamProject string) bool {
return true
}

type dummyFailingTeamProjectAuthz struct{}

func (h dummyFailingTeamProjectAuthz) TeamProjectValidationForCohort(ctx *gin.Context, cohortDefinitionId int) bool {
Expand All @@ -165,6 +170,10 @@ func (h dummyFailingTeamProjectAuthz) TeamProjectValidationForCohortIdsList(ctx
return false
}

func (h dummyFailingTeamProjectAuthz) HasAccessToTeamProject(ctx *gin.Context, teamProject string) bool {
return false
}

var conceptController = controllers.NewConceptController(*new(dummyConceptDataModel), *new(dummyCohortDefinitionDataModel), *new(dummyTeamProjectAuthz))
var conceptControllerWithFailingTeamProjectAuthz = controllers.NewConceptController(*new(dummyConceptDataModel), *new(dummyCohortDefinitionDataModel), *new(dummyFailingTeamProjectAuthz))

Expand Down Expand Up @@ -463,18 +472,37 @@ func TestRetriveStatsBySourceIdAndTeamProjectCheckMandatoryTeamProject(t *testin
}
}

func TestRetriveStatsBySourceIdAndTeamProjectAuthorizationError(t *testing.T) {
setUp(t)
requestContext := new(gin.Context)
requestContext.Params = append(requestContext.Params, gin.Param{Key: "sourceid", Value: strconv.Itoa(tests.GetTestSourceId())})
requestContext.Request = &http.Request{URL: &url.URL{}}
teamProject := "/test/dummyname/dummy-team-project"
requestContext.Request.URL.RawQuery = "team-project=" + teamProject
requestContext.Writer = new(tests.CustomResponseWriter)
cohortDefinitionControllerWithFailingTeamProjectAuthz.RetriveStatsBySourceIdAndTeamProject(requestContext)
result := requestContext.Writer.(*tests.CustomResponseWriter)
if !requestContext.IsAborted() {
t.Errorf("Expected aborted request")
}
if result.Status() != http.StatusForbidden {
t.Errorf("Expected StatusForbidden, got %d", result.Status())
}
if !strings.Contains(result.CustomResponseWriterOut, "access denied") {
t.Errorf("Expected 'access denied' in response")
}
}

func TestRetriveStatsBySourceIdAndTeamProject(t *testing.T) {
setUp(t)
requestContext := new(gin.Context)
requestContext.Params = append(requestContext.Params, gin.Param{Key: "sourceid", Value: strconv.Itoa(tests.GetTestSourceId())})
//requestContext.Params = append(requestContext.Params, gin.Param{Key: "teamproject", Value: "dummy-team-project"})
requestContext.Request = &http.Request{URL: &url.URL{}}
teamProject := "/test/dummyname/dummy-team-project"
requestContext.Request.URL.RawQuery = "team-project=" + teamProject
requestContext.Writer = new(tests.CustomResponseWriter)
cohortDefinitionController.RetriveStatsBySourceIdAndTeamProject(requestContext)
result := requestContext.Writer.(*tests.CustomResponseWriter)
log.Printf("result: %s", result)
// expect result with all of the dummy data:
if !strings.Contains(result.CustomResponseWriterOut, "name1_"+teamProject) ||
!strings.Contains(result.CustomResponseWriterOut, "name2_"+teamProject) ||
Expand Down Expand Up @@ -502,7 +530,6 @@ func TestRetriveById(t *testing.T) {
requestContext.Writer = new(tests.CustomResponseWriter)
cohortDefinitionController.RetriveById(requestContext)
result := requestContext.Writer.(*tests.CustomResponseWriter)
log.Printf("result: %s", result)
// expect result with dummy data:
if !strings.Contains(result.CustomResponseWriterOut, "test 1") {
t.Errorf("Expected data in result")
Expand All @@ -522,6 +549,26 @@ func TestRetriveByIdModelError(t *testing.T) {
}
}

func TestRetriveByIdAuthorizationError(t *testing.T) {
setUp(t)
requestContext := new(gin.Context)
requestContext.Params = append(requestContext.Params, gin.Param{Key: "id", Value: "1"})
requestContext.Writer = new(tests.CustomResponseWriter)
cohortDefinitionControllerWithFailingTeamProjectAuthz.RetriveById(requestContext)
result := requestContext.Writer.(*tests.CustomResponseWriter)
if !requestContext.IsAborted() {
t.Errorf("Expected aborted request")
}
if result.Status() != http.StatusForbidden {
t.Errorf("Expected StatusForbidden, got %d", result.Status())
}
// expect result with dummy data:
if !strings.Contains(result.CustomResponseWriterOut, "access denied") {
t.Errorf("Expected 'access denied' in response")
}

}

func TestRetrieveBreakdownStatsBySourceIdAndCohortId(t *testing.T) {
setUp(t)
requestContext := new(gin.Context)
Expand All @@ -532,7 +579,6 @@ func TestRetrieveBreakdownStatsBySourceIdAndCohortId(t *testing.T) {
requestContext.Writer = new(tests.CustomResponseWriter)
conceptController.RetrieveBreakdownStatsBySourceIdAndCohortId(requestContext)
result := requestContext.Writer.(*tests.CustomResponseWriter)
log.Printf("result: %s", result)
// expect result with dummy data:
if !strings.Contains(result.CustomResponseWriterOut, "persons_in_cohort_with_value") {
t.Errorf("Expected data in result")
Expand Down Expand Up @@ -563,7 +609,6 @@ func TestRetrieveBreakdownStatsBySourceIdAndCohortIdAndVariables(t *testing.T) {
requestContext.Writer = new(tests.CustomResponseWriter)
conceptController.RetrieveBreakdownStatsBySourceIdAndCohortIdAndVariables(requestContext)
result := requestContext.Writer.(*tests.CustomResponseWriter)
log.Printf("result: %s", result)
// expect result with dummy data:
if !strings.Contains(result.CustomResponseWriterOut, "persons_in_cohort_with_value") {
t.Errorf("Expected data in result")
Expand Down Expand Up @@ -608,7 +653,6 @@ func TestRetrieveInfoBySourceIdAndConceptIds(t *testing.T) {
requestContext.Writer = new(tests.CustomResponseWriter)
conceptController.RetrieveInfoBySourceIdAndConceptIds(requestContext)
result := requestContext.Writer.(*tests.CustomResponseWriter)
log.Printf("result: %s", result)
// expect result with dummy data:
if !strings.Contains(result.CustomResponseWriterOut, "Concept A") ||
!strings.Contains(result.CustomResponseWriterOut, "Concept B") {
Expand All @@ -625,7 +669,6 @@ func TestRetrieveInfoBySourceIdAndConceptTypes(t *testing.T) {
requestContext.Writer = new(tests.CustomResponseWriter)
conceptController.RetrieveInfoBySourceIdAndConceptTypes(requestContext)
result := requestContext.Writer.(*tests.CustomResponseWriter)
log.Printf("result: %s", result)
// expect result with dummy data:
if !strings.Contains(result.CustomResponseWriterOut, "Concept A") ||
!strings.Contains(result.CustomResponseWriterOut, "Concept B") {
Expand All @@ -644,7 +687,6 @@ func TestRetrieveInfoBySourceIdAndConceptTypesModelError(t *testing.T) {
dummyModelReturnError = true
conceptController.RetrieveInfoBySourceIdAndConceptTypes(requestContext)
result := requestContext.Writer.(*tests.CustomResponseWriter)
log.Printf("result: %s", result)
if !requestContext.IsAborted() {
t.Errorf("Expected aborted request")
}
Expand All @@ -662,7 +704,6 @@ func TestRetrieveInfoBySourceIdAndConceptTypesArgsError(t *testing.T) {
dummyModelReturnError = true
conceptController.RetrieveInfoBySourceIdAndConceptTypes(requestContext)
result := requestContext.Writer.(*tests.CustomResponseWriter)
log.Printf("result: %s", result)
if !requestContext.IsAborted() {
t.Errorf("Expected aborted request")
}
Expand All @@ -680,7 +721,6 @@ func TestRetrieveInfoBySourceIdAndConceptTypesMissingBody(t *testing.T) {
dummyModelReturnError = true
conceptController.RetrieveInfoBySourceIdAndConceptTypes(requestContext)
result := requestContext.Writer.(*tests.CustomResponseWriter)
log.Printf("result: %s", result)
if !requestContext.IsAborted() {
t.Errorf("Expected aborted request")
}
Expand Down Expand Up @@ -982,7 +1022,6 @@ func TestRetrieveAttritionTable(t *testing.T) {
requestContext.Writer = new(tests.CustomResponseWriter)
conceptController.RetrieveAttritionTable(requestContext)
result := requestContext.Writer.(*tests.CustomResponseWriter)
log.Printf("result: %s", result.CustomResponseWriterOut)
// check result vs expect result:
csvLines := strings.Split(strings.TrimRight(result.CustomResponseWriterOut, "\n"), "\n")
expectedLines := []string{
Expand Down
Loading
Loading