ARM: recommit r237590: allow jump tables to be placed as constant islands.

The original version didn't properly account for the base register
being modified before the final jump, so caused miscompilations in
Chromium and LLVM. I've fixed this and tested with an LLVM self-host
(I don't have the means to build & test Chromium).

The general idea remains the same: in pathological cases jump tables
can be too far away from the instructions referencing them (like other
constants) so they need to be movable.

Should fix PR23627.

llvm-svn: 238680
diff --git a/llvm/lib/Target/ARM/ARMConstantIslandPass.cpp b/llvm/lib/Target/ARM/ARMConstantIslandPass.cpp
index 6fa5ad7..d42b124 100644
--- a/llvm/lib/Target/ARM/ARMConstantIslandPass.cpp
+++ b/llvm/lib/Target/ARM/ARMConstantIslandPass.cpp
@@ -180,9 +180,7 @@
       MachineInstr *MI;
       MachineInstr *CPEMI;
       MachineBasicBlock *HighWaterMark;
-    private:
       unsigned MaxDisp;
-    public:
       bool NegOk;
       bool IsSoImm;
       bool KnownAlignment;
@@ -216,12 +214,24 @@
     };
 
     /// CPEntries - Keep track of all of the constant pool entry machine
-    /// instructions. For each original constpool index (i.e. those that
-    /// existed upon entry to this pass), it keeps a vector of entries.
-    /// Original elements are cloned as we go along; the clones are
-    /// put in the vector of the original element, but have distinct CPIs.
+    /// instructions. For each original constpool index (i.e. those that existed
+    /// upon entry to this pass), it keeps a vector of entries.  Original
+    /// elements are cloned as we go along; the clones are put in the vector of
+    /// the original element, but have distinct CPIs.
+    ///
+    /// The first half of CPEntries contains generic constants, the second half
+    /// contains jump tables. Use getCombinedIndex on a generic CPEMI to look up
+    /// which vector it will be in here.
     std::vector<std::vector<CPEntry> > CPEntries;
 
+    /// Maps a JT index to the offset in CPEntries containing copies of that
+    /// table. The equivalent map for a CONSTPOOL_ENTRY is the identity.
+    DenseMap<int, int> JumpTableEntryIndices;
+
+    /// Maps a JT index to the LEA that actually uses the index to calculate its
+    /// base address.
+    DenseMap<int, int> JumpTableUserIndices;
+
     /// ImmBranch - One per immediate branch, keeping the machine instruction
     /// pointer, conditional or unconditional, the max displacement,
     /// and (if isCond is true) the corresponding unconditional branch
@@ -269,7 +279,8 @@
     }
 
   private:
-    void doInitialPlacement(std::vector<MachineInstr*> &CPEMIs);
+    void doInitialConstPlacement(std::vector<MachineInstr *> &CPEMIs);
+    void doInitialJumpTablePlacement(std::vector<MachineInstr *> &CPEMIs);
     bool BBHasFallthrough(MachineBasicBlock *MBB);
     CPEntry *findConstPoolEntry(unsigned CPI, const MachineInstr *CPEMI);
     unsigned getCPELogAlign(const MachineInstr *CPEMI);
@@ -279,6 +290,7 @@
     void updateForInsertedWaterBlock(MachineBasicBlock *NewBB);
     void adjustBBOffsetsAfter(MachineBasicBlock *BB);
     bool decrementCPEReferenceCount(unsigned CPI, MachineInstr* CPEMI);
+    unsigned getCombinedIndex(const MachineInstr *CPEMI);
     int findInRangeCPEntry(CPUser& U, unsigned UserOffset);
     bool findAvailableWater(CPUser&U, unsigned UserOffset,
                             water_iterator &WaterIter);
@@ -301,8 +313,9 @@
     bool optimizeThumb2Instructions();
     bool optimizeThumb2Branches();
     bool reorderThumb2JumpTables();
-    unsigned removeDeadDefinitions(MachineInstr *MI, unsigned BaseReg,
-                                   unsigned IdxReg);
+    bool preserveBaseRegister(MachineInstr *JumpMI, MachineInstr *LEAMI,
+                              unsigned &DeadSize, bool &CanDeleteLEA,
+                              bool &BaseRegKill);
     bool optimizeThumb2JumpTables();
     MachineBasicBlock *adjustJTTargetBlockForward(MachineBasicBlock *BB,
                                                   MachineBasicBlock *JTBB);
