Module auton_survival.estimators

Utilities to train survival regression models and estimate survival.

Classes

class SurvivalModel (model, random_seed=0, **hyperparams)

Universal interface to train multiple different survival models.

Parameters

model : str

A string that determines the choice of the surival analysis model. Survival model choices include:

  • dsm : Deep Survival Machines [3] model
  • dcph : Deep Cox Proportional Hazards [2] model
  • dcm : Deep Cox Mixtures [4] model
  • rsf : Random Survival Forests [1] model
  • cph : Cox Proportional Hazards [2] model
random_seed : int
Controls the reproducibility of called functions.

References

[1] Hemant Ishwaran et al. Random survival forests. The annals of applied statistics, 2(3):841–860, 2008.

[2] Cox, D. R. (1972). Regression models and life-tables. Journal of the Royal Statistical Society: Series B (Methodological).

[3] Chirag Nagpal, Xinyu Li, and Artur Dubrawski. Deep survival machines: Fully parametric survival regression and representation learning for censored data with competing risks. 2020.

[4] Nagpal, C., Yadlowsky, S., Rostamzadeh, N., and Heller, K. (2021c). Deep cox mixtures for survival regression. In Machine Learning for Healthcare Conference, pages 674–708. PMLR

Methods

def fit(self, features, outcomes, vsize=0.15, val_data=None, weights=None, weights_val=None, resample_size=1.0)

This method is used to train an instance of the survival model.

Parameters

features : pd.DataFrame
a pandas dataframe with rows corresponding to individual samples and columns as covariates.
outcomes : pd.DataFrame
a pandas dataframe with columns 'time' and 'event'.
vsize : float, default=0.15
Amount of data to set aside as the validation set. Not applicable to 'rsf' and 'cph' models.
val_data : tuple
A tuple of the validation dataset features and outcomes of 'time' and 'event'. If passed, vsize is ignored. Not applicable to 'rsf' and 'cph' models.
weights_train : list or np.array
a list or numpy array of importance weights for each sample.
weights_val :  list or np.array
a list or numpy array of importance weights for each validation set sample. Ignored if val_data is None.
resample_size : float
a float between 0 and 1 that controls the size of the resampled dataset.

Returns

self
Trained instance of a survival model.
def predict_survival(self, features, times)

Predict survival probabilities at specified time(s).

Parameters

features : pd.DataFrame
a pandas dataframe with rows corresponding to individual samples and columns as covariates.
times : float or list
a float or list of the times at which to compute the survival probability.

Returns

np.array : An array of the survival probabilites at each
 

time point in times.

def predict_risk(self, features, times)

Predict risk of an outcome occurring within the specified time(s).

Parameters

features : pd.DataFrame
a pandas dataframe with rows corresponding to individual samples and columns as covariates.
times : float or list
a float or list of the times at which to compute the risk.

Returns

np.array
numpy array of the outcome risks at each time point in times.
class CounterfactualSurvivalModel (treated_model, control_model)

Universal interface to train multiple different counterfactual survival models.

Methods

def predict_counterfactual_survival(self, features, times)
def predict_counterfactual_risk(self, features, times)