Skip to content

Commit

Permalink
FIX LogrankCriterion does not consider sample_weights
Browse files Browse the repository at this point in the history
Fixes #443
  • Loading branch information
sebp committed Jun 28, 2024
1 parent bceb53e commit d2d7eb0
Showing 1 changed file with 35 additions and 31 deletions.
66 changes: 35 additions & 31 deletions sksurv/tree/_criterion.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,17 @@ cpdef get_unique_times(cnp.ndarray[float64_t, ndim=1] time, cnp.ndarray[cnp.npy_
cdef class RisksetCounter:
cdef:
const float64_t[:] unique_times
cnp.npy_int64 * n_events
cnp.npy_int64 * n_at_risk
float64_t * n_events
float64_t * n_at_risk
const float64_t[:, ::1] data
const float64_t[:] sample_weight
intp_t nbytes

def __cinit__(self, const float64_t[:] unique_times):
cdef intp_t n_unique_times = unique_times.shape[0]
self.nbytes = n_unique_times * sizeof(cnp.npy_int64)
self.n_events = <cnp.npy_int64 *> malloc(self.nbytes)
self.n_at_risk = <cnp.npy_int64 *> malloc(self.nbytes)
self.nbytes = n_unique_times * sizeof(float64_t)
self.n_events = <float64_t *> malloc(self.nbytes)
self.n_at_risk = <float64_t *> malloc(self.nbytes)
self.unique_times = unique_times

def __dealloc__(self):
Expand All @@ -63,8 +64,9 @@ cdef class RisksetCounter:
memset(self.n_events, 0, self.nbytes)
memset(self.n_at_risk, 0, self.nbytes)

cdef void set_data(self, const float64_t[:, ::1] data) noexcept nogil:
cdef void set_data(self, const float64_t[:, ::1] data, const float64_t[:] sample_weight) noexcept nogil:
self.data = data
self.sample_weight = sample_weight

cdef void update(self, const intp_t[:] samples, intp_t start, intp_t end) noexcept nogil:
cdef:
Expand All @@ -73,6 +75,7 @@ cdef class RisksetCounter:
intp_t ti
float64_t time
float64_t event
float64_t w = 1.0
const float64_t[:] unique_times = self.unique_times
intp_t n_times = unique_times.shape[0]
const float64_t[:, ::1] y = self.data
Expand All @@ -83,22 +86,25 @@ cdef class RisksetCounter:
idx = samples[i]
time, event = y[idx, 0], y[idx, 1]

if self.sample_weight is not None:
w = self.sample_weight[idx]

# i-th sample is in all risk sets with time <= i-th time
ti = 0
while ti < n_times and unique_times[ti] < time:
self.n_at_risk[ti] += 1
self.n_at_risk[ti] += w
ti += 1

if ti < n_times: # unique_times[ti] == time
self.n_at_risk[ti] += 1
self.n_at_risk[ti] += w
if event != 0.0:
self.n_events[ti] += 1
self.n_events[ti] += w

cdef inline void at(self, intp_t index, float64_t * at_risk, float64_t * events) noexcept nogil:
if at_risk != NULL:
at_risk[0] = <float64_t> self.n_at_risk[index]
at_risk[0] = self.n_at_risk[index]
if events != NULL:
events[0] = <float64_t> self.n_events[index]
events[0] = self.n_events[index]


cdef int argbinsearch(const float64_t[:] arr, float64_t key_val, intp_t * ret) except -1 nogil:
Expand Down Expand Up @@ -135,10 +141,9 @@ cdef class LogrankCriterion(Criterion):
intp_t n_unique_times
intp_t nbytes
RisksetCounter riskset_total
cnp.npy_int64 * delta_n_at_risk_left
cnp.npy_int64 * n_events_left
float64_t * weighted_n_events_left
float64_t * weighted_delta_n_at_risk_left
intp_t * samples_time_idx
intp_t n_samples_left

def __cinit__(self, intp_t n_outputs, intp_t n_samples, const float64_t[::1] unique_times, const cnp.npy_bool[::1] is_event_time):
# Default values
Expand All @@ -151,21 +156,21 @@ cdef class LogrankCriterion(Criterion):
self.unique_times = unique_times
self.is_event_time = is_event_time
self.n_unique_times = unique_times.shape[0]
self.nbytes = self.n_unique_times * sizeof(cnp.npy_int64)
self.nbytes = self.n_unique_times * sizeof(float64_t)
self.n_node_samples = 0
self.weighted_n_node_samples = 0.0
self.weighted_n_left = 0.0
self.weighted_n_right = 0.0

self.riskset_total = RisksetCounter(unique_times)
self.delta_n_at_risk_left = <cnp.npy_int64 *> malloc(self.nbytes)
self.n_events_left = <cnp.npy_int64 *> malloc(self.nbytes)
self.weighted_delta_n_at_risk_left = <float64_t *> malloc(self.nbytes)
self.weighted_n_events_left = <float64_t *> malloc(self.nbytes)
self.samples_time_idx = <intp_t *> malloc(n_samples * sizeof(intp_t))

def __dealloc__(self):
"""Destructor."""
free(self.delta_n_at_risk_left)
free(self.n_events_left)
free(self.weighted_delta_n_at_risk_left)
free(self.weighted_n_events_left)
free(self.samples_time_idx)

def __reduce__(self):
Expand Down Expand Up @@ -199,7 +204,7 @@ cdef class LogrankCriterion(Criterion):
float64_t w = 1.0
const float64_t[::1] unique_times = self.unique_times

self.riskset_total.set_data(y)
self.riskset_total.set_data(y, sample_weight)
self.riskset_total.update(sample_indices, start, end)

for i in range(start, end):
Expand Down Expand Up @@ -244,9 +249,8 @@ cdef class LogrankCriterion(Criterion):
intp_t time_idx
float64_t w = 1.0

self.n_samples_left = new_pos - pos
memset(self.delta_n_at_risk_left, 0, self.nbytes)
memset(self.n_events_left, 0, self.nbytes)
memset(self.weighted_delta_n_at_risk_left, 0, self.nbytes)
memset(self.weighted_n_events_left, 0, self.nbytes)

# Update statistics up to new_pos
self.weighted_n_left = 0.0
Expand All @@ -255,13 +259,13 @@ cdef class LogrankCriterion(Criterion):
event = y[idx, 1]
time_idx = self.samples_time_idx[idx]

self.delta_n_at_risk_left[time_idx] += 1
if event != 0.0:
self.n_events_left[time_idx] += 1

if sample_weight is not None:
w = sample_weight[idx]

self.weighted_delta_n_at_risk_left[time_idx] += w
if event != 0.0:
self.weighted_n_events_left[time_idx] += w

self.weighted_n_left += w

self.weighted_n_right = (self.weighted_n_node_samples -
Expand All @@ -283,7 +287,7 @@ cdef class LogrankCriterion(Criterion):

cdef:
intp_t i
float64_t at_risk = <float64_t> self.n_samples_left
float64_t weighted_at_risk = self.weighted_n_left
float64_t events
float64_t total_at_risk
float64_t total_events
Expand All @@ -293,19 +297,19 @@ cdef class LogrankCriterion(Criterion):
float64_t numer = 0.0

for i in range(self.n_unique_times):
events = <float64_t> self.n_events_left[i]
events = self.weighted_n_events_left[i]
self.riskset_total.at(i, &total_at_risk, &total_events)

if total_at_risk == 0:
break # we reached the end
ratio = at_risk / total_at_risk
ratio = weighted_at_risk / total_at_risk
numer += events - total_events * ratio
if total_at_risk > 1.0:
v = (total_at_risk - total_events) / (total_at_risk - 1.0) * total_events
denom += ratio * (1.0 - ratio) * v

# Update number of samples at risk for next bigger timepoint
at_risk -= <float64_t> self.delta_n_at_risk_left[i]
weighted_at_risk -= self.weighted_delta_n_at_risk_left[i]

if denom != 0.0:
# absolute value is the measure of node separation
Expand Down

0 comments on commit d2d7eb0

Please sign in to comment.