Skip to content

Commit

Permalink
feat: add missing arborist checks
Browse files Browse the repository at this point in the history
  • Loading branch information
pieterlukasse committed Jan 26, 2024
1 parent ecb2237 commit 75e4b48
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 24 deletions.
29 changes: 25 additions & 4 deletions controllers/cohortdefinition.go
Original file line number Diff line number Diff line change
@@ -1,28 +1,41 @@
package controllers

import (
"log"
"net/http"
"strconv"

"github.com/gin-gonic/gin"
"github.com/uc-cdis/cohort-middleware/middlewares"
"github.com/uc-cdis/cohort-middleware/models"
"github.com/uc-cdis/cohort-middleware/utils"
)

type CohortDefinitionController struct {
cohortDefinitionModel models.CohortDefinitionI
teamProjectAuthz middlewares.TeamProjectAuthzI
}

func NewCohortDefinitionController(cohortDefinitionModel models.CohortDefinitionI) CohortDefinitionController {
return CohortDefinitionController{cohortDefinitionModel: cohortDefinitionModel}
func NewCohortDefinitionController(cohortDefinitionModel models.CohortDefinitionI, teamProjectAuthz middlewares.TeamProjectAuthzI) CohortDefinitionController {
return CohortDefinitionController{
cohortDefinitionModel: cohortDefinitionModel,
teamProjectAuthz: teamProjectAuthz,
}
}

func (u CohortDefinitionController) RetriveById(c *gin.Context) {
// TODO - add teamproject validation - check if user has the necessary atlas and arborist permissions
cohortDefinitionId := c.Param("id")

if cohortDefinitionId != "" {
cohortDefinitionId, _ := strconv.Atoi(cohortDefinitionId)
// validate teamproject access permission for cohort:
validAccessRequest := u.teamProjectAuthz.TeamProjectValidationForCohort(c, cohortDefinitionId)
if !validAccessRequest {
log.Printf("Error: invalid request")
c.JSON(http.StatusForbidden, gin.H{"message": "access denied"})
c.Abort()
return
}
cohortDefinition, err := u.cohortDefinitionModel.GetCohortDefinitionById(cohortDefinitionId)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"message": "Error retrieving cohortDefinition", "error": err.Error()})
Expand All @@ -45,7 +58,15 @@ func (u CohortDefinitionController) RetriveStatsBySourceIdAndTeamProject(c *gin.
c.Abort()
return
}
// TODO - validate teamproject against arborist
// validate teamproject access permission:
validAccessRequest := u.teamProjectAuthz.HasAccessToTeamProject(c, teamProject)
if !validAccessRequest {
log.Printf("Error: invalid request")
c.JSON(http.StatusForbidden, gin.H{"message": "access denied"})
c.Abort()
return
}

if err1 == nil {
cohortDefinitionsAndStats, err := u.cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId, teamProject)
if err != nil {
Expand Down
45 changes: 28 additions & 17 deletions middlewares/teamprojectauthz.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type TeamProjectAuthzI interface {
TeamProjectValidationForCohort(ctx *gin.Context, cohortDefinitionId int) bool
TeamProjectValidation(ctx *gin.Context, cohortDefinitionIds []int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool
TeamProjectValidationForCohortIdsList(ctx *gin.Context, uniqueCohortDefinitionIdsList []int) bool
HasAccessToTeamProject(ctx *gin.Context, teamProject string) bool
}

type HttpClientI interface {
Expand All @@ -30,30 +31,40 @@ func NewTeamProjectAuthz(cohortDefinitionModel models.CohortDefinitionI, httpCli
httpClient: httpClient,
}
}
func (u TeamProjectAuthz) hasAccessToAtLeastOne(ctx *gin.Context, teamProjects []string) bool {

// query Arborist and return as soon as one of the teamProjects access check returns 200:
for _, teamProject := range teamProjects {
teamProjectAsResourcePath := teamProject
teamProjectAccessService := "atlas-argo-wrapper-and-cohort-middleware"
func (u TeamProjectAuthz) HasAccessToTeamProject(ctx *gin.Context, teamProject string) bool {
teamProjectAsResourcePath := teamProject
teamProjectAccessService := "atlas-argo-wrapper-and-cohort-middleware"

req, err := PrepareNewArboristRequestForResourceAndService(ctx, teamProjectAsResourcePath, teamProjectAccessService)
if err != nil {
ctx.AbortWithStatus(500)
panic("Error while preparing Arborist request")
}
// send the request to Arborist:
resp, _ := u.httpClient.Do(req)
log.Printf("Got response status %d from Arborist...", resp.StatusCode)
req, err := PrepareNewArboristRequestForResourceAndService(ctx, teamProjectAsResourcePath, teamProjectAccessService)
if err != nil {
ctx.AbortWithStatus(500)
panic("Error while preparing Arborist request")
}
// send the request to Arborist:
resp, _ := u.httpClient.Do(req)
log.Printf("Got response status %d from Arborist...", resp.StatusCode)

// arborist will return with 200 if the user has been granted access to the cohort-middleware URL in ctx:
if resp.StatusCode == 200 {
return true
} else {
// unauthorized or otherwise:
log.Printf("Authorization check for team project failed with status %d ...", resp.StatusCode)
return false
}
}

// arborist will return with 200 if the user has been granted access to the cohort-middleware URL in ctx:
if resp.StatusCode == 200 {
func (u TeamProjectAuthz) hasAccessToAtLeastOne(ctx *gin.Context, teamProjects []string) bool {
for _, teamProject := range teamProjects {
if u.HasAccessToTeamProject(ctx, teamProject) {
return true
} else {
// unauthorized or otherwise:
log.Printf("Status %d does NOT give access to team project...", resp.StatusCode)
// unauthorized:
log.Printf("NO access to team project...checking next one (if any)...")
}
}
log.Printf("NO access to any of the team projects queried...")
return false
}

Expand Down
3 changes: 2 additions & 1 deletion server/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ func NewRouter() *gin.Engine {
authorized.GET("/source/by-name/:name", source.RetriveByName)
authorized.GET("/sources", source.RetriveAll)

cohortdefinitions := controllers.NewCohortDefinitionController(*new(models.CohortDefinition))
cohortdefinitions := controllers.NewCohortDefinitionController(*new(models.CohortDefinition),
middlewares.NewTeamProjectAuthz(*new(models.CohortDefinition), &http.Client{}))
authorized.GET("/cohortdefinition/by-id/:id", cohortdefinitions.RetriveById)

authorized.GET("/cohortdefinition-stats/by-source-id/:sourceid/by-team-project", cohortdefinitions.RetriveStatsBySourceIdAndTeamProject)
Expand Down
12 changes: 10 additions & 2 deletions tests/controllers_tests/controllers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ var cohortDataController = controllers.NewCohortDataController(*new(dummyCohortD
var cohortDataControllerWithFailingTeamProjectAuthz = controllers.NewCohortDataController(*new(dummyCohortDataModel), *new(dummyFailingTeamProjectAuthz))

// instance of the controller that talks to the regular model implementation (that needs a real DB):
var cohortDefinitionControllerNeedsDb = controllers.NewCohortDefinitionController(*new(models.CohortDefinition))
var cohortDefinitionControllerNeedsDb = controllers.NewCohortDefinitionController(*new(models.CohortDefinition), *new(dummyTeamProjectAuthz))

// instance of the controller that talks to a mock implementation of the model:
var cohortDefinitionController = controllers.NewCohortDefinitionController(*new(dummyCohortDefinitionDataModel))
var cohortDefinitionController = controllers.NewCohortDefinitionController(*new(dummyCohortDefinitionDataModel), *new(dummyTeamProjectAuthz))

type dummyCohortDataModel struct{}

Expand Down Expand Up @@ -151,6 +151,10 @@ func (h dummyTeamProjectAuthz) TeamProjectValidationForCohortIdsList(ctx *gin.Co
return true
}

func (h dummyTeamProjectAuthz) HasAccessToTeamProject(ctx *gin.Context, teamProject string) bool {
return true
}

type dummyFailingTeamProjectAuthz struct{}

func (h dummyFailingTeamProjectAuthz) TeamProjectValidationForCohort(ctx *gin.Context, cohortDefinitionId int) bool {
Expand All @@ -165,6 +169,10 @@ func (h dummyFailingTeamProjectAuthz) TeamProjectValidationForCohortIdsList(ctx
return false
}

func (h dummyFailingTeamProjectAuthz) HasAccessToTeamProject(ctx *gin.Context, teamProject string) bool {
return false
}

var conceptController = controllers.NewConceptController(*new(dummyConceptDataModel), *new(dummyCohortDefinitionDataModel), *new(dummyTeamProjectAuthz))
var conceptControllerWithFailingTeamProjectAuthz = controllers.NewConceptController(*new(dummyConceptDataModel), *new(dummyCohortDefinitionDataModel), *new(dummyFailingTeamProjectAuthz))

Expand Down

0 comments on commit 75e4b48

Please sign in to comment.