[RISCV] Implement codegen for cmpxchg on RV32IA

Utilise a similar ('late') lowering strategy to D47882. The changes to 
AtomicExpandPass allow this strategy to be utilised by other targets which 
implement shouldExpandAtomicCmpXchgInIR.

All cmpxchg are lowered as 'strong' currently and failure ordering is ignored. 
This is conservative but correct.

Differential Revision: https://reviews.llvm.org/D48131

llvm-svn: 347914
diff --git a/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp b/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp
index 1c23680..35c185a 100644
--- a/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp
+++ b/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp
@@ -52,6 +52,9 @@
                             MachineBasicBlock::iterator MBBI,
                             AtomicRMWInst::BinOp, bool IsMasked, int Width,
                             MachineBasicBlock::iterator &NextMBBI);
+  bool expandAtomicCmpXchg(MachineBasicBlock &MBB,
+                           MachineBasicBlock::iterator MBBI, bool IsMasked,
+                           int Width, MachineBasicBlock::iterator &NextMBBI);
 };
 
 char RISCVExpandPseudo::ID = 0;
@@ -106,6 +109,10 @@
   case RISCV::PseudoMaskedAtomicLoadUMin32:
     return expandAtomicMinMaxOp(MBB, MBBI, AtomicRMWInst::UMin, true, 32,
                                 NextMBBI);
+  case RISCV::PseudoCmpXchg32:
+    return expandAtomicCmpXchg(MBB, MBBI, false, 32, NextMBBI);
+  case RISCV::PseudoMaskedCmpXchg32:
+    return expandAtomicCmpXchg(MBB, MBBI, true, 32, NextMBBI);
   }
 
   return false;
@@ -441,6 +448,103 @@
   return true;
 }
 
