blob: dc6ef0d477864ea9a30907c7e8e5d28503c1e8e8 [file] [log] [blame]
Raymonddee08492015-04-02 10:43:13 -07001/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17package org.apache.commons.math.stat.regression;
18
19import org.apache.commons.math.linear.LUDecompositionImpl;
20import org.apache.commons.math.linear.RealMatrix;
21import org.apache.commons.math.linear.Array2DRowRealMatrix;
22import org.apache.commons.math.linear.RealVector;
23
24/**
25 * The GLS implementation of the multiple linear regression.
26 *
27 * GLS assumes a general covariance matrix Omega of the error
28 * <pre>
29 * u ~ N(0, Omega)
30 * </pre>
31 *
32 * Estimated by GLS,
33 * <pre>
34 * b=(X' Omega^-1 X)^-1X'Omega^-1 y
35 * </pre>
36 * whose variance is
37 * <pre>
38 * Var(b)=(X' Omega^-1 X)^-1
39 * </pre>
40 * @version $Revision: 1073460 $ $Date: 2011-02-22 20:22:39 +0100 (mar. 22 févr. 2011) $
41 * @since 2.0
42 */
43public class GLSMultipleLinearRegression extends AbstractMultipleLinearRegression {
44
45 /** Covariance matrix. */
46 private RealMatrix Omega;
47
48 /** Inverse of covariance matrix. */
49 private RealMatrix OmegaInverse;
50
51 /** Replace sample data, overriding any previous sample.
52 * @param y y values of the sample
53 * @param x x values of the sample
54 * @param covariance array representing the covariance matrix
55 */
56 public void newSampleData(double[] y, double[][] x, double[][] covariance) {
57 validateSampleData(x, y);
58 newYSampleData(y);
59 newXSampleData(x);
60 validateCovarianceData(x, covariance);
61 newCovarianceData(covariance);
62 }
63
64 /**
65 * Add the covariance data.
66 *
67 * @param omega the [n,n] array representing the covariance
68 */
69 protected void newCovarianceData(double[][] omega){
70 this.Omega = new Array2DRowRealMatrix(omega);
71 this.OmegaInverse = null;
72 }
73
74 /**
75 * Get the inverse of the covariance.
76 * <p>The inverse of the covariance matrix is lazily evaluated and cached.</p>
77 * @return inverse of the covariance
78 */
79 protected RealMatrix getOmegaInverse() {
80 if (OmegaInverse == null) {
81 OmegaInverse = new LUDecompositionImpl(Omega).getSolver().getInverse();
82 }
83 return OmegaInverse;
84 }
85
86 /**
87 * Calculates beta by GLS.
88 * <pre>
89 * b=(X' Omega^-1 X)^-1X'Omega^-1 y
90 * </pre>
91 * @return beta
92 */
93 @Override
94 protected RealVector calculateBeta() {
95 RealMatrix OI = getOmegaInverse();
96 RealMatrix XT = X.transpose();
97 RealMatrix XTOIX = XT.multiply(OI).multiply(X);
98 RealMatrix inverse = new LUDecompositionImpl(XTOIX).getSolver().getInverse();
99 return inverse.multiply(XT).multiply(OI).operate(Y);
100 }
101
102 /**
103 * Calculates the variance on the beta.
104 * <pre>
105 * Var(b)=(X' Omega^-1 X)^-1
106 * </pre>
107 * @return The beta variance matrix
108 */
109 @Override
110 protected RealMatrix calculateBetaVariance() {
111 RealMatrix OI = getOmegaInverse();
112 RealMatrix XTOIX = X.transpose().multiply(OI).multiply(X);
113 return new LUDecompositionImpl(XTOIX).getSolver().getInverse();
114 }
115
116
117 /**
118 * Calculates the estimated variance of the error term using the formula
119 * <pre>
120 * Var(u) = Tr(u' Omega^-1 u)/(n-k)
121 * </pre>
122 * where n and k are the row and column dimensions of the design
123 * matrix X.
124 *
125 * @return error variance
126 * @since 2.2
127 */
128 @Override
129 protected double calculateErrorVariance() {
130 RealVector residuals = calculateResiduals();
131 double t = residuals.dotProduct(getOmegaInverse().operate(residuals));
132 return t / (X.getRowDimension() - X.getColumnDimension());
133
134 }
135
136}