@@ -413,7 +426,10 @@
   // we put them all at the end of the function.
   std::vector<MachineInstr*> CPEMIs;
   if (!MCP->isEmpty())
-    doInitialPlacement(CPEMIs);
+    doInitialConstPlacement(CPEMIs);
+
+  if (MF->getJumpTableInfo())
+    doInitialJumpTablePlacement(CPEMIs);
 
   /// The next UID to take is the first unused one.
   AFI->initPICLabelUId(CPEMIs.size());
@@ -478,7 +494,8 @@
   for (unsigned i = 0, e = CPEntries.size(); i != e; ++i) {
     for (unsigned j = 0, je = CPEntries[i].size(); j != je; ++j) {
       const CPEntry & CPE = CPEntries[i][j];
-      AFI->recordCPEClone(i, CPE.CPI);
+      if (CPE.CPEMI && CPE.CPEMI->getOperand(1).isCPI())
+        AFI->recordCPEClone(i, CPE.CPI);
     }
   }
 
@@ -488,6 +505,8 @@
   WaterList.clear();
   CPUsers.clear();
   CPEntries.clear();
+  JumpTableEntryIndices.clear();
+  JumpTableUserIndices.clear();
   ImmBranches.clear();
   PushPopMIs.clear();
   T2JumpTables.clear();
@@ -495,10 +514,10 @@
   return MadeChange;
 }
 
-/// doInitialPlacement - Perform the initial placement of the constant pool
-/// entries.  To start with, we put them all at the end of the function.
+/// \brief Perform the initial placement of the regular constant pool entries.
+/// To start with, we put them all at the end of the function.
 void
