[Attributor] Deduce "align" attribute

Summary:
Deduce "align" attribute in attributor.

Reviewers: jdoerfert, sstefan1

Reviewed By: jdoerfert

Subscribers: hiraditya, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D64152

llvm-svn: 367187
diff --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp
index fabe144..6a7e76a 100644
--- a/llvm/lib/Transforms/IPO/Attributor.cpp
+++ b/llvm/lib/Transforms/IPO/Attributor.cpp
@@ -72,6 +72,9 @@
           "Number of function arguments marked dereferenceable");
 STATISTIC(NumCSArgumentDereferenceable,
           "Number of call site arguments marked dereferenceable");
+STATISTIC(NumFnReturnedAlign, "Number of function return values marked align");
+STATISTIC(NumFnArgumentAlign, "Number of function arguments marked align");
+STATISTIC(NumCSArgumentAlign, "Number of call site arguments marked align");
 
 // TODO: Determine a good default value.
 //
@@ -115,6 +118,21 @@
     return;
 
   switch (Attr.getKindAsEnum()) {
+  case Attribute::Alignment:
+    switch (MP) {
+    case AbstractAttribute::MP_RETURNED:
+      NumFnReturnedAlign++;
+      break;
+    case AbstractAttribute::MP_ARGUMENT:
+      NumFnArgumentAlign++;
+      break;
+    case AbstractAttribute::MP_CALL_SITE_ARGUMENT:
+      NumCSArgumentAlign++;
+      break;
+    default:
+      break;
+    }
+    break;
   case Attribute::Dereferenceable:
     switch (MP) {
     case AbstractAttribute::MP_RETURNED:
@@ -1974,6 +1992,212 @@
                                                        : ChangeStatus::CHANGED;
 }
 
+// ------------------------ Align Argument Attribute ------------------------
+
+struct AAAlignImpl : AAAlign, IntegerState {
+
+  // Max alignemnt value allowed in IR
+  static const unsigned MAX_ALIGN = 1U << 29;
+
+  AAAlignImpl(Value *AssociatedVal, Value &AnchoredValue,
+              InformationCache &InfoCache)
+      : AAAlign(AssociatedVal, AnchoredValue, InfoCache),
+        IntegerState(MAX_ALIGN) {}
+
+  AAAlignImpl(Value &V, InformationCache &InfoCache)
+      : AAAlignImpl(&V, V, InfoCache) {}
+
+  /// See AbstractAttribute::getState()
+  /// {
+  AbstractState &getState() override { return *this; }
+  const AbstractState &getState() const override { return *this; }
+  /// }
+
+  virtual const std::string getAsStr() const override {
+    return getAssumedAlign() ? ("align<" + std::to_string(getKnownAlign()) +
+                                "-" + std::to_string(getAssumedAlign()) + ">")
+                             : "unknown-align";
+  }
+
+  /// See AAAlign::getAssumedAlign().
+  unsigned getAssumedAlign() const override { return getAssumed(); }
+
+  /// See AAAlign::getKnownAlign().
+  unsigned getKnownAlign() const override { return getKnown(); }
+
+  /// See AbstractAttriubute::initialize(...).
+  void initialize(Attributor &A) override {
+    Function &F = getAnchorScope();
+
+    unsigned AttrIdx =
+        getAttrIndex(getManifestPosition(), getArgNo(getAnchoredValue()));
+
+    // Already the function has align attribute on return value or argument.
+    if (F.getAttributes().hasAttribute(AttrIdx, ID))
+      addKnownBits(F.getAttribute(AttrIdx, ID).getAlignment());
+  }
+
+  /// See AbstractAttribute::getDeducedAttributes
+  virtual void
+  getDeducedAttributes(SmallVectorImpl<Attribute> &Attrs) const override {
+    LLVMContext &Ctx = AnchoredVal.getContext();
+
+    Attrs.emplace_back(Attribute::getWithAlignment(Ctx, getAssumedAlign()));
+  }
+};
+
+/// Align attribute for function return value.
+struct AAAlignReturned : AAAlignImpl {
+
+  AAAlignReturned(Function &F, InformationCache &InfoCache)
+      : AAAlignImpl(F, InfoCache) {}
+
+  /// See AbstractAttribute::getManifestPosition().
+  virtual ManifestPosition getManifestPosition() const override {
+    return MP_RETURNED;
+  }
+
+  /// See AbstractAttribute::updateImpl(...).
+  virtual ChangeStatus updateImpl(Attributor &A) override;
+};
+
+ChangeStatus AAAlignReturned::updateImpl(Attributor &A) {
+  Function &F = getAnchorScope();
+  auto *AARetValImpl = A.getAAFor<AAReturnedValuesImpl>(*this, F);
+  if (!AARetValImpl) {
+    indicatePessimisticFixpoint();
+    return ChangeStatus::CHANGED;
+  }
+
+  // Currently, align<n> is deduced if alignments in return values are assumed
+  // as greater than n. We reach pessimistic fixpoint if any of the return value
+  // wouldn't have align. If no assumed state was used for reasoning, an
+  // optimistic fixpoint is reached earlier.
+
+  base_t BeforeState = getAssumed();
+  std::function<bool(Value &)> Pred = [&](Value &RV) -> bool {
+    auto *AlignAA = A.getAAFor<AAAlign>(*this, RV);
+
+    if (AlignAA)
+      takeAssumedMinimum(AlignAA->getAssumedAlign());
+    else
+      // Use IR information.
+      takeAssumedMinimum(RV.getPointerAlignment(
+          getAnchorScope().getParent()->getDataLayout()));
+
+    return isValidState();
+  };
+
+  if (!AARetValImpl->checkForallReturnedValues(Pred)) {
+    indicatePessimisticFixpoint();
+    return ChangeStatus::CHANGED;
+  }
+
+  return (getAssumed() != BeforeState) ? ChangeStatus::CHANGED
+                                       : ChangeStatus::UNCHANGED;
+}
+
+/// Align attribute for function argument.
+struct AAAlignArgument : AAAlignImpl {
+
+  AAAlignArgument(Argument &A, InformationCache &InfoCache)
+      : AAAlignImpl(A, InfoCache) {}
+
+  /// See AbstractAttribute::getManifestPosition().
+  virtual ManifestPosition getManifestPosition() const override {
+    return MP_ARGUMENT;
+  }
+
+  /// See AbstractAttribute::updateImpl(...).
+  virtual ChangeStatus updateImpl(Attributor &A) override;
+};
+
+ChangeStatus AAAlignArgument::updateImpl(Attributor &A) {
+
+  Function &F = getAnchorScope();
+  Argument &Arg = cast<Argument>(getAnchoredValue());
+
+  unsigned ArgNo = Arg.getArgNo();
+  const DataLayout &DL = F.getParent()->getDataLayout();
+
+  auto BeforeState = getAssumed();
+
+  // Callback function
+  std::function<bool(CallSite)> CallSiteCheck = [&](CallSite CS) {
+    assert(CS && "Sanity check: Call site was not initialized properly!");
+
+    auto *AlignAA = A.getAAFor<AAAlign>(*this, *CS.getInstruction(), ArgNo);
+
+    // Check that AlignAA is AAAlignCallSiteArgument.
+    if (AlignAA) {
+      ImmutableCallSite ICS(&AlignAA->getAnchoredValue());
+      if (ICS && CS.getInstruction() == ICS.getInstruction()) {
+        takeAssumedMinimum(AlignAA->getAssumedAlign());
+        return isValidState();
+      }
+    }
+
+    Value *V = CS.getArgOperand(ArgNo);
+    takeAssumedMinimum(V->getPointerAlignment(DL));
+    return isValidState();
+  };
+
+  if (!A.checkForAllCallSites(F, CallSiteCheck, true))
+    indicatePessimisticFixpoint();
+
+  return BeforeState == getAssumed() ? ChangeStatus::UNCHANGED
+                                     : ChangeStatus ::CHANGED;
+}
+
+struct AAAlignCallSiteArgument : AAAlignImpl {
+
+  /// See AANonNullImpl::AANonNullImpl(...).
+  AAAlignCallSiteArgument(CallSite CS, unsigned ArgNo,
+                          InformationCache &InfoCache)
+      : AAAlignImpl(CS.getArgOperand(ArgNo), *CS.getInstruction(), InfoCache),
+        ArgNo(ArgNo) {}
+
+  /// See AbstractAttribute::initialize(...).
+  void initialize(Attributor &A) override {
+    CallSite CS(&getAnchoredValue());
+    takeKnownMaximum(getAssociatedValue()->getPointerAlignment(
+        getAnchorScope().getParent()->getDataLayout()));
+  }
+
+  /// See AbstractAttribute::updateImpl(Attributor &A).
+  ChangeStatus updateImpl(Attributor &A) override;
+
+  /// See AbstractAttribute::getManifestPosition().
+  ManifestPosition getManifestPosition() const override {
+    return MP_CALL_SITE_ARGUMENT;
+  };
+
+  // Return argument index of associated value.
+  int getArgNo() const { return ArgNo; }
+
+private:
+  unsigned ArgNo;
+};
+
+ChangeStatus AAAlignCallSiteArgument::updateImpl(Attributor &A) {
+  // NOTE: Never look at the argument of the callee in this method.
+  //       If we do this, "align" is always deduced because of the assumption.
+
+  auto BeforeState = getAssumed();
+
+  Value &V = *getAssociatedValue();
+
+  auto *AlignAA = A.getAAFor<AAAlign>(*this, V);
+
+  if (AlignAA)
+    takeAssumedMinimum(AlignAA->getAssumedAlign());
+  else
+    indicatePessimisticFixpoint();
+
+  return BeforeState == getAssumed() ? ChangeStatus::UNCHANGED
+                                     : ChangeStatus::CHANGED;
+}
+
 /// ----------------------------------------------------------------------------
 ///                               Attributor
 /// ----------------------------------------------------------------------------
@@ -2171,6 +2395,10 @@
       registerAA(*new AAReturnedValuesImpl(F, InfoCache));
 
     if (ReturnType->isPointerTy()) {
+      // Every function with pointer return type might be marked align.
+      if (!Whitelist || Whitelist->count(AAAlignReturned::ID))
+        registerAA(*new AAAlignReturned(F, InfoCache));
+
       // Every function with pointer return type might be marked nonnull.
       if (!Whitelist || Whitelist->count(AANonNullReturned::ID))
         registerAA(*new AANonNullReturned(F, InfoCache));
@@ -2196,6 +2424,10 @@
       // Every argument with pointer type might be marked dereferenceable.
       if (!Whitelist || Whitelist->count(AADereferenceableArgument::ID))
         registerAA(*new AADereferenceableArgument(Arg, InfoCache));
+
+      // Every argument with pointer type might be marked align.
+      if (!Whitelist || Whitelist->count(AAAlignArgument::ID))
+        registerAA(*new AAAlignArgument(Arg, InfoCache));
     }
   }
 
@@ -2254,6 +2486,10 @@
             Whitelist->count(AADereferenceableCallSiteArgument::ID))
           registerAA(*new AADereferenceableCallSiteArgument(CS, i, InfoCache),
                      i);
+
+        // Call site argument attribute "align".
+        if (!Whitelist || Whitelist->count(AAAlignCallSiteArgument::ID))
+          registerAA(*new AAAlignCallSiteArgument(CS, i, InfoCache), i);
       }
     }
   }