Eliminate the use of spill (reserved) registers.


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@11476 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/CodeGen/LiveIntervalAnalysis.cpp b/lib/CodeGen/LiveIntervalAnalysis.cpp
index f965c5f..f602658 100644
--- a/lib/CodeGen/LiveIntervalAnalysis.cpp
+++ b/lib/CodeGen/LiveIntervalAnalysis.cpp
@@ -30,6 +30,7 @@
 #include "Support/CommandLine.h"
 #include "Support/Debug.h"
 #include "Support/Statistic.h"
+#include "Support/STLExtras.h"
 #include <cmath>
 #include <iostream>
 #include <limits>
@@ -42,6 +43,8 @@
 
     Statistic<> numIntervals("liveintervals", "Number of intervals");
     Statistic<> numJoined   ("liveintervals", "Number of joined intervals");
+    Statistic<> numPeep     ("liveintervals", "Number of identity moves "
+                             "eliminated after coalescing");
 
     cl::opt<bool>
     join("join-liveintervals",
@@ -64,7 +67,7 @@
 {
     mbbi2mbbMap_.clear();
     mi2iMap_.clear();
-    r2iMap_.clear();
+    i2miMap_.clear();
     r2iMap_.clear();
     r2rMap_.clear();
     intervals_.clear();
@@ -74,7 +77,7 @@
 /// runOnMachineFunction - Register allocate the whole function
 ///
 bool LiveIntervals::runOnMachineFunction(MachineFunction &fn) {
-    DEBUG(std::cerr << "Machine Function\n");
+    DEBUG(std::cerr << "MACHINE FUNCTION: "; fn.print(std::cerr));
     mf_ = &fn;
     tm_ = &fn.getTarget();
     mri_ = tm_->getRegisterInfo();
@@ -94,47 +97,110 @@
              mi != miEnd; ++mi) {
             inserted = mi2iMap_.insert(std::make_pair(mi, miIndex)).second;
             assert(inserted && "multiple MachineInstr -> index mappings");
+            i2miMap_.push_back(mi);
             miIndex += 2;
         }
     }
 
     computeIntervals();
 
-    // compute spill weights
-    const LoopInfo& loopInfo = getAnalysis<LoopInfo>();
-    const TargetInstrInfo& tii = tm_->getInstrInfo();
-
-    for (MachineFunction::const_iterator mbbi = mf_->begin(),
-             mbbe = mf_->end(); mbbi != mbbe; ++mbbi) {
-        const MachineBasicBlock* mbb = mbbi;
-        unsigned loopDepth = loopInfo.getLoopDepth(mbb->getBasicBlock());
-
-        for (MachineBasicBlock::const_iterator mi = mbb->begin(),
-                 mie = mbb->end(); mi != mie; ++mi) {
-            for (int i = mi->getNumOperands() - 1; i >= 0; --i) {
-                const MachineOperand& mop = mi->getOperand(i);
-                if (mop.isRegister() &&
-                    MRegisterInfo::isVirtualRegister(mop.getReg())) {
-                    unsigned reg = mop.getReg();
-                    Reg2IntervalMap::iterator r2iit = r2iMap_.find(reg);
-                    assert(r2iit != r2iMap_.end());
-                    r2iit->second->weight += pow(10.0F, loopDepth);
-                }
-            }
-        }
-    }
+    numIntervals += intervals_.size();
 
     // join intervals if requested
     if (join) joinIntervals();
 
-    numIntervals += intervals_.size();
+    // perform a final pass over the instructions and compute spill
+    // weights, coalesce virtual registers and remove identity moves
+    const LoopInfo& loopInfo = getAnalysis<LoopInfo>();
+    const TargetInstrInfo& tii = tm_->getInstrInfo();
+
+    for (MachineFunction::iterator mbbi = mf_->begin(), mbbe = mf_->end();
+         mbbi != mbbe; ++mbbi) {
+        MachineBasicBlock* mbb = mbbi;
+        unsigned loopDepth = loopInfo.getLoopDepth(mbb->getBasicBlock());
+
+        for (MachineBasicBlock::iterator mii = mbb->begin(), mie = mbb->end();
+             mii != mie; ) {
+            for (unsigned i = 0; i < mii->getNumOperands(); ++i) {
+                const MachineOperand& mop = mii->getOperand(i);
+                if (mop.isRegister()) {
+                    // replace register with representative register
+                    unsigned reg = rep(mop.getReg());
+                    mii->SetMachineOperandReg(i, reg);
+
+                    if (MRegisterInfo::isVirtualRegister(reg)) {
+                        Reg2IntervalMap::iterator r2iit = r2iMap_.find(reg);
+                        assert(r2iit != r2iMap_.end());
+                        r2iit->second->weight += pow(10.0F, loopDepth);
+                    }
+                }
+            }
+
+            // if the move is now an identity move delete it
+            unsigned srcReg, dstReg;
+            if (tii.isMoveInstr(*mii, srcReg, dstReg) && srcReg == dstReg) {
+                // remove index -> MachineInstr and
+                // MachineInstr -> index mappings
+                Mi2IndexMap::iterator mi2i = mi2iMap_.find(mii);
+                if (mi2i != mi2iMap_.end()) {
+                    i2miMap_[mi2i->second/2] = 0;
+                    mi2iMap_.erase(mi2i);
+                }
+                mii = mbbi->erase(mii);
+                ++numPeep;
+            }
+            else
+                ++mii;
+        }
+    }
 
     intervals_.sort(StartPointComp());
+    DEBUG(std::cerr << "*** INTERVALS ***\n");
     DEBUG(std::copy(intervals_.begin(), intervals_.end(),
                     std::ostream_iterator<Interval>(std::cerr, "\n")));
+    DEBUG(std::cerr << "*** MACHINEINSTRS ***\n");
+    DEBUG(
+        for (unsigned i = 0; i != i2miMap_.size(); ++i) {
+            if (const MachineInstr* mi = i2miMap_[i]) {
+                std:: cerr << i*2 << '\t';
+                mi->print(std::cerr, *tm_);
+            }
+        });
+
     return true;
 }
 
