Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for scikit-learn 1.6 #504

Merged
merged 20 commits into from
Jan 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,6 @@ repos:
hooks:
- id: black-jupyter
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.1
rev: v0.8.4
hooks:
- id: ruff
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Requirements
- numpy
- osqp
- pandas 1.4.0 or later
- scikit-learn 1.4 or 1.5
- scikit-learn 1.6
- scipy
- C/C++ compiler

Expand Down
2 changes: 1 addition & 1 deletion ci/appveyor/py310.ps1
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
$env:CI_PYTHON_VERSION="3.10.*"
$env:CI_PANDAS_VERSION="1.5.*"
$env:CI_NUMPY_VERSION="1.25.*"
$env:CI_SKLEARN_VERSION="1.4.*"
$env:CI_SKLEARN_VERSION="1.6.*"
2 changes: 1 addition & 1 deletion ci/appveyor/py311.ps1
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
$env:CI_PYTHON_VERSION="3.11.*"
$env:CI_PANDAS_VERSION="2.0.*"
$env:CI_NUMPY_VERSION="1.26.*"
$env:CI_SKLEARN_VERSION="1.5.*"
$env:CI_SKLEARN_VERSION="1.6.*"
2 changes: 1 addition & 1 deletion ci/appveyor/py312.ps1
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
$env:CI_PYTHON_VERSION="3.12.*"
$env:CI_PANDAS_VERSION="2.2.*"
$env:CI_NUMPY_VERSION="2.0.*"
$env:CI_SKLEARN_VERSION="1.5.*"
$env:CI_SKLEARN_VERSION="1.6.*"
2 changes: 1 addition & 1 deletion ci/appveyor/py313.ps1
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
$env:CI_PYTHON_VERSION="3.13.*"
$env:CI_PANDAS_VERSION="2.2.*"
$env:CI_NUMPY_VERSION="2.1.*"
$env:CI_SKLEARN_VERSION="1.5.*"
$env:CI_SKLEARN_VERSION="1.6.*"
2 changes: 1 addition & 1 deletion ci/deps/py310.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
export CI_PYTHON_VERSION='3.10.*'
export CI_PANDAS_VERSION='1.5.*'
export CI_NUMPY_VERSION='1.25.*'
export CI_SKLEARN_VERSION='1.4.*'
export CI_SKLEARN_VERSION='1.6.*'
export CI_NO_SLOW=false
2 changes: 1 addition & 1 deletion ci/deps/py311.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
export CI_PYTHON_VERSION='3.11.*'
export CI_PANDAS_VERSION='2.0.*'
export CI_NUMPY_VERSION='1.26.*'
export CI_SKLEARN_VERSION='1.5.*'
export CI_SKLEARN_VERSION='1.6.*'
export CI_NO_SLOW=true
2 changes: 1 addition & 1 deletion ci/deps/py312.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
export CI_PYTHON_VERSION='3.12.*'
export CI_PANDAS_VERSION='2.2.*'
export CI_NUMPY_VERSION='2.0.*'
export CI_SKLEARN_VERSION='1.5.*'
export CI_SKLEARN_VERSION='1.6.*'
export CI_NO_SLOW=true
2 changes: 1 addition & 1 deletion ci/deps/py313.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
export CI_PYTHON_VERSION='3.13.*'
export CI_PANDAS_VERSION='2.2.*'
export CI_NUMPY_VERSION='2.1.*'
export CI_SKLEARN_VERSION='1.5.*'
export CI_SKLEARN_VERSION='1.6.*'
export CI_NO_SLOW=false
2 changes: 1 addition & 1 deletion doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@
}

