Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: add "team project" filtering to /cohortdefinition-stats endpoint #81

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions controllers/cohortdefinition.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,18 @@ func (u CohortDefinitionController) RetriveAll(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"cohort_definitions": cohortDefinitions})
}

func (u CohortDefinitionController) RetriveStatsBySourceId(c *gin.Context) {
func (u CohortDefinitionController) RetriveStatsBySourceIdAndTeamProject(c *gin.Context) {
// This method returns ALL cohortdefinition entries with cohort size statistics (for a given source)

sourceId, err1 := utils.ParseNumericArg(c, "sourceid")
teamProject := c.Param("teamproject")
if teamProject == "" {
c.JSON(http.StatusInternalServerError, gin.H{"message": "Error while parsing request", "error": "team-project is a mandatory parameter but was found to be empty!"})
c.Abort()
return
}
if err1 == nil {
cohortDefinitionsAndStats, err := u.cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId)
cohortDefinitionsAndStats, err := u.cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId, teamProject)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"message": "Error retrieving cohortDefinition", "error": err.Error()})
c.Abort()
Expand Down
29 changes: 25 additions & 4 deletions models/cohortdefinition.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ type CohortDefinitionI interface {
GetCohortDefinitionById(id int) (*CohortDefinition, error)
GetCohortDefinitionByName(name string) (*CohortDefinition, error)
GetAllCohortDefinitions() ([]*CohortDefinition, error)
GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int) ([]*CohortDefinitionStats, error)
GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int, teamProject string) ([]*CohortDefinitionStats, error)
GetCohortName(cohortId int) (string, error)
}

Expand Down Expand Up @@ -67,7 +67,19 @@ func (h CohortDefinition) GetAllCohortDefinitions() ([]*CohortDefinition, error)
return cohortDefinition, meta_result.Error
}

func (h CohortDefinition) GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int) ([]*CohortDefinitionStats, error) {
// Get the list of cohort_definition ids for a given "team project" (where "team project" is basically
// a security role name of one of the roles in Atlas/WebAPI database).
func (h CohortDefinition) GetCohortDefinitionIdsForTeamProject(teamProject string) ([]int, error) {
db2 := db.GetAtlasDB().Db
var cohortDefinitionIds []int
query := db2.Table(db.GetAtlasDB().Schema+".cohort_definition_sec_role").
Select("cohort_definition_id").
Where("sec_role_name = ?", teamProject).
Scan(&cohortDefinitionIds)
return cohortDefinitionIds, query.Error
}

func (h CohortDefinition) GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int, teamProject string) ([]*CohortDefinitionStats, error) {

// Connect to source db and gather stats:
var dataSourceModel = new(Source)
Expand All @@ -81,6 +93,11 @@ func (h CohortDefinition) GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceI
defer cancel()
meta_result := query.Scan(&cohortDefinitionStats)

// get (from separate Atlas DB - hence not using JOIN above) the list of cohort_definition_ids
// that are allowed for the given teamProject:
allowedCohortDefinitionIds, _ := h.GetCohortDefinitionIdsForTeamProject(teamProject)
log.Printf("INFO: found %d cohorts for this team project", len(allowedCohortDefinitionIds))

// add name details:
finalList := []*CohortDefinitionStats{}
for _, cohortDefinitionStat := range cohortDefinitionStats {
Expand All @@ -91,10 +108,14 @@ func (h CohortDefinition) GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceI
cohortDefinitionStat.CohortSize)
continue
} else {
cohortDefinitionStat.Name = cohortDefinition.Name
finalList = append(finalList, cohortDefinitionStat)
// filter to keep only the allowed ones:
if utils.Contains(allowedCohortDefinitionIds, cohortDefinitionStat.Id) {
cohortDefinitionStat.Name = cohortDefinition.Name
finalList = append(finalList, cohortDefinitionStat)
}
}
}

return finalList, meta_result.Error
}

