[RISCV] Support assembling TLS add and associated modifiers

This patch adds support in the MC layer for parsing and assembling the
4-operand add instruction needed for TLS addressing. This also involves
parsing the %tprel_hi, %tprel_lo and %tprel_add operand modifiers.

Differential Revision: https://reviews.llvm.org/D55341

llvm-svn: 357698
diff --git a/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp b/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp
index 6c46a40..7113adc 100644
--- a/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp
+++ b/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp
@@ -295,6 +295,16 @@
             VK == RISCVMCExpr::VK_RISCV_CALL_PLT);
   }
 
+  bool isTPRelAddSymbol() const {
+    int64_t Imm;
+    RISCVMCExpr::VariantKind VK;
+    // Must be of 'immediate' type but not a constant.
+    if (!isImm() || evaluateConstantImm(getImm(), Imm, VK))
+      return false;
+    return RISCVAsmParser::classifySymbolRef(getImm(), VK, Imm) &&
+           VK == RISCVMCExpr::VK_RISCV_TPREL_ADD;
+  }
+
   bool isCSRSystemRegister() const { return isSystemRegister(); }
 
   /// Return true if the operand is a valid for the fence instruction e.g.
@@ -489,7 +499,8 @@
       IsValid = isInt<12>(Imm);
     return IsValid && ((IsConstantImm && VK == RISCVMCExpr::VK_RISCV_None) ||
                        VK == RISCVMCExpr::VK_RISCV_LO ||
-                       VK == RISCVMCExpr::VK_RISCV_PCREL_LO);
+                       VK == RISCVMCExpr::VK_RISCV_PCREL_LO ||
+                       VK == RISCVMCExpr::VK_RISCV_TPREL_LO);
   }
 
   bool isSImm12Lsb0() const { return isBareSimmNLsb0<12>(); }
@@ -515,10 +526,12 @@
     bool IsConstantImm = evaluateConstantImm(getImm(), Imm, VK);
     if (!IsConstantImm) {
       IsValid = RISCVAsmParser::classifySymbolRef(getImm(), VK, Imm);
-      return IsValid && VK == RISCVMCExpr::VK_RISCV_HI;
+      return IsValid && (VK == RISCVMCExpr::VK_RISCV_HI ||
+                         VK == RISCVMCExpr::VK_RISCV_TPREL_HI);
     } else {
       return isUInt<20>(Imm) && (VK == RISCVMCExpr::VK_RISCV_None ||
-                                 VK == RISCVMCExpr::VK_RISCV_HI);
+                                 VK == RISCVMCExpr::VK_RISCV_HI ||
+                                 VK == RISCVMCExpr::VK_RISCV_TPREL_HI);
     }
   }
 
@@ -872,8 +885,8 @@
   case Match_InvalidSImm12:
     return generateImmOutOfRangeError(
         Operands, ErrorInfo, -(1 << 11), (1 << 11) - 1,
-        "operand must be a symbol with %lo/%pcrel_lo modifier or an integer in "
-        "the range");
+        "operand must be a symbol with %lo/%pcrel_lo/%tprel_lo modifier or an "
+        "integer in the range");
   case Match_InvalidSImm12Lsb0:
     return generateImmOutOfRangeError(
         Operands, ErrorInfo, -(1 << 11), (1 << 11) - 2,
@@ -884,8 +897,9 @@
         "immediate must be a multiple of 2 bytes in the range");
   case Match_InvalidUImm20LUI:
     return generateImmOutOfRangeError(Operands, ErrorInfo, 0, (1 << 20) - 1,
-                                      "operand must be a symbol with %hi() "
-                                      "modifier or an integer in the range");
+                                      "operand must be a symbol with "
+                                      "%hi/%tprel_hi modifier or an integer in "
+                                      "the range");
   case Match_InvalidUImm20AUIPC:
     return generateImmOutOfRangeError(
         Operands, ErrorInfo, 0, (1 << 20) - 1,
@@ -920,6 +934,10 @@
     SMLoc ErrorLoc = ((RISCVOperand &)*Operands[ErrorInfo]).getStartLoc();
     return Error(ErrorLoc, "operand must be a bare symbol name");
   }
