Support expanding partial-word cmpxchg to full-word cmpxchg in AtomicExpandPass.

Many CPUs only have the ability to do a 4-byte cmpxchg (or ll/sc), not 1
or 2-byte. For those, you need to mask and shift the 1 or 2 byte values
appropriately to use the 4-byte instruction.

This change adds support for cmpxchg-based instruction sets (only SPARC,
in LLVM). The support can be extended for LL/SC-based PPC and MIPS in
the future, supplanting the ISel expansions those architectures
currently use.

Tests added for the IR transform and SPARCv9.

Differential Revision: http://reviews.llvm.org/D21029

llvm-svn: 273025
diff --git a/llvm/lib/CodeGen/AtomicExpandPass.cpp b/llvm/lib/CodeGen/AtomicExpandPass.cpp
index 4b26b64..bf5cf10 100644
--- a/llvm/lib/CodeGen/AtomicExpandPass.cpp
+++ b/llvm/lib/CodeGen/AtomicExpandPass.cpp
@@ -57,10 +57,25 @@
     StoreInst *convertAtomicStoreToIntegerType(StoreInst *SI);
     bool expandAtomicStore(StoreInst *SI);
     bool tryExpandAtomicRMW(AtomicRMWInst *AI);
-    bool expandAtomicOpToLLSC(
-        Instruction *I, Value *Addr, AtomicOrdering MemOpOrder,
+    Value *
+    insertRMWLLSCLoop(IRBuilder<> &Builder, Type *ResultTy, Value *Addr,
+                      AtomicOrdering MemOpOrder,
+                      function_ref<Value *(IRBuilder<> &, Value *)> PerformOp);
+    void expandAtomicOpToLLSC(
+        Instruction *I, Type *ResultTy, Value *Addr, AtomicOrdering MemOpOrder,
         function_ref<Value *(IRBuilder<> &, Value *)> PerformOp);
+    void expandPartwordAtomicRMW(
+        AtomicRMWInst *I,
+        TargetLoweringBase::AtomicExpansionKind ExpansionKind);
+    void expandPartwordCmpXchg(AtomicCmpXchgInst *I);
+
     AtomicCmpXchgInst *convertCmpXchgToIntegerType(AtomicCmpXchgInst *CI);
+    static Value *insertRMWCmpXchgLoop(
+        IRBuilder<> &Builder, Type *ResultType, Value *Addr,
+        AtomicOrdering MemOpOrder,
+        function_ref<Value *(IRBuilder<> &, Value *)> PerformOp,
+        CreateCmpXchgInstFun CreateCmpXchg);
+
     bool expandAtomicCmpXchg(AtomicCmpXchgInst *CI);
     bool isIdempotentRMW(AtomicRMWInst *AI);
     bool simplifyIdempotentRMW(AtomicRMWInst *AI);
@@ -74,6 +89,10 @@
     void expandAtomicStoreToLibcall(StoreInst *LI);
     void expandAtomicRMWToLibcall(AtomicRMWInst *I);
     void expandAtomicCASToLibcall(AtomicCmpXchgInst *I);
+
+    friend bool
+    llvm::expandAtomicRMWToCmpXchg(AtomicRMWInst *AI,
+                                   CreateCmpXchgInstFun CreateCmpXchg);
   };
 }
 
@@ -285,9 +304,17 @@
                "invariant broken");
         MadeChange = true;
       }
-      
-      if (TLI->shouldExpandAtomicCmpXchgInIR(CASI))
-        MadeChange |= expandAtomicCmpXchg(CASI);
+
+      unsigned MinCASSize = TLI->getMinCmpXchgSizeInBits() / 8;
+      unsigned ValueSize = getAtomicOpSize(CASI);
+      if (ValueSize < MinCASSize) {
+        assert(!TLI->shouldExpandAtomicCmpXchgInIR(CASI) &&
+               "MinCmpXchgSizeInBits not yet supported for LL/SC expansions.");
+        expandPartwordCmpXchg(CASI);
+      } else {
+        if (TLI->shouldExpandAtomicCmpXchgInIR(CASI))
+          MadeChange |= expandAtomicCmpXchg(CASI);
+      }
     }
   }
   return MadeChange;
@@ -355,9 +382,10 @@
   case TargetLoweringBase::AtomicExpansionKind::None:
     return false;
   case TargetLoweringBase::AtomicExpansionKind::LLSC:
-    return expandAtomicOpToLLSC(
-        LI, LI->getPointerOperand(), LI->getOrdering(),
+    expandAtomicOpToLLSC(
+        LI, LI->getType(), LI->getPointerOperand(), LI->getOrdering(),
         [](IRBuilder<> &Builder, Value *Loaded) { return Loaded; });
+    return true;
   case TargetLoweringBase::AtomicExpansionKind::LLOnly:
     return expandAtomicLoadToLL(LI);
   case TargetLoweringBase::AtomicExpansionKind::CmpXChg:
@@ -498,32 +526,353 @@
   switch (TLI->shouldExpandAtomicRMWInIR(AI)) {
   case TargetLoweringBase::AtomicExpansionKind::None:
     return false;
-  case TargetLoweringBase::AtomicExpansionKind::LLSC:
-    return expandAtomicOpToLLSC(AI, AI->getPointerOperand(), AI->getOrdering(),
-                                [&](IRBuilder<> &Builder, Value *Loaded) {
-                                  return performAtomicOp(AI->getOperation(),
-                                                         Builder, Loaded,
-                                                         AI->getValOperand());
-                                });
-  case TargetLoweringBase::AtomicExpansionKind::CmpXChg:
-    return expandAtomicRMWToCmpXchg(AI, createCmpXchgInstFun);
+  case TargetLoweringBase::AtomicExpansionKind::LLSC: {
+    unsigned MinCASSize = TLI->getMinCmpXchgSizeInBits() / 8;
+    unsigned ValueSize = getAtomicOpSize(AI);
+    if (ValueSize < MinCASSize) {
+      llvm_unreachable(
+          "MinCmpXchgSizeInBits not yet supported for LL/SC architectures.");
+    } else {
+      auto PerformOp = [&](IRBuilder<> &Builder, Value *Loaded) {
+        return performAtomicOp(AI->getOperation(), Builder, Loaded,
+                               AI->getValOperand());
+      };
+      expandAtomicOpToLLSC(AI, AI->getType(), AI->getPointerOperand(),
+                           AI->getOrdering(), PerformOp);
+    }
+    return true;
+  }
+  case TargetLoweringBase::AtomicExpansionKind::CmpXChg: {
+    unsigned MinCASSize = TLI->getMinCmpXchgSizeInBits() / 8;
+    unsigned ValueSize = getAtomicOpSize(AI);
+    if (ValueSize < MinCASSize) {
+      expandPartwordAtomicRMW(AI,
+                              TargetLoweringBase::AtomicExpansionKind::CmpXChg);
+    } else {
+      expandAtomicRMWToCmpXchg(AI, createCmpXchgInstFun);
+    }
+    return true;
+  }
   default:
     llvm_unreachable("Unhandled case in tryExpandAtomicRMW");
   }
 }
 
