Arnaud A. de Grandmaison | c75dbbb | 2014-09-10 14:06:10 +0000 | [diff] [blame^] | 1 | //===-- AArch64PBQPRegAlloc.cpp - AArch64 specific PBQP constraints -------===// |
| 2 | // |
| 3 | // The LLVM Compiler Infrastructure |
| 4 | // |
| 5 | // This file is distributed under the University of Illinois Open Source |
| 6 | // License. See LICENSE.TXT for details. |
| 7 | // |
| 8 | //===----------------------------------------------------------------------===// |
| 9 | // This file contains the AArch64 / Cortex-A57 specific register allocation |
| 10 | // constraints for use by the PBQP register allocator. |
| 11 | // |
| 12 | // It is essentially a transcription of what is contained in |
| 13 | // AArch64A57FPLoadBalancing, which tries to use a balanced |
| 14 | // mix of odd and even D-registers when performing a critical sequence of |
| 15 | // independent, non-quadword FP/ASIMD floating-point multiply-accumulates. |
| 16 | //===----------------------------------------------------------------------===// |
| 17 | |
| 18 | #define DEBUG_TYPE "aarch64-pbqp" |
| 19 | |
| 20 | #include "AArch64.h" |
| 21 | #include "AArch64RegisterInfo.h" |
| 22 | |
| 23 | #include "llvm/ADT/SetVector.h" |
| 24 | #include "llvm/CodeGen/LiveIntervalAnalysis.h" |
| 25 | #include "llvm/CodeGen/MachineBasicBlock.h" |
| 26 | #include "llvm/CodeGen/MachineFunction.h" |
| 27 | #include "llvm/CodeGen/MachineRegisterInfo.h" |
| 28 | #include "llvm/CodeGen/RegAllocPBQP.h" |
| 29 | #include "llvm/Support/Debug.h" |
| 30 | #include "llvm/Support/ErrorHandling.h" |
| 31 | #include "llvm/Support/raw_ostream.h" |
| 32 | |
| 33 | #define PBQP_BUILDER PBQPBuilderWithCoalescing |
| 34 | //#define PBQP_BUILDER PBQPBuilder |
| 35 | |
| 36 | using namespace llvm; |
| 37 | |
| 38 | namespace { |
| 39 | |
| 40 | bool isFPReg(unsigned reg) { |
| 41 | return AArch64::FPR32RegClass.contains(reg) || |
| 42 | AArch64::FPR64RegClass.contains(reg) || |
| 43 | AArch64::FPR128RegClass.contains(reg); |
| 44 | }; |
| 45 | |
| 46 | bool isOdd(unsigned reg) { |
| 47 | switch (reg) { |
| 48 | default: |
| 49 | llvm_unreachable("Register is not from the expected class !"); |
| 50 | case AArch64::S1: |
| 51 | case AArch64::S3: |
| 52 | case AArch64::S5: |
| 53 | case AArch64::S7: |
| 54 | case AArch64::S9: |
| 55 | case AArch64::S11: |
| 56 | case AArch64::S13: |
| 57 | case AArch64::S15: |
| 58 | case AArch64::S17: |
| 59 | case AArch64::S19: |
| 60 | case AArch64::S21: |
| 61 | case AArch64::S23: |
| 62 | case AArch64::S25: |
| 63 | case AArch64::S27: |
| 64 | case AArch64::S29: |
| 65 | case AArch64::S31: |
| 66 | case AArch64::D1: |
| 67 | case AArch64::D3: |
| 68 | case AArch64::D5: |
| 69 | case AArch64::D7: |
| 70 | case AArch64::D9: |
| 71 | case AArch64::D11: |
| 72 | case AArch64::D13: |
| 73 | case AArch64::D15: |
| 74 | case AArch64::D17: |
| 75 | case AArch64::D19: |
| 76 | case AArch64::D21: |
| 77 | case AArch64::D23: |
| 78 | case AArch64::D25: |
| 79 | case AArch64::D27: |
| 80 | case AArch64::D29: |
| 81 | case AArch64::D31: |
| 82 | case AArch64::Q1: |
| 83 | case AArch64::Q3: |
| 84 | case AArch64::Q5: |
| 85 | case AArch64::Q7: |
| 86 | case AArch64::Q9: |
| 87 | case AArch64::Q11: |
| 88 | case AArch64::Q13: |
| 89 | case AArch64::Q15: |
| 90 | case AArch64::Q17: |
| 91 | case AArch64::Q19: |
| 92 | case AArch64::Q21: |
| 93 | case AArch64::Q23: |
| 94 | case AArch64::Q25: |
| 95 | case AArch64::Q27: |
| 96 | case AArch64::Q29: |
| 97 | case AArch64::Q31: |
| 98 | return true; |
| 99 | case AArch64::S0: |
| 100 | case AArch64::S2: |
| 101 | case AArch64::S4: |
| 102 | case AArch64::S6: |
| 103 | case AArch64::S8: |
| 104 | case AArch64::S10: |
| 105 | case AArch64::S12: |
| 106 | case AArch64::S14: |
| 107 | case AArch64::S16: |
| 108 | case AArch64::S18: |
| 109 | case AArch64::S20: |
| 110 | case AArch64::S22: |
| 111 | case AArch64::S24: |
| 112 | case AArch64::S26: |
| 113 | case AArch64::S28: |
| 114 | case AArch64::S30: |
| 115 | case AArch64::D0: |
| 116 | case AArch64::D2: |
| 117 | case AArch64::D4: |
| 118 | case AArch64::D6: |
| 119 | case AArch64::D8: |
| 120 | case AArch64::D10: |
| 121 | case AArch64::D12: |
| 122 | case AArch64::D14: |
| 123 | case AArch64::D16: |
| 124 | case AArch64::D18: |
| 125 | case AArch64::D20: |
| 126 | case AArch64::D22: |
| 127 | case AArch64::D24: |
| 128 | case AArch64::D26: |
| 129 | case AArch64::D28: |
| 130 | case AArch64::D30: |
| 131 | case AArch64::Q0: |
| 132 | case AArch64::Q2: |
| 133 | case AArch64::Q4: |
| 134 | case AArch64::Q6: |
| 135 | case AArch64::Q8: |
| 136 | case AArch64::Q10: |
| 137 | case AArch64::Q12: |
| 138 | case AArch64::Q14: |
| 139 | case AArch64::Q16: |
| 140 | case AArch64::Q18: |
| 141 | case AArch64::Q20: |
| 142 | case AArch64::Q22: |
| 143 | case AArch64::Q24: |
| 144 | case AArch64::Q26: |
| 145 | case AArch64::Q28: |
| 146 | case AArch64::Q30: |
| 147 | return false; |
| 148 | |
| 149 | } |
| 150 | } |
| 151 | |
| 152 | bool haveSameParity(unsigned reg1, unsigned reg2) { |
| 153 | assert(isFPReg(reg1) && "Expecting an FP register for reg1"); |
| 154 | assert(isFPReg(reg2) && "Expecting an FP register for reg2"); |
| 155 | |
| 156 | return isOdd(reg1) == isOdd(reg2); |
| 157 | } |
| 158 | |
| 159 | class A57PBQPBuilder : public PBQP_BUILDER { |
| 160 | public: |
| 161 | A57PBQPBuilder() : PBQP_BUILDER(), TRI(nullptr), LIs(nullptr), Chains() {} |
| 162 | |
| 163 | // Build a PBQP instance to represent the register allocation problem for |
| 164 | // the given MachineFunction. |
| 165 | std::unique_ptr<PBQPRAProblem> |
| 166 | build(MachineFunction *MF, const LiveIntervals *LI, |
| 167 | const MachineBlockFrequencyInfo *blockInfo, |
| 168 | const RegSet &VRegs) override; |
| 169 | |
| 170 | private: |
| 171 | const AArch64RegisterInfo *TRI; |
| 172 | const LiveIntervals *LIs; |
| 173 | SmallSetVector<unsigned, 32> Chains; |
| 174 | |
| 175 | // Return true if reg is a physical register |
| 176 | bool isPhysicalReg(unsigned reg) const { |
| 177 | return TRI->isPhysicalRegister(reg); |
| 178 | } |
| 179 | |
| 180 | // Add the accumulator chaining constraint, inside the chain, i.e. so that |
| 181 | // parity(Rd) == parity(Ra). |
| 182 | // \return true if a constraint was added |
| 183 | bool addIntraChainConstraint(PBQPRAProblem *p, unsigned Rd, unsigned Ra); |
| 184 | |
| 185 | // Add constraints between existing chains |
| 186 | void addInterChainConstraint(PBQPRAProblem *p, unsigned Rd, unsigned Ra); |
| 187 | }; |
| 188 | } // Anonymous namespace |
| 189 | |
| 190 | bool A57PBQPBuilder::addIntraChainConstraint(PBQPRAProblem *p, unsigned Rd, |
| 191 | unsigned Ra) { |
| 192 | if (Rd == Ra) |
| 193 | return false; |
| 194 | |
| 195 | if (isPhysicalReg(Rd) || isPhysicalReg(Ra)) { |
| 196 | dbgs() << "Rd is a physical reg:" << isPhysicalReg(Rd) << '\n'; |
| 197 | dbgs() << "Ra is a physical reg:" << isPhysicalReg(Ra) << '\n'; |
| 198 | return false; |
| 199 | } |
| 200 | |
| 201 | const PBQPRAProblem::AllowedSet *vRdAllowed = &p->getAllowedSet(Rd); |
| 202 | const PBQPRAProblem::AllowedSet *vRaAllowed = &p->getAllowedSet(Ra); |
| 203 | |
| 204 | PBQPRAGraph &g = p->getGraph(); |
| 205 | PBQPRAGraph::NodeId node1 = p->getNodeForVReg(Rd); |
| 206 | PBQPRAGraph::NodeId node2 = p->getNodeForVReg(Ra); |
| 207 | PBQPRAGraph::EdgeId edge = g.findEdge(node1, node2); |
| 208 | |
| 209 | // The edge does not exist. Create one with the appropriate interference |
| 210 | // costs. |
| 211 | if (edge == g.invalidEdgeId()) { |
| 212 | const LiveInterval &ld = LIs->getInterval(Rd); |
| 213 | const LiveInterval &la = LIs->getInterval(Ra); |
| 214 | bool livesOverlap = ld.overlaps(la); |
| 215 | |
| 216 | PBQP::Matrix costs(vRdAllowed->size() + 1, vRaAllowed->size() + 1, 0); |
| 217 | for (unsigned i = 0; i != vRdAllowed->size(); ++i) { |
| 218 | unsigned pRd = (*vRdAllowed)[i]; |
| 219 | for (unsigned j = 0; j != vRaAllowed->size(); ++j) { |
| 220 | unsigned pRa = (*vRaAllowed)[j]; |
| 221 | if (livesOverlap && TRI->regsOverlap(pRd, pRa)) |
| 222 | costs[i + 1][j + 1] = std::numeric_limits<PBQP::PBQPNum>::infinity(); |
| 223 | else |
| 224 | costs[i + 1][j + 1] = haveSameParity(pRd, pRa) ? 0.0 : 1.0; |
| 225 | } |
| 226 | } |
| 227 | g.addEdge(node1, node2, std::move(costs)); |
| 228 | return true; |
| 229 | } |
| 230 | |
| 231 | if (g.getEdgeNode1Id(edge) == node2) { |
| 232 | std::swap(node1, node2); |
| 233 | std::swap(vRdAllowed, vRaAllowed); |
| 234 | } |
| 235 | |
| 236 | // Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass)) |
| 237 | PBQP::Matrix costs(g.getEdgeCosts(edge)); |
| 238 | for (unsigned i = 0; i != vRdAllowed->size(); ++i) { |
| 239 | unsigned pRd = (*vRdAllowed)[i]; |
| 240 | |
| 241 | // Get the maximum cost (excluding unallocatable reg) for same parity |
| 242 | // registers |
| 243 | PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min(); |
| 244 | for (unsigned j = 0; j != vRaAllowed->size(); ++j) { |
| 245 | unsigned pRa = (*vRaAllowed)[j]; |
| 246 | if (haveSameParity(pRd, pRa)) |
| 247 | if (costs[i + 1][j + 1] != |
| 248 | std::numeric_limits<PBQP::PBQPNum>::infinity() && |
| 249 | costs[i + 1][j + 1] > sameParityMax) |
| 250 | sameParityMax = costs[i + 1][j + 1]; |
| 251 | } |
| 252 | |
| 253 | // Ensure all registers with a different parity have a higher cost |
| 254 | // than sameParityMax |
| 255 | for (unsigned j = 0; j != vRaAllowed->size(); ++j) { |
| 256 | unsigned pRa = (*vRaAllowed)[j]; |
| 257 | if (!haveSameParity(pRd, pRa)) |
| 258 | if (sameParityMax > costs[i + 1][j + 1]) |
| 259 | costs[i + 1][j + 1] = sameParityMax + 1.0; |
| 260 | } |
| 261 | } |
| 262 | g.setEdgeCosts(edge, costs); |
| 263 | |
| 264 | return true; |
| 265 | } |
| 266 | |
| 267 | void |
| 268 | A57PBQPBuilder::addInterChainConstraint(PBQPRAProblem *p, unsigned Rd, |
| 269 | unsigned Ra) { |
| 270 | // Do some Chain management |
| 271 | if (Chains.count(Ra)) { |
| 272 | if (Rd != Ra) { |
| 273 | DEBUG(dbgs() << "Moving acc chain from " << PrintReg(Ra, TRI) << " to " |
| 274 | << PrintReg(Rd, TRI) << '\n';); |
| 275 | Chains.remove(Ra); |
| 276 | Chains.insert(Rd); |
| 277 | } |
| 278 | } else { |
| 279 | DEBUG(dbgs() << "Creating new acc chain for " << PrintReg(Rd, TRI) |
| 280 | << '\n';); |
| 281 | Chains.insert(Rd); |
| 282 | } |
| 283 | |
| 284 | const LiveInterval &ld = LIs->getInterval(Rd); |
| 285 | for (auto r : Chains) { |
| 286 | // Skip self |
| 287 | if (r == Rd) |
| 288 | continue; |
| 289 | |
| 290 | const LiveInterval &lr = LIs->getInterval(r); |
| 291 | if (ld.overlaps(lr)) { |
| 292 | const PBQPRAProblem::AllowedSet *vRdAllowed = &p->getAllowedSet(Rd); |
| 293 | const PBQPRAProblem::AllowedSet *vRrAllowed = &p->getAllowedSet(r); |
| 294 | |
| 295 | PBQPRAGraph &g = p->getGraph(); |
| 296 | PBQPRAGraph::NodeId node1 = p->getNodeForVReg(Rd); |
| 297 | PBQPRAGraph::NodeId node2 = p->getNodeForVReg(r); |
| 298 | PBQPRAGraph::EdgeId edge = g.findEdge(node1, node2); |
| 299 | assert(edge != g.invalidEdgeId() && |
| 300 | "PBQP error ! The edge should exist !"); |
| 301 | |
| 302 | DEBUG(dbgs() << "Refining constraint !\n";); |
| 303 | |
| 304 | if (g.getEdgeNode1Id(edge) == node2) { |
| 305 | std::swap(node1, node2); |
| 306 | std::swap(vRdAllowed, vRrAllowed); |
| 307 | } |
| 308 | |
| 309 | // Enforce that cost is higher with all other Chains of the same parity |
| 310 | PBQP::Matrix costs(g.getEdgeCosts(edge)); |
| 311 | for (unsigned i = 0; i != vRdAllowed->size(); ++i) { |
| 312 | unsigned pRd = (*vRdAllowed)[i]; |
| 313 | |
| 314 | // Get the maximum cost (excluding unallocatable reg) for all other |
| 315 | // parity registers |
| 316 | PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min(); |
| 317 | for (unsigned j = 0; j != vRrAllowed->size(); ++j) { |
| 318 | unsigned pRa = (*vRrAllowed)[j]; |
| 319 | if (!haveSameParity(pRd, pRa)) |
| 320 | if (costs[i + 1][j + 1] != |
| 321 | std::numeric_limits<PBQP::PBQPNum>::infinity() && |
| 322 | costs[i + 1][j + 1] > sameParityMax) |
| 323 | sameParityMax = costs[i + 1][j + 1]; |
| 324 | } |
| 325 | |
| 326 | // Ensure all registers with same parity have a higher cost |
| 327 | // than sameParityMax |
| 328 | for (unsigned j = 0; j != vRrAllowed->size(); ++j) { |
| 329 | unsigned pRa = (*vRrAllowed)[j]; |
| 330 | if (haveSameParity(pRd, pRa)) |
| 331 | if (sameParityMax > costs[i + 1][j + 1]) |
| 332 | costs[i + 1][j + 1] = sameParityMax + 1.0; |
| 333 | } |
| 334 | } |
| 335 | g.setEdgeCosts(edge, costs); |
| 336 | } |
| 337 | } |
| 338 | } |
| 339 | |
| 340 | std::unique_ptr<PBQPRAProblem> |
| 341 | A57PBQPBuilder::build(MachineFunction *MF, const LiveIntervals *LI, |
| 342 | const MachineBlockFrequencyInfo *blockInfo, |
| 343 | const RegSet &VRegs) { |
| 344 | std::unique_ptr<PBQPRAProblem> p = |
| 345 | PBQP_BUILDER::build(MF, LI, blockInfo, VRegs); |
| 346 | |
| 347 | TRI = static_cast<const AArch64RegisterInfo *>( |
| 348 | MF->getTarget().getSubtargetImpl()->getRegisterInfo()); |
| 349 | LIs = LI; |
| 350 | |
| 351 | DEBUG(MF->dump();); |
| 352 | |
| 353 | for (MachineFunction::const_iterator mbbItr = MF->begin(), mbbEnd = MF->end(); |
| 354 | mbbItr != mbbEnd; ++mbbItr) { |
| 355 | const MachineBasicBlock *MBB = &*mbbItr; |
| 356 | Chains.clear(); // FIXME: really needed ? Could not work at MF level ? |
| 357 | |
| 358 | for (MachineBasicBlock::const_iterator miItr = MBB->begin(), |
| 359 | miEnd = MBB->end(); |
| 360 | miItr != miEnd; ++miItr) { |
| 361 | const MachineInstr *MI = &*miItr; |
| 362 | switch (MI->getOpcode()) { |
| 363 | case AArch64::FMSUBSrrr: |
| 364 | case AArch64::FMADDSrrr: |
| 365 | case AArch64::FNMSUBSrrr: |
| 366 | case AArch64::FNMADDSrrr: |
| 367 | case AArch64::FMSUBDrrr: |
| 368 | case AArch64::FMADDDrrr: |
| 369 | case AArch64::FNMSUBDrrr: |
| 370 | case AArch64::FNMADDDrrr: { |
| 371 | unsigned Rd = MI->getOperand(0).getReg(); |
| 372 | unsigned Ra = MI->getOperand(3).getReg(); |
| 373 | |
| 374 | if (addIntraChainConstraint(p.get(), Rd, Ra)) |
| 375 | addInterChainConstraint(p.get(), Rd, Ra); |
| 376 | break; |
| 377 | } |
| 378 | |
| 379 | case AArch64::FMLAv2f32: |
| 380 | case AArch64::FMLSv2f32: { |
| 381 | unsigned Rd = MI->getOperand(0).getReg(); |
| 382 | addInterChainConstraint(p.get(), Rd, Rd); |
| 383 | break; |
| 384 | } |
| 385 | |
| 386 | default: |
| 387 | // Forget Chains which have been killed |
| 388 | for (auto r : Chains) { |
| 389 | SmallVector<unsigned, 8> toDel; |
| 390 | if (MI->killsRegister(r)) { |
| 391 | DEBUG(dbgs() << "Killing chain " << PrintReg(r, TRI) << " at "; |
| 392 | MI->print(dbgs());); |
| 393 | toDel.push_back(r); |
| 394 | } |
| 395 | |
| 396 | while (!toDel.empty()) { |
| 397 | Chains.remove(toDel.back()); |
| 398 | toDel.pop_back(); |
| 399 | } |
| 400 | } |
| 401 | } |
| 402 | } |
| 403 | } |
| 404 | |
| 405 | return p; |
| 406 | } |
| 407 | |
| 408 | // Factory function used by AArch64TargetMachine to add the pass to the |
| 409 | // passmanager. |
| 410 | FunctionPass *llvm::createAArch64A57PBQPRegAlloc() { |
| 411 | std::unique_ptr<PBQP_BUILDER> builder = llvm::make_unique<A57PBQPBuilder>(); |
| 412 | return createPBQPRegisterAllocator(std::move(builder), nullptr); |
| 413 | } |