Basic codegen for MTE stack tagging.

Implement IR intrinsics for stack tagging. Generated code is very
unoptimized for now.

Two special intrinsics, llvm.aarch64.irg.sp and llvm.aarch64.tagp are
used to implement a tagged stack frame pointer in a virtual register.

Differential Revision: https://reviews.llvm.org/D64172

llvm-svn: 366360
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 49a328b..c70906d 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -3666,7 +3666,8 @@
     const CallBase *Call) {
   return Call->getIntrinsicID() == Intrinsic::launder_invariant_group ||
          Call->getIntrinsicID() == Intrinsic::strip_invariant_group ||
-         Call->getIntrinsicID() == Intrinsic::aarch64_irg;
+         Call->getIntrinsicID() == Intrinsic::aarch64_irg ||
+         Call->getIntrinsicID() == Intrinsic::aarch64_tagp;
 }
 
 /// \p PN defines a loop-variant pointer to an object.  Check if the
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 61ec292..e818dd2 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -6805,6 +6805,19 @@
     // MachineFunction in SelectionDAGISel::PrepareEHLandingPad. We can safely
     // delete it now.
     return;
+
+  case Intrinsic::aarch64_settag:
+  case Intrinsic::aarch64_settag_zero: {
+    const SelectionDAGTargetInfo &TSI = DAG.getSelectionDAGInfo();
+    bool ZeroMemory = Intrinsic == Intrinsic::aarch64_settag_zero;
+    SDValue Val = TSI.EmitTargetCodeForSetTag(
+        DAG, getCurSDLoc(), getRoot(), getValue(I.getArgOperand(0)),
+        getValue(I.getArgOperand(1)), MachinePointerInfo(I.getArgOperand(0)),
+        ZeroMemory);
+    DAG.setRoot(Val);
+    setValue(&I, Val);
+    return;
+  }
   }
 }
 
diff --git a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
index 68076d2..210c10e 100644
--- a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
@@ -15,6 +15,7 @@
 
 #include "AArch64ExpandImm.h"
 #include "AArch64InstrInfo.h"
+#include "AArch64MachineFunctionInfo.h"
 #include "AArch64Subtarget.h"
 #include "MCTargetDesc/AArch64AddressingModes.h"
 #include "Utils/AArch64BaseInfo.h"
@@ -74,6 +75,9 @@
   bool expandCMP_SWAP_128(MachineBasicBlock &MBB,
                           MachineBasicBlock::iterator MBBI,
                           MachineBasicBlock::iterator &NextMBBI);
+  bool expandSetTagLoop(MachineBasicBlock &MBB,
+                        MachineBasicBlock::iterator MBBI,
+                        MachineBasicBlock::iterator &NextMBBI);
 };
 
 } // end anonymous namespace
@@ -336,6 +340,64 @@
   return true;
 }
 
