survivalpredict.estimators.CoxPHElasticNet¶
- class survivalpredict.estimators.CoxPHElasticNet(*, alpha=0.0, l1_ratio=0.5, max_iter=100, tol=1e-09)¶
Cox Proportional Hazards with Elastic Net penalty and feature shrinkage.
A Cox Proportional Hazards model with Elastic Net penalty estimated via coordinate descent. The coordinate descent algorithm for Elastic Net/Lasso allows shrinking features asynchronously as the ‘alpha’ parameter increases, and the ‘l1_ratio’ is greater than 0. The raphson-newton-like for coordinate descent described in Simon et al. (2011)[1] is used.
Only ‘breslow’ ties are available; the literature is currently unclear on how to add stratification to Simon’s algorithm.
- Parameters:
alpha (float, default=0.0) – Constant that multiplies the penalty terms. Used to penalize coefficients durring training.
l1_ratio (float, default=0.5) – The ElasticNet mixing parameter, with
0 <= l1_ratio <= 1. Forl1_ratio = 0the penalty is an L2 penalty.For l1_ratio = 1it is an L1 penalty. For0 < l1_ratio < 1, the penalty is a combination of L1 and L2.max_iter (int, default=100) – The maximum number of iterations.
tol (float, default=1e-9) – The tolerance for the optimization: if the updates are smaller or equal to
tol, the optimization code checks the dual gap for optimality and continues until it is smaller or equal totol.
References
[1] Simon N, Friedman J, Hastie T, Tibshirani R. Regularization Paths for Cox’s Proportional Hazards Model via Coordinate Descent. J Stat Softw. 2011 Mar;39(5):1-13. doi: 10.18637/jss.v039.i05. PMID: 27065756; PMCID: PMC4824408.
Methods
fit(X, times, events[, check_input, times_start])Fit model.
fit_predict(*args, **kwargs)Fit model and Build survival curves.
predict(X[, max_time])Build survival curves on an array of vectors X.
predict_risk(X)Build relative risk on an array of vectors X.
- __init__(*, alpha=0.0, l1_ratio=0.5, max_iter=100, tol=1e-09)¶
- Parameters:
alpha (float)
l1_ratio (float)
max_iter (int | None)
tol (float)
Methods
__init__(*[, alpha, l1_ratio, max_iter, tol])fit(X, times, events[, check_input, times_start])Fit model.
fit_predict(*args, **kwargs)Fit model and Build survival curves.
get_metadata_routing()Get metadata routing of this object.
get_params([deep])Get parameters for this estimator.
predict(X[, max_time])Build survival curves on an array of vectors X.
predict_risk(X)Build relative risk on an array of vectors X.
set_fit_request(*[, check_input, events, ...])Configure whether metadata should be requested to be passed to the
fitmethod.set_params(**params)Set the parameters of this estimator.
set_predict_request(*[, max_time])Configure whether metadata should be requested to be passed to the
predictmethod.- fit(X, times, events, check_input=True, times_start=None)¶
Fit model.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data.
times (array-like of shape (n_samples), dtype=np.int64) – Point in time last observed.
events (array-like of shape (n_samples), dtype=np.bool_) – Experianed event.
check_input (bool, default=True) – If True, validates and casts inputs.
times_start (array-like of shape (n_samples, dtype=np.int64), default=None) – Starting point for observation. If not passed in, all times_start times are assumed to be 0.
- Returns:
Fitted Estimator.
- Return type:
object
- fit_predict(*args, **kwargs)¶
Fit model and Build survival curves.
- predict(X, max_time=None)¶
Build survival curves on an array of vectors X.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Predicting data.
max_time (int, default=None) – Maximum time of built survival curves. If none, maximum time is max time seen on training data.
- Returns:
The estimated survival curves, the left-most column is the probability of survival at time 1, and the right-most column ends at max_time.
- Return type:
ndarray of shape (n_samples, max_time), dtype=np.float64
- predict_risk(X)¶
Build relative risk on an array of vectors X.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Predicting data.
- Returns:
The Relative risk of X, used under the hood for building survival curves. Relative risk is what ‘Concordance Index’ examines.
- Return type:
ndarray of shape (n_samples), dtype=np.float64