intersphinx_mapping = {
"sklearn": ("https://scikit-learn.org/1.5", None),
"sklearn": ("https://scikit-learn.org/1.6", None),
"cython": ("https://cython.readthedocs.io/en/latest/", None),
"scipy": ("https://docs.scipy.org/doc/scipy/", None),
"pandas": ("https://pandas.pydata.org/docs/", None),
Expand Down
2 changes: 1 addition & 1 deletion doc/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,6 @@ The current minimum dependencies to run scikit-survival are:
- numpy
- osqp
- pandas 1.4.0 or later
- scikit-learn 1.4 or 1.5
- scikit-learn 1.6
- scipy
- C/C++ compiler
19 changes: 13 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ requires = [
"numpy>=2.0.0",

# scikit-learn requirements
"scikit-learn~=1.4.0; python_version<='3.12'",
"scikit-learn~=1.5.0; python_version=='3.13'",
"scikit-learn~=1.6.1; python_version<='3.13'",
"scikit-learn; python_version>'3.13'",
]
build-backend = "setuptools.build_meta"
Expand Down Expand Up @@ -51,7 +50,7 @@ dependencies = [
"osqp !=0.6.0,!=0.6.1",
"pandas >=1.4.0",
"scipy >=1.3.2",
"scikit-learn >=1.4.0,<1.6",
"scikit-learn >=1.6.1,<1.7",
]
dynamic = ["version"]

Expand Down Expand Up @@ -188,13 +187,21 @@ target-version = "py310"
ignore = ["C408"]
ignore-init-module-imports = true
select = [
"C4",
"C9",
# pycodestyle
"E",
"W",
# mccabe
"C90",
# pyflakes
"F",
# isort
"I",
# flake8-builtins
"A",
# flake8-comprehensions
"C4",
# flake8-pytest-style
"PT",
"W",
]

[tool.ruff.lint.flake8-pytest-style]
Expand Down
2 changes: 1 addition & 1 deletion sksurv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def show_versions():
max(map(len, deps)),
max(map(len, sys_info.keys())),
)
fmt = "{0:<%ds}: {1}" % minwidth
fmt = f"{{0:<{minwidth}s}}: {{1}}"

print("SYSTEM")
print("------")
Expand Down
6 changes: 4 additions & 2 deletions sksurv/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,5 +99,7 @@ def score(self, X, y):
result = concordance_index_censored(y[name_event], y[name_time], risk_score)
return result[0]

def _more_tags(self):
return {"requires_y": True}
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.target_tags.required = True
return tags
4 changes: 4 additions & 0 deletions sksurv/bintrees/_binarytrees.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
cimport cython
from libcpp cimport bool
from libcpp.cast cimport dynamic_cast

Expand Down Expand Up @@ -76,18 +77,21 @@ cdef class BaseTree:
return self.count_larger(key)


@cython.final
cdef class RBTree(BaseTree):
def __cinit__(self, int size):
if size <= 0:
raise ValueError('size must be greater zero')
self.treeptr = new rbtree(size)

@cython.final
cdef class AVLTree(BaseTree):
def __cinit__(self, int size):
if size <= 0:
raise ValueError('size must be greater zero')
self.treeptr = dynamic_cast[rbtree_ptr](new avl(size))

@cython.final
cdef class AATree(BaseTree):
def __cinit__(self, int size):
if size <= 0:
Expand Down
30 changes: 18 additions & 12 deletions sksurv/ensemble/boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,15 @@
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor
from sklearn.tree._tree import DTYPE
from sklearn.utils import check_random_state
from sklearn.utils._param_validation import Interval, StrOptions
from sklearn.utils.extmath import squared_norm
from sklearn.utils.validation import _check_sample_weight, check_array, check_is_fitted
from sklearn.utils.validation import (
_check_sample_weight,
check_array,
check_is_fitted,
check_random_state,
validate_data,
)

from ..base import SurvivalAnalysisMixin
from ..linear_model.coxph import BreslowEstimator
Expand Down Expand Up @@ -389,7 +394,7 @@ def fit(self, X, y, sample_weight=None):
if not self.warm_start:
self._clear_state()

X = self._validate_data(X, ensure_min_samples=2)
X = validate_data(self, X, ensure_min_samples=2)
event, time = check_array_survival(X, y)

sample_weight = _check_sample_weight(sample_weight, X)
Expand All @@ -398,7 +403,7 @@ def fit(self, X, y, sample_weight=None):
Xi = np.column_stack((np.ones(n_samples), X))

