Skip to content

Commit

Permalink
Merge pull request #81 from uc-cdis/feat/integrate_team_project_checks
Browse files Browse the repository at this point in the history
Feat: add "team project" filtering to /cohortdefinition-stats endpoint
  • Loading branch information
pieterlukasse authored Dec 13, 2023
2 parents 4ba94d3 + 88c2289 commit 6751594
Show file tree
Hide file tree
Showing 8 changed files with 231 additions and 23 deletions.
10 changes: 8 additions & 2 deletions controllers/cohortdefinition.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,18 @@ func (u CohortDefinitionController) RetriveAll(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"cohort_definitions": cohortDefinitions})
}

func (u CohortDefinitionController) RetriveStatsBySourceId(c *gin.Context) {
func (u CohortDefinitionController) RetriveStatsBySourceIdAndTeamProject(c *gin.Context) {
// This method returns ALL cohortdefinition entries with cohort size statistics (for a given source)

sourceId, err1 := utils.ParseNumericArg(c, "sourceid")
teamProject := c.Param("teamproject")
if teamProject == "" {
c.JSON(http.StatusInternalServerError, gin.H{"message": "Error while parsing request", "error": "team-project is a mandatory parameter but was found to be empty!"})
c.Abort()
return
}
if err1 == nil {
cohortDefinitionsAndStats, err := u.cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId)
cohortDefinitionsAndStats, err := u.cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId, teamProject)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"message": "Error retrieving cohortDefinition", "error": err.Error()})
c.Abort()
Expand Down
29 changes: 25 additions & 4 deletions models/cohortdefinition.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ type CohortDefinitionI interface {
GetCohortDefinitionById(id int) (*CohortDefinition, error)
GetCohortDefinitionByName(name string) (*CohortDefinition, error)
GetAllCohortDefinitions() ([]*CohortDefinition, error)
GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int) ([]*CohortDefinitionStats, error)
GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int, teamProject string) ([]*CohortDefinitionStats, error)
GetCohortName(cohortId int) (string, error)
}

Expand Down Expand Up @@ -67,7 +67,19 @@ func (h CohortDefinition) GetAllCohortDefinitions() ([]*CohortDefinition, error)
return cohortDefinition, meta_result.Error
}

func (h CohortDefinition) GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int) ([]*CohortDefinitionStats, error) {
// Get the list of cohort_definition ids for a given "team project" (where "team project" is basically
// a security role name of one of the roles in Atlas/WebAPI database).
func (h CohortDefinition) GetCohortDefinitionIdsForTeamProject(teamProject string) ([]int, error) {
db2 := db.GetAtlasDB().Db
var cohortDefinitionIds []int
query := db2.Table(db.GetAtlasDB().Schema+".cohort_definition_sec_role").
Select("cohort_definition_id").
Where("sec_role_name = ?", teamProject).
Scan(&cohortDefinitionIds)
return cohortDefinitionIds, query.Error
}

func (h CohortDefinition) GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int, teamProject string) ([]*CohortDefinitionStats, error) {

// Connect to source db and gather stats:
var dataSourceModel = new(Source)
Expand All @@ -81,6 +93,11 @@ func (h CohortDefinition) GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceI
defer cancel()
meta_result := query.Scan(&cohortDefinitionStats)

// get (from separate Atlas DB - hence not using JOIN above) the list of cohort_definition_ids
// that are allowed for the given teamProject:
allowedCohortDefinitionIds, _ := h.GetCohortDefinitionIdsForTeamProject(teamProject)
log.Printf("INFO: found %d cohorts for this team project", len(allowedCohortDefinitionIds))

// add name details:
finalList := []*CohortDefinitionStats{}
for _, cohortDefinitionStat := range cohortDefinitionStats {
Expand All @@ -91,10 +108,14 @@ func (h CohortDefinition) GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceI
cohortDefinitionStat.CohortSize)
continue
} else {
cohortDefinitionStat.Name = cohortDefinition.Name
finalList = append(finalList, cohortDefinitionStat)
// filter to keep only the allowed ones:
if utils.Contains(allowedCohortDefinitionIds, cohortDefinitionStat.Id) {
cohortDefinitionStat.Name = cohortDefinition.Name
finalList = append(finalList, cohortDefinitionStat)
}
}
}