+bool AArch64ExpandPseudo::expandSetTagLoop(
+    MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
+    MachineBasicBlock::iterator &NextMBBI) {
+  MachineInstr &MI = *MBBI;
+  DebugLoc DL = MI.getDebugLoc();
+  Register SizeReg = MI.getOperand(2).getReg();
+  Register AddressReg = MI.getOperand(3).getReg();
+
+  MachineFunction *MF = MBB.getParent();
+
+  bool ZeroData = MI.getOpcode() == AArch64::STZGloop;
+  const unsigned OpCode =
+      ZeroData ? AArch64::STZ2GPostIndex : AArch64::ST2GPostIndex;
+
+  auto LoopBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
+  auto DoneBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
+
+  MF->insert(++MBB.getIterator(), LoopBB);
+  MF->insert(++LoopBB->getIterator(), DoneBB);
+
+  BuildMI(LoopBB, DL, TII->get(OpCode))
+      .addDef(AddressReg)
+      .addReg(AddressReg)
+      .addReg(AddressReg)
+      .addImm(2)
+      .cloneMemRefs(MI)
+      .setMIFlags(MI.getFlags());
+  BuildMI(LoopBB, DL, TII->get(AArch64::SUBXri))
+      .addDef(SizeReg)
+      .addReg(SizeReg)
+      .addImm(16 * 2)
+      .addImm(0);
+  BuildMI(LoopBB, DL, TII->get(AArch64::CBNZX)).addUse(SizeReg).addMBB(LoopBB);
+
+  LoopBB->addSuccessor(LoopBB);
+  LoopBB->addSuccessor(DoneBB);
+
+  DoneBB->splice(DoneBB->end(), &MBB, MI, MBB.end());
+  DoneBB->transferSuccessors(&MBB);
+
+  MBB.addSuccessor(LoopBB);
+
+  NextMBBI = MBB.end();
+  MI.eraseFromParent();
+  // Recompute liveness bottom up.
+  LivePhysRegs LiveRegs;
+  computeAndAddLiveIns(LiveRegs, *DoneBB);
+  computeAndAddLiveIns(LiveRegs, *LoopBB);
+  // Do an extra pass in the loop to get the loop carried dependencies right.
+  // FIXME: is this necessary?
+  LoopBB->clearLiveIns();
+  computeAndAddLiveIns(LiveRegs, *LoopBB);
+  DoneBB->clearLiveIns();
+  computeAndAddLiveIns(LiveRegs, *DoneBB);
+
+  return true;
+}
+
 /// If MBBI references a pseudo instruction that should be expanded here,
 /// do the expansion and return true.  Otherwise return false.
 bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
@@ -569,6 +631,46 @@
     MI.eraseFromParent();
     return true;
    }
+   case AArch64::IRGstack: {
+     MachineFunction &MF = *MBB.getParent();
+     const AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
+     const AArch64FrameLowering *TFI =
+         MF.getSubtarget<AArch64Subtarget>().getFrameLowering();
+
+     // IRG does not allow immediate offset. getTaggedBasePointerOffset should
+     // almost always point to SP-after-prologue; if not, emit a longer
+     // instruction sequence.
+     int BaseOffset = -AFI->getTaggedBasePointerOffset();
+     unsigned FrameReg;
+     int FrameRegOffset = TFI->resolveFrameOffsetReference(
+         MF, BaseOffset, false /*isFixed*/, FrameReg, /*PreferFP=*/false,
+         /*ForSimm=*/true);
+     Register SrcReg = FrameReg;
+     if (FrameRegOffset != 0) {
+       // Use output register as temporary.
+       SrcReg = MI.getOperand(0).getReg();
+       emitFrameOffset(MBB, &MI, MI.getDebugLoc(), SrcReg, FrameReg,
+                       FrameRegOffset, TII);
+     }
+     BuildMI(MBB, MBBI, MI.getDebugLoc(), TII->get(AArch64::IRG))
+         .add(MI.getOperand(0))
+         .addUse(SrcReg)
+         .add(MI.getOperand(2));
+     MI.eraseFromParent();
+     return true;
+   }
+   case AArch64::TAGPstack: {
+     BuildMI(MBB, MBBI, MI.getDebugLoc(), TII->get(AArch64::ADDG))
+         .add(MI.getOperand(0))
+         .add(MI.getOperand(1))
+         .add(MI.getOperand(2))
+         .add(MI.getOperand(4));
+     MI.eraseFromParent();
+     return true;
+   }
+   case AArch64::STGloop:
+   case AArch64::STZGloop:
+     return expandSetTagLoop(MBB, MBBI, NextMBBI);
   }
   return false;
 }
diff --git a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
index fed0fc7..8c6e5cb 100644
--- a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
@@ -842,6 +842,10 @@
   if (MF.getFunction().getCallingConv() == CallingConv::GHC)
     return;
 
+  // Set tagged base pointer to the bottom of the stack frame.
+  // Ideally it should match SP value after prologue.
+  AFI->setTaggedBasePointerOffset(MFI.getStackSize());
+
   // getStackSize() includes all the locals in its size calculation. We don't
   // include these locals when computing the stack size of a funclet, as they
   // are allocated in the parent's stack frame and accessed via the frame
diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
index 2a911c4..cd7e927 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
@@ -157,6 +157,9 @@
 
   bool tryIndexedLoad(SDNode *N);
 
