From bb9a4daf6cb81e30d3fad81b1f1721aadf74c7aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20P=C3=B6lsterl?= Date: Sun, 29 Dec 2024 15:00:59 +0100 Subject: [PATCH] Add estimators_samples_ property to _BaseSurvivalForest --- sksurv/ensemble/forest.py | 4 +++- tests/test_forest.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/sksurv/ensemble/forest.py b/sksurv/ensemble/forest.py index 34583054..1958653e 100644 --- a/sksurv/ensemble/forest.py +++ b/sksurv/ensemble/forest.py @@ -117,7 +117,7 @@ def fit(self, X, y, sample_weight=None): X, estimator_name=self.__class__.__name__ ) - self.n_features_in_ = X.shape[1] + self._n_samples, self.n_features_in_ = X.shape time = time.astype(np.float64) self.unique_times_, self.is_event_time_ = get_unique_times(time, event) self.n_outputs_ = self.unique_times_.shape[0] @@ -138,6 +138,8 @@ def fit(self, X, y, sample_weight=None): else: n_samples_bootstrap = None + self._n_samples_bootstrap = n_samples_bootstrap + # Check parameters self._validate_estimator() diff --git a/tests/test_forest.py b/tests/test_forest.py index b057af1e..d6387ca6 100644 --- a/tests/test_forest.py +++ b/tests/test_forest.py @@ -251,6 +251,18 @@ def test_max_samples_without_bootstrap(make_whas500, forest_cls): est.fit(whas500.x, whas500.y) +@pytest.mark.parametrize("forest_cls", FORESTS) +def test_estimators_samples(make_whas500, forest_cls): + whas500 = make_whas500(to_numeric=True) + + est = forest_cls(n_estimators=10, max_samples=333, random_state=1, low_memory=True) + est.fit(whas500.x, whas500.y) + + n_samples = [len(np.unique(arr)) for arr in est.estimators_samples_] + expected = np.array([255, 227, 245, 247, 246, 239, 254, 252, 245, 248]) + assert_array_equal(n_samples, expected) + + @pytest.mark.parametrize("forest_cls", FORESTS) @pytest.mark.parametrize("func", ["predict_survival_function", "predict_cumulative_hazard_function"]) def test_pipeline_predict(breast_cancer, forest_cls, func):