return finalList, meta_result.Error
}

Expand Down
2 changes: 1 addition & 1 deletion server/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func NewRouter() *gin.Engine {
authorized.GET("/cohortdefinition/by-id/:id", cohortdefinitions.RetriveById)
authorized.GET("/cohortdefinition/by-name/:name", cohortdefinitions.RetriveByName)
authorized.GET("/cohortdefinitions", cohortdefinitions.RetriveAll)
authorized.GET("/cohortdefinition-stats/by-source-id/:sourceid", cohortdefinitions.RetriveStatsBySourceId)
authorized.GET("/cohortdefinition-stats/by-source-id/:sourceid/by-team-project/:teamproject", cohortdefinitions.RetriveStatsBySourceIdAndTeamProject)

// concept endpoints:
concepts := controllers.NewConceptController(*new(models.Concept), *new(models.CohortDefinition))
Expand Down
32 changes: 25 additions & 7 deletions tests/controllers_tests/controllers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (h dummyCohortDefinitionDataModel) GetCohortName(cohortId int) (string, err
return "dummy cohort name", nil
}

func (h dummyCohortDefinitionDataModel) GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int) ([]*models.CohortDefinitionStats, error) {
func (h dummyCohortDefinitionDataModel) GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int, teamProject string) ([]*models.CohortDefinitionStats, error) {
cohortDefinitionStats := []*models.CohortDefinitionStats{
{Id: 1, CohortSize: 10, Name: "name1"},
{Id: 2, CohortSize: 22, Name: "name2"},
Expand Down Expand Up @@ -340,22 +340,23 @@ func TestGenerateCSV(t *testing.T) {
}
}

func TestRetriveStatsBySourceIdWrongParams(t *testing.T) {
func TestRetriveStatsBySourceIdAndTeamProjectWrongParams(t *testing.T) {
setUp(t)
requestContext := new(gin.Context)
requestContext.Params = append(requestContext.Params, gin.Param{Key: "Abc", Value: "def"})
requestContext.Writer = new(tests.CustomResponseWriter)
cohortDefinitionController.RetriveStatsBySourceId(requestContext)
cohortDefinitionController.RetriveStatsBySourceIdAndTeamProject(requestContext)
// Params above are wrong, so request should abort:
if !requestContext.IsAborted() {
t.Errorf("Expected aborted request")
}
}

func TestRetriveStatsBySourceIdDbPanic(t *testing.T) {
func TestRetriveStatsBySourceIdAndTeamProjectDbPanic(t *testing.T) {
setUp(t)
requestContext := new(gin.Context)
requestContext.Params = append(requestContext.Params, gin.Param{Key: "sourceid", Value: strconv.Itoa(tests.GetTestSourceId())})
requestContext.Params = append(requestContext.Params, gin.Param{Key: "teamproject", Value: "dummy-team-project"})
requestContext.Writer = new(tests.CustomResponseWriter)

defer func() {
Expand All @@ -366,16 +367,33 @@ func TestRetriveStatsBySourceIdDbPanic(t *testing.T) {
}
}
}()
cohortDefinitionControllerNeedsDb.RetriveStatsBySourceId(requestContext)
cohortDefinitionControllerNeedsDb.RetriveStatsBySourceIdAndTeamProject(requestContext)
t.Errorf("Expected error")
}

func TestRetriveStatsBySourceId(t *testing.T) {
func TestRetriveStatsBySourceIdAndTeamProjectCheckMandatoryTeamProject(t *testing.T) {
setUp(t)
requestContext := new(gin.Context)
requestContext.Params = append(requestContext.Params, gin.Param{Key: "sourceid", Value: strconv.Itoa(tests.GetTestSourceId())})
requestContext.Writer = new(tests.CustomResponseWriter)
cohortDefinitionController.RetriveStatsBySourceId(requestContext)
cohortDefinitionController.RetriveStatsBySourceIdAndTeamProject(requestContext)
result := requestContext.Writer.(*tests.CustomResponseWriter)
// Params above are wrong, so request should abort:
if !requestContext.IsAborted() {
t.Errorf("Expected aborted request")
}
if !strings.Contains(result.CustomResponseWriterOut, "team-project is a mandatory parameter") {
t.Errorf("Expected error about mandatory team-project")
}
}

func TestRetriveStatsBySourceIdAndTeamProject(t *testing.T) {
setUp(t)
requestContext := new(gin.Context)
requestContext.Params = append(requestContext.Params, gin.Param{Key: "sourceid", Value: strconv.Itoa(tests.GetTestSourceId())})
requestContext.Params = append(requestContext.Params, gin.Param{Key: "teamproject", Value: "dummy-team-project"})
requestContext.Writer = new(tests.CustomResponseWriter)
cohortDefinitionController.RetriveStatsBySourceIdAndTeamProject(requestContext)
result := requestContext.Writer.(*tests.CustomResponseWriter)
log.Printf("result: %s", result)
// expect result with all of the dummy data:
Expand Down
51 changes: 44 additions & 7 deletions tests/models_tests/models_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ var allConceptIds []int64
var dummyContinuousConceptId = tests.GetTestDummyContinuousConceptId()
var hareConceptId = tests.GetTestHareConceptId()
var histogramConceptId = tests.GetTestHistogramConceptId()
var defaultTeamProject = "defaultteamproject"

func TestMain(m *testing.M) {
setupSuite()
Expand All @@ -47,7 +48,7 @@ func setupSuite() {

// initialize some handy variables to use in tests below:
// (see also tests/setup_local_db/test_data_results_and_cdm.sql for these test cohort details)
allCohortDefinitions, _ = cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId)
allCohortDefinitions, _ = cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, defaultTeamProject)
largestCohort = allCohortDefinitions[0]
secondLargestCohort = allCohortDefinitions[2]
extendedCopyOfSecondLargestCohort = allCohortDefinitions[1]
Expand Down Expand Up @@ -563,11 +564,29 @@ func TestRetrieveBreakdownStatsBySourceIdAndCohortIdWithResultsWithOnePersonTwoH
}
}

func TestGetCohortDefinitionIdsForTeamProject(t *testing.T) {
setUp(t)
testTeamProject := "teamprojectX"
allowedCohortDefinitionIds, _ := cohortDefinitionModel.GetCohortDefinitionIdsForTeamProject(testTeamProject)
if len(allowedCohortDefinitionIds) != 1 {
t.Errorf("Expected teamProject '%s' to have one cohort, but found %d",
testTeamProject, len(allowedCohortDefinitionIds))
}
// test data is crafted in such a way that the default "team project" has access to all
// the cohorts. Check if this is indeed the case:
testTeamProject = defaultTeamProject
allowedCohortDefinitionIds, _ = cohortDefinitionModel.GetCohortDefinitionIdsForTeamProject(testTeamProject)
allCohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitions()
if len(allCohortDefinitions) != len(allowedCohortDefinitionIds) && len(allCohortDefinitions) > 1 {
t.Errorf("Found %d, expected %d", len(allowedCohortDefinitionIds), len(allCohortDefinitions))
}
}

func TestGetAllCohortDefinitionsAndStatsOrderBySizeDesc(t *testing.T) {
setUp(t)
cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId)
cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, defaultTeamProject)
if len(cohortDefinitions) != len(allCohortDefinitions) {
t.Errorf("Found %d", len(cohortDefinitions))
t.Errorf("Found %d, expected %d", len(cohortDefinitions), len(allCohortDefinitions))
}
// check if stats fields are filled and if order is as expected:
previousSize := 1000000
Expand All @@ -580,14 +599,32 @@ func TestGetAllCohortDefinitionsAndStatsOrderBySizeDesc(t *testing.T) {
}
previousSize = cohortDefinition.CohortSize
}