Expand Down
2 changes: 1 addition & 1 deletion server/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func NewRouter() *gin.Engine {
authorized.GET("/cohortdefinition/by-id/:id", cohortdefinitions.RetriveById)
authorized.GET("/cohortdefinition/by-name/:name", cohortdefinitions.RetriveByName)
authorized.GET("/cohortdefinitions", cohortdefinitions.RetriveAll)
authorized.GET("/cohortdefinition-stats/by-source-id/:sourceid", cohortdefinitions.RetriveStatsBySourceId)
authorized.GET("/cohortdefinition-stats/by-source-id/:sourceid/by-team-project/:teamproject", cohortdefinitions.RetriveStatsBySourceIdAndTeamProject)

// concept endpoints:
concepts := controllers.NewConceptController(*new(models.Concept), *new(models.CohortDefinition))
Expand Down
32 changes: 25 additions & 7 deletions tests/controllers_tests/controllers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (h dummyCohortDefinitionDataModel) GetCohortName(cohortId int) (string, err
return "dummy cohort name", nil
}

func (h dummyCohortDefinitionDataModel) GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int) ([]*models.CohortDefinitionStats, error) {
func (h dummyCohortDefinitionDataModel) GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId int, teamProject string) ([]*models.CohortDefinitionStats, error) {
pieterlukasse marked this conversation as resolved.
Show resolved Hide resolved
cohortDefinitionStats := []*models.CohortDefinitionStats{
{Id: 1, CohortSize: 10, Name: "name1"},
{Id: 2, CohortSize: 22, Name: "name2"},
Expand Down Expand Up @@ -340,22 +340,23 @@ func TestGenerateCSV(t *testing.T) {
}
}

func TestRetriveStatsBySourceIdWrongParams(t *testing.T) {
func TestRetriveStatsBySourceIdAndTeamProjectWrongParams(t *testing.T) {
setUp(t)
requestContext := new(gin.Context)
requestContext.Params = append(requestContext.Params, gin.Param{Key: "Abc", Value: "def"})
requestContext.Writer = new(tests.CustomResponseWriter)
cohortDefinitionController.RetriveStatsBySourceId(requestContext)
cohortDefinitionController.RetriveStatsBySourceIdAndTeamProject(requestContext)
// Params above are wrong, so request should abort:
if !requestContext.IsAborted() {
t.Errorf("Expected aborted request")
}
}

func TestRetriveStatsBySourceIdDbPanic(t *testing.T) {
func TestRetriveStatsBySourceIdAndTeamProjectDbPanic(t *testing.T) {
setUp(t)
requestContext := new(gin.Context)
requestContext.Params = append(requestContext.Params, gin.Param{Key: "sourceid", Value: strconv.Itoa(tests.GetTestSourceId())})
requestContext.Params = append(requestContext.Params, gin.Param{Key: "teamproject", Value: "dummy-team-project"})
requestContext.Writer = new(tests.CustomResponseWriter)

defer func() {
Expand All @@ -366,16 +367,33 @@ func TestRetriveStatsBySourceIdDbPanic(t *testing.T) {
}
}
}()
cohortDefinitionControllerNeedsDb.RetriveStatsBySourceId(requestContext)
cohortDefinitionControllerNeedsDb.RetriveStatsBySourceIdAndTeamProject(requestContext)
t.Errorf("Expected error")
}

func TestRetriveStatsBySourceId(t *testing.T) {
func TestRetriveStatsBySourceIdAndTeamProjectCheckMandatoryTeamProject(t *testing.T) {
setUp(t)
requestContext := new(gin.Context)
requestContext.Params = append(requestContext.Params, gin.Param{Key: "sourceid", Value: strconv.Itoa(tests.GetTestSourceId())})
requestContext.Writer = new(tests.CustomResponseWriter)
cohortDefinitionController.RetriveStatsBySourceId(requestContext)
cohortDefinitionController.RetriveStatsBySourceIdAndTeamProject(requestContext)
result := requestContext.Writer.(*tests.CustomResponseWriter)
// Params above are wrong, so request should abort:
if !requestContext.IsAborted() {
t.Errorf("Expected aborted request")
}
if !strings.Contains(result.CustomResponseWriterOut, "team-project is a mandatory parameter") {
t.Errorf("Expected error about mandatory team-project")
}
}

func TestRetriveStatsBySourceIdAndTeamProject(t *testing.T) {
setUp(t)
requestContext := new(gin.Context)
requestContext.Params = append(requestContext.Params, gin.Param{Key: "sourceid", Value: strconv.Itoa(tests.GetTestSourceId())})
requestContext.Params = append(requestContext.Params, gin.Param{Key: "teamproject", Value: "dummy-team-project"})
requestContext.Writer = new(tests.CustomResponseWriter)
cohortDefinitionController.RetriveStatsBySourceIdAndTeamProject(requestContext)
pieterlukasse marked this conversation as resolved.
Show resolved Hide resolved
result := requestContext.Writer.(*tests.CustomResponseWriter)
log.Printf("result: %s", result)
// expect result with all of the dummy data:
Expand Down
51 changes: 44 additions & 7 deletions tests/models_tests/models_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ var allConceptIds []int64
var dummyContinuousConceptId = tests.GetTestDummyContinuousConceptId()
var hareConceptId = tests.GetTestHareConceptId()
var histogramConceptId = tests.GetTestHistogramConceptId()
var defaultTeamProject = "defaultteamproject"

func TestMain(m *testing.M) {
setupSuite()
Expand All @@ -47,7 +48,7 @@ func setupSuite() {

// initialize some handy variables to use in tests below:
// (see also tests/setup_local_db/test_data_results_and_cdm.sql for these test cohort details)
allCohortDefinitions, _ = cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId)
allCohortDefinitions, _ = cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, defaultTeamProject)
largestCohort = allCohortDefinitions[0]
secondLargestCohort = allCohortDefinitions[2]
extendedCopyOfSecondLargestCohort = allCohortDefinitions[1]
Expand Down Expand Up @@ -563,11 +564,29 @@ func TestRetrieveBreakdownStatsBySourceIdAndCohortIdWithResultsWithOnePersonTwoH
}
}

func TestGetCohortDefinitionIdsForTeamProject(t *testing.T) {
setUp(t)
testTeamProject := "teamprojectX"
allowedCohortDefinitionIds, _ := cohortDefinitionModel.GetCohortDefinitionIdsForTeamProject(testTeamProject)
if len(allowedCohortDefinitionIds) != 1 {
t.Errorf("Expected teamProject '%s' to have one cohort, but found %d",
testTeamProject, len(allowedCohortDefinitionIds))
}
// test data is crafted in such a way that the default "team project" has access to all
// the cohorts. Check if this is indeed the case:
testTeamProject = defaultTeamProject
allowedCohortDefinitionIds, _ = cohortDefinitionModel.GetCohortDefinitionIdsForTeamProject(testTeamProject)
allCohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitions()
if len(allCohortDefinitions) != len(allowedCohortDefinitionIds) && len(allCohortDefinitions) > 1 {
t.Errorf("Found %d, expected %d", len(allowedCohortDefinitionIds), len(allCohortDefinitions))
}
}

func TestGetAllCohortDefinitionsAndStatsOrderBySizeDesc(t *testing.T) {
setUp(t)
cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId)
cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, defaultTeamProject)
if len(cohortDefinitions) != len(allCohortDefinitions) {
t.Errorf("Found %d", len(cohortDefinitions))
t.Errorf("Found %d, expected %d", len(cohortDefinitions), len(allCohortDefinitions))
}
// check if stats fields are filled and if order is as expected:
previousSize := 1000000
Expand All @@ -580,14 +599,32 @@ func TestGetAllCohortDefinitionsAndStatsOrderBySizeDesc(t *testing.T) {
}
previousSize = cohortDefinition.CohortSize
}

