blob: 7088f364568d862a96acf814fcaa770e5a48bd75 [file] [log] [blame]
Lang Hames6699fb22009-08-06 23:32:48 +00001#ifndef LLVM_CODEGEN_PBQP_HEURISTICSOLVER_H
2#define LLVM_CODEGEN_PBQP_HEURISTICSOLVER_H
3
4#include "Solver.h"
5#include "AnnotatedGraph.h"
6
7#include <limits>
8#include <iostream>
9
10namespace PBQP {
11
12/// \brief Important types for the HeuristicSolverImpl.
13///
14/// Declared seperately to allow access to heuristic classes before the solver
15/// is fully constructed.
16template <typename HeuristicNodeData, typename HeuristicEdgeData>
17class HSITypes {
18public:
19
20 class NodeData;
21 class EdgeData;
22
23 typedef AnnotatedGraph<NodeData, EdgeData> SolverGraph;
24 typedef typename SolverGraph::NodeIterator GraphNodeIterator;
25 typedef typename SolverGraph::EdgeIterator GraphEdgeIterator;
26 typedef typename SolverGraph::AdjEdgeIterator GraphAdjEdgeIterator;
27
28 typedef std::list<GraphNodeIterator> NodeList;
29 typedef typename NodeList::iterator NodeListIterator;
30
31 typedef std::vector<GraphNodeIterator> NodeStack;
32 typedef typename NodeStack::iterator NodeStackIterator;
33
34 class NodeData {
35 friend class EdgeData;
36
37 private:
38
39 typedef std::list<GraphEdgeIterator> LinksList;
40
41 unsigned numLinks;
42 LinksList links, solvedLinks;
43 NodeListIterator bucketItr;
44 HeuristicNodeData heuristicData;
45
46 public:
47
48 typedef typename LinksList::iterator AdjLinkIterator;
49
50 private:
51
52 AdjLinkIterator addLink(const GraphEdgeIterator &edgeItr) {
53 ++numLinks;
54 return links.insert(links.end(), edgeItr);
55 }
56
57 void delLink(const AdjLinkIterator &adjLinkItr) {
58 --numLinks;
59 links.erase(adjLinkItr);
60 }
61
62 public:
63
64 NodeData() : numLinks(0) {}
65
66 unsigned getLinkDegree() const { return numLinks; }
67
68 HeuristicNodeData& getHeuristicData() { return heuristicData; }
69 const HeuristicNodeData& getHeuristicData() const {
70 return heuristicData;
71 }
72
73 void setBucketItr(const NodeListIterator &bucketItr) {
74 this->bucketItr = bucketItr;
75 }
76
77 const NodeListIterator& getBucketItr() const {
78 return bucketItr;
79 }
80
81 AdjLinkIterator adjLinksBegin() {
82 return links.begin();
83 }
84
85 AdjLinkIterator adjLinksEnd() {
86 return links.end();
87 }
88
89 void addSolvedLink(const GraphEdgeIterator &solvedLinkItr) {
90 solvedLinks.push_back(solvedLinkItr);
91 }
92
93 AdjLinkIterator solvedLinksBegin() {
94 return solvedLinks.begin();
95 }
96
97 AdjLinkIterator solvedLinksEnd() {
98 return solvedLinks.end();
99 }
100
101 };
102
103 class EdgeData {
104 private:
105
106 SolverGraph &g;
107 GraphNodeIterator node1Itr, node2Itr;
108 HeuristicEdgeData heuristicData;
109 typename NodeData::AdjLinkIterator node1ThisEdgeItr, node2ThisEdgeItr;
110
111 public:
112
113 EdgeData(SolverGraph &g) : g(g) {}
114
115 HeuristicEdgeData& getHeuristicData() { return heuristicData; }
116 const HeuristicEdgeData& getHeuristicData() const {
117 return heuristicData;
118 }
119
120 void setup(const GraphEdgeIterator &thisEdgeItr) {
121 node1Itr = g.getEdgeNode1Itr(thisEdgeItr);
122 node2Itr = g.getEdgeNode2Itr(thisEdgeItr);
123
124 node1ThisEdgeItr = g.getNodeData(node1Itr).addLink(thisEdgeItr);
125 node2ThisEdgeItr = g.getNodeData(node2Itr).addLink(thisEdgeItr);
126 }
127
128 void unlink() {
129 g.getNodeData(node1Itr).delLink(node1ThisEdgeItr);
130 g.getNodeData(node2Itr).delLink(node2ThisEdgeItr);
131 }
132
133 };
134
135};
136
137template <typename Heuristic>
138class HeuristicSolverImpl {
139public:
140 // Typedefs to make life easier:
141 typedef HSITypes<typename Heuristic::NodeData,
142 typename Heuristic::EdgeData> HSIT;
143 typedef typename HSIT::SolverGraph SolverGraph;
144 typedef typename HSIT::NodeData NodeData;
145 typedef typename HSIT::EdgeData EdgeData;
146 typedef typename HSIT::GraphNodeIterator GraphNodeIterator;
147 typedef typename HSIT::GraphEdgeIterator GraphEdgeIterator;
148 typedef typename HSIT::GraphAdjEdgeIterator GraphAdjEdgeIterator;
149
150 typedef typename HSIT::NodeList NodeList;
151 typedef typename HSIT::NodeListIterator NodeListIterator;
152
153 typedef std::vector<GraphNodeIterator> NodeStack;
154 typedef typename NodeStack::iterator NodeStackIterator;
155
156 /*!
157 * \brief Constructor, which performs all the actual solver work.
158 */
159 HeuristicSolverImpl(const SimpleGraph &orig) :
160 solution(orig.getNumNodes(), true)
161 {
162 copyGraph(orig);
163 simplify();
164 setup();
165 computeSolution();
166 computeSolutionCost(orig);
167 }
168
169 /*!
170 * \brief Returns the graph for this solver.
171 */
172 SolverGraph& getGraph() { return g; }
173
174 /*!
175 * \brief Return the solution found by this solver.
176 */
177 const Solution& getSolution() const { return solution; }
178
179private:
180
181 /*!
182 * \brief Add the given node to the appropriate bucket for its link
183 * degree.
184 */
185 void addToBucket(const GraphNodeIterator &nodeItr) {
186 NodeData &nodeData = g.getNodeData(nodeItr);
187
188 switch (nodeData.getLinkDegree()) {
189 case 0: nodeData.setBucketItr(
190 r0Bucket.insert(r0Bucket.end(), nodeItr));
191 break;
192 case 1: nodeData.setBucketItr(
193 r1Bucket.insert(r1Bucket.end(), nodeItr));
194 break;
195 case 2: nodeData.setBucketItr(
196 r2Bucket.insert(r2Bucket.end(), nodeItr));
197 break;
198 default: heuristic.addToRNBucket(nodeItr);
199 break;
200 }
201 }
202
203 /*!
204 * \brief Remove the given node from the appropriate bucket for its link
205 * degree.
206 */
207 void removeFromBucket(const GraphNodeIterator &nodeItr) {
208 NodeData &nodeData = g.getNodeData(nodeItr);
209
210 switch (nodeData.getLinkDegree()) {
211 case 0: r0Bucket.erase(nodeData.getBucketItr()); break;
212 case 1: r1Bucket.erase(nodeData.getBucketItr()); break;
213 case 2: r2Bucket.erase(nodeData.getBucketItr()); break;
214 default: heuristic.removeFromRNBucket(nodeItr); break;
215 }
216 }
217
218public:
219
220 /*!
221 * \brief Add a link.
222 */
223 void addLink(const GraphEdgeIterator &edgeItr) {
224 g.getEdgeData(edgeItr).setup(edgeItr);
225
226 if ((g.getNodeData(g.getEdgeNode1Itr(edgeItr)).getLinkDegree() > 2) ||
227 (g.getNodeData(g.getEdgeNode2Itr(edgeItr)).getLinkDegree() > 2)) {
228 heuristic.handleAddLink(edgeItr);
229 }
230 }
231
232 /*!
233 * \brief Remove link, update info for node.
234 *
235 * Only updates information for the given node, since usually the other
236 * is about to be removed.
237 */
238 void removeLink(const GraphEdgeIterator &edgeItr,
239 const GraphNodeIterator &nodeItr) {
240
241 if (g.getNodeData(nodeItr).getLinkDegree() > 2) {
242 heuristic.handleRemoveLink(edgeItr, nodeItr);
243 }
244 g.getEdgeData(edgeItr).unlink();
245 }
246
247 /*!
248 * \brief Remove link, update info for both nodes. Useful for R2 only.
249 */
250 void removeLinkR2(const GraphEdgeIterator &edgeItr) {
251 GraphNodeIterator node1Itr = g.getEdgeNode1Itr(edgeItr);
252
253 if (g.getNodeData(node1Itr).getLinkDegree() > 2) {
254 heuristic.handleRemoveLink(edgeItr, node1Itr);
255 }
256 removeLink(edgeItr, g.getEdgeNode2Itr(edgeItr));
257 }
258
259 /*!
260 * \brief Removes all links connected to the given node.
261 */
262 void unlinkNode(const GraphNodeIterator &nodeItr) {
263 NodeData &nodeData = g.getNodeData(nodeItr);
264
265 typedef std::vector<GraphEdgeIterator> TempEdgeList;
266
267 TempEdgeList edgesToUnlink;
268 edgesToUnlink.reserve(nodeData.getLinkDegree());
269
270 // Copy adj edges into a temp vector. We want to destroy them during
271 // the unlink, and we can't do that while we're iterating over them.
272 std::copy(nodeData.adjLinksBegin(), nodeData.adjLinksEnd(),
273 std::back_inserter(edgesToUnlink));
274
275 for (typename TempEdgeList::iterator
276 edgeItr = edgesToUnlink.begin(), edgeEnd = edgesToUnlink.end();
277 edgeItr != edgeEnd; ++edgeItr) {
278
279 GraphNodeIterator otherNode = g.getEdgeOtherNode(*edgeItr, nodeItr);
280
281 removeFromBucket(otherNode);
282 removeLink(*edgeItr, otherNode);
283 addToBucket(otherNode);
284 }
285 }
286
287 /*!
288 * \brief Push the given node onto the stack to be solved with
289 * backpropagation.
290 */
291 void pushStack(const GraphNodeIterator &nodeItr) {
292 stack.push_back(nodeItr);
293 }
294
295 /*!
296 * \brief Set the solution of the given node.
297 */
298 void setSolution(const GraphNodeIterator &nodeItr, unsigned solIndex) {
299 solution.setSelection(g.getNodeID(nodeItr), solIndex);
300
301 for (GraphAdjEdgeIterator adjEdgeItr = g.adjEdgesBegin(nodeItr),
302 adjEdgeEnd = g.adjEdgesEnd(nodeItr);
303 adjEdgeItr != adjEdgeEnd; ++adjEdgeItr) {
304 GraphEdgeIterator edgeItr(*adjEdgeItr);
305 GraphNodeIterator adjNodeItr(g.getEdgeOtherNode(edgeItr, nodeItr));
306 g.getNodeData(adjNodeItr).addSolvedLink(edgeItr);
307 }
308 }
309
310private:
311
312 SolverGraph g;
313 Heuristic heuristic;
314 Solution solution;
315
316 NodeList r0Bucket,
317 r1Bucket,
318 r2Bucket;
319
320 NodeStack stack;
321
322 // Copy the SimpleGraph into an annotated graph which we can use for reduction.
323 void copyGraph(const SimpleGraph &orig) {
324
325 assert((g.getNumEdges() == 0) && (g.getNumNodes() == 0) &&
326 "Graph should be empty prior to solver setup.");
327
328 assert(orig.areNodeIDsValid() &&
329 "Cannot copy from a graph with invalid node IDs.");
330
331 std::vector<GraphNodeIterator> newNodeItrs;
332
333 for (unsigned nodeID = 0; nodeID < orig.getNumNodes(); ++nodeID) {
334 newNodeItrs.push_back(
335 g.addNode(orig.getNodeCosts(orig.getNodeItr(nodeID)), NodeData()));
336 }
337
338 for (SimpleGraph::ConstEdgeIterator
339 origEdgeItr = orig.edgesBegin(), origEdgeEnd = orig.edgesEnd();
340 origEdgeItr != origEdgeEnd; ++origEdgeItr) {
341
342 unsigned id1 = orig.getNodeID(orig.getEdgeNode1Itr(origEdgeItr)),
343 id2 = orig.getNodeID(orig.getEdgeNode2Itr(origEdgeItr));
344
345 g.addEdge(newNodeItrs[id1], newNodeItrs[id2],
346 orig.getEdgeCosts(origEdgeItr), EdgeData(g));
347 }
348
349 // Assign IDs to the new nodes using the ordering from the old graph,
350 // this will lead to nodes in the new graph getting the same ID as the
351 // corresponding node in the old graph.
352 g.assignNodeIDs(newNodeItrs);
353 }
354
355 // Simplify the annotated graph by eliminating independent edges and trivial
356 // nodes.
357 void simplify() {
358 disconnectTrivialNodes();
359 eliminateIndependentEdges();
360 }
361
362 // Eliminate trivial nodes.
363 void disconnectTrivialNodes() {
364 for (GraphNodeIterator nodeItr = g.nodesBegin(), nodeEnd = g.nodesEnd();
365 nodeItr != nodeEnd; ++nodeItr) {
366
367 if (g.getNodeCosts(nodeItr).getLength() == 1) {
368
369 std::vector<GraphEdgeIterator> edgesToRemove;
370
371 for (GraphAdjEdgeIterator adjEdgeItr = g.adjEdgesBegin(nodeItr),
372 adjEdgeEnd = g.adjEdgesEnd(nodeItr);
373 adjEdgeItr != adjEdgeEnd; ++adjEdgeItr) {
374
375 GraphEdgeIterator edgeItr = *adjEdgeItr;
376
377 if (g.getEdgeNode1Itr(edgeItr) == nodeItr) {
378 GraphNodeIterator otherNodeItr = g.getEdgeNode2Itr(edgeItr);
379 g.getNodeCosts(otherNodeItr) +=
380 g.getEdgeCosts(edgeItr).getRowAsVector(0);
381 }
382 else {
383 GraphNodeIterator otherNodeItr = g.getEdgeNode1Itr(edgeItr);
384 g.getNodeCosts(otherNodeItr) +=
385 g.getEdgeCosts(edgeItr).getColAsVector(0);
386 }
387
388 edgesToRemove.push_back(edgeItr);
389 }
390
391 while (!edgesToRemove.empty()) {
392 g.removeEdge(edgesToRemove.back());
393 edgesToRemove.pop_back();
394 }
395 }
396 }
397 }
398
399 void eliminateIndependentEdges() {
400 std::vector<GraphEdgeIterator> edgesToProcess;
401
402 for (GraphEdgeIterator edgeItr = g.edgesBegin(), edgeEnd = g.edgesEnd();
403 edgeItr != edgeEnd; ++edgeItr) {
404 edgesToProcess.push_back(edgeItr);
405 }
406
407 while (!edgesToProcess.empty()) {
408 tryToEliminateEdge(edgesToProcess.back());
409 edgesToProcess.pop_back();
410 }
411 }
412
413 void tryToEliminateEdge(const GraphEdgeIterator &edgeItr) {
414 if (tryNormaliseEdgeMatrix(edgeItr)) {
415 g.removeEdge(edgeItr);
416 }
417 }
418
419 bool tryNormaliseEdgeMatrix(const GraphEdgeIterator &edgeItr) {
420
421 Matrix &edgeCosts = g.getEdgeCosts(edgeItr);
422 Vector &uCosts = g.getNodeCosts(g.getEdgeNode1Itr(edgeItr)),
423 &vCosts = g.getNodeCosts(g.getEdgeNode2Itr(edgeItr));
424
425 for (unsigned r = 0; r < edgeCosts.getRows(); ++r) {
426 PBQPNum rowMin = edgeCosts.getRowMin(r);
427 uCosts[r] += rowMin;
428 if (rowMin != std::numeric_limits<PBQPNum>::infinity()) {
429 edgeCosts.subFromRow(r, rowMin);
430 }
431 else {
432 edgeCosts.setRow(r, 0);
433 }
434 }
435
436 for (unsigned c = 0; c < edgeCosts.getCols(); ++c) {
437 PBQPNum colMin = edgeCosts.getColMin(c);
438 vCosts[c] += colMin;
439 if (colMin != std::numeric_limits<PBQPNum>::infinity()) {
440 edgeCosts.subFromCol(c, colMin);
441 }
442 else {
443 edgeCosts.setCol(c, 0);
444 }
445 }
446
447 return edgeCosts.isZero();
448 }
449
450 void setup() {
451 setupLinks();
452 heuristic.initialise(*this);
453 setupBuckets();
454 }
455
456 void setupLinks() {
457 for (GraphEdgeIterator edgeItr = g.edgesBegin(), edgeEnd = g.edgesEnd();
458 edgeItr != edgeEnd; ++edgeItr) {
459 g.getEdgeData(edgeItr).setup(edgeItr);
460 }
461 }
462
463 void setupBuckets() {
464 for (GraphNodeIterator nodeItr = g.nodesBegin(), nodeEnd = g.nodesEnd();
465 nodeItr != nodeEnd; ++nodeItr) {
466 addToBucket(nodeItr);
467 }
468 }
469
470 void computeSolution() {
471 assert(g.areNodeIDsValid() &&
472 "Nodes cannot be added/removed during reduction.");
473
474 reduce();
475 computeTrivialSolutions();
476 backpropagate();
477 }
478
479 void printNode(const GraphNodeIterator &nodeItr) {
480
481 std::cerr << "Node " << g.getNodeID(nodeItr) << " (" << &*nodeItr << "):\n"
482 << " costs = " << g.getNodeCosts(nodeItr) << "\n"
483 << " link degree = " << g.getNodeData(nodeItr).getLinkDegree() << "\n"
484 << " links = [ ";
485
486 for (typename HSIT::NodeData::AdjLinkIterator
487 aeItr = g.getNodeData(nodeItr).adjLinksBegin(),
488 aeEnd = g.getNodeData(nodeItr).adjLinksEnd();
489 aeItr != aeEnd; ++aeItr) {
490 std::cerr << "(" << g.getNodeID(g.getEdgeNode1Itr(*aeItr))
491 << ", " << g.getNodeID(g.getEdgeNode2Itr(*aeItr))
492 << ") ";
493 }
494 std::cout << "]\n";
495 }
496
497 void dumpState() {
498
499 std::cerr << "\n";
500
501 for (GraphNodeIterator nodeItr = g.nodesBegin(), nodeEnd = g.nodesEnd();
502 nodeItr != nodeEnd; ++nodeItr) {
503 printNode(nodeItr);
504 }
505
506 NodeList* buckets[] = { &r0Bucket, &r1Bucket, &r2Bucket };
507
508 for (unsigned b = 0; b < 3; ++b) {
509 NodeList &bucket = *buckets[b];
510
511 std::cerr << "Bucket " << b << ": [ ";
512
513 for (NodeListIterator nItr = bucket.begin(), nEnd = bucket.end();
514 nItr != nEnd; ++nItr) {
515 std::cerr << g.getNodeID(*nItr) << " ";
516 }
517
518 std::cerr << "]\n";
519 }
520
521 std::cerr << "Stack: [ ";
522 for (NodeStackIterator nsItr = stack.begin(), nsEnd = stack.end();
523 nsItr != nsEnd; ++nsItr) {
524 std::cerr << g.getNodeID(*nsItr) << " ";
525 }
526 std::cerr << "]\n";
527 }
528
529 void reduce() {
530 bool reductionFinished = r1Bucket.empty() && r2Bucket.empty() &&
531 heuristic.rNBucketEmpty();
532
533 while (!reductionFinished) {
534
535 if (!r1Bucket.empty()) {
536 processR1();
537 }
538 else if (!r2Bucket.empty()) {
539 processR2();
540 }
541 else if (!heuristic.rNBucketEmpty()) {
542 solution.setProvedOptimal(false);
543 solution.incRNReductions();
544 heuristic.processRN();
545 }
546 else reductionFinished = true;
547 }
548
549 };
550
551 void processR1() {
552
553 // Remove the first node in the R0 bucket:
554 GraphNodeIterator xNodeItr = r1Bucket.front();
555 r1Bucket.pop_front();
556
557 solution.incR1Reductions();
558
559 //std::cerr << "Applying R1 to " << g.getNodeID(xNodeItr) << "\n";
560
561 assert((g.getNodeData(xNodeItr).getLinkDegree() == 1) &&
562 "Node in R1 bucket has degree != 1");
563
564 GraphEdgeIterator edgeItr = *g.getNodeData(xNodeItr).adjLinksBegin();
565
566 const Matrix &edgeCosts = g.getEdgeCosts(edgeItr);
567
568 const Vector &xCosts = g.getNodeCosts(xNodeItr);
569 unsigned xLen = xCosts.getLength();
570
571 // Duplicate a little code to avoid transposing matrices:
572 if (xNodeItr == g.getEdgeNode1Itr(edgeItr)) {
573 GraphNodeIterator yNodeItr = g.getEdgeNode2Itr(edgeItr);
574 Vector &yCosts = g.getNodeCosts(yNodeItr);
575 unsigned yLen = yCosts.getLength();
576
577 for (unsigned j = 0; j < yLen; ++j) {
578 PBQPNum min = edgeCosts[0][j] + xCosts[0];
579 for (unsigned i = 1; i < xLen; ++i) {
580 PBQPNum c = edgeCosts[i][j] + xCosts[i];
581 if (c < min)
582 min = c;
583 }
584 yCosts[j] += min;
585 }
586 }
587 else {
588 GraphNodeIterator yNodeItr = g.getEdgeNode1Itr(edgeItr);
589 Vector &yCosts = g.getNodeCosts(yNodeItr);
590 unsigned yLen = yCosts.getLength();
591
592 for (unsigned i = 0; i < yLen; ++i) {
593 PBQPNum min = edgeCosts[i][0] + xCosts[0];
594
595 for (unsigned j = 1; j < xLen; ++j) {
596 PBQPNum c = edgeCosts[i][j] + xCosts[j];
597 if (c < min)
598 min = c;
599 }
600 yCosts[i] += min;
601 }
602 }
603
604 unlinkNode(xNodeItr);
605 pushStack(xNodeItr);
606 }
607
608 void processR2() {
609
610 GraphNodeIterator xNodeItr = r2Bucket.front();
611 r2Bucket.pop_front();
612
613 solution.incR2Reductions();
614
615 // Unlink is unsafe here. At some point it may optimistically more a node
616 // to a lower-degree list when its degree will later rise, or vice versa,
617 // violating the assumption that node degrees monotonically decrease
618 // during the reduction phase. Instead we'll bucket shuffle manually.
619 pushStack(xNodeItr);
620
621 assert((g.getNodeData(xNodeItr).getLinkDegree() == 2) &&
622 "Node in R2 bucket has degree != 2");
623
624 const Vector &xCosts = g.getNodeCosts(xNodeItr);
625
626 typename NodeData::AdjLinkIterator tempItr =
627 g.getNodeData(xNodeItr).adjLinksBegin();
628
629 GraphEdgeIterator yxEdgeItr = *tempItr,
630 zxEdgeItr = *(++tempItr);
631
632 GraphNodeIterator yNodeItr = g.getEdgeOtherNode(yxEdgeItr, xNodeItr),
633 zNodeItr = g.getEdgeOtherNode(zxEdgeItr, xNodeItr);
634
635 removeFromBucket(yNodeItr);
636 removeFromBucket(zNodeItr);
637
638 removeLink(yxEdgeItr, yNodeItr);
639 removeLink(zxEdgeItr, zNodeItr);
640
641 // Graph some of the costs:
642 bool flipEdge1 = (g.getEdgeNode1Itr(yxEdgeItr) == xNodeItr),
643 flipEdge2 = (g.getEdgeNode1Itr(zxEdgeItr) == xNodeItr);
644
645 const Matrix *yxCosts = flipEdge1 ?
646 new Matrix(g.getEdgeCosts(yxEdgeItr).transpose()) :
647 &g.getEdgeCosts(yxEdgeItr),
648 *zxCosts = flipEdge2 ?
649 new Matrix(g.getEdgeCosts(zxEdgeItr).transpose()) :
650 &g.getEdgeCosts(zxEdgeItr);
651
652 unsigned xLen = xCosts.getLength(),
653 yLen = yxCosts->getRows(),
654 zLen = zxCosts->getRows();
655
656 // Compute delta:
657 Matrix delta(yLen, zLen);
658
659 for (unsigned i = 0; i < yLen; ++i) {
660 for (unsigned j = 0; j < zLen; ++j) {
661 PBQPNum min = (*yxCosts)[i][0] + (*zxCosts)[j][0] + xCosts[0];
662 for (unsigned k = 1; k < xLen; ++k) {
663 PBQPNum c = (*yxCosts)[i][k] + (*zxCosts)[j][k] + xCosts[k];
664 if (c < min) {
665 min = c;
666 }
667 }
668 delta[i][j] = min;
669 }
670 }
671
672 if (flipEdge1)
673 delete yxCosts;
674
675 if (flipEdge2)
676 delete zxCosts;
677
678 // Deal with the potentially induced yz edge.
679 GraphEdgeIterator yzEdgeItr = g.findEdge(yNodeItr, zNodeItr);
680 if (yzEdgeItr == g.edgesEnd()) {
681 yzEdgeItr = g.addEdge(yNodeItr, zNodeItr, delta, EdgeData(g));
682 }
683 else {
684 // There was an edge, but we're going to screw with it. Delete the old
685 // link, update the costs. We'll re-link it later.
686 removeLinkR2(yzEdgeItr);
687 g.getEdgeCosts(yzEdgeItr) +=
688 (yNodeItr == g.getEdgeNode1Itr(yzEdgeItr)) ?
689 delta : delta.transpose();
690 }
691
692 bool nullCostEdge = tryNormaliseEdgeMatrix(yzEdgeItr);
693
694 // Nulled the edge, remove it entirely.
695 if (nullCostEdge) {
696 g.removeEdge(yzEdgeItr);
697 }
698 else {
699 // Edge remains - re-link it.
700 addLink(yzEdgeItr);
701 }
702
703 addToBucket(yNodeItr);
704 addToBucket(zNodeItr);
705 }
706
707 void computeTrivialSolutions() {
708
709 for (NodeListIterator r0Itr = r0Bucket.begin(), r0End = r0Bucket.end();
710 r0Itr != r0End; ++r0Itr) {
711 GraphNodeIterator nodeItr = *r0Itr;
712
713 solution.incR0Reductions();
714 setSolution(nodeItr, g.getNodeCosts(nodeItr).minIndex());
715 }
716
717 }
718
719 void backpropagate() {
720 while (!stack.empty()) {
721 computeSolution(stack.back());
722 stack.pop_back();
723 }
724 }
725
726 void computeSolution(const GraphNodeIterator &nodeItr) {
727
728 NodeData &nodeData = g.getNodeData(nodeItr);
729
730 Vector v(g.getNodeCosts(nodeItr));
731
732 // Solve based on existing links.
733 for (typename NodeData::AdjLinkIterator
734 solvedLinkItr = nodeData.solvedLinksBegin(),
735 solvedLinkEnd = nodeData.solvedLinksEnd();
736 solvedLinkItr != solvedLinkEnd; ++solvedLinkItr) {
737
738 GraphEdgeIterator solvedEdgeItr(*solvedLinkItr);
739 Matrix &edgeCosts = g.getEdgeCosts(solvedEdgeItr);
740
741 if (nodeItr == g.getEdgeNode1Itr(solvedEdgeItr)) {
742 GraphNodeIterator adjNode(g.getEdgeNode2Itr(solvedEdgeItr));
743 unsigned adjSolution =
744 solution.getSelection(g.getNodeID(adjNode));
745 v += edgeCosts.getColAsVector(adjSolution);
746 }
747 else {
748 GraphNodeIterator adjNode(g.getEdgeNode1Itr(solvedEdgeItr));
749 unsigned adjSolution =
750 solution.getSelection(g.getNodeID(adjNode));
751 v += edgeCosts.getRowAsVector(adjSolution);
752 }
753
754 }
755
756 setSolution(nodeItr, v.minIndex());
757 }
758
759 void computeSolutionCost(const SimpleGraph &orig) {
760 PBQPNum cost = 0.0;
761
762 for (SimpleGraph::ConstNodeIterator
763 nodeItr = orig.nodesBegin(), nodeEnd = orig.nodesEnd();
764 nodeItr != nodeEnd; ++nodeItr) {
765
766 unsigned nodeId = orig.getNodeID(nodeItr);
767
768 cost += orig.getNodeCosts(nodeItr)[solution.getSelection(nodeId)];
769 }
770
771 for (SimpleGraph::ConstEdgeIterator
772 edgeItr = orig.edgesBegin(), edgeEnd = orig.edgesEnd();
773 edgeItr != edgeEnd; ++edgeItr) {
774
775 SimpleGraph::ConstNodeIterator n1 = orig.getEdgeNode1Itr(edgeItr),
776 n2 = orig.getEdgeNode2Itr(edgeItr);
777 unsigned sol1 = solution.getSelection(orig.getNodeID(n1)),
778 sol2 = solution.getSelection(orig.getNodeID(n2));
779
780 cost += orig.getEdgeCosts(edgeItr)[sol1][sol2];
781 }
782
783 solution.setSolutionCost(cost);
784 }
785
786};
787
788template <typename Heuristic>
789class HeuristicSolver : public Solver {
790public:
791 Solution solve(const SimpleGraph &g) const {
792 HeuristicSolverImpl<Heuristic> solverImpl(g);
793 return solverImpl.getSolution();
794 }
795};
796
797}
798
799#endif // LLVM_CODEGEN_PBQP_HEURISTICSOLVER_H