blob: 9757682bc9e2d181b53d2a856bb771ab0a008cbc [file] [log] [blame]
Raymonddee08492015-04-02 10:43:13 -07001/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17package org.apache.commons.math.stat.regression;
18
19import org.apache.commons.math.MathRuntimeException;
20import org.apache.commons.math.exception.util.LocalizedFormats;
21import org.apache.commons.math.linear.RealMatrix;
22import org.apache.commons.math.linear.Array2DRowRealMatrix;
23import org.apache.commons.math.linear.RealVector;
24import org.apache.commons.math.linear.ArrayRealVector;
25import org.apache.commons.math.stat.descriptive.moment.Variance;
26import org.apache.commons.math.util.FastMath;
27
28/**
29 * Abstract base class for implementations of MultipleLinearRegression.
30 * @version $Revision: 1073459 $ $Date: 2011-02-22 20:18:12 +0100 (mar. 22 févr. 2011) $
31 * @since 2.0
32 */
33public abstract class AbstractMultipleLinearRegression implements
34 MultipleLinearRegression {
35
36 /** X sample data. */
37 protected RealMatrix X;
38
39 /** Y sample data. */
40 protected RealVector Y;
41
42 /** Whether or not the regression model includes an intercept. True means no intercept. */
43 private boolean noIntercept = false;
44
45 /**
46 * @return true if the model has no intercept term; false otherwise
47 * @since 2.2
48 */
49 public boolean isNoIntercept() {
50 return noIntercept;
51 }
52
53 /**
54 * @param noIntercept true means the model is to be estimated without an intercept term
55 * @since 2.2
56 */
57 public void setNoIntercept(boolean noIntercept) {
58 this.noIntercept = noIntercept;
59 }
60
61 /**
62 * <p>Loads model x and y sample data from a flat input array, overriding any previous sample.
63 * </p>
64 * <p>Assumes that rows are concatenated with y values first in each row. For example, an input
65 * <code>data</code> array containing the sequence of values (1, 2, 3, 4, 5, 6, 7, 8, 9) with
66 * <code>nobs = 3</code> and <code>nvars = 2</code> creates a regression dataset with two
67 * independent variables, as below:
68 * <pre>
69 * y x[0] x[1]
70 * --------------
71 * 1 2 3
72 * 4 5 6
73 * 7 8 9
74 * </pre>
75 * </p>
76 * <p>Note that there is no need to add an initial unitary column (column of 1's) when
77 * specifying a model including an intercept term. If {@link #isNoIntercept()} is <code>true</code>,
78 * the X matrix will be created without an initial column of "1"s; otherwise this column will
79 * be added.
80 * </p>
81 * <p>Throws IllegalArgumentException if any of the following preconditions fail:
82 * <ul><li><code>data</code> cannot be null</li>
83 * <li><code>data.length = nobs * (nvars + 1)</li>
84 * <li><code>nobs > nvars</code></li></ul>
85 * </p>
86 *
87 * @param data input data array
88 * @param nobs number of observations (rows)
89 * @param nvars number of independent variables (columns, not counting y)
90 * @throws IllegalArgumentException if the preconditions are not met
91 */
92 public void newSampleData(double[] data, int nobs, int nvars) {
93 if (data == null) {
94 throw MathRuntimeException.createIllegalArgumentException(
95 LocalizedFormats.NULL_NOT_ALLOWED);
96 }
97 if (data.length != nobs * (nvars + 1)) {
98 throw MathRuntimeException.createIllegalArgumentException(
99 LocalizedFormats.INVALID_REGRESSION_ARRAY, data.length, nobs, nvars);
100 }
101 if (nobs <= nvars) {
102 throw MathRuntimeException.createIllegalArgumentException(
103 LocalizedFormats.NOT_ENOUGH_DATA_FOR_NUMBER_OF_PREDICTORS);
104 }
105 double[] y = new double[nobs];
106 final int cols = noIntercept ? nvars: nvars + 1;
107 double[][] x = new double[nobs][cols];
108 int pointer = 0;
109 for (int i = 0; i < nobs; i++) {
110 y[i] = data[pointer++];
111 if (!noIntercept) {
112 x[i][0] = 1.0d;
113 }
114 for (int j = noIntercept ? 0 : 1; j < cols; j++) {
115 x[i][j] = data[pointer++];
116 }
117 }
118 this.X = new Array2DRowRealMatrix(x);
119 this.Y = new ArrayRealVector(y);
120 }
121
122 /**
123 * Loads new y sample data, overriding any previous data.
124 *
125 * @param y the array representing the y sample
126 * @throws IllegalArgumentException if y is null or empty
127 */
128 protected void newYSampleData(double[] y) {
129 if (y == null) {
130 throw MathRuntimeException.createIllegalArgumentException(
131 LocalizedFormats.NULL_NOT_ALLOWED);
132 }
133 if (y.length == 0) {
134 throw MathRuntimeException.createIllegalArgumentException(
135 LocalizedFormats.NO_DATA);
136 }
137 this.Y = new ArrayRealVector(y);
138 }
139
140 /**
141 * <p>Loads new x sample data, overriding any previous data.
142 * </p>
143 * The input <code>x</code> array should have one row for each sample
144 * observation, with columns corresponding to independent variables.
145 * For example, if <pre>
146 * <code> x = new double[][] {{1, 2}, {3, 4}, {5, 6}} </code></pre>
147 * then <code>setXSampleData(x) </code> results in a model with two independent
148 * variables and 3 observations:
149 * <pre>
150 * x[0] x[1]
151 * ----------
152 * 1 2
153 * 3 4
154 * 5 6
155 * </pre>
156 * </p>
157 * <p>Note that there is no need to add an initial unitary column (column of 1's) when
158 * specifying a model including an intercept term.
159 * </p>
160 * @param x the rectangular array representing the x sample
161 * @throws IllegalArgumentException if x is null, empty or not rectangular
162 */
163 protected void newXSampleData(double[][] x) {
164 if (x == null) {
165 throw MathRuntimeException.createIllegalArgumentException(
166 LocalizedFormats.NULL_NOT_ALLOWED);
167 }
168 if (x.length == 0) {
169 throw MathRuntimeException.createIllegalArgumentException(
170 LocalizedFormats.NO_DATA);
171 }
172 if (noIntercept) {
173 this.X = new Array2DRowRealMatrix(x, true);
174 } else { // Augment design matrix with initial unitary column
175 final int nVars = x[0].length;
176 final double[][] xAug = new double[x.length][nVars + 1];
177 for (int i = 0; i < x.length; i++) {
178 if (x[i].length != nVars) {
179 throw MathRuntimeException.createIllegalArgumentException(
180 LocalizedFormats.DIFFERENT_ROWS_LENGTHS,
181 x[i].length, nVars);
182 }
183 xAug[i][0] = 1.0d;
184 System.arraycopy(x[i], 0, xAug[i], 1, nVars);
185 }
186 this.X = new Array2DRowRealMatrix(xAug, false);
187 }
188 }
189
190 /**
191 * Validates sample data. Checks that
192 * <ul><li>Neither x nor y is null or empty;</li>
193 * <li>The length (i.e. number of rows) of x equals the length of y</li>
194 * <li>x has at least one more row than it has columns (i.e. there is
195 * sufficient data to estimate regression coefficients for each of the
196 * columns in x plus an intercept.</li>
197 * </ul>
198 *
199 * @param x the [n,k] array representing the x data
200 * @param y the [n,1] array representing the y data
201 * @throws IllegalArgumentException if any of the checks fail
202 *
203 */
204 protected void validateSampleData(double[][] x, double[] y) {
205 if ((x == null) || (y == null) || (x.length != y.length)) {
206 throw MathRuntimeException.createIllegalArgumentException(
207 LocalizedFormats.DIMENSIONS_MISMATCH_SIMPLE,
208 (x == null) ? 0 : x.length,
209 (y == null) ? 0 : y.length);
210 }
211 if (x.length == 0) { // Must be no y data either
212 throw MathRuntimeException.createIllegalArgumentException(
213 LocalizedFormats.NO_DATA);
214 }
215 if (x[0].length + 1 > x.length) {
216 throw MathRuntimeException.createIllegalArgumentException(
217 LocalizedFormats.NOT_ENOUGH_DATA_FOR_NUMBER_OF_PREDICTORS,
218 x.length, x[0].length);
219 }
220 }
221
222 /**
223 * Validates that the x data and covariance matrix have the same
224 * number of rows and that the covariance matrix is square.
225 *
226 * @param x the [n,k] array representing the x sample
227 * @param covariance the [n,n] array representing the covariance matrix
228 * @throws IllegalArgumentException if the number of rows in x is not equal
229 * to the number of rows in covariance or covariance is not square.
230 */
231 protected void validateCovarianceData(double[][] x, double[][] covariance) {
232 if (x.length != covariance.length) {
233 throw MathRuntimeException.createIllegalArgumentException(
234 LocalizedFormats.DIMENSIONS_MISMATCH_SIMPLE, x.length, covariance.length);
235 }
236 if (covariance.length > 0 && covariance.length != covariance[0].length) {
237 throw MathRuntimeException.createIllegalArgumentException(
238 LocalizedFormats.NON_SQUARE_MATRIX,
239 covariance.length, covariance[0].length);
240 }
241 }
242
243 /**
244 * {@inheritDoc}
245 */
246 public double[] estimateRegressionParameters() {
247 RealVector b = calculateBeta();
248 return b.getData();
249 }
250
251 /**
252 * {@inheritDoc}
253 */
254 public double[] estimateResiduals() {
255 RealVector b = calculateBeta();
256 RealVector e = Y.subtract(X.operate(b));
257 return e.getData();
258 }
259
260 /**
261 * {@inheritDoc}
262 */
263 public double[][] estimateRegressionParametersVariance() {
264 return calculateBetaVariance().getData();
265 }
266
267 /**
268 * {@inheritDoc}
269 */
270 public double[] estimateRegressionParametersStandardErrors() {
271 double[][] betaVariance = estimateRegressionParametersVariance();
272 double sigma = calculateErrorVariance();
273 int length = betaVariance[0].length;
274 double[] result = new double[length];
275 for (int i = 0; i < length; i++) {
276 result[i] = FastMath.sqrt(sigma * betaVariance[i][i]);
277 }
278 return result;
279 }
280
281 /**
282 * {@inheritDoc}
283 */
284 public double estimateRegressandVariance() {
285 return calculateYVariance();
286 }
287
288 /**
289 * Estimates the variance of the error.
290 *
291 * @return estimate of the error variance
292 * @since 2.2
293 */
294 public double estimateErrorVariance() {
295 return calculateErrorVariance();
296
297 }
298
299 /**
300 * Estimates the standard error of the regression.
301 *
302 * @return regression standard error
303 * @since 2.2
304 */
305 public double estimateRegressionStandardError() {
306 return Math.sqrt(estimateErrorVariance());
307 }
308
309 /**
310 * Calculates the beta of multiple linear regression in matrix notation.
311 *
312 * @return beta
313 */
314 protected abstract RealVector calculateBeta();
315
316 /**
317 * Calculates the beta variance of multiple linear regression in matrix
318 * notation.
319 *
320 * @return beta variance
321 */
322 protected abstract RealMatrix calculateBetaVariance();
323
324
325 /**
326 * Calculates the variance of the y values.
327 *
328 * @return Y variance
329 */
330 protected double calculateYVariance() {
331 return new Variance().evaluate(Y.getData());
332 }
333
334 /**
335 * <p>Calculates the variance of the error term.</p>
336 * Uses the formula <pre>
337 * var(u) = u &middot; u / (n - k)
338 * </pre>
339 * where n and k are the row and column dimensions of the design
340 * matrix X.
341 *
342 * @return error variance estimate
343 * @since 2.2
344 */
345 protected double calculateErrorVariance() {
346 RealVector residuals = calculateResiduals();
347 return residuals.dotProduct(residuals) /
348 (X.getRowDimension() - X.getColumnDimension());
349 }
350
351 /**
352 * Calculates the residuals of multiple linear regression in matrix
353 * notation.
354 *
355 * <pre>
356 * u = y - X * b
357 * </pre>
358 *
359 * @return The residuals [n,1] matrix
360 */
361 protected RealVector calculateResiduals() {
362 RealVector b = calculateBeta();
363 return Y.subtract(X.operate(b));
364 }
365
366}