Teach bottom up pre-ra scheduler to track register pressure. Work in progress.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@108991 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/CodeGen/SelectionDAG/ScheduleDAGRRList.cpp b/lib/CodeGen/SelectionDAG/ScheduleDAGRRList.cpp
index 3ef521c..9503b0c 100644
--- a/lib/CodeGen/SelectionDAG/ScheduleDAGRRList.cpp
+++ b/lib/CodeGen/SelectionDAG/ScheduleDAGRRList.cpp
@@ -24,15 +24,20 @@
 #include "llvm/Target/TargetData.h"
 #include "llvm/Target/TargetMachine.h"
 #include "llvm/Target/TargetInstrInfo.h"
+#include "llvm/Target/TargetLowering.h"
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/raw_ostream.h"
 #include <climits>
 using namespace llvm;
 
+static cl::opt<bool> RegPressureAware("reg-pressure-aware-sched",
+                                      cl::init(false), cl::Hidden);
+
 STATISTIC(NumBacktracks, "Number of times scheduler backtracked");
 STATISTIC(NumUnfolds,    "Number of nodes unfolded");
 STATISTIC(NumDups,       "Number of duplicated nodes");
@@ -181,7 +186,9 @@
 
 /// Schedule - Schedule the DAG using list scheduling.
 void ScheduleDAGRRList::Schedule() {
-  DEBUG(dbgs() << "********** List Scheduling **********\n");
+  DEBUG(dbgs()
+        << "********** List Scheduling BB#" << BB->getNumber()
+        << " **********\n");
 
   NumLiveRegs = 0;
   LiveRegDefs.resize(TRI->getNumRegs(), NULL);  
@@ -989,7 +996,7 @@
       : SPQ(spq) {}
     hybrid_ls_rr_sort(const hybrid_ls_rr_sort &RHS)
       : SPQ(RHS.SPQ) {}
-    
+
     bool operator()(const SUnit* left, const SUnit* right) const;
   };
 }  // end anonymous namespace
@@ -1029,23 +1036,46 @@
     std::vector<SUnit*> Queue;
     SF Picker;
     unsigned CurQueueId;
+    bool isBottomUp;
 
   protected:
     // SUnits - The SUnits for the current graph.
     std::vector<SUnit> *SUnits;
-    
+
+    MachineFunction &MF;
     const TargetInstrInfo *TII;
     const TargetRegisterInfo *TRI;
+    const TargetLowering *TLI;
     ScheduleDAGRRList *scheduleDAG;
 
     // SethiUllmanNumbers - The SethiUllman number for each node.
     std::vector<unsigned> SethiUllmanNumbers;
 
+    /// RegPressure - Tracking current reg pressure per register class.
+    ///
+    std::vector<int> RegPressure;
+
+    /// RegLimit - Tracking the number of allocatable registers per register
+    /// class.
+    std::vector<int> RegLimit;
+
   public:
