Skip to content

Commit

Permalink
Merge pull request #87 from uc-cdis/feat/add_missing_arborist_checks
Browse files Browse the repository at this point in the history
Feat/add missing arborist checks
  • Loading branch information
pieterlukasse authored Jan 30, 2024
2 parents ecb2237 + aa75cb3 commit 5e1839b
Show file tree
Hide file tree
Showing 8 changed files with 191 additions and 46 deletions.
6 changes: 3 additions & 3 deletions controllers/cohortdata.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func (u CohortDataController) RetrieveHistogramForCohortIdAndConceptId(c *gin.Co
validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{cohortId}, cohortPairs)
if !validAccessRequest {
log.Printf("Error: invalid request")
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
c.JSON(http.StatusForbidden, gin.H{"message": "access denied"})
c.Abort()
return
}
Expand Down Expand Up @@ -101,7 +101,7 @@ func (u CohortDataController) RetrieveDataBySourceIdAndCohortIdAndVariables(c *g
validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{cohortId}, cohortPairs)
if !validAccessRequest {
log.Printf("Error: invalid request")
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
c.JSON(http.StatusForbidden, gin.H{"message": "access denied"})
c.Abort()
return
}
Expand Down Expand Up @@ -254,7 +254,7 @@ func (u CohortDataController) RetrieveCohortOverlapStatsWithoutFilteringOnConcep
validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{caseCohortId, controlCohortId}, cohortPairs)
if !validAccessRequest {
log.Printf("Error: invalid request")
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
c.JSON(http.StatusForbidden, gin.H{"message": "access denied"})
c.Abort()
return
}
Expand Down
29 changes: 25 additions & 4 deletions controllers/cohortdefinition.go
Original file line number Diff line number Diff line change
@@ -1,28 +1,41 @@
package controllers

import (
"log"
"net/http"
"strconv"

"github.com/gin-gonic/gin"
"github.com/uc-cdis/cohort-middleware/middlewares"
"github.com/uc-cdis/cohort-middleware/models"
"github.com/uc-cdis/cohort-middleware/utils"
)

type CohortDefinitionController struct {
cohortDefinitionModel models.CohortDefinitionI
teamProjectAuthz middlewares.TeamProjectAuthzI
}

