Skip to content

Commit

Permalink
Remove bmt.csv and use load_bmt for tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
mvlvrd committed Nov 27, 2024
1 parent 3c975a9 commit 9fc899a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 52 deletions.
40 changes: 0 additions & 40 deletions tests/data/bmt.csv

This file was deleted.

25 changes: 13 additions & 12 deletions tests/test_nonparametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pandas as pd
import pytest

from sksurv.datasets import load_bmt
from sksurv.nonparametric import (
CensoringDistributionEstimator,
SurvivalFunctionEstimator,
Expand All @@ -18,7 +19,6 @@
CHANNING_FILE = join(dirname(__file__), "data", "channing.csv")
AIDS_CHILDREN_FILE = join(dirname(__file__), "data", "Lagakos_AIDS_children.csv")
AIDS_ADULTS_FILE = join(dirname(__file__), "data", "Lagakos_AIDS_adults.csv")
BMT_FILE = join(dirname(__file__), "data", "bmt.csv")


class SimpleDataKMCases(FixtureParameterFactory):
Expand Down Expand Up @@ -6236,11 +6236,12 @@ def test_whas500(make_whas500, whas500_true_x):


class SimpleDataBMTCases(FixtureParameterFactory):
bmt_df = pd.read_csv(BMT_FILE, sep=";", skiprows=4)
dis_df, bmt = load_bmt()
dis_np = dis_df["dis"].values

def data_full(self):
event = self.bmt_df["status"].values
time = self.bmt_df["ftime"].values
event = self.bmt["status"]
time = self.bmt["ftime"]

true_x = np.array([0, 1, 2, 3, 4, 5, 7, 8, 9, 10, 12, 13, 14, 22, 26, 32, 35, 67, 68, 70, 72])
true_y = np.array(
Expand Down Expand Up @@ -6272,10 +6273,10 @@ def data_full(self):
return event, time, true_x, true_y

def data_ALL(self):
dis = 0

event = self.bmt_df[self.bmt_df["dis"] == dis]["status"].values
time = self.bmt_df[self.bmt_df["dis"] == dis]["ftime"].values
dis = "0"
dis_filter = self.dis_np == dis
event = self.bmt["status"][dis_filter]
time = self.bmt["ftime"][dis_filter]

true_x = np.array([0, 1, 3, 4, 5, 7, 8, 9, 12, 13, 14, 22, 26, 35, 72])
true_y = np.array(
Expand All @@ -6301,10 +6302,10 @@ def data_ALL(self):
return event, time, true_x, true_y

def data_AML(self):
dis = 1

event = self.bmt_df[self.bmt_df["dis"] == dis]["status"].values
time = self.bmt_df[self.bmt_df["dis"] == dis]["ftime"].values
dis = "1"
dis_filter = self.dis_np == dis
event = self.bmt["status"][dis_filter]
time = self.bmt["ftime"][dis_filter]

true_x = np.array([2, 3, 4, 7, 8, 10, 32, 35, 67, 68, 70])
true_y = np.array(
Expand Down

0 comments on commit 9fc899a

Please sign in to comment.