-    RegReductionPriorityQueue(const TargetInstrInfo *tii,
-                              const TargetRegisterInfo *tri)
-      : Picker(this), CurQueueId(0),
-        TII(tii), TRI(tri), scheduleDAG(NULL) {}
+    RegReductionPriorityQueue(MachineFunction &mf,
+                              bool isbottomup,
+                              const TargetInstrInfo *tii,
+                              const TargetRegisterInfo *tri,
+                              const TargetLowering *tli)
+      : Picker(this), CurQueueId(0), isBottomUp(isbottomup),
+        MF(mf), TII(tii), TRI(tri), TLI(tli), scheduleDAG(NULL) {
+      unsigned NumRC = TRI->getNumRegClasses();
+      RegLimit.resize(NumRC);
+      RegPressure.resize(NumRC);
+      std::fill(RegLimit.begin(), RegLimit.end(), 0);
+      std::fill(RegPressure.begin(), RegPressure.end(), 0);
+      for (TargetRegisterInfo::regclass_iterator I = TRI->regclass_begin(),
+             E = TRI->regclass_end(); I != E; ++I)
+        RegLimit[(*I)->getID()] = tri->getAllocatableSet(MF, *I).count() - 1;
+    }
     
     void initNodes(std::vector<SUnit> &sunits) {
       SUnits = &sunits;
@@ -1072,6 +1102,7 @@
     void releaseState() {
       SUnits = 0;
       SethiUllmanNumbers.clear();
+      std::fill(RegPressure.begin(), RegPressure.end(), 0);
     }
 
     unsigned getNodePriority(const SUnit *SU) const {
@@ -1139,10 +1170,191 @@
       SU->NodeQueueId = 0;
     }
 
+    // EstimateSpills - Given a scheduling unit, estimate the number of spills 
+    // it would cause by scheduling it at the current cycle.
+    unsigned EstimateSpills(const SUnit *SU) const {
+      if (!TLI)
+        return 0;
+
+      unsigned Spills = 0;
+      for (SUnit::const_pred_iterator I = SU->Preds.begin(),E = SU->Preds.end();
+           I != E; ++I) {
+        if (I->isCtrl())
+          continue;
+        SUnit *PredSU = I->getSUnit();
+        if (PredSU->NumSuccsLeft != PredSU->NumSuccs - 1)
+          continue;
+        const SDNode *N = PredSU->getNode();
+        if (!N->isMachineOpcode())
+          continue;
+        unsigned NumDefs = TII->get(N->getMachineOpcode()).getNumDefs();
+        for (unsigned i = 0; i != NumDefs; ++i) {
+          EVT VT = N->getValueType(i);
+          if (!N->hasAnyUseOfValue(i))
+            continue;
+          unsigned RCId = TLI->getRepRegClassFor(VT)->getID();
+          unsigned Cost = TLI->getRepRegClassCostFor(VT);
+          // Check if this increases register pressure of the specific register
+          // class to the point where it would cause spills.
+          int Excess = RegPressure[RCId] + Cost - RegLimit[RCId];
+          if (Excess > 0)
+            Spills += Excess;
+        }
+      }
+
+      if (!SU->NumSuccs || !Spills)
+        return Spills;
+      const SDNode *N = SU->getNode();
+      if (!N->isMachineOpcode())
+        return Spills;
+      unsigned NumDefs = TII->get(N->getMachineOpcode()).getNumDefs();
+      for (unsigned i = 0; i != NumDefs; ++i) {
+        EVT VT = N->getValueType(i);
+        if (!N->hasAnyUseOfValue(i))
+          continue;
+        unsigned RCId = TLI->getRepRegClassFor(VT)->getID();
+        unsigned Cost = TLI->getRepRegClassCostFor(VT);
+        if (RegPressure[RCId] > RegLimit[RCId]) {
+          int Less = RegLimit[RCId] - (RegPressure[RCId] - Cost);
+          if (Less > 0) {
+            if (Spills <= (unsigned)Less)
+              return 0;
+            Spills -= Less;
+          }
+        }
+      }
+
+      return Spills;
+    }
+
+    void OpenPredLives(SUnit *SU) {
+      const SDNode *N = SU->getNode();
+      if (!N->isMachineOpcode())
+        return;
+      unsigned Opc = N->getMachineOpcode();
+      if (Opc == TargetOpcode::EXTRACT_SUBREG || 
+          Opc == TargetOpcode::INSERT_SUBREG ||
+          Opc == TargetOpcode::SUBREG_TO_REG ||
+          Opc == TargetOpcode::COPY_TO_REGCLASS ||
+          Opc == TargetOpcode::REG_SEQUENCE ||
+          Opc == TargetOpcode::IMPLICIT_DEF)
+        return;
+
+      for (SUnit::pred_iterator I = SU->Preds.begin(), E = SU->Preds.end();
+           I != E; ++I) {
+        if (I->isCtrl())
+          continue;
+        SUnit *PredSU = I->getSUnit();
+        if (PredSU->NumSuccsLeft != PredSU->NumSuccs - 1)
+          continue;
+        const SDNode *PN = PredSU->getNode();
+        if (!PN->isMachineOpcode())
+          continue;
+        unsigned NumDefs = TII->get(PN->getMachineOpcode()).getNumDefs();
+        for (unsigned i = 0; i != NumDefs; ++i) {
+          EVT VT = PN->getValueType(i);
+          if (!PN->hasAnyUseOfValue(i))
+            continue;
+          unsigned RCId = TLI->getRepRegClassFor(VT)->getID();
+          RegPressure[RCId] += TLI->getRepRegClassCostFor(VT);
+        }
+      }
+
+      if (!SU->NumSuccs)
+        return;
+      unsigned NumDefs = TII->get(N->getMachineOpcode()).getNumDefs();
+      for (unsigned i = 0; i != NumDefs; ++i) {
+        EVT VT = N->getValueType(i);
+        if (!N->hasAnyUseOfValue(i))
+          continue;
+        unsigned RCId = TLI->getRepRegClassFor(VT)->getID();
+        RegPressure[RCId] -= TLI->getRepRegClassCostFor(VT);
+        if (RegPressure[RCId] < 0)
+          // Register pressure tracking is imprecise. This can happen.
+          RegPressure[RCId] = 0;
+      }
+    }
+
+    void ClosePredLives(SUnit *SU) {
+      const SDNode *N = SU->getNode();
+      if (!N->isMachineOpcode())
+        return;
+      unsigned Opc = N->getMachineOpcode();
+      if (Opc == TargetOpcode::EXTRACT_SUBREG || 
+          Opc == TargetOpcode::INSERT_SUBREG ||
+          Opc == TargetOpcode::SUBREG_TO_REG ||
+          Opc == TargetOpcode::COPY_TO_REGCLASS ||
+          Opc == TargetOpcode::REG_SEQUENCE ||
+          Opc == TargetOpcode::IMPLICIT_DEF)
+        return;
+
+      for (SUnit::pred_iterator I = SU->Preds.begin(), E = SU->Preds.end();
+           I != E; ++I) {
+        if (I->isCtrl())
+          continue;
+        SUnit *PredSU = I->getSUnit();
+        if (PredSU->NumSuccsLeft != PredSU->NumSuccs - 1)
+          continue;
+        const SDNode *PN = PredSU->getNode();
+        if (!PN->isMachineOpcode())
+          continue;
+        unsigned NumDefs = TII->get(PN->getMachineOpcode()).getNumDefs();
+        for (unsigned i = 0; i != NumDefs; ++i) {
+          EVT VT = PN->getValueType(i);
+          if (!PN->hasAnyUseOfValue(i))
+            continue;
+          unsigned RCId = TLI->getRepRegClassFor(VT)->getID();
+          RegPressure[RCId] -= TLI->getRepRegClassCostFor(VT);
+          if (RegPressure[RCId] < 0)
+            // Register pressure tracking is imprecise. This can happen.
+            RegPressure[RCId] = 0;
+        }
+      }
+
+      if (!SU->NumSuccs)
+        return;
+      unsigned NumDefs = TII->get(N->getMachineOpcode()).getNumDefs();
+      for (unsigned i = NumDefs, e = N->getNumValues(); i != e; ++i) {
+        EVT VT = N->getValueType(i);
+        if (VT == MVT::Flag || VT == MVT::Other)
+          continue;
+        if (!N->hasAnyUseOfValue(i))
+          continue;
+        unsigned RCId = TLI->getRepRegClassFor(VT)->getID();
+        RegPressure[RCId] += TLI->getRepRegClassCostFor(VT);
+      }
+    }
+
+    void ScheduledNode(SUnit *SU) {
+      if (!TLI || !isBottomUp)
+        return;
+      OpenPredLives(SU);
+      dumpRegPressure();
+    }
+
+    void UnscheduledNode(SUnit *SU) {
+      if (!TLI || !isBottomUp)
+        return;
+      ClosePredLives(SU);
+      dumpRegPressure();
+    }
+
     void setScheduleDAG(ScheduleDAGRRList *scheduleDag) { 
       scheduleDAG = scheduleDag; 
     }
 
+    void dumpRegPressure() const {
+      for (TargetRegisterInfo::regclass_iterator I = TRI->regclass_begin(),
+             E = TRI->regclass_end(); I != E; ++I) {
+        const TargetRegisterClass *RC = *I;
+        unsigned Id = RC->getID();
+        unsigned RP = RegPressure[Id];
+        if (!RP) continue;
+        DEBUG(dbgs() << RC->getName() << ": " << RP << " / " << RegLimit[Id]
+              << '\n');
+      }
+    }
+
   protected:
     bool canClobber(const SUnit *SU, const SUnit *Op);
     void AddPseudoTwoAddrDeps();
@@ -1635,8 +1847,8 @@
   const TargetInstrInfo *TII = TM.getInstrInfo();
   const TargetRegisterInfo *TRI = TM.getRegisterInfo();
   
-  BURegReductionPriorityQueue *PQ = new BURegReductionPriorityQueue(TII, TRI);
-
+  BURegReductionPriorityQueue *PQ =
+    new BURegReductionPriorityQueue(*IS->MF, true, TII, TRI, 0);
   ScheduleDAGRRList *SD = new ScheduleDAGRRList(*IS->MF, true, false, PQ);
   PQ->setScheduleDAG(SD);
   return SD;  
@@ -1648,8 +1860,8 @@
   const TargetInstrInfo *TII = TM.getInstrInfo();
   const TargetRegisterInfo *TRI = TM.getRegisterInfo();
   
-  TDRegReductionPriorityQueue *PQ = new TDRegReductionPriorityQueue(TII, TRI);
-
+  TDRegReductionPriorityQueue *PQ =
+    new TDRegReductionPriorityQueue(*IS->MF, false, TII, TRI, 0);
   ScheduleDAGRRList *SD = new ScheduleDAGRRList(*IS->MF, false, false, PQ);
   PQ->setScheduleDAG(SD);
   return SD;
@@ -1661,8 +1873,8 @@
   const TargetInstrInfo *TII = TM.getInstrInfo();
   const TargetRegisterInfo *TRI = TM.getRegisterInfo();
   
-  SrcRegReductionPriorityQueue *PQ = new SrcRegReductionPriorityQueue(TII, TRI);
-
+  SrcRegReductionPriorityQueue *PQ =
+    new SrcRegReductionPriorityQueue(*IS->MF, true, TII, TRI, 0);
   ScheduleDAGRRList *SD = new ScheduleDAGRRList(*IS->MF, true, false, PQ);
   PQ->setScheduleDAG(SD);
   return SD;  
@@ -1673,9 +1885,11 @@
   const TargetMachine &TM = IS->TM;
   const TargetInstrInfo *TII = TM.getInstrInfo();
   const TargetRegisterInfo *TRI = TM.getRegisterInfo();
+  const TargetLowering *TLI = &IS->getTargetLowering();
   
-  HybridBURRPriorityQueue *PQ = new HybridBURRPriorityQueue(TII, TRI);
-
+  HybridBURRPriorityQueue *PQ =
+    new HybridBURRPriorityQueue(*IS->MF, true, TII, TRI,
+                                (RegPressureAware ? TLI : 0));
   ScheduleDAGRRList *SD = new ScheduleDAGRRList(*IS->MF, true, true, PQ);
   PQ->setScheduleDAG(SD);
   return SD;  
diff --git a/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index dafda50..6e6fede 100644
--- a/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -678,9 +678,12 @@
 }
 
 /// findRepresentativeClass - Return the largest legal super-reg register class