+  case Match_InvalidTPRelAddSymbol: {
+    SMLoc ErrorLoc = ((RISCVOperand &)*Operands[ErrorInfo]).getStartLoc();
+    return Error(ErrorLoc, "operand must be a symbol with %tprel_add modifier");
+  }
   }
 
   llvm_unreachable("Unknown match type detected!");
diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.cpp b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.cpp
index d5d464e..c57eb73 100644
--- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.cpp
+++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.cpp
@@ -187,12 +187,15 @@
     return Value;
   case RISCV::fixup_riscv_lo12_i:
   case RISCV::fixup_riscv_pcrel_lo12_i:
+  case RISCV::fixup_riscv_tprel_lo12_i:
     return Value & 0xfff;
   case RISCV::fixup_riscv_lo12_s:
   case RISCV::fixup_riscv_pcrel_lo12_s:
+  case RISCV::fixup_riscv_tprel_lo12_s:
     return (((Value >> 5) & 0x7f) << 25) | ((Value & 0x1f) << 7);
   case RISCV::fixup_riscv_hi20:
   case RISCV::fixup_riscv_pcrel_hi20:
+  case RISCV::fixup_riscv_tprel_hi20:
     // Add 1 if bit 11 is 1, to compensate for low 12 bits being negative.
     return ((Value + 0x800) >> 12) & 0xfffff;
   case RISCV::fixup_riscv_jal: {
diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.h b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.h
index ae13e5f..f16bd4e 100644
--- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.h
+++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVAsmBackend.h
@@ -105,6 +105,10 @@
       { "fixup_riscv_pcrel_lo12_i", 20,     12,  MCFixupKindInfo::FKF_IsPCRel },
       { "fixup_riscv_pcrel_lo12_s",  0,     32,  MCFixupKindInfo::FKF_IsPCRel },
       { "fixup_riscv_got_hi20",     12,     20,  MCFixupKindInfo::FKF_IsPCRel },
+      { "fixup_riscv_tprel_hi20",   12,     20,  0 },
+      { "fixup_riscv_tprel_lo12_i", 20,     12,  0 },
+      { "fixup_riscv_tprel_lo12_s",  0,     32,  0 },
+      { "fixup_riscv_tprel_add",     0,      0,  0 },
       { "fixup_riscv_jal",          12,     20,  MCFixupKindInfo::FKF_IsPCRel },
       { "fixup_riscv_branch",        0,     32,  MCFixupKindInfo::FKF_IsPCRel },
       { "fixup_riscv_rvc_jump",      2,     11,  MCFixupKindInfo::FKF_IsPCRel },
diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVELFObjectWriter.cpp b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVELFObjectWriter.cpp
index 4fa3cd8..e649776 100644
--- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVELFObjectWriter.cpp
+++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVELFObjectWriter.cpp
@@ -85,6 +85,14 @@
     return ELF::R_RISCV_PCREL_LO12_S;
   case RISCV::fixup_riscv_got_hi20:
     return ELF::R_RISCV_GOT_HI20;
+  case RISCV::fixup_riscv_tprel_hi20:
+    return ELF::R_RISCV_TPREL_HI20;
+  case RISCV::fixup_riscv_tprel_lo12_i:
+    return ELF::R_RISCV_TPREL_LO12_I;
+  case RISCV::fixup_riscv_tprel_lo12_s:
+    return ELF::R_RISCV_TPREL_LO12_S;
+  case RISCV::fixup_riscv_tprel_add:
+    return ELF::R_RISCV_TPREL_ADD;
   case RISCV::fixup_riscv_jal:
     return ELF::R_RISCV_JAL;
   case RISCV::fixup_riscv_branch:
diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVFixupKinds.h b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVFixupKinds.h
index b931acc..e33183e 100644
--- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVFixupKinds.h
+++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVFixupKinds.h
@@ -37,6 +37,18 @@
   // fixup_riscv_got_hi20 - 20-bit fixup corresponding to got_pcrel_hi(foo) for
   // instructions like auipc
   fixup_riscv_got_hi20,
+  // fixup_riscv_tprel_hi20 - 20-bit fixup corresponding to tprel_hi(foo) for
+  // instructions like lui
+  fixup_riscv_tprel_hi20,
+  // fixup_riscv_tprel_lo12_i - 12-bit fixup corresponding to tprel_lo(foo) for
+  // instructions like addi
+  fixup_riscv_tprel_lo12_i,
+  // fixup_riscv_tprel_lo12_s - 12-bit fixup corresponding to tprel_lo(foo) for
+  // the S-type store instructions
+  fixup_riscv_tprel_lo12_s,
+  // fixup_riscv_tprel_add - A fixup corresponding to %tprel_add(foo) for the
+  // add_tls instruction. Used to provide a hint to the linker.
+  fixup_riscv_tprel_add,
   // fixup_riscv_jal - 20-bit fixup for symbol references in the jal
   // instruction
   fixup_riscv_jal,
diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVMCCodeEmitter.cpp b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVMCCodeEmitter.cpp
index 692398e..f1248e5 100644
--- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVMCCodeEmitter.cpp
+++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVMCCodeEmitter.cpp
@@ -56,6 +56,10 @@
                           SmallVectorImpl<MCFixup> &Fixups,
                           const MCSubtargetInfo &STI) const;
 
