Skip to content

Commit

Permalink
Refactor the calculation of confidence intervals.
Browse files Browse the repository at this point in the history
  • Loading branch information
mvlvrd committed Jan 9, 2025
1 parent 6d389db commit 32f4bc0
Showing 1 changed file with 16 additions and 27 deletions.
43 changes: 16 additions & 27 deletions sksurv/nonparametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,17 +188,20 @@ def _compute_counts_truncated(event, time_enter, time_exit):
return uniq_times, event_counts, total_counts


def _ci_logmlog(prob_survival, sigma_t, z):
"""Compute the pointwise log-minus-log transformed confidence intervals"""
eps = np.finfo(prob_survival.dtype).eps
log_p = np.zeros_like(prob_survival)
np.log(prob_survival, where=prob_survival > eps, out=log_p)
theta = np.zeros_like(prob_survival)
def _ci_logmlog(s, sigma_t, z):
"""Compute the pointwise log-minus-log transformed confidence intervals.
s refers to the prob_survival or the cum_inc (for the competing risks case).
sigma_t is the square root of the estimator of the log of the variance of s.
"""
eps = np.finfo(s.dtype).eps
mask = s > eps
log_p = np.zeros_like(s)
np.log(s, where=mask, out=log_p)
theta = np.zeros_like(s)
np.true_divide(sigma_t, log_p, where=log_p < -eps, out=theta)
theta = np.array([[-1], [1]]) * theta * z
theta = z * np.multiply.outer([-1, 1], theta)
ci = np.exp(np.exp(theta) * log_p)
ci[:, prob_survival <= eps] = 0.0
ci[:, 1.0 - prob_survival <= eps] = 1.0
ci[:, ~mask] = 0.0
return ci


Expand Down Expand Up @@ -601,31 +604,17 @@ def predict_ipcw(self, y):
return weights


def _cr_ci_logmlog(cum_inc, sigma_t, z):
"""Compute the pointwise log-minus-log transformed confidence intervals"""
eps = np.finfo(cum_inc.dtype).eps
non_zero_count = cum_inc > eps
log_cum_i = np.zeros_like(cum_inc)
np.log(cum_inc, where=non_zero_count, out=log_cum_i)
theta = np.zeros_like(cum_inc)
den = cum_inc * log_cum_i
np.divide(sigma_t, den, where=non_zero_count, out=theta)
theta = z * np.multiply.outer(np.array([-1, 1]), theta)
ci = np.exp(log_cum_i * np.exp(theta))
ci[:, ~non_zero_count] = 0.0
return ci


def _cum_inc_cr_ci_estimator(cum_inc, var, conf_level, conf_type):
if conf_type not in {"log-log"}:
raise ValueError(f"conf_type must be None or a str among {{'log-log'}}, but was {conf_type!r}")

if not isinstance(conf_level, numbers.Real) or not np.isfinite(conf_level) or conf_level <= 0 or conf_level >= 1.0:
raise ValueError(f"conf_level must be a float in the range (0.0, 1.0), but was {conf_level!r}")

eps = np.finfo(var.dtype).eps
z = stats.norm.isf((1.0 - conf_level) / 2.0)
sigma = np.sqrt(var)
ci = _cr_ci_logmlog(cum_inc[1:], sigma, z)
sigma = np.zeros_like(var)
np.divide(np.sqrt(var), cum_inc[1:], where=var > eps, out=sigma)
ci = _ci_logmlog(cum_inc[1:], sigma, z)
return ci


Expand Down

0 comments on commit 32f4bc0

Please sign in to comment.