Module auton_survival.experiments

Utilities to perform cross-validation.

Classes

class SurvivalRegressionCV (model='dcph', folds=None, num_folds=5, random_seed=0, hyperparam_grid={})

Universal interface to train Survival Analysis models in a cross- validation fashion.

The model is trained in a CV fashion over the user-specified hyperparameter grid. Model hyperparameters are selected based on the user-specified metric.

Parameters

model : str
A string that determines the choice of the surival regression model. Survival model choices include: - 'dsm' : Deep Survival Machines [3] model - 'dcph' : Deep Cox Proportional Hazards [2] model - 'dcm' : Deep Cox Mixtures [4] model - 'rsf' : Random Survival Forests [1] model - 'cph' : Cox Proportional Hazards [2] model
model : str, default='dcph'
Survival regression model name.
folds : list, default=None
A list of fold assignment values for each sample. For regular (unnested) cross-validation, folds correspond to train and validation set. For nested cross-validation, folds correspond to train and test set.
num_folds : int, default=5
The number of folds. Ignored if folds is specified.
random_seed : int, default=0
Controls reproducibility of results.
hyperparam_grid : dict
A dictionary that contains the hyperparameters for grid search. The keys of the dictionary are the hyperparameter names and the values are lists of hyperparameter values.

References

[1] Hemant Ishwaran et al. Random survival forests. The annals of applied statistics, 2(3):841–860, 2008. [2] Cox, D. R. (1972). Regression models and life-tables. Journal of the Royal Statistical Society: Series B (Methodological). [3] Chirag Nagpal, Xinyu Li, and Artur Dubrawski. Deep survival machines: Fully parametric survival regression and representation learning for censored data with competing risks. 2020. [4] Nagpal, C., Yadlowsky, S., Rostamzadeh, N., and Heller, K. (2021c). Deep cox mixtures for survival regression. In Machine Learning for Healthcare Conference, pages 674–708. PMLR

Methods

def fit(self, features, outcomes, horizons, metric='ibs')

Fits the survival regression model to the data in a cross- validation or nested cross-validation fashion.

Parameters

features : pd.DataFrame
A pandas dataframe with rows corresponding to individual samples and columns as covariates.
outcomes : pd.DataFrame
A pandas dataframe with columns 'time' and 'event' that contain the survival time and censoring status \delta_i = 1 , respectively.
horizons : int or float or list
Event-horizons at which to evaluate model performance.
metric : str, default='ibs'
Metric used to evaluate model performance and tune hyperparameters. Options include: - 'auc': Dynamic area under the ROC curve - 'brs' : Brier Score - 'ibs' : Integrated Brier Score - 'ctd' : Concordance Index

Returns

Trained survival regression model(s).

class CounterfactualSurvivalRegressionCV (model, cv_folds=5, random_seed=0, hyperparam_grid={})

Universal interface to train Counterfactual Survival Analysis models in a Cross Validation fashion.

Each of the model is trained in a CV fashion over the user specified hyperparameter grid. The best model (in terms of integrated brier score) is then selected.

Parameters

model : str
A string that determines the choice of the surival analysis model. Survival model choices include: - 'dsm' : Deep Survival Machines [3] model - 'dcph' : Deep Cox Proportional Hazards [2] model - 'dcm' : Deep Cox Mixtures [4] model - 'rsf' : Random Survival Forests [1] model - 'cph' : Cox Proportional Hazards [2] model
cv_folds : int
Number of folds in the cross validation.
random_seed : int
Random seed for reproducibility.
hyperparam_grid : dict
A dictionary that contains the hyperparameters for grid search. The keys of the dictionary are the hyperparameter names and the values are lists of hyperparameter values.

References

[1] Hemant Ishwaran et al. Random survival forests. The annals of applied statistics, 2(3):841–860, 2008.

[2] Cox, D. R. (1972). Regression models and life-tables. Journal of the Royal Statistical Society: Series B (Methodological).

[3] Chirag Nagpal, Xinyu Li, and Artur Dubrawski. Deep survival machines: Fully parametric survival regression and representation learning for censored data with competing risks. 2020.

[4] Nagpal, C., Yadlowsky, S., Rostamzadeh, N., and Heller, K. (2021c). Deep cox mixtures for survival regression. In Machine Learning for Healthcare Conference, pages 674–708. PMLR

Methods

def fit(self, features, outcomes, interventions, horizons, metric)

Fits the Survival Regression Model to the data in a cross- validation fashion.

Parameters

features : pandas.DataFrame
a pandas dataframe containing the features to use as covariates.
outcomes : pandas.DataFrame
a pandas dataframe containing the survival outcomes. The index of the dataframe should be the same as the index of the features dataframe. Should contain a column named 'time' that contains the survival time and a column named 'event' that contains the censoring status. \delta_i = 1 if the event is observed.
interventions : pandas.Series
A pandas series containing the treatment status of each subject. a_i = 1 if the subject is treated, else is considered control.
horizons : int or float or list
Event-horizons at which to evaluate model performance.
metric : str, default='ibs'
Metric used to evaluate model performance and tune hyperparameters. Options include: - 'auc': Dynamic area under the ROC curve - 'brs' : Brier Score - 'ibs' : Integrated Brier Score - 'ctd' : Concordance Index

Returns

auton_survival.estimators.CounterfactualSurvivalModel:
The trained counterfactual survival model.