blob: e0c35e7039e500a7be1afc691b04d17d62d4d695 [file] [log] [blame]
Justin Holewinskib94bd052013-03-30 14:29:25 +00001//===- NVVMReflect.cpp - NVVM Emulate conditional compilation -------------===//
2//
3// The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
Justin Lebarefcc81c2016-04-01 01:09:07 +000010// This pass replaces occurrences of __nvvm_reflect("foo") and llvm.nvvm.reflect
11// with an integer.
Justin Lebare3804cc2016-03-30 20:40:11 +000012//
Justin Lebarefcc81c2016-04-01 01:09:07 +000013// We choose the value we use by looking, in this order, at:
14//
15// * the -nvvm-reflect-list flag, which has the format "foo=1,bar=42",
16// * the StringMap passed to the pass's constructor, and
17// * metadata in the module itself.
18//
19// If we see an unknown string, we replace its call with 0.
Justin Holewinskib94bd052013-03-30 14:29:25 +000020//
21//===----------------------------------------------------------------------===//
22
Justin Holewinski18f3a1f2013-05-20 16:42:16 +000023#include "NVPTX.h"
Justin Holewinskia922c7e2013-04-02 12:37:11 +000024#include "llvm/ADT/SmallVector.h"
Justin Holewinskib94bd052013-03-30 14:29:25 +000025#include "llvm/ADT/StringMap.h"
Chandler Carruth8a8cd2b2014-01-07 11:48:04 +000026#include "llvm/IR/Constants.h"
27#include "llvm/IR/DerivedTypes.h"
Justin Holewinskib94bd052013-03-30 14:29:25 +000028#include "llvm/IR/Function.h"
Justin Lebare3804cc2016-03-30 20:40:11 +000029#include "llvm/IR/InstIterator.h"
Chandler Carruth8a8cd2b2014-01-07 11:48:04 +000030#include "llvm/IR/Instructions.h"
Justin Holewinskia0d531f2014-06-27 18:36:11 +000031#include "llvm/IR/Intrinsics.h"
Justin Holewinskib94bd052013-03-30 14:29:25 +000032#include "llvm/IR/Module.h"
33#include "llvm/IR/Type.h"
Chandler Carruth8a8cd2b2014-01-07 11:48:04 +000034#include "llvm/Pass.h"
Justin Holewinskib94bd052013-03-30 14:29:25 +000035#include "llvm/Support/CommandLine.h"
36#include "llvm/Support/Debug.h"
37#include "llvm/Support/raw_os_ostream.h"
Benjamin Kramer799003b2015-03-23 19:32:43 +000038#include "llvm/Support/raw_ostream.h"
Justin Holewinskib94bd052013-03-30 14:29:25 +000039#include "llvm/Transforms/Scalar.h"
Justin Holewinskib94bd052013-03-30 14:29:25 +000040#include <sstream>
41#include <string>
Justin Holewinskib94bd052013-03-30 14:29:25 +000042#define NVVM_REFLECT_FUNCTION "__nvvm_reflect"
43
44using namespace llvm;
45
Chandler Carruthe96dd892014-04-21 22:55:11 +000046#define DEBUG_TYPE "nvptx-reflect"
47
Justin Holewinskib94bd052013-03-30 14:29:25 +000048namespace llvm { void initializeNVVMReflectPass(PassRegistry &); }
49
50namespace {
Justin Lebare3804cc2016-03-30 20:40:11 +000051class NVVMReflect : public FunctionPass {
Justin Holewinskib94bd052013-03-30 14:29:25 +000052private:
Justin Holewinskib94bd052013-03-30 14:29:25 +000053 StringMap<int> VarMap;
Justin Holewinskib94bd052013-03-30 14:29:25 +000054
55public:
56 static char ID;
Justin Lebare3804cc2016-03-30 20:40:11 +000057 NVVMReflect() : NVVMReflect(StringMap<int>()) {}
Justin Holewinski18f3a1f2013-05-20 16:42:16 +000058
Justin Lebarefcc81c2016-04-01 01:09:07 +000059 NVVMReflect(const StringMap<int> &Mapping)
60 : FunctionPass(ID), VarMap(Mapping) {
Justin Holewinski18f3a1f2013-05-20 16:42:16 +000061 initializeNVVMReflectPass(*PassRegistry::getPassRegistry());
Justin Lebare3804cc2016-03-30 20:40:11 +000062 setVarMap();
Justin Holewinskib94bd052013-03-30 14:29:25 +000063 }
64
Justin Lebare3804cc2016-03-30 20:40:11 +000065 bool runOnFunction(Function &) override;
Justin Holewinskib94bd052013-03-30 14:29:25 +000066
Justin Holewinskia0d531f2014-06-27 18:36:11 +000067private:
68 bool handleFunction(Function *ReflectFunction);
Justin Holewinskib94bd052013-03-30 14:29:25 +000069 void setVarMap();
70};
Alexander Kornienkof00654e2015-06-23 09:49:53 +000071}
Justin Holewinskib94bd052013-03-30 14:29:25 +000072
Justin Lebare3804cc2016-03-30 20:40:11 +000073FunctionPass *llvm::createNVVMReflectPass() { return new NVVMReflect(); }
74FunctionPass *llvm::createNVVMReflectPass(const StringMap<int> &Mapping) {
Justin Holewinski18f3a1f2013-05-20 16:42:16 +000075 return new NVVMReflect(Mapping);
76}
77
Justin Holewinskib94bd052013-03-30 14:29:25 +000078static cl::opt<bool>
Nadav Rotem7f27e0b2013-10-18 23:38:13 +000079NVVMReflectEnabled("nvvm-reflect-enable", cl::init(true), cl::Hidden,
Justin Holewinskib94bd052013-03-30 14:29:25 +000080 cl::desc("NVVM reflection, enabled by default"));
81
82char NVVMReflect::ID = 0;
83INITIALIZE_PASS(NVVMReflect, "nvvm-reflect",
Alp Tokercb402912014-01-24 17:20:08 +000084 "Replace occurrences of __nvvm_reflect() calls with 0/1", false,
Justin Holewinskib94bd052013-03-30 14:29:25 +000085 false)
86
87static cl::list<std::string>
Nadav Rotem7f27e0b2013-10-18 23:38:13 +000088ReflectList("nvvm-reflect-list", cl::value_desc("name=<int>"), cl::Hidden,
Justin Holewinskia922c7e2013-04-02 12:37:11 +000089 cl::desc("A list of string=num assignments"),
Justin Holewinskib94bd052013-03-30 14:29:25 +000090 cl::ValueRequired);
91
Justin Holewinskib94bd052013-03-30 14:29:25 +000092/// The command line can look as follows :
Justin Holewinskia922c7e2013-04-02 12:37:11 +000093/// -nvvm-reflect-list a=1,b=2 -nvvm-reflect-list c=3,d=0 -R e=2
Justin Holewinskib94bd052013-03-30 14:29:25 +000094/// The strings "a=1,b=2", "c=3,d=0", "e=2" are available in the
95/// ReflectList vector. First, each of ReflectList[i] is 'split'
96/// using "," as the delimiter. Then each of this part is split
97/// using "=" as the delimiter.
98void NVVMReflect::setVarMap() {
99 for (unsigned i = 0, e = ReflectList.size(); i != e; ++i) {
Justin Holewinskia922c7e2013-04-02 12:37:11 +0000100 DEBUG(dbgs() << "Option : " << ReflectList[i] << "\n");
101 SmallVector<StringRef, 4> NameValList;
Chandler Carruthe4405e92015-09-10 06:12:31 +0000102 StringRef(ReflectList[i]).split(NameValList, ',');
Justin Holewinskia922c7e2013-04-02 12:37:11 +0000103 for (unsigned j = 0, ej = NameValList.size(); j != ej; ++j) {
104 SmallVector<StringRef, 2> NameValPair;
Chandler Carruthe4405e92015-09-10 06:12:31 +0000105 NameValList[j].split(NameValPair, '=');
Justin Holewinskia922c7e2013-04-02 12:37:11 +0000106 assert(NameValPair.size() == 2 && "name=val expected");
107 std::stringstream ValStream(NameValPair[1]);
108 int Val;
109 ValStream >> Val;
110 assert((!(ValStream.fail())) && "integer value expected");
111 VarMap[NameValPair[0]] = Val;
Justin Holewinskib94bd052013-03-30 14:29:25 +0000112 }
113 }
114}
115
Justin Lebare3804cc2016-03-30 20:40:11 +0000116bool NVVMReflect::runOnFunction(Function &F) {
117 if (!NVVMReflectEnabled)
118 return false;
Justin Holewinskib94bd052013-03-30 14:29:25 +0000119
Justin Lebare3804cc2016-03-30 20:40:11 +0000120 if (F.getName() == NVVM_REFLECT_FUNCTION) {
121 assert(F.isDeclaration() && "_reflect function should not have a body");
122 assert(F.getReturnType()->isIntegerTy() &&
123 "_reflect's return type should be integer");
124 return false;
125 }
Justin Holewinskib94bd052013-03-30 14:29:25 +0000126
Justin Lebare3804cc2016-03-30 20:40:11 +0000127 SmallVector<Instruction *, 4> ToRemove;
Artem Belevich9e8a0392015-03-19 17:05:35 +0000128
Justin Lebare3804cc2016-03-30 20:40:11 +0000129 // Go through the calls in this function. Each call to __nvvm_reflect or
130 // llvm.nvvm.reflect should be a CallInst with a ConstantArray argument.
131 // First validate that. If the c-string corresponding to the ConstantArray can
132 // be found successfully, see if it can be found in VarMap. If so, replace the
133 // uses of CallInst with the value found in VarMap. If not, replace the use
134 // with value 0.
135
136 // The IR for __nvvm_reflect calls differs between CUDA versions.
137 //
Artem Belevich9e8a0392015-03-19 17:05:35 +0000138 // CUDA 6.5 and earlier uses this sequence:
139 // %ptr = tail call i8* @llvm.nvvm.ptr.constant.to.gen.p0i8.p4i8
140 // (i8 addrspace(4)* getelementptr inbounds
141 // ([8 x i8], [8 x i8] addrspace(4)* @str, i32 0, i32 0))
142 // %reflect = tail call i32 @__nvvm_reflect(i8* %ptr)
143 //
Justin Lebare3804cc2016-03-30 20:40:11 +0000144 // The value returned by Sym->getOperand(0) is a Constant with a
Artem Belevich9e8a0392015-03-19 17:05:35 +0000145 // ConstantDataSequential operand which can be converted to string and used
146 // for lookup.
147 //
148 // CUDA 7.0 does it slightly differently:
149 // %reflect = call i32 @__nvvm_reflect(i8* addrspacecast
150 // (i8 addrspace(1)* getelementptr inbounds
151 // ([8 x i8], [8 x i8] addrspace(1)* @str, i32 0, i32 0) to i8*))
152 //
153 // In this case, we get a Constant with a GlobalVariable operand and we need
154 // to dig deeper to find its initializer with the string we'll use for lookup.
Justin Lebare3804cc2016-03-30 20:40:11 +0000155 for (Instruction &I : instructions(F)) {
156 CallInst *Call = dyn_cast<CallInst>(&I);
157 if (!Call)
158 continue;
159 Function *Callee = Call->getCalledFunction();
160 if (!Callee || (Callee->getName() != NVVM_REFLECT_FUNCTION &&
161 Callee->getIntrinsicID() != Intrinsic::nvvm_reflect))
162 continue;
Artem Belevich9e8a0392015-03-19 17:05:35 +0000163
Justin Lebare3804cc2016-03-30 20:40:11 +0000164 // FIXME: Improve error handling here and elsewhere in this pass.
165 assert(Call->getNumOperands() == 2 &&
166 "Wrong number of operands to __nvvm_reflect function");
Justin Holewinskib94bd052013-03-30 14:29:25 +0000167
Justin Lebare3804cc2016-03-30 20:40:11 +0000168 // In cuda 6.5 and earlier, we will have an extra constant-to-generic
169 // conversion of the string.
170 const Value *Str = Call->getArgOperand(0);
171 if (const CallInst *ConvCall = dyn_cast<CallInst>(Str)) {
172 // FIXME: Add assertions about ConvCall.
Justin Holewinskia0d531f2014-06-27 18:36:11 +0000173 Str = ConvCall->getArgOperand(0);
174 }
175 assert(isa<ConstantExpr>(Str) &&
Justin Lebare3804cc2016-03-30 20:40:11 +0000176 "Format of __nvvm__reflect function not recognized");
Justin Holewinskia0d531f2014-06-27 18:36:11 +0000177 const ConstantExpr *GEP = cast<ConstantExpr>(Str);
Justin Holewinskib94bd052013-03-30 14:29:25 +0000178
Justin Holewinskia922c7e2013-04-02 12:37:11 +0000179 const Value *Sym = GEP->getOperand(0);
Justin Lebare3804cc2016-03-30 20:40:11 +0000180 assert(isa<Constant>(Sym) &&
181 "Format of __nvvm_reflect function not recognized");
Justin Holewinskib94bd052013-03-30 14:29:25 +0000182
Artem Belevich9e8a0392015-03-19 17:05:35 +0000183 const Value *Operand = cast<Constant>(Sym)->getOperand(0);
184 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(Operand)) {
Justin Lebare3804cc2016-03-30 20:40:11 +0000185 // For CUDA-7.0 style __nvvm_reflect calls, we need to find the operand's
Artem Belevich9e8a0392015-03-19 17:05:35 +0000186 // initializer.
187 assert(GV->hasInitializer() &&
188 "Format of _reflect function not recognized");
189 const Constant *Initializer = GV->getInitializer();
190 Operand = Initializer;
191 }
Justin Holewinskib94bd052013-03-30 14:29:25 +0000192
Artem Belevich9e8a0392015-03-19 17:05:35 +0000193 assert(isa<ConstantDataSequential>(Operand) &&
Justin Holewinskib94bd052013-03-30 14:29:25 +0000194 "Format of _reflect function not recognized");
Artem Belevich9e8a0392015-03-19 17:05:35 +0000195 assert(cast<ConstantDataSequential>(Operand)->isCString() &&
Justin Holewinskib94bd052013-03-30 14:29:25 +0000196 "Format of _reflect function not recognized");
197
Justin Lebare3804cc2016-03-30 20:40:11 +0000198 StringRef ReflectArg = cast<ConstantDataSequential>(Operand)->getAsString();
Justin Holewinskia922c7e2013-04-02 12:37:11 +0000199 ReflectArg = ReflectArg.substr(0, ReflectArg.size() - 1);
200 DEBUG(dbgs() << "Arg of _reflect : " << ReflectArg << "\n");
Justin Holewinskib94bd052013-03-30 14:29:25 +0000201
Justin Holewinskia922c7e2013-04-02 12:37:11 +0000202 int ReflectVal = 0; // The default value is 0
Justin Lebare3804cc2016-03-30 20:40:11 +0000203 auto Iter = VarMap.find(ReflectArg);
204 if (Iter != VarMap.end())
205 ReflectVal = Iter->second;
Justin Lebarefcc81c2016-04-01 01:09:07 +0000206 else if (ReflectArg == "__CUDA_FTZ") {
207 // Try to pull __CUDA_FTZ from the nvvm-reflect-ftz module flag.
208 if (auto *Flag = mdconst::extract_or_null<ConstantInt>(
209 F.getParent()->getModuleFlag("nvvm-reflect-ftz")))
210 ReflectVal = Flag->getSExtValue();
211 }
Justin Lebare3804cc2016-03-30 20:40:11 +0000212 Call->replaceAllUsesWith(ConstantInt::get(Call->getType(), ReflectVal));
213 ToRemove.push_back(Call);
Justin Holewinskia0d531f2014-06-27 18:36:11 +0000214 }
215
Justin Lebare3804cc2016-03-30 20:40:11 +0000216 for (Instruction *I : ToRemove)
217 I->eraseFromParent();
Justin Holewinskia0d531f2014-06-27 18:36:11 +0000218
Justin Lebare3804cc2016-03-30 20:40:11 +0000219 return ToRemove.size() > 0;
Justin Holewinskia0d531f2014-06-27 18:36:11 +0000220}