// some extra tests to cover also the teamProject option for this method:
testTeamProject := "teamprojectX"
allowedCohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, testTeamProject)
if len(allowedCohortDefinitions) != 1 {
t.Errorf("Expected teamProject '%s' to have one cohort, but found %d",
testTeamProject, len(allowedCohortDefinitions))
}
if len(cohortDefinitions) <= len(allowedCohortDefinitions) {
t.Errorf("Expected list of projects for '%s' to be larger than for %s",
defaultTeamProject, testTeamProject)
}
testTeamProject = "teamprojectNonExisting"
allowedCohortDefinitions, _ = cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, testTeamProject)
if len(allowedCohortDefinitions) != 0 {
t.Errorf("Expected teamProject '%s' to have NO cohort, but found %d",
testTeamProject, len(allowedCohortDefinitions))
}
}

// Tests whether the code deals correctly with the (error) situation where
// the `cohort_definition` and `cohort` tables are not in sync (more specifically
// the situation where a cohort still exists in `cohort` table but not in `cohort_definition`).
func TestGetAllCohortDefinitionsAndStatsOrderBySizeDescWhenCohortDefinitionIsMissing(t *testing.T) {
setUp(t)
cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId)
cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, defaultTeamProject)
if len(cohortDefinitions) != len(allCohortDefinitions) {
t.Errorf("Found %d", len(cohortDefinitions))
}
Expand All @@ -596,7 +633,7 @@ func TestGetAllCohortDefinitionsAndStatsOrderBySizeDescWhenCohortDefinitionIsMis
firstCohort := cohortDefinitions[0]
tests.ExecAtlasSQLString(fmt.Sprintf("delete from %s.cohort_definition where id = %d",
db.GetAtlasDB().Schema, firstCohort.Id))
cohortDefinitions, _ = cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId)
cohortDefinitions, _ = cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, defaultTeamProject)
if len(cohortDefinitions) != len(allCohortDefinitions)-1 {
t.Errorf("Number of cohor_definition records expected to be %d, found %d",
len(allCohortDefinitions)-1, len(cohortDefinitions))
Expand Down Expand Up @@ -685,7 +722,7 @@ func TestQueryFilterByConceptIdsHelper(t *testing.T) {

func TestRetrieveDataBySourceIdAndCohortIdAndConceptIdsOrderedByPersonId(t *testing.T) {
setUp(t)
cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId)
cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, defaultTeamProject)
var sumNumeric float32 = 0
textConcat := ""
classIdConcat := ""
Expand Down Expand Up @@ -737,7 +774,7 @@ func TestRetrieveDataBySourceIdAndCohortIdAndConceptIdsOrderedByPersonId(t *test
func TestErrorForRetrieveDataBySourceIdAndCohortIdAndConceptIdsOrderedByPersonId(t *testing.T) {
// Tests if the method returns an error when query fails.

cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId)
cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, defaultTeamProject)

// break something in the Results schema to cause a query failure in the next method:
tests.BreakSomething(models.Results, "cohort", "cohort_definition_id")
Expand Down
49 changes: 49 additions & 0 deletions tests/setup_local_db/ddl_atlas.sql
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,52 @@ CREATE TABLE atlas.cohort_definition
modified_by_id integer,
CONSTRAINT PK_cohort_definition PRIMARY KEY (id)
);