func NewCohortDefinitionController(cohortDefinitionModel models.CohortDefinitionI) CohortDefinitionController {
return CohortDefinitionController{cohortDefinitionModel: cohortDefinitionModel}
func NewCohortDefinitionController(cohortDefinitionModel models.CohortDefinitionI, teamProjectAuthz middlewares.TeamProjectAuthzI) CohortDefinitionController {
return CohortDefinitionController{
cohortDefinitionModel: cohortDefinitionModel,
teamProjectAuthz: teamProjectAuthz,
}
}

func (u CohortDefinitionController) RetriveById(c *gin.Context) {
// TODO - add teamproject validation - check if user has the necessary atlas and arborist permissions
cohortDefinitionId := c.Param("id")

if cohortDefinitionId != "" {
cohortDefinitionId, _ := strconv.Atoi(cohortDefinitionId)
// validate teamproject access permission for cohort:
validAccessRequest := u.teamProjectAuthz.TeamProjectValidationForCohort(c, cohortDefinitionId)
if !validAccessRequest {
log.Printf("Error: invalid request")
c.JSON(http.StatusForbidden, gin.H{"message": "access denied"})
c.Abort()
return
}
cohortDefinition, err := u.cohortDefinitionModel.GetCohortDefinitionById(cohortDefinitionId)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"message": "Error retrieving cohortDefinition", "error": err.Error()})
Expand All @@ -45,7 +58,15 @@ func (u CohortDefinitionController) RetriveStatsBySourceIdAndTeamProject(c *gin.
c.Abort()
return
}
// TODO - validate teamproject against arborist
// validate teamproject access permission:
validAccessRequest := u.teamProjectAuthz.HasAccessToTeamProject(c, teamProject)
if !validAccessRequest {
log.Printf("Error: invalid request")
c.JSON(http.StatusForbidden, gin.H{"message": "access denied"})
c.Abort()
return
}

if err1 == nil {
cohortDefinitionsAndStats, err := u.cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId, teamProject)
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions controllers/concept.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func (u ConceptController) RetrieveBreakdownStatsBySourceIdAndCohortId(c *gin.Co
validAccessRequest := u.teamProjectAuthz.TeamProjectValidationForCohort(c, cohortId)
if !validAccessRequest {
log.Printf("Error: invalid request")
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
c.JSON(http.StatusForbidden, gin.H{"message": "access denied"})
c.Abort()
return
}
Expand Down Expand Up @@ -135,7 +135,7 @@ func (u ConceptController) RetrieveBreakdownStatsBySourceIdAndCohortIdAndVariabl
validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{cohortId}, cohortPairs)
if !validAccessRequest {
log.Printf("Error: invalid request")
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
c.JSON(http.StatusForbidden, gin.H{"message": "access denied"})
c.Abort()
return
}
Expand Down Expand Up @@ -201,7 +201,7 @@ func (u ConceptController) RetrieveAttritionTable(c *gin.Context) {
validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{cohortId}, cohortPairs)
if !validAccessRequest {
log.Printf("Error: invalid request")
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
c.JSON(http.StatusForbidden, gin.H{"message": "access denied"})
c.Abort()
return
}
Expand Down
45 changes: 28 additions & 17 deletions middlewares/teamprojectauthz.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type TeamProjectAuthzI interface {
TeamProjectValidationForCohort(ctx *gin.Context, cohortDefinitionId int) bool
TeamProjectValidation(ctx *gin.Context, cohortDefinitionIds []int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool
TeamProjectValidationForCohortIdsList(ctx *gin.Context, uniqueCohortDefinitionIdsList []int) bool
HasAccessToTeamProject(ctx *gin.Context, teamProject string) bool
}

type HttpClientI interface {
Expand All @@ -30,30 +31,40 @@ func NewTeamProjectAuthz(cohortDefinitionModel models.CohortDefinitionI, httpCli
httpClient: httpClient,
}
}
func (u TeamProjectAuthz) hasAccessToAtLeastOne(ctx *gin.Context, teamProjects []string) bool {

// query Arborist and return as soon as one of the teamProjects access check returns 200:
for _, teamProject := range teamProjects {
teamProjectAsResourcePath := teamProject
teamProjectAccessService := "atlas-argo-wrapper-and-cohort-middleware"
func (u TeamProjectAuthz) HasAccessToTeamProject(ctx *gin.Context, teamProject string) bool {
teamProjectAsResourcePath := teamProject
teamProjectAccessService := "atlas-argo-wrapper-and-cohort-middleware"

req, err := PrepareNewArboristRequestForResourceAndService(ctx, teamProjectAsResourcePath, teamProjectAccessService)
if err != nil {
ctx.AbortWithStatus(500)
panic("Error while preparing Arborist request")
}
// send the request to Arborist:
resp, _ := u.httpClient.Do(req)
log.Printf("Got response status %d from Arborist...", resp.StatusCode)
req, err := PrepareNewArboristRequestForResourceAndService(ctx, teamProjectAsResourcePath, teamProjectAccessService)
if err != nil {
ctx.AbortWithStatus(500)
panic("Error while preparing Arborist request")
}
// send the request to Arborist:
resp, _ := u.httpClient.Do(req)
log.Printf("Got response status %d from Arborist...", resp.StatusCode)

// arborist will return with 200 if the user has been granted access to the cohort-middleware URL in ctx:
if resp.StatusCode == 200 {
return true
} else {
// unauthorized or otherwise:
log.Printf("Authorization check for team project failed with status %d ...", resp.StatusCode)
return false
}
}

// arborist will return with 200 if the user has been granted access to the cohort-middleware URL in ctx:
if resp.StatusCode == 200 {
func (u TeamProjectAuthz) hasAccessToAtLeastOne(ctx *gin.Context, teamProjects []string) bool {
for _, teamProject := range teamProjects {
if u.HasAccessToTeamProject(ctx, teamProject) {
return true
} else {
// unauthorized or otherwise:
log.Printf("Status %d does NOT give access to team project...", resp.StatusCode)
// unauthorized:
log.Printf("NO access to team project...checking next one (if any)...")
}
}
log.Printf("NO access to any of the team projects queried...")
return false
}

Expand Down
3 changes: 2 additions & 1 deletion server/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ func NewRouter() *gin.Engine {
authorized.GET("/source/by-name/:name", source.RetriveByName)
authorized.GET("/sources", source.RetriveAll)

cohortdefinitions := controllers.NewCohortDefinitionController(*new(models.CohortDefinition))
cohortdefinitions := controllers.NewCohortDefinitionController(*new(models.CohortDefinition),
middlewares.NewTeamProjectAuthz(*new(models.CohortDefinition), &http.Client{}))
authorized.GET("/cohortdefinition/by-id/:id", cohortdefinitions.RetriveById)

authorized.GET("/cohortdefinition-stats/by-source-id/:sourceid/by-team-project", cohortdefinitions.RetriveStatsBySourceIdAndTeamProject)
Expand Down
65 changes: 52 additions & 13 deletions tests/controllers_tests/controllers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@ var cohortDataController = controllers.NewCohortDataController(*new(dummyCohortD
var cohortDataControllerWithFailingTeamProjectAuthz = controllers.NewCohortDataController(*new(dummyCohortDataModel), *new(dummyFailingTeamProjectAuthz))

// instance of the controller that talks to the regular model implementation (that needs a real DB):
var cohortDefinitionControllerNeedsDb = controllers.NewCohortDefinitionController(*new(models.CohortDefinition))
var cohortDefinitionControllerNeedsDb = controllers.NewCohortDefinitionController(*new(models.CohortDefinition), *new(dummyTeamProjectAuthz))

// instance of the controller that talks to a mock implementation of the model:
var cohortDefinitionController = controllers.NewCohortDefinitionController(*new(dummyCohortDefinitionDataModel))
var cohortDefinitionController = controllers.NewCohortDefinitionController(*new(dummyCohortDefinitionDataModel), *new(dummyTeamProjectAuthz))
var cohortDefinitionControllerWithFailingTeamProjectAuthz = controllers.NewCohortDefinitionController(*new(dummyCohortDefinitionDataModel), *new(dummyFailingTeamProjectAuthz))

type dummyCohortDataModel struct{}

Expand Down Expand Up @@ -151,6 +152,10 @@ func (h dummyTeamProjectAuthz) TeamProjectValidationForCohortIdsList(ctx *gin.Co
return true
}

func (h dummyTeamProjectAuthz) HasAccessToTeamProject(ctx *gin.Context, teamProject string) bool {
return true
}

type dummyFailingTeamProjectAuthz struct{}

func (h dummyFailingTeamProjectAuthz) TeamProjectValidationForCohort(ctx *gin.Context, cohortDefinitionId int) bool {
Expand All @@ -165,6 +170,10 @@ func (h dummyFailingTeamProjectAuthz) TeamProjectValidationForCohortIdsList(ctx
return false
}

func (h dummyFailingTeamProjectAuthz) HasAccessToTeamProject(ctx *gin.Context, teamProject string) bool {
return false
}

var conceptController = controllers.NewConceptController(*new(dummyConceptDataModel), *new(dummyCohortDefinitionDataModel), *new(dummyTeamProjectAuthz))
var conceptControllerWithFailingTeamProjectAuthz = controllers.NewConceptController(*new(dummyConceptDataModel), *new(dummyCohortDefinitionDataModel), *new(dummyFailingTeamProjectAuthz))

Expand Down Expand Up @@ -463,18 +472,37 @@ func TestRetriveStatsBySourceIdAndTeamProjectCheckMandatoryTeamProject(t *testin
}
}

func TestRetriveStatsBySourceIdAndTeamProjectAuthorizationError(t *testing.T) {
setUp(t)
requestContext := new(gin.Context)
requestContext.Params = append(requestContext.Params, gin.Param{Key: "sourceid", Value: strconv.Itoa(tests.GetTestSourceId())})
requestContext.Request = &http.Request{URL: &url.URL{}}
teamProject := "/test/dummyname/dummy-team-project"
requestContext.Request.URL.RawQuery = "team-project=" + teamProject
requestContext.Writer = new(tests.CustomResponseWriter)
cohortDefinitionControllerWithFailingTeamProjectAuthz.RetriveStatsBySourceIdAndTeamProject(requestContext)
result := requestContext.Writer.(*tests.CustomResponseWriter)
if !requestContext.IsAborted() {
t.Errorf("Expected aborted request")
}
if result.Status() != http.StatusForbidden {
t.Errorf("Expected StatusForbidden, got %d", result.Status())
}
if !strings.Contains(result.CustomResponseWriterOut, "access denied") {
t.Errorf("Expected 'access denied' in response")
}
}

func TestRetriveStatsBySourceIdAndTeamProject(t *testing.T) {
setUp(t)
requestContext := new(gin.Context)
requestContext.Params = append(requestContext.Params, gin.Param{Key: "sourceid", Value: strconv.Itoa(tests.GetTestSourceId())})
//requestContext.Params = append(requestContext.Params, gin.Param{Key: "teamproject", Value: "dummy-team-project"})
requestContext.Request = &http.Request{URL: &url.URL{}}
teamProject := "/test/dummyname/dummy-team-project"
requestContext.Request.URL.RawQuery = "team-project=" + teamProject
requestContext.Writer = new(tests.CustomResponseWriter)
cohortDefinitionController.RetriveStatsBySourceIdAndTeamProject(requestContext)
result := requestContext.Writer.(*tests.CustomResponseWriter)
log.Printf("result: %s", result)
// expect result with all of the dummy data:
if !strings.Contains(result.CustomResponseWriterOut, "name1_"+teamProject) ||
!strings.Contains(result.CustomResponseWriterOut, "name2_"+teamProject) ||
Expand Down Expand Up @@ -502,7 +530,6 @@ func TestRetriveById(t *testing.T) {
requestContext.Writer = new(tests.CustomResponseWriter)
cohortDefinitionController.RetriveById(requestContext)
result := requestContext.Writer.(*tests.CustomResponseWriter)
log.Printf("result: %s", result)
// expect result with dummy data:
if !strings.Contains(result.CustomResponseWriterOut, "test 1") {
t.Errorf("Expected data in result")
Expand All @@ -522,6 +549,26 @@ func TestRetriveByIdModelError(t *testing.T) {
}
}

func TestRetriveByIdAuthorizationError(t *testing.T) {
setUp(t)
requestContext := new(gin.Context)
requestContext.Params = append(requestContext.Params, gin.Param{Key: "id", Value: "1"})
requestContext.Writer = new(tests.CustomResponseWriter)
cohortDefinitionControllerWithFailingTeamProjectAuthz.RetriveById(requestContext)
result := requestContext.Writer.(*tests.CustomResponseWriter)
if !requestContext.IsAborted() {
t.Errorf("Expected aborted request")
}
if result.Status() != http.StatusForbidden {
t.Errorf("Expected StatusForbidden, got %d", result.Status())
}
// expect result with dummy data:
if !strings.Contains(result.CustomResponseWriterOut, "access denied") {
t.Errorf("Expected 'access denied' in response")
}

}

func TestRetrieveBreakdownStatsBySourceIdAndCohortId(t *testing.T) {
setUp(t)
requestContext := new(gin.Context)
Expand All @@ -532,7 +579,6 @@ func TestRetrieveBreakdownStatsBySourceIdAndCohortId(t *testing.T) {
requestContext.Writer = new(tests.CustomResponseWriter)
conceptController.RetrieveBreakdownStatsBySourceIdAndCohortId(requestContext)
result := requestContext.Writer.(*tests.CustomResponseWriter)
log.Printf("result: %s", result)
// expect result with dummy data:
if !strings.Contains(result.CustomResponseWriterOut, "persons_in_cohort_with_value") {
t.Errorf("Expected data in result")
Expand Down Expand Up @@ -563,7 +609,6 @@ func TestRetrieveBreakdownStatsBySourceIdAndCohortIdAndVariables(t *testing.T) {
requestContext.Writer = new(tests.CustomResponseWriter)
conceptController.RetrieveBreakdownStatsBySourceIdAndCohortIdAndVariables(requestContext)
result := requestContext.Writer.(*tests.CustomResponseWriter)
log.Printf("result: %s", result)
// expect result with dummy data:
if !strings.Contains(result.CustomResponseWriterOut, "persons_in_cohort_with_value") {
t.Errorf("Expected data in result")
Expand Down Expand Up @@ -608,7 +653,6 @@ func TestRetrieveInfoBySourceIdAndConceptIds(t *testing.T) {
requestContext.Writer = new(tests.CustomResponseWriter)
conceptController.RetrieveInfoBySourceIdAndConceptIds(requestContext)
result := requestContext.Writer.(*tests.CustomResponseWriter)
log.Printf("result: %s", result)
// expect result with dummy data:
if !strings.Contains(result.CustomResponseWriterOut, "Concept A") ||
!strings.Contains(result.CustomResponseWriterOut, "Concept B") {
Expand All @@ -625,7 +669,6 @@ func TestRetrieveInfoBySourceIdAndConceptTypes(t *testing.T) {
requestContext.Writer = new(tests.CustomResponseWriter)
conceptController.RetrieveInfoBySourceIdAndConceptTypes(requestContext)
result := requestContext.Writer.(*tests.CustomResponseWriter)
log.Printf("result: %s", result)
// expect result with dummy data:
if !strings.Contains(result.CustomResponseWriterOut, "Concept A") ||
!strings.Contains(result.CustomResponseWriterOut, "Concept B") {
Expand All @@ -644,7 +687,6 @@ func TestRetrieveInfoBySourceIdAndConceptTypesModelError(t *testing.T) {
dummyModelReturnError = true
conceptController.RetrieveInfoBySourceIdAndConceptTypes(requestContext)
result := requestContext.Writer.(*tests.CustomResponseWriter)
log.Printf("result: %s", result)
if !requestContext.IsAborted() {
t.Errorf("Expected aborted request")
}
Expand All @@ -662,7 +704,6 @@ func TestRetrieveInfoBySourceIdAndConceptTypesArgsError(t *testing.T) {
dummyModelReturnError = true
conceptController.RetrieveInfoBySourceIdAndConceptTypes(requestContext)
result := requestContext.Writer.(*tests.CustomResponseWriter)
log.Printf("result: %s", result)
if !requestContext.IsAborted() {
t.Errorf("Expected aborted request")
}
Expand All @@ -680,7 +721,6 @@ func TestRetrieveInfoBySourceIdAndConceptTypesMissingBody(t *testing.T) {
dummyModelReturnError = true
conceptController.RetrieveInfoBySourceIdAndConceptTypes(requestContext)
result := requestContext.Writer.(*tests.CustomResponseWriter)
log.Printf("result: %s", result)
if !requestContext.IsAborted() {
t.Errorf("Expected aborted request")
}
Expand Down Expand Up @@ -982,7 +1022,6 @@ func TestRetrieveAttritionTable(t *testing.T) {
requestContext.Writer = new(tests.CustomResponseWriter)
conceptController.RetrieveAttritionTable(requestContext)
result := requestContext.Writer.(*tests.CustomResponseWriter)
log.Printf("result: %s", result.CustomResponseWriterOut)
// check result vs expect result:
csvLines := strings.Split(strings.TrimRight(result.CustomResponseWriterOut, "\n"), "\n")
expectedLines := []string{
Expand Down
Loading

0 comments on commit 5e1839b

Please sign in to comment.