blob: 0d01f772a3bd0d2ac9c9d657cdf18b3e514ec27c [file] [log] [blame]
Sascha Haeberling1d2624a2013-07-23 19:00:21 -07001// Ceres Solver - A fast non-linear least squares minimizer
2// Copyright 2013 Google Inc. All rights reserved.
3// http://code.google.com/p/ceres-solver/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7//
8// * Redistributions of source code must retain the above copyright notice,
9// this list of conditions and the following disclaimer.
10// * Redistributions in binary form must reproduce the above copyright notice,
11// this list of conditions and the following disclaimer in the documentation
12// and/or other materials provided with the distribution.
13// * Neither the name of Google Inc. nor the names of its contributors may be
14// used to endorse or promote products derived from this software without
15// specific prior written permission.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27// POSSIBILITY OF SUCH DAMAGE.
28//
29// Author: sameeragarwal@google.com (Sameer Agarwal)
30//
31// CostFunctionToFunctor is an adapter class that allows users to use
32// CostFunction objects in templated functors which are to be used for
33// automatic differentiation. This allows the user to seamlessly mix
34// analytic, numeric and automatic differentiation.
35//
36// For example, let us assume that
37//
38// class IntrinsicProjection : public SizedCostFunction<2, 5, 3> {
39// public:
40// IntrinsicProjection(const double* observations);
41// virtual bool Evaluate(double const* const* parameters,
42// double* residuals,
43// double** jacobians) const;
44// };
45//
46// is a cost function that implements the projection of a point in its
47// local coordinate system onto its image plane and subtracts it from
48// the observed point projection. It can compute its residual and
49// either via analytic or numerical differentiation can compute its
50// jacobians.
51//
52// Now we would like to compose the action of this CostFunction with
53// the action of camera extrinsics, i.e., rotation and
54// translation. Say we have a templated function
55//
56// template<typename T>
57// void RotateAndTranslatePoint(const T* rotation,
58// const T* translation,
59// const T* point,
60// T* result);
61//
62// Then we can now do the following,
63//
64// struct CameraProjection {
65// CameraProjection(double* observation) {
66// intrinsic_projection_.reset(
67// new CostFunctionToFunctor<2, 5, 3>(
68// new IntrinsicProjection(observation_)));
69// }
70// template <typename T>
71// bool operator()(const T* rotation,
72// const T* translation,
73// const T* intrinsics,
74// const T* point,
75// T* residual) const {
76// T transformed_point[3];
77// RotateAndTranslatePoint(rotation, translation, point, transformed_point);
78//
79// // Note that we call intrinsic_projection_, just like it was
80// // any other templated functor.
81//
82// return (*intrinsic_projection_)(intrinsics, transformed_point, residual);
83// }
84//
85// private:
86// scoped_ptr<CostFunctionToFunctor<2,5,3> > intrinsic_projection_;
87// };
88
89#ifndef CERES_PUBLIC_COST_FUNCTION_TO_FUNCTOR_H_
90#define CERES_PUBLIC_COST_FUNCTION_TO_FUNCTOR_H_
91
92#include <numeric>
93#include <vector>
94
95#include "ceres/cost_function.h"
96#include "ceres/internal/fixed_array.h"
97#include "ceres/internal/port.h"
98#include "ceres/internal/scoped_ptr.h"
99
100namespace ceres {
101
102template <int kNumResiduals,
103 int N0, int N1 = 0, int N2 = 0, int N3 = 0, int N4 = 0,
104 int N5 = 0, int N6 = 0, int N7 = 0, int N8 = 0, int N9 = 0>
105class CostFunctionToFunctor {
106 public:
107 explicit CostFunctionToFunctor(CostFunction* cost_function)
108 : cost_function_(cost_function) {
109 CHECK_NOTNULL(cost_function);
110
111 CHECK_GE(kNumResiduals, 0);
112 CHECK_EQ(cost_function->num_residuals(), kNumResiduals);
113
114 // This block breaks the 80 column rule to keep it somewhat readable.
115 CHECK((!N1 && !N2 && !N3 && !N4 && !N5 && !N6 && !N7 && !N8 && !N9) ||
116 ((N1 > 0) && !N2 && !N3 && !N4 && !N5 && !N6 && !N7 && !N8 && !N9) ||
117 ((N1 > 0) && (N2 > 0) && !N3 && !N4 && !N5 && !N6 && !N7 && !N8 && !N9) ||
118 ((N1 > 0) && (N2 > 0) && (N3 > 0) && !N4 && !N5 && !N6 && !N7 && !N8 && !N9) ||
119 ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && !N5 && !N6 && !N7 && !N8 && !N9) ||
120 ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && (N5 > 0) && !N6 && !N7 && !N8 && !N9) ||
121 ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && (N5 > 0) && (N6 > 0) && !N7 && !N8 && !N9) ||
122 ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && (N5 > 0) && (N6 > 0) && (N7 > 0) && !N8 && !N9) ||
123 ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && (N5 > 0) && (N6 > 0) && (N7 > 0) && (N8 > 0) && !N9) ||
124 ((N1 > 0) && (N2 > 0) && (N3 > 0) && (N4 > 0) && (N5 > 0) && (N6 > 0) && (N7 > 0) && (N8 > 0) && (N9 > 0)))
125 << "Zero block cannot precede a non-zero block. Block sizes are "
126 << "(ignore trailing 0s): " << N0 << ", " << N1 << ", " << N2 << ", "
127 << N3 << ", " << N4 << ", " << N5 << ", " << N6 << ", " << N7 << ", "
128 << N8 << ", " << N9;
129
Carlos Hernandez79397c22014-08-07 17:51:38 -0700130 const vector<int32>& parameter_block_sizes =
Sascha Haeberling1d2624a2013-07-23 19:00:21 -0700131 cost_function->parameter_block_sizes();
132 const int num_parameter_blocks =
133 (N0 > 0) + (N1 > 0) + (N2 > 0) + (N3 > 0) + (N4 > 0) +
134 (N5 > 0) + (N6 > 0) + (N7 > 0) + (N8 > 0) + (N9 > 0);
135 CHECK_EQ(parameter_block_sizes.size(), num_parameter_blocks);
136
137 CHECK_EQ(N0, parameter_block_sizes[0]);
138 if (parameter_block_sizes.size() > 1) CHECK_EQ(N1, parameter_block_sizes[1]); // NOLINT
139 if (parameter_block_sizes.size() > 2) CHECK_EQ(N2, parameter_block_sizes[2]); // NOLINT
140 if (parameter_block_sizes.size() > 3) CHECK_EQ(N3, parameter_block_sizes[3]); // NOLINT
141 if (parameter_block_sizes.size() > 4) CHECK_EQ(N4, parameter_block_sizes[4]); // NOLINT
142 if (parameter_block_sizes.size() > 5) CHECK_EQ(N5, parameter_block_sizes[5]); // NOLINT
143 if (parameter_block_sizes.size() > 6) CHECK_EQ(N6, parameter_block_sizes[6]); // NOLINT
144 if (parameter_block_sizes.size() > 7) CHECK_EQ(N7, parameter_block_sizes[7]); // NOLINT
145 if (parameter_block_sizes.size() > 8) CHECK_EQ(N8, parameter_block_sizes[8]); // NOLINT
146 if (parameter_block_sizes.size() > 9) CHECK_EQ(N9, parameter_block_sizes[9]); // NOLINT
147
148 CHECK_EQ(accumulate(parameter_block_sizes.begin(),
149 parameter_block_sizes.end(), 0),
150 N0 + N1 + N2 + N3 + N4 + N5 + N6 + N7 + N8 + N9);
151 }
152
153 bool operator()(const double* x0, double* residuals) const {
154 CHECK_NE(N0, 0);
155 CHECK_EQ(N1, 0);
156 CHECK_EQ(N2, 0);
157 CHECK_EQ(N3, 0);
158 CHECK_EQ(N4, 0);
159 CHECK_EQ(N5, 0);
160 CHECK_EQ(N6, 0);
161 CHECK_EQ(N7, 0);
162 CHECK_EQ(N8, 0);
163 CHECK_EQ(N9, 0);
164
165 return cost_function_->Evaluate(&x0, residuals, NULL);
166 }
167
168 bool operator()(const double* x0,
169 const double* x1,
170 double* residuals) const {
171 CHECK_NE(N0, 0);
172 CHECK_NE(N1, 0);
173 CHECK_EQ(N2, 0);
174 CHECK_EQ(N3, 0);
175 CHECK_EQ(N4, 0);
176 CHECK_EQ(N5, 0);
177 CHECK_EQ(N6, 0);
178 CHECK_EQ(N7, 0);
179 CHECK_EQ(N8, 0);
180 CHECK_EQ(N9, 0);
181 internal::FixedArray<const double*> parameter_blocks(2);
182 parameter_blocks[0] = x0;
183 parameter_blocks[1] = x1;
184 return cost_function_->Evaluate(parameter_blocks.get(), residuals, NULL);
185 }
186
187 bool operator()(const double* x0,
188 const double* x1,
189 const double* x2,
190 double* residuals) const {
191 CHECK_NE(N0, 0);
192 CHECK_NE(N1, 0);
193 CHECK_NE(N2, 0);
194 CHECK_EQ(N3, 0);
195 CHECK_EQ(N4, 0);
196 CHECK_EQ(N5, 0);
197 CHECK_EQ(N6, 0);
198 CHECK_EQ(N7, 0);
199 CHECK_EQ(N8, 0);
200 CHECK_EQ(N9, 0);
201 internal::FixedArray<const double*> parameter_blocks(3);
202 parameter_blocks[0] = x0;
203 parameter_blocks[1] = x1;
204 parameter_blocks[2] = x2;
205 return cost_function_->Evaluate(parameter_blocks.get(), residuals, NULL);
206 }
207
208 bool operator()(const double* x0,
209 const double* x1,
210 const double* x2,
211 const double* x3,
212 double* residuals) const {
213 CHECK_NE(N0, 0);
214 CHECK_NE(N1, 0);
215 CHECK_NE(N2, 0);
216 CHECK_NE(N3, 0);
217 CHECK_EQ(N4, 0);
218 CHECK_EQ(N5, 0);
219 CHECK_EQ(N6, 0);
220 CHECK_EQ(N7, 0);
221 CHECK_EQ(N8, 0);
222 CHECK_EQ(N9, 0);
223 internal::FixedArray<const double*> parameter_blocks(4);
224 parameter_blocks[0] = x0;
225 parameter_blocks[1] = x1;
226 parameter_blocks[2] = x2;
227 parameter_blocks[3] = x3;
228 return cost_function_->Evaluate(parameter_blocks.get(), residuals, NULL);
229 }
230
231 bool operator()(const double* x0,
232 const double* x1,
233 const double* x2,
234 const double* x3,
235 const double* x4,
236 double* residuals) const {
237 CHECK_NE(N0, 0);
238 CHECK_NE(N1, 0);
239 CHECK_NE(N2, 0);
240 CHECK_NE(N3, 0);
241 CHECK_NE(N4, 0);
242 CHECK_EQ(N5, 0);
243 CHECK_EQ(N6, 0);
244 CHECK_EQ(N7, 0);
245 CHECK_EQ(N8, 0);
246 CHECK_EQ(N9, 0);
247 internal::FixedArray<const double*> parameter_blocks(5);
248 parameter_blocks[0] = x0;
249 parameter_blocks[1] = x1;
250 parameter_blocks[2] = x2;
251 parameter_blocks[3] = x3;
252 parameter_blocks[4] = x4;
253 return cost_function_->Evaluate(parameter_blocks.get(), residuals, NULL);
254 }
255
256 bool operator()(const double* x0,
257 const double* x1,
258 const double* x2,
259 const double* x3,
260 const double* x4,
261 const double* x5,
262 double* residuals) const {
263 CHECK_NE(N0, 0);
264 CHECK_NE(N1, 0);
265 CHECK_NE(N2, 0);
266 CHECK_NE(N3, 0);
267 CHECK_NE(N4, 0);
268 CHECK_NE(N5, 0);
269 CHECK_EQ(N6, 0);
270 CHECK_EQ(N7, 0);
271 CHECK_EQ(N8, 0);
272 CHECK_EQ(N9, 0);
273 internal::FixedArray<const double*> parameter_blocks(6);
274 parameter_blocks[0] = x0;
275 parameter_blocks[1] = x1;
276 parameter_blocks[2] = x2;
277 parameter_blocks[3] = x3;
278 parameter_blocks[4] = x4;
279 parameter_blocks[5] = x5;
280 return cost_function_->Evaluate(parameter_blocks.get(), residuals, NULL);
281 }
282
283 bool operator()(const double* x0,
284 const double* x1,
285 const double* x2,
286 const double* x3,
287 const double* x4,
288 const double* x5,
289 const double* x6,
290 double* residuals) const {
291 CHECK_NE(N0, 0);
292 CHECK_NE(N1, 0);
293 CHECK_NE(N2, 0);
294 CHECK_NE(N3, 0);
295 CHECK_NE(N4, 0);
296 CHECK_NE(N5, 0);
297 CHECK_NE(N6, 0);
298 CHECK_EQ(N7, 0);
299 CHECK_EQ(N8, 0);
300 CHECK_EQ(N9, 0);
301 internal::FixedArray<const double*> parameter_blocks(7);
302 parameter_blocks[0] = x0;
303 parameter_blocks[1] = x1;
304 parameter_blocks[2] = x2;
305 parameter_blocks[3] = x3;
306 parameter_blocks[4] = x4;
307 parameter_blocks[5] = x5;
308 parameter_blocks[6] = x6;
309 return cost_function_->Evaluate(parameter_blocks.get(), residuals, NULL);
310 }
311
312 bool operator()(const double* x0,
313 const double* x1,
314 const double* x2,
315 const double* x3,
316 const double* x4,
317 const double* x5,
318 const double* x6,
319 const double* x7,
320 double* residuals) const {
321 CHECK_NE(N0, 0);
322 CHECK_NE(N1, 0);
323 CHECK_NE(N2, 0);
324 CHECK_NE(N3, 0);
325 CHECK_NE(N4, 0);
326 CHECK_NE(N5, 0);
327 CHECK_NE(N6, 0);
328 CHECK_NE(N7, 0);
329 CHECK_EQ(N8, 0);
330 CHECK_EQ(N9, 0);
331 internal::FixedArray<const double*> parameter_blocks(8);
332 parameter_blocks[0] = x0;
333 parameter_blocks[1] = x1;
334 parameter_blocks[2] = x2;
335 parameter_blocks[3] = x3;
336 parameter_blocks[4] = x4;
337 parameter_blocks[5] = x5;
338 parameter_blocks[6] = x6;
339 parameter_blocks[7] = x7;
340 return cost_function_->Evaluate(parameter_blocks.get(), residuals, NULL);
341 }
342
343 bool operator()(const double* x0,
344 const double* x1,
345 const double* x2,
346 const double* x3,
347 const double* x4,
348 const double* x5,
349 const double* x6,
350 const double* x7,
351 const double* x8,
352 double* residuals) const {
353 CHECK_NE(N0, 0);
354 CHECK_NE(N1, 0);
355 CHECK_NE(N2, 0);
356 CHECK_NE(N3, 0);
357 CHECK_NE(N4, 0);
358 CHECK_NE(N5, 0);
359 CHECK_NE(N6, 0);
360 CHECK_NE(N7, 0);
361 CHECK_NE(N8, 0);
362 CHECK_EQ(N9, 0);
363 internal::FixedArray<const double*> parameter_blocks(9);
364 parameter_blocks[0] = x0;
365 parameter_blocks[1] = x1;
366 parameter_blocks[2] = x2;
367 parameter_blocks[3] = x3;
368 parameter_blocks[4] = x4;
369 parameter_blocks[5] = x5;
370 parameter_blocks[6] = x6;
371 parameter_blocks[7] = x7;
372 parameter_blocks[8] = x8;
373 return cost_function_->Evaluate(parameter_blocks.get(), residuals, NULL);
374 }
375
376 bool operator()(const double* x0,
377 const double* x1,
378 const double* x2,
379 const double* x3,
380 const double* x4,
381 const double* x5,
382 const double* x6,
383 const double* x7,
384 const double* x8,
385 const double* x9,
386 double* residuals) const {
387 CHECK_NE(N0, 0);
388 CHECK_NE(N1, 0);
389 CHECK_NE(N2, 0);
390 CHECK_NE(N3, 0);
391 CHECK_NE(N4, 0);
392 CHECK_NE(N5, 0);
393 CHECK_NE(N6, 0);
394 CHECK_NE(N7, 0);
395 CHECK_NE(N8, 0);
396 CHECK_NE(N9, 0);
397 internal::FixedArray<const double*> parameter_blocks(10);
398 parameter_blocks[0] = x0;
399 parameter_blocks[1] = x1;
400 parameter_blocks[2] = x2;
401 parameter_blocks[3] = x3;
402 parameter_blocks[4] = x4;
403 parameter_blocks[5] = x5;
404 parameter_blocks[6] = x6;
405 parameter_blocks[7] = x7;
406 parameter_blocks[8] = x8;
407 parameter_blocks[9] = x9;
408 return cost_function_->Evaluate(parameter_blocks.get(), residuals, NULL);
409 }
410
411 template <typename JetT>
412 bool operator()(const JetT* x0, JetT* residuals) const {
413 CHECK_NE(N0, 0);
414 CHECK_EQ(N1, 0);
415 CHECK_EQ(N2, 0);
416 CHECK_EQ(N3, 0);
417 CHECK_EQ(N4, 0);
418 CHECK_EQ(N5, 0);
419 CHECK_EQ(N6, 0);
420 CHECK_EQ(N7, 0);
421 CHECK_EQ(N8, 0);
422 CHECK_EQ(N9, 0);
423 return EvaluateWithJets(&x0, residuals);
424 }
425
426 template <typename JetT>
427 bool operator()(const JetT* x0,
428 const JetT* x1,
429 JetT* residuals) const {
430 CHECK_NE(N0, 0);
431 CHECK_NE(N1, 0);
432 CHECK_EQ(N2, 0);
433 CHECK_EQ(N3, 0);
434 CHECK_EQ(N4, 0);
435 CHECK_EQ(N5, 0);
436 CHECK_EQ(N6, 0);
437 CHECK_EQ(N7, 0);
438 CHECK_EQ(N8, 0);
439 CHECK_EQ(N9, 0);
440 internal::FixedArray<const JetT*> jets(2);
441 jets[0] = x0;
442 jets[1] = x1;
443 return EvaluateWithJets(jets.get(), residuals);
444 }
445
446 template <typename JetT>
447 bool operator()(const JetT* x0,
448 const JetT* x1,
449 const JetT* x2,
450 JetT* residuals) const {
451 CHECK_NE(N0, 0);
452 CHECK_NE(N1, 0);
453 CHECK_NE(N2, 0);
454 CHECK_EQ(N3, 0);
455 CHECK_EQ(N4, 0);
456 CHECK_EQ(N5, 0);
457 CHECK_EQ(N6, 0);
458 CHECK_EQ(N7, 0);
459 CHECK_EQ(N8, 0);
460 CHECK_EQ(N9, 0);
461 internal::FixedArray<const JetT*> jets(3);
462 jets[0] = x0;
463 jets[1] = x1;
464 jets[2] = x2;
465 return EvaluateWithJets(jets.get(), residuals);
466 }
467
468 template <typename JetT>
469 bool operator()(const JetT* x0,
470 const JetT* x1,
471 const JetT* x2,
472 const JetT* x3,
473 JetT* residuals) const {
474 CHECK_NE(N0, 0);
475 CHECK_NE(N1, 0);
476 CHECK_NE(N2, 0);
477 CHECK_NE(N3, 0);
478 CHECK_EQ(N4, 0);
479 CHECK_EQ(N5, 0);
480 CHECK_EQ(N6, 0);
481 CHECK_EQ(N7, 0);
482 CHECK_EQ(N8, 0);
483 CHECK_EQ(N9, 0);
484 internal::FixedArray<const JetT*> jets(4);
485 jets[0] = x0;
486 jets[1] = x1;
487 jets[2] = x2;
488 jets[3] = x3;
489 return EvaluateWithJets(jets.get(), residuals);
490 }
491
492 template <typename JetT>
493 bool operator()(const JetT* x0,
494 const JetT* x1,
495 const JetT* x2,
496 const JetT* x3,
497 const JetT* x4,
498 JetT* residuals) const {
499 CHECK_NE(N0, 0);
500 CHECK_NE(N1, 0);
501 CHECK_NE(N2, 0);
502 CHECK_NE(N3, 0);
503 CHECK_NE(N4, 0);
504 CHECK_EQ(N5, 0);
505 CHECK_EQ(N6, 0);
506 CHECK_EQ(N7, 0);
507 CHECK_EQ(N8, 0);
508 CHECK_EQ(N9, 0);
509 internal::FixedArray<const JetT*> jets(5);
510 jets[0] = x0;
511 jets[1] = x1;
512 jets[2] = x2;
513 jets[3] = x3;
514 jets[4] = x4;
515 return EvaluateWithJets(jets.get(), residuals);
516 }
517
518 template <typename JetT>
519 bool operator()(const JetT* x0,
520 const JetT* x1,
521 const JetT* x2,
522 const JetT* x3,
523 const JetT* x4,
524 const JetT* x5,
525 JetT* residuals) const {
526 CHECK_NE(N0, 0);
527 CHECK_NE(N1, 0);
528 CHECK_NE(N2, 0);
529 CHECK_NE(N3, 0);
530 CHECK_NE(N4, 0);
531 CHECK_NE(N5, 0);
532 CHECK_EQ(N6, 0);
533 CHECK_EQ(N7, 0);
534 CHECK_EQ(N8, 0);
535 CHECK_EQ(N9, 0);
536 internal::FixedArray<const JetT*> jets(6);
537 jets[0] = x0;
538 jets[1] = x1;
539 jets[2] = x2;
540 jets[3] = x3;
541 jets[4] = x4;
542 jets[5] = x5;
543 return EvaluateWithJets(jets.get(), residuals);
544 }
545
546 template <typename JetT>
547 bool operator()(const JetT* x0,
548 const JetT* x1,
549 const JetT* x2,
550 const JetT* x3,
551 const JetT* x4,
552 const JetT* x5,
553 const JetT* x6,
554 JetT* residuals) const {
555 CHECK_NE(N0, 0);
556 CHECK_NE(N1, 0);
557 CHECK_NE(N2, 0);
558 CHECK_NE(N3, 0);
559 CHECK_NE(N4, 0);
560 CHECK_NE(N5, 0);
561 CHECK_NE(N6, 0);
562 CHECK_EQ(N7, 0);
563 CHECK_EQ(N8, 0);
564 CHECK_EQ(N9, 0);
565 internal::FixedArray<const JetT*> jets(7);
566 jets[0] = x0;
567 jets[1] = x1;
568 jets[2] = x2;
569 jets[3] = x3;
570 jets[4] = x4;
571 jets[5] = x5;
572 jets[6] = x6;
573 return EvaluateWithJets(jets.get(), residuals);
574 }
575
576 template <typename JetT>
577 bool operator()(const JetT* x0,
578 const JetT* x1,
579 const JetT* x2,
580 const JetT* x3,
581 const JetT* x4,
582 const JetT* x5,
583 const JetT* x6,
584 const JetT* x7,
585 JetT* residuals) const {
586 CHECK_NE(N0, 0);
587 CHECK_NE(N1, 0);
588 CHECK_NE(N2, 0);
589 CHECK_NE(N3, 0);
590 CHECK_NE(N4, 0);
591 CHECK_NE(N5, 0);
592 CHECK_NE(N6, 0);
593 CHECK_NE(N7, 0);
594 CHECK_EQ(N8, 0);
595 CHECK_EQ(N9, 0);
596 internal::FixedArray<const JetT*> jets(8);
597 jets[0] = x0;
598 jets[1] = x1;
599 jets[2] = x2;
600 jets[3] = x3;
601 jets[4] = x4;
602 jets[5] = x5;
603 jets[6] = x6;
604 jets[7] = x7;
605 return EvaluateWithJets(jets.get(), residuals);
606 }
607
608 template <typename JetT>
609 bool operator()(const JetT* x0,
610 const JetT* x1,
611 const JetT* x2,
612 const JetT* x3,
613 const JetT* x4,
614 const JetT* x5,
615 const JetT* x6,
616 const JetT* x7,
617 const JetT* x8,
618 JetT* residuals) const {
619 CHECK_NE(N0, 0);
620 CHECK_NE(N1, 0);
621 CHECK_NE(N2, 0);
622 CHECK_NE(N3, 0);
623 CHECK_NE(N4, 0);
624 CHECK_NE(N5, 0);
625 CHECK_NE(N6, 0);
626 CHECK_NE(N7, 0);
627 CHECK_NE(N8, 0);
628 CHECK_EQ(N9, 0);
629 internal::FixedArray<const JetT*> jets(9);
630 jets[0] = x0;
631 jets[1] = x1;
632 jets[2] = x2;
633 jets[3] = x3;
634 jets[4] = x4;
635 jets[5] = x5;
636 jets[6] = x6;
637 jets[7] = x7;
638 jets[8] = x8;
639 return EvaluateWithJets(jets.get(), residuals);
640 }
641
642 template <typename JetT>
643 bool operator()(const JetT* x0,
644 const JetT* x1,
645 const JetT* x2,
646 const JetT* x3,
647 const JetT* x4,
648 const JetT* x5,
649 const JetT* x6,
650 const JetT* x7,
651 const JetT* x8,
652 const JetT* x9,
653 JetT* residuals) const {
654 CHECK_NE(N0, 0);
655 CHECK_NE(N1, 0);
656 CHECK_NE(N2, 0);
657 CHECK_NE(N3, 0);
658 CHECK_NE(N4, 0);
659 CHECK_NE(N5, 0);
660 CHECK_NE(N6, 0);
661 CHECK_NE(N7, 0);
662 CHECK_NE(N8, 0);
663 CHECK_NE(N9, 0);
664 internal::FixedArray<const JetT*> jets(10);
665 jets[0] = x0;
666 jets[1] = x1;
667 jets[2] = x2;
668 jets[3] = x3;
669 jets[4] = x4;
670 jets[5] = x5;
671 jets[6] = x6;
672 jets[7] = x7;
673 jets[8] = x8;
674 jets[9] = x9;
675 return EvaluateWithJets(jets.get(), residuals);
676 }
677
678 private:
679 template <typename JetT>
680 bool EvaluateWithJets(const JetT** inputs, JetT* output) const {
681 const int kNumParameters = N0 + N1 + N2 + N3 + N4 + N5 + N6 + N7 + N8 + N9;
Carlos Hernandez79397c22014-08-07 17:51:38 -0700682 const vector<int32>& parameter_block_sizes =
Sascha Haeberling1d2624a2013-07-23 19:00:21 -0700683 cost_function_->parameter_block_sizes();
684 const int num_parameter_blocks = parameter_block_sizes.size();
685 const int num_residuals = cost_function_->num_residuals();
686
687 internal::FixedArray<double> parameters(kNumParameters);
688 internal::FixedArray<double*> parameter_blocks(num_parameter_blocks);
689 internal::FixedArray<double> jacobians(num_residuals * kNumParameters);
690 internal::FixedArray<double*> jacobian_blocks(num_parameter_blocks);
691 internal::FixedArray<double> residuals(num_residuals);
692
693 // Build a set of arrays to get the residuals and jacobians from
694 // the CostFunction wrapped by this functor.
695 double* parameter_ptr = parameters.get();
696 double* jacobian_ptr = jacobians.get();
697 for (int i = 0; i < num_parameter_blocks; ++i) {
698 parameter_blocks[i] = parameter_ptr;
699 jacobian_blocks[i] = jacobian_ptr;
700 for (int j = 0; j < parameter_block_sizes[i]; ++j) {
701 *parameter_ptr++ = inputs[i][j].a;
702 }
703 jacobian_ptr += num_residuals * parameter_block_sizes[i];
704 }
705
706 if (!cost_function_->Evaluate(parameter_blocks.get(),
707 residuals.get(),
708 jacobian_blocks.get())) {
709 return false;
710 }
711
712 // Now that we have the incoming Jets, which are carrying the
713 // partial derivatives of each of the inputs w.r.t to some other
714 // underlying parameters. The derivative of the outputs of the
715 // cost function w.r.t to the same underlying parameters can now
716 // be computed by applying the chain rule.
717 //
718 // d output[i] d output[i] d input[j]
719 // -------------- = sum_j ----------- * ------------
720 // d parameter[k] d input[j] d parameter[k]
721 //
722 // d input[j]
723 // -------------- = inputs[j], so
724 // d parameter[k]
725 //
726 // outputJet[i] = sum_k jacobian[i][k] * inputJet[k]
727 //
728 // The following loop, iterates over the residuals, computing one
729 // output jet at a time.
730 for (int i = 0; i < num_residuals; ++i) {
731 output[i].a = residuals[i];
732 output[i].v.setZero();
733
734 for (int j = 0; j < num_parameter_blocks; ++j) {
Carlos Hernandez79397c22014-08-07 17:51:38 -0700735 const int32 block_size = parameter_block_sizes[j];
Sascha Haeberling1d2624a2013-07-23 19:00:21 -0700736 for (int k = 0; k < parameter_block_sizes[j]; ++k) {
737 output[i].v +=
738 jacobian_blocks[j][i * block_size + k] * inputs[j][k].v;
739 }
740 }
741 }
742
743 return true;
744 }
745
746 private:
747 internal::scoped_ptr<CostFunction> cost_function_;
748};
749
750} // namespace ceres
751
752#endif // CERES_PUBLIC_COST_FUNCTION_TO_FUNCTOR_H_