self._loss = LOSS_FUNCTIONS[self.loss]()
if isinstance(self._loss, (CensoredSquaredLoss, IPCWLeastSquaresError)):
if isinstance(self._loss, CensoredSquaredLoss | IPCWLeastSquaresError):
time = np.log(time)

if not self._is_fitted():
Expand Down Expand Up @@ -470,7 +475,7 @@ def predict(self, X):
Predicted risk scores.
"""
check_is_fitted(self, "estimators_")
X = self._validate_data(X, reset=False)
X = validate_data(self, X, reset=False)

return self._predict(X)

Expand Down Expand Up @@ -957,7 +962,7 @@ def _set_max_features(self):
max_features = max(1, int(np.log2(self.n_features_in_)))
elif self.max_features is None:
max_features = self.n_features_in_
elif isinstance(self.max_features, (numbers.Integral, np.integer)):
elif isinstance(self.max_features, numbers.Integral):
max_features = self.max_features
else: # float
max_features = max(1, int(self.max_features * self.n_features_in_))
Expand Down Expand Up @@ -1234,7 +1239,8 @@ def fit(self, X, y, sample_weight=None, monitor=None):
if not self.warm_start:
self._clear_state()

X = self._validate_data(
X = validate_data(
self,
X,
ensure_min_samples=2,
order="C",
Expand All @@ -1256,7 +1262,7 @@ def fit(self, X, y, sample_weight=None, monitor=None):
# self.loss is guaranteed to be a string
self._loss = self._get_loss(sample_weight=sample_weight)

if isinstance(self._loss, (CensoredSquaredLoss, IPCWLeastSquaresError)):
if isinstance(self._loss, CensoredSquaredLoss | IPCWLeastSquaresError):
time = np.log(time)

if self.n_iter_no_change is not None:
Expand Down Expand Up @@ -1315,13 +1321,13 @@ def fit(self, X, y, sample_weight=None, monitor=None):
begin_at_stage = self.estimators_.shape[0]
# The requirements of _raw_predict
# are more constrained than fit. It accepts only CSR
# matrices. Finite values have already been checked in _validate_data.
# matrices. Finite values have already been checked in validate_data.
X_train = check_array(
X_train,
dtype=DTYPE,
order="C",
accept_sparse="csr",
force_all_finite=False,
ensure_all_finite=False,
)
raw_predictions = self._raw_predict(X_train)
self._resize_state()
Expand Down Expand Up @@ -1390,7 +1396,7 @@ def _dropout_raw_predict(self, X):
return raw_predictions

def _dropout_staged_raw_predict(self, X):
X = self._validate_data(X, dtype=DTYPE, order="C", accept_sparse="csr")
X = validate_data(self, X, dtype=DTYPE, order="C", accept_sparse="csr")
raw_predictions = self._raw_predict_init(X)

n_estimators, K = self.estimators_.shape
Expand Down Expand Up @@ -1438,7 +1444,7 @@ def predict(self, X):
"""
check_is_fitted(self, "estimators_")

X = self._validate_data(X, reset=False, order="C", accept_sparse="csr", dtype=DTYPE)
X = validate_data(self, X, reset=False, order="C", accept_sparse="csr", dtype=DTYPE)
return self._predict(X)

def staged_predict(self, X):
Expand Down
38 changes: 25 additions & 13 deletions sksurv/ensemble/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
_parallel_build_trees,
)
from sklearn.tree._tree import DTYPE
from sklearn.utils._tags import _safe_tags
from sklearn.utils.validation import check_is_fitted, check_random_state
from sklearn.utils._tags import get_tags
from sklearn.utils.validation import check_is_fitted, check_random_state, validate_data

from ..base import SurvivalAnalysisMixin
from ..metrics import concordance_index_censored
Expand All @@ -29,18 +29,20 @@
MAX_INT = np.iinfo(np.int32).max