+  void expandAddTPRel(const MCInst &MI, raw_ostream &OS,
+                      SmallVectorImpl<MCFixup> &Fixups,
+                      const MCSubtargetInfo &STI) const;
+
   /// TableGen'erated function for getting the binary encoding for an
   /// instruction.
   uint64_t getBinaryCodeForInstr(const MCInst &MI,
@@ -120,6 +124,44 @@
   support::endian::write(OS, Binary, support::little);
 }
 
+// Expand PseudoAddTPRel to a simple ADD with the correct relocation.
+void RISCVMCCodeEmitter::expandAddTPRel(const MCInst &MI, raw_ostream &OS,
+                                        SmallVectorImpl<MCFixup> &Fixups,
+                                        const MCSubtargetInfo &STI) const {
+  MCOperand DestReg = MI.getOperand(0);
+  MCOperand SrcReg = MI.getOperand(1);
+  MCOperand TPReg = MI.getOperand(2);
+  assert(TPReg.isReg() && TPReg.getReg() == RISCV::X4 &&
+         "Expected thread pointer as second input to TP-relative add");
+
+  MCOperand SrcSymbol = MI.getOperand(3);
+  assert(SrcSymbol.isExpr() &&
+         "Expected expression as third input to TP-relative add");
+
+  const RISCVMCExpr *Expr = dyn_cast<RISCVMCExpr>(SrcSymbol.getExpr());
+  assert(Expr && Expr->getKind() == RISCVMCExpr::VK_RISCV_TPREL_ADD &&
+         "Expected tprel_add relocation on TP-relative symbol");
+
+  // Emit the correct tprel_add relocation for the symbol.
+  Fixups.push_back(MCFixup::create(
+      0, Expr, MCFixupKind(RISCV::fixup_riscv_tprel_add), MI.getLoc()));
+
+  // Emit fixup_riscv_relax for tprel_add where the relax feature is enabled.
+  if (STI.getFeatureBits()[RISCV::FeatureRelax]) {
+    const MCConstantExpr *Dummy = MCConstantExpr::create(0, Ctx);
+    Fixups.push_back(MCFixup::create(
+        0, Dummy, MCFixupKind(RISCV::fixup_riscv_relax), MI.getLoc()));
+  }
+
+  // Emit a normal ADD instruction with the given operands.
+  MCInst TmpInst = MCInstBuilder(RISCV::ADD)
+                       .addOperand(DestReg)
+                       .addOperand(SrcReg)
+                       .addOperand(TPReg);
+  uint32_t Binary = getBinaryCodeForInstr(TmpInst, Fixups, STI);
+  support::endian::write(OS, Binary, support::little);
+}
+
 void RISCVMCCodeEmitter::encodeInstruction(const MCInst &MI, raw_ostream &OS,
                                            SmallVectorImpl<MCFixup> &Fixups,
                                            const MCSubtargetInfo &STI) const {
@@ -134,6 +176,12 @@
     return;
   }
 