-/// of the specified register class.
-const TargetRegisterClass *
-TargetLowering::findRepresentativeClass(const TargetRegisterClass *RC) const {
+/// of the register class for the specified type and its associated "cost".
+std::pair<const TargetRegisterClass*, uint8_t>
+TargetLowering::findRepresentativeClass(EVT VT) const {
+  const TargetRegisterClass *RC = RegClassForVT[VT.getSimpleVT().SimpleTy];
+  if (!RC)
+    return std::make_pair(RC, 0);
   const TargetRegisterClass *BestRC = RC;
   for (TargetRegisterInfo::regclass_iterator I = RC->superregclasses_begin(),
          E = RC->superregclasses_end(); I != E; ++I) {
@@ -688,10 +691,10 @@
     if (RRC->isASubClass() || !isLegalRC(RRC))
       continue;
     if (!hasLegalSuperRegRegClasses(RRC))
-      return RRC;
+      return std::make_pair(RRC, 1);
     BestRC = RRC;
   }
-  return BestRC;
+  return std::make_pair(BestRC, 1);
 }
 
 /// computeRegisterProperties - Once all of the register classes are added,
@@ -820,8 +823,11 @@
   // a group of value types. For example, on i386, i8, i16, and i32
   // representative would be GR32; while on x86_64 it's GR64.
   for (unsigned i = 0; i != MVT::LAST_VALUETYPE; ++i) {
-    const TargetRegisterClass *RC = RegClassForVT[i];
-    RepRegClassForVT[i] = RC ? findRepresentativeClass(RC) : 0;
+    const TargetRegisterClass* RRC;
+    uint8_t Cost;
+    tie(RRC, Cost) =  findRepresentativeClass((MVT::SimpleValueType)i);
+    RepRegClassForVT[i] = RRC;
+    RepRegClassCostForVT[i] = Cost;
   }
 }