Skip to content

Commit

Permalink
Fix allowed values for kernel
Browse files Browse the repository at this point in the history
FastSurvivalSVM and MinlipSurvivalAnalysis

Closes #439
  • Loading branch information
sebp committed Apr 2, 2024
1 parent 764ef79 commit 6984da2
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions sksurv/svm/minlip.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from scipy import linalg, sparse
from sklearn.base import BaseEstimator
from sklearn.exceptions import ConvergenceWarning
from sklearn.metrics.pairwise import pairwise_kernels
from sklearn.metrics.pairwise import PAIRWISE_KERNEL_FUNCTIONS, pairwise_kernels
from sklearn.utils._param_validation import Interval, StrOptions

from ..base import SurvivalAnalysisMixin
Expand Down Expand Up @@ -207,7 +207,7 @@ class MinlipSurvivalAnalysis(BaseEstimator, SurvivalAnalysisMixin):
solver : {'ecos', 'osqp'}, optional, default: 'ecos'
Which quadratic program solver to use.
kernel : {'linear', 'poly', 'rbf', 'sigmoid', 'cosine', 'precomputed'} or callable, default: 'linear'.
kernel : str or callable, default: 'linear'.
Kernel mapping used internally. This parameter is directly passed to
:func:`sklearn.metrics.pairwise.pairwise_kernels`.
If `kernel` is a string, it must be one of the metrics
Expand Down Expand Up @@ -290,7 +290,7 @@ class MinlipSurvivalAnalysis(BaseEstimator, SurvivalAnalysisMixin):
"solver": [StrOptions({"ecos", "osqp"})],
"alpha": [Interval(numbers.Real, 0, None, closed="neither")],
"kernel": [
StrOptions({"linear", "poly", "rbf", "sigmoid", "precomputed"}),
StrOptions(set(PAIRWISE_KERNEL_FUNCTIONS.keys()) | {"precomputed"}),
callable,
],
"degree": [Interval(numbers.Integral, 0, None, closed="left")],
Expand Down
6 changes: 3 additions & 3 deletions sksurv/svm/survival_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from scipy.optimize import minimize
from sklearn.base import BaseEstimator
from sklearn.exceptions import ConvergenceWarning
from sklearn.metrics.pairwise import pairwise_kernels
from sklearn.metrics.pairwise import PAIRWISE_KERNEL_FUNCTIONS, pairwise_kernels
from sklearn.utils import check_array, check_consistent_length, check_random_state, check_X_y
from sklearn.utils._param_validation import Interval, StrOptions
from sklearn.utils.extmath import safe_sparse_dot, squared_norm
Expand Down Expand Up @@ -998,7 +998,7 @@ class FastKernelSurvivalSVM(BaseSurvivalSVM, SurvivalAnalysisMixin):
Whether to calculate an intercept for the regression model. If set to ``False``, no intercept
will be calculated. Has no effect if ``rank_ratio = 1``, i.e., only ranking is performed.
kernel : {'linear', 'poly', 'rbf', 'sigmoid', 'cosine', 'precomputed'} or callable, default: 'linear'.
kernel : str or callable, default: 'linear'.
Kernel mapping used internally. This parameter is directly passed to
:func:`sklearn.metrics.pairwise.pairwise_kernels`.
If `kernel` is a string, it must be one of the metrics
Expand Down Expand Up @@ -1088,7 +1088,7 @@ class FastKernelSurvivalSVM(BaseSurvivalSVM, SurvivalAnalysisMixin):
_parameter_constraints = {
**FastSurvivalSVM._parameter_constraints,
"kernel": [
StrOptions({"linear", "poly", "rbf", "sigmoid", "precomputed"}),
StrOptions(set(PAIRWISE_KERNEL_FUNCTIONS.keys()) | {"precomputed"}),
callable,
],
"gamma": [Interval(Real, 0.0, None, closed="left"), None],
Expand Down

0 comments on commit 6984da2

Please sign in to comment.