Skip to content

Commit

Permalink
feat: added "team project" to /cohortdefinition-stats/by-source-id/ e…
Browse files Browse the repository at this point in the history
…ndpoint
  • Loading branch information
pieterlukasse committed Dec 1, 2023
1 parent c90591b commit ff23235
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 25 deletions.
5 changes: 3 additions & 2 deletions controllers/cohortdefinition.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,13 @@ func (u CohortDefinitionController) RetriveAll(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"cohort_definitions": cohortDefinitions})
}

func (u CohortDefinitionController) RetriveStatsBySourceId(c *gin.Context) {
func (u CohortDefinitionController) RetriveStatsBySourceIdAndTeamProject(c *gin.Context) {
// This method returns ALL cohortdefinition entries with cohort size statistics (for a given source)

sourceId, err1 := utils.ParseNumericArg(c, "sourceid")
teamProject := c.Param("teamproject")
if err1 == nil {
cohortDefinitionsAndStats, err := u.cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId)
cohortDefinitionsAndStats, err := u.cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId, teamProject)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"message": "Error retrieving cohortDefinition", "error": err.Error()})
c.Abort()
Expand Down
31 changes: 28 additions & 3 deletions models/cohortdefinition.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ type CohortDefinitionI interface {
GetCohortDefinitionById(id int) (*CohortDefinition, error)
GetCohortDefinitionByName(name string) (*CohortDefinition, error)
GetAllCohortDefinitions() ([]*CohortDefinition, error)
GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int) ([]*CohortDefinitionStats, error)
GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int, teamProject string) ([]*CohortDefinitionStats, error)
GetCohortName(cohortId int) (string, error)
}

Expand Down Expand Up @@ -67,7 +67,19 @@ func (h CohortDefinition) GetAllCohortDefinitions() ([]*CohortDefinition, error)
return cohortDefinition, meta_result.Error
}

func (h CohortDefinition) GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int) ([]*CohortDefinitionStats, error) {
// Get the list of cohort_definition ids for a given "team project" (where "team project" is basically
// a security role name of one of the roles in Atlas/WebAPI database).
func (h CohortDefinition) GetCohortDefinitionIdsForTeamProject(teamProject string) ([]int, error) {
db2 := db.GetAtlasDB().Db
var cohortDefinitionIds []int
query := db2.Table(db.GetAtlasDB().Schema+".cohort_definition_sec_role").
Select("cohort_definition_id").
Where("sec_role_name = ?", teamProject).
Scan(&cohortDefinitionIds)
return cohortDefinitionIds, query.Error
}

func (h CohortDefinition) GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int, teamProject string) ([]*CohortDefinitionStats, error) {

// Connect to source db and gather stats:
var dataSourceModel = new(Source)
Expand All @@ -81,6 +93,11 @@ func (h CohortDefinition) GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceI
defer cancel()
meta_result := query.Scan(&cohortDefinitionStats)

// get (from separate Atlas DB - hence not using JOIN above) the list of cohort_definition_ids
// that are allowed for the given teamProject:
allowedCohortDefinitionIds, _ := h.GetCohortDefinitionIdsForTeamProject(teamProject)
log.Printf("INFO: found %d cohorts for this team project", len(allowedCohortDefinitionIds))

// add name details:
finalList := []*CohortDefinitionStats{}
for _, cohortDefinitionStat := range cohortDefinitionStats {
Expand All @@ -95,7 +112,15 @@ func (h CohortDefinition) GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceI
finalList = append(finalList, cohortDefinitionStat)
}
}
return finalList, meta_result.Error
// filter to keep only the allowed ones:
filteredFinalList := []*CohortDefinitionStats{}
for _, cohortDefinitionStat := range finalList {
if utils.Contains(allowedCohortDefinitionIds, cohortDefinitionStat.Id) {
filteredFinalList = append(filteredFinalList, cohortDefinitionStat)
}
}

return filteredFinalList, meta_result.Error
}