// some extra tests to cover also the teamProject option for this method:
testTeamProject := "teamprojectX"
allowedCohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, testTeamProject)
if len(allowedCohortDefinitions) != 1 {
t.Errorf("Expected teamProject '%s' to have one cohort, but found %d",
testTeamProject, len(allowedCohortDefinitions))
}
if len(cohortDefinitions) <= len(allowedCohortDefinitions) {
t.Errorf("Expected list of projects for '%s' to be larger than for %s",
defaultTeamProject, testTeamProject)
}
testTeamProject = "teamprojectNonExisting"
allowedCohortDefinitions, _ = cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, testTeamProject)
if len(allowedCohortDefinitions) != 0 {
t.Errorf("Expected teamProject '%s' to have NO cohort, but found %d",
testTeamProject, len(allowedCohortDefinitions))
}
}

// Tests whether the code deals correctly with the (error) situation where
// the `cohort_definition` and `cohort` tables are not in sync (more specifically
// the situation where a cohort still exists in `cohort` table but not in `cohort_definition`).
func TestGetAllCohortDefinitionsAndStatsOrderBySizeDescWhenCohortDefinitionIsMissing(t *testing.T) {
setUp(t)
cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId)
cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, defaultTeamProject)
if len(cohortDefinitions) != len(allCohortDefinitions) {
t.Errorf("Found %d", len(cohortDefinitions))
}
Expand All @@ -596,7 +633,7 @@ func TestGetAllCohortDefinitionsAndStatsOrderBySizeDescWhenCohortDefinitionIsMis
firstCohort := cohortDefinitions[0]
tests.ExecAtlasSQLString(fmt.Sprintf("delete from %s.cohort_definition where id = %d",
db.GetAtlasDB().Schema, firstCohort.Id))
cohortDefinitions, _ = cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId)
cohortDefinitions, _ = cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, defaultTeamProject)
if len(cohortDefinitions) != len(allCohortDefinitions)-1 {
t.Errorf("Number of cohor_definition records expected to be %d, found %d",
len(allCohortDefinitions)-1, len(cohortDefinitions))
Expand Down Expand Up @@ -685,7 +722,7 @@ func TestQueryFilterByConceptIdsHelper(t *testing.T) {

func TestRetrieveDataBySourceIdAndCohortIdAndConceptIdsOrderedByPersonId(t *testing.T) {
setUp(t)
cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId)
cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, defaultTeamProject)
var sumNumeric float32 = 0
textConcat := ""
classIdConcat := ""
Expand Down Expand Up @@ -737,7 +774,7 @@ func TestRetrieveDataBySourceIdAndCohortIdAndConceptIdsOrderedByPersonId(t *test
func TestErrorForRetrieveDataBySourceIdAndCohortIdAndConceptIdsOrderedByPersonId(t *testing.T) {
// Tests if the method returns an error when query fails.

cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId)
cohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, defaultTeamProject)

// break something in the Results schema to cause a query failure in the next method:
tests.BreakSomething(models.Results, "cohort", "cohort_definition_id")
Expand Down
49 changes: 49 additions & 0 deletions tests/setup_local_db/ddl_atlas.sql
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,52 @@ CREATE TABLE atlas.cohort_definition
modified_by_id integer,
CONSTRAINT PK_cohort_definition PRIMARY KEY (id)
);

