Introduce the ScopExpander as a SCEVExpander replacement

  The SCEVExpander cannot deal with all SCEVs Polly allows in all kinds
  of expressions. To this end we introduce a ScopExpander that handles
  the additional expressions separatly and falls back to the
  SCEVExpander for everything else.

Reviewers: grosser, Meinersbur

Subscribers: #polly

Differential Revision: http://reviews.llvm.org/D12066

llvm-svn: 245288
diff --git a/polly/lib/Support/ScopHelper.cpp b/polly/lib/Support/ScopHelper.cpp
index 0827dc0..602009c 100644
--- a/polly/lib/Support/ScopHelper.cpp
+++ b/polly/lib/Support/ScopHelper.cpp
@@ -17,12 +17,14 @@
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/RegionInfo.h"
 #include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/Analysis/ScalarEvolutionExpander.h"
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
 #include "llvm/IR/CFG.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 
 using namespace llvm;
+using namespace polly;
 
 #define DEBUG_TYPE "polly-scop-helper"
 
@@ -252,3 +254,110 @@
   // splitBlock updates DT, LI and RI.
   splitBlock(EntryBlock, I, DT, LI, RI);
 }
+
+/// The SCEVExpander will __not__ generate any code for an existing SDiv/SRem
+/// instruction but just use it, if it is referenced as a SCEVUnknown. We want
+/// however to generate new code if the instruction is in the analyzed region
+/// and we generate code outside/in front of that region. Hence, we generate the
+/// code for the SDiv/SRem operands in front of the analyzed region and then
+/// create a new SDiv/SRem operation there too.
+struct ScopExpander : SCEVVisitor<ScopExpander, const SCEV *> {
+  friend struct SCEVVisitor<ScopExpander, const SCEV *>;
+
+  explicit ScopExpander(const Region &R, ScalarEvolution &SE,
+                        const DataLayout &DL, const char *Name)
+      : Expander(SCEVExpander(SE, DL, Name)), SE(SE), Name(Name), R(R) {}
+
+  Value *expandCodeFor(const SCEV *E, Type *Ty, Instruction *I) {
+    // If we generate code in the region we will immediately fall back to the
+    // SCEVExpander, otherwise we will stop at all unknowns in the SCEV and if
+    // needed replace them by copies computed in the entering block.
+    if (!R.contains(I))
+      E = visit(E);
+    return Expander.expandCodeFor(E, Ty, I);
+  }
+
+private:
+  SCEVExpander Expander;
+  ScalarEvolution &SE;
+  const char *Name;
+  const Region &R;
+
+  const SCEV *visitUnknown(const SCEVUnknown *E) {
+    Instruction *Inst = dyn_cast<Instruction>(E->getValue());
+    if (!Inst || (Inst->getOpcode() != Instruction::SRem &&
+                  Inst->getOpcode() != Instruction::SDiv))
+      return E;
+
+    if (!R.contains(Inst))
+      return E;
+
+    Instruction *StartIP = R.getEnteringBlock()->getTerminator();
+
+    const SCEV *LHSScev = visit(SE.getSCEV(Inst->getOperand(0)));
+    const SCEV *RHSScev = visit(SE.getSCEV(Inst->getOperand(1)));
+
+    Value *LHS = Expander.expandCodeFor(LHSScev, E->getType(), StartIP);
+    Value *RHS = Expander.expandCodeFor(RHSScev, E->getType(), StartIP);
+
+    Inst = BinaryOperator::Create((Instruction::BinaryOps)Inst->getOpcode(),
+                                  LHS, RHS, Inst->getName() + Name, StartIP);
+    return SE.getSCEV(Inst);
+  }
+
+  /// The following functions will just traverse the SCEV and rebuild it with
+  /// the new operands returned by the traversal.
+  ///
+  ///{
+  const SCEV *visitConstant(const SCEVConstant *E) { return E; }
+  const SCEV *visitTruncateExpr(const SCEVTruncateExpr *E) {
+    return SE.getTruncateExpr(visit(E->getOperand()), E->getType());
+  }
+  const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *E) {
+    return SE.getZeroExtendExpr(visit(E->getOperand()), E->getType());
+  }
+  const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *E) {
+    return SE.getSignExtendExpr(visit(E->getOperand()), E->getType());
+  }
+  const SCEV *visitUDivExpr(const SCEVUDivExpr *E) {
+    return SE.getUDivExpr(visit(E->getLHS()), visit(E->getRHS()));
+  }
+  const SCEV *visitAddExpr(const SCEVAddExpr *E) {
+    SmallVector<const SCEV *, 4> NewOps;
+    for (const SCEV *Op : E->operands())
+      NewOps.push_back(visit(Op));
+    return SE.getAddExpr(NewOps);
+  }
+  const SCEV *visitMulExpr(const SCEVMulExpr *E) {
+    SmallVector<const SCEV *, 4> NewOps;
+    for (const SCEV *Op : E->operands())
+      NewOps.push_back(visit(Op));
+    return SE.getMulExpr(NewOps);
+  }
+  const SCEV *visitUMaxExpr(const SCEVUMaxExpr *E) {
+    SmallVector<const SCEV *, 4> NewOps;
+    for (const SCEV *Op : E->operands())
+      NewOps.push_back(visit(Op));
+    return SE.getUMaxExpr(NewOps);
+  }
+  const SCEV *visitSMaxExpr(const SCEVSMaxExpr *E) {
+    SmallVector<const SCEV *, 4> NewOps;
+    for (const SCEV *Op : E->operands())
+      NewOps.push_back(visit(Op));
+    return SE.getSMaxExpr(NewOps);
+  }
+  const SCEV *visitAddRecExpr(const SCEVAddRecExpr *E) {
+    SmallVector<const SCEV *, 4> NewOps;
+    for (const SCEV *Op : E->operands())
+      NewOps.push_back(visit(Op));
+    return SE.getAddRecExpr(NewOps, E->getLoop(), E->getNoWrapFlags());
+  }
+  ///}
+};
+
+Value *polly::expandCodeFor(Scop &S, ScalarEvolution &SE, const DataLayout &DL,
+                            const char *Name, const SCEV *E, Type *Ty,
+                            Instruction *IP) {
+  ScopExpander Expander(S.getRegion(), SE, DL, Name);
+  return Expander.expandCodeFor(E, Ty, IP);
+}