[WebAssembly] Fix types for address taken functions

Differential Revision: https://reviews.llvm.org/D34966

llvm-svn: 307198
diff --git a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyTargetStreamer.cpp b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyTargetStreamer.cpp
index ad59f2f..00bf024 100644
--- a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyTargetStreamer.cpp
+++ b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyTargetStreamer.cpp
@@ -115,8 +115,8 @@
 void WebAssemblyTargetAsmStreamer::emitEndFunc() { OS << "\t.endfunc\n"; }
 
 void WebAssemblyTargetAsmStreamer::emitIndirectFunctionType(
-    StringRef name, SmallVectorImpl<MVT> &Params, SmallVectorImpl<MVT> &Results) {
-  OS << "\t.functype\t" << name;
+    MCSymbol *Symbol, SmallVectorImpl<MVT> &Params, SmallVectorImpl<MVT> &Results) {
+  OS << "\t.functype\t" << Symbol->getName();
   if (Results.empty())
     OS << ", void";
   else {
@@ -171,7 +171,7 @@
 }
 
 void WebAssemblyTargetELFStreamer::emitIndirectFunctionType(
-    StringRef name, SmallVectorImpl<MVT> &Params, SmallVectorImpl<MVT> &Results) {
+    MCSymbol *Symbol, SmallVectorImpl<MVT> &Params, SmallVectorImpl<MVT> &Results) {
   // Nothing to emit here. TODO: Re-design how linking works and re-evaluate
   // whether it's necessary for .o files to declare indirect function types.
 }
@@ -255,9 +255,25 @@
 }
 
 void WebAssemblyTargetWasmStreamer::emitIndirectFunctionType(
-    StringRef name, SmallVectorImpl<MVT> &Params, SmallVectorImpl<MVT> &Results) {
-  // Nothing to emit here. TODO: Re-design how linking works and re-evaluate
-  // whether it's necessary for .o files to declare indirect function types.
+    MCSymbol *Symbol, SmallVectorImpl<MVT> &Params,
+    SmallVectorImpl<MVT> &Results) {
+  MCSymbolWasm *WasmSym = cast<MCSymbolWasm>(Symbol);
+  if (WasmSym->isFunction()) {
+    // Symbol already has its arguments and result set.
+    return;
+  }
+
+  SmallVector<wasm::ValType, 4> ValParams;
+  for (MVT Ty : Params)
+    ValParams.push_back(WebAssembly::toValType(Ty));
+
+  SmallVector<wasm::ValType, 1> ValResults;
+  for (MVT Ty : Results)
+    ValResults.push_back(WebAssembly::toValType(Ty));
+
+  WasmSym->setParams(std::move(ValParams));
+  WasmSym->setReturns(std::move(ValResults));
+  WasmSym->setIsFunction(true);
 }
 
 void WebAssemblyTargetWasmStreamer::emitGlobalImport(StringRef name) {
diff --git a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyTargetStreamer.h b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyTargetStreamer.h
index 5ad147e..102d721 100644
--- a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyTargetStreamer.h
+++ b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyTargetStreamer.h
@@ -44,7 +44,7 @@
   /// .endfunc
   virtual void emitEndFunc() = 0;
   /// .functype
-  virtual void emitIndirectFunctionType(StringRef name,
+  virtual void emitIndirectFunctionType(MCSymbol *Symbol,
                                         SmallVectorImpl<MVT> &Params,
                                         SmallVectorImpl<MVT> &Results) = 0;
   /// .indidx
@@ -69,7 +69,7 @@
   void emitGlobal(ArrayRef<wasm::Global> Globals) override;
   void emitStackPointer(uint32_t Index) override;
   void emitEndFunc() override;
-  void emitIndirectFunctionType(StringRef name,
+  void emitIndirectFunctionType(MCSymbol *Symbol,
                                 SmallVectorImpl<MVT> &Params,
                                 SmallVectorImpl<MVT> &Results) override;
   void emitIndIdx(const MCExpr *Value) override;
@@ -87,7 +87,7 @@
   void emitGlobal(ArrayRef<wasm::Global> Globals) override;
   void emitStackPointer(uint32_t Index) override;
   void emitEndFunc() override;
-  void emitIndirectFunctionType(StringRef name,
+  void emitIndirectFunctionType(MCSymbol *Symbol,
                                 SmallVectorImpl<MVT> &Params,
                                 SmallVectorImpl<MVT> &Results) override;
   void emitIndIdx(const MCExpr *Value) override;
@@ -105,7 +105,7 @@
   void emitGlobal(ArrayRef<wasm::Global> Globals) override;
   void emitStackPointer(uint32_t Index) override;
   void emitEndFunc() override;
-  void emitIndirectFunctionType(StringRef name,
+  void emitIndirectFunctionType(MCSymbol *Symbol,
                                 SmallVectorImpl<MVT> &Params,
                                 SmallVectorImpl<MVT> &Results) override;
   void emitIndIdx(const MCExpr *Value) override;
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp
index f51585a..211358a 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp
@@ -84,7 +84,7 @@
       SmallVector<MVT, 4> Results;
       SmallVector<MVT, 4> Params;
       ComputeSignatureVTs(F, TM, Params, Results);
-      getTargetStreamer()->emitIndirectFunctionType(F.getName(), Params,
+      getTargetStreamer()->emitIndirectFunctionType(getSymbol(&F), Params,
                                                     Results);
     }
   }
@@ -214,11 +214,8 @@
 const MCExpr *WebAssemblyAsmPrinter::lowerConstant(const Constant *CV) {
   if (const GlobalValue *GV = dyn_cast<GlobalValue>(CV))
     if (GV->getValueType()->isFunctionTy()) {
-      MCSymbol* Sym = getSymbol(GV);
-      if (!isa<MCSymbolELF>(Sym))
-        cast<MCSymbolWasm>(Sym)->setIsFunction(true);
       return MCSymbolRefExpr::create(
-          Sym, MCSymbolRefExpr::VK_WebAssembly_FUNCTION, OutContext);
+          getSymbol(GV), MCSymbolRefExpr::VK_WebAssembly_FUNCTION, OutContext);
     }
   return AsmPrinter::lowerConstant(CV);
 }
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
index ff186eb..8880539 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
@@ -112,8 +112,6 @@
   MCSymbolRefExpr::VariantKind VK =
       IsFunc ? MCSymbolRefExpr::VK_WebAssembly_FUNCTION
              : MCSymbolRefExpr::VK_None;
-  if (!isa<MCSymbolELF>(Sym))
-    cast<MCSymbolWasm>(Sym)->setIsFunction(IsFunc);
 
   const MCExpr *Expr = MCSymbolRefExpr::create(Sym, VK, Ctx);