+  if (MI.getOpcode() == RISCV::PseudoAddTPRel) {
+    expandAddTPRel(MI, OS, Fixups, STI);
+    MCNumEmitted += 1;
+    return;
+  }
+
   switch (Size) {
   default:
     llvm_unreachable("Unhandled encodeInstruction length!");
@@ -208,6 +256,13 @@
     case RISCVMCExpr::VK_RISCV_None:
     case RISCVMCExpr::VK_RISCV_Invalid:
       llvm_unreachable("Unhandled fixup kind!");
+    case RISCVMCExpr::VK_RISCV_TPREL_ADD:
+      // tprel_add is only used to indicate that a relocation should be emitted
+      // for an add instruction used in TP-relative addressing. It should not be
+      // expanded as if representing an actual instruction operand and so to
+      // encounter it here is an error.
+      llvm_unreachable(
+          "VK_RISCV_TPREL_ADD should not represent an instruction operand");
     case RISCVMCExpr::VK_RISCV_LO:
       if (MIFrm == RISCVII::InstFormatI)
         FixupKind = RISCV::fixup_riscv_lo12_i;
@@ -238,6 +293,20 @@
     case RISCVMCExpr::VK_RISCV_GOT_HI:
       FixupKind = RISCV::fixup_riscv_got_hi20;
       break;
+    case RISCVMCExpr::VK_RISCV_TPREL_LO:
+      if (MIFrm == RISCVII::InstFormatI)
+        FixupKind = RISCV::fixup_riscv_tprel_lo12_i;
+      else if (MIFrm == RISCVII::InstFormatS)
+        FixupKind = RISCV::fixup_riscv_tprel_lo12_s;
+      else
+        llvm_unreachable(
+            "VK_RISCV_TPREL_LO used with unexpected instruction format");
+      RelaxCandidate = true;
+      break;
+    case RISCVMCExpr::VK_RISCV_TPREL_HI:
+      FixupKind = RISCV::fixup_riscv_tprel_hi20;
+      RelaxCandidate = true;
+      break;
     case RISCVMCExpr::VK_RISCV_CALL:
       FixupKind = RISCV::fixup_riscv_call;
       RelaxCandidate = true;
diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVMCExpr.cpp b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVMCExpr.cpp
index d397499..6ccabee 100644
--- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVMCExpr.cpp
+++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVMCExpr.cpp
@@ -15,6 +15,7 @@
 #include "MCTargetDesc/RISCVAsmBackend.h"
 #include "RISCV.h"
 #include "RISCVFixupKinds.h"
+#include "llvm/BinaryFormat/ELF.h"
 #include "llvm/MC/MCAsmLayout.h"
 #include "llvm/MC/MCAssembler.h"
 #include "llvm/MC/MCContext.h"
@@ -162,6 +163,9 @@
     case VK_RISCV_PCREL_LO:
     case VK_RISCV_PCREL_HI:
     case VK_RISCV_GOT_HI:
+    case VK_RISCV_TPREL_LO:
+    case VK_RISCV_TPREL_HI:
+    case VK_RISCV_TPREL_ADD:
       return false;
     }
   }
@@ -180,6 +184,9 @@
       .Case("pcrel_lo", VK_RISCV_PCREL_LO)
       .Case("pcrel_hi", VK_RISCV_PCREL_HI)
       .Case("got_pcrel_hi", VK_RISCV_GOT_HI)
+      .Case("tprel_lo", VK_RISCV_TPREL_LO)
+      .Case("tprel_hi", VK_RISCV_TPREL_HI)
+      .Case("tprel_add", VK_RISCV_TPREL_ADD)
       .Default(VK_RISCV_Invalid);
 }
 
@@ -197,15 +204,62 @@
     return "pcrel_hi";
   case VK_RISCV_GOT_HI:
     return "got_pcrel_hi";
+  case VK_RISCV_TPREL_LO:
+    return "tprel_lo";
+  case VK_RISCV_TPREL_HI:
+    return "tprel_hi";
+  case VK_RISCV_TPREL_ADD:
+    return "tprel_add";
   }
 }
 