-ARMConstantIslands::doInitialPlacement(std::vector<MachineInstr*> &CPEMIs) {
+ARMConstantIslands::doInitialConstPlacement(std::vector<MachineInstr*> &CPEMIs) {
   // Create the basic block to hold the CPE's.
   MachineBasicBlock *BB = MF->CreateMachineBasicBlock();
   MF->push_back(BB);
@@ -556,6 +575,66 @@
   DEBUG(BB->dump());
 }
 
+/// \brief Do initial placement of the jump tables. Because Thumb2's TBB and TBH
+/// instructions can be made more efficient if the jump table immediately
+/// follows the instruction, it's best to place them immediately next to their
+/// jumps to begin with. In almost all cases they'll never be moved from that
+/// position.
+void ARMConstantIslands::doInitialJumpTablePlacement(
+    std::vector<MachineInstr *> &CPEMIs) {
+  unsigned i = CPEntries.size();
+  auto MJTI = MF->getJumpTableInfo();
+  const std::vector<MachineJumpTableEntry> &JT = MJTI->getJumpTables();
+
+  MachineBasicBlock *LastCorrectlyNumberedBB = nullptr;
+  for (MachineBasicBlock &MBB : *MF) {
+    auto MI = MBB.getLastNonDebugInstr();
+
+    unsigned JTOpcode;
+    switch (MI->getOpcode()) {
+    default:
+      continue;
+    case ARM::BR_JTadd:
+    case ARM::BR_JTr:
+    case ARM::tBR_JTr:
+    case ARM::BR_JTm:
+      JTOpcode = ARM::JUMPTABLE_ADDRS;
+      break;
+    case ARM::t2BR_JT:
+      JTOpcode = ARM::JUMPTABLE_INSTS;
+      break;
+    case ARM::t2TBB_JT:
+      JTOpcode = ARM::JUMPTABLE_TBB;
+      break;
+    case ARM::t2TBH_JT:
+      JTOpcode = ARM::JUMPTABLE_TBH;
+      break;
+    }
+
+    unsigned NumOps = MI->getDesc().getNumOperands();
+    MachineOperand JTOp =
+      MI->getOperand(NumOps - (MI->isPredicable() ? 2 : 1));
+    unsigned JTI = JTOp.getIndex();
+    unsigned Size = JT[JTI].MBBs.size() * sizeof(uint32_t);
+    MachineBasicBlock *JumpTableBB = MF->CreateMachineBasicBlock();
+    MF->insert(std::next(MachineFunction::iterator(MBB)), JumpTableBB);
+    MachineInstr *CPEMI = BuildMI(*JumpTableBB, JumpTableBB->begin(),
+                                  DebugLoc(), TII->get(JTOpcode))
+                              .addImm(i++)
+                              .addJumpTableIndex(JTI)
+                              .addImm(Size);
+    CPEMIs.push_back(CPEMI);
+    CPEntries.emplace_back(1, CPEntry(CPEMI, JTI));
+    JumpTableEntryIndices.insert(std::make_pair(JTI, CPEntries.size() - 1));
+    if (!LastCorrectlyNumberedBB)
+      LastCorrectlyNumberedBB = &MBB;
+  }
+
+  // If we did anything then we need to renumber the subsequent blocks.
+  if (LastCorrectlyNumberedBB)
+    MF->RenumberBlocks(LastCorrectlyNumberedBB);
+}
+
 /// BBHasFallthrough - Return true if the specified basic block can fallthrough
 /// into the block immediately after it.
 bool ARMConstantIslands::BBHasFallthrough(MachineBasicBlock *MBB) {
@@ -595,9 +674,21 @@
 /// getCPELogAlign - Returns the required alignment of the constant pool entry
 /// represented by CPEMI.  Alignment is measured in log2(bytes) units.
 unsigned ARMConstantIslands::getCPELogAlign(const MachineInstr *CPEMI) {
-  assert(CPEMI && CPEMI->getOpcode() == ARM::CONSTPOOL_ENTRY);
+  switch (CPEMI->getOpcode()) {
+  case ARM::CONSTPOOL_ENTRY:
+    break;
+  case ARM::JUMPTABLE_TBB:
+    return 0;
+  case ARM::JUMPTABLE_TBH:
+  case ARM::JUMPTABLE_INSTS:
+    return 1;
+  case ARM::JUMPTABLE_ADDRS:
+    return 2;
+  default:
+    llvm_unreachable("unknown constpool entry kind");
+  }
 
-  unsigned CPI = CPEMI->getOperand(1).getIndex();
+  unsigned CPI = getCombinedIndex(CPEMI);
   assert(CPI < MCP->getConstants().size() && "Invalid constant pool index.");
   unsigned Align = MCP->getConstants()[CPI].getAlignment();
   assert(isPowerOf2_32(Align) && "Invalid CPE alignment");
@@ -706,12 +797,14 @@
       if (Opc == ARM::tPUSH || Opc == ARM::tPOP_RET)
         PushPopMIs.push_back(I);
 
-      if (Opc == ARM::CONSTPOOL_ENTRY)
+      if (Opc == ARM::CONSTPOOL_ENTRY || Opc == ARM::JUMPTABLE_ADDRS ||
+          Opc == ARM::JUMPTABLE_INSTS || Opc == ARM::JUMPTABLE_TBB ||
+          Opc == ARM::JUMPTABLE_TBH)
         continue;
 
       // Scan the instructions for constant pool operands.
       for (unsigned op = 0, e = I->getNumOperands(); op != e; ++op)
-        if (I->getOperand(op).isCPI()) {
+        if (I->getOperand(op).isCPI() || I->getOperand(op).isJTI()) {
           // We found one.  The addressing mode tells us the max displacement
           // from the PC that this instruction permits.
 
@@ -727,6 +820,7 @@
 
           // Taking the address of a CP entry.
           case ARM::LEApcrel:
+          case ARM::LEApcrelJT:
             // This takes a SoImm, which is 8 bit immediate rotated. We'll
             // pretend the maximum offset is 255 * 4. Since each instruction
             // 4 byte wide, this is always correct. We'll check for other
@@ -737,10 +831,12 @@
             IsSoImm = true;
             break;
           case ARM::t2LEApcrel:
+          case ARM::t2LEApcrelJT:
             Bits = 12;
             NegOk = true;
             break;
           case ARM::tLEApcrel:
+          case ARM::tLEApcrelJT:
             Bits = 8;
             Scale = 4;
             break;
@@ -768,6 +864,11 @@
 
           // Remember that this is a user of a CP entry.
           unsigned CPI = I->getOperand(op).getIndex();
+          if (I->getOperand(op).isJTI()) {
+            JumpTableUserIndices.insert(std::make_pair(CPI, CPUsers.size()));
+            CPI = JumpTableEntryIndices[CPI];
+          }
+
           MachineInstr *CPEMI = CPEMIs[CPI];
           unsigned MaxOffs = ((1 << Bits)-1) * Scale;
           CPUsers.push_back(CPUser(I, CPEMI, MaxOffs, NegOk, IsSoImm));
@@ -1101,6 +1202,13 @@
   return false;
 }
 
+unsigned ARMConstantIslands::getCombinedIndex(const MachineInstr *CPEMI) {
+  if (CPEMI->getOperand(1).isCPI())
+    return CPEMI->getOperand(1).getIndex();
+
+  return JumpTableEntryIndices[CPEMI->getOperand(1).getIndex()];
+}
+
 /// LookForCPEntryInRange - see if the currently referenced CPE is in range;
 /// if not, see if an in-range clone of the CPE is in range, and if so,
 /// change the data structures so the user references the clone.  Returns:
@@ -1120,7 +1228,7 @@
   }
 
   // No.  Look for previously created clones of the CPE that are in range.
-  unsigned CPI = CPEMI->getOperand(1).getIndex();
+  unsigned CPI = getCombinedIndex(CPEMI);
   std::vector<CPEntry> &CPEs = CPEntries[CPI];
   for (unsigned i = 0, e = CPEs.size(); i != e; ++i) {
     // We already tried this one
@@ -1365,7 +1473,7 @@
   CPUser &U = CPUsers[CPUserIndex];
   MachineInstr *UserMI = U.MI;
   MachineInstr *CPEMI  = U.CPEMI;
-  unsigned CPI = CPEMI->getOperand(1).getIndex();
+  unsigned CPI = getCombinedIndex(CPEMI);
   unsigned Size = CPEMI->getOperand(2).getImm();
   // Compute this only once, it's expensive.
   unsigned UserOffset = getUserOffset(U);
@@ -1429,17 +1537,17 @@
   // Update internal data structures to account for the newly inserted MBB.
   updateForInsertedWaterBlock(NewIsland);
 
-  // Decrement the old entry, and remove it if refcount becomes 0.
-  decrementCPEReferenceCount(CPI, CPEMI);
-
   // Now that we have an island to add the CPE to, clone the original CPE and
   // add it to the island.
   U.HighWaterMark = NewIsland;
-  U.CPEMI = BuildMI(NewIsland, DebugLoc(), TII->get(ARM::CONSTPOOL_ENTRY))
-                .addImm(ID).addConstantPoolIndex(CPI).addImm(Size);
+  U.CPEMI = BuildMI(NewIsland, DebugLoc(), CPEMI->getDesc())
+                .addImm(ID).addOperand(CPEMI->getOperand(1)).addImm(Size);
   CPEntries[CPI].push_back(CPEntry(U.CPEMI, ID, 1));
   ++NumCPEs;
 
+  // Decrement the old entry, and remove it if refcount becomes 0.
+  decrementCPEReferenceCount(CPI, CPEMI);
+
   // Mark the basic block as aligned as required by the const-pool entry.
   NewIsland->setAlignment(getCPELogAlign(U.CPEMI));
 
@@ -1844,77 +1952,121 @@
   return MadeChange;
 }
 
-/// If we've formed a TBB or TBH instruction, the base register is now
-/// redundant. In most cases, the instructions defining it will now be dead and
-/// can be tidied up. This function removes them if so, and returns the number
-/// of bytes saved.
-unsigned ARMConstantIslands::removeDeadDefinitions(MachineInstr *MI,
-                                                   unsigned BaseReg,
-                                                   unsigned IdxReg) {
-  unsigned BytesRemoved = 0;
-  MachineBasicBlock *MBB = MI->getParent();
+/// \brief 
+static bool isSimpleIndexCalc(MachineInstr &I, unsigned EntryReg,
+                              unsigned BaseReg) {
+  if (I.getOpcode() != ARM::t2ADDrs)
+    return false;
 
-  // Scan backwards to find the instruction that defines the base
-  // register. Due to post-RA scheduling, we can't count on it
-  // immediately preceding the branch instruction.
-  MachineBasicBlock::iterator PrevI = MI;
-  MachineBasicBlock::iterator B = MBB->begin();
-  while (PrevI != B && !PrevI->definesRegister(BaseReg))
-    --PrevI;
+  if (I.getOperand(0).getReg() != EntryReg)
+    return false;
 
-  // If for some reason we didn't find it, we can't do anything, so
-  // just skip this one.
-  if (!PrevI->definesRegister(BaseReg) || PrevI->hasUnmodeledSideEffects() ||
-      PrevI->mayStore())
-    return BytesRemoved;
+  if (I.getOperand(1).getReg() != BaseReg)
+    return false;
 
-  MachineInstr *AddrMI = PrevI;
-  unsigned NewBaseReg = BytesRemoved;
+  // FIXME: what about CC and IdxReg?
+  return true;
+}
 
-  // Examine the instruction that calculates the jumptable entry address.  Make
-  // sure it only defines the base register and kills any uses other than the
-  // index register. We also need precisely one use to trace backwards to
-  // (hopefully) the LEA.
-  for (unsigned k = 0, eee = AddrMI->getNumOperands(); k != eee; ++k) {
-    const MachineOperand &MO = AddrMI->getOperand(k);
-    if (!MO.isReg() || !MO.getReg())
-      continue;
-    if (MO.isDef() && MO.getReg() != BaseReg)
-      return BytesRemoved;
+/// \brief While trying to form a TBB/TBH instruction, we may (if the table
+/// doesn't immediately follow the BR_JT) need access to the start of the
+/// jump-table. We know one instruction that produces such a register; this
+/// function works out whether that definition can be preserved to the BR_JT,
+/// possibly by removing an intervening addition (which is usually needed to
+/// calculate the actual entry to jump to).
+bool ARMConstantIslands::preserveBaseRegister(MachineInstr *JumpMI,
+                                              MachineInstr *LEAMI,
+                                              unsigned &DeadSize,
+                                              bool &CanDeleteLEA,
+                                              bool &BaseRegKill) {
+  if (JumpMI->getParent() != LEAMI->getParent())
+    return false;
 
-    if (MO.isUse() && MO.getReg() != IdxReg) {
-      if (!MO.isKill() || (NewBaseReg != 0 && NewBaseReg != MO.getReg()))
-        return BytesRemoved;
-      NewBaseReg = MO.getReg();
+  // Now we hope that we have at least these instructions in the basic block:
+  //     BaseReg = t2LEA ...
+  //     [...]
+  //     EntryReg = t2ADDrs BaseReg, ...
+  //     [...]
+  //     t2BR_JT EntryReg
+  //
+  // We have to be very conservative about what we recognise here though. The
+  // main perturbing factors to watch out for are:
+  //    + Spills at any point in the chain: not direct problems but we would
+  //      expect a blocking Def of the spilled register so in practice what we
+  //      can do is limited.
+  //    + EntryReg == BaseReg: this is the one situation we should allow a Def
+  //      of BaseReg, but only if the t2ADDrs can be removed.
+  //    + Some instruction other than t2ADDrs computing the entry. Not seen in
+  //      the wild, but we should be careful.
+  unsigned EntryReg = JumpMI->getOperand(0).getReg();
+  unsigned BaseReg = LEAMI->getOperand(0).getReg();
+
+  CanDeleteLEA = true;
+  BaseRegKill = false;
+  MachineInstr *RemovableAdd = nullptr;
+  MachineBasicBlock::iterator I(LEAMI);
+  for (++I; &*I != JumpMI; ++I) {
+    if (isSimpleIndexCalc(*I, EntryReg, BaseReg)) {
+      RemovableAdd = &*I;
+      break;
+    }
+
+    for (unsigned K = 0, E = I->getNumOperands(); K != E; ++K) {
+      const MachineOperand &MO = I->getOperand(K);
+      if (!MO.isReg() || !MO.getReg())
+        continue;
+      if (MO.isDef() && MO.getReg() == BaseReg)
+        return false;
+      if (MO.isUse() && MO.getReg() == BaseReg) {
+        BaseRegKill = BaseRegKill || MO.isKill();
+        CanDeleteLEA = false;
+      }
     }
   }
 
-  // Want to continue searching for AddrMI, but there are 2 problems: AddrMI is
-  // going away soon, and even decrementing once may be invalid.
-  if (PrevI != B)
-    PrevI = std::prev(PrevI);
+  if (!RemovableAdd)
+    return true;
 
-  DEBUG(dbgs() << "remove addr: " << *AddrMI);
-  BytesRemoved += TII->GetInstSizeInBytes(AddrMI);
-  AddrMI->eraseFromParent();
+  // Check the add really is removable, and that nothing else in the block
+  // clobbers BaseReg.
+  for (++I; &*I != JumpMI; ++I) {
+    for (unsigned K = 0, E = I->getNumOperands(); K != E; ++K) {
+      const MachineOperand &MO = I->getOperand(K);
+      if (!MO.isReg() || !MO.getReg())
+        continue;
+      if (MO.isDef() && MO.getReg() == BaseReg)
+        return false;
+      if (MO.isUse() && MO.getReg() == EntryReg)
+        RemovableAdd = nullptr;
+    }
+  }
 
-  // Now scan back again to find the tLEApcrel or t2LEApcrelJT instruction
-  // that gave us the initial base register definition.
-  for (; PrevI != B && !PrevI->definesRegister(NewBaseReg); --PrevI)
-    ;
+  if (RemovableAdd) {
+    RemovableAdd->eraseFromParent();
+    DeadSize += 4;
+  } else if (BaseReg == EntryReg) {
+    // The add wasn't removable, but clobbered the base for the TBB. So we can't
+    // preserve it.
+    return false;
+  }
 
-  // The instruction should be a tLEApcrel or t2LEApcrelJT; we want
-  // to delete it as well.
-  MachineInstr *LeaMI = PrevI;
-  if ((LeaMI->getOpcode() != ARM::tLEApcrelJT &&
-       LeaMI->getOpcode() != ARM::t2LEApcrelJT) ||
-      LeaMI->getOperand(0).getReg() != NewBaseReg)
-    return BytesRemoved;
+  // We reached the end of the block without seeing another definition of
+  // BaseReg (except, possibly the t2ADDrs, which was removed). BaseReg can be
+  // used in the TBB/TBH if necessary.
+  return true;
+}
 
-  DEBUG(dbgs() << "remove lea: " << *LeaMI);
-  BytesRemoved += TII->GetInstSizeInBytes(LeaMI);
-  LeaMI->eraseFromParent();
-  return BytesRemoved;
+/// \brief Returns whether CPEMI is the first instruction in the block
+/// immediately following JTMI (assumed to be a TBB or TBH terminator). If so,
+/// we can switch the first register to PC and usually remove the address
+/// calculation that preceeded it.
+static bool jumpTableFollowsTB(MachineInstr *JTMI, MachineInstr *CPEMI) {
+  MachineFunction::iterator MBB = JTMI->getParent();
+  MachineFunction *MF = MBB->getParent();
+  ++MBB;
+
+  return MBB != MF->end() && MBB->begin() != MBB->end() &&
+         &*MBB->begin() == CPEMI;
 }
 
 /// optimizeThumb2JumpTables - Use tbb / tbh instructions to generate smaller
@@ -1955,37 +2107,79 @@
         break;
     }
 
-    if (ByteOk || HalfWordOk) {
-      MachineBasicBlock *MBB = MI->getParent();
-      unsigned BaseReg = MI->getOperand(0).getReg();
-      bool BaseRegKill = MI->getOperand(0).isKill();
-      if (!BaseRegKill)
-        continue;
-      unsigned IdxReg = MI->getOperand(1).getReg();
-      bool IdxRegKill = MI->getOperand(1).isKill();
+    if (!ByteOk && !HalfWordOk)
+      continue;
 
-      DEBUG(dbgs() << "Shrink JT: " << *MI);
-      unsigned Opc = ByteOk ? ARM::t2TBB_JT : ARM::t2TBH_JT;
-      MachineBasicBlock::iterator MI_JT = MI;
-      MachineInstr *NewJTMI =
+    MachineBasicBlock *MBB = MI->getParent();
+    if (!MI->getOperand(0).isKill()) // FIXME: needed now?
+      continue;
+    unsigned IdxReg = MI->getOperand(1).getReg();
+    bool IdxRegKill = MI->getOperand(1).isKill();
+
+    CPUser &User = CPUsers[JumpTableUserIndices[JTI]];
+    unsigned DeadSize = 0;
+    bool CanDeleteLEA = false;
+    bool BaseRegKill = false;
+    bool PreservedBaseReg =
+        preserveBaseRegister(MI, User.MI, DeadSize, CanDeleteLEA, BaseRegKill);
+
+    if (!jumpTableFollowsTB(MI, User.CPEMI) && !PreservedBaseReg)
+      continue;
+
+    DEBUG(dbgs() << "Shrink JT: " << *MI);
+    MachineInstr *CPEMI = User.CPEMI;
+    unsigned Opc = ByteOk ? ARM::t2TBB_JT : ARM::t2TBH_JT;
+    MachineBasicBlock::iterator MI_JT = MI;
+    MachineInstr *NewJTMI =
         BuildMI(*MBB, MI_JT, MI->getDebugLoc(), TII->get(Opc))
-        .addReg(IdxReg, getKillRegState(IdxRegKill))
-        .addJumpTableIndex(JTI, JTOP.getTargetFlags());
-      DEBUG(dbgs() << "BB#" << MBB->getNumber() << ": " << *NewJTMI);
-      // FIXME: Insert an "ALIGN" instruction to ensure the next instruction
-      // is 2-byte aligned. For now, asm printer will fix it up.
-      unsigned NewSize = TII->GetInstSizeInBytes(NewJTMI);
-      unsigned OrigSize = TII->GetInstSizeInBytes(MI);
-      unsigned DeadSize = removeDeadDefinitions(MI, BaseReg, IdxReg);
-      MI->eraseFromParent();
+            .addReg(User.MI->getOperand(0).getReg(),
+                    getKillRegState(BaseRegKill))
+            .addReg(IdxReg, getKillRegState(IdxRegKill))
+            .addJumpTableIndex(JTI, JTOP.getTargetFlags())
+            .addImm(CPEMI->getOperand(0).getImm());
+    DEBUG(dbgs() << "BB#" << MBB->getNumber() << ": " << *NewJTMI);
 
-      int delta = OrigSize - NewSize + DeadSize;
-      BBInfo[MBB->getNumber()].Size -= delta;
-      adjustBBOffsetsAfter(MBB);
+    unsigned JTOpc = ByteOk ? ARM::JUMPTABLE_TBB : ARM::JUMPTABLE_TBH;
+    CPEMI->setDesc(TII->get(JTOpc));
 
-      ++NumTBs;
-      MadeChange = true;
+    if (jumpTableFollowsTB(MI, User.CPEMI)) {
+      NewJTMI->getOperand(0).setReg(ARM::PC);
+      NewJTMI->getOperand(0).setIsKill(false);
+
+      if (CanDeleteLEA)  {
+        User.MI->eraseFromParent();
+        DeadSize += 4;
+
+        // The LEA was eliminated, the TBB instruction becomes the only new user
+        // of the jump table.
+        User.MI = NewJTMI;
+        User.MaxDisp = 4;
+        User.NegOk = false;
+        User.IsSoImm = false;
+        User.KnownAlignment = false;
+      } else {
+        // The LEA couldn't be eliminated, so we must add another CPUser to
+        // record the TBB or TBH use.
+        int CPEntryIdx = JumpTableEntryIndices[JTI];
+        auto &CPEs = CPEntries[CPEntryIdx];
+        auto Entry = std::find_if(CPEs.begin(), CPEs.end(), [&](CPEntry &E) {
+          return E.CPEMI == User.CPEMI;
+        });
+        ++Entry->RefCount;
+        CPUsers.emplace_back(CPUser(NewJTMI, User.CPEMI, 4, false, false));
+      }
     }
+
+    unsigned NewSize = TII->GetInstSizeInBytes(NewJTMI);
+    unsigned OrigSize = TII->GetInstSizeInBytes(MI);
+    MI->eraseFromParent();
+
+    int Delta = OrigSize - NewSize + DeadSize;
+    BBInfo[MBB->getNumber()].Size -= Delta;
+    adjustBBOffsetsAfter(MBB);
+
+    ++NumTBs;
+    MadeChange = true;
   }
 
   return MadeChange;