[WebAssembly] Do not emit tail calls with return type mismatch

Summary:
return_call and return_call_indirect are only valid if the return
types of the callee and caller match. We were previously not enforcing
that, which was producing invalid modules.

Reviewers: aheejin

Subscribers: dschuff, sbc100, jgravelle-google, hiraditya, sunfish, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D65246

llvm-svn: 367339
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index 4064a98..065a1a7 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -644,13 +644,36 @@
   if (CLI.IsPatchPoint)
     fail(DL, DAG, "WebAssembly doesn't support patch point yet");
 
-  // Fail if tail calls are required but not enabled
-  if (!Subtarget->hasTailCall()) {
-    if ((CallConv == CallingConv::Fast && CLI.IsTailCall &&
-         MF.getTarget().Options.GuaranteedTailCallOpt) ||
-        (CLI.CS && CLI.CS.isMustTailCall()))
-      fail(DL, DAG, "WebAssembly 'tail-call' feature not enabled");
-    CLI.IsTailCall = false;
+  if (CLI.IsTailCall) {
+    bool MustTail = CLI.CS && CLI.CS.isMustTailCall();
+    if (Subtarget->hasTailCall() && !CLI.IsVarArg) {
+      // Do not tail call unless caller and callee return types match
+      const Function &F = MF.getFunction();
+      const TargetMachine &TM = getTargetMachine();
+      Type *RetTy = F.getReturnType();
+      SmallVector<MVT, 4> CallerRetTys;
+      SmallVector<MVT, 4> CalleeRetTys;
+      computeLegalValueVTs(F, TM, RetTy, CallerRetTys);
+      computeLegalValueVTs(F, TM, CLI.RetTy, CalleeRetTys);
+      bool TypesMatch = CallerRetTys.size() == CalleeRetTys.size() &&
+                        std::equal(CallerRetTys.begin(), CallerRetTys.end(),
+                                   CalleeRetTys.begin());
+      if (!TypesMatch) {
+        // musttail in this case would be an LLVM IR validation failure
+        assert(!MustTail);
+        CLI.IsTailCall = false;
+      }
+    } else {
+      CLI.IsTailCall = false;
+      if (MustTail) {
+        if (CLI.IsVarArg) {
+          // The return would pop the argument buffer
+          fail(DL, DAG, "WebAssembly does not support varargs tail calls");
+        } else {
+          fail(DL, DAG, "WebAssembly 'tail-call' feature not enabled");
+        }
+      }
+    }
   }
 
   SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins;
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
index d9089de..26baac3 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
@@ -226,6 +226,17 @@
           if (WebAssembly::isCallIndirect(MI->getOpcode()))
             Params.pop_back();
 
+          // return_call_indirect instructions have the return type of the
+          // caller
+          if (MI->getOpcode() == WebAssembly::RET_CALL_INDIRECT) {
+            const Function &F = MI->getMF()->getFunction();
+            const TargetMachine &TM = MI->getMF()->getTarget();
+            Type *RetTy = F.getReturnType();
+            SmallVector<MVT, 4> CallerRetTys;
+            computeLegalValueVTs(F, TM, RetTy, CallerRetTys);
+            valTypesFromMVTs(CallerRetTys, Returns);
+          }
+
           auto *WasmSym = cast<MCSymbolWasm>(Sym);
           auto Signature = make_unique<wasm::WasmSignature>(std::move(Returns),
                                                             std::move(Params));