diff --git a/controllers/cohortdefinition.go b/controllers/cohortdefinition.go index 229aef7..e163838 100644 --- a/controllers/cohortdefinition.go +++ b/controllers/cohortdefinition.go @@ -62,12 +62,18 @@ func (u CohortDefinitionController) RetriveAll(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"cohort_definitions": cohortDefinitions}) } -func (u CohortDefinitionController) RetriveStatsBySourceId(c *gin.Context) { +func (u CohortDefinitionController) RetriveStatsBySourceIdAndTeamProject(c *gin.Context) { // This method returns ALL cohortdefinition entries with cohort size statistics (for a given source) sourceId, err1 := utils.ParseNumericArg(c, "sourceid") + teamProject := c.Param("teamproject") + if teamProject == "" { + c.JSON(http.StatusInternalServerError, gin.H{"message": "Error while parsing request", "error": "team-project is a mandatory parameter but was found to be empty!"}) + c.Abort() + return + } if err1 == nil { - cohortDefinitionsAndStats, err := u.cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId) + cohortDefinitionsAndStats, err := u.cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId, teamProject) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"message": "Error retrieving cohortDefinition", "error": err.Error()}) c.Abort() diff --git a/models/cohortdefinition.go b/models/cohortdefinition.go index e3e2102..710700c 100644 --- a/models/cohortdefinition.go +++ b/models/cohortdefinition.go @@ -13,7 +13,7 @@ type CohortDefinitionI interface { GetCohortDefinitionById(id int) (*CohortDefinition, error) GetCohortDefinitionByName(name string) (*CohortDefinition, error) GetAllCohortDefinitions() ([]*CohortDefinition, error) - GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int) ([]*CohortDefinitionStats, error) + GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int, teamProject string) ([]*CohortDefinitionStats, error) GetCohortName(cohortId int) (string, error) } @@ -67,7 +67,19 @@ func (h CohortDefinition) GetAllCohortDefinitions() ([]*CohortDefinition, error) return cohortDefinition, meta_result.Error } -func (h CohortDefinition) GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int) ([]*CohortDefinitionStats, error) { +// Get the list of cohort_definition ids for a given "team project" (where "team project" is basically +// a security role name of one of the roles in Atlas/WebAPI database). +func (h CohortDefinition) GetCohortDefinitionIdsForTeamProject(teamProject string) ([]int, error) { + db2 := db.GetAtlasDB().Db + var cohortDefinitionIds []int + query := db2.Table(db.GetAtlasDB().Schema+".cohort_definition_sec_role"). + Select("cohort_definition_id"). + Where("sec_role_name = ?", teamProject). + Scan(&cohortDefinitionIds) + return cohortDefinitionIds, query.Error +} + +func (h CohortDefinition) GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int, teamProject string) ([]*CohortDefinitionStats, error) { // Connect to source db and gather stats: var dataSourceModel = new(Source) @@ -81,6 +93,11 @@ func (h CohortDefinition) GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceI defer cancel() meta_result := query.Scan(&cohortDefinitionStats) + // get (from separate Atlas DB - hence not using JOIN above) the list of cohort_definition_ids + // that are allowed for the given teamProject: + allowedCohortDefinitionIds, _ := h.GetCohortDefinitionIdsForTeamProject(teamProject) + log.Printf("INFO: found %d cohorts for this team project", len(allowedCohortDefinitionIds)) + // add name details: finalList := []*CohortDefinitionStats{} for _, cohortDefinitionStat := range cohortDefinitionStats { @@ -91,10 +108,14 @@ func (h CohortDefinition) GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceI cohortDefinitionStat.CohortSize) continue } else { - cohortDefinitionStat.Name = cohortDefinition.Name - finalList = append(finalList, cohortDefinitionStat) + // filter to keep only the allowed ones: + if utils.Contains(allowedCohortDefinitionIds, cohortDefinitionStat.Id) { + cohortDefinitionStat.Name = cohortDefinition.Name + finalList = append(finalList, cohortDefinitionStat) + } } } + return finalList, meta_result.Error } diff --git a/server/router.go b/server/router.go index 81d9ee8..1357cbf 100644 --- a/server/router.go +++ b/server/router.go @@ -30,7 +30,7 @@ func NewRouter() *gin.Engine { authorized.GET("/cohortdefinition/by-id/:id", cohortdefinitions.RetriveById) authorized.GET("/cohortdefinition/by-name/:name", cohortdefinitions.RetriveByName) authorized.GET("/cohortdefinitions", cohortdefinitions.RetriveAll) - authorized.GET("/cohortdefinition-stats/by-source-id/:sourceid", cohortdefinitions.RetriveStatsBySourceId) + authorized.GET("/cohortdefinition-stats/by-source-id/:sourceid/by-team-project/:teamproject", cohortdefinitions.RetriveStatsBySourceIdAndTeamProject) // concept endpoints: concepts := controllers.NewConceptController(*new(models.Concept), *new(models.CohortDefinition)) diff --git a/tests/controllers_tests/controllers_test.go b/tests/controllers_tests/controllers_test.go index 368b660..6a3f6ff 100644 --- a/tests/controllers_tests/controllers_test.go +++ b/tests/controllers_tests/controllers_test.go @@ -100,7 +100,7 @@ func (h dummyCohortDefinitionDataModel) GetCohortName(cohortId int) (string, err return "dummy cohort name", nil } -func (h dummyCohortDefinitionDataModel) GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int) ([]*models.CohortDefinitionStats, error) { +func (h dummyCohortDefinitionDataModel) GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int, teamProject string) ([]*models.CohortDefinitionStats, error) { cohortDefinitionStats := []*models.CohortDefinitionStats{ {Id: 1, CohortSize: 10, Name: "name1"}, {Id: 2, CohortSize: 22, Name: "name2"}, @@ -340,22 +340,23 @@ func TestGenerateCSV(t *testing.T) { } } -func TestRetriveStatsBySourceIdWrongParams(t *testing.T) { +func TestRetriveStatsBySourceIdAndTeamProjectWrongParams(t *testing.T) { setUp(t) requestContext := new(gin.Context) requestContext.Params = append(requestContext.Params, gin.Param{Key: "Abc", Value: "def"}) requestContext.Writer = new(tests.CustomResponseWriter) - cohortDefinitionController.RetriveStatsBySourceId(requestContext) + cohortDefinitionController.RetriveStatsBySourceIdAndTeamProject(requestContext) // Params above are wrong, so request should abort: if !requestContext.IsAborted() { t.Errorf("Expected aborted request") } } -func TestRetriveStatsBySourceIdDbPanic(t *testing.T) { +func TestRetriveStatsBySourceIdAndTeamProjectDbPanic(t *testing.T) { setUp(t) requestContext := new(gin.Context) requestContext.Params = append(requestContext.Params, gin.Param{Key: "sourceid", Value: strconv.Itoa(tests.GetTestSourceId())}) + requestContext.Params = append(requestContext.Params, gin.Param{Key: "teamproject", Value: "dummy-team-project"}) requestContext.Writer = new(tests.CustomResponseWriter) defer func() { @@ -366,16 +367,33 @@ func TestRetriveStatsBySourceIdDbPanic(t *testing.T) { } } }() - cohortDefinitionControllerNeedsDb.RetriveStatsBySourceId(requestContext) + cohortDefinitionControllerNeedsDb.RetriveStatsBySourceIdAndTeamProject(requestContext) t.Errorf("Expected error") } -func TestRetriveStatsBySourceId(t *testing.T) { +func TestRetriveStatsBySourceIdAndTeamProjectCheckMandatoryTeamProject(t *testing.T) { setUp(t) requestContext := new(gin.Context) requestContext.Params = append(requestContext.Params, gin.Param{Key: "sourceid", Value: strconv.Itoa(tests.GetTestSourceId())}) requestContext.Writer = new(tests.CustomResponseWriter) - cohortDefinitionController.RetriveStatsBySourceId(requestContext) + cohortDefinitionController.RetriveStatsBySourceIdAndTeamProject(requestContext) + result := requestContext.Writer.(*tests.CustomResponseWriter) + // Params above are wrong, so request should abort: + if !requestContext.IsAborted() { + t.Errorf("Expected aborted request") + } + if !strings.Contains(result.CustomResponseWriterOut, "team-project is a mandatory parameter") { + t.Errorf("Expected error about mandatory team-project") + } +} + +func TestRetriveStatsBySourceIdAndTeamProject(t *testing.T) { + setUp(t) + requestContext := new(gin.Context) + requestContext.Params = append(requestContext.Params, gin.Param{Key: "sourceid", Value: strconv.Itoa(tests.GetTestSourceId())}) + requestContext.Params = append(requestContext.Params, gin.Param{Key: "teamproject", Value: "dummy-team-project"}) + requestContext.Writer = new(tests.CustomResponseWriter) + cohortDefinitionController.RetriveStatsBySourceIdAndTeamProject(requestContext) result := requestContext.Writer.(*tests.CustomResponseWriter) log.Printf("result: %s", result) // expect result with all of the dummy data: diff --git a/tests/models_tests/models_test.go b/tests/models_tests/models_test.go index ccc007b..a495d9f 100644 --- a/tests/models_tests/models_test.go +++ b/tests/models_tests/models_test.go @@ -27,6 +27,7 @@ var allConceptIds []int64 var dummyContinuousConceptId = tests.GetTestDummyContinuousConceptId() var hareConceptId = tests.GetTestHareConceptId() var histogramConceptId = tests.GetTestHistogramConceptId() +var defaultTeamProject = "defaultteamproject" func TestMain(m *testing.M) { setupSuite() @@ -47,7 +48,7 @@ func setupSuite() { // initialize some handy variables to use in tests below: // (see also tests/setup_local_db/test_data_results_and_cdm.sql for these test cohort details) - allCohortDefinitions, _ = cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId) + allCohortDefinitions, _ = cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, defaultTeamProject) largestCohort = allCohortDefinitions[0] secondLargestCohort = allCohortDefinitions[2] extendedCopyOfSecondLargestCohort = allCohortDefinitions[1] @@ -563,11 +564,29 @@ func TestRetrieveBreakdownStatsBySourceIdAndCohortIdWithResultsWithOnePersonTwoH } } +func TestGetCohortDefinitionIdsForTeamProject(t *testing.T) { + setUp(t) + testTeamProject := "teamprojectX" + allowedCohortDefinitionIds, _ := cohortDefinitionModel.GetCohortDefinitionIdsForTeamProject(testTeamProject) + if len(allowedCohortDefinitionIds) != 1 { + t.Errorf("Expected teamProject '%s' to have one cohort, but found %d", + testTeamProject, len(allowedCohortDefinitionIds)) + } + // test data is crafted in such a way that the default "team project" has access to all + // the cohorts. Check if this is indeed the case: + testTeamProject = defaultTeamProject + allowedCohortDefinitionIds, _ = cohortDefinitionModel.GetCohortDefinitionIdsForTeamProject(testTeamProject) + allCohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitions() + if len(allCohortDefinitions) != len(allowedCohortDefinitionIds) && len(allCohortDefinitions) > 1 { + t.Errorf("Found %d, expected %d", len(allowedCohortDefinitionIds), len(allCohortDefinitions)) + } +} + func TestGetAllCohortDefinitionsAndStatsOrderBySizeDesc(t *testing.T) { setUp(t) - cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId) + cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, defaultTeamProject) if len(cohortDefinitions) != len(allCohortDefinitions) { - t.Errorf("Found %d", len(cohortDefinitions)) + t.Errorf("Found %d, expected %d", len(cohortDefinitions), len(allCohortDefinitions)) } // check if stats fields are filled and if order is as expected: previousSize := 1000000 @@ -580,6 +599,24 @@ func TestGetAllCohortDefinitionsAndStatsOrderBySizeDesc(t *testing.T) { } previousSize = cohortDefinition.CohortSize } + + // some extra tests to cover also the teamProject option for this method: + testTeamProject := "teamprojectX" + allowedCohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, testTeamProject) + if len(allowedCohortDefinitions) != 1 { + t.Errorf("Expected teamProject '%s' to have one cohort, but found %d", + testTeamProject, len(allowedCohortDefinitions)) + } + if len(cohortDefinitions) <= len(allowedCohortDefinitions) { + t.Errorf("Expected list of projects for '%s' to be larger than for %s", + defaultTeamProject, testTeamProject) + } + testTeamProject = "teamprojectNonExisting" + allowedCohortDefinitions, _ = cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, testTeamProject) + if len(allowedCohortDefinitions) != 0 { + t.Errorf("Expected teamProject '%s' to have NO cohort, but found %d", + testTeamProject, len(allowedCohortDefinitions)) + } } // Tests whether the code deals correctly with the (error) situation where @@ -587,7 +624,7 @@ func TestGetAllCohortDefinitionsAndStatsOrderBySizeDesc(t *testing.T) { // the situation where a cohort still exists in `cohort` table but not in `cohort_definition`). func TestGetAllCohortDefinitionsAndStatsOrderBySizeDescWhenCohortDefinitionIsMissing(t *testing.T) { setUp(t) - cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId) + cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, defaultTeamProject) if len(cohortDefinitions) != len(allCohortDefinitions) { t.Errorf("Found %d", len(cohortDefinitions)) } @@ -596,7 +633,7 @@ func TestGetAllCohortDefinitionsAndStatsOrderBySizeDescWhenCohortDefinitionIsMis firstCohort := cohortDefinitions[0] tests.ExecAtlasSQLString(fmt.Sprintf("delete from %s.cohort_definition where id = %d", db.GetAtlasDB().Schema, firstCohort.Id)) - cohortDefinitions, _ = cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId) + cohortDefinitions, _ = cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, defaultTeamProject) if len(cohortDefinitions) != len(allCohortDefinitions)-1 { t.Errorf("Number of cohor_definition records expected to be %d, found %d", len(allCohortDefinitions)-1, len(cohortDefinitions)) @@ -685,7 +722,7 @@ func TestQueryFilterByConceptIdsHelper(t *testing.T) { func TestRetrieveDataBySourceIdAndCohortIdAndConceptIdsOrderedByPersonId(t *testing.T) { setUp(t) - cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId) + cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, defaultTeamProject) var sumNumeric float32 = 0 textConcat := "" classIdConcat := "" @@ -737,7 +774,7 @@ func TestRetrieveDataBySourceIdAndCohortIdAndConceptIdsOrderedByPersonId(t *test func TestErrorForRetrieveDataBySourceIdAndCohortIdAndConceptIdsOrderedByPersonId(t *testing.T) { // Tests if the method returns an error when query fails. - cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId) + cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, defaultTeamProject) // break something in the Results schema to cause a query failure in the next method: tests.BreakSomething(models.Results, "cohort", "cohort_definition_id") diff --git a/tests/setup_local_db/ddl_atlas.sql b/tests/setup_local_db/ddl_atlas.sql index 7ca52b1..b49eef1 100644 --- a/tests/setup_local_db/ddl_atlas.sql +++ b/tests/setup_local_db/ddl_atlas.sql @@ -37,3 +37,52 @@ CREATE TABLE atlas.cohort_definition modified_by_id integer, CONSTRAINT PK_cohort_definition PRIMARY KEY (id) ); + +CREATE TABLE atlas.sec_role +( + id integer NOT NULL, + name varchar(255) , + system_role boolean NOT NULL DEFAULT false, + CONSTRAINT pk_sec_role PRIMARY KEY (id), + CONSTRAINT sec_role_name_uq UNIQUE (name, system_role) +); + +CREATE TABLE atlas.sec_permission +( + id integer NOT NULL, + value varchar(255) NOT NULL, + description varchar(255), + CONSTRAINT pk_sec_permission PRIMARY KEY (id), + CONSTRAINT permission_unique UNIQUE (value) +); + +CREATE TABLE atlas.sec_role_permission +( + id integer NOT NULL, + role_id integer NOT NULL, + permission_id integer NOT NULL, + status varchar(255), + CONSTRAINT pk_sec_role_permission PRIMARY KEY (id), + CONSTRAINT role_permission_unique UNIQUE (role_id, permission_id), + CONSTRAINT fk_role_permission_to_permission FOREIGN KEY (permission_id) + REFERENCES atlas.sec_permission (id) MATCH SIMPLE + ON UPDATE NO ACTION + ON DELETE NO ACTION, + CONSTRAINT fk_role_permission_to_role FOREIGN KEY (role_id) + REFERENCES atlas.sec_role (id) MATCH SIMPLE + ON UPDATE NO ACTION + ON DELETE NO ACTION +); + +CREATE VIEW atlas.COHORT_DEFINITION_SEC_ROLE AS + select + distinct cast(regexp_replace(sec_permission.value, + '^cohortdefinition:([0-9]+):.*','\1') as integer) as cohort_definition_id, + sec_role.name as sec_role_name + from + atlas.sec_role + inner join atlas.sec_role_permission on sec_role.id = sec_role_permission.role_id + inner join atlas.sec_permission on sec_role_permission.permission_id = sec_permission.id + where + sec_permission.value ~ 'cohortdefinition:[0-9]+' +; diff --git a/tests/setup_local_db/test_data_atlas.sql b/tests/setup_local_db/test_data_atlas.sql index 9cc5a83..52ef5f0 100644 --- a/tests/setup_local_db/test_data_atlas.sql +++ b/tests/setup_local_db/test_data_atlas.sql @@ -26,3 +26,71 @@ values (32,'Test cohort3b','Copy of Larger cohort'), (4,'Test cohort4','Extra Larger cohort') ; + +insert into atlas.sec_role + (id, name, system_role) +values + (1,'public',true), + (1005,'teamprojectX',false), + (1009,'teamprojectY',false), + (3000,'someotherrole',false), + (4000,'defaultteamproject',false) +; + +insert into atlas.sec_permission + (id, value, description) +values + (1181, 'cohortdefinition:2:check:post', 'Fix Cohort Definition with ID = 2'), + (1182, 'cohortdefinition:2:put', 'Update Cohort Definition with ID = 2'), + (1183, 'cohortdefinition:2:delete', 'Delete Cohort Definition with ID = 2'), + (1184, 'cohortdefinition:2:version:*:get', 'Get cohort version'), + (1185, 'cohortdefinition:2:info:get', 'no description'), + (1186, 'cohortdefinition:2:get', 'Get Cohort Definition by ID'), + (1187, 'cohortdefinition:2:version:get', 'Get list of cohort versions'), + (1188, 'cohortdefinition:4:check:post', 'Fix Cohort Definition with ID = 4'), + (1189, 'cohortdefinition:4:put', 'Update Cohort Definition with ID = 4'), + (1190, 'cohortdefinition:4:delete', 'Delete Cohort Definition with ID = 4'), + (1191, 'cohortdefinition:4:version:*:get', 'Get cohort version'), + (1192, 'cohortdefinition:4:info:get', 'no description'), + (1193, 'cohortdefinition:4:get', 'Get Cohort Definition by ID'), + (1194, 'cohortdefinition:4:version:get', 'Get list of cohort versions'), + (2193, 'cohortdefinition:1:get', 'Get Cohort Definition by ID'), + (3193, 'cohortdefinition:3:get', 'Get Cohort Definition by ID'), + (4193, 'cohortdefinition:32:get', 'Get Cohort Definition by ID') +; + +insert into atlas.sec_role_permission + (id, role_id, permission_id) +values + (1454, 1005, 1181), + (1455, 1005, 1182), + (1456, 1005, 1183), + (1457, 1005, 1184), + (1458, 1005, 1185), + (1459, 1005, 1186), + (1460, 1005, 1187), + (1461, 1009, 1188), + (1462, 1009, 1189), + (1463, 1009, 1190), + (1464, 1009, 1191), + (1465, 1009, 1192), + (1466, 1009, 1193), + (1467, 1009, 1194), + (2454, 4000, 1181), + (2455, 4000, 1182), + (2456, 4000, 1183), + (2457, 4000, 1184), + (2458, 4000, 1185), + (2459, 4000, 1186), + (2460, 4000, 1187), + (2461, 4000, 1188), + (2462, 4000, 1189), + (2463, 4000, 1190), + (2464, 4000, 1191), + (2465, 4000, 1192), + (2466, 4000, 1193), + (2467, 4000, 1194), + (2468, 4000, 2193), + (2469, 4000, 3193), + (2470, 4000, 4193) +; diff --git a/utils/parsing.go b/utils/parsing.go index e6a25f4..e29c8c6 100644 --- a/utils/parsing.go +++ b/utils/parsing.go @@ -43,6 +43,15 @@ func Pos(value int64, list []int64) int { return -1 } +func Contains(list []int, value int) bool { + for _, item := range list { + if item == value { + return true + } + } + return false +} + func ParseInt64(strValue string) int64 { value, error := strconv.ParseInt(strValue, 10, 64) if error != nil { @@ -52,8 +61,8 @@ func ParseInt64(strValue string) int64 { } func ContainsNonNil(errors []error) bool { - for _, v := range errors { - if v != nil { + for _, item := range errors { + if item != nil { return true } }