diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index bbaed5a1..e561a090 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -43,6 +43,6 @@ repos:
hooks:
- id: black-jupyter
- repo: https://github.com/astral-sh/ruff-pre-commit
- rev: v0.7.1
+ rev: v0.8.4
hooks:
- id: ruff
diff --git a/README.rst b/README.rst
index 71ef1cd5..fd1223f2 100644
--- a/README.rst
+++ b/README.rst
@@ -39,7 +39,7 @@ Requirements
- numpy
- osqp
- pandas 1.4.0 or later
-- scikit-learn 1.4 or 1.5
+- scikit-learn 1.6
- scipy
- C/C++ compiler
diff --git a/ci/appveyor/py310.ps1 b/ci/appveyor/py310.ps1
index df7f0bfd..665d3214 100644
--- a/ci/appveyor/py310.ps1
+++ b/ci/appveyor/py310.ps1
@@ -1,4 +1,4 @@
$env:CI_PYTHON_VERSION="3.10.*"
$env:CI_PANDAS_VERSION="1.5.*"
$env:CI_NUMPY_VERSION="1.25.*"
-$env:CI_SKLEARN_VERSION="1.4.*"
+$env:CI_SKLEARN_VERSION="1.6.*"
diff --git a/ci/appveyor/py311.ps1 b/ci/appveyor/py311.ps1
index 3cd8ad40..1d1e2456 100644
--- a/ci/appveyor/py311.ps1
+++ b/ci/appveyor/py311.ps1
@@ -1,4 +1,4 @@
$env:CI_PYTHON_VERSION="3.11.*"
$env:CI_PANDAS_VERSION="2.0.*"
$env:CI_NUMPY_VERSION="1.26.*"
-$env:CI_SKLEARN_VERSION="1.5.*"
+$env:CI_SKLEARN_VERSION="1.6.*"
diff --git a/ci/appveyor/py312.ps1 b/ci/appveyor/py312.ps1
index d1ba90d1..eb7ee138 100644
--- a/ci/appveyor/py312.ps1
+++ b/ci/appveyor/py312.ps1
@@ -1,4 +1,4 @@
$env:CI_PYTHON_VERSION="3.12.*"
$env:CI_PANDAS_VERSION="2.2.*"
$env:CI_NUMPY_VERSION="2.0.*"
-$env:CI_SKLEARN_VERSION="1.5.*"
+$env:CI_SKLEARN_VERSION="1.6.*"
diff --git a/ci/appveyor/py313.ps1 b/ci/appveyor/py313.ps1
index 32f2664a..cd6ef92d 100644
--- a/ci/appveyor/py313.ps1
+++ b/ci/appveyor/py313.ps1
@@ -1,4 +1,4 @@
$env:CI_PYTHON_VERSION="3.13.*"
$env:CI_PANDAS_VERSION="2.2.*"
$env:CI_NUMPY_VERSION="2.1.*"
-$env:CI_SKLEARN_VERSION="1.5.*"
+$env:CI_SKLEARN_VERSION="1.6.*"
diff --git a/ci/deps/py310.sh b/ci/deps/py310.sh
index 57ef4d7b..968d5684 100644
--- a/ci/deps/py310.sh
+++ b/ci/deps/py310.sh
@@ -2,5 +2,5 @@
export CI_PYTHON_VERSION='3.10.*'
export CI_PANDAS_VERSION='1.5.*'
export CI_NUMPY_VERSION='1.25.*'
-export CI_SKLEARN_VERSION='1.4.*'
+export CI_SKLEARN_VERSION='1.6.*'
export CI_NO_SLOW=false
diff --git a/ci/deps/py311.sh b/ci/deps/py311.sh
index 28c83a40..13edeb85 100644
--- a/ci/deps/py311.sh
+++ b/ci/deps/py311.sh
@@ -2,5 +2,5 @@
export CI_PYTHON_VERSION='3.11.*'
export CI_PANDAS_VERSION='2.0.*'
export CI_NUMPY_VERSION='1.26.*'
-export CI_SKLEARN_VERSION='1.5.*'
+export CI_SKLEARN_VERSION='1.6.*'
export CI_NO_SLOW=true
diff --git a/ci/deps/py312.sh b/ci/deps/py312.sh
index 380ffca2..14b1da21 100644
--- a/ci/deps/py312.sh
+++ b/ci/deps/py312.sh
@@ -2,5 +2,5 @@
export CI_PYTHON_VERSION='3.12.*'
export CI_PANDAS_VERSION='2.2.*'
export CI_NUMPY_VERSION='2.0.*'
-export CI_SKLEARN_VERSION='1.5.*'
+export CI_SKLEARN_VERSION='1.6.*'
export CI_NO_SLOW=true
diff --git a/ci/deps/py313.sh b/ci/deps/py313.sh
index bce18de1..569fc5e8 100644
--- a/ci/deps/py313.sh
+++ b/ci/deps/py313.sh
@@ -2,5 +2,5 @@
export CI_PYTHON_VERSION='3.13.*'
export CI_PANDAS_VERSION='2.2.*'
export CI_NUMPY_VERSION='2.1.*'
-export CI_SKLEARN_VERSION='1.5.*'
+export CI_SKLEARN_VERSION='1.6.*'
export CI_NO_SLOW=false
diff --git a/doc/conf.py b/doc/conf.py
index 08f3670b..7b2ce914 100644
--- a/doc/conf.py
+++ b/doc/conf.py
@@ -217,7 +217,7 @@
}
intersphinx_mapping = {
- "sklearn": ("https://scikit-learn.org/1.5", None),
+ "sklearn": ("https://scikit-learn.org/1.6", None),
"cython": ("https://cython.readthedocs.io/en/latest/", None),
"scipy": ("https://docs.scipy.org/doc/scipy/", None),
"pandas": ("https://pandas.pydata.org/docs/", None),
diff --git a/doc/install.rst b/doc/install.rst
index c4e10363..ea70927a 100644
--- a/doc/install.rst
+++ b/doc/install.rst
@@ -100,6 +100,6 @@ The current minimum dependencies to run scikit-survival are:
- numpy
- osqp
- pandas 1.4.0 or later
-- scikit-learn 1.4 or 1.5
+- scikit-learn 1.6
- scipy
- C/C++ compiler
diff --git a/pyproject.toml b/pyproject.toml
index 7179e31e..3a6cad39 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -9,8 +9,7 @@ requires = [
"numpy>=2.0.0",
# scikit-learn requirements
- "scikit-learn~=1.4.0; python_version<='3.12'",
- "scikit-learn~=1.5.0; python_version=='3.13'",
+ "scikit-learn~=1.6.1; python_version<='3.13'",
"scikit-learn; python_version>'3.13'",
]
build-backend = "setuptools.build_meta"
@@ -51,7 +50,7 @@ dependencies = [
"osqp !=0.6.0,!=0.6.1",
"pandas >=1.4.0",
"scipy >=1.3.2",
- "scikit-learn >=1.4.0,<1.6",
+ "scikit-learn >=1.6.1,<1.7",
]
dynamic = ["version"]
@@ -188,13 +187,21 @@ target-version = "py310"
ignore = ["C408"]
ignore-init-module-imports = true
select = [
- "C4",
- "C9",
+ # pycodestyle
"E",
+ "W",
+ # mccabe
+ "C90",
+ # pyflakes
"F",
+ # isort
"I",
+ # flake8-builtins
+ "A",
+ # flake8-comprehensions
+ "C4",
+ # flake8-pytest-style
"PT",
- "W",
]
[tool.ruff.lint.flake8-pytest-style]
diff --git a/sksurv/__init__.py b/sksurv/__init__.py
index 2405e27c..b75c6b37 100644
--- a/sksurv/__init__.py
+++ b/sksurv/__init__.py
@@ -44,7 +44,7 @@ def show_versions():
max(map(len, deps)),
max(map(len, sys_info.keys())),
)
- fmt = "{0:<%ds}: {1}" % minwidth
+ fmt = f"{{0:<{minwidth}s}}: {{1}}"
print("SYSTEM")
print("------")
diff --git a/sksurv/base.py b/sksurv/base.py
index 1c608723..ab12eb46 100644
--- a/sksurv/base.py
+++ b/sksurv/base.py
@@ -99,5 +99,7 @@ def score(self, X, y):
result = concordance_index_censored(y[name_event], y[name_time], risk_score)
return result[0]
- def _more_tags(self):
- return {"requires_y": True}
+ def __sklearn_tags__(self):
+ tags = super().__sklearn_tags__()
+ tags.target_tags.required = True
+ return tags
diff --git a/sksurv/bintrees/_binarytrees.pyx b/sksurv/bintrees/_binarytrees.pyx
index bdeb02eb..36f19a1d 100644
--- a/sksurv/bintrees/_binarytrees.pyx
+++ b/sksurv/bintrees/_binarytrees.pyx
@@ -10,6 +10,7 @@
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
+cimport cython
from libcpp cimport bool
from libcpp.cast cimport dynamic_cast
@@ -76,18 +77,21 @@ cdef class BaseTree:
return self.count_larger(key)
+@cython.final
cdef class RBTree(BaseTree):
def __cinit__(self, int size):
if size <= 0:
raise ValueError('size must be greater zero')
self.treeptr = new rbtree(size)
+@cython.final
cdef class AVLTree(BaseTree):
def __cinit__(self, int size):
if size <= 0:
raise ValueError('size must be greater zero')
self.treeptr = dynamic_cast[rbtree_ptr](new avl(size))
+@cython.final
cdef class AATree(BaseTree):
def __cinit__(self, int size):
if size <= 0:
diff --git a/sksurv/ensemble/boosting.py b/sksurv/ensemble/boosting.py
index 39686226..e2af8d04 100644
--- a/sksurv/ensemble/boosting.py
+++ b/sksurv/ensemble/boosting.py
@@ -21,10 +21,15 @@
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor
from sklearn.tree._tree import DTYPE
-from sklearn.utils import check_random_state
from sklearn.utils._param_validation import Interval, StrOptions
from sklearn.utils.extmath import squared_norm
-from sklearn.utils.validation import _check_sample_weight, check_array, check_is_fitted
+from sklearn.utils.validation import (
+ _check_sample_weight,
+ check_array,
+ check_is_fitted,
+ check_random_state,
+ validate_data,
+)
from ..base import SurvivalAnalysisMixin
from ..linear_model.coxph import BreslowEstimator
@@ -389,7 +394,7 @@ def fit(self, X, y, sample_weight=None):
if not self.warm_start:
self._clear_state()
- X = self._validate_data(X, ensure_min_samples=2)
+ X = validate_data(self, X, ensure_min_samples=2)
event, time = check_array_survival(X, y)
sample_weight = _check_sample_weight(sample_weight, X)
@@ -398,7 +403,7 @@ def fit(self, X, y, sample_weight=None):
Xi = np.column_stack((np.ones(n_samples), X))
self._loss = LOSS_FUNCTIONS[self.loss]()
- if isinstance(self._loss, (CensoredSquaredLoss, IPCWLeastSquaresError)):
+ if isinstance(self._loss, CensoredSquaredLoss | IPCWLeastSquaresError):
time = np.log(time)
if not self._is_fitted():
@@ -470,7 +475,7 @@ def predict(self, X):
Predicted risk scores.
"""
check_is_fitted(self, "estimators_")
- X = self._validate_data(X, reset=False)
+ X = validate_data(self, X, reset=False)
return self._predict(X)
@@ -957,7 +962,7 @@ def _set_max_features(self):
max_features = max(1, int(np.log2(self.n_features_in_)))
elif self.max_features is None:
max_features = self.n_features_in_
- elif isinstance(self.max_features, (numbers.Integral, np.integer)):
+ elif isinstance(self.max_features, numbers.Integral):
max_features = self.max_features
else: # float
max_features = max(1, int(self.max_features * self.n_features_in_))
@@ -1234,7 +1239,8 @@ def fit(self, X, y, sample_weight=None, monitor=None):
if not self.warm_start:
self._clear_state()
- X = self._validate_data(
+ X = validate_data(
+ self,
X,
ensure_min_samples=2,
order="C",
@@ -1256,7 +1262,7 @@ def fit(self, X, y, sample_weight=None, monitor=None):
# self.loss is guaranteed to be a string
self._loss = self._get_loss(sample_weight=sample_weight)
- if isinstance(self._loss, (CensoredSquaredLoss, IPCWLeastSquaresError)):
+ if isinstance(self._loss, CensoredSquaredLoss | IPCWLeastSquaresError):
time = np.log(time)
if self.n_iter_no_change is not None:
@@ -1315,13 +1321,13 @@ def fit(self, X, y, sample_weight=None, monitor=None):
begin_at_stage = self.estimators_.shape[0]
# The requirements of _raw_predict
# are more constrained than fit. It accepts only CSR
- # matrices. Finite values have already been checked in _validate_data.
+ # matrices. Finite values have already been checked in validate_data.
X_train = check_array(
X_train,
dtype=DTYPE,
order="C",
accept_sparse="csr",
- force_all_finite=False,
+ ensure_all_finite=False,
)
raw_predictions = self._raw_predict(X_train)
self._resize_state()
@@ -1390,7 +1396,7 @@ def _dropout_raw_predict(self, X):
return raw_predictions
def _dropout_staged_raw_predict(self, X):
- X = self._validate_data(X, dtype=DTYPE, order="C", accept_sparse="csr")
+ X = validate_data(self, X, dtype=DTYPE, order="C", accept_sparse="csr")
raw_predictions = self._raw_predict_init(X)
n_estimators, K = self.estimators_.shape
@@ -1438,7 +1444,7 @@ def predict(self, X):
"""
check_is_fitted(self, "estimators_")
- X = self._validate_data(X, reset=False, order="C", accept_sparse="csr", dtype=DTYPE)
+ X = validate_data(self, X, reset=False, order="C", accept_sparse="csr", dtype=DTYPE)
return self._predict(X)
def staged_predict(self, X):
diff --git a/sksurv/ensemble/forest.py b/sksurv/ensemble/forest.py
index 1dc96a05..d43c242b 100644
--- a/sksurv/ensemble/forest.py
+++ b/sksurv/ensemble/forest.py
@@ -14,8 +14,8 @@
_parallel_build_trees,
)
from sklearn.tree._tree import DTYPE
-from sklearn.utils._tags import _safe_tags
-from sklearn.utils.validation import check_is_fitted, check_random_state
+from sklearn.utils._tags import get_tags
+from sklearn.utils.validation import check_is_fitted, check_random_state, validate_data
from ..base import SurvivalAnalysisMixin
from ..metrics import concordance_index_censored
@@ -29,18 +29,20 @@
MAX_INT = np.iinfo(np.int32).max
-def _more_tags_patch(self):
- # BaseForest._more_tags calls
+def _sklearn_tags_patch(self):
+ # BaseForest.__sklearn_tags__ calls
# type(self.estimator)(criterion=self.criterions),
# which is incompatible with LogrankCriterion
if isinstance(self, _BaseSurvivalForest):
estimator = type(self.estimator)()
else:
estimator = type(self.estimator)(criterion=self.criterion)
- return {"allow_nan": _safe_tags(estimator, key="allow_nan")}
+ tags = super(BaseForest, self).__sklearn_tags__()
+ tags.input_tags.allow_nan = get_tags(estimator).input_tags.allow_nan
+ return tags
-BaseForest._more_tags = _more_tags_patch
+BaseForest.__sklearn_tags__ = _sklearn_tags_patch
class _BaseSurvivalForest(BaseForest, metaclass=ABCMeta):
@@ -104,7 +106,7 @@ def fit(self, X, y, sample_weight=None):
"""
self._validate_params()
- X = self._validate_data(X, dtype=DTYPE, accept_sparse="csc", ensure_min_samples=2, force_all_finite=False)
+ X = validate_data(self, X, dtype=DTYPE, accept_sparse="csc", ensure_min_samples=2, ensure_all_finite=False)
event, time = check_array_survival(X, y)
# _compute_missing_values_in_feature_mask checks if X has missing values and
@@ -115,7 +117,7 @@ def fit(self, X, y, sample_weight=None):
X, estimator_name=self.__class__.__name__
)
- self.n_features_in_ = X.shape[1]
+ self._n_samples, self.n_features_in_ = X.shape
time = time.astype(np.float64)
self.unique_times_, self.is_event_time_ = get_unique_times(time, event)
self.n_outputs_ = self.unique_times_.shape[0]
@@ -125,7 +127,18 @@ def fit(self, X, y, sample_weight=None):
y_numeric[:, 1] = event.astype(np.float64)
# Get bootstrap sample size
- n_samples_bootstrap = _get_n_samples_bootstrap(n_samples=X.shape[0], max_samples=self.max_samples)
+ if not self.bootstrap and self.max_samples is not None: # pylint: disable=no-else-raise
+ raise ValueError(
+ "`max_sample` cannot be set if `bootstrap=False`. "
+ "Either switch to `bootstrap=True` or set "
+ "`max_sample=None`."
+ )
+ elif self.bootstrap:
+ n_samples_bootstrap = _get_n_samples_bootstrap(n_samples=X.shape[0], max_samples=self.max_samples)
+ else:
+ n_samples_bootstrap = None
+
+ self._n_samples_bootstrap = n_samples_bootstrap
# Check parameters
self._validate_estimator()
@@ -141,13 +154,12 @@ def fit(self, X, y, sample_weight=None):
n_more_estimators = self.n_estimators - len(self.estimators_)
- if n_more_estimators < 0:
+ if n_more_estimators < 0: # pylint: disable=no-else-raise
raise ValueError(
f"n_estimators={self.n_estimators} must be larger or equal to "
f"len(estimators_)={len(self.estimators_)} when warm_start==True"
)
-
- if n_more_estimators == 0:
+ elif n_more_estimators == 0:
warnings.warn("Warm-start fitting without increasing n_estimators does not fit new trees.", stacklevel=2)
else:
if self.warm_start and len(self.estimators_) > 0:
@@ -442,7 +454,7 @@ class RandomSurvivalForest(SurvivalAnalysisMixin, _BaseSurvivalForest):
`min_impurity_decrease` or `min_impurity_split` are absent.
In addition, the `feature_importances_` attribute is not available.
It is recommended to estimate feature importances via
- `permutation-based methods `_.
+ :func:`sklearn.inspection.permutation_importance`.
The features are always randomly permuted at each split. Therefore,
the best found split may vary, even with the same training data,
diff --git a/sksurv/functions.py b/sksurv/functions.py
index ee978cc8..154f0f75 100644
--- a/sksurv/functions.py
+++ b/sksurv/functions.py
@@ -12,7 +12,7 @@
# along with this program. If not, see .
import numpy as np
-from sklearn.utils import check_consistent_length
+from sklearn.utils.validation import check_consistent_length
__all__ = ["StepFunction"]
diff --git a/sksurv/kernels/clinical.py b/sksurv/kernels/clinical.py
index 0301f29f..306b2232 100644
--- a/sksurv/kernels/clinical.py
+++ b/sksurv/kernels/clinical.py
@@ -14,7 +14,7 @@
import pandas as pd
from pandas.api.types import CategoricalDtype, is_numeric_dtype
from sklearn.base import BaseEstimator, TransformerMixin
-from sklearn.utils.validation import check_is_fitted
+from sklearn.utils.validation import _check_feature_names, _check_n_features, check_is_fitted
from ._clinical_kernel import (
continuous_ordinal_kernel,
@@ -227,8 +227,8 @@ def fit(self, X, y=None, **kwargs): # pylint: disable=unused-argument
if X.ndim != 2:
raise ValueError(f"expected 2d array, but got {X.ndim}")
- self._check_feature_names(X, reset=True)
- self._check_n_features(X, reset=True)
+ _check_feature_names(self, X, reset=True)
+ _check_n_features(self, X, reset=True)
if self.fit_once:
self.X_fit_ = X
@@ -251,8 +251,8 @@ def transform(self, Y):
"""
check_is_fitted(self, "X_fit_")
- self._check_feature_names(Y, reset=False)
- self._check_n_features(Y, reset=False)
+ _check_feature_names(self, Y, reset=False)
+ _check_n_features(self, Y, reset=False)
n_samples_x = self.X_fit_.shape[0]
diff --git a/sksurv/linear_model/coxnet.py b/sksurv/linear_model/coxnet.py
index bd7a9711..f53000a0 100644
--- a/sksurv/linear_model/coxnet.py
+++ b/sksurv/linear_model/coxnet.py
@@ -18,7 +18,13 @@
from sklearn.exceptions import ConvergenceWarning
from sklearn.preprocessing import normalize as f_normalize
from sklearn.utils._param_validation import Interval, StrOptions
-from sklearn.utils.validation import assert_all_finite, check_is_fitted, check_non_negative, column_or_1d
+from sklearn.utils.validation import (
+ assert_all_finite,
+ check_is_fitted,
+ check_non_negative,
+ column_or_1d,
+ validate_data,
+)
from ..base import SurvivalAnalysisMixin
from ..util import check_array_survival
@@ -178,7 +184,7 @@ def __init__(
self._baseline_models = None
def _pre_fit(self, X, y):
- X = self._validate_data(X, ensure_min_samples=2, dtype=np.float64, copy=self.copy_X)
+ X = validate_data(self, X, ensure_min_samples=2, dtype=np.float64, copy=self.copy_X)
event, time = check_array_survival(X, y)
# center feature matrix
X_offset = np.average(X, axis=0)
@@ -212,9 +218,6 @@ def _check_penalty_factor(self, n_features):
def _check_alphas(self):
create_path = self.alphas is None
if create_path:
- if self.n_alphas <= 0:
- raise ValueError("n_alphas must be a positive integer")
-
alphas = np.empty(int(self.n_alphas), dtype=np.float64)
else:
alphas = column_or_1d(self.alphas, warn=True)
@@ -239,9 +242,6 @@ def _check_params(self, n_samples, n_features):
alphas, create_path = self._check_alphas()
- if self.max_iter <= 0:
- raise ValueError("max_iter must be a positive integer")
-
alpha_min_ratio = self._check_alpha_min_ratio(n_samples, n_features)
return create_path, alphas.astype(np.float64), penalty_factor.astype(np.float64), alpha_min_ratio
@@ -366,7 +366,7 @@ def predict(self, X, alpha=None):
T : array, shape = (n_samples,)
The predicted decision function
"""
- X = self._validate_data(X, reset=False)
+ X = validate_data(self, X, reset=False)
coef, offset = self._get_coef(alpha)
return np.dot(X, coef) - offset
diff --git a/sksurv/linear_model/coxph.py b/sksurv/linear_model/coxph.py
index eb09cf31..3d8cb466 100644
--- a/sksurv/linear_model/coxph.py
+++ b/sksurv/linear_model/coxph.py
@@ -18,7 +18,7 @@
from sklearn.base import BaseEstimator
from sklearn.exceptions import ConvergenceWarning
from sklearn.utils._param_validation import Interval, StrOptions
-from sklearn.utils.validation import check_array, check_is_fitted
+from sklearn.utils.validation import check_array, check_is_fitted, validate_data
from ..base import SurvivalAnalysisMixin
from ..functions import StepFunction
@@ -413,10 +413,10 @@ def fit(self, X, y):
"""
self._validate_params()
- X = self._validate_data(X, ensure_min_samples=2, dtype=np.float64)
+ X = validate_data(self, X, ensure_min_samples=2, dtype=np.float64)
event, time = check_array_survival(X, y)
- if isinstance(self.alpha, (numbers.Real, numbers.Integral)):
+ if isinstance(self.alpha, numbers.Real | numbers.Integral):
alphas = np.empty(X.shape[1], dtype=float)
alphas[:] = self.alpha
else:
@@ -494,7 +494,7 @@ def predict(self, X):
"""
check_is_fitted(self, "coef_")
- X = self._validate_data(X, reset=False)
+ X = validate_data(self, X, reset=False)
return np.dot(X, self.coef_)
diff --git a/sksurv/metrics.py b/sksurv/metrics.py
index d797607d..af6bc4d6 100644
--- a/sksurv/metrics.py
+++ b/sksurv/metrics.py
@@ -12,9 +12,8 @@
# along with this program. If not, see .
import numpy as np
from sklearn.base import BaseEstimator
-from sklearn.utils import check_array, check_consistent_length
from sklearn.utils.metaestimators import available_if
-from sklearn.utils.validation import check_is_fitted
+from sklearn.utils.validation import check_array, check_consistent_length, check_is_fitted
from .exceptions import NoComparablePairException
from .nonparametric import CensoringDistributionEstimator, SurvivalFunctionEstimator
diff --git a/sksurv/preprocessing.py b/sksurv/preprocessing.py
index 1f39e612..e4ad69b3 100644
--- a/sksurv/preprocessing.py
+++ b/sksurv/preprocessing.py
@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
from sklearn.base import BaseEstimator, TransformerMixin
-from sklearn.utils.validation import _check_feature_names_in, check_is_fitted
+from sklearn.utils.validation import _check_feature_names, _check_feature_names_in, _check_n_features, check_is_fitted
from .column import encode_categorical
@@ -97,8 +97,8 @@ def fit_transform(self, X, y=None, **fit_params): # pylint: disable=unused-argu
Xt : pandas.DataFrame
Encoded data.
"""
- self._check_feature_names(X, reset=True)
- self._check_n_features(X, reset=True)
+ _check_feature_names(self, X, reset=True)
+ _check_n_features(self, X, reset=True)
columns_to_encode = X.select_dtypes(include=["object", "category"]).columns
x_dummy = self._encode(X, columns_to_encode)
@@ -121,7 +121,7 @@ def transform(self, X):
Encoded data.
"""
check_is_fitted(self, "encoded_columns_")
- self._check_n_features(X, reset=False)
+ _check_n_features(self, X, reset=False)
check_columns_exist(X.columns, self.feature_names_)
Xt = X.copy()
diff --git a/sksurv/svm/minlip.py b/sksurv/svm/minlip.py
index 8a61c6de..7121f316 100644
--- a/sksurv/svm/minlip.py
+++ b/sksurv/svm/minlip.py
@@ -8,6 +8,7 @@
from sklearn.exceptions import ConvergenceWarning
from sklearn.metrics.pairwise import PAIRWISE_KERNEL_FUNCTIONS, pairwise_kernels
from sklearn.utils._param_validation import Interval, StrOptions
+from sklearn.utils.validation import validate_data
from ..base import SurvivalAnalysisMixin
from ..exceptions import NoComparablePairException
@@ -330,9 +331,11 @@ def __init__(
self.timeit = timeit
self.max_iter = max_iter
- def _more_tags(self):
+ def __sklearn_tags__(self):
# tell sklearn.utils.metaestimators._safe_split function that we expect kernel matrix
- return {"pairwise": self.kernel == "precomputed"}
+ tags = super().__sklearn_tags__()
+ tags.input_tags.pairwise = self.kernel == "precomputed"
+ return tags
def _get_kernel(self, X, Y=None):
if callable(self.kernel):
@@ -411,7 +414,7 @@ def fit(self, X, y):
self
"""
self._validate_params()
- X = self._validate_data(X, ensure_min_samples=2)
+ X = validate_data(self, X, ensure_min_samples=2)
event, time = check_array_survival(X, y)
self._fit(X, event, time)
@@ -433,7 +436,7 @@ def predict(self, X):
y : ndarray, shape = (n_samples,)
Predicted risk.
"""
- X = self._validate_data(X, reset=False)
+ X = validate_data(self, X, reset=False)
K = self._get_kernel(X, self.X_fit_)
pred = -np.dot(self.coef_, K.T)
return pred.ravel()
diff --git a/sksurv/svm/naive_survival_svm.py b/sksurv/svm/naive_survival_svm.py
index 52cf3114..26d6c7da 100644
--- a/sksurv/svm/naive_survival_svm.py
+++ b/sksurv/svm/naive_survival_svm.py
@@ -16,8 +16,7 @@
import pandas as pd
from scipy.special import comb
from sklearn.svm import LinearSVC
-from sklearn.utils import check_random_state
-from sklearn.utils.validation import _get_feature_names
+from sklearn.utils.validation import _get_feature_names, check_random_state, validate_data
from ..base import SurvivalAnalysisMixin
from ..exceptions import NoComparablePairException
@@ -141,7 +140,7 @@ def __init__(
def _get_survival_pairs(self, X, y, random_state): # pylint: disable=no-self-use
feature_names = _get_feature_names(X)
- X = self._validate_data(X, ensure_min_samples=2)
+ X = validate_data(self, X, ensure_min_samples=2)
event, time = check_array_survival(X, y)
idx = np.arange(X.shape[0], dtype=int)
diff --git a/sksurv/svm/survival_svm.py b/sksurv/svm/survival_svm.py
index d04908a7..89cf9534 100644
--- a/sksurv/svm/survival_svm.py
+++ b/sksurv/svm/survival_svm.py
@@ -20,10 +20,16 @@
from sklearn.base import BaseEstimator
from sklearn.exceptions import ConvergenceWarning
from sklearn.metrics.pairwise import PAIRWISE_KERNEL_FUNCTIONS, pairwise_kernels
-from sklearn.utils import check_array, check_consistent_length, check_random_state, check_X_y
from sklearn.utils._param_validation import Interval, StrOptions
from sklearn.utils.extmath import safe_sparse_dot, squared_norm
-from sklearn.utils.validation import check_is_fitted
+from sklearn.utils.validation import (
+ check_array,
+ check_consistent_length,
+ check_is_fitted,
+ check_random_state,
+ check_X_y,
+ validate_data,
+)
from ..base import SurvivalAnalysisMixin
from ..bintrees import AVLTree, RBTree
@@ -727,7 +733,7 @@ def predict(self, X):
"""Predict risk score"""
def _validate_for_fit(self, X):
- return self._validate_data(X, ensure_min_samples=2)
+ return validate_data(self, X, ensure_min_samples=2)
def fit(self, X, y):
"""Build a survival support vector machine model from training data.
@@ -962,7 +968,7 @@ def predict(self, X):
Predicted ranks.
"""
check_is_fitted(self, "coef_")
- X = self._validate_data(X, reset=False)
+ X = validate_data(self, X, reset=False)
val = np.dot(X, self.coef_)
if hasattr(self, "intercept_"):
@@ -1133,9 +1139,11 @@ def __init__(
self.coef0 = coef0
self.kernel_params = kernel_params
- def _more_tags(self):
+ def __sklearn_tags__(self):
# tell sklearn.utils.metaestimators._safe_split function that we expect kernel matrix
- return {"pairwise": self.kernel == "precomputed"}
+ tags = super().__sklearn_tags__()
+ tags.input_tags.pairwise = self.kernel == "precomputed"
+ return tags
def _get_kernel(self, X, Y=None):
if callable(self.kernel):
@@ -1211,7 +1219,7 @@ def predict(self, X):
y : ndarray, shape = (n_samples,)
Predicted ranks.
"""
- X = self._validate_data(X, reset=False)
+ X = validate_data(self, X, reset=False)
kernel_mat = self._get_kernel(X, self.fit_X_)
val = np.dot(kernel_mat, self.coef_)
diff --git a/sksurv/tree/_criterion.pyx b/sksurv/tree/_criterion.pyx
index e636bc9d..57b66ec7 100644
--- a/sksurv/tree/_criterion.pyx
+++ b/sksurv/tree/_criterion.pyx
@@ -2,6 +2,7 @@
# cython: boundscheck=False
# cython: wraparound=False
+cimport cython
from libc.math cimport INFINITY, NAN, fabs, sqrt
from libc.stdlib cimport free, malloc
from libc.string cimport memset
@@ -39,6 +40,7 @@ cpdef get_unique_times(cnp.ndarray[float64_t, ndim=1] time, cnp.ndarray[cnp.npy_
return np.asarray(unique_values), np.asarray(has_event, dtype=np.bool_)
+@cython.final
cdef class RisksetCounter:
cdef:
const float64_t[:] unique_times
@@ -132,6 +134,7 @@ cdef int argbinsearch(const float64_t[:] arr, float64_t key_val, intp_t * ret) e
return 0
+@cython.final
cdef class LogrankCriterion(Criterion):
cdef:
diff --git a/sksurv/tree/tree.py b/sksurv/tree/tree.py
index 750fb518..1205a0c8 100644
--- a/sksurv/tree/tree.py
+++ b/sksurv/tree/tree.py
@@ -9,12 +9,14 @@
from sklearn.tree._splitter import Splitter
from sklearn.tree._tree import BestFirstTreeBuilder, DepthFirstTreeBuilder, Tree
from sklearn.tree._utils import _any_isnan_axis0
-from sklearn.utils._param_validation import Interval, StrOptions
+from sklearn.utils._param_validation import Interval, RealNotInt, StrOptions
from sklearn.utils.validation import (
_assert_all_finite_element_wise,
+ _check_n_features,
assert_all_finite,
check_is_fitted,
check_random_state,
+ validate_data,
)
from ..base import SurvivalAnalysisMixin
@@ -159,16 +161,16 @@ class SurvivalTree(BaseEstimator, SurvivalAnalysisMixin):
"max_depth": [Interval(Integral, 1, None, closed="left"), None],
"min_samples_split": [
Interval(Integral, 2, None, closed="left"),
- Interval(Real, 0.0, 1.0, closed="neither"),
+ Interval(RealNotInt, 0.0, 1.0, closed="neither"),
],
"min_samples_leaf": [
Interval(Integral, 1, None, closed="left"),
- Interval(Real, 0.0, 0.5, closed="right"),
+ Interval(RealNotInt, 0.0, 0.5, closed="right"),
],
"min_weight_fraction_leaf": [Interval(Real, 0.0, 0.5, closed="both")],
"max_features": [
Interval(Integral, 1, None, closed="left"),
- Interval(Real, 0.0, 1.0, closed="right"),
+ Interval(RealNotInt, 0.0, 1.0, closed="right"),
StrOptions({"sqrt", "log2"}),
None,
],
@@ -202,12 +204,13 @@ def __init__(
self.max_leaf_nodes = max_leaf_nodes
self.low_memory = low_memory
- def _more_tags(self):
- allow_nan = self.splitter == "best"
- return {"allow_nan": allow_nan}
+ def __sklearn_tags__(self):
+ tags = super().__sklearn_tags__()
+ tags.input_tags.allow_nan = self.splitter in ("best", "random")
+ return tags
def _support_missing_values(self, X):
- return not issparse(X) and self._get_tags()["allow_nan"]
+ return not issparse(X) and self.__sklearn_tags__().input_tags.allow_nan
def _compute_missing_values_in_feature_mask(self, X, estimator_name=None):
"""Return boolean mask denoting if there are missing values for each feature.
@@ -283,7 +286,7 @@ def _fit(self, X, y, sample_weight=None, check_input=True, missing_values_in_fea
random_state = check_random_state(self.random_state)
if check_input:
- X = self._validate_data(X, dtype=DTYPE, ensure_min_samples=2, accept_sparse="csc", force_all_finite=False)
+ X = validate_data(self, X, dtype=DTYPE, ensure_min_samples=2, accept_sparse="csc", ensure_all_finite=False)
event, time = check_array_survival(X, y)
time = time.astype(np.float64)
self.unique_times_, self.is_event_time_ = get_unique_times(time, event)
@@ -360,7 +363,7 @@ def _check_params(self, n_samples):
max_leaf_nodes = -1 if self.max_leaf_nodes is None else self.max_leaf_nodes
- if isinstance(self.min_samples_leaf, (Integral, np.integer)):
+ if isinstance(self.min_samples_leaf, Integral):
min_samples_leaf = self.min_samples_leaf
else: # float
min_samples_leaf = int(ceil(self.min_samples_leaf * n_samples))
@@ -375,9 +378,6 @@ def _check_params(self, n_samples):
self._check_max_features()
- if not 0 <= self.min_weight_fraction_leaf <= 0.5:
- raise ValueError("min_weight_fraction_leaf must in [0, 0.5]")
-
min_weight_leaf = self.min_weight_fraction_leaf * n_samples
return {
@@ -397,13 +397,13 @@ def _check_max_features(self):
elif self.max_features is None:
max_features = self.n_features_in_
- elif isinstance(self.max_features, (Integral, np.integer)):
+ elif isinstance(self.max_features, Integral):
max_features = self.max_features
else: # float
if self.max_features > 0.0:
max_features = max(1, int(self.max_features * self.n_features_in_))
else:
- max_features = 0
+ max_features = 0 # pragma: no cover
if not 0 < max_features <= self.n_features_in_:
raise ValueError("max_features must be in (0, n_features]")
@@ -422,19 +422,20 @@ def _validate_X_predict(self, X, check_input, accept_sparse="csr"):
"""Validate X whenever one tries to predict"""
if check_input:
if self._support_missing_values(X):
- force_all_finite = "allow-nan"
+ ensure_all_finite = "allow-nan"
else:
- force_all_finite = True
- X = self._validate_data(
+ ensure_all_finite = True
+ X = validate_data(
+ self,
X,
dtype=DTYPE,
accept_sparse=accept_sparse,
reset=False,
- force_all_finite=force_all_finite,
+ ensure_all_finite=ensure_all_finite,
)
else:
# The number of features is checked regardless of `check_input`
- self._check_n_features(X, reset=False)
+ _check_n_features(self, X, reset=False)
return X
diff --git a/sksurv/util.py b/sksurv/util.py
index 1580d901..87e2a73e 100644
--- a/sksurv/util.py
+++ b/sksurv/util.py
@@ -13,7 +13,7 @@
import numpy as np
import pandas as pd
from pandas.api.types import CategoricalDtype
-from sklearn.utils import check_array, check_consistent_length
+from sklearn.utils.validation import check_array, check_consistent_length
__all__ = ["check_array_survival", "check_y_survival", "safe_concat", "Surv"]
diff --git a/tests/test_coxnet.py b/tests/test_coxnet.py
index 9e418942..53c9d0f4 100644
--- a/tests/test_coxnet.py
+++ b/tests/test_coxnet.py
@@ -51,7 +51,7 @@ def nan_float_array():
def assert_columns_almost_equal(actual, expected, decimal=6):
for i, col in enumerate(expected.columns):
assert_array_almost_equal(
- expected.loc[:, col].values, actual.loc[:, col].values, decimal=decimal, err_msg="Column %d: %s" % (i, col)
+ expected.loc[:, col].values, actual.loc[:, col].values, decimal=decimal, err_msg=f"Column {i:d}: {col}"
)
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index 42c547c2..2cff9cb6 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -90,13 +90,7 @@ def attr_labels(self):
return ["event", "time"]
def _to_data_frame(self, data, columns):
- if isinstance(
- data,
- (
- tuple,
- list,
- ),
- ):
+ if isinstance(data, tuple | list):
data = np.column_stack(data)
return pd.DataFrame(data, columns=columns)
diff --git a/tests/test_forest.py b/tests/test_forest.py
index 9bd43a0b..d6387ca6 100644
--- a/tests/test_forest.py
+++ b/tests/test_forest.py
@@ -40,7 +40,10 @@ def test_fit_predict(make_whas500, forest_cls, expected_c):
assert_cindex_almost_equal(whas500.y["fstat"], whas500.y["lenfol"], pred, expected_c)
-def test_fit_missing_values(make_whas500):
+@pytest.mark.parametrize(
+ "forest_cls,expected_cindex", [(ExtraSurvivalTrees, 0.7486232588273405), (RandomSurvivalForest, 0.7444120505344995)]
+)
+def test_fit_missing_values(make_whas500, forest_cls, expected_cindex):
whas500 = make_whas500(to_numeric=True)
rng = np.random.RandomState(42)
@@ -52,43 +55,26 @@ def test_fit_missing_values(make_whas500):
X_train, y_train = X[:400], whas500.y[:400]
X_test, y_test = X[400:], whas500.y[400:]
- forest = RandomSurvivalForest(random_state=42)
+ forest = forest_cls(random_state=42)
forest.fit(X_train, y_train)
- tags = forest._get_tags()
- assert tags["allow_nan"]
+ tags = forest.__sklearn_tags__()
+ assert tags.input_tags.allow_nan
cindex = forest.score(X_test, y_test)
- assert cindex == pytest.approx(0.7444120505344995)
-
-
-def test_fit_missing_values_not_supported(make_whas500):
- whas500 = make_whas500(to_numeric=True)
-
- rng = np.random.RandomState(42)
- mask = rng.binomial(n=1, p=0.15, size=whas500.x.shape)
- mask = mask.astype(bool)
- X = whas500.x.copy()
- X[mask] = np.nan
+ assert cindex == pytest.approx(expected_cindex)
- forest = ExtraSurvivalTrees(random_state=42)
- with pytest.raises(ValueError, match="Input X contains NaN"):
- forest.fit(X, whas500.y)
- tags = forest._get_tags()
- assert not tags["allow_nan"]
-
-
-@pytest.mark.parametrize("forst_cls,allows_nan", [(ExtraTreesClassifier, False), (RandomForestClassifier, True)])
-def test_sklearn_random_forest_tags(forst_cls, allows_nan):
+@pytest.mark.parametrize("forst_cls", [ExtraTreesClassifier, RandomForestClassifier])
+def test_sklearn_random_forest_tags(forst_cls):
est = forst_cls()
# https://scikit-learn.org/stable/developers/develop.html#estimator-tags
- tags = est._get_tags()
- assert tags["multioutput"]
- assert tags["requires_fit"]
- assert tags["requires_y"]
- assert tags["allow_nan"] is allows_nan
+ tags = est.__sklearn_tags__()
+ assert tags.target_tags.multi_output
+ assert tags.requires_fit
+ assert tags.target_tags.required
+ assert tags.input_tags.allow_nan
@pytest.mark.parametrize("forest_cls", FORESTS)
@@ -252,6 +238,31 @@ def test_fit_with_small_max_samples(make_whas500, forest_cls):
assert tree1.node_count > tree2.node_count, msg
+@pytest.mark.parametrize("forest_cls", FORESTS)
+def test_max_samples_without_bootstrap(make_whas500, forest_cls):
+ whas500 = make_whas500(to_numeric=True)
+
+ est = forest_cls(n_estimators=1, random_state=1, bootstrap=False, max_samples=10)
+ msg = (
+ r"`max_sample` cannot be set if `bootstrap=False`\. "
+ r"Either switch to `bootstrap=True` or set `max_sample=None`\."
+ )
+ with pytest.raises(ValueError, match=msg):
+ est.fit(whas500.x, whas500.y)
+
+
+@pytest.mark.parametrize("forest_cls", FORESTS)
+def test_estimators_samples(make_whas500, forest_cls):
+ whas500 = make_whas500(to_numeric=True)
+
+ est = forest_cls(n_estimators=10, max_samples=333, random_state=1, low_memory=True)
+ est.fit(whas500.x, whas500.y)
+
+ n_samples = [len(np.unique(arr)) for arr in est.estimators_samples_]
+ expected = np.array([255, 227, 245, 247, 246, 239, 254, 252, 245, 248])
+ assert_array_equal(n_samples, expected)
+
+
@pytest.mark.parametrize("forest_cls", FORESTS)
@pytest.mark.parametrize("func", ["predict_survival_function", "predict_cumulative_hazard_function"])
def test_pipeline_predict(breast_cancer, forest_cls, func):
diff --git a/tests/test_minlip.py b/tests/test_minlip.py
index 186b4f2a..fec35e47 100644
--- a/tests/test_minlip.py
+++ b/tests/test_minlip.py
@@ -7,6 +7,7 @@
from sksurv.column import encode_categorical
from sksurv.datasets import load_gbsg2
from sksurv.exceptions import NoComparablePairException
+from sksurv.kernels import ClinicalKernelTransform
from sksurv.svm._minlip import create_difference_matrix
from sksurv.svm.minlip import HingeLossSurvivalSVM, MinlipSurvivalAnalysis
from sksurv.testing import FixtureParameterFactory, assert_cindex_almost_equal
@@ -396,6 +397,22 @@ def test_kernel_precomputed(gbsg2_scaled, solver):
p = m.predict(X_test)
assert_cindex_almost_equal(y_test["cens"], y_test["time"], p, (0.6518928901200369, 8472, 4524, 0, 3))
+ @staticmethod
+ @pytest.mark.slow()
+ def test_fit_clinical_kernel(make_whas500):
+ whas500 = make_whas500(with_mean=False, with_std=False)
+
+ trans = ClinicalKernelTransform()
+ trans.fit(whas500.x_data_frame)
+
+ m = MinlipSurvivalAnalysis(kernel=trans.pairwise_kernel)
+ m.fit(whas500.x, whas500.y)
+
+ assert not m.__sklearn_tags__().input_tags.pairwise
+
+ c = m.score(whas500.x, whas500.y)
+ assert c == pytest.approx(0.7314135916645598)
+
@staticmethod
@pytest.mark.parametrize("solver", ["osqp", "ecos"])
def test_max_iter(gbsg2_scaled, solver):
diff --git a/tests/test_survival_svm.py b/tests/test_survival_svm.py
index adb6f6f3..44f2896f 100644
--- a/tests/test_survival_svm.py
+++ b/tests/test_survival_svm.py
@@ -517,7 +517,7 @@ def test_fit_and_predict_linear(kernel, make_whas500):
x = whas500.x
ssvm.fit(x, whas500.y)
- assert ssvm._more_tags()["pairwise"] is (kernel == "precomputed")
+ assert ssvm.__sklearn_tags__().input_tags.pairwise is (kernel == "precomputed")
assert whas500.x.shape[0] == ssvm.coef_.shape[0]
@@ -545,7 +545,7 @@ def test_fit_and_predict_linear_regression(kernel, make_whas500):
x = whas500.x
ssvm.fit(x, whas500.y)
- assert ssvm._get_tags()["pairwise"] is (kernel == "precomputed")
+ assert ssvm.__sklearn_tags__().input_tags.pairwise is (kernel == "precomputed")
assert float(ssvm.intercept_) == pytest.approx(6.416017539824949, 1e-5)
@@ -581,7 +581,7 @@ def test_fit_and_predict_rbf(make_whas500, optimizer):
ssvm = FastKernelSurvivalSVM(optimizer=optimizer, kernel="rbf", tol=2e-6, max_iter=75, random_state=0)
ssvm.fit(whas500.x, whas500.y)
- assert not ssvm._get_tags()["pairwise"]
+ assert not ssvm.__sklearn_tags__().input_tags.pairwise
assert whas500.x.shape[0] == ssvm.coef_.shape[0]
c = ssvm.score(whas500.x, whas500.y)
@@ -597,7 +597,7 @@ def test_fit_and_predict_regression_rbf(make_whas500):
)
ssvm.fit(whas500.x, whas500.y)
- assert not ssvm._get_tags()["pairwise"]
+ assert not ssvm.__sklearn_tags__().input_tags.pairwise
assert ssvm.intercept_ == pytest.approx(4.9267218894089533, 1e-7)
pred = ssvm.predict(whas500.x)
@@ -623,7 +623,7 @@ def test_fit_and_predict_hybrid_polynomial(make_whas500):
)
ssvm.fit(X, whas500.y)
- assert not ssvm._get_tags()["pairwise"]
+ assert not ssvm.__sklearn_tags__().input_tags.pairwise
assert pytest.approx(6.482593184472981, 1e-5) == ssvm.intercept_
pred = ssvm.predict(X)
@@ -644,7 +644,7 @@ def test_fit_and_predict_clinical_kernel(make_whas500):
)
ssvm.fit(whas500.x, whas500.y)
- assert not ssvm._get_tags()["pairwise"]
+ assert not ssvm.__sklearn_tags__().input_tags.pairwise
assert whas500.x.shape[0] == ssvm.coef_.shape[0]
c = ssvm.score(whas500.x, whas500.y)
diff --git a/tests/test_tree.py b/tests/test_tree.py
index 4178bfd1..fc2a214a 100644
--- a/tests/test_tree.py
+++ b/tests/test_tree.py
@@ -14,7 +14,7 @@
from sksurv.compare import compare_survival
from sksurv.datasets import load_breast_cancer, load_veterans_lung_cancer
from sksurv.nonparametric import kaplan_meier_estimator, nelson_aalen_estimator
-from sksurv.tree import ExtraSurvivalTree, SurvivalTree
+from sksurv.tree import SurvivalTree
from sksurv.util import Surv
@@ -837,19 +837,3 @@ def test_missing_values_best_splitter_to_right():
# missing values go to the right
y_expected = tree.tree_.value[4]
assert_array_almost_equal(y_pred, y_expected)
-
-
-@pytest.mark.parametrize("is_sparse", [False, True])
-def test_missing_value_random_splitter_errors(is_sparse):
- X = np.array([[3, 5, 7, 11, np.nan, 13, 17, np.nan, 19]], dtype=np.float32).T
- y = Surv.from_arrays(
- event=np.array([True, True, True, False, True, False, False, False, True]),
- time=np.array([90, 80, 70, 60, 50, 40, 30, 20, 10]),
- )
-
- if is_sparse:
- X = sparse.csr_matrix(X)
-
- tree = ExtraSurvivalTree()
- with pytest.raises(ValueError, match="Input X contains NaN"):
- tree.fit(X, y)
diff --git a/tox.ini b/tox.ini
index 513bce32..510cde41 100644
--- a/tox.ini
+++ b/tox.ini
@@ -11,7 +11,7 @@ deps =
description = Run linters
skip_install = true
deps =
- ruff~=0.7.1
+ ruff~=0.8.4
commands = ruff check sksurv/ tests/ setup.py
pass_env = RUFF_*