+  bool trySelectStackSlotTagP(SDNode *N);
+  void SelectTagP(SDNode *N);
+
   void SelectLoad(SDNode *N, unsigned NumVecs, unsigned Opc,
                      unsigned SubRegIdx);
   void SelectPostLoad(SDNode *N, unsigned NumVecs, unsigned Opc,
@@ -703,7 +706,7 @@
     return true;
   }
 
-  // As opposed to the (12-bit) Indexed addressing mode below, the 7-bit signed
+  // As opposed to the (12-bit) Indexed addressing mode below, the 7/9-bit signed
   // selected here doesn't support labels/immediates, only base+offset.
   if (CurDAG->isBaseWithConstantOffset(N)) {
     if (ConstantSDNode *RHS = dyn_cast<ConstantSDNode>(N.getOperand(1))) {
@@ -2790,6 +2793,58 @@
   return true;
 }
 
+bool AArch64DAGToDAGISel::trySelectStackSlotTagP(SDNode *N) {
+  // tagp(FrameIndex, IRGstack, tag_offset):
+  // since the offset between FrameIndex and IRGstack is a compile-time
+  // constant, this can be lowered to a single ADDG instruction.
+  if (!(isa<FrameIndexSDNode>(N->getOperand(1)))) {
+    return false;
+  }
+
+  SDValue IRG_SP = N->getOperand(2);
+  if (IRG_SP->getOpcode() != ISD::INTRINSIC_W_CHAIN ||
+      cast<ConstantSDNode>(IRG_SP->getOperand(1))->getZExtValue() !=
+          Intrinsic::aarch64_irg_sp) {
+    return false;
+  }
+
+  const TargetLowering *TLI = getTargetLowering();
+  SDLoc DL(N);
+  int FI = cast<FrameIndexSDNode>(N->getOperand(1))->getIndex();
+  SDValue FiOp = CurDAG->getTargetFrameIndex(
+      FI, TLI->getPointerTy(CurDAG->getDataLayout()));
+  int TagOffset = cast<ConstantSDNode>(N->getOperand(3))->getZExtValue();
+
+  SDNode *Out = CurDAG->getMachineNode(
+      AArch64::TAGPstack, DL, MVT::i64,
+      {FiOp, CurDAG->getTargetConstant(0, DL, MVT::i64), N->getOperand(2),
+       CurDAG->getTargetConstant(TagOffset, DL, MVT::i64)});
+  ReplaceNode(N, Out);
+  return true;
+}
+
+void AArch64DAGToDAGISel::SelectTagP(SDNode *N) {
+  assert(isa<ConstantSDNode>(N->getOperand(3)) &&
+         "llvm.aarch64.tagp third argument must be an immediate");
+  if (trySelectStackSlotTagP(N))
+    return;
+  // FIXME: above applies in any case when offset between Op1 and Op2 is a
+  // compile-time constant, not just for stack allocations.
+
+  // General case for unrelated pointers in Op1 and Op2.
+  SDLoc DL(N);
+  int TagOffset = cast<ConstantSDNode>(N->getOperand(3))->getZExtValue();
+  SDNode *N1 = CurDAG->getMachineNode(AArch64::SUBP, DL, MVT::i64,
+                                      {N->getOperand(1), N->getOperand(2)});
+  SDNode *N2 = CurDAG->getMachineNode(AArch64::ADDXrr, DL, MVT::i64,
+                                      {SDValue(N1, 0), N->getOperand(2)});
+  SDNode *N3 = CurDAG->getMachineNode(
+      AArch64::ADDG, DL, MVT::i64,
+      {SDValue(N2, 0), CurDAG->getTargetConstant(0, DL, MVT::i64),
+       CurDAG->getTargetConstant(TagOffset, DL, MVT::i64)});
+  ReplaceNode(N, N3);
+}
+
 void AArch64DAGToDAGISel::Select(SDNode *Node) {
   // If we have a custom node, we already have selected!
   if (Node->isMachineOpcode()) {
@@ -3283,6 +3338,9 @@
     switch (IntNo) {
     default:
       break;
+    case Intrinsic::aarch64_tagp:
+      SelectTagP(Node);
+      return;
     case Intrinsic::aarch64_neon_tbl2:
       SelectTable(Node, 2,
                   VT == MVT::v8i8 ? AArch64::TBLv8i8Two : AArch64::TBLv16i8Two,
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 11ee1a5..7becc99 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1234,6 +1234,10 @@
   case AArch64ISD::FRECPS:            return "AArch64ISD::FRECPS";
   case AArch64ISD::FRSQRTE:           return "AArch64ISD::FRSQRTE";
   case AArch64ISD::FRSQRTS:           return "AArch64ISD::FRSQRTS";
+  case AArch64ISD::STG:               return "AArch64ISD::STG";
+  case AArch64ISD::STZG:              return "AArch64ISD::STZG";
+  case AArch64ISD::ST2G:              return "AArch64ISD::ST2G";
+  case AArch64ISD::STZ2G:             return "AArch64ISD::STZ2G";
   }
   return nullptr;
 }
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 754caaf..4421c31 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -214,7 +214,13 @@
   LD4LANEpost,
   ST2LANEpost,
   ST3LANEpost,
-  ST4LANEpost
+  ST4LANEpost,
+
+  STG,
+  STZG,
+  ST2G,
+  STZ2G
+
 };
 
 } // end namespace AArch64ISD
diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
index 74fa5ef..d619137 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
@@ -4067,12 +4067,12 @@
                     (outs), (ins GPR64sp:$Rt, GPR64sp:$Rn, simm9s16:$offset)>;
   def PreIndex :
     BaseMemTagStore<opc1, 0b11, insn, "\t$Rt, [$Rn, $offset]!",
-                    "$Rn = $wback,@earlyclobber $wback",
+                    "$Rn = $wback",
                     (outs GPR64sp:$wback),
                     (ins GPR64sp:$Rt, GPR64sp:$Rn, simm9s16:$offset)>;
   def PostIndex :
     BaseMemTagStore<opc1, 0b01, insn, "\t$Rt, [$Rn], $offset",
-                    "$Rn = $wback,@earlyclobber $wback",
+                    "$Rn = $wback",
                     (outs GPR64sp:$wback),
                     (ins GPR64sp:$Rt, GPR64sp:$Rn, simm9s16:$offset)>;
 
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 599a5ab..215e96a 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -1772,6 +1772,7 @@
   case AArch64::STNPWi:
   case AArch64::STNPSi:
   case AArch64::LDG:
+  case AArch64::STGPi:
     return 3;
   case AArch64::ADDG:
   case AArch64::STGOffset:
@@ -2151,6 +2152,7 @@
     MaxOffset = 4095;
     break;
   case AArch64::ADDG:
+  case AArch64::TAGPstack:
     Scale = 16;
     Width = 0;
     MinOffset = 0;
@@ -2158,10 +2160,23 @@
     break;
   case AArch64::LDG:
   case AArch64::STGOffset:
+  case AArch64::STZGOffset:
     Scale = Width = 16;
     MinOffset = -256;
     MaxOffset = 255;
     break;
+  case AArch64::ST2GOffset:
+  case AArch64::STZ2GOffset:
+    Scale = 16;
+    Width = 32;
+    MinOffset = -256;
+    MaxOffset = 255;
+    break;
+  case AArch64::STGPi:
+    Scale = Width = 16;
+    MinOffset = -64;
+    MaxOffset = 63;
+    break;
   }
 
   return true;
