blob: a36ca78e2f3e2a9992fa6775b254464320a88d0a [file] [log] [blame]
Lang Hamescaaf1202009-08-07 00:25:12 +00001//===-- HeuristicSolver.h - Heuristic PBQP Solver --------------*- C++ --*-===//
2//
3// The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// Heuristic PBQP solver. This solver is able to perform optimal reductions for
11// nodes of degree 0, 1 or 2. For nodes of degree >2 a plugable heuristic is
12// used to to select a node for reduction.
13//
14//===----------------------------------------------------------------------===//
15
Lang Hames6699fb22009-08-06 23:32:48 +000016#ifndef LLVM_CODEGEN_PBQP_HEURISTICSOLVER_H
17#define LLVM_CODEGEN_PBQP_HEURISTICSOLVER_H
18
19#include "Solver.h"
20#include "AnnotatedGraph.h"
21
22#include <limits>
23#include <iostream>
24
25namespace PBQP {
26
27/// \brief Important types for the HeuristicSolverImpl.
28///
29/// Declared seperately to allow access to heuristic classes before the solver
30/// is fully constructed.
31template <typename HeuristicNodeData, typename HeuristicEdgeData>
32class HSITypes {
33public:
34
35 class NodeData;
36 class EdgeData;
37
38 typedef AnnotatedGraph<NodeData, EdgeData> SolverGraph;
39 typedef typename SolverGraph::NodeIterator GraphNodeIterator;
40 typedef typename SolverGraph::EdgeIterator GraphEdgeIterator;
41 typedef typename SolverGraph::AdjEdgeIterator GraphAdjEdgeIterator;
42
43 typedef std::list<GraphNodeIterator> NodeList;
44 typedef typename NodeList::iterator NodeListIterator;
45
46 typedef std::vector<GraphNodeIterator> NodeStack;
47 typedef typename NodeStack::iterator NodeStackIterator;
48
49 class NodeData {
50 friend class EdgeData;
51
52 private:
53
54 typedef std::list<GraphEdgeIterator> LinksList;
55
56 unsigned numLinks;
57 LinksList links, solvedLinks;
58 NodeListIterator bucketItr;
59 HeuristicNodeData heuristicData;
60
61 public:
62
63 typedef typename LinksList::iterator AdjLinkIterator;
64
65 private:
66
67 AdjLinkIterator addLink(const GraphEdgeIterator &edgeItr) {
68 ++numLinks;
69 return links.insert(links.end(), edgeItr);
70 }
71
72 void delLink(const AdjLinkIterator &adjLinkItr) {
73 --numLinks;
74 links.erase(adjLinkItr);
75 }
76
77 public:
78
79 NodeData() : numLinks(0) {}
80
81 unsigned getLinkDegree() const { return numLinks; }
82
83 HeuristicNodeData& getHeuristicData() { return heuristicData; }
84 const HeuristicNodeData& getHeuristicData() const {
85 return heuristicData;
86 }
87
88 void setBucketItr(const NodeListIterator &bucketItr) {
89 this->bucketItr = bucketItr;
90 }
91
92 const NodeListIterator& getBucketItr() const {
93 return bucketItr;
94 }
95
96 AdjLinkIterator adjLinksBegin() {
97 return links.begin();
98 }
99
100 AdjLinkIterator adjLinksEnd() {
101 return links.end();
102 }
103
104 void addSolvedLink(const GraphEdgeIterator &solvedLinkItr) {
105 solvedLinks.push_back(solvedLinkItr);
106 }
107
108 AdjLinkIterator solvedLinksBegin() {
109 return solvedLinks.begin();
110 }
111
112 AdjLinkIterator solvedLinksEnd() {
113 return solvedLinks.end();
114 }
115
116 };
117
118 class EdgeData {
119 private:
120
121 SolverGraph &g;
122 GraphNodeIterator node1Itr, node2Itr;
123 HeuristicEdgeData heuristicData;
124 typename NodeData::AdjLinkIterator node1ThisEdgeItr, node2ThisEdgeItr;
125
126 public:
127
128 EdgeData(SolverGraph &g) : g(g) {}
129
130 HeuristicEdgeData& getHeuristicData() { return heuristicData; }
131 const HeuristicEdgeData& getHeuristicData() const {
132 return heuristicData;
133 }
134
135 void setup(const GraphEdgeIterator &thisEdgeItr) {
136 node1Itr = g.getEdgeNode1Itr(thisEdgeItr);
137 node2Itr = g.getEdgeNode2Itr(thisEdgeItr);
138
139 node1ThisEdgeItr = g.getNodeData(node1Itr).addLink(thisEdgeItr);
140 node2ThisEdgeItr = g.getNodeData(node2Itr).addLink(thisEdgeItr);
141 }
142
143 void unlink() {
144 g.getNodeData(node1Itr).delLink(node1ThisEdgeItr);
145 g.getNodeData(node2Itr).delLink(node2ThisEdgeItr);
146 }
147
148 };
149
150};
151
152template <typename Heuristic>
153class HeuristicSolverImpl {
154public:
155 // Typedefs to make life easier:
156 typedef HSITypes<typename Heuristic::NodeData,
157 typename Heuristic::EdgeData> HSIT;
158 typedef typename HSIT::SolverGraph SolverGraph;
159 typedef typename HSIT::NodeData NodeData;
160 typedef typename HSIT::EdgeData EdgeData;
161 typedef typename HSIT::GraphNodeIterator GraphNodeIterator;
162 typedef typename HSIT::GraphEdgeIterator GraphEdgeIterator;
163 typedef typename HSIT::GraphAdjEdgeIterator GraphAdjEdgeIterator;
164
165 typedef typename HSIT::NodeList NodeList;
166 typedef typename HSIT::NodeListIterator NodeListIterator;
167
168 typedef std::vector<GraphNodeIterator> NodeStack;
169 typedef typename NodeStack::iterator NodeStackIterator;
170
Lang Hamescaaf1202009-08-07 00:25:12 +0000171 /// \brief Constructor, which performs all the actual solver work.
Lang Hames6699fb22009-08-06 23:32:48 +0000172 HeuristicSolverImpl(const SimpleGraph &orig) :
173 solution(orig.getNumNodes(), true)
174 {
175 copyGraph(orig);
176 simplify();
177 setup();
178 computeSolution();
179 computeSolutionCost(orig);
180 }
181
Lang Hamescaaf1202009-08-07 00:25:12 +0000182 /// \brief Returns the graph for this solver.
Lang Hames6699fb22009-08-06 23:32:48 +0000183 SolverGraph& getGraph() { return g; }
184
Lang Hamescaaf1202009-08-07 00:25:12 +0000185 /// \brief Return the solution found by this solver.
Lang Hames6699fb22009-08-06 23:32:48 +0000186 const Solution& getSolution() const { return solution; }
187
188private:
189
Lang Hamescaaf1202009-08-07 00:25:12 +0000190 /// \brief Add the given node to the appropriate bucket for its link
191 /// degree.
Lang Hames6699fb22009-08-06 23:32:48 +0000192 void addToBucket(const GraphNodeIterator &nodeItr) {
193 NodeData &nodeData = g.getNodeData(nodeItr);
194
195 switch (nodeData.getLinkDegree()) {
196 case 0: nodeData.setBucketItr(
197 r0Bucket.insert(r0Bucket.end(), nodeItr));
198 break;
199 case 1: nodeData.setBucketItr(
200 r1Bucket.insert(r1Bucket.end(), nodeItr));
201 break;
202 case 2: nodeData.setBucketItr(
203 r2Bucket.insert(r2Bucket.end(), nodeItr));
204 break;
205 default: heuristic.addToRNBucket(nodeItr);
206 break;
207 }
208 }
209
Lang Hamescaaf1202009-08-07 00:25:12 +0000210 /// \brief Remove the given node from the appropriate bucket for its link
211 /// degree.
Lang Hames6699fb22009-08-06 23:32:48 +0000212 void removeFromBucket(const GraphNodeIterator &nodeItr) {
213 NodeData &nodeData = g.getNodeData(nodeItr);
214
215 switch (nodeData.getLinkDegree()) {
216 case 0: r0Bucket.erase(nodeData.getBucketItr()); break;
217 case 1: r1Bucket.erase(nodeData.getBucketItr()); break;
218 case 2: r2Bucket.erase(nodeData.getBucketItr()); break;
219 default: heuristic.removeFromRNBucket(nodeItr); break;
220 }
221 }
222
223public:
224
Lang Hamescaaf1202009-08-07 00:25:12 +0000225 /// \brief Add a link.
Lang Hames6699fb22009-08-06 23:32:48 +0000226 void addLink(const GraphEdgeIterator &edgeItr) {
227 g.getEdgeData(edgeItr).setup(edgeItr);
228
229 if ((g.getNodeData(g.getEdgeNode1Itr(edgeItr)).getLinkDegree() > 2) ||
230 (g.getNodeData(g.getEdgeNode2Itr(edgeItr)).getLinkDegree() > 2)) {
231 heuristic.handleAddLink(edgeItr);
232 }
233 }
234
Lang Hamescaaf1202009-08-07 00:25:12 +0000235 /// \brief Remove link, update info for node.
236 ///
237 /// Only updates information for the given node, since usually the other
238 /// is about to be removed.
Lang Hames6699fb22009-08-06 23:32:48 +0000239 void removeLink(const GraphEdgeIterator &edgeItr,
240 const GraphNodeIterator &nodeItr) {
241
242 if (g.getNodeData(nodeItr).getLinkDegree() > 2) {
243 heuristic.handleRemoveLink(edgeItr, nodeItr);
244 }
245 g.getEdgeData(edgeItr).unlink();
246 }
247
Lang Hamescaaf1202009-08-07 00:25:12 +0000248 /// \brief Remove link, update info for both nodes. Useful for R2 only.
Lang Hames6699fb22009-08-06 23:32:48 +0000249 void removeLinkR2(const GraphEdgeIterator &edgeItr) {
250 GraphNodeIterator node1Itr = g.getEdgeNode1Itr(edgeItr);
251
252 if (g.getNodeData(node1Itr).getLinkDegree() > 2) {
253 heuristic.handleRemoveLink(edgeItr, node1Itr);
254 }
255 removeLink(edgeItr, g.getEdgeNode2Itr(edgeItr));
256 }
257
Lang Hamescaaf1202009-08-07 00:25:12 +0000258 /// \brief Removes all links connected to the given node.
Lang Hames6699fb22009-08-06 23:32:48 +0000259 void unlinkNode(const GraphNodeIterator &nodeItr) {
260 NodeData &nodeData = g.getNodeData(nodeItr);
261
262 typedef std::vector<GraphEdgeIterator> TempEdgeList;
263
264 TempEdgeList edgesToUnlink;
265 edgesToUnlink.reserve(nodeData.getLinkDegree());
266
267 // Copy adj edges into a temp vector. We want to destroy them during
268 // the unlink, and we can't do that while we're iterating over them.
269 std::copy(nodeData.adjLinksBegin(), nodeData.adjLinksEnd(),
270 std::back_inserter(edgesToUnlink));
271
272 for (typename TempEdgeList::iterator
273 edgeItr = edgesToUnlink.begin(), edgeEnd = edgesToUnlink.end();
274 edgeItr != edgeEnd; ++edgeItr) {
275
276 GraphNodeIterator otherNode = g.getEdgeOtherNode(*edgeItr, nodeItr);
277
278 removeFromBucket(otherNode);
279 removeLink(*edgeItr, otherNode);
280 addToBucket(otherNode);
281 }
282 }
283
Lang Hamescaaf1202009-08-07 00:25:12 +0000284 /// \brief Push the given node onto the stack to be solved with
285 /// backpropagation.
Lang Hames6699fb22009-08-06 23:32:48 +0000286 void pushStack(const GraphNodeIterator &nodeItr) {
287 stack.push_back(nodeItr);
288 }
289
Lang Hamescaaf1202009-08-07 00:25:12 +0000290 /// \brief Set the solution of the given node.
Lang Hames6699fb22009-08-06 23:32:48 +0000291 void setSolution(const GraphNodeIterator &nodeItr, unsigned solIndex) {
292 solution.setSelection(g.getNodeID(nodeItr), solIndex);
293
294 for (GraphAdjEdgeIterator adjEdgeItr = g.adjEdgesBegin(nodeItr),
295 adjEdgeEnd = g.adjEdgesEnd(nodeItr);
296 adjEdgeItr != adjEdgeEnd; ++adjEdgeItr) {
297 GraphEdgeIterator edgeItr(*adjEdgeItr);
298 GraphNodeIterator adjNodeItr(g.getEdgeOtherNode(edgeItr, nodeItr));
299 g.getNodeData(adjNodeItr).addSolvedLink(edgeItr);
300 }
301 }
302
303private:
304
305 SolverGraph g;
306 Heuristic heuristic;
307 Solution solution;
308
309 NodeList r0Bucket,
310 r1Bucket,
311 r2Bucket;
312
313 NodeStack stack;
314
315 // Copy the SimpleGraph into an annotated graph which we can use for reduction.
316 void copyGraph(const SimpleGraph &orig) {
317
318 assert((g.getNumEdges() == 0) && (g.getNumNodes() == 0) &&
319 "Graph should be empty prior to solver setup.");
320
321 assert(orig.areNodeIDsValid() &&
322 "Cannot copy from a graph with invalid node IDs.");
323
324 std::vector<GraphNodeIterator> newNodeItrs;
325
326 for (unsigned nodeID = 0; nodeID < orig.getNumNodes(); ++nodeID) {
327 newNodeItrs.push_back(
328 g.addNode(orig.getNodeCosts(orig.getNodeItr(nodeID)), NodeData()));
329 }
330
331 for (SimpleGraph::ConstEdgeIterator
332 origEdgeItr = orig.edgesBegin(), origEdgeEnd = orig.edgesEnd();
333 origEdgeItr != origEdgeEnd; ++origEdgeItr) {
334
335 unsigned id1 = orig.getNodeID(orig.getEdgeNode1Itr(origEdgeItr)),
336 id2 = orig.getNodeID(orig.getEdgeNode2Itr(origEdgeItr));
337
338 g.addEdge(newNodeItrs[id1], newNodeItrs[id2],
339 orig.getEdgeCosts(origEdgeItr), EdgeData(g));
340 }
341
342 // Assign IDs to the new nodes using the ordering from the old graph,
343 // this will lead to nodes in the new graph getting the same ID as the
344 // corresponding node in the old graph.
345 g.assignNodeIDs(newNodeItrs);
346 }
347
348 // Simplify the annotated graph by eliminating independent edges and trivial
349 // nodes.
350 void simplify() {
351 disconnectTrivialNodes();
352 eliminateIndependentEdges();
353 }
354
355 // Eliminate trivial nodes.
356 void disconnectTrivialNodes() {
357 for (GraphNodeIterator nodeItr = g.nodesBegin(), nodeEnd = g.nodesEnd();
358 nodeItr != nodeEnd; ++nodeItr) {
359
360 if (g.getNodeCosts(nodeItr).getLength() == 1) {
361
362 std::vector<GraphEdgeIterator> edgesToRemove;
363
364 for (GraphAdjEdgeIterator adjEdgeItr = g.adjEdgesBegin(nodeItr),
365 adjEdgeEnd = g.adjEdgesEnd(nodeItr);
366 adjEdgeItr != adjEdgeEnd; ++adjEdgeItr) {
367
368 GraphEdgeIterator edgeItr = *adjEdgeItr;
369
370 if (g.getEdgeNode1Itr(edgeItr) == nodeItr) {
371 GraphNodeIterator otherNodeItr = g.getEdgeNode2Itr(edgeItr);
372 g.getNodeCosts(otherNodeItr) +=
373 g.getEdgeCosts(edgeItr).getRowAsVector(0);
374 }
375 else {
376 GraphNodeIterator otherNodeItr = g.getEdgeNode1Itr(edgeItr);
377 g.getNodeCosts(otherNodeItr) +=
378 g.getEdgeCosts(edgeItr).getColAsVector(0);
379 }
380
381 edgesToRemove.push_back(edgeItr);
382 }
383
384 while (!edgesToRemove.empty()) {
385 g.removeEdge(edgesToRemove.back());
386 edgesToRemove.pop_back();
387 }
388 }
389 }
390 }
391
392 void eliminateIndependentEdges() {
393 std::vector<GraphEdgeIterator> edgesToProcess;
394
395 for (GraphEdgeIterator edgeItr = g.edgesBegin(), edgeEnd = g.edgesEnd();
396 edgeItr != edgeEnd; ++edgeItr) {
397 edgesToProcess.push_back(edgeItr);
398 }
399
400 while (!edgesToProcess.empty()) {
401 tryToEliminateEdge(edgesToProcess.back());
402 edgesToProcess.pop_back();
403 }
404 }
405
406 void tryToEliminateEdge(const GraphEdgeIterator &edgeItr) {
407 if (tryNormaliseEdgeMatrix(edgeItr)) {
408 g.removeEdge(edgeItr);
409 }
410 }
411
412 bool tryNormaliseEdgeMatrix(const GraphEdgeIterator &edgeItr) {
413
414 Matrix &edgeCosts = g.getEdgeCosts(edgeItr);
415 Vector &uCosts = g.getNodeCosts(g.getEdgeNode1Itr(edgeItr)),
416 &vCosts = g.getNodeCosts(g.getEdgeNode2Itr(edgeItr));
417
418 for (unsigned r = 0; r < edgeCosts.getRows(); ++r) {
419 PBQPNum rowMin = edgeCosts.getRowMin(r);
420 uCosts[r] += rowMin;
421 if (rowMin != std::numeric_limits<PBQPNum>::infinity()) {
422 edgeCosts.subFromRow(r, rowMin);
423 }
424 else {
425 edgeCosts.setRow(r, 0);
426 }
427 }
428
429 for (unsigned c = 0; c < edgeCosts.getCols(); ++c) {
430 PBQPNum colMin = edgeCosts.getColMin(c);
431 vCosts[c] += colMin;
432 if (colMin != std::numeric_limits<PBQPNum>::infinity()) {
433 edgeCosts.subFromCol(c, colMin);
434 }
435 else {
436 edgeCosts.setCol(c, 0);
437 }
438 }
439
440 return edgeCosts.isZero();
441 }
442
443 void setup() {
444 setupLinks();
445 heuristic.initialise(*this);
446 setupBuckets();
447 }
448
449 void setupLinks() {
450 for (GraphEdgeIterator edgeItr = g.edgesBegin(), edgeEnd = g.edgesEnd();
451 edgeItr != edgeEnd; ++edgeItr) {
452 g.getEdgeData(edgeItr).setup(edgeItr);
453 }
454 }
455
456 void setupBuckets() {
457 for (GraphNodeIterator nodeItr = g.nodesBegin(), nodeEnd = g.nodesEnd();
458 nodeItr != nodeEnd; ++nodeItr) {
459 addToBucket(nodeItr);
460 }
461 }
462
463 void computeSolution() {
464 assert(g.areNodeIDsValid() &&
465 "Nodes cannot be added/removed during reduction.");
466
467 reduce();
468 computeTrivialSolutions();
469 backpropagate();
470 }
471
472 void printNode(const GraphNodeIterator &nodeItr) {
473
474 std::cerr << "Node " << g.getNodeID(nodeItr) << " (" << &*nodeItr << "):\n"
475 << " costs = " << g.getNodeCosts(nodeItr) << "\n"
476 << " link degree = " << g.getNodeData(nodeItr).getLinkDegree() << "\n"
477 << " links = [ ";
478
479 for (typename HSIT::NodeData::AdjLinkIterator
480 aeItr = g.getNodeData(nodeItr).adjLinksBegin(),
481 aeEnd = g.getNodeData(nodeItr).adjLinksEnd();
482 aeItr != aeEnd; ++aeItr) {
483 std::cerr << "(" << g.getNodeID(g.getEdgeNode1Itr(*aeItr))
484 << ", " << g.getNodeID(g.getEdgeNode2Itr(*aeItr))
485 << ") ";
486 }
487 std::cout << "]\n";
488 }
489
490 void dumpState() {
491
492 std::cerr << "\n";
493
494 for (GraphNodeIterator nodeItr = g.nodesBegin(), nodeEnd = g.nodesEnd();
495 nodeItr != nodeEnd; ++nodeItr) {
496 printNode(nodeItr);
497 }
498
499 NodeList* buckets[] = { &r0Bucket, &r1Bucket, &r2Bucket };
500
501 for (unsigned b = 0; b < 3; ++b) {
502 NodeList &bucket = *buckets[b];
503
504 std::cerr << "Bucket " << b << ": [ ";
505
506 for (NodeListIterator nItr = bucket.begin(), nEnd = bucket.end();
507 nItr != nEnd; ++nItr) {
508 std::cerr << g.getNodeID(*nItr) << " ";
509 }
510
511 std::cerr << "]\n";
512 }
513
514 std::cerr << "Stack: [ ";
515 for (NodeStackIterator nsItr = stack.begin(), nsEnd = stack.end();
516 nsItr != nsEnd; ++nsItr) {
517 std::cerr << g.getNodeID(*nsItr) << " ";
518 }
519 std::cerr << "]\n";
520 }
521
522 void reduce() {
523 bool reductionFinished = r1Bucket.empty() && r2Bucket.empty() &&
524 heuristic.rNBucketEmpty();
525
526 while (!reductionFinished) {
527
528 if (!r1Bucket.empty()) {
529 processR1();
530 }
531 else if (!r2Bucket.empty()) {
532 processR2();
533 }
534 else if (!heuristic.rNBucketEmpty()) {
535 solution.setProvedOptimal(false);
536 solution.incRNReductions();
537 heuristic.processRN();
538 }
539 else reductionFinished = true;
540 }
541
542 };
543
544 void processR1() {
545
546 // Remove the first node in the R0 bucket:
547 GraphNodeIterator xNodeItr = r1Bucket.front();
548 r1Bucket.pop_front();
549
550 solution.incR1Reductions();
551
552 //std::cerr << "Applying R1 to " << g.getNodeID(xNodeItr) << "\n";
553
554 assert((g.getNodeData(xNodeItr).getLinkDegree() == 1) &&
555 "Node in R1 bucket has degree != 1");
556
557 GraphEdgeIterator edgeItr = *g.getNodeData(xNodeItr).adjLinksBegin();
558
559 const Matrix &edgeCosts = g.getEdgeCosts(edgeItr);
560
561 const Vector &xCosts = g.getNodeCosts(xNodeItr);
562 unsigned xLen = xCosts.getLength();
563
564 // Duplicate a little code to avoid transposing matrices:
565 if (xNodeItr == g.getEdgeNode1Itr(edgeItr)) {
566 GraphNodeIterator yNodeItr = g.getEdgeNode2Itr(edgeItr);
567 Vector &yCosts = g.getNodeCosts(yNodeItr);
568 unsigned yLen = yCosts.getLength();
569
570 for (unsigned j = 0; j < yLen; ++j) {
571 PBQPNum min = edgeCosts[0][j] + xCosts[0];
572 for (unsigned i = 1; i < xLen; ++i) {
573 PBQPNum c = edgeCosts[i][j] + xCosts[i];
574 if (c < min)
575 min = c;
576 }
577 yCosts[j] += min;
578 }
579 }
580 else {
581 GraphNodeIterator yNodeItr = g.getEdgeNode1Itr(edgeItr);
582 Vector &yCosts = g.getNodeCosts(yNodeItr);
583 unsigned yLen = yCosts.getLength();
584
585 for (unsigned i = 0; i < yLen; ++i) {
586 PBQPNum min = edgeCosts[i][0] + xCosts[0];
587
588 for (unsigned j = 1; j < xLen; ++j) {
589 PBQPNum c = edgeCosts[i][j] + xCosts[j];
590 if (c < min)
591 min = c;
592 }
593 yCosts[i] += min;
594 }
595 }
596
597 unlinkNode(xNodeItr);
598 pushStack(xNodeItr);
599 }
600
601 void processR2() {
602
603 GraphNodeIterator xNodeItr = r2Bucket.front();
604 r2Bucket.pop_front();
605
606 solution.incR2Reductions();
607
608 // Unlink is unsafe here. At some point it may optimistically more a node
609 // to a lower-degree list when its degree will later rise, or vice versa,
610 // violating the assumption that node degrees monotonically decrease
611 // during the reduction phase. Instead we'll bucket shuffle manually.
612 pushStack(xNodeItr);
613
614 assert((g.getNodeData(xNodeItr).getLinkDegree() == 2) &&
615 "Node in R2 bucket has degree != 2");
616
617 const Vector &xCosts = g.getNodeCosts(xNodeItr);
618
619 typename NodeData::AdjLinkIterator tempItr =
620 g.getNodeData(xNodeItr).adjLinksBegin();
621
622 GraphEdgeIterator yxEdgeItr = *tempItr,
623 zxEdgeItr = *(++tempItr);
624
625 GraphNodeIterator yNodeItr = g.getEdgeOtherNode(yxEdgeItr, xNodeItr),
626 zNodeItr = g.getEdgeOtherNode(zxEdgeItr, xNodeItr);
627
628 removeFromBucket(yNodeItr);
629 removeFromBucket(zNodeItr);
630
631 removeLink(yxEdgeItr, yNodeItr);
632 removeLink(zxEdgeItr, zNodeItr);
633
634 // Graph some of the costs:
635 bool flipEdge1 = (g.getEdgeNode1Itr(yxEdgeItr) == xNodeItr),
636 flipEdge2 = (g.getEdgeNode1Itr(zxEdgeItr) == xNodeItr);
637
638 const Matrix *yxCosts = flipEdge1 ?
639 new Matrix(g.getEdgeCosts(yxEdgeItr).transpose()) :
640 &g.getEdgeCosts(yxEdgeItr),
641 *zxCosts = flipEdge2 ?
642 new Matrix(g.getEdgeCosts(zxEdgeItr).transpose()) :
643 &g.getEdgeCosts(zxEdgeItr);
644
645 unsigned xLen = xCosts.getLength(),
646 yLen = yxCosts->getRows(),
647 zLen = zxCosts->getRows();
648
649 // Compute delta:
650 Matrix delta(yLen, zLen);
651
652 for (unsigned i = 0; i < yLen; ++i) {
653 for (unsigned j = 0; j < zLen; ++j) {
654 PBQPNum min = (*yxCosts)[i][0] + (*zxCosts)[j][0] + xCosts[0];
655 for (unsigned k = 1; k < xLen; ++k) {
656 PBQPNum c = (*yxCosts)[i][k] + (*zxCosts)[j][k] + xCosts[k];
657 if (c < min) {
658 min = c;
659 }
660 }
661 delta[i][j] = min;
662 }
663 }
664
665 if (flipEdge1)
666 delete yxCosts;
667
668 if (flipEdge2)
669 delete zxCosts;
670
671 // Deal with the potentially induced yz edge.
672 GraphEdgeIterator yzEdgeItr = g.findEdge(yNodeItr, zNodeItr);
673 if (yzEdgeItr == g.edgesEnd()) {
674 yzEdgeItr = g.addEdge(yNodeItr, zNodeItr, delta, EdgeData(g));
675 }
676 else {
677 // There was an edge, but we're going to screw with it. Delete the old
678 // link, update the costs. We'll re-link it later.
679 removeLinkR2(yzEdgeItr);
680 g.getEdgeCosts(yzEdgeItr) +=
681 (yNodeItr == g.getEdgeNode1Itr(yzEdgeItr)) ?
682 delta : delta.transpose();
683 }
684
685 bool nullCostEdge = tryNormaliseEdgeMatrix(yzEdgeItr);
686
687 // Nulled the edge, remove it entirely.
688 if (nullCostEdge) {
689 g.removeEdge(yzEdgeItr);
690 }
691 else {
692 // Edge remains - re-link it.
693 addLink(yzEdgeItr);
694 }
695
696 addToBucket(yNodeItr);
697 addToBucket(zNodeItr);
698 }
699
700 void computeTrivialSolutions() {
701
702 for (NodeListIterator r0Itr = r0Bucket.begin(), r0End = r0Bucket.end();
703 r0Itr != r0End; ++r0Itr) {
704 GraphNodeIterator nodeItr = *r0Itr;
705
706 solution.incR0Reductions();
707 setSolution(nodeItr, g.getNodeCosts(nodeItr).minIndex());
708 }
709
710 }
711
712 void backpropagate() {
713 while (!stack.empty()) {
714 computeSolution(stack.back());
715 stack.pop_back();
716 }
717 }
718
719 void computeSolution(const GraphNodeIterator &nodeItr) {
720
721 NodeData &nodeData = g.getNodeData(nodeItr);
722
723 Vector v(g.getNodeCosts(nodeItr));
724
725 // Solve based on existing links.
726 for (typename NodeData::AdjLinkIterator
727 solvedLinkItr = nodeData.solvedLinksBegin(),
728 solvedLinkEnd = nodeData.solvedLinksEnd();
729 solvedLinkItr != solvedLinkEnd; ++solvedLinkItr) {
730
731 GraphEdgeIterator solvedEdgeItr(*solvedLinkItr);
732 Matrix &edgeCosts = g.getEdgeCosts(solvedEdgeItr);
733
734 if (nodeItr == g.getEdgeNode1Itr(solvedEdgeItr)) {
735 GraphNodeIterator adjNode(g.getEdgeNode2Itr(solvedEdgeItr));
736 unsigned adjSolution =
737 solution.getSelection(g.getNodeID(adjNode));
738 v += edgeCosts.getColAsVector(adjSolution);
739 }
740 else {
741 GraphNodeIterator adjNode(g.getEdgeNode1Itr(solvedEdgeItr));
742 unsigned adjSolution =
743 solution.getSelection(g.getNodeID(adjNode));
744 v += edgeCosts.getRowAsVector(adjSolution);
745 }
746
747 }
748
749 setSolution(nodeItr, v.minIndex());
750 }
751
752 void computeSolutionCost(const SimpleGraph &orig) {
753 PBQPNum cost = 0.0;
754
755 for (SimpleGraph::ConstNodeIterator
756 nodeItr = orig.nodesBegin(), nodeEnd = orig.nodesEnd();
757 nodeItr != nodeEnd; ++nodeItr) {
758
759 unsigned nodeId = orig.getNodeID(nodeItr);
760
761 cost += orig.getNodeCosts(nodeItr)[solution.getSelection(nodeId)];
762 }
763
764 for (SimpleGraph::ConstEdgeIterator
765 edgeItr = orig.edgesBegin(), edgeEnd = orig.edgesEnd();
766 edgeItr != edgeEnd; ++edgeItr) {
767
768 SimpleGraph::ConstNodeIterator n1 = orig.getEdgeNode1Itr(edgeItr),
769 n2 = orig.getEdgeNode2Itr(edgeItr);
770 unsigned sol1 = solution.getSelection(orig.getNodeID(n1)),
771 sol2 = solution.getSelection(orig.getNodeID(n2));
772
773 cost += orig.getEdgeCosts(edgeItr)[sol1][sol2];
774 }
775
776 solution.setSolutionCost(cost);
777 }
778
779};
780
781template <typename Heuristic>
782class HeuristicSolver : public Solver {
783public:
784 Solution solve(const SimpleGraph &g) const {
785 HeuristicSolverImpl<Heuristic> solverImpl(g);
786 return solverImpl.getSolution();
787 }
788};
789
790}
791
792#endif // LLVM_CODEGEN_PBQP_HEURISTICSOLVER_H