func (h CohortDefinition) GetCohortName(cohortId int) (string, error) {
Expand Down
2 changes: 1 addition & 1 deletion server/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func NewRouter() *gin.Engine {
authorized.GET("/cohortdefinition/by-id/:id", cohortdefinitions.RetriveById)
authorized.GET("/cohortdefinition/by-name/:name", cohortdefinitions.RetriveByName)
authorized.GET("/cohortdefinitions", cohortdefinitions.RetriveAll)
authorized.GET("/cohortdefinition-stats/by-source-id/:sourceid", cohortdefinitions.RetriveStatsBySourceId)
authorized.GET("/cohortdefinition-stats/by-source-id/:sourceid/by-team-project/:teamproject", cohortdefinitions.RetriveStatsBySourceIdAndTeamProject)

// concept endpoints:
concepts := controllers.NewConceptController(*new(models.Concept), *new(models.CohortDefinition))
Expand Down
14 changes: 7 additions & 7 deletions tests/controllers_tests/controllers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (h dummyCohortDefinitionDataModel) GetCohortName(cohortId int) (string, err
return "dummy cohort name", nil
}

func (h dummyCohortDefinitionDataModel) GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int) ([]*models.CohortDefinitionStats, error) {
func (h dummyCohortDefinitionDataModel) GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int, teamProject string) ([]*models.CohortDefinitionStats, error) {
cohortDefinitionStats := []*models.CohortDefinitionStats{
{Id: 1, CohortSize: 10, Name: "name1"},
{Id: 2, CohortSize: 22, Name: "name2"},
Expand Down Expand Up @@ -340,19 +340,19 @@ func TestGenerateCSV(t *testing.T) {
}
}

func TestRetriveStatsBySourceIdWrongParams(t *testing.T) {
func TestRetriveStatsBySourceIdAndTeamProjectWrongParams(t *testing.T) {
setUp(t)
requestContext := new(gin.Context)
requestContext.Params = append(requestContext.Params, gin.Param{Key: "Abc", Value: "def"})
requestContext.Writer = new(tests.CustomResponseWriter)
cohortDefinitionController.RetriveStatsBySourceId(requestContext)
cohortDefinitionController.RetriveStatsBySourceIdAndTeamProject(requestContext)
// Params above are wrong, so request should abort:
if !requestContext.IsAborted() {
t.Errorf("Expected aborted request")
}
}

func TestRetriveStatsBySourceIdDbPanic(t *testing.T) {
func TestRetriveStatsBySourceIdAndTeamProjectDbPanic(t *testing.T) {
setUp(t)
requestContext := new(gin.Context)
requestContext.Params = append(requestContext.Params, gin.Param{Key: "sourceid", Value: strconv.Itoa(tests.GetTestSourceId())})
Expand All @@ -366,16 +366,16 @@ func TestRetriveStatsBySourceIdDbPanic(t *testing.T) {
}
}
}()
cohortDefinitionControllerNeedsDb.RetriveStatsBySourceId(requestContext)
cohortDefinitionControllerNeedsDb.RetriveStatsBySourceIdAndTeamProject(requestContext)
t.Errorf("Expected error")
}

func TestRetriveStatsBySourceId(t *testing.T) {
func TestRetriveStatsBySourceIdAndTeamProject(t *testing.T) {
setUp(t)
requestContext := new(gin.Context)
requestContext.Params = append(requestContext.Params, gin.Param{Key: "sourceid", Value: strconv.Itoa(tests.GetTestSourceId())})
requestContext.Writer = new(tests.CustomResponseWriter)
cohortDefinitionController.RetriveStatsBySourceId(requestContext)
cohortDefinitionController.RetriveStatsBySourceIdAndTeamProject(requestContext)
result := requestContext.Writer.(*tests.CustomResponseWriter)
log.Printf("result: %s", result)
// expect result with all of the dummy data:
Expand Down
33 changes: 26 additions & 7 deletions tests/models_tests/models_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ var allConceptIds []int64
var dummyContinuousConceptId = tests.GetTestDummyContinuousConceptId()
var hareConceptId = tests.GetTestHareConceptId()
var histogramConceptId = tests.GetTestHistogramConceptId()
var defaultTeamProject = "defaultteamproject"

func TestMain(m *testing.M) {
setupSuite()
Expand All @@ -47,7 +48,7 @@ func setupSuite() {

// initialize some handy variables to use in tests below:
// (see also tests/setup_local_db/test_data_results_and_cdm.sql for these test cohort details)
allCohortDefinitions, _ = cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId)
allCohortDefinitions, _ = cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, defaultTeamProject)
largestCohort = allCohortDefinitions[0]
secondLargestCohort = allCohortDefinitions[2]
extendedCopyOfSecondLargestCohort = allCohortDefinitions[1]
Expand Down Expand Up @@ -563,11 +564,29 @@ func TestRetrieveBreakdownStatsBySourceIdAndCohortIdWithResultsWithOnePersonTwoH
}
}

func TestGetCohortDefinitionIdsForTeamProject(t *testing.T) {
setUp(t)
testTeamProject := "teamprojectX"
allowedCohortDefinitionIds, _ := cohortDefinitionModel.GetCohortDefinitionIdsForTeamProject(testTeamProject)
if len(allowedCohortDefinitionIds) != 1 {
t.Errorf("Expected teamProject '%s' to have one cohort, but found %d",
testTeamProject, len(allowedCohortDefinitionIds))
}
// test data is crafted in such a way that the default "team project" has access to all
// the cohorts. Check if this is indeed the case:
testTeamProject = defaultTeamProject
allowedCohortDefinitionIds, _ = cohortDefinitionModel.GetCohortDefinitionIdsForTeamProject(testTeamProject)
allCohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitions()
if len(allCohortDefinitions) != len(allowedCohortDefinitionIds) && len(allCohortDefinitions) > 1 {
t.Errorf("Found %d, expected %d", len(allowedCohortDefinitionIds), len(allCohortDefinitions))
}
}

func TestGetAllCohortDefinitionsAndStatsOrderBySizeDesc(t *testing.T) {
setUp(t)
cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId)
cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, defaultTeamProject)
if len(cohortDefinitions) != len(allCohortDefinitions) {
t.Errorf("Found %d", len(cohortDefinitions))
t.Errorf("Found %d, expected %d", len(cohortDefinitions), len(allCohortDefinitions))
}
// check if stats fields are filled and if order is as expected:
previousSize := 1000000
Expand All @@ -587,7 +606,7 @@ func TestGetAllCohortDefinitionsAndStatsOrderBySizeDesc(t *testing.T) {
// the situation where a cohort still exists in `cohort` table but not in `cohort_definition`).
func TestGetAllCohortDefinitionsAndStatsOrderBySizeDescWhenCohortDefinitionIsMissing(t *testing.T) {
setUp(t)
cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId)
cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, defaultTeamProject)
if len(cohortDefinitions) != len(allCohortDefinitions) {
t.Errorf("Found %d", len(cohortDefinitions))
}
Expand All @@ -596,7 +615,7 @@ func TestGetAllCohortDefinitionsAndStatsOrderBySizeDescWhenCohortDefinitionIsMis
firstCohort := cohortDefinitions[0]
tests.ExecAtlasSQLString(fmt.Sprintf("delete from %s.cohort_definition where id = %d",
db.GetAtlasDB().Schema, firstCohort.Id))
cohortDefinitions, _ = cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId)
cohortDefinitions, _ = cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, defaultTeamProject)
if len(cohortDefinitions) != len(allCohortDefinitions)-1 {
t.Errorf("Number of cohor_definition records expected to be %d, found %d",
len(allCohortDefinitions)-1, len(cohortDefinitions))
Expand Down Expand Up @@ -685,7 +704,7 @@ func TestQueryFilterByConceptIdsHelper(t *testing.T) {

func TestRetrieveDataBySourceIdAndCohortIdAndConceptIdsOrderedByPersonId(t *testing.T) {
setUp(t)
cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId)
cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, defaultTeamProject)
var sumNumeric float32 = 0
textConcat := ""
classIdConcat := ""
Expand Down Expand Up @@ -737,7 +756,7 @@ func TestRetrieveDataBySourceIdAndCohortIdAndConceptIdsOrderedByPersonId(t *test
func TestErrorForRetrieveDataBySourceIdAndCohortIdAndConceptIdsOrderedByPersonId(t *testing.T) {
// Tests if the method returns an error when query fails.

cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId)
cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, defaultTeamProject)

// break something in the Results schema to cause a query failure in the next method:
tests.BreakSomething(models.Results, "cohort", "cohort_definition_id")
Expand Down
27 changes: 24 additions & 3 deletions tests/setup_local_db/test_data_atlas.sql
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ values
(1,'public',true),
(1005,'teamprojectX',false),
(1009,'teamprojectY',false),
(3000,'someotherrole',false)
(3000,'someotherrole',false),
(4000,'defaultteamproject',false)
;

insert into atlas.sec_permission
Expand All @@ -52,7 +53,10 @@ values
(1191, 'cohortdefinition:4:version:*:get', 'Get cohort version'),
(1192, 'cohortdefinition:4:info:get', 'no description'),
(1193, 'cohortdefinition:4:get', 'Get Cohort Definition by ID'),
(1194, 'cohortdefinition:4:version:get', 'Get list of cohort versions')
(1194, 'cohortdefinition:4:version:get', 'Get list of cohort versions'),
(2193, 'cohortdefinition:1:get', 'Get Cohort Definition by ID'),
(3193, 'cohortdefinition:3:get', 'Get Cohort Definition by ID'),
(4193, 'cohortdefinition:32:get', 'Get Cohort Definition by ID')
;

insert into atlas.sec_role_permission
Expand All @@ -71,5 +75,22 @@ values
(1464, 1009, 1191),
(1465, 1009, 1192),
(1466, 1009, 1193),
(1467, 1009, 1194)
(1467, 1009, 1194),
(2454, 4000, 1181),
(2455, 4000, 1182),
(2456, 4000, 1183),
(2457, 4000, 1184),
(2458, 4000, 1185),
(2459, 4000, 1186),
(2460, 4000, 1187),
(2461, 4000, 1188),
(2462, 4000, 1189),
(2463, 4000, 1190),
(2464, 4000, 1191),
(2465, 4000, 1192),
(2466, 4000, 1193),
(2467, 4000, 1194),
(2468, 4000, 2193),
(2469, 4000, 3193),
(2470, 4000, 4193)
;
13 changes: 11 additions & 2 deletions utils/parsing.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ func Pos(value int64, list []int64) int {
return -1
}

func Contains(list []int, value int) bool {
for _, item := range list {
if item == value {
return true
}
}
return false
}

func ParseInt64(strValue string) int64 {
value, error := strconv.ParseInt(strValue, 10, 64)
if error != nil {
Expand All @@ -52,8 +61,8 @@ func ParseInt64(strValue string) int64 {
}

func ContainsNonNil(errors []error) bool {
for _, v := range errors {
if v != nil {
for _, item := range errors {
if item != nil {
return true
}
}
Expand Down

0 comments on commit ff23235

Please sign in to comment.