CREATE TABLE atlas.sec_role
(
id integer NOT NULL,
name varchar(255) ,
system_role boolean NOT NULL DEFAULT false,
CONSTRAINT pk_sec_role PRIMARY KEY (id),
CONSTRAINT sec_role_name_uq UNIQUE (name, system_role)
);

CREATE TABLE atlas.sec_permission
(
id integer NOT NULL,
value varchar(255) NOT NULL,
description varchar(255),
CONSTRAINT pk_sec_permission PRIMARY KEY (id),
CONSTRAINT permission_unique UNIQUE (value)
);

CREATE TABLE atlas.sec_role_permission
(
id integer NOT NULL,
role_id integer NOT NULL,
permission_id integer NOT NULL,
status varchar(255),
CONSTRAINT pk_sec_role_permission PRIMARY KEY (id),
CONSTRAINT role_permission_unique UNIQUE (role_id, permission_id),
CONSTRAINT fk_role_permission_to_permission FOREIGN KEY (permission_id)
REFERENCES atlas.sec_permission (id) MATCH SIMPLE
ON UPDATE NO ACTION
ON DELETE NO ACTION,
CONSTRAINT fk_role_permission_to_role FOREIGN KEY (role_id)
REFERENCES atlas.sec_role (id) MATCH SIMPLE
ON UPDATE NO ACTION
ON DELETE NO ACTION
);

