Skip to content

Commit

Permalink
Use X | Y in isinstance call instead of (X, Y)
Browse files Browse the repository at this point in the history
  • Loading branch information
sebp committed Jan 1, 2025
1 parent e0232dd commit 2e784db
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 10 deletions.
4 changes: 2 additions & 2 deletions sksurv/ensemble/boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def fit(self, X, y, sample_weight=None):
Xi = np.column_stack((np.ones(n_samples), X))

self._loss = LOSS_FUNCTIONS[self.loss]()
if isinstance(self._loss, (CensoredSquaredLoss, IPCWLeastSquaresError)):
if isinstance(self._loss, CensoredSquaredLoss | IPCWLeastSquaresError):
time = np.log(time)

if not self._is_fitted():
Expand Down Expand Up @@ -1262,7 +1262,7 @@ def fit(self, X, y, sample_weight=None, monitor=None):
# self.loss is guaranteed to be a string
self._loss = self._get_loss(sample_weight=sample_weight)

if isinstance(self._loss, (CensoredSquaredLoss, IPCWLeastSquaresError)):
if isinstance(self._loss, CensoredSquaredLoss | IPCWLeastSquaresError):
time = np.log(time)

if self.n_iter_no_change is not None:
Expand Down
2 changes: 1 addition & 1 deletion sksurv/linear_model/coxph.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def fit(self, X, y):
X = validate_data(self, X, ensure_min_samples=2, dtype=np.float64)
event, time = check_array_survival(X, y)

if isinstance(self.alpha, (numbers.Real, numbers.Integral)):
if isinstance(self.alpha, numbers.Real | numbers.Integral):
alphas = np.empty(X.shape[1], dtype=float)
alphas[:] = self.alpha
else:
Expand Down
8 changes: 1 addition & 7 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,7 @@ def attr_labels(self):
return ["event", "time"]

def _to_data_frame(self, data, columns):
if isinstance(
data,
(
tuple,
list,
),
):
if isinstance(data, tuple | list):
data = np.column_stack(data)
return pd.DataFrame(data, columns=columns)

Expand Down

0 comments on commit 2e784db

Please sign in to comment.