+static void fixELFSymbolsInTLSFixupsImpl(const MCExpr *Expr, MCAssembler &Asm) {
+  switch (Expr->getKind()) {
+  case MCExpr::Target:
+    llvm_unreachable("Can't handle nested target expression");
+    break;
+  case MCExpr::Constant:
+    break;
+
+  case MCExpr::Binary: {
+    const MCBinaryExpr *BE = cast<MCBinaryExpr>(Expr);
+    fixELFSymbolsInTLSFixupsImpl(BE->getLHS(), Asm);
+    fixELFSymbolsInTLSFixupsImpl(BE->getRHS(), Asm);
+    break;
+  }
+
+  case MCExpr::SymbolRef: {
+    // We're known to be under a TLS fixup, so any symbol should be
+    // modified. There should be only one.
+    const MCSymbolRefExpr &SymRef = *cast<MCSymbolRefExpr>(Expr);
+    cast<MCSymbolELF>(SymRef.getSymbol()).setType(ELF::STT_TLS);
+    break;
+  }
+
+  case MCExpr::Unary:
+    fixELFSymbolsInTLSFixupsImpl(cast<MCUnaryExpr>(Expr)->getSubExpr(), Asm);
+    break;
+  }
+}
+
+void RISCVMCExpr::fixELFSymbolsInTLSFixups(MCAssembler &Asm) const {
+  switch (getKind()) {
+  default:
+    return;
+  case VK_RISCV_TPREL_HI:
+    break;
+  }
+
+  fixELFSymbolsInTLSFixupsImpl(getSubExpr(), Asm);
+}
+
 bool RISCVMCExpr::evaluateAsConstant(int64_t &Res) const {
   MCValue Value;
 
   if (Kind == VK_RISCV_PCREL_HI || Kind == VK_RISCV_PCREL_LO ||
-      Kind == VK_RISCV_GOT_HI || Kind == VK_RISCV_CALL ||
-      Kind == VK_RISCV_CALL_PLT)
+      Kind == VK_RISCV_GOT_HI || Kind == VK_RISCV_TPREL_HI ||
+      Kind == VK_RISCV_TPREL_LO || Kind == VK_RISCV_TPREL_ADD ||
+      Kind == VK_RISCV_CALL || Kind == VK_RISCV_CALL_PLT)
     return false;
 
   if (!getSubExpr()->evaluateAsRelocatable(Value, nullptr, nullptr))
diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVMCExpr.h b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVMCExpr.h
index 2cd3087..5846811 100644
--- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVMCExpr.h
+++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVMCExpr.h
@@ -29,6 +29,9 @@
     VK_RISCV_PCREL_LO,
     VK_RISCV_PCREL_HI,
     VK_RISCV_GOT_HI,
+    VK_RISCV_TPREL_LO,
+    VK_RISCV_TPREL_HI,
+    VK_RISCV_TPREL_ADD,
     VK_RISCV_CALL,
     VK_RISCV_CALL_PLT,
     VK_RISCV_Invalid
@@ -69,8 +72,7 @@
     return getSubExpr()->findAssociatedFragment();
   }
 
-  // There are no TLS RISCVMCExprs at the moment.
-  void fixELFSymbolsInTLSFixups(MCAssembler &Asm) const override {}
+  void fixELFSymbolsInTLSFixups(MCAssembler &Asm) const override;
 
   bool evaluateAsConstant(int64_t &Res) const;
 
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.td b/llvm/lib/Target/RISCV/RISCVInstrInfo.td
index b0de892..99002386 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.td
@@ -202,6 +202,18 @@
   let ParserMatchClass = CallSymbol;
 }
 
+def TPRelAddSymbol : AsmOperandClass {
+  let Name = "TPRelAddSymbol";
+  let RenderMethod = "addImmOperands";
+  let DiagnosticType = "InvalidTPRelAddSymbol";
+  let ParserMethod = "parseOperandWithModifier";
+}
+
+// A bare symbol with the %tprel_add variant.
+def tprel_add_symbol : Operand<XLenVT> {
+  let ParserMatchClass = TPRelAddSymbol;
+}
+
 def CSRSystemRegister : AsmOperandClass {
   let Name = "CSRSystemRegister";
   let ParserMethod = "parseCSRSystemRegister";
@@ -765,6 +777,15 @@
 def : PatGprGpr<shiftop<srl>, SRL>;
 def : PatGprGpr<shiftop<sra>, SRA>;
 
+// This is a special case of the ADD instruction used to facilitate the use of a
+// fourth operand to emit a relocation on a symbol relating to this instruction.
+// The relocation does not affect any bits of the instruction itself but is used
+// as a hint to the linker.
+let hasSideEffects = 0, mayLoad = 0, mayStore = 0, isCodeGenOnly = 0 in
+def PseudoAddTPRel : Pseudo<(outs GPR:$rd),
+                            (ins GPR:$rs1, GPR:$rs2, tprel_add_symbol:$src), [],
+                            "add", "$rd, $rs1, $rs2, $src">;
+
 /// FrameIndex calculations
 
 def : Pat<(add (i32 AddrFI:$Rs), simm12:$imm12),