[LAA] Merge memchecks for accesses separated by a constant offset
Summary:
Often filter-like loops will do memory accesses that are
separated by constant offsets. In these cases it is
common that we will exceed the threshold for the
allowable number of checks.
However, it should be possible to merge such checks,
sice a check of any interval againt two other intervals separated
by a constant offset (a,b), (a+c, b+c) will be equivalent with
a check againt (a, b+c), as long as (a,b) and (a+c, b+c) overlap.
Assuming the loop will be executed for a sufficient number of
iterations, this will be true. If not true, checking against
(a, b+c) is still safe (although not equivalent).
As long as there are no dependencies between two accesses,
we can merge their checks into a single one. We use this
technique to construct groups of accesses, and then check
the intervals associated with the groups instead of
checking the accesses directly.
Reviewers: anemet
Subscribers: llvm-commits
Differential Revision: http://reviews.llvm.org/D10386
llvm-svn: 241673
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index b11cd7e..65a2586 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -48,6 +48,13 @@
cl::location(VectorizerParams::RuntimeMemoryCheckThreshold), cl::init(8));
unsigned VectorizerParams::RuntimeMemoryCheckThreshold;
+/// \brief The maximum iterations used to merge memory checks
+static cl::opt<unsigned> MemoryCheckMergeThreshold(
+ "memory-check-merge-threshold", cl::Hidden,
+ cl::desc("Maximum number of comparisons done when trying to merge "
+ "runtime memory checks. (default = 100)"),
+ cl::init(100));
+
/// Maximum SIMD width.
const unsigned VectorizerParams::MaxVectorWidth = 64;
@@ -113,8 +120,8 @@
}
void LoopAccessInfo::RuntimePointerCheck::insert(
- ScalarEvolution *SE, Loop *Lp, Value *Ptr, bool WritePtr, unsigned DepSetId,
- unsigned ASId, const ValueToValueMap &Strides) {
+ Loop *Lp, Value *Ptr, bool WritePtr, unsigned DepSetId, unsigned ASId,
+ const ValueToValueMap &Strides) {
// Get the stride replaced scev.
const SCEV *Sc = replaceSymbolicStrideSCEV(SE, Strides, Ptr);
const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Sc);
@@ -127,6 +134,136 @@
IsWritePtr.push_back(WritePtr);
DependencySetId.push_back(DepSetId);
AliasSetId.push_back(ASId);
+ Exprs.push_back(Sc);
+}
+
+bool LoopAccessInfo::RuntimePointerCheck::needsChecking(
+ const CheckingPtrGroup &M, const CheckingPtrGroup &N,
+ const SmallVectorImpl<int> *PtrPartition) const {
+ for (unsigned I = 0, EI = M.Members.size(); EI != I; ++I)
+ for (unsigned J = 0, EJ = N.Members.size(); EJ != J; ++J)
+ if (needsChecking(M.Members[I], N.Members[J], PtrPartition))
+ return true;
+ return false;
+}
+
+/// Compare \p I and \p J and return the minimum.
+/// Return nullptr in case we couldn't find an answer.
+static const SCEV *getMinFromExprs(const SCEV *I, const SCEV *J,
+ ScalarEvolution *SE) {
+ const SCEV *Diff = SE->getMinusSCEV(J, I);
+ const SCEVConstant *C = dyn_cast<const SCEVConstant>(Diff);
+
+ if (!C)
+ return nullptr;
+ if (C->getValue()->isNegative())
+ return J;
+ return I;
+}
+
+bool LoopAccessInfo::RuntimePointerCheck::CheckingPtrGroup::addPointer(
+ unsigned Index) {
+ // Compare the starts and ends with the known minimum and maximum
+ // of this set. We need to know how we compare against the min/max
+ // of the set in order to be able to emit memchecks.
+ const SCEV *Min0 = getMinFromExprs(RtCheck.Starts[Index], Low, RtCheck.SE);
+ if (!Min0)
+ return false;
+
+ const SCEV *Min1 = getMinFromExprs(RtCheck.Ends[Index], High, RtCheck.SE);
+ if (!Min1)
+ return false;
+
+ // Update the low bound expression if we've found a new min value.
+ if (Min0 == RtCheck.Starts[Index])
+ Low = RtCheck.Starts[Index];
+
+ // Update the high bound expression if we've found a new max value.
+ if (Min1 != RtCheck.Ends[Index])
+ High = RtCheck.Ends[Index];
+
+ Members.push_back(Index);
+ return true;
+}
+
+void LoopAccessInfo::RuntimePointerCheck::groupChecks(
+ MemoryDepChecker::DepCandidates &DepCands,
+ bool UseDependencies) {
+ // We build the groups from dependency candidates equivalence classes
+ // because:
+ // - We know that pointers in the same equivalence class share
+ // the same underlying object and therefore there is a chance
+ // that we can compare pointers
+ // - We wouldn't be able to merge two pointers for which we need
+ // to emit a memcheck. The classes in DepCands are already
+ // conveniently built such that no two pointers in the same
+ // class need checking against each other.
+
+ // We use the following (greedy) algorithm to construct the groups
+ // For every pointer in the equivalence class:
+ // For each existing group:
+ // - if the difference between this pointer and the min/max bounds
+ // of the group is a constant, then make the pointer part of the
+ // group and update the min/max bounds of that group as required.
+
+ CheckingGroups.clear();
+
+ // If we don't have the dependency partitions, construct a new
+ // checking pointer group for each pointer.
+ if (!UseDependencies) {
+ for (unsigned I = 0; I < Pointers.size(); ++I)
+ CheckingGroups.push_back(CheckingPtrGroup(I, *this));
+ return;
+ }
+
+ unsigned TotalComparisons = 0;
+
+ DenseMap<Value *, unsigned> PositionMap;
+ for (unsigned Pointer = 0; Pointer < Pointers.size(); ++Pointer)
+ PositionMap[Pointers[Pointer]] = Pointer;
+
+ // Go through all equivalence classes, get the the "pointer check groups"
+ // and add them to the overall solution.
+ for (auto DI = DepCands.begin(), DE = DepCands.end(); DI != DE; ++DI) {
+ if (!DI->isLeader())
+ continue;
+
+ SmallVector<CheckingPtrGroup, 2> Groups;
+
+ for (auto MI = DepCands.member_begin(DI), ME = DepCands.member_end();
+ MI != ME; ++MI) {
+ unsigned Pointer = PositionMap[MI->getPointer()];
+ bool Merged = false;
+
+ // Go through all the existing sets and see if we can find one
+ // which can include this pointer.
+ for (CheckingPtrGroup &Group : Groups) {
+ // Don't perform more than a certain amount of comparisons.
+ // This should limit the cost of grouping the pointers to something
+ // reasonable. If we do end up hitting this threshold, the algorithm
+ // will create separate groups for all remaining pointers.
+ if (TotalComparisons > MemoryCheckMergeThreshold)
+ break;
+
+ TotalComparisons++;
+
+ if (Group.addPointer(Pointer)) {
+ Merged = true;
+ break;
+ }
+ }
+
+ if (!Merged)
+ // We couldn't add this pointer to any existing set or the threshold
+ // for the number of comparisons has been reached. Create a new group
+ // to hold the current pointer.
+ Groups.push_back(CheckingPtrGroup(Pointer, *this));
+ }
+
+ // We've computed the grouped checks for this partition.
+ // Save the results and continue with the next one.
+ std::copy(Groups.begin(), Groups.end(), std::back_inserter(CheckingGroups));
+ }
}
bool LoopAccessInfo::RuntimePointerCheck::needsChecking(
@@ -156,42 +293,71 @@
void LoopAccessInfo::RuntimePointerCheck::print(
raw_ostream &OS, unsigned Depth,
const SmallVectorImpl<int> *PtrPartition) const {
- unsigned NumPointers = Pointers.size();
- if (NumPointers == 0)
- return;
OS.indent(Depth) << "Run-time memory checks:\n";
+
unsigned N = 0;
- for (unsigned I = 0; I < NumPointers; ++I)
- for (unsigned J = I + 1; J < NumPointers; ++J)
- if (needsChecking(I, J, PtrPartition)) {
- OS.indent(Depth) << N++ << ":\n";
- OS.indent(Depth + 2) << *Pointers[I];
- if (PtrPartition)
- OS << " (Partition: " << (*PtrPartition)[I] << ")";
- OS << "\n";
- OS.indent(Depth + 2) << *Pointers[J];
- if (PtrPartition)
- OS << " (Partition: " << (*PtrPartition)[J] << ")";
- OS << "\n";
+ for (unsigned I = 0; I < CheckingGroups.size(); ++I)
+ for (unsigned J = I + 1; J < CheckingGroups.size(); ++J)
+ if (needsChecking(CheckingGroups[I], CheckingGroups[J], PtrPartition)) {
+ OS.indent(Depth) << "Check " << N++ << ":\n";
+ OS.indent(Depth + 2) << "Comparing group " << I << ":\n";
+
+ for (unsigned K = 0; K < CheckingGroups[I].Members.size(); ++K) {
+ OS.indent(Depth + 2) << *Pointers[CheckingGroups[I].Members[K]]
+ << "\n";
+ if (PtrPartition)
+ OS << " (Partition: "
+ << (*PtrPartition)[CheckingGroups[I].Members[K]] << ")"
+ << "\n";
+ }
+
+ OS.indent(Depth + 2) << "Against group " << J << ":\n";
+
+ for (unsigned K = 0; K < CheckingGroups[J].Members.size(); ++K) {
+ OS.indent(Depth + 2) << *Pointers[CheckingGroups[J].Members[K]]
+ << "\n";
+ if (PtrPartition)
+ OS << " (Partition: "
+ << (*PtrPartition)[CheckingGroups[J].Members[K]] << ")"
+ << "\n";
+ }
}
+
+ OS.indent(Depth) << "Grouped accesses:\n";
+ for (unsigned I = 0; I < CheckingGroups.size(); ++I) {
+ OS.indent(Depth + 2) << "Group " << I << ":\n";
+ OS.indent(Depth + 4) << "(Low: " << *CheckingGroups[I].Low
+ << " High: " << *CheckingGroups[I].High << ")\n";
+ for (unsigned J = 0; J < CheckingGroups[I].Members.size(); ++J) {
+ OS.indent(Depth + 6) << "Member: " << *Exprs[CheckingGroups[I].Members[J]]
+ << "\n";
+ }
+ }
}
unsigned LoopAccessInfo::RuntimePointerCheck::getNumberOfChecks(
const SmallVectorImpl<int> *PtrPartition) const {
- unsigned NumPointers = Pointers.size();
+
+ unsigned NumPartitions = CheckingGroups.size();
unsigned CheckCount = 0;
- for (unsigned I = 0; I < NumPointers; ++I)
- for (unsigned J = I + 1; J < NumPointers; ++J)
- if (needsChecking(I, J, PtrPartition))
+ for (unsigned I = 0; I < NumPartitions; ++I)
+ for (unsigned J = I + 1; J < NumPartitions; ++J)
+ if (needsChecking(CheckingGroups[I], CheckingGroups[J], PtrPartition))
CheckCount++;
return CheckCount;
}
bool LoopAccessInfo::RuntimePointerCheck::needsAnyChecking(
const SmallVectorImpl<int> *PtrPartition) const {
- return getNumberOfChecks(PtrPartition) != 0;
+ unsigned NumPointers = Pointers.size();
+
+ for (unsigned I = 0; I < NumPointers; ++I)
+ for (unsigned J = I + 1; J < NumPointers; ++J)
+ if (needsChecking(I, J, PtrPartition))
+ return true;
+ return false;
}
namespace {
@@ -341,7 +507,7 @@
// Each access has its own dependence set.
DepId = RunningDepId++;
- RtCheck.insert(SE, TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap);
+ RtCheck.insert(TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap);
DEBUG(dbgs() << "LAA: Found a runtime check ptr:" << *Ptr << '\n');
} else {
@@ -387,6 +553,9 @@
}
}
+ if (NeedRTCheck && CanDoRT)
+ RtCheck.groupChecks(DepCands, IsDepCheckNeeded);
+
return CanDoRT;
}
@@ -1360,32 +1529,35 @@
if (!PtrRtCheck.Need)
return std::make_pair(nullptr, nullptr);
- unsigned NumPointers = PtrRtCheck.Pointers.size();
- SmallVector<TrackingVH<Value> , 2> Starts;
- SmallVector<TrackingVH<Value> , 2> Ends;
+ SmallVector<TrackingVH<Value>, 2> Starts;
+ SmallVector<TrackingVH<Value>, 2> Ends;
LLVMContext &Ctx = Loc->getContext();
SCEVExpander Exp(*SE, DL, "induction");
Instruction *FirstInst = nullptr;
- for (unsigned i = 0; i < NumPointers; ++i) {
- Value *Ptr = PtrRtCheck.Pointers[i];
+ for (unsigned i = 0; i < PtrRtCheck.CheckingGroups.size(); ++i) {
+ const RuntimePointerCheck::CheckingPtrGroup &CG =
+ PtrRtCheck.CheckingGroups[i];
+ Value *Ptr = PtrRtCheck.Pointers[CG.Members[0]];
const SCEV *Sc = SE->getSCEV(Ptr);
if (SE->isLoopInvariant(Sc, TheLoop)) {
- DEBUG(dbgs() << "LAA: Adding RT check for a loop invariant ptr:" <<
- *Ptr <<"\n");
+ DEBUG(dbgs() << "LAA: Adding RT check for a loop invariant ptr:" << *Ptr
+ << "\n");
Starts.push_back(Ptr);
Ends.push_back(Ptr);
} else {
- DEBUG(dbgs() << "LAA: Adding RT check for range:" << *Ptr << '\n');
unsigned AS = Ptr->getType()->getPointerAddressSpace();
// Use this type for pointer arithmetic.
Type *PtrArithTy = Type::getInt8PtrTy(Ctx, AS);
+ Value *Start = nullptr, *End = nullptr;
- Value *Start = Exp.expandCodeFor(PtrRtCheck.Starts[i], PtrArithTy, Loc);
- Value *End = Exp.expandCodeFor(PtrRtCheck.Ends[i], PtrArithTy, Loc);
+ DEBUG(dbgs() << "LAA: Adding RT check for range:\n");
+ Start = Exp.expandCodeFor(CG.Low, PtrArithTy, Loc);
+ End = Exp.expandCodeFor(CG.High, PtrArithTy, Loc);
+ DEBUG(dbgs() << "Start: " << *CG.Low << " End: " << *CG.High << "\n");
Starts.push_back(Start);
Ends.push_back(End);
}
@@ -1394,9 +1566,14 @@
IRBuilder<> ChkBuilder(Loc);
// Our instructions might fold to a constant.
Value *MemoryRuntimeCheck = nullptr;
- for (unsigned i = 0; i < NumPointers; ++i) {
- for (unsigned j = i+1; j < NumPointers; ++j) {
- if (!PtrRtCheck.needsChecking(i, j, PtrPartition))
+ for (unsigned i = 0; i < PtrRtCheck.CheckingGroups.size(); ++i) {
+ for (unsigned j = i + 1; j < PtrRtCheck.CheckingGroups.size(); ++j) {
+ const RuntimePointerCheck::CheckingPtrGroup &CGI =
+ PtrRtCheck.CheckingGroups[i];
+ const RuntimePointerCheck::CheckingPtrGroup &CGJ =
+ PtrRtCheck.CheckingGroups[j];
+
+ if (!PtrRtCheck.needsChecking(CGI, CGJ, PtrPartition))
continue;
unsigned AS0 = Starts[i]->getType()->getPointerAddressSpace();
@@ -1447,8 +1624,8 @@
const TargetLibraryInfo *TLI, AliasAnalysis *AA,
DominatorTree *DT, LoopInfo *LI,
const ValueToValueMap &Strides)
- : DepChecker(SE, L), TheLoop(L), SE(SE), DL(DL),
- TLI(TLI), AA(AA), DT(DT), LI(LI), NumLoads(0), NumStores(0),
+ : PtrRtCheck(SE), DepChecker(SE, L), TheLoop(L), SE(SE), DL(DL), TLI(TLI),
+ AA(AA), DT(DT), LI(LI), NumLoads(0), NumStores(0),
MaxSafeDepDistBytes(-1U), CanVecMem(false),
StoreToLoopInvariantAddress(false) {
if (canAnalyzeLoop())