+bool RISCVExpandPseudo::expandAtomicCmpXchg(
+    MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, bool IsMasked,
+    int Width, MachineBasicBlock::iterator &NextMBBI) {
+  assert(Width == 32 && "RV64 atomic expansion currently unsupported");
+  MachineInstr &MI = *MBBI;
+  DebugLoc DL = MI.getDebugLoc();
+  MachineFunction *MF = MBB.getParent();
+  auto LoopHeadMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
+  auto LoopTailMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
+  auto DoneMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
+
+  // Insert new MBBs.
+  MF->insert(++MBB.getIterator(), LoopHeadMBB);
+  MF->insert(++LoopHeadMBB->getIterator(), LoopTailMBB);
+  MF->insert(++LoopTailMBB->getIterator(), DoneMBB);
+
+  // Set up successors and transfer remaining instructions to DoneMBB.
+  LoopHeadMBB->addSuccessor(LoopTailMBB);
+  LoopHeadMBB->addSuccessor(DoneMBB);
+  LoopTailMBB->addSuccessor(DoneMBB);
+  LoopTailMBB->addSuccessor(LoopHeadMBB);
+  DoneMBB->splice(DoneMBB->end(), &MBB, MI, MBB.end());
+  DoneMBB->transferSuccessors(&MBB);
+  MBB.addSuccessor(LoopHeadMBB);
+
+  unsigned DestReg = MI.getOperand(0).getReg();
+  unsigned ScratchReg = MI.getOperand(1).getReg();
+  unsigned AddrReg = MI.getOperand(2).getReg();
+  unsigned CmpValReg = MI.getOperand(3).getReg();
+  unsigned NewValReg = MI.getOperand(4).getReg();
+  AtomicOrdering Ordering =
+      static_cast<AtomicOrdering>(MI.getOperand(IsMasked ? 6 : 5).getImm());
+
+  if (!IsMasked) {
+    // .loophead:
+    //   lr.w dest, (addr)
+    //   bne dest, cmpval, done
+    BuildMI(LoopHeadMBB, DL, TII->get(getLRForRMW32(Ordering)), DestReg)
+        .addReg(AddrReg);
+    BuildMI(LoopHeadMBB, DL, TII->get(RISCV::BNE))
+        .addReg(DestReg)
+        .addReg(CmpValReg)
+        .addMBB(DoneMBB);
+    // .looptail:
+    //   sc.w scratch, newval, (addr)
+    //   bnez scratch, loophead
+    BuildMI(LoopTailMBB, DL, TII->get(getSCForRMW32(Ordering)), ScratchReg)
+        .addReg(AddrReg)
+        .addReg(NewValReg);
+    BuildMI(LoopTailMBB, DL, TII->get(RISCV::BNE))
+        .addReg(ScratchReg)
+        .addReg(RISCV::X0)
+        .addMBB(LoopHeadMBB);
+  } else {
+    // .loophead:
+    //   lr.w dest, (addr)
+    //   and scratch, dest, mask
+    //   bne scratch, cmpval, done
+    unsigned MaskReg = MI.getOperand(5).getReg();
+    BuildMI(LoopHeadMBB, DL, TII->get(getLRForRMW32(Ordering)), DestReg)
+        .addReg(AddrReg);
+    BuildMI(LoopHeadMBB, DL, TII->get(RISCV::AND), ScratchReg)
+        .addReg(DestReg)
+        .addReg(MaskReg);
+    BuildMI(LoopHeadMBB, DL, TII->get(RISCV::BNE))
+        .addReg(ScratchReg)
+        .addReg(CmpValReg)
+        .addMBB(DoneMBB);
+
+    // .looptail:
+    //   xor scratch, dest, newval
+    //   and scratch, scratch, mask
+    //   xor scratch, dest, scratch
+    //   sc.w scratch, scratch, (adrr)
+    //   bnez scratch, loophead
+    insertMaskedMerge(TII, DL, LoopTailMBB, ScratchReg, DestReg, NewValReg,
+                      MaskReg, ScratchReg);
+    BuildMI(LoopTailMBB, DL, TII->get(getSCForRMW32(Ordering)), ScratchReg)
+        .addReg(AddrReg)
+        .addReg(ScratchReg);
+    BuildMI(LoopTailMBB, DL, TII->get(RISCV::BNE))
+        .addReg(ScratchReg)
+        .addReg(RISCV::X0)
+        .addMBB(LoopHeadMBB);
+  }
+
+  NextMBBI = MBB.end();
+  MI.eraseFromParent();
+
+  LivePhysRegs LiveRegs;
+  computeAndAddLiveIns(LiveRegs, *LoopHeadMBB);
+  computeAndAddLiveIns(LiveRegs, *LoopTailMBB);
+  computeAndAddLiveIns(LiveRegs, *DoneMBB);
+
+  return true;
+}
+
 } // end of anonymous namespace
 
 INITIALIZE_PASS(RISCVExpandPseudo, "riscv-expand-pseudo",
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 56900a6..d8c63a4 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -186,6 +186,7 @@
   case Intrinsic::riscv_masked_atomicrmw_min_i32:
   case Intrinsic::riscv_masked_atomicrmw_umax_i32:
   case Intrinsic::riscv_masked_atomicrmw_umin_i32:
+  case Intrinsic::riscv_masked_cmpxchg_i32:
     PointerType *PtrTy = cast<PointerType>(I.getArgOperand(0)->getType());
     Info.opc = ISD::INTRINSIC_W_CHAIN;
     Info.memVT = MVT::getVT(PtrTy->getElementType());
@@ -1708,3 +1709,23 @@
 
   return Builder.CreateCall(LrwOpScwLoop, {AlignedAddr, Incr, Mask, Ordering});
 }
+
+TargetLowering::AtomicExpansionKind
+RISCVTargetLowering::shouldExpandAtomicCmpXchgInIR(
+    AtomicCmpXchgInst *CI) const {
+  unsigned Size = CI->getCompareOperand()->getType()->getPrimitiveSizeInBits();
+  if (Size == 8 || Size == 16)
+    return AtomicExpansionKind::MaskedIntrinsic;
+  return AtomicExpansionKind::None;
+}
+
+Value *RISCVTargetLowering::emitMaskedAtomicCmpXchgIntrinsic(
+    IRBuilder<> &Builder, AtomicCmpXchgInst *CI, Value *AlignedAddr,
+    Value *CmpVal, Value *NewVal, Value *Mask, AtomicOrdering Ord) const {
+  Value *Ordering = Builder.getInt32(static_cast<uint32_t>(Ord));
+  Type *Tys[] = {AlignedAddr->getType()};
+  Function *MaskedCmpXchg = Intrinsic::getDeclaration(
+      CI->getModule(), Intrinsic::riscv_masked_cmpxchg_i32, Tys);
+  return Builder.CreateCall(MaskedCmpXchg,
+                            {AlignedAddr, CmpVal, NewVal, Mask, Ordering});
+}
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index fe988bb..a99e2c2 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -126,6 +126,13 @@
   virtual Value *emitMaskedAtomicRMWIntrinsic(
       IRBuilder<> &Builder, AtomicRMWInst *AI, Value *AlignedAddr, Value *Incr,
       Value *Mask, Value *ShiftAmt, AtomicOrdering Ord) const override;