-bool AtomicExpand::expandAtomicOpToLLSC(
-    Instruction *I, Value *Addr, AtomicOrdering MemOpOrder,
-    function_ref<Value *(IRBuilder<> &, Value *)> PerformOp) {
+namespace {
+
+/// Result values from createMaskInstrs helper.
+struct PartwordMaskValues {
+  Type *WordType;
+  Type *ValueType;
+  Value *AlignedAddr;
+  Value *ShiftAmt;
+  Value *Mask;
+  Value *Inv_Mask;
+};
+} // end anonymous namespace
+
+/// This is a helper function which builds instructions to provide
+/// values necessary for partword atomic operations. It takes an
+/// incoming address, Addr, and ValueType, and constructs the address,
+/// shift-amounts and masks needed to work with a larger value of size
+/// WordSize.
+///
+/// AlignedAddr: Addr rounded down to a multiple of WordSize
+///
+/// ShiftAmt: Number of bits to right-shift a WordSize value loaded
+///           from AlignAddr for it to have the same value as if
+///           ValueType was loaded from Addr.
+///
+/// Mask: Value to mask with the value loaded from AlignAddr to
+///       include only the part that would've been loaded from Addr.
+///
+/// Inv_Mask: The inverse of Mask.
+
+static PartwordMaskValues createMaskInstrs(IRBuilder<> &Builder, Instruction *I,
+                                           Type *ValueType, Value *Addr,
+                                           unsigned WordSize) {
+  PartwordMaskValues Ret;
+
   BasicBlock *BB = I->getParent();
   Function *F = BB->getParent();
+  Module *M = I->getModule();
+
   LLVMContext &Ctx = F->getContext();
+  const DataLayout &DL = M->getDataLayout();
+
+  unsigned ValueSize = DL.getTypeStoreSize(ValueType);
+
+  assert(ValueSize < WordSize);
+
+  Ret.ValueType = ValueType;
+  Ret.WordType = Type::getIntNTy(Ctx, WordSize * 8);
+
+  Type *WordPtrType =
+      Ret.WordType->getPointerTo(Addr->getType()->getPointerAddressSpace());
+
+  Value *AddrInt = Builder.CreatePtrToInt(Addr, DL.getIntPtrType(Ctx));
+  Ret.AlignedAddr = Builder.CreateIntToPtr(
+      Builder.CreateAnd(AddrInt, ~(uint64_t)(WordSize - 1)), WordPtrType,
+      "AlignedAddr");
+
+  Value *PtrLSB = Builder.CreateAnd(AddrInt, WordSize - 1, "PtrLSB");
+  if (DL.isLittleEndian()) {
+    // turn bytes into bits
+    Ret.ShiftAmt = Builder.CreateShl(PtrLSB, 3);
+  } else {
+    // turn bytes into bits, and count from the other side.
+    Ret.ShiftAmt =
+        Builder.CreateShl(Builder.CreateXor(PtrLSB, WordSize - ValueSize), 3);
+  }
+
+  Ret.ShiftAmt = Builder.CreateTrunc(Ret.ShiftAmt, Ret.WordType, "ShiftAmt");
+  Ret.Mask = Builder.CreateShl(
+      ConstantInt::get(Ret.WordType, (1 << ValueSize * 8) - 1), Ret.ShiftAmt,
+      "Mask");
+  Ret.Inv_Mask = Builder.CreateNot(Ret.Mask, "Inv_Mask");
+
+  return Ret;
+}
+
+/// Emit IR to implement a masked version of a given atomicrmw
+/// operation. (That is, only the bits under the Mask should be
+/// affected by the operation)
+static Value *performMaskedAtomicOp(AtomicRMWInst::BinOp Op,
+                                    IRBuilder<> &Builder, Value *Loaded,
+                                    Value *Shifted_Inc, Value *Inc,
+                                    const PartwordMaskValues &PMV) {
+  switch (Op) {
+  case AtomicRMWInst::Xchg: {
+    Value *Loaded_MaskOut = Builder.CreateAnd(Loaded, PMV.Inv_Mask);
+    Value *FinalVal = Builder.CreateOr(Loaded_MaskOut, Shifted_Inc);
+    return FinalVal;
+  }
+  case AtomicRMWInst::Or:
+  case AtomicRMWInst::Xor:
+    // Or/Xor won't affect any other bits, so can just be done
+    // directly.
+    return performAtomicOp(Op, Builder, Loaded, Shifted_Inc);
+  case AtomicRMWInst::Add:
+  case AtomicRMWInst::Sub:
+  case AtomicRMWInst::And:
+  case AtomicRMWInst::Nand: {
+    // The other arithmetic ops need to be masked into place.
+    Value *NewVal = performAtomicOp(Op, Builder, Loaded, Shifted_Inc);
+    Value *NewVal_Masked = Builder.CreateAnd(NewVal, PMV.Mask);
+    Value *Loaded_MaskOut = Builder.CreateAnd(Loaded, PMV.Inv_Mask);
+    Value *FinalVal = Builder.CreateOr(Loaded_MaskOut, NewVal_Masked);
+    return FinalVal;
+  }
+  case AtomicRMWInst::Max:
+  case AtomicRMWInst::Min:
+  case AtomicRMWInst::UMax:
+  case AtomicRMWInst::UMin: {
+    // Finally, comparison ops will operate on the full value, so
+    // truncate down to the original size, and expand out again after
+    // doing the operation.
+    Value *Loaded_Shiftdown = Builder.CreateTrunc(
+        Builder.CreateLShr(Loaded, PMV.ShiftAmt), PMV.ValueType);
+    Value *NewVal = performAtomicOp(Op, Builder, Loaded_Shiftdown, Inc);
+    Value *NewVal_Shiftup = Builder.CreateShl(
+        Builder.CreateZExt(NewVal, PMV.WordType), PMV.ShiftAmt);
+    Value *Loaded_MaskOut = Builder.CreateAnd(Loaded, PMV.Inv_Mask);
+    Value *FinalVal = Builder.CreateOr(Loaded_MaskOut, NewVal_Shiftup);
+    return FinalVal;
+  }
+  default:
+    llvm_unreachable("Unknown atomic op");
+  }
+}
+
+/// Expand a sub-word atomicrmw operation into an appropriate
+/// word-sized operation.
+///
+/// It will create an LL/SC or cmpxchg loop, as appropriate, the same
+/// way as a typical atomicrmw expansion. The only difference here is
+/// that the operation inside of the loop must operate only upon a
+/// part of the value.
+void AtomicExpand::expandPartwordAtomicRMW(
+    AtomicRMWInst *AI, TargetLoweringBase::AtomicExpansionKind ExpansionKind) {
+
+  assert(ExpansionKind == TargetLoweringBase::AtomicExpansionKind::CmpXChg);
+
+  AtomicOrdering MemOpOrder = AI->getOrdering();
+
+  IRBuilder<> Builder(AI);
+
+  PartwordMaskValues PMV =
+      createMaskInstrs(Builder, AI, AI->getType(), AI->getPointerOperand(),
+                       TLI->getMinCmpXchgSizeInBits() / 8);
+
+  Value *ValOperand_Shifted =
+      Builder.CreateShl(Builder.CreateZExt(AI->getValOperand(), PMV.WordType),
+                        PMV.ShiftAmt, "ValOperand_Shifted");
+
+  auto PerformPartwordOp = [&](IRBuilder<> &Builder, Value *Loaded) {
+    return performMaskedAtomicOp(AI->getOperation(), Builder, Loaded,
+                                 ValOperand_Shifted, AI->getValOperand(), PMV);
+  };
+
+  // TODO: When we're ready to support LLSC conversions too, use
+  // insertRMWLLSCLoop here for ExpansionKind==LLSC.
+  Value *OldResult =
+      insertRMWCmpXchgLoop(Builder, PMV.WordType, PMV.AlignedAddr, MemOpOrder,
+                           PerformPartwordOp, createCmpXchgInstFun);
+  Value *FinalOldResult = Builder.CreateTrunc(
+      Builder.CreateLShr(OldResult, PMV.ShiftAmt), PMV.ValueType);
+  AI->replaceAllUsesWith(FinalOldResult);
+  AI->eraseFromParent();
+}
+
+void AtomicExpand::expandPartwordCmpXchg(AtomicCmpXchgInst *CI) {
+  // The basic idea here is that we're expanding a cmpxchg of a
+  // smaller memory size up to a word-sized cmpxchg. To do this, we
+  // need to add a retry-loop for strong cmpxchg, so that
+  // modifications to other parts of the word don't cause a spurious
+  // failure.
+
+  // This generates code like the following:
+  //     [[Setup mask values PMV.*]]
+  //     %NewVal_Shifted = shl i32 %NewVal, %PMV.ShiftAmt
+  //     %Cmp_Shifted = shl i32 %Cmp, %PMV.ShiftAmt
+  //     %InitLoaded = load i32* %addr
+  //     %InitLoaded_MaskOut = and i32 %InitLoaded, %PMV.Inv_Mask
+  //     br partword.cmpxchg.loop
+  // partword.cmpxchg.loop:
+  //     %Loaded_MaskOut = phi i32 [ %InitLoaded_MaskOut, %entry ],
+  //        [ %OldVal_MaskOut, %partword.cmpxchg.failure ]
+  //     %FullWord_NewVal = or i32 %Loaded_MaskOut, %NewVal_Shifted
+  //     %FullWord_Cmp = or i32 %Loaded_MaskOut, %Cmp_Shifted
+  //     %NewCI = cmpxchg i32* %PMV.AlignedAddr, i32 %FullWord_Cmp,
+  //        i32 %FullWord_NewVal success_ordering failure_ordering
+  //     %OldVal = extractvalue { i32, i1 } %NewCI, 0
+  //     %Success = extractvalue { i32, i1 } %NewCI, 1
+  //     br i1 %Success, label %partword.cmpxchg.end,
+  //        label %partword.cmpxchg.failure
+  // partword.cmpxchg.failure:
+  //     %OldVal_MaskOut = and i32 %OldVal, %PMV.Inv_Mask
+  //     %ShouldContinue = icmp ne i32 %Loaded_MaskOut, %OldVal_MaskOut
+  //     br i1 %ShouldContinue, label %partword.cmpxchg.loop,
+  //         label %partword.cmpxchg.end
+  // partword.cmpxchg.end:
+  //    %tmp1 = lshr i32 %OldVal, %PMV.ShiftAmt
+  //    %FinalOldVal = trunc i32 %tmp1 to i8
+  //    %tmp2 = insertvalue { i8, i1 } undef, i8 %FinalOldVal, 0
+  //    %Res = insertvalue { i8, i1 } %25, i1 %Success, 1
+
+  Value *Addr = CI->getPointerOperand();
+  Value *Cmp = CI->getCompareOperand();
+  Value *NewVal = CI->getNewValOperand();
+
+  BasicBlock *BB = CI->getParent();
+  Function *F = BB->getParent();
+  IRBuilder<> Builder(CI);
+  LLVMContext &Ctx = Builder.getContext();
+
+  const int WordSize = TLI->getMinCmpXchgSizeInBits() / 8;
+
+  BasicBlock *EndBB =
+      BB->splitBasicBlock(CI->getIterator(), "partword.cmpxchg.end");
+  auto FailureBB =
+      BasicBlock::Create(Ctx, "partword.cmpxchg.failure", F, EndBB);
+  auto LoopBB = BasicBlock::Create(Ctx, "partword.cmpxchg.loop", F, FailureBB);
+
+  // The split call above "helpfully" added a branch at the end of BB
+  // (to the wrong place).
+  std::prev(BB->end())->eraseFromParent();
+  Builder.SetInsertPoint(BB);
+
+  PartwordMaskValues PMV = createMaskInstrs(
+      Builder, CI, CI->getCompareOperand()->getType(), Addr, WordSize);
+
+  // Shift the incoming values over, into the right location in the word.
+  Value *NewVal_Shifted =
+      Builder.CreateShl(Builder.CreateZExt(NewVal, PMV.WordType), PMV.ShiftAmt);
+  Value *Cmp_Shifted =
+      Builder.CreateShl(Builder.CreateZExt(Cmp, PMV.WordType), PMV.ShiftAmt);
+
+  // Load the entire current word, and mask into place the expected and new
+  // values
+  LoadInst *InitLoaded = Builder.CreateLoad(PMV.WordType, PMV.AlignedAddr);
+  InitLoaded->setVolatile(CI->isVolatile());
+  Value *InitLoaded_MaskOut = Builder.CreateAnd(InitLoaded, PMV.Inv_Mask);
+  Builder.CreateBr(LoopBB);
+
+  // partword.cmpxchg.loop:
+  Builder.SetInsertPoint(LoopBB);
+  PHINode *Loaded_MaskOut = Builder.CreatePHI(PMV.WordType, 2);
+  Loaded_MaskOut->addIncoming(InitLoaded_MaskOut, BB);
+
+  // Mask/Or the expected and new values into place in the loaded word.
+  Value *FullWord_NewVal = Builder.CreateOr(Loaded_MaskOut, NewVal_Shifted);
+  Value *FullWord_Cmp = Builder.CreateOr(Loaded_MaskOut, Cmp_Shifted);
+  AtomicCmpXchgInst *NewCI = Builder.CreateAtomicCmpXchg(
+      PMV.AlignedAddr, FullWord_Cmp, FullWord_NewVal, CI->getSuccessOrdering(),
+      CI->getFailureOrdering(), CI->getSynchScope());
+  NewCI->setVolatile(CI->isVolatile());
+  // When we're building a strong cmpxchg, we need a loop, so you
+  // might think we could use a weak cmpxchg inside. But, using strong
+  // allows the below comparison for ShouldContinue, and we're
+  // expecting the underlying cmpxchg to be a machine instruction,
+  // which is strong anyways.
+  NewCI->setWeak(CI->isWeak());
+
+  Value *OldVal = Builder.CreateExtractValue(NewCI, 0);
+  Value *Success = Builder.CreateExtractValue(NewCI, 1);
+
+  if (CI->isWeak())
+    Builder.CreateBr(EndBB);
+  else
+    Builder.CreateCondBr(Success, EndBB, FailureBB);
+
+  // partword.cmpxchg.failure:
+  Builder.SetInsertPoint(FailureBB);
+  // Upon failure, verify that the masked-out part of the loaded value
+  // has been modified.  If it didn't, abort the cmpxchg, since the
+  // masked-in part must've.
+  Value *OldVal_MaskOut = Builder.CreateAnd(OldVal, PMV.Inv_Mask);
+  Value *ShouldContinue = Builder.CreateICmpNE(Loaded_MaskOut, OldVal_MaskOut);
+  Builder.CreateCondBr(ShouldContinue, LoopBB, EndBB);
+
+  // Add the second value to the phi from above
+  Loaded_MaskOut->addIncoming(OldVal_MaskOut, FailureBB);
+
+  // partword.cmpxchg.end:
+  Builder.SetInsertPoint(CI);
+
+  Value *FinalOldVal = Builder.CreateTrunc(
+      Builder.CreateLShr(OldVal, PMV.ShiftAmt), PMV.ValueType);
+  Value *Res = UndefValue::get(CI->getType());
+  Res = Builder.CreateInsertValue(Res, FinalOldVal, 0);
+  Res = Builder.CreateInsertValue(Res, Success, 1);
+
+  CI->replaceAllUsesWith(Res);
+  CI->eraseFromParent();
+}
+
+void AtomicExpand::expandAtomicOpToLLSC(
+    Instruction *I, Type *ResultType, Value *Addr, AtomicOrdering MemOpOrder,
+    function_ref<Value *(IRBuilder<> &, Value *)> PerformOp) {
+  IRBuilder<> Builder(I);
+  Value *Loaded =
+      insertRMWLLSCLoop(Builder, ResultType, Addr, MemOpOrder, PerformOp);
+
+  I->replaceAllUsesWith(Loaded);
+  I->eraseFromParent();
+}
+
+Value *AtomicExpand::insertRMWLLSCLoop(
+    IRBuilder<> &Builder, Type *ResultTy, Value *Addr,
+    AtomicOrdering MemOpOrder,
+    function_ref<Value *(IRBuilder<> &, Value *)> PerformOp) {
+  LLVMContext &Ctx = Builder.getContext();
+  BasicBlock *BB = Builder.GetInsertBlock();
+  Function *F = BB->getParent();
 
   // Given: atomicrmw some_op iN* %addr, iN %incr ordering
   //
   // The standard expansion we produce is:
   //     [...]
-  //     fence?
   // atomicrmw.start:
   //     %loaded = @load.linked(%addr)
   //     %new = some_op iN %loaded, %incr
@@ -531,17 +880,13 @@
   //     %try_again = icmp i32 ne %stored, 0
   //     br i1 %try_again, label %loop, label %atomicrmw.end
   // atomicrmw.end:
-  //     fence?
   //     [...]
-  BasicBlock *ExitBB = BB->splitBasicBlock(I->getIterator(), "atomicrmw.end");
+  BasicBlock *ExitBB =
+      BB->splitBasicBlock(Builder.GetInsertPoint(), "atomicrmw.end");
   BasicBlock *LoopBB =  BasicBlock::Create(Ctx, "atomicrmw.start", F, ExitBB);
 
-  // This grabs the DebugLoc from I.
-  IRBuilder<> Builder(I);
-
   // The split call above "helpfully" added a branch at the end of BB (to the
-  // wrong place), but we might want a fence too. It's easiest to just remove
-  // the branch entirely.
+  // wrong place).
   std::prev(BB->end())->eraseFromParent();
   Builder.SetInsertPoint(BB);
   Builder.CreateBr(LoopBB);
@@ -559,11 +904,7 @@
   Builder.CreateCondBr(TryAgain, LoopBB, ExitBB);
 
   Builder.SetInsertPoint(ExitBB, ExitBB->begin());
-
-  I->replaceAllUsesWith(Loaded);
-  I->eraseFromParent();
-
-  return true;
+  return Loaded;
 }
 
 /// Convert an atomic cmpxchg of a non-integral type to an integer cmpxchg of
@@ -867,17 +1208,14 @@
   return false;
 }
 
