Skip to content

Commit

Permalink
Make some changes to cohorts to keep compatibility between Analysis a…
Browse files Browse the repository at this point in the history
…nd Experience

Rebase, update citation and contributors
  • Loading branch information
MaryanMorel committed Nov 12, 2020
1 parent 75c0291 commit 91da649
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 1 deletion.
3 changes: 2 additions & 1 deletion CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ The _SCALPEL-Analysis_ package was initially implemented by researchers, develop

- Youcef Sebiat
- Maryan Morel
- Dinh Phong Nguyen
- Dinh Phong Nguyen
- Dian Sun
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ If you use a library part of _SCALPEL3_ in a scientific publication, we would ap
year={2020},
publisher={Elsevier}
}


## Contributing
The development cycle is opinionated. Each time you commit, git will
Expand Down
6 changes: 6 additions & 0 deletions scalpel/core/cohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,12 @@ def load(input: Dict) -> "Cohort":
def from_description(description: str) -> "Cohort":
raise NotImplementedError

def cache(self) -> "Cohort":
self.subjects = self.subjects.cache()
if self.events is not None:
self.events = self.events.cache()
return self


def _union(a: Cohort, b: Cohort) -> Cohort:
if a.events is None or b.events is None:
Expand Down
16 changes: 16 additions & 0 deletions tests/core/cohort_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,19 @@ def test_save_cohort(self, mock_method):
},
cohort_2.save_cohort("../../output"),
)

def test_cache(self):
patients_pd = pd.DataFrame({"patientID": [1, 2, 3]})
events_pd = pd.DataFrame({"patientID": [1, 2, 3], "value": ["DP", "DAS", "DR"]})

patients = self.spark.createDataFrame(patients_pd)

events = self.spark.createDataFrame(events_pd)
cohort = Cohort("liberal_fractures", "liberal_fractures", patients, events)
cohort.cache()
assert cohort.subjects.storageLevel.useMemory
assert cohort.events.storageLevel.useMemory

def test_from_description(self):
self.assertRaises(NotImplementedError, Cohort.from_description,
description='some string')

0 comments on commit 91da649

Please sign in to comment.