diff --git a/controllers/cohortdefinition.go b/controllers/cohortdefinition.go index 229aef7..5f96d62 100644 --- a/controllers/cohortdefinition.go +++ b/controllers/cohortdefinition.go @@ -62,12 +62,13 @@ func (u CohortDefinitionController) RetriveAll(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"cohort_definitions": cohortDefinitions}) } -func (u CohortDefinitionController) RetriveStatsBySourceId(c *gin.Context) { +func (u CohortDefinitionController) RetriveStatsBySourceIdAndTeamProject(c *gin.Context) { // This method returns ALL cohortdefinition entries with cohort size statistics (for a given source) sourceId, err1 := utils.ParseNumericArg(c, "sourceid") + teamProject := c.Param("teamproject") if err1 == nil { - cohortDefinitionsAndStats, err := u.cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId) + cohortDefinitionsAndStats, err := u.cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId, teamProject) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"message": "Error retrieving cohortDefinition", "error": err.Error()}) c.Abort() diff --git a/models/cohortdefinition.go b/models/cohortdefinition.go index e3e2102..9893f07 100644 --- a/models/cohortdefinition.go +++ b/models/cohortdefinition.go @@ -13,7 +13,7 @@ type CohortDefinitionI interface { GetCohortDefinitionById(id int) (*CohortDefinition, error) GetCohortDefinitionByName(name string) (*CohortDefinition, error) GetAllCohortDefinitions() ([]*CohortDefinition, error) - GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int) ([]*CohortDefinitionStats, error) + GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int, teamProject string) ([]*CohortDefinitionStats, error) GetCohortName(cohortId int) (string, error) } @@ -67,7 +67,19 @@ func (h CohortDefinition) GetAllCohortDefinitions() ([]*CohortDefinition, error) return cohortDefinition, meta_result.Error } -func (h CohortDefinition) GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int) ([]*CohortDefinitionStats, error) { +// Get the list of cohort_definition ids for a given "team project" (where "team project" is basically +// a security role name of one of the roles in Atlas/WebAPI database). +func (h CohortDefinition) GetCohortDefinitionIdsForTeamProject(teamProject string) ([]int, error) { + db2 := db.GetAtlasDB().Db + var cohortDefinitionIds []int + query := db2.Table(db.GetAtlasDB().Schema+".cohort_definition_sec_role"). + Select("cohort_definition_id"). + Where("sec_role_name = ?", teamProject). + Scan(&cohortDefinitionIds) + return cohortDefinitionIds, query.Error +} + +func (h CohortDefinition) GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int, teamProject string) ([]*CohortDefinitionStats, error) { // Connect to source db and gather stats: var dataSourceModel = new(Source) @@ -81,6 +93,11 @@ func (h CohortDefinition) GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceI defer cancel() meta_result := query.Scan(&cohortDefinitionStats) + // get (from separate Atlas DB - hence not using JOIN above) the list of cohort_definition_ids + // that are allowed for the given teamProject: + allowedCohortDefinitionIds, _ := h.GetCohortDefinitionIdsForTeamProject(teamProject) + log.Printf("INFO: found %d cohorts for this team project", len(allowedCohortDefinitionIds)) + // add name details: finalList := []*CohortDefinitionStats{} for _, cohortDefinitionStat := range cohortDefinitionStats { @@ -95,7 +112,15 @@ func (h CohortDefinition) GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceI finalList = append(finalList, cohortDefinitionStat) } } - return finalList, meta_result.Error + // filter to keep only the allowed ones: + filteredFinalList := []*CohortDefinitionStats{} + for _, cohortDefinitionStat := range finalList { + if utils.Contains(allowedCohortDefinitionIds, cohortDefinitionStat.Id) { + filteredFinalList = append(filteredFinalList, cohortDefinitionStat) + } + } + + return filteredFinalList, meta_result.Error } func (h CohortDefinition) GetCohortName(cohortId int) (string, error) { diff --git a/server/router.go b/server/router.go index 81d9ee8..1357cbf 100644 --- a/server/router.go +++ b/server/router.go @@ -30,7 +30,7 @@ func NewRouter() *gin.Engine { authorized.GET("/cohortdefinition/by-id/:id", cohortdefinitions.RetriveById) authorized.GET("/cohortdefinition/by-name/:name", cohortdefinitions.RetriveByName) authorized.GET("/cohortdefinitions", cohortdefinitions.RetriveAll) - authorized.GET("/cohortdefinition-stats/by-source-id/:sourceid", cohortdefinitions.RetriveStatsBySourceId) + authorized.GET("/cohortdefinition-stats/by-source-id/:sourceid/by-team-project/:teamproject", cohortdefinitions.RetriveStatsBySourceIdAndTeamProject) // concept endpoints: concepts := controllers.NewConceptController(*new(models.Concept), *new(models.CohortDefinition)) diff --git a/tests/controllers_tests/controllers_test.go b/tests/controllers_tests/controllers_test.go index 368b660..4645b01 100644 --- a/tests/controllers_tests/controllers_test.go +++ b/tests/controllers_tests/controllers_test.go @@ -100,7 +100,7 @@ func (h dummyCohortDefinitionDataModel) GetCohortName(cohortId int) (string, err return "dummy cohort name", nil } -func (h dummyCohortDefinitionDataModel) GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int) ([]*models.CohortDefinitionStats, error) { +func (h dummyCohortDefinitionDataModel) GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int, teamProject string) ([]*models.CohortDefinitionStats, error) { cohortDefinitionStats := []*models.CohortDefinitionStats{ {Id: 1, CohortSize: 10, Name: "name1"}, {Id: 2, CohortSize: 22, Name: "name2"}, @@ -340,19 +340,19 @@ func TestGenerateCSV(t *testing.T) { } } -func TestRetriveStatsBySourceIdWrongParams(t *testing.T) { +func TestRetriveStatsBySourceIdAndTeamProjectWrongParams(t *testing.T) { setUp(t) requestContext := new(gin.Context) requestContext.Params = append(requestContext.Params, gin.Param{Key: "Abc", Value: "def"}) requestContext.Writer = new(tests.CustomResponseWriter) - cohortDefinitionController.RetriveStatsBySourceId(requestContext) + cohortDefinitionController.RetriveStatsBySourceIdAndTeamProject(requestContext) // Params above are wrong, so request should abort: if !requestContext.IsAborted() { t.Errorf("Expected aborted request") } } -func TestRetriveStatsBySourceIdDbPanic(t *testing.T) { +func TestRetriveStatsBySourceIdAndTeamProjectDbPanic(t *testing.T) { setUp(t) requestContext := new(gin.Context) requestContext.Params = append(requestContext.Params, gin.Param{Key: "sourceid", Value: strconv.Itoa(tests.GetTestSourceId())}) @@ -366,16 +366,16 @@ func TestRetriveStatsBySourceIdDbPanic(t *testing.T) { } } }() - cohortDefinitionControllerNeedsDb.RetriveStatsBySourceId(requestContext) + cohortDefinitionControllerNeedsDb.RetriveStatsBySourceIdAndTeamProject(requestContext) t.Errorf("Expected error") } -func TestRetriveStatsBySourceId(t *testing.T) { +func TestRetriveStatsBySourceIdAndTeamProject(t *testing.T) { setUp(t) requestContext := new(gin.Context) requestContext.Params = append(requestContext.Params, gin.Param{Key: "sourceid", Value: strconv.Itoa(tests.GetTestSourceId())}) requestContext.Writer = new(tests.CustomResponseWriter) - cohortDefinitionController.RetriveStatsBySourceId(requestContext) + cohortDefinitionController.RetriveStatsBySourceIdAndTeamProject(requestContext) result := requestContext.Writer.(*tests.CustomResponseWriter) log.Printf("result: %s", result) // expect result with all of the dummy data: diff --git a/tests/models_tests/models_test.go b/tests/models_tests/models_test.go index ccc007b..9e3b280 100644 --- a/tests/models_tests/models_test.go +++ b/tests/models_tests/models_test.go @@ -27,6 +27,7 @@ var allConceptIds []int64 var dummyContinuousConceptId = tests.GetTestDummyContinuousConceptId() var hareConceptId = tests.GetTestHareConceptId() var histogramConceptId = tests.GetTestHistogramConceptId() +var defaultTeamProject = "defaultteamproject" func TestMain(m *testing.M) { setupSuite() @@ -47,7 +48,7 @@ func setupSuite() { // initialize some handy variables to use in tests below: // (see also tests/setup_local_db/test_data_results_and_cdm.sql for these test cohort details) - allCohortDefinitions, _ = cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId) + allCohortDefinitions, _ = cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, defaultTeamProject) largestCohort = allCohortDefinitions[0] secondLargestCohort = allCohortDefinitions[2] extendedCopyOfSecondLargestCohort = allCohortDefinitions[1] @@ -563,11 +564,29 @@ func TestRetrieveBreakdownStatsBySourceIdAndCohortIdWithResultsWithOnePersonTwoH } } +func TestGetCohortDefinitionIdsForTeamProject(t *testing.T) { + setUp(t) + testTeamProject := "teamprojectX" + allowedCohortDefinitionIds, _ := cohortDefinitionModel.GetCohortDefinitionIdsForTeamProject(testTeamProject) + if len(allowedCohortDefinitionIds) != 1 { + t.Errorf("Expected teamProject '%s' to have one cohort, but found %d", + testTeamProject, len(allowedCohortDefinitionIds)) + } + // test data is crafted in such a way that the default "team project" has access to all + // the cohorts. Check if this is indeed the case: + testTeamProject = defaultTeamProject + allowedCohortDefinitionIds, _ = cohortDefinitionModel.GetCohortDefinitionIdsForTeamProject(testTeamProject) + allCohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitions() + if len(allCohortDefinitions) != len(allowedCohortDefinitionIds) && len(allCohortDefinitions) > 1 { + t.Errorf("Found %d, expected %d", len(allowedCohortDefinitionIds), len(allCohortDefinitions)) + } +} + func TestGetAllCohortDefinitionsAndStatsOrderBySizeDesc(t *testing.T) { setUp(t) - cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId) + cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, defaultTeamProject) if len(cohortDefinitions) != len(allCohortDefinitions) { - t.Errorf("Found %d", len(cohortDefinitions)) + t.Errorf("Found %d, expected %d", len(cohortDefinitions), len(allCohortDefinitions)) } // check if stats fields are filled and if order is as expected: previousSize := 1000000 @@ -587,7 +606,7 @@ func TestGetAllCohortDefinitionsAndStatsOrderBySizeDesc(t *testing.T) { // the situation where a cohort still exists in `cohort` table but not in `cohort_definition`). func TestGetAllCohortDefinitionsAndStatsOrderBySizeDescWhenCohortDefinitionIsMissing(t *testing.T) { setUp(t) - cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId) + cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, defaultTeamProject) if len(cohortDefinitions) != len(allCohortDefinitions) { t.Errorf("Found %d", len(cohortDefinitions)) } @@ -596,7 +615,7 @@ func TestGetAllCohortDefinitionsAndStatsOrderBySizeDescWhenCohortDefinitionIsMis firstCohort := cohortDefinitions[0] tests.ExecAtlasSQLString(fmt.Sprintf("delete from %s.cohort_definition where id = %d", db.GetAtlasDB().Schema, firstCohort.Id)) - cohortDefinitions, _ = cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId) + cohortDefinitions, _ = cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, defaultTeamProject) if len(cohortDefinitions) != len(allCohortDefinitions)-1 { t.Errorf("Number of cohor_definition records expected to be %d, found %d", len(allCohortDefinitions)-1, len(cohortDefinitions)) @@ -685,7 +704,7 @@ func TestQueryFilterByConceptIdsHelper(t *testing.T) { func TestRetrieveDataBySourceIdAndCohortIdAndConceptIdsOrderedByPersonId(t *testing.T) { setUp(t) - cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId) + cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, defaultTeamProject) var sumNumeric float32 = 0 textConcat := "" classIdConcat := "" @@ -737,7 +756,7 @@ func TestRetrieveDataBySourceIdAndCohortIdAndConceptIdsOrderedByPersonId(t *test func TestErrorForRetrieveDataBySourceIdAndCohortIdAndConceptIdsOrderedByPersonId(t *testing.T) { // Tests if the method returns an error when query fails. - cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId) + cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, defaultTeamProject) // break something in the Results schema to cause a query failure in the next method: tests.BreakSomething(models.Results, "cohort", "cohort_definition_id") diff --git a/tests/setup_local_db/test_data_atlas.sql b/tests/setup_local_db/test_data_atlas.sql index b07ec4c..52ef5f0 100644 --- a/tests/setup_local_db/test_data_atlas.sql +++ b/tests/setup_local_db/test_data_atlas.sql @@ -33,7 +33,8 @@ values (1,'public',true), (1005,'teamprojectX',false), (1009,'teamprojectY',false), - (3000,'someotherrole',false) + (3000,'someotherrole',false), + (4000,'defaultteamproject',false) ; insert into atlas.sec_permission @@ -52,7 +53,10 @@ values (1191, 'cohortdefinition:4:version:*:get', 'Get cohort version'), (1192, 'cohortdefinition:4:info:get', 'no description'), (1193, 'cohortdefinition:4:get', 'Get Cohort Definition by ID'), - (1194, 'cohortdefinition:4:version:get', 'Get list of cohort versions') + (1194, 'cohortdefinition:4:version:get', 'Get list of cohort versions'), + (2193, 'cohortdefinition:1:get', 'Get Cohort Definition by ID'), + (3193, 'cohortdefinition:3:get', 'Get Cohort Definition by ID'), + (4193, 'cohortdefinition:32:get', 'Get Cohort Definition by ID') ; insert into atlas.sec_role_permission @@ -71,5 +75,22 @@ values (1464, 1009, 1191), (1465, 1009, 1192), (1466, 1009, 1193), - (1467, 1009, 1194) + (1467, 1009, 1194), + (2454, 4000, 1181), + (2455, 4000, 1182), + (2456, 4000, 1183), + (2457, 4000, 1184), + (2458, 4000, 1185), + (2459, 4000, 1186), + (2460, 4000, 1187), + (2461, 4000, 1188), + (2462, 4000, 1189), + (2463, 4000, 1190), + (2464, 4000, 1191), + (2465, 4000, 1192), + (2466, 4000, 1193), + (2467, 4000, 1194), + (2468, 4000, 2193), + (2469, 4000, 3193), + (2470, 4000, 4193) ; diff --git a/utils/parsing.go b/utils/parsing.go index e6a25f4..e29c8c6 100644 --- a/utils/parsing.go +++ b/utils/parsing.go @@ -43,6 +43,15 @@ func Pos(value int64, list []int64) int { return -1 } +func Contains(list []int, value int) bool { + for _, item := range list { + if item == value { + return true + } + } + return false +} + func ParseInt64(strValue string) int64 { value, error := strconv.ParseInt(strValue, 10, 64) if error != nil { @@ -52,8 +61,8 @@ func ParseInt64(strValue string) int64 { } func ContainsNonNil(errors []error) bool { - for _, v := range errors { - if v != nil { + for _, item := range errors { + if item != nil { return true } }