delicatessen.utilities.regression_predictions
- regression_predictions(X, theta, covariance, offset=None, alpha=0.05)
Generate predicted values of the outcome given a design matrix, point estimates, and covariance matrix. This functionality computes \(\hat{Y}\), \(\hat{Var}\left(\hat{Y}\right)\), and corresponding Wald-type \((1 - \alpha) \times\) 100% confidence intervals from estimated coefficients and covariance of a regression model given a set of specific covariate values.
This function is a helper function to compute the predictions from a regression model for a set of given \(X\) values. Importantly, this method allows for the variance of \(\hat{Y}\) to be estimated without having to expand the estimating equations. As such, this functionality is meant to be used after
MEstimatorhas been used to estimate the coefficients (i.e., this function is for use after the M-estimator has computed the results for the chosen regression model). Therefore, this function should be viewed as a post-processing functionality for generating plots or tables.Note
No tranformations are applied by this function. So, input from a logistic model will generate the log-odds of the outcome (not probability).
- Parameters
X (ndarray, list, vector) – 2-dimensional vector of values to generate predicted variances for. The number of columns must match the number of coefficients / parameters in
theta.theta (ndarray) – Estimated coefficients from
MEstimator.theta.covariance (ndarray) – Estimated covariance matrix from
MEstimator.variance.offset (ndarray, None, optional) – A 1-dimensional offset to be included in the model. Default is None, which applies no offset term.
alpha (float, optional) – The \(\alpha\) level for the corresponding confidence intervals. Default is 0.05, which calculate the 95% confidence intervals. Notice that \(0<\alpha<1\).
- Returns
Returns a n-by-4 NumPy array with the 4 columns correspond to the predicted outcome, variance of the predictied outcome, lower confidence limit, and upper confidence limit.
- Return type
array
Examples
The following is a simple example demonstrating how this function can be used to plot a regression line and corresponding 95% confidence intervals.
>>> import numpy as np >>> import pandas as pd >>> import matplotlib.pyplot as plt >>> from delicatessen import MEstimator >>> from delicatessen.estimating_equations import ee_regression >>> from delicatessen.utilities import regression_predictions
Some generic data to estimate the regression model with
>>> n = 50 >>> data = pd.DataFrame() >>> data['X'] = np.random.normal(size=n) >>> data['Z'] = np.random.normal(size=n) >>> data['Y'] = 0.5 + 2*data['X'] - 1*data['Z'] + np.random.normal(loc=0, size=n) >>> data['C'] = 1
Estimating the corresponding regression model parameters
>>> def psi(theta): >>> return ee_regression(theta=theta, X=data[['C', 'X', 'Z']], y=data['Y'], model='linear')
>>> estr = MEstimator(stacked_equations=psi, init=[0., 0., 0.,]) >>> estr.estimate(solver='lm')
To create a line plot of our regression line, we need to first create a new set of covariate values that are evenly spaced across the range of the predictor values. Here, we will plot the relationship between
ZandYwhile holdingX=0.>>> pred = pd.DataFrame() >>> pred['Z'] = np.linspace(np.min(data['Z']), np.max(data['Z']), 200) >>> pred['X'] = 0 >>> pred['C'] = 1
Now the predicted values of the outcome, and confidence intervals to plot
>>> Xp = pred[['C', 'X', 'Z']] >>> yhat = regression_predictions(X=Xp, theta=estr.theta, covariance=estr.variance)
Finally, the predicted values can be plotted (using matplotlib)
>>> plt.plot(pred['Z'], yhat[:, 0], '-', color='blue') >>> plt.fill_between(pred['Z'], yhat[:, 2], yhat[:, 3], alpha=0.25, color='blue') >>> plt.show()
For predicting with a Poisson or logistic model, one may want to transform the predicted values and confidence intervals to another measure. For the logistic model, the predicted log-odds can easily be transformed using
delicatessen.utilities.inverse_logit. For the Poisson model, predictions can easily be transformed usingnumpy.exp.