CREATE VIEW atlas.COHORT_DEFINITION_SEC_ROLE AS
select
distinct cast(regexp_replace(sec_permission.value,
'^cohortdefinition:([0-9]+):.*','\1') as integer) as cohort_definition_id,
sec_role.name as sec_role_name
from
atlas.sec_role
inner join atlas.sec_role_permission on sec_role.id = sec_role_permission.role_id
inner join atlas.sec_permission on sec_role_permission.permission_id = sec_permission.id
where
sec_permission.value ~ 'cohortdefinition:[0-9]+'
;
pieterlukasse marked this conversation as resolved.
Show resolved Hide resolved
68 changes: 68 additions & 0 deletions tests/setup_local_db/test_data_atlas.sql
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,71 @@ values
(32,'Test cohort3b','Copy of Larger cohort'),
(4,'Test cohort4','Extra Larger cohort')
;

insert into atlas.sec_role
(id, name, system_role)
values
(1,'public',true),
(1005,'teamprojectX',false),
(1009,'teamprojectY',false),
(3000,'someotherrole',false),
(4000,'defaultteamproject',false)
;

insert into atlas.sec_permission
(id, value, description)
values
(1181, 'cohortdefinition:2:check:post', 'Fix Cohort Definition with ID = 2'),
(1182, 'cohortdefinition:2:put', 'Update Cohort Definition with ID = 2'),
(1183, 'cohortdefinition:2:delete', 'Delete Cohort Definition with ID = 2'),
(1184, 'cohortdefinition:2:version:*:get', 'Get cohort version'),
(1185, 'cohortdefinition:2:info:get', 'no description'),
(1186, 'cohortdefinition:2:get', 'Get Cohort Definition by ID'),
(1187, 'cohortdefinition:2:version:get', 'Get list of cohort versions'),
(1188, 'cohortdefinition:4:check:post', 'Fix Cohort Definition with ID = 4'),
(1189, 'cohortdefinition:4:put', 'Update Cohort Definition with ID = 4'),
(1190, 'cohortdefinition:4:delete', 'Delete Cohort Definition with ID = 4'),
(1191, 'cohortdefinition:4:version:*:get', 'Get cohort version'),
(1192, 'cohortdefinition:4:info:get', 'no description'),
(1193, 'cohortdefinition:4:get', 'Get Cohort Definition by ID'),
(1194, 'cohortdefinition:4:version:get', 'Get list of cohort versions'),
(2193, 'cohortdefinition:1:get', 'Get Cohort Definition by ID'),
(3193, 'cohortdefinition:3:get', 'Get Cohort Definition by ID'),
(4193, 'cohortdefinition:32:get', 'Get Cohort Definition by ID')
;

insert into atlas.sec_role_permission
(id, role_id, permission_id)
values
(1454, 1005, 1181),
(1455, 1005, 1182),
(1456, 1005, 1183),
(1457, 1005, 1184),
(1458, 1005, 1185),
(1459, 1005, 1186),
(1460, 1005, 1187),
(1461, 1009, 1188),
(1462, 1009, 1189),
(1463, 1009, 1190),
(1464, 1009, 1191),
(1465, 1009, 1192),
(1466, 1009, 1193),
(1467, 1009, 1194),
(2454, 4000, 1181),
(2455, 4000, 1182),
(2456, 4000, 1183),
(2457, 4000, 1184),
(2458, 4000, 1185),
(2459, 4000, 1186),
(2460, 4000, 1187),
(2461, 4000, 1188),
(2462, 4000, 1189),
(2463, 4000, 1190),
(2464, 4000, 1191),
(2465, 4000, 1192),
(2466, 4000, 1193),
(2467, 4000, 1194),
(2468, 4000, 2193),
(2469, 4000, 3193),
(2470, 4000, 4193)
;
Loading
Loading