Skip to content

Commit

Permalink
Add estimators_samples_ property to _BaseSurvivalForest
Browse files Browse the repository at this point in the history
  • Loading branch information
sebp committed Dec 29, 2024
1 parent 23abaae commit bb9a4da
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
4 changes: 3 additions & 1 deletion sksurv/ensemble/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def fit(self, X, y, sample_weight=None):
X, estimator_name=self.__class__.__name__
)

self.n_features_in_ = X.shape[1]
self._n_samples, self.n_features_in_ = X.shape
time = time.astype(np.float64)
self.unique_times_, self.is_event_time_ = get_unique_times(time, event)
self.n_outputs_ = self.unique_times_.shape[0]
Expand All @@ -138,6 +138,8 @@ def fit(self, X, y, sample_weight=None):
else:
n_samples_bootstrap = None

self._n_samples_bootstrap = n_samples_bootstrap

# Check parameters
self._validate_estimator()

Expand Down
12 changes: 12 additions & 0 deletions tests/test_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,18 @@ def test_max_samples_without_bootstrap(make_whas500, forest_cls):
est.fit(whas500.x, whas500.y)


@pytest.mark.parametrize("forest_cls", FORESTS)
def test_estimators_samples(make_whas500, forest_cls):
whas500 = make_whas500(to_numeric=True)

est = forest_cls(n_estimators=10, max_samples=333, random_state=1, low_memory=True)
est.fit(whas500.x, whas500.y)

n_samples = [len(np.unique(arr)) for arr in est.estimators_samples_]
expected = np.array([255, 227, 245, 247, 246, 239, 254, 252, 245, 248])
assert_array_equal(n_samples, expected)


@pytest.mark.parametrize("forest_cls", FORESTS)
@pytest.mark.parametrize("func", ["predict_survival_function", "predict_cumulative_hazard_function"])
def test_pipeline_predict(breast_cancer, forest_cls, func):
Expand Down

0 comments on commit bb9a4da

Please sign in to comment.