Sam Parker | 312409e | 2019-09-06 08:24:41 +0000 | [diff] [blame^] | 1 | //===- MVETailPredication.cpp - MVE Tail Predication ----------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | /// \file |
| 10 | /// Armv8.1m introduced MVE, M-Profile Vector Extension, and low-overhead |
| 11 | /// branches to help accelerate DSP applications. These two extensions can be |
| 12 | /// combined to provide implicit vector predication within a low-overhead loop. |
| 13 | /// The HardwareLoops pass inserts intrinsics identifying loops that the |
| 14 | /// backend will attempt to convert into a low-overhead loop. The vectorizer is |
| 15 | /// responsible for generating a vectorized loop in which the lanes are |
| 16 | /// predicated upon the iteration counter. This pass looks at these predicated |
| 17 | /// vector loops, that are targets for low-overhead loops, and prepares it for |
| 18 | /// code generation. Once the vectorizer has produced a masked loop, there's a |
| 19 | /// couple of final forms: |
| 20 | /// - A tail-predicated loop, with implicit predication. |
| 21 | /// - A loop containing multiple VCPT instructions, predicating multiple VPT |
| 22 | /// blocks of instructions operating on different vector types. |
| 23 | |
| 24 | #include "llvm/Analysis/LoopInfo.h" |
| 25 | #include "llvm/Analysis/LoopPass.h" |
| 26 | #include "llvm/Analysis/ScalarEvolution.h" |
| 27 | #include "llvm/Analysis/ScalarEvolutionExpander.h" |
| 28 | #include "llvm/Analysis/ScalarEvolutionExpressions.h" |
| 29 | #include "llvm/Analysis/TargetTransformInfo.h" |
| 30 | #include "llvm/CodeGen/TargetPassConfig.h" |
| 31 | #include "llvm/IR/Instructions.h" |
| 32 | #include "llvm/IR/IRBuilder.h" |
| 33 | #include "llvm/IR/PatternMatch.h" |
| 34 | #include "llvm/Support/Debug.h" |
| 35 | #include "llvm/Transforms/Utils/BasicBlockUtils.h" |
| 36 | #include "ARM.h" |
| 37 | #include "ARMSubtarget.h" |
| 38 | |
| 39 | using namespace llvm; |
| 40 | |
| 41 | #define DEBUG_TYPE "mve-tail-predication" |
| 42 | #define DESC "Transform predicated vector loops to use MVE tail predication" |
| 43 | |
| 44 | static cl::opt<bool> |
| 45 | DisableTailPredication("disable-mve-tail-predication", cl::Hidden, |
| 46 | cl::init(true), |
| 47 | cl::desc("Disable MVE Tail Predication")); |
| 48 | namespace { |
| 49 | |
| 50 | class MVETailPredication : public LoopPass { |
| 51 | SmallVector<IntrinsicInst*, 4> MaskedInsts; |
| 52 | Loop *L = nullptr; |
| 53 | ScalarEvolution *SE = nullptr; |
| 54 | TargetTransformInfo *TTI = nullptr; |
| 55 | |
| 56 | public: |
| 57 | static char ID; |
| 58 | |
| 59 | MVETailPredication() : LoopPass(ID) { } |
| 60 | |
| 61 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
| 62 | AU.addRequired<ScalarEvolutionWrapperPass>(); |
| 63 | AU.addRequired<LoopInfoWrapperPass>(); |
| 64 | AU.addRequired<TargetPassConfig>(); |
| 65 | AU.addRequired<TargetTransformInfoWrapperPass>(); |
| 66 | AU.addPreserved<LoopInfoWrapperPass>(); |
| 67 | AU.setPreservesCFG(); |
| 68 | } |
| 69 | |
| 70 | bool runOnLoop(Loop *L, LPPassManager&) override; |
| 71 | |
| 72 | private: |
| 73 | |
| 74 | /// Perform the relevant checks on the loop and convert if possible. |
| 75 | bool TryConvert(Value *TripCount); |
| 76 | |
| 77 | /// Return whether this is a vectorized loop, that contains masked |
| 78 | /// load/stores. |
| 79 | bool IsPredicatedVectorLoop(); |
| 80 | |
| 81 | /// Compute a value for the total number of elements that the predicated |
| 82 | /// loop will process. |
| 83 | Value *ComputeElements(Value *TripCount, VectorType *VecTy); |
| 84 | |
| 85 | /// Is the icmp that generates an i1 vector, based upon a loop counter |
| 86 | /// and a limit that is defined outside the loop. |
| 87 | bool isTailPredicate(Value *Predicate, Value *NumElements); |
| 88 | }; |
| 89 | |
| 90 | } // end namespace |
| 91 | |
| 92 | static bool IsDecrement(Instruction &I) { |
| 93 | auto *Call = dyn_cast<IntrinsicInst>(&I); |
| 94 | if (!Call) |
| 95 | return false; |
| 96 | |
| 97 | Intrinsic::ID ID = Call->getIntrinsicID(); |
| 98 | return ID == Intrinsic::loop_decrement_reg; |
| 99 | } |
| 100 | |
| 101 | static bool IsMasked(Instruction *I) { |
| 102 | auto *Call = dyn_cast<IntrinsicInst>(I); |
| 103 | if (!Call) |
| 104 | return false; |
| 105 | |
| 106 | Intrinsic::ID ID = Call->getIntrinsicID(); |
| 107 | // TODO: Support gather/scatter expand/compress operations. |
| 108 | return ID == Intrinsic::masked_store || ID == Intrinsic::masked_load; |
| 109 | } |
| 110 | |
| 111 | bool MVETailPredication::runOnLoop(Loop *L, LPPassManager&) { |
| 112 | if (skipLoop(L) || DisableTailPredication) |
| 113 | return false; |
| 114 | |
| 115 | Function &F = *L->getHeader()->getParent(); |
| 116 | auto &TPC = getAnalysis<TargetPassConfig>(); |
| 117 | auto &TM = TPC.getTM<TargetMachine>(); |
| 118 | auto *ST = &TM.getSubtarget<ARMSubtarget>(F); |
| 119 | TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); |
| 120 | SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); |
| 121 | this->L = L; |
| 122 | |
| 123 | // The MVE and LOB extensions are combined to enable tail-predication, but |
| 124 | // there's nothing preventing us from generating VCTP instructions for v8.1m. |
| 125 | if (!ST->hasMVEIntegerOps() || !ST->hasV8_1MMainlineOps()) { |
| 126 | LLVM_DEBUG(dbgs() << "TP: Not a v8.1m.main+mve target.\n"); |
| 127 | return false; |
| 128 | } |
| 129 | |
| 130 | BasicBlock *Preheader = L->getLoopPreheader(); |
| 131 | if (!Preheader) |
| 132 | return false; |
| 133 | |
| 134 | auto FindLoopIterations = [](BasicBlock *BB) -> IntrinsicInst* { |
| 135 | for (auto &I : *BB) { |
| 136 | auto *Call = dyn_cast<IntrinsicInst>(&I); |
| 137 | if (!Call) |
| 138 | continue; |
| 139 | |
| 140 | Intrinsic::ID ID = Call->getIntrinsicID(); |
| 141 | if (ID == Intrinsic::set_loop_iterations || |
| 142 | ID == Intrinsic::test_set_loop_iterations) |
| 143 | return cast<IntrinsicInst>(&I); |
| 144 | } |
| 145 | return nullptr; |
| 146 | }; |
| 147 | |
| 148 | // Look for the hardware loop intrinsic that sets the iteration count. |
| 149 | IntrinsicInst *Setup = FindLoopIterations(Preheader); |
| 150 | |
| 151 | // The test.set iteration could live in the pre- preheader. |
| 152 | if (!Setup) { |
| 153 | if (!Preheader->getSinglePredecessor()) |
| 154 | return false; |
| 155 | Setup = FindLoopIterations(Preheader->getSinglePredecessor()); |
| 156 | if (!Setup) |
| 157 | return false; |
| 158 | } |
| 159 | |
| 160 | // Search for the hardware loop intrinic that decrements the loop counter. |
| 161 | IntrinsicInst *Decrement = nullptr; |
| 162 | for (auto *BB : L->getBlocks()) { |
| 163 | for (auto &I : *BB) { |
| 164 | if (IsDecrement(I)) { |
| 165 | Decrement = cast<IntrinsicInst>(&I); |
| 166 | break; |
| 167 | } |
| 168 | } |
| 169 | } |
| 170 | |
| 171 | if (!Decrement) |
| 172 | return false; |
| 173 | |
| 174 | LLVM_DEBUG(dbgs() << "TP: Running on Loop: " << *L |
| 175 | << *Setup << "\n" |
| 176 | << *Decrement << "\n"); |
| 177 | bool Changed = TryConvert(Setup->getArgOperand(0)); |
| 178 | return Changed; |
| 179 | } |
| 180 | |
| 181 | bool MVETailPredication::isTailPredicate(Value *V, Value *NumElements) { |
| 182 | // Look for the following: |
| 183 | |
| 184 | // %trip.count.minus.1 = add i32 %N, -1 |
| 185 | // %broadcast.splatinsert10 = insertelement <4 x i32> undef, |
| 186 | // i32 %trip.count.minus.1, i32 0 |
| 187 | // %broadcast.splat11 = shufflevector <4 x i32> %broadcast.splatinsert10, |
| 188 | // <4 x i32> undef, |
| 189 | // <4 x i32> zeroinitializer |
| 190 | // ... |
| 191 | // ... |
| 192 | // %index = phi i32 |
| 193 | // %broadcast.splatinsert = insertelement <4 x i32> undef, i32 %index, i32 0 |
| 194 | // %broadcast.splat = shufflevector <4 x i32> %broadcast.splatinsert, |
| 195 | // <4 x i32> undef, |
| 196 | // <4 x i32> zeroinitializer |
| 197 | // %induction = add <4 x i32> %broadcast.splat, <i32 0, i32 1, i32 2, i32 3> |
| 198 | // %pred = icmp ule <4 x i32> %induction, %broadcast.splat11 |
| 199 | |
| 200 | // And return whether V == %pred. |
| 201 | |
| 202 | using namespace PatternMatch; |
| 203 | |
| 204 | CmpInst::Predicate Pred; |
| 205 | Instruction *Shuffle = nullptr; |
| 206 | Instruction *Induction = nullptr; |
| 207 | |
| 208 | // The vector icmp |
| 209 | if (!match(V, m_ICmp(Pred, m_Instruction(Induction), |
| 210 | m_Instruction(Shuffle))) || |
| 211 | Pred != ICmpInst::ICMP_ULE || !L->isLoopInvariant(Shuffle)) |
| 212 | return false; |
| 213 | |
| 214 | // First find the stuff outside the loop which is setting up the limit |
| 215 | // vector.... |
| 216 | // The invariant shuffle that broadcast the limit into a vector. |
| 217 | Instruction *Insert = nullptr; |
| 218 | if (!match(Shuffle, m_ShuffleVector(m_Instruction(Insert), m_Undef(), |
| 219 | m_Zero()))) |
| 220 | return false; |
| 221 | |
| 222 | // Insert the limit into a vector. |
| 223 | Instruction *BECount = nullptr; |
| 224 | if (!match(Insert, m_InsertElement(m_Undef(), m_Instruction(BECount), |
| 225 | m_Zero()))) |
| 226 | return false; |
| 227 | |
| 228 | // The limit calculation, backedge count. |
| 229 | Value *TripCount = nullptr; |
| 230 | if (!match(BECount, m_Add(m_Value(TripCount), m_AllOnes()))) |
| 231 | return false; |
| 232 | |
| 233 | if (TripCount != NumElements) |
| 234 | return false; |
| 235 | |
| 236 | // Now back to searching inside the loop body... |
| 237 | // Find the add with takes the index iv and adds a constant vector to it. |
| 238 | Instruction *BroadcastSplat = nullptr; |
| 239 | Constant *Const = nullptr; |
| 240 | if (!match(Induction, m_Add(m_Instruction(BroadcastSplat), |
| 241 | m_Constant(Const)))) |
| 242 | return false; |
| 243 | |
| 244 | // Check that we're adding <0, 1, 2, 3... |
| 245 | if (auto *CDS = dyn_cast<ConstantDataSequential>(Const)) { |
| 246 | for (unsigned i = 0; i < CDS->getNumElements(); ++i) { |
| 247 | if (CDS->getElementAsInteger(i) != i) |
| 248 | return false; |
| 249 | } |
| 250 | } else |
| 251 | return false; |
| 252 | |
| 253 | // The shuffle which broadcasts the index iv into a vector. |
| 254 | if (!match(BroadcastSplat, m_ShuffleVector(m_Instruction(Insert), m_Undef(), |
| 255 | m_Zero()))) |
| 256 | return false; |
| 257 | |
| 258 | // The insert element which initialises a vector with the index iv. |
| 259 | Instruction *IV = nullptr; |
| 260 | if (!match(Insert, m_InsertElement(m_Undef(), m_Instruction(IV), m_Zero()))) |
| 261 | return false; |
| 262 | |
| 263 | // The index iv. |
| 264 | auto *Phi = dyn_cast<PHINode>(IV); |
| 265 | if (!Phi) |
| 266 | return false; |
| 267 | |
| 268 | // TODO: Don't think we need to check the entry value. |
| 269 | Value *OnEntry = Phi->getIncomingValueForBlock(L->getLoopPreheader()); |
| 270 | if (!match(OnEntry, m_Zero())) |
| 271 | return false; |
| 272 | |
| 273 | Value *InLoop = Phi->getIncomingValueForBlock(L->getLoopLatch()); |
| 274 | unsigned Lanes = cast<VectorType>(Insert->getType())->getNumElements(); |
| 275 | |
| 276 | Instruction *LHS = nullptr; |
| 277 | if (!match(InLoop, m_Add(m_Instruction(LHS), m_SpecificInt(Lanes)))) |
| 278 | return false; |
| 279 | |
| 280 | return LHS == Phi; |
| 281 | } |
| 282 | |
| 283 | static VectorType* getVectorType(IntrinsicInst *I) { |
| 284 | unsigned TypeOp = I->getIntrinsicID() == Intrinsic::masked_load ? 0 : 1; |
| 285 | auto *PtrTy = cast<PointerType>(I->getOperand(TypeOp)->getType()); |
| 286 | return cast<VectorType>(PtrTy->getElementType()); |
| 287 | } |
| 288 | |
| 289 | bool MVETailPredication::IsPredicatedVectorLoop() { |
| 290 | // Check that the loop contains at least one masked load/store intrinsic. |
| 291 | // We only support 'normal' vector instructions - other than masked |
| 292 | // load/stores. |
| 293 | for (auto *BB : L->getBlocks()) { |
| 294 | for (auto &I : *BB) { |
| 295 | if (IsMasked(&I)) { |
| 296 | VectorType *VecTy = getVectorType(cast<IntrinsicInst>(&I)); |
| 297 | unsigned Lanes = VecTy->getNumElements(); |
| 298 | unsigned ElementWidth = VecTy->getScalarSizeInBits(); |
| 299 | // MVE vectors are 128-bit, but don't support 128 x i1. |
| 300 | // TODO: Can we support vectors larger than 128-bits? |
| 301 | unsigned MaxWidth = TTI->getRegisterBitWidth(true); |
| 302 | if (Lanes * ElementWidth != MaxWidth || Lanes == MaxWidth) |
| 303 | return false; |
| 304 | MaskedInsts.push_back(cast<IntrinsicInst>(&I)); |
| 305 | } else if (auto *Int = dyn_cast<IntrinsicInst>(&I)) { |
| 306 | for (auto &U : Int->args()) { |
| 307 | if (isa<VectorType>(U->getType())) |
| 308 | return false; |
| 309 | } |
| 310 | } |
| 311 | } |
| 312 | } |
| 313 | |
| 314 | return !MaskedInsts.empty(); |
| 315 | } |
| 316 | |
| 317 | Value* MVETailPredication::ComputeElements(Value *TripCount, |
| 318 | VectorType *VecTy) { |
| 319 | const SCEV *TripCountSE = SE->getSCEV(TripCount); |
| 320 | ConstantInt *VF = ConstantInt::get(cast<IntegerType>(TripCount->getType()), |
| 321 | VecTy->getNumElements()); |
| 322 | |
| 323 | if (VF->equalsInt(1)) |
| 324 | return nullptr; |
| 325 | |
| 326 | // TODO: Support constant trip counts. |
| 327 | auto VisitAdd = [&](const SCEVAddExpr *S) -> const SCEVMulExpr* { |
| 328 | if (auto *Const = dyn_cast<SCEVConstant>(S->getOperand(0))) { |
| 329 | if (Const->getAPInt() != -VF->getValue()) |
| 330 | return nullptr; |
| 331 | } else |
| 332 | return nullptr; |
| 333 | return dyn_cast<SCEVMulExpr>(S->getOperand(1)); |
| 334 | }; |
| 335 | |
| 336 | auto VisitMul = [&](const SCEVMulExpr *S) -> const SCEVUDivExpr* { |
| 337 | if (auto *Const = dyn_cast<SCEVConstant>(S->getOperand(0))) { |
| 338 | if (Const->getValue() != VF) |
| 339 | return nullptr; |
| 340 | } else |
| 341 | return nullptr; |
| 342 | return dyn_cast<SCEVUDivExpr>(S->getOperand(1)); |
| 343 | }; |
| 344 | |
| 345 | auto VisitDiv = [&](const SCEVUDivExpr *S) -> const SCEV* { |
| 346 | if (auto *Const = dyn_cast<SCEVConstant>(S->getRHS())) { |
| 347 | if (Const->getValue() != VF) |
| 348 | return nullptr; |
| 349 | } else |
| 350 | return nullptr; |
| 351 | |
| 352 | if (auto *RoundUp = dyn_cast<SCEVAddExpr>(S->getLHS())) { |
| 353 | if (auto *Const = dyn_cast<SCEVConstant>(RoundUp->getOperand(0))) { |
| 354 | if (Const->getAPInt() != (VF->getValue() - 1)) |
| 355 | return nullptr; |
| 356 | } else |
| 357 | return nullptr; |
| 358 | |
| 359 | return RoundUp->getOperand(1); |
| 360 | } |
| 361 | return nullptr; |
| 362 | }; |
| 363 | |
| 364 | // TODO: Can we use SCEV helpers, such as findArrayDimensions, and friends to |
| 365 | // determine the numbers of elements instead? Looks like this is what is used |
| 366 | // for delinearization, but I'm not sure if it can be applied to the |
| 367 | // vectorized form - at least not without a bit more work than I feel |
| 368 | // comfortable with. |
| 369 | |
| 370 | // Search for Elems in the following SCEV: |
| 371 | // (1 + ((-VF + (VF * (((VF - 1) + %Elems) /u VF))<nuw>) /u VF))<nuw><nsw> |
| 372 | const SCEV *Elems = nullptr; |
| 373 | if (auto *TC = dyn_cast<SCEVAddExpr>(TripCountSE)) |
| 374 | if (auto *Div = dyn_cast<SCEVUDivExpr>(TC->getOperand(1))) |
| 375 | if (auto *Add = dyn_cast<SCEVAddExpr>(Div->getLHS())) |
| 376 | if (auto *Mul = VisitAdd(Add)) |
| 377 | if (auto *Div = VisitMul(Mul)) |
| 378 | if (auto *Res = VisitDiv(Div)) |
| 379 | Elems = Res; |
| 380 | |
| 381 | if (!Elems) |
| 382 | return nullptr; |
| 383 | |
| 384 | Instruction *InsertPt = L->getLoopPreheader()->getTerminator(); |
| 385 | if (!isSafeToExpandAt(Elems, InsertPt, *SE)) |
| 386 | return nullptr; |
| 387 | |
| 388 | auto DL = L->getHeader()->getModule()->getDataLayout(); |
| 389 | SCEVExpander Expander(*SE, DL, "elements"); |
| 390 | return Expander.expandCodeFor(Elems, Elems->getType(), InsertPt); |
| 391 | } |
| 392 | |
| 393 | bool MVETailPredication::TryConvert(Value *TripCount) { |
| 394 | if (!IsPredicatedVectorLoop()) |
| 395 | return false; |
| 396 | |
| 397 | LLVM_DEBUG(dbgs() << "TP: Found predicated vector loop.\n"); |
| 398 | |
| 399 | // Walk through the masked intrinsics and try to find whether the predicate |
| 400 | // operand is generated from an induction variable. |
| 401 | Module *M = L->getHeader()->getModule(); |
| 402 | Type *Ty = IntegerType::get(M->getContext(), 32); |
| 403 | SmallPtrSet<Value*, 4> Predicates; |
| 404 | |
| 405 | for (auto *I : MaskedInsts) { |
| 406 | Intrinsic::ID ID = I->getIntrinsicID(); |
| 407 | unsigned PredOp = ID == Intrinsic::masked_load ? 2 : 3; |
| 408 | Value *Predicate = I->getArgOperand(PredOp); |
| 409 | if (Predicates.count(Predicate)) |
| 410 | continue; |
| 411 | |
| 412 | VectorType *VecTy = getVectorType(I); |
| 413 | Value *NumElements = ComputeElements(TripCount, VecTy); |
| 414 | if (!NumElements) |
| 415 | continue; |
| 416 | |
| 417 | if (!isTailPredicate(Predicate, NumElements)) { |
| 418 | LLVM_DEBUG(dbgs() << "TP: Not tail predicate: " << *Predicate << "\n"); |
| 419 | continue; |
| 420 | } |
| 421 | |
| 422 | LLVM_DEBUG(dbgs() << "TP: Found tail predicate: " << *Predicate << "\n"); |
| 423 | Predicates.insert(Predicate); |
| 424 | |
| 425 | // Insert a phi to count the number of elements processed by the loop. |
| 426 | IRBuilder<> Builder(L->getHeader()->getFirstNonPHI()); |
| 427 | PHINode *Processed = Builder.CreatePHI(Ty, 2); |
| 428 | Processed->addIncoming(NumElements, L->getLoopPreheader()); |
| 429 | |
| 430 | // Insert the intrinsic to represent the effect of tail predication. |
| 431 | Builder.SetInsertPoint(cast<Instruction>(Predicate)); |
| 432 | ConstantInt *Factor = |
| 433 | ConstantInt::get(cast<IntegerType>(Ty), VecTy->getNumElements()); |
| 434 | Intrinsic::ID VCTPID; |
| 435 | switch (VecTy->getNumElements()) { |
| 436 | default: |
| 437 | llvm_unreachable("unexpected number of lanes"); |
| 438 | case 2: VCTPID = Intrinsic::arm_vctp64; break; |
| 439 | case 4: VCTPID = Intrinsic::arm_vctp32; break; |
| 440 | case 8: VCTPID = Intrinsic::arm_vctp16; break; |
| 441 | case 16: VCTPID = Intrinsic::arm_vctp8; break; |
| 442 | } |
| 443 | Function *VCTP = Intrinsic::getDeclaration(M, VCTPID); |
| 444 | // TODO: This add likely already exists in the loop. |
| 445 | Value *Remaining = Builder.CreateSub(Processed, Factor); |
| 446 | Value *TailPredicate = Builder.CreateCall(VCTP, Remaining); |
| 447 | Predicate->replaceAllUsesWith(TailPredicate); |
| 448 | |
| 449 | // Add the incoming value to the new phi. |
| 450 | Processed->addIncoming(Remaining, L->getLoopLatch()); |
| 451 | LLVM_DEBUG(dbgs() << "TP: Insert processed elements phi: " |
| 452 | << *Processed << "\n" |
| 453 | << "TP: Inserted VCTP: " << *TailPredicate << "\n"); |
| 454 | } |
| 455 | |
| 456 | for (auto I : L->blocks()) |
| 457 | DeleteDeadPHIs(I); |
| 458 | |
| 459 | return true; |
| 460 | } |
| 461 | |
| 462 | Pass *llvm::createMVETailPredicationPass() { |
| 463 | return new MVETailPredication(); |
| 464 | } |
| 465 | |
| 466 | char MVETailPredication::ID = 0; |
| 467 | |
| 468 | INITIALIZE_PASS_BEGIN(MVETailPredication, DEBUG_TYPE, DESC, false, false) |
| 469 | INITIALIZE_PASS_END(MVETailPredication, DEBUG_TYPE, DESC, false, false) |