CREATE TABLE atlas.sec_role
(
id integer NOT NULL,
name varchar(255) ,
system_role boolean NOT NULL DEFAULT false,
CONSTRAINT pk_sec_role PRIMARY KEY (id),
CONSTRAINT sec_role_name_uq UNIQUE (name, system_role)
);

CREATE TABLE atlas.sec_permission
(
id integer NOT NULL,
value varchar(255) NOT NULL,
description varchar(255),
CONSTRAINT pk_sec_permission PRIMARY KEY (id),
CONSTRAINT permission_unique UNIQUE (value)
);

CREATE TABLE atlas.sec_role_permission
(
id integer NOT NULL,
role_id integer NOT NULL,
permission_id integer NOT NULL,
status varchar(255),
CONSTRAINT pk_sec_role_permission PRIMARY KEY (id),
CONSTRAINT role_permission_unique UNIQUE (role_id, permission_id),
CONSTRAINT fk_role_permission_to_permission FOREIGN KEY (permission_id)
REFERENCES atlas.sec_permission (id) MATCH SIMPLE
ON UPDATE NO ACTION
ON DELETE NO ACTION,
CONSTRAINT fk_role_permission_to_role FOREIGN KEY (role_id)
REFERENCES atlas.sec_role (id) MATCH SIMPLE
ON UPDATE NO ACTION
ON DELETE NO ACTION
);

