Skip to content

Commit

Permalink
Merge pull request #83 from uc-cdis/feat/integrate_arborist_validatio…
Browse files Browse the repository at this point in the history
…n_for_team_project_for_cohort_data_endpoints

Feat: integrate Arborist validation for team project for cohort data endpoints AND remove unused endpoints
  • Loading branch information
pieterlukasse authored Dec 19, 2023
2 parents 198efbf + f7b7fc7 commit 30fb6b4
Show file tree
Hide file tree
Showing 10 changed files with 117 additions and 122 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ cd tests/setup_local_db/
JSON summary data endpoints:
```bash
curl http://localhost:8080/sources | python -m json.tool
curl http://localhost:8080/cohortdefinition-stats/by-source-id/1 | python -m json.tool
curl "http://localhost:8080/cohortdefinition-stats/by-source-id/1/by-team-project?team-project=test" | python -m json.tool
curl http://localhost:8080/concept/by-source-id/1 | python -m json.tool
curl -d '{"ConceptIds":[2000000324,2000006885]}' -H "Content-Type: application/json" -X POST http://localhost:8080/concept/by-source-id/1 | python -m json.tool
curl -d '{"ConceptTypes":["Measurement","Person"]}' -H "Content-Type: application/json" -X POST http://localhost:8080/concept/by-source-id/1/by-type | python -m json.tool
Expand Down
35 changes: 32 additions & 3 deletions controllers/cohortdata.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,21 @@ import (
"strconv"

"github.com/gin-gonic/gin"
"github.com/uc-cdis/cohort-middleware/middlewares"
"github.com/uc-cdis/cohort-middleware/models"
"github.com/uc-cdis/cohort-middleware/utils"
)

type CohortDataController struct {
cohortDataModel models.CohortDataI
cohortDataModel models.CohortDataI
teamProjectAuthz middlewares.TeamProjectAuthzI
}

func NewCohortDataController(cohortDataModel models.CohortDataI) CohortDataController {
return CohortDataController{cohortDataModel: cohortDataModel}
func NewCohortDataController(cohortDataModel models.CohortDataI, teamProjectAuthz middlewares.TeamProjectAuthzI) CohortDataController {
return CohortDataController{
cohortDataModel: cohortDataModel,
teamProjectAuthz: teamProjectAuthz,
}
}

func (u CohortDataController) RetrieveHistogramForCohortIdAndConceptId(c *gin.Context) {
Expand All @@ -44,6 +49,14 @@ func (u CohortDataController) RetrieveHistogramForCohortIdAndConceptId(c *gin.Co
cohortId, _ := strconv.Atoi(cohortIdStr)
histogramConceptId, _ := strconv.ParseInt(histogramIdStr, 10, 64)

validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{cohortId}, cohortPairs)
if !validAccessRequest {
log.Printf("Error: invalid request")
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
c.Abort()
return
}

cohortData, err := u.cohortDataModel.RetrieveHistogramDataBySourceIdAndCohortIdAndConceptIdsAndCohortPairs(sourceId, cohortId, histogramConceptId, filterConceptIds, cohortPairs)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"message": "Error retrieving concept details", "error": err.Error()})
Expand Down Expand Up @@ -85,6 +98,14 @@ func (u CohortDataController) RetrieveDataBySourceIdAndCohortIdAndVariables(c *g
sourceId, _ := strconv.Atoi(sourceIdStr)
cohortId, _ := strconv.Atoi(cohortIdStr)

validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{cohortId}, cohortPairs)
if !validAccessRequest {
log.Printf("Error: invalid request")
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
c.Abort()
return
}

// call model method:
cohortData, err := u.cohortDataModel.RetrieveDataBySourceIdAndCohortIdAndConceptIdsOrderedByPersonId(sourceId, cohortId, conceptIds)
if err != nil {
Expand Down Expand Up @@ -230,6 +251,14 @@ func (u CohortDataController) RetrieveCohortOverlapStatsWithoutFilteringOnConcep
controlCohortId, errors[2] = utils.ParseNumericArg(c, "controlcohortid")
conceptIds, cohortPairs, errors[3] = utils.ParseConceptIdsAndDichotomousDefs(c)

validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{caseCohortId, controlCohortId}, cohortPairs)
if !validAccessRequest {
log.Printf("Error: invalid request")
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
c.Abort()
return
}

if utils.ContainsNonNil(errors) {
c.JSON(http.StatusBadRequest, gin.H{"message": "bad request"})
c.Abort()
Expand Down
49 changes: 1 addition & 48 deletions controllers/cohortdefinition.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package controllers

import (
"net/http"
"strconv"

"github.com/gin-gonic/gin"
"github.com/uc-cdis/cohort-middleware/models"
Expand All @@ -17,56 +16,10 @@ func NewCohortDefinitionController(cohortDefinitionModel models.CohortDefinition
return CohortDefinitionController{cohortDefinitionModel: cohortDefinitionModel}
}

func (u CohortDefinitionController) RetriveById(c *gin.Context) {
cohortDefinitionId := c.Param("id")

if cohortDefinitionId != "" {
cohortDefinitionId, _ := strconv.Atoi(cohortDefinitionId)
cohortDefinition, err := u.cohortDefinitionModel.GetCohortDefinitionById(cohortDefinitionId)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"message": "Error retrieving cohortDefinition", "error": err.Error()})
c.Abort()
return
}
c.JSON(http.StatusOK, gin.H{"cohort_definition": cohortDefinition})
return
}
c.JSON(http.StatusBadRequest, gin.H{"message": "bad request"})
c.Abort()
}

func (u CohortDefinitionController) RetriveByName(c *gin.Context) {
cohortDefinitionName := c.Param("name")

if cohortDefinitionName != "" {
cohortDefinition, err := u.cohortDefinitionModel.GetCohortDefinitionByName(cohortDefinitionName)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"message": "Error retrieving cohortDefinition", "error": err.Error()})
c.Abort()
return
}
c.JSON(http.StatusOK, gin.H{"CohortDefinition": cohortDefinition})
return
}
c.JSON(http.StatusBadRequest, gin.H{"message": "bad request"})
c.Abort()
}

func (u CohortDefinitionController) RetriveAll(c *gin.Context) {
cohortDefinitions, err := u.cohortDefinitionModel.GetAllCohortDefinitions()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"message": "Error retrieving cohortDefinition", "error": err.Error()})
c.Abort()
return
}
c.JSON(http.StatusOK, gin.H{"cohort_definitions": cohortDefinitions})
}

func (u CohortDefinitionController) RetriveStatsBySourceIdAndTeamProject(c *gin.Context) {
// This method returns ALL cohortdefinition entries with cohort size statistics (for a given source)

sourceId, err1 := utils.ParseNumericArg(c, "sourceid")
teamProject := c.Param("teamproject")
teamProject := c.Query("team-project")
if teamProject == "" {
c.JSON(http.StatusInternalServerError, gin.H{"message": "Error while parsing request", "error": "team-project is a mandatory parameter but was found to be empty!"})
c.Abort()
Expand Down
4 changes: 2 additions & 2 deletions controllers/concept.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func (u ConceptController) RetrieveBreakdownStatsBySourceIdAndCohortIdAndVariabl
c.Abort()
return
}
validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, cohortId, cohortPairs)
validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{cohortId}, cohortPairs)
if !validAccessRequest {
log.Printf("Error: invalid request")
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
Expand Down Expand Up @@ -198,7 +198,7 @@ func (u ConceptController) RetrieveAttritionTable(c *gin.Context) {
return
}
_, cohortPairs := utils.GetConceptIdsAndCohortPairsAsSeparateLists(conceptIdsAndCohortPairs)
validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, cohortId, cohortPairs)
validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{cohortId}, cohortPairs)
if !validAccessRequest {
log.Printf("Error: invalid request")
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
Expand Down
17 changes: 11 additions & 6 deletions middlewares/teamprojectauthz.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ import (

type TeamProjectAuthzI interface {
TeamProjectValidationForCohort(ctx *gin.Context, cohortDefinitionId int) bool
TeamProjectValidation(ctx *gin.Context, cohortDefinitionId int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool
TeamProjectValidation(ctx *gin.Context, cohortDefinitionIds []int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool
TeamProjectValidationForCohortIdsList(ctx *gin.Context, uniqueCohortDefinitionIdsList []int) bool
}

type HttpClientI interface {
Expand Down Expand Up @@ -58,16 +59,20 @@ func (u TeamProjectAuthz) hasAccessToAtLeastOne(ctx *gin.Context, teamProjects [

func (u TeamProjectAuthz) TeamProjectValidationForCohort(ctx *gin.Context, cohortDefinitionId int) bool {
filterCohortPairs := []utils.CustomDichotomousVariableDef{}
return u.TeamProjectValidation(ctx, cohortDefinitionId, filterCohortPairs)
return u.TeamProjectValidation(ctx, []int{cohortDefinitionId}, filterCohortPairs)
}

func (u TeamProjectAuthz) TeamProjectValidation(ctx *gin.Context, cohortDefinitionIds []int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool {

uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionIds, filterCohortPairs)
return u.TeamProjectValidationForCohortIdsList(ctx, uniqueCohortDefinitionIdsList)
}

// "team project" related checks:
// (1) check if the request contains any cohorts and if all cohorts belong to the same "team project"
// (1) check if all cohorts belong to the same "team project"
// (2) check if the user has permission in the "team project"
// Returns true if both checks above pass, false otherwise.
func (u TeamProjectAuthz) TeamProjectValidation(ctx *gin.Context, cohortDefinitionId int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool {

uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionId, filterCohortPairs)
func (u TeamProjectAuthz) TeamProjectValidationForCohortIdsList(ctx *gin.Context, uniqueCohortDefinitionIdsList []int) bool {
teamProjects, _ := u.cohortDefinitionModel.GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList)
if len(teamProjects) == 0 {
log.Printf("Invalid request error: could not find a 'team project' that is associated to ALL the cohorts present in this request")
Expand Down
7 changes: 2 additions & 5 deletions server/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@ func NewRouter() *gin.Engine {
authorized.GET("/sources", source.RetriveAll)

cohortdefinitions := controllers.NewCohortDefinitionController(*new(models.CohortDefinition))
authorized.GET("/cohortdefinition/by-id/:id", cohortdefinitions.RetriveById)
authorized.GET("/cohortdefinition/by-name/:name", cohortdefinitions.RetriveByName)
authorized.GET("/cohortdefinitions", cohortdefinitions.RetriveAll)
authorized.GET("/cohortdefinition-stats/by-source-id/:sourceid/by-team-project/:teamproject", cohortdefinitions.RetriveStatsBySourceIdAndTeamProject)
authorized.GET("/cohortdefinition-stats/by-source-id/:sourceid/by-team-project", cohortdefinitions.RetriveStatsBySourceIdAndTeamProject)

// concept endpoints:
concepts := controllers.NewConceptController(*new(models.Concept), *new(models.CohortDefinition),
Expand All @@ -46,7 +43,7 @@ func NewRouter() *gin.Engine {
authorized.POST("/concept-stats/by-source-id/:sourceid/by-cohort-definition-id/:cohortid/breakdown-by-concept-id/:breakdownconceptid/csv", concepts.RetrieveAttritionTable)

// cohort stats and checks:
cohortData := controllers.NewCohortDataController(*new(models.CohortData))
cohortData := controllers.NewCohortDataController(*new(models.CohortData), middlewares.NewTeamProjectAuthz(*new(models.CohortDefinition), &http.Client{}))
// :casecohortid/:controlcohortid are just labels here and have no special meaning. Could also just be :cohortAId/:cohortBId here:
authorized.POST("/cohort-stats/check-overlap/by-source-id/:sourceid/by-cohort-definition-ids/:casecohortid/:controlcohortid", cohortData.RetrieveCohortOverlapStatsWithoutFilteringOnConceptValue)

Expand Down
Loading

0 comments on commit 30fb6b4

Please sign in to comment.