+void LiveIntervals::updateSpilledInterval(Interval& li)
+{
+    assert(li.weight != std::numeric_limits<float>::infinity() &&
+           "attempt to spill already spilled interval!");
+    Interval::Ranges oldRanges;
+    swap(oldRanges, li.ranges);
+
+    for (Interval::Ranges::iterator i = oldRanges.begin(), e = oldRanges.end();
+         i != e; ++i) {
+        unsigned index = i->first & ~1;
+        unsigned end = i->second;
+
+        for (; index < end; index += 2) {
+            // skip deleted instructions
+            while (!getInstructionFromIndex(index)) index += 2;
+            MachineInstr* mi = getInstructionFromIndex(index);
+            for (unsigned i = 0; i < mi->getNumOperands(); ++i) {
+                MachineOperand& mop = mi->getOperand(i);
+                if (mop.isRegister()) {
+                    unsigned reg = mop.getReg();
+                    if (rep(reg) == li.reg) {
+                        li.addRange(index, index + 2);
+                    }
+                }
+            }
+        }
+    }
+    // the new spill weight is now infinity as it cannot be spilled again
+    li.weight = std::numeric_limits<float>::infinity();
+}
+
 void LiveIntervals::printRegName(unsigned reg) const
 {
     if (MRegisterInfo::isPhysicalRegister(reg))
@@ -277,9 +343,16 @@
 
 unsigned LiveIntervals::getInstructionIndex(MachineInstr* instr) const
 {
-    assert(mi2iMap_.find(instr) != mi2iMap_.end() &&
-           "instruction not assigned a number");
-    return mi2iMap_.find(instr)->second;
+    Mi2IndexMap::const_iterator it = mi2iMap_.find(instr);
+    return it == mi2iMap_.end() ? std::numeric_limits<unsigned>::max() : it->second;
+}
+
+MachineInstr* LiveIntervals::getInstructionFromIndex(unsigned index) const
+{
+    index /= 2; // convert index to vector index
+    assert(index < i2miMap_.size() &&
+           "index does not correspond to an instruction");
+    return i2miMap_[index];
 }
 
 /// computeIntervals - computes the live intervals for virtual
@@ -288,20 +361,19 @@
 /// which a variable is live
 void LiveIntervals::computeIntervals()
 {
-    DEBUG(std::cerr << "computing live intervals:\n");
+    DEBUG(std::cerr << "*** COMPUTING LIVE INTERVALS ***\n");
 
     for (MbbIndex2MbbMap::iterator
              it = mbbi2mbbMap_.begin(), itEnd = mbbi2mbbMap_.end();
          it != itEnd; ++it) {
         MachineBasicBlock* mbb = it->second;
-        DEBUG(std::cerr << "machine basic block: "
-              << mbb->getBasicBlock()->getName() << "\n");
+        DEBUG(std::cerr << mbb->getBasicBlock()->getName() << ":\n");
 
         for (MachineBasicBlock::iterator mi = mbb->begin(), miEnd = mbb->end();
              mi != miEnd; ++mi) {
             const TargetInstrDescriptor& tid =
                 tm_->getInstrInfo().get(mi->getOpcode());
-            DEBUG(std::cerr << "\t[" << getInstructionIndex(mi) << "] ";
+            DEBUG(std::cerr << "[" << getInstructionIndex(mi) << "]\t";
                   mi->print(std::cerr, *tm_););
 
             // handle implicit defs
@@ -329,22 +401,20 @@
 
 void LiveIntervals::joinIntervals()
 {
-    DEBUG(std::cerr << "joining compatible intervals:\n");
+    DEBUG(std::cerr << "** JOINING INTERVALS ***\n");
 
     const TargetInstrInfo& tii = tm_->getInstrInfo();
 
     for (MachineFunction::iterator mbbi = mf_->begin(), mbbe = mf_->end();
          mbbi != mbbe; ++mbbi) {
         MachineBasicBlock* mbb = mbbi;
-        DEBUG(std::cerr << "machine basic block: "
-              << mbb->getBasicBlock()->getName() << "\n");
+        DEBUG(std::cerr << mbb->getBasicBlock()->getName() << ":\n");
 
         for (MachineBasicBlock::iterator mi = mbb->begin(), mie = mbb->end();
              mi != mie; ++mi) {
             const TargetInstrDescriptor& tid =
                 tm_->getInstrInfo().get(mi->getOpcode());
-            DEBUG(std::cerr << "\t\tinstruction["
-                  << getInstructionIndex(mi) << "]: ";
+            DEBUG(std::cerr << "[" << getInstructionIndex(mi) << "]\t";
                   mi->print(std::cerr, *tm_););
 
             // we only join virtual registers with allocatable