blob: c4787f4afd64d27bf612e7060ab2afc2df9d822e [file] [log] [blame]
Arnaud A. de Grandmaisonc75dbbb2014-09-10 14:06:10 +00001//===-- AArch64PBQPRegAlloc.cpp - AArch64 specific PBQP constraints -------===//
2//
3// The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9// This file contains the AArch64 / Cortex-A57 specific register allocation
10// constraints for use by the PBQP register allocator.
11//
12// It is essentially a transcription of what is contained in
13// AArch64A57FPLoadBalancing, which tries to use a balanced
14// mix of odd and even D-registers when performing a critical sequence of
15// independent, non-quadword FP/ASIMD floating-point multiply-accumulates.
16//===----------------------------------------------------------------------===//
17
18#define DEBUG_TYPE "aarch64-pbqp"
19
20#include "AArch64.h"
21#include "AArch64RegisterInfo.h"
22
23#include "llvm/ADT/SetVector.h"
24#include "llvm/CodeGen/LiveIntervalAnalysis.h"
25#include "llvm/CodeGen/MachineBasicBlock.h"
26#include "llvm/CodeGen/MachineFunction.h"
27#include "llvm/CodeGen/MachineRegisterInfo.h"
28#include "llvm/CodeGen/RegAllocPBQP.h"
29#include "llvm/Support/Debug.h"
30#include "llvm/Support/ErrorHandling.h"
31#include "llvm/Support/raw_ostream.h"
32
33#define PBQP_BUILDER PBQPBuilderWithCoalescing
34//#define PBQP_BUILDER PBQPBuilder
35
36using namespace llvm;
37
38namespace {
39
40bool isFPReg(unsigned reg) {
41 return AArch64::FPR32RegClass.contains(reg) ||
42 AArch64::FPR64RegClass.contains(reg) ||
43 AArch64::FPR128RegClass.contains(reg);
44};
45
46bool isOdd(unsigned reg) {
47 switch (reg) {
48 default:
49 llvm_unreachable("Register is not from the expected class !");
50 case AArch64::S1:
51 case AArch64::S3:
52 case AArch64::S5:
53 case AArch64::S7:
54 case AArch64::S9:
55 case AArch64::S11:
56 case AArch64::S13:
57 case AArch64::S15:
58 case AArch64::S17:
59 case AArch64::S19:
60 case AArch64::S21:
61 case AArch64::S23:
62 case AArch64::S25:
63 case AArch64::S27:
64 case AArch64::S29:
65 case AArch64::S31:
66 case AArch64::D1:
67 case AArch64::D3:
68 case AArch64::D5:
69 case AArch64::D7:
70 case AArch64::D9:
71 case AArch64::D11:
72 case AArch64::D13:
73 case AArch64::D15:
74 case AArch64::D17:
75 case AArch64::D19:
76 case AArch64::D21:
77 case AArch64::D23:
78 case AArch64::D25:
79 case AArch64::D27:
80 case AArch64::D29:
81 case AArch64::D31:
82 case AArch64::Q1:
83 case AArch64::Q3:
84 case AArch64::Q5:
85 case AArch64::Q7:
86 case AArch64::Q9:
87 case AArch64::Q11:
88 case AArch64::Q13:
89 case AArch64::Q15:
90 case AArch64::Q17:
91 case AArch64::Q19:
92 case AArch64::Q21:
93 case AArch64::Q23:
94 case AArch64::Q25:
95 case AArch64::Q27:
96 case AArch64::Q29:
97 case AArch64::Q31:
98 return true;
99 case AArch64::S0:
100 case AArch64::S2:
101 case AArch64::S4:
102 case AArch64::S6:
103 case AArch64::S8:
104 case AArch64::S10:
105 case AArch64::S12:
106 case AArch64::S14:
107 case AArch64::S16:
108 case AArch64::S18:
109 case AArch64::S20:
110 case AArch64::S22:
111 case AArch64::S24:
112 case AArch64::S26:
113 case AArch64::S28:
114 case AArch64::S30:
115 case AArch64::D0:
116 case AArch64::D2:
117 case AArch64::D4:
118 case AArch64::D6:
119 case AArch64::D8:
120 case AArch64::D10:
121 case AArch64::D12:
122 case AArch64::D14:
123 case AArch64::D16:
124 case AArch64::D18:
125 case AArch64::D20:
126 case AArch64::D22:
127 case AArch64::D24:
128 case AArch64::D26:
129 case AArch64::D28:
130 case AArch64::D30:
131 case AArch64::Q0:
132 case AArch64::Q2:
133 case AArch64::Q4:
134 case AArch64::Q6:
135 case AArch64::Q8:
136 case AArch64::Q10:
137 case AArch64::Q12:
138 case AArch64::Q14:
139 case AArch64::Q16:
140 case AArch64::Q18:
141 case AArch64::Q20:
142 case AArch64::Q22:
143 case AArch64::Q24:
144 case AArch64::Q26:
145 case AArch64::Q28:
146 case AArch64::Q30:
147 return false;
148
149 }
150}
151
152bool haveSameParity(unsigned reg1, unsigned reg2) {
153 assert(isFPReg(reg1) && "Expecting an FP register for reg1");
154 assert(isFPReg(reg2) && "Expecting an FP register for reg2");
155
156 return isOdd(reg1) == isOdd(reg2);
157}
158
159class A57PBQPBuilder : public PBQP_BUILDER {
160public:
161 A57PBQPBuilder() : PBQP_BUILDER(), TRI(nullptr), LIs(nullptr), Chains() {}
162
163 // Build a PBQP instance to represent the register allocation problem for
164 // the given MachineFunction.
165 std::unique_ptr<PBQPRAProblem>
166 build(MachineFunction *MF, const LiveIntervals *LI,
167 const MachineBlockFrequencyInfo *blockInfo,
168 const RegSet &VRegs) override;
169
170private:
171 const AArch64RegisterInfo *TRI;
172 const LiveIntervals *LIs;
173 SmallSetVector<unsigned, 32> Chains;
174
175 // Return true if reg is a physical register
176 bool isPhysicalReg(unsigned reg) const {
177 return TRI->isPhysicalRegister(reg);
178 }
179
180 // Add the accumulator chaining constraint, inside the chain, i.e. so that
181 // parity(Rd) == parity(Ra).
182 // \return true if a constraint was added
183 bool addIntraChainConstraint(PBQPRAProblem *p, unsigned Rd, unsigned Ra);
184
185 // Add constraints between existing chains
186 void addInterChainConstraint(PBQPRAProblem *p, unsigned Rd, unsigned Ra);
187};
188} // Anonymous namespace
189
190bool A57PBQPBuilder::addIntraChainConstraint(PBQPRAProblem *p, unsigned Rd,
191 unsigned Ra) {
192 if (Rd == Ra)
193 return false;
194
195 if (isPhysicalReg(Rd) || isPhysicalReg(Ra)) {
196 dbgs() << "Rd is a physical reg:" << isPhysicalReg(Rd) << '\n';
197 dbgs() << "Ra is a physical reg:" << isPhysicalReg(Ra) << '\n';
198 return false;
199 }
200
201 const PBQPRAProblem::AllowedSet *vRdAllowed = &p->getAllowedSet(Rd);
202 const PBQPRAProblem::AllowedSet *vRaAllowed = &p->getAllowedSet(Ra);
203
204 PBQPRAGraph &g = p->getGraph();
205 PBQPRAGraph::NodeId node1 = p->getNodeForVReg(Rd);
206 PBQPRAGraph::NodeId node2 = p->getNodeForVReg(Ra);
207 PBQPRAGraph::EdgeId edge = g.findEdge(node1, node2);
208
209 // The edge does not exist. Create one with the appropriate interference
210 // costs.
211 if (edge == g.invalidEdgeId()) {
212 const LiveInterval &ld = LIs->getInterval(Rd);
213 const LiveInterval &la = LIs->getInterval(Ra);
214 bool livesOverlap = ld.overlaps(la);
215
216 PBQP::Matrix costs(vRdAllowed->size() + 1, vRaAllowed->size() + 1, 0);
217 for (unsigned i = 0; i != vRdAllowed->size(); ++i) {
218 unsigned pRd = (*vRdAllowed)[i];
219 for (unsigned j = 0; j != vRaAllowed->size(); ++j) {
220 unsigned pRa = (*vRaAllowed)[j];
221 if (livesOverlap && TRI->regsOverlap(pRd, pRa))
222 costs[i + 1][j + 1] = std::numeric_limits<PBQP::PBQPNum>::infinity();
223 else
224 costs[i + 1][j + 1] = haveSameParity(pRd, pRa) ? 0.0 : 1.0;
225 }
226 }
227 g.addEdge(node1, node2, std::move(costs));
228 return true;
229 }
230
231 if (g.getEdgeNode1Id(edge) == node2) {
232 std::swap(node1, node2);
233 std::swap(vRdAllowed, vRaAllowed);
234 }
235
236 // Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass))
237 PBQP::Matrix costs(g.getEdgeCosts(edge));
238 for (unsigned i = 0; i != vRdAllowed->size(); ++i) {
239 unsigned pRd = (*vRdAllowed)[i];
240
241 // Get the maximum cost (excluding unallocatable reg) for same parity
242 // registers
243 PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
244 for (unsigned j = 0; j != vRaAllowed->size(); ++j) {
245 unsigned pRa = (*vRaAllowed)[j];
246 if (haveSameParity(pRd, pRa))
247 if (costs[i + 1][j + 1] !=
248 std::numeric_limits<PBQP::PBQPNum>::infinity() &&
249 costs[i + 1][j + 1] > sameParityMax)
250 sameParityMax = costs[i + 1][j + 1];
251 }
252
253 // Ensure all registers with a different parity have a higher cost
254 // than sameParityMax
255 for (unsigned j = 0; j != vRaAllowed->size(); ++j) {
256 unsigned pRa = (*vRaAllowed)[j];
257 if (!haveSameParity(pRd, pRa))
258 if (sameParityMax > costs[i + 1][j + 1])
259 costs[i + 1][j + 1] = sameParityMax + 1.0;
260 }
261 }
262 g.setEdgeCosts(edge, costs);
263
264 return true;
265}
266
267void
268A57PBQPBuilder::addInterChainConstraint(PBQPRAProblem *p, unsigned Rd,
269 unsigned Ra) {
270 // Do some Chain management
271 if (Chains.count(Ra)) {
272 if (Rd != Ra) {
273 DEBUG(dbgs() << "Moving acc chain from " << PrintReg(Ra, TRI) << " to "
274 << PrintReg(Rd, TRI) << '\n';);
275 Chains.remove(Ra);
276 Chains.insert(Rd);
277 }
278 } else {
279 DEBUG(dbgs() << "Creating new acc chain for " << PrintReg(Rd, TRI)
280 << '\n';);
281 Chains.insert(Rd);
282 }
283
284 const LiveInterval &ld = LIs->getInterval(Rd);
285 for (auto r : Chains) {
286 // Skip self
287 if (r == Rd)
288 continue;
289
290 const LiveInterval &lr = LIs->getInterval(r);
291 if (ld.overlaps(lr)) {
292 const PBQPRAProblem::AllowedSet *vRdAllowed = &p->getAllowedSet(Rd);
293 const PBQPRAProblem::AllowedSet *vRrAllowed = &p->getAllowedSet(r);
294
295 PBQPRAGraph &g = p->getGraph();
296 PBQPRAGraph::NodeId node1 = p->getNodeForVReg(Rd);
297 PBQPRAGraph::NodeId node2 = p->getNodeForVReg(r);
298 PBQPRAGraph::EdgeId edge = g.findEdge(node1, node2);
299 assert(edge != g.invalidEdgeId() &&
300 "PBQP error ! The edge should exist !");
301
302 DEBUG(dbgs() << "Refining constraint !\n";);
303
304 if (g.getEdgeNode1Id(edge) == node2) {
305 std::swap(node1, node2);
306 std::swap(vRdAllowed, vRrAllowed);
307 }
308
309 // Enforce that cost is higher with all other Chains of the same parity
310 PBQP::Matrix costs(g.getEdgeCosts(edge));
311 for (unsigned i = 0; i != vRdAllowed->size(); ++i) {
312 unsigned pRd = (*vRdAllowed)[i];
313
314 // Get the maximum cost (excluding unallocatable reg) for all other
315 // parity registers
316 PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
317 for (unsigned j = 0; j != vRrAllowed->size(); ++j) {
318 unsigned pRa = (*vRrAllowed)[j];
319 if (!haveSameParity(pRd, pRa))
320 if (costs[i + 1][j + 1] !=
321 std::numeric_limits<PBQP::PBQPNum>::infinity() &&
322 costs[i + 1][j + 1] > sameParityMax)
323 sameParityMax = costs[i + 1][j + 1];
324 }
325
326 // Ensure all registers with same parity have a higher cost
327 // than sameParityMax
328 for (unsigned j = 0; j != vRrAllowed->size(); ++j) {
329 unsigned pRa = (*vRrAllowed)[j];
330 if (haveSameParity(pRd, pRa))
331 if (sameParityMax > costs[i + 1][j + 1])
332 costs[i + 1][j + 1] = sameParityMax + 1.0;
333 }
334 }
335 g.setEdgeCosts(edge, costs);
336 }
337 }
338}
339
340std::unique_ptr<PBQPRAProblem>
341A57PBQPBuilder::build(MachineFunction *MF, const LiveIntervals *LI,
342 const MachineBlockFrequencyInfo *blockInfo,
343 const RegSet &VRegs) {
344 std::unique_ptr<PBQPRAProblem> p =
345 PBQP_BUILDER::build(MF, LI, blockInfo, VRegs);
346
347 TRI = static_cast<const AArch64RegisterInfo *>(
348 MF->getTarget().getSubtargetImpl()->getRegisterInfo());
349 LIs = LI;
350
351 DEBUG(MF->dump(););
352
353 for (MachineFunction::const_iterator mbbItr = MF->begin(), mbbEnd = MF->end();
354 mbbItr != mbbEnd; ++mbbItr) {
355 const MachineBasicBlock *MBB = &*mbbItr;
356 Chains.clear(); // FIXME: really needed ? Could not work at MF level ?
357
358 for (MachineBasicBlock::const_iterator miItr = MBB->begin(),
359 miEnd = MBB->end();
360 miItr != miEnd; ++miItr) {
361 const MachineInstr *MI = &*miItr;
362 switch (MI->getOpcode()) {
363 case AArch64::FMSUBSrrr:
364 case AArch64::FMADDSrrr:
365 case AArch64::FNMSUBSrrr:
366 case AArch64::FNMADDSrrr:
367 case AArch64::FMSUBDrrr:
368 case AArch64::FMADDDrrr:
369 case AArch64::FNMSUBDrrr:
370 case AArch64::FNMADDDrrr: {
371 unsigned Rd = MI->getOperand(0).getReg();
372 unsigned Ra = MI->getOperand(3).getReg();
373
374 if (addIntraChainConstraint(p.get(), Rd, Ra))
375 addInterChainConstraint(p.get(), Rd, Ra);
376 break;
377 }
378
379 case AArch64::FMLAv2f32:
380 case AArch64::FMLSv2f32: {
381 unsigned Rd = MI->getOperand(0).getReg();
382 addInterChainConstraint(p.get(), Rd, Rd);
383 break;
384 }
385
386 default:
387 // Forget Chains which have been killed
388 for (auto r : Chains) {
389 SmallVector<unsigned, 8> toDel;
390 if (MI->killsRegister(r)) {
391 DEBUG(dbgs() << "Killing chain " << PrintReg(r, TRI) << " at ";
392 MI->print(dbgs()););
393 toDel.push_back(r);
394 }
395
396 while (!toDel.empty()) {
397 Chains.remove(toDel.back());
398 toDel.pop_back();
399 }
400 }
401 }
402 }
403 }
404
405 return p;
406}
407
408// Factory function used by AArch64TargetMachine to add the pass to the
409// passmanager.
410FunctionPass *llvm::createAArch64A57PBQPRegAlloc() {
411 std::unique_ptr<PBQP_BUILDER> builder = llvm::make_unique<A57PBQPBuilder>();
412 return createPBQPRegisterAllocator(std::move(builder), nullptr);
413}