-bool llvm::expandAtomicRMWToCmpXchg(AtomicRMWInst *AI,
-                                    CreateCmpXchgInstFun CreateCmpXchg) {
-  assert(AI);
-
-  AtomicOrdering MemOpOrder = AI->getOrdering() == AtomicOrdering::Unordered
-                                  ? AtomicOrdering::Monotonic
-                                  : AI->getOrdering();
-  Value *Addr = AI->getPointerOperand();
-  BasicBlock *BB = AI->getParent();
+Value *AtomicExpand::insertRMWCmpXchgLoop(
+    IRBuilder<> &Builder, Type *ResultTy, Value *Addr,
+    AtomicOrdering MemOpOrder,
+    function_ref<Value *(IRBuilder<> &, Value *)> PerformOp,
+    CreateCmpXchgInstFun CreateCmpXchg) {
+  LLVMContext &Ctx = Builder.getContext();
+  BasicBlock *BB = Builder.GetInsertBlock();
   Function *F = BB->getParent();
-  LLVMContext &Ctx = F->getContext();
 
   // Given: atomicrmw some_op iN* %addr, iN %incr ordering
   //
@@ -894,34 +1232,34 @@
   //     br i1 %success, label %atomicrmw.end, label %loop
   // atomicrmw.end:
   //     [...]
-  BasicBlock *ExitBB = BB->splitBasicBlock(AI->getIterator(), "atomicrmw.end");
+  BasicBlock *ExitBB =
+      BB->splitBasicBlock(Builder.GetInsertPoint(), "atomicrmw.end");
   BasicBlock *LoopBB = BasicBlock::Create(Ctx, "atomicrmw.start", F, ExitBB);
 
-  // This grabs the DebugLoc from AI.
-  IRBuilder<> Builder(AI);
-
   // The split call above "helpfully" added a branch at the end of BB (to the
   // wrong place), but we want a load. It's easiest to just remove
   // the branch entirely.
   std::prev(BB->end())->eraseFromParent();
   Builder.SetInsertPoint(BB);
-  LoadInst *InitLoaded = Builder.CreateLoad(Addr);
+  LoadInst *InitLoaded = Builder.CreateLoad(ResultTy, Addr);
   // Atomics require at least natural alignment.
-  InitLoaded->setAlignment(AI->getType()->getPrimitiveSizeInBits() / 8);
+  InitLoaded->setAlignment(ResultTy->getPrimitiveSizeInBits() / 8);
   Builder.CreateBr(LoopBB);
 
   // Start the main loop block now that we've taken care of the preliminaries.
   Builder.SetInsertPoint(LoopBB);
-  PHINode *Loaded = Builder.CreatePHI(AI->getType(), 2, "loaded");
+  PHINode *Loaded = Builder.CreatePHI(ResultTy, 2, "loaded");
   Loaded->addIncoming(InitLoaded, BB);
 
-  Value *NewVal =
-      performAtomicOp(AI->getOperation(), Builder, Loaded, AI->getValOperand());
+  Value *NewVal = PerformOp(Builder, Loaded);
 
   Value *NewLoaded = nullptr;
   Value *Success = nullptr;
 
-  CreateCmpXchg(Builder, Addr, Loaded, NewVal, MemOpOrder,
+  CreateCmpXchg(Builder, Addr, Loaded, NewVal,
+                MemOpOrder == AtomicOrdering::Unordered
+                    ? AtomicOrdering::Monotonic
+                    : MemOpOrder,
                 Success, NewLoaded);
   assert(Success && NewLoaded);
 
@@ -930,10 +1268,23 @@
   Builder.CreateCondBr(Success, ExitBB, LoopBB);
 
   Builder.SetInsertPoint(ExitBB, ExitBB->begin());
+  return NewLoaded;
+}
 
-  AI->replaceAllUsesWith(NewLoaded);
+// Note: This function is exposed externally by AtomicExpandUtils.h
+bool llvm::expandAtomicRMWToCmpXchg(AtomicRMWInst *AI,
+                                    CreateCmpXchgInstFun CreateCmpXchg) {
+  IRBuilder<> Builder(AI);
+  Value *Loaded = AtomicExpand::insertRMWCmpXchgLoop(
+      Builder, AI->getType(), AI->getPointerOperand(), AI->getOrdering(),
+      [&](IRBuilder<> &Builder, Value *Loaded) {
+        return performAtomicOp(AI->getOperation(), Builder, Loaded,
+                               AI->getValOperand());
+      },
+      CreateCmpXchg);
+
+  AI->replaceAllUsesWith(Loaded);
   AI->eraseFromParent();
-
   return true;
 }