@@ -3257,6 +3272,8 @@
   case AArch64::ST1Twov1d:
   case AArch64::ST1Threev1d:
   case AArch64::ST1Fourv1d:
+  case AArch64::IRG:
+  case AArch64::IRGstack:
     return AArch64FrameOffsetCannotUpdate;
   }
 
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 897b3eb..eed53f3 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -409,6 +409,12 @@
 def AArch64smaxv    : SDNode<"AArch64ISD::SMAXV", SDT_AArch64UnaryVec>;
 def AArch64umaxv    : SDNode<"AArch64ISD::UMAXV", SDT_AArch64UnaryVec>;
 
+def SDT_AArch64SETTAG : SDTypeProfile<0, 2, [SDTCisPtrTy<0>, SDTCisPtrTy<1>]>;
+def AArch64stg : SDNode<"AArch64ISD::STG", SDT_AArch64SETTAG, [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>;
+def AArch64stzg : SDNode<"AArch64ISD::STZG", SDT_AArch64SETTAG, [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>;
+def AArch64st2g : SDNode<"AArch64ISD::ST2G", SDT_AArch64SETTAG, [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>;
+def AArch64stz2g : SDNode<"AArch64ISD::STZ2G", SDT_AArch64SETTAG, [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>;
+
 //===----------------------------------------------------------------------===//
 
 //===----------------------------------------------------------------------===//
@@ -1289,6 +1295,15 @@
 defm ST2G  : MemTagStore<0b10, "st2g">;
 defm STZ2G : MemTagStore<0b11, "stz2g">;
 
+def : Pat<(AArch64stg GPR64sp:$Rn, (am_indexeds9s128 GPR64sp:$Rm, simm9s16:$imm)),
+          (STGOffset $Rn, $Rm, $imm)>;
+def : Pat<(AArch64stzg GPR64sp:$Rn, (am_indexeds9s128 GPR64sp:$Rm, simm9s16:$imm)),
+          (STZGOffset $Rn, $Rm, $imm)>;
+def : Pat<(AArch64st2g GPR64sp:$Rn, (am_indexeds9s128 GPR64sp:$Rm, simm9s16:$imm)),
+          (ST2GOffset $Rn, $Rm, $imm)>;
+def : Pat<(AArch64stz2g GPR64sp:$Rn, (am_indexeds9s128 GPR64sp:$Rm, simm9s16:$imm)),
+          (STZ2GOffset $Rn, $Rm, $imm)>;
+
 defm STGP     : StorePairOffset <0b01, 0, GPR64z, simm7s16, "stgp">;
 def  STGPpre  : StorePairPreIdx <0b01, 0, GPR64z, simm7s16, "stgp">;
 def  STGPpost : StorePairPostIdx<0b01, 0, GPR64z, simm7s16, "stgp">;
@@ -1296,6 +1311,36 @@
 def : Pat<(int_aarch64_stg GPR64:$Rt, (am_indexeds9s128 GPR64sp:$Rn, simm9s16:$offset)),
           (STGOffset GPR64:$Rt, GPR64sp:$Rn,  simm9s16:$offset)>;
 
+def : Pat<(int_aarch64_stgp (am_indexed7s128 GPR64sp:$Rn, simm7s16:$imm), GPR64:$Rt, GPR64:$Rt2),
+          (STGPi $Rt, $Rt2, $Rn, $imm)>;
+
+def IRGstack
+    : Pseudo<(outs GPR64sp:$Rd), (ins GPR64sp:$Rsp, GPR64:$Rm), []>,
+      Sched<[]>;
+def TAGPstack
+    : Pseudo<(outs GPR64sp:$Rd), (ins GPR64sp:$Rn, uimm6s16:$imm6, GPR64sp:$Rm, imm0_15:$imm4), []>,
+      Sched<[]>;
+
+// Explicit SP in the first operand prevents ShrinkWrap optimization
+// from leaving this instruction out of the stack frame. When IRGstack
+// is transformed into IRG, this operand is replaced with the actual
+// register / expression for the tagged base pointer of the current function.
+def : Pat<(int_aarch64_irg_sp i64:$Rm), (IRGstack SP, i64:$Rm)>;
+
+// Large STG to be expanded into a loop. $Rm is the size, $Rn is start address.
+// $Rn_wback is one past the end of the range.
+let isCodeGenOnly=1, mayStore=1 in {
+def STGloop
+    : Pseudo<(outs GPR64common:$Rm_wback, GPR64sp:$Rn_wback), (ins GPR64common:$Rm, GPR64sp:$Rn),
+             [], "$Rn = $Rn_wback,@earlyclobber $Rn_wback,$Rm = $Rm_wback,@earlyclobber $Rm_wback" >,
+      Sched<[WriteAdr, WriteST]>;
+
+def STZGloop
+    : Pseudo<(outs GPR64common:$Rm_wback, GPR64sp:$Rn_wback), (ins GPR64common:$Rm, GPR64sp:$Rn),
+             [], "$Rn = $Rn_wback,@earlyclobber $Rn_wback,$Rm = $Rm_wback,@earlyclobber $Rm_wback" >,
+      Sched<[WriteAdr, WriteST]>;
+}
+
 } // Predicates = [HasMTE]
 
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
index f4e810f..0efeeb2 100644
--- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
@@ -105,6 +105,12 @@
   /// ForwardedMustTailRegParms - A list of virtual and physical registers
   /// that must be forwarded to every musttail call.
   SmallVector<ForwardedRegister, 1> ForwardedMustTailRegParms;
+
+  // Offset from SP-at-entry to the tagged base pointer.
+  // Tagged base pointer is set up to point to the first (lowest address) tagged
+  // stack slot.
+  unsigned TaggedBasePointerOffset;
+
 public:
   AArch64FunctionInfo() = default;
 
@@ -224,6 +230,13 @@
     return ForwardedMustTailRegParms;
   }
 
+  unsigned getTaggedBasePointerOffset() const {
+    return TaggedBasePointerOffset;
+  }
+  void setTaggedBasePointerOffset(unsigned Offset) {
+    TaggedBasePointerOffset = Offset;
+  }
+
 private:
   // Hold the lists of LOHs.
   MILOHContainer LOHContainerSet;
diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
index c44d77c..6d5a4e3 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
@@ -468,10 +468,19 @@
     return;
   }
 
-  // Modify MI as necessary to handle as much of 'Offset' as possible
-  Offset = TFI->resolveFrameIndexReference(
-      MF, FrameIndex, FrameReg, /*PreferFP=*/false, /*ForSimm=*/true);
+  if (MI.getOpcode() == AArch64::TAGPstack) {
+    // TAGPstack must use the virtual frame register in its 3rd operand.
+    const MachineFrameInfo &MFI = MF.getFrameInfo();
+    const AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
+    FrameReg = MI.getOperand(3).getReg();
+    Offset =
+        MFI.getObjectOffset(FrameIndex) + AFI->getTaggedBasePointerOffset();
+  } else {
+    Offset = TFI->resolveFrameIndexReference(
+        MF, FrameIndex, FrameReg, /*PreferFP=*/false, /*ForSimm=*/true);
+  }
 
+  // Modify MI as necessary to handle as much of 'Offset' as possible
   if (rewriteAArch64FrameIndex(MI, FIOperandNum, FrameReg, Offset, TII))
     return;
 
diff --git a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp
index 953d738..60dbace 100644
--- a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp
@@ -56,3 +56,91 @@
     CodeGenOpt::Level OptLevel) const {
   return OptLevel >= CodeGenOpt::Aggressive;
 }