CREATE VIEW atlas.COHORT_DEFINITION_SEC_ROLE AS
select
distinct cast(regexp_replace(sec_permission.value,
'^cohortdefinition:([0-9]+):.*','\1') as integer) as cohort_definition_id,
sec_role.name as sec_role_name
from
atlas.sec_role
inner join atlas.sec_role_permission on sec_role.id = sec_role_permission.role_id
inner join atlas.sec_permission on sec_role_permission.permission_id = sec_permission.id
where
sec_permission.value ~ 'cohortdefinition:[0-9]+'
;
68 changes: 68 additions & 0 deletions tests/setup_local_db/test_data_atlas.sql
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,71 @@ values
(32,'Test cohort3b','Copy of Larger cohort'),
(4,'Test cohort4','Extra Larger cohort')
;

insert into atlas.sec_role
(id, name, system_role)
values
(1,'public',true),
(1005,'teamprojectX',false),
(1009,'teamprojectY',false),
(3000,'someotherrole',false),
(4000,'defaultteamproject',false)
;

insert into atlas.sec_permission
(id, value, description)
values
(1181, 'cohortdefinition:2:check:post', 'Fix Cohort Definition with ID = 2'),
(1182, 'cohortdefinition:2:put', 'Update Cohort Definition with ID = 2'),
(1183, 'cohortdefinition:2:delete', 'Delete Cohort Definition with ID = 2'),
(1184, 'cohortdefinition:2:version:*:get', 'Get cohort version'),
(1185, 'cohortdefinition:2:info:get', 'no description'),
(1186, 'cohortdefinition:2:get', 'Get Cohort Definition by ID'),
(1187, 'cohortdefinition:2:version:get', 'Get list of cohort versions'),
(1188, 'cohortdefinition:4:check:post', 'Fix Cohort Definition with ID = 4'),
(1189, 'cohortdefinition:4:put', 'Update Cohort Definition with ID = 4'),
(1190, 'cohortdefinition:4:delete', 'Delete Cohort Definition with ID = 4'),
(1191, 'cohortdefinition:4:version:*:get', 'Get cohort version'),
(1192, 'cohortdefinition:4:info:get', 'no description'),
(1193, 'cohortdefinition:4:get', 'Get Cohort Definition by ID'),
(1194, 'cohortdefinition:4:version:get', 'Get list of cohort versions'),
(2193, 'cohortdefinition:1:get', 'Get Cohort Definition by ID'),
(3193, 'cohortdefinition:3:get', 'Get Cohort Definition by ID'),
(4193, 'cohortdefinition:32:get', 'Get Cohort Definition by ID')
;

insert into atlas.sec_role_permission
(id, role_id, permission_id)
values
(1454, 1005, 1181),
(1455, 1005, 1182),
(1456, 1005, 1183),
(1457, 1005, 1184),
(1458, 1005, 1185),
(1459, 1005, 1186),
(1460, 1005, 1187),
(1461, 1009, 1188),
(1462, 1009, 1189),
(1463, 1009, 1190),
(1464, 1009, 1191),
(1465, 1009, 1192),
(1466, 1009, 1193),
(1467, 1009, 1194),
(2454, 4000, 1181),
(2455, 4000, 1182),
(2456, 4000, 1183),
(2457, 4000, 1184),
(2458, 4000, 1185),
(2459, 4000, 1186),
(2460, 4000, 1187),
(2461, 4000, 1188),
(2462, 4000, 1189),
(2463, 4000, 1190),
(2464, 4000, 1191),
(2465, 4000, 1192),
(2466, 4000, 1193),
(2467, 4000, 1194),
(2468, 4000, 2193),
(2469, 4000, 3193),
(2470, 4000, 4193)
;
Loading

0 comments on commit 6751594

Please sign in to comment.