+  TargetLowering::AtomicExpansionKind
+  shouldExpandAtomicCmpXchgInIR(AtomicCmpXchgInst *CI) const override;
+  virtual Value *
+  emitMaskedAtomicCmpXchgIntrinsic(IRBuilder<> &Builder, AtomicCmpXchgInst *CI,
+                                   Value *AlignedAddr, Value *CmpVal,
+                                   Value *NewVal, Value *Mask,
+                                   AtomicOrdering Ord) const override;
 };
 }
 
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoA.td b/llvm/lib/Target/RISCV/RISCVInstrInfoA.td
index 180434e..9cb1d2f 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoA.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoA.td
@@ -153,7 +153,7 @@
 }
 
 def PseudoAtomicLoadNand32 : PseudoAMO;
-// Ordering constants must be kept in sync with the AtomicOrdering enum in 
+// Ordering constants must be kept in sync with the AtomicOrdering enum in
 // AtomicOrdering.h.
 def : Pat<(atomic_load_nand_32_monotonic GPR:$addr, GPR:$incr),
           (PseudoAtomicLoadNand32 GPR:$addr, GPR:$incr, 2)>;
@@ -230,4 +230,49 @@
 def PseudoMaskedAtomicLoadUMin32 : PseudoMaskedAMOUMinUMax;
 def : PseudoMaskedAMOPat<int_riscv_masked_atomicrmw_umin_i32,
                          PseudoMaskedAtomicLoadUMin32>;
+
+/// Compare and exchange
+
+class PseudoCmpXchg
+    : Pseudo<(outs GPR:$res, GPR:$scratch),
+             (ins GPR:$addr, GPR:$cmpval, GPR:$newval, i32imm:$ordering), []> {
+  let Constraints = "@earlyclobber $res,@earlyclobber $scratch";
+  let mayLoad = 1;
+  let mayStore = 1;
+  let hasSideEffects = 0;
+}
+
+// Ordering constants must be kept in sync with the AtomicOrdering enum in
+// AtomicOrdering.h.
+multiclass PseudoCmpXchgPat<string Op, Pseudo CmpXchgInst> {
+  def : Pat<(!cast<PatFrag>(Op#"_monotonic") GPR:$addr, GPR:$cmp, GPR:$new),
+            (CmpXchgInst GPR:$addr, GPR:$cmp, GPR:$new, 2)>;
+  def : Pat<(!cast<PatFrag>(Op#"_acquire") GPR:$addr, GPR:$cmp, GPR:$new),
+            (CmpXchgInst GPR:$addr, GPR:$cmp, GPR:$new, 4)>;
+  def : Pat<(!cast<PatFrag>(Op#"_release") GPR:$addr, GPR:$cmp, GPR:$new),
+            (CmpXchgInst GPR:$addr, GPR:$cmp, GPR:$new, 5)>;
+  def : Pat<(!cast<PatFrag>(Op#"_acq_rel") GPR:$addr, GPR:$cmp, GPR:$new),
+            (CmpXchgInst GPR:$addr, GPR:$cmp, GPR:$new, 6)>;
+  def : Pat<(!cast<PatFrag>(Op#"_seq_cst") GPR:$addr, GPR:$cmp, GPR:$new),
+            (CmpXchgInst GPR:$addr, GPR:$cmp, GPR:$new, 7)>;
+}
+
+def PseudoCmpXchg32 : PseudoCmpXchg;
+defm : PseudoCmpXchgPat<"atomic_cmp_swap_32", PseudoCmpXchg32>;
+
+def PseudoMaskedCmpXchg32
+    : Pseudo<(outs GPR:$res, GPR:$scratch),
+             (ins GPR:$addr, GPR:$cmpval, GPR:$newval, GPR:$mask,
+              i32imm:$ordering), []> {
+  let Constraints = "@earlyclobber $res,@earlyclobber $scratch";
+  let mayLoad = 1;
+  let mayStore = 1;
+  let hasSideEffects = 0;
+}
+
+def : Pat<(int_riscv_masked_cmpxchg_i32
+            GPR:$addr, GPR:$cmpval, GPR:$newval, GPR:$mask, imm:$ordering),
+          (PseudoMaskedCmpXchg32
+            GPR:$addr, GPR:$cmpval, GPR:$newval, GPR:$mask, imm:$ordering)>;
+
 } // Predicates = [HasStdExtA]