+
+static const int kSetTagLoopThreshold = 176;
+
+static SDValue EmitUnrolledSetTag(SelectionDAG &DAG, const SDLoc &dl,
+                                  SDValue Chain, SDValue Ptr, uint64_t ObjSize,
+                                  const MachineMemOperand *BaseMemOperand,
+                                  bool ZeroData) {
+  MachineFunction &MF = DAG.getMachineFunction();
+  unsigned ObjSizeScaled = ObjSize / 16;
+
+  SDValue TagSrc = Ptr;
+  if (Ptr.getOpcode() == ISD::FrameIndex) {
+    int FI = cast<FrameIndexSDNode>(Ptr)->getIndex();
+    Ptr = DAG.getTargetFrameIndex(FI, MVT::i64);
+    // A frame index operand may end up as [SP + offset] => it is fine to use SP
+    // register as the tag source.
+    TagSrc = DAG.getRegister(AArch64::SP, MVT::i64);
+  }
+
+  const unsigned OpCode1 = ZeroData ? AArch64ISD::STZG : AArch64ISD::STG;
+  const unsigned OpCode2 = ZeroData ? AArch64ISD::STZ2G : AArch64ISD::ST2G;
+
+  SmallVector<SDValue, 8> OutChains;
+  unsigned OffsetScaled = 0;
+  while (OffsetScaled < ObjSizeScaled) {
+    if (ObjSizeScaled - OffsetScaled >= 2) {
+      SDValue AddrNode = DAG.getMemBasePlusOffset(Ptr, OffsetScaled * 16, dl);
+      SDValue St = DAG.getMemIntrinsicNode(
+          OpCode2, dl, DAG.getVTList(MVT::Other),
+          {Chain, TagSrc, AddrNode},
+          MVT::v4i64,
+          MF.getMachineMemOperand(BaseMemOperand, OffsetScaled * 16, 16 * 2));
+      OffsetScaled += 2;
+      OutChains.push_back(St);
+      continue;
+    }
+
+    if (ObjSizeScaled - OffsetScaled > 0) {
+      SDValue AddrNode = DAG.getMemBasePlusOffset(Ptr, OffsetScaled * 16, dl);
+      SDValue St = DAG.getMemIntrinsicNode(
+          OpCode1, dl, DAG.getVTList(MVT::Other),
+          {Chain, TagSrc, AddrNode},
+          MVT::v2i64,
+          MF.getMachineMemOperand(BaseMemOperand, OffsetScaled * 16, 16));
+      OffsetScaled += 1;
+      OutChains.push_back(St);
+    }
+  }
+
+  SDValue Res = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains);
+  return Res;
+}
+
+SDValue AArch64SelectionDAGInfo::EmitTargetCodeForSetTag(
+    SelectionDAG &DAG, const SDLoc &dl, SDValue Chain, SDValue Addr,
+    SDValue Size, MachinePointerInfo DstPtrInfo, bool ZeroData) const {
+  uint64_t ObjSize = cast<ConstantSDNode>(Size)->getZExtValue();
+  assert(ObjSize % 16 == 0);
+
+  MachineFunction &MF = DAG.getMachineFunction();
+  MachineMemOperand *BaseMemOperand = MF.getMachineMemOperand(
+      DstPtrInfo, MachineMemOperand::MOStore, ObjSize, 16);
+
+  bool UseSetTagRangeLoop =
+      kSetTagLoopThreshold >= 0 && (int)ObjSize >= kSetTagLoopThreshold;
+  if (!UseSetTagRangeLoop)
+    return EmitUnrolledSetTag(DAG, dl, Chain, Addr, ObjSize, BaseMemOperand,
+                              ZeroData);
+
+  if (ObjSize % 32 != 0) {
+    SDNode *St1 = DAG.getMachineNode(
+        ZeroData ? AArch64::STZGPostIndex : AArch64::STGPostIndex, dl,
+        {MVT::i64, MVT::Other},
+        {Addr, Addr, DAG.getTargetConstant(1, dl, MVT::i64), Chain});
+    DAG.setNodeMemRefs(cast<MachineSDNode>(St1), {BaseMemOperand});
+    ObjSize -= 16;
+    Addr = SDValue(St1, 0);
+    Chain = SDValue(St1, 1);
+  }
+
+  const EVT ResTys[] = {MVT::i64, MVT::i64, MVT::Other};
+  SDValue Ops[] = {DAG.getConstant(ObjSize, dl, MVT::i64), Addr, Chain};
+  SDNode *St = DAG.getMachineNode(
+      ZeroData ? AArch64::STZGloop : AArch64::STGloop, dl, ResTys, Ops);
+
+  DAG.setNodeMemRefs(cast<MachineSDNode>(St), {BaseMemOperand});
+  return SDValue(St, 2);
+}
diff --git a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.h b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.h
index 9d38612..d0967fb 100644
--- a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.h
@@ -23,6 +23,10 @@
                                   SDValue Chain, SDValue Dst, SDValue Src,
                                   SDValue Size, unsigned Align, bool isVolatile,
                                   MachinePointerInfo DstPtrInfo) const override;
+  SDValue EmitTargetCodeForSetTag(SelectionDAG &DAG, const SDLoc &dl,
+                                  SDValue Chain, SDValue Op1, SDValue Op2,
+                                  MachinePointerInfo DstPtrInfo,
+                                  bool ZeroData) const override;
   bool generateFMAsInMachineCombiner(CodeGenOpt::Level OptLevel) const override;
 };
 }