def _more_tags_patch(self):
# BaseForest._more_tags calls
def _sklearn_tags_patch(self):
# BaseForest.__sklearn_tags__ calls
# type(self.estimator)(criterion=self.criterions),
# which is incompatible with LogrankCriterion
if isinstance(self, _BaseSurvivalForest):
estimator = type(self.estimator)()
else:
estimator = type(self.estimator)(criterion=self.criterion)
return {"allow_nan": _safe_tags(estimator, key="allow_nan")}
tags = super(BaseForest, self).__sklearn_tags__()
tags.input_tags.allow_nan = get_tags(estimator).input_tags.allow_nan
return tags


BaseForest._more_tags = _more_tags_patch
BaseForest.__sklearn_tags__ = _sklearn_tags_patch


class _BaseSurvivalForest(BaseForest, metaclass=ABCMeta):
Expand Down Expand Up @@ -104,7 +106,7 @@ def fit(self, X, y, sample_weight=None):
"""
self._validate_params()

X = self._validate_data(X, dtype=DTYPE, accept_sparse="csc", ensure_min_samples=2, force_all_finite=False)
X = validate_data(self, X, dtype=DTYPE, accept_sparse="csc", ensure_min_samples=2, ensure_all_finite=False)
event, time = check_array_survival(X, y)

# _compute_missing_values_in_feature_mask checks if X has missing values and
Expand All @@ -115,7 +117,7 @@ def fit(self, X, y, sample_weight=None):
X, estimator_name=self.__class__.__name__
)

self.n_features_in_ = X.shape[1]
self._n_samples, self.n_features_in_ = X.shape
time = time.astype(np.float64)
self.unique_times_, self.is_event_time_ = get_unique_times(time, event)
self.n_outputs_ = self.unique_times_.shape[0]
Expand All @@ -125,7 +127,18 @@ def fit(self, X, y, sample_weight=None):
y_numeric[:, 1] = event.astype(np.float64)

# Get bootstrap sample size
n_samples_bootstrap = _get_n_samples_bootstrap(n_samples=X.shape[0], max_samples=self.max_samples)
if not self.bootstrap and self.max_samples is not None: # pylint: disable=no-else-raise
raise ValueError(
"`max_sample` cannot be set if `bootstrap=False`. "
"Either switch to `bootstrap=True` or set "
"`max_sample=None`."
)
elif self.bootstrap:
n_samples_bootstrap = _get_n_samples_bootstrap(n_samples=X.shape[0], max_samples=self.max_samples)
else:
n_samples_bootstrap = None

self._n_samples_bootstrap = n_samples_bootstrap

# Check parameters
self._validate_estimator()
Expand All @@ -141,13 +154,12 @@ def fit(self, X, y, sample_weight=None):

n_more_estimators = self.n_estimators - len(self.estimators_)

if n_more_estimators < 0:
if n_more_estimators < 0: # pylint: disable=no-else-raise
raise ValueError(
f"n_estimators={self.n_estimators} must be larger or equal to "
f"len(estimators_)={len(self.estimators_)} when warm_start==True"
)

if n_more_estimators == 0:
elif n_more_estimators == 0:
warnings.warn("Warm-start fitting without increasing n_estimators does not fit new trees.", stacklevel=2)
else:
if self.warm_start and len(self.estimators_) > 0:
Expand Down Expand Up @@ -442,7 +454,7 @@ class RandomSurvivalForest(SurvivalAnalysisMixin, _BaseSurvivalForest):
`min_impurity_decrease` or `min_impurity_split` are absent.
In addition, the `feature_importances_` attribute is not available.
It is recommended to estimate feature importances via
`permutation-based methods <https://eli5.readthedocs.io>`_.
:func:`sklearn.inspection.permutation_importance`.

The features are always randomly permuted at each split. Therefore,
the best found split may vary, even with the same training data,
Expand Down
2 changes: 1 addition & 1 deletion sksurv/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import numpy as np
from sklearn.utils import check_consistent_length
from sklearn.utils.validation import check_consistent_length

__all__ = ["StepFunction"]

Expand Down
Loading
Loading