[COFF, ARM64] Fix ABI implementation of struct returns

Summary:
Related llvm patch: D60348.
Patch co-authored by Sanjin Sijaric.

Reviewers: rnk, efriedma, TomTan, ssijaric, ostannard

Reviewed By: efriedma

Subscribers: dmajor, richard.townsend.arm, ostannard, javed.absar, kristof.beyls, cfe-commits

Tags: #clang

Differential Revision: https://reviews.llvm.org/D60349

llvm-svn: 359932
diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index b837fa4..bc9be14 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -1999,8 +1999,7 @@
   // Attach attributes to sret.
   if (IRFunctionArgs.hasSRetArg()) {
     llvm::AttrBuilder SRETAttrs;
-    if (!RetAI.getSuppressSRet())
-      SRETAttrs.addAttribute(llvm::Attribute::StructRet);
+    SRETAttrs.addAttribute(llvm::Attribute::StructRet);
     hasUsedSRet = true;
     if (RetAI.getInReg())
       SRETAttrs.addAttribute(llvm::Attribute::InReg);
diff --git a/clang/lib/CodeGen/MicrosoftCXXABI.cpp b/clang/lib/CodeGen/MicrosoftCXXABI.cpp
index fba5490..c37bfe3 100644
--- a/clang/lib/CodeGen/MicrosoftCXXABI.cpp
+++ b/clang/lib/CodeGen/MicrosoftCXXABI.cpp
@@ -1051,33 +1051,55 @@
   return isDeletingDtor(GD);
 }
 
+static bool IsSizeGreaterThan128(const CXXRecordDecl *RD) {
+  return RD->getASTContext().getTypeSize(RD->getTypeForDecl()) > 128;
+}
+
+static bool hasMicrosoftABIRestrictions(const CXXRecordDecl *RD) {
+  // For AArch64, we use the C++14 definition of an aggregate, so we also
+  // check for:
+  //   No private or protected non static data members.
+  //   No base classes
+  //   No virtual functions
+  // Additionally, we need to ensure that there is a trivial copy assignment
+  // operator, a trivial destructor and no user-provided constructors.
+  if (RD->hasProtectedFields() || RD->hasPrivateFields())
+    return true;
+  if (RD->getNumBases() > 0)
+    return true;
+  if (RD->isPolymorphic())
+    return true;
+  if (RD->hasNonTrivialCopyAssignment())
+    return true;
+  for (const CXXConstructorDecl *Ctor : RD->ctors())
+    if (Ctor->isUserProvided())
+      return true;
+  if (RD->hasNonTrivialDestructor())
+    return true;
+  return false;
+}
+
 bool MicrosoftCXXABI::classifyReturnType(CGFunctionInfo &FI) const {
   const CXXRecordDecl *RD = FI.getReturnType()->getAsCXXRecordDecl();
   if (!RD)
     return false;
 
-  CharUnits Align = CGM.getContext().getTypeAlignInChars(FI.getReturnType());
-  if (FI.isInstanceMethod()) {
-    // If it's an instance method, aggregates are always returned indirectly via
-    // the second parameter.
-    FI.getReturnInfo() = ABIArgInfo::getIndirect(Align, /*ByVal=*/false);
-    FI.getReturnInfo().setSRetAfterThis(FI.isInstanceMethod());
+  bool isAArch64 = CGM.getTarget().getTriple().isAArch64();
+  bool isSimple = !isAArch64 || !hasMicrosoftABIRestrictions(RD);
+  bool isIndirectReturn =
+      isAArch64 ? (!RD->canPassInRegisters() ||
+                   IsSizeGreaterThan128(RD))
+                : !RD->isPOD();
+  bool isInstanceMethod = FI.isInstanceMethod();
 
-    // aarch64-windows requires that instance methods use X1 for the return
-    // address. So for aarch64-windows we do not mark the
-    // return as SRet.
-    FI.getReturnInfo().setSuppressSRet(CGM.getTarget().getTriple().getArch() ==
-                                       llvm::Triple::aarch64);
-    return true;
-  } else if (!RD->isPOD()) {
-    // If it's a free function, non-POD types are returned indirectly.
+  if (isIndirectReturn || !isSimple || isInstanceMethod) {
+    CharUnits Align = CGM.getContext().getTypeAlignInChars(FI.getReturnType());
     FI.getReturnInfo() = ABIArgInfo::getIndirect(Align, /*ByVal=*/false);
+    FI.getReturnInfo().setSRetAfterThis(isInstanceMethod);
 
-    // aarch64-windows requires that non-POD, non-instance returns use X0 for
-    // the return address. So for aarch64-windows we do not mark the return as
-    // SRet.
-    FI.getReturnInfo().setSuppressSRet(CGM.getTarget().getTriple().getArch() ==
-                                       llvm::Triple::aarch64);
+    FI.getReturnInfo().setInReg(isAArch64 &&
+                                !(isSimple && IsSizeGreaterThan128(RD)));
+
     return true;
   }
 
diff --git a/clang/lib/Sema/SemaDeclCXX.cpp b/clang/lib/Sema/SemaDeclCXX.cpp
index 2e7573d..b90ab04 100644
--- a/clang/lib/Sema/SemaDeclCXX.cpp
+++ b/clang/lib/Sema/SemaDeclCXX.cpp
@@ -5957,8 +5957,11 @@
 
     // Note: This permits small classes with nontrivial destructors to be
     // passed in registers, which is non-conforming.
+    bool isAArch64 = S.Context.getTargetInfo().getTriple().isAArch64();
+    uint64_t TypeSize = isAArch64 ? 128 : 64;
+
     if (CopyCtorIsTrivial &&
-        S.getASTContext().getTypeSize(D->getTypeForDecl()) <= 64)
+        S.getASTContext().getTypeSize(D->getTypeForDecl()) <= TypeSize)
       return true;
     return false;
   }