Compiler kernel fusion for RenderScript.
This was started by Tobias Grosser during his internship at Google in 2013.
This CL includes his original changes and generalization to make it work with
the new proposed ScriptGroup API and made the ImageProcessing processing work
with this.
An enabling CL is needed in the RenderScript runtime, which I will post
seperately.
Change-Id: Ia73ea917a126a5055ec97f13d90a5feaafd6a2f5
diff --git a/lib/Core/Source.cpp b/lib/Core/Source.cpp
index c64931f..e7b3781 100644
--- a/lib/Core/Source.cpp
+++ b/lib/Core/Source.cpp
@@ -76,7 +76,7 @@
return nullptr;
}
- Source *result = CreateFromModule(pContext, *module, /* pNoDelete */false);
+ Source *result = CreateFromModule(pContext, pName, *module, /* pNoDelete */false);
if (result == nullptr) {
delete module;
}
@@ -102,7 +102,7 @@
return nullptr;
}
- Source *result = CreateFromModule(pContext, *module, /* pNoDelete */false);
+ Source *result = CreateFromModule(pContext, pPath.c_str(), *module, /* pNoDelete */false);
if (result == nullptr) {
delete module;
}
@@ -110,7 +110,7 @@
return result;
}
-Source *Source::CreateFromModule(BCCContext &pContext, llvm::Module &pModule,
+Source *Source::CreateFromModule(BCCContext &pContext, const char* name, llvm::Module &pModule,
bool pNoDelete) {
std::string ErrorInfo;
llvm::raw_string_ostream ErrorStream(ErrorInfo);
@@ -120,7 +120,7 @@
return nullptr;
}
- Source *result = new (std::nothrow) Source(pContext, pModule, pNoDelete);
+ Source *result = new (std::nothrow) Source(name, pContext, pModule, pNoDelete);
if (result == nullptr) {
ALOGE("Out of memory during Source object allocation for `%s'!",
pModule.getModuleIdentifier().c_str());
@@ -128,8 +128,9 @@
return result;
}
-Source::Source(BCCContext &pContext, llvm::Module &pModule, bool pNoDelete)
- : mContext(pContext), mModule(&pModule), mNoDelete(pNoDelete) {
+Source::Source(const char* name, BCCContext &pContext, llvm::Module &pModule,
+ bool pNoDelete)
+ : mName(name), mContext(pContext), mModule(&pModule), mNoDelete(pNoDelete) {
pContext.addSource(*this);
}
@@ -160,7 +161,7 @@
return nullptr;
}
- Source *result = CreateFromModule(pContext, *module, /* pNoDelete */false);
+ Source *result = CreateFromModule(pContext, pName.c_str(), *module, /* pNoDelete */false);
if (result == nullptr) {
delete module;
}
diff --git a/lib/Renderscript/Android.mk b/lib/Renderscript/Android.mk
index f6f47a4..3280909 100644
--- a/lib/Renderscript/Android.mk
+++ b/lib/Renderscript/Android.mk
@@ -29,10 +29,12 @@
RSInfoExtractor.cpp \
RSInfoReader.cpp \
RSInfoWriter.cpp \
+ RSMetadata.cpp \
RSScript.cpp \
RSInvokeHelperPass.cpp \
RSScreenFunctionsPass.cpp \
- RSStubsWhiteList.cpp
+ RSStubsWhiteList.cpp \
+ RSScriptGroupFusion.cpp
#=====================================================================
# Device Static Library: libbccRenderscript
diff --git a/lib/Renderscript/RSCompilerDriver.cpp b/lib/Renderscript/RSCompilerDriver.cpp
index 7d46c93..6aa2e5f 100644
--- a/lib/Renderscript/RSCompilerDriver.cpp
+++ b/lib/Renderscript/RSCompilerDriver.cpp
@@ -14,21 +14,21 @@
* limitations under the License.
*/
-#include <string>
-
#include "bcc/Renderscript/RSCompilerDriver.h"
+#include "llvm/IR/AssemblyAnnotationWriter.h"
#include <llvm/IR/Module.h>
#include <llvm/Support/CommandLine.h>
#include <llvm/Support/Path.h>
#include <llvm/Support/raw_ostream.h>
#include "bcinfo/BitcodeWrapper.h"
-
+#include "bcc/BCCContext.h"
#include "bcc/Compiler.h"
#include "bcc/Config/Config.h"
#include "bcc/Renderscript/RSInfo.h"
#include "bcc/Renderscript/RSScript.h"
+#include "bcc/Renderscript/RSScriptGroupFusion.h"
#include "bcc/Support/CompilerConfig.h"
#include "bcc/Source.h"
#include "bcc/Support/FileMutex.h"
@@ -38,6 +38,8 @@
#include "bcc/Support/Sha1Util.h"
#include "bcc/Support/OutputFile.h"
+#include <string>
+
#ifdef HAVE_ANDROID_OS
#include <cutils/properties.h>
#endif
@@ -316,6 +318,30 @@
return status == Compiler::kSuccess;
}
+bool RSCompilerDriver::buildScriptGroup(
+ BCCContext& Context, const char* pOutputFilepath, const char*pRuntimePath,
+ const std::vector<const Source*>& sources, const std::vector<int>& slots,
+ bool dumpIR) {
+ llvm::Module* module = fuseKernels(Context, sources, slots);
+ if (module == nullptr) {
+ return false;
+ }
+
+ const std::unique_ptr<Source> source(
+ Source::CreateFromModule(Context, pOutputFilepath, *module));
+ RSScript script(*source);
+
+ uint8_t bitcode_sha1[SHA1_DIGEST_LENGTH];
+ const char* compileCommandLineToEmbed = "";
+
+ llvm::SmallString<80> output_path(pOutputFilepath);
+ llvm::sys::path::replace_extension(output_path, ".o");
+
+ compileScript(script, pOutputFilepath, output_path.c_str(), pRuntimePath,
+ bitcode_sha1, compileCommandLineToEmbed, true, dumpIR);
+
+ return true;
+}
bool RSCompilerDriver::buildForCompatLib(RSScript &pScript, const char *pOut,
const char *pRuntimePath) {
diff --git a/lib/Renderscript/RSMetadata.cpp b/lib/Renderscript/RSMetadata.cpp
new file mode 100644
index 0000000..841ade7
--- /dev/null
+++ b/lib/Renderscript/RSMetadata.cpp
@@ -0,0 +1,92 @@
+/*
+ * Copyright 2015, The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "bcc/Renderscript/RSMetadata.h"
+
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/IR/Module.h"
+
+// Name of metadata node where pragma info resides (should be synced with
+// slang.cpp)
+const llvm::StringRef pragma_metadata_name("#pragma");
+
+/*
+ * The following names should be synced with the one appeared in
+ * slang_rs_metadata.h.
+ */
+
+// Name of metadata node where exported variable names reside
+static const llvm::StringRef
+export_var_metadata_name("#rs_export_var");
+
+// Name of metadata node where exported function names reside
+static const llvm::StringRef
+export_func_metadata_name("#rs_export_func");
+
+// Name of metadata node where exported ForEach name information resides
+static const llvm::StringRef
+export_foreach_name_metadata_name("#rs_export_foreach_name");
+
+// Name of metadata node where exported ForEach signature information resides
+static const llvm::StringRef
+export_foreach_metadata_name("#rs_export_foreach");
+
+// Name of metadata node where RS object slot info resides (should be
+static const llvm::StringRef
+object_slot_metadata_name("#rs_object_slots");
+
+bcc::RSMetadata::RSMetadata(llvm::Module &Module) : Module(Module) {}
+
+void bcc::RSMetadata::deleteAll() {
+ std::vector<llvm::StringRef> MDNames;
+ MDNames.push_back(pragma_metadata_name);
+ MDNames.push_back(export_var_metadata_name);
+ MDNames.push_back(export_func_metadata_name);
+ MDNames.push_back(export_foreach_name_metadata_name);
+ MDNames.push_back(export_foreach_metadata_name);
+ MDNames.push_back(object_slot_metadata_name);
+
+ for (std::vector<llvm::StringRef>::iterator MI = MDNames.begin(),
+ ME = MDNames.end();
+ MI != ME; ++MI) {
+ llvm::NamedMDNode *MDNode = Module.getNamedMetadata(*MI);
+ if (MDNode) {
+ MDNode->eraseFromParent();
+ }
+ }
+}
+
+void bcc::RSMetadata::markForEachFunction(llvm::Function &Function,
+ uint32_t Signature) {
+ llvm::NamedMDNode *ExportForEachNameMD;
+ llvm::NamedMDNode *ExportForEachMD;
+
+ llvm::MDString *MDString;
+ llvm::MDNode *MDNode;
+
+ ExportForEachNameMD =
+ Module.getOrInsertNamedMetadata(export_foreach_name_metadata_name);
+ MDString = llvm::MDString::get(Module.getContext(), Function.getName());
+ MDNode = llvm::MDNode::get(Module.getContext(), MDString);
+ ExportForEachNameMD->addOperand(MDNode);
+
+ ExportForEachMD =
+ Module.getOrInsertNamedMetadata(export_foreach_metadata_name);
+ MDString = llvm::MDString::get(Module.getContext(),
+ llvm::utostr_32(Signature));
+ MDNode = llvm::MDNode::get(Module.getContext(), MDString);
+ ExportForEachMD->addOperand(MDNode);
+}
diff --git a/lib/Renderscript/RSScriptGroupFusion.cpp b/lib/Renderscript/RSScriptGroupFusion.cpp
new file mode 100644
index 0000000..2763844
--- /dev/null
+++ b/lib/Renderscript/RSScriptGroupFusion.cpp
@@ -0,0 +1,209 @@
+/*
+ * Copyright 2015, The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "bcc/Renderscript/RSScriptGroupFusion.h"
+
+#include "bcc/Assert.h"
+#include "bcc/BCCContext.h"
+#include "bcc/Renderscript/RSMetadata.h"
+#include "bcc/Renderscript/RSScript.h"
+#include "bcc/Source.h"
+#include "bcc/Support/Log.h"
+#include "bcinfo/MetadataExtractor.h"
+#include "llvm/IR/AssemblyAnnotationWriter.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Linker/Linker.h"
+#include "llvm/PassManager.h"
+#include "llvm/Transforms/IPO.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <map>
+#include <string>
+
+using llvm::Function;
+
+using std::map;
+using std::pair;
+using std::string;
+
+namespace bcc {
+
+namespace {
+
+struct SourceCompare {
+ bool operator()(const Source* lhs, const Source* rhs) const {
+ return lhs->getName().compare(rhs->getName()) < 0;
+ }
+};
+
+typedef map<const Source*,
+ map<int, pair<const Function*, int>>, SourceCompare> SlotMap;
+
+const Function* getFunction(const Source* source, const int slot) {
+ const llvm::Module* module = &source->getModule();
+ bcinfo::MetadataExtractor metadata(module);
+ if (!metadata.extract()) {
+ return nullptr;
+ }
+ const char* functionName = metadata.getExportForEachNameList()[slot];
+ return module->getFunction(functionName);
+}
+
+llvm::Type* getArgType(const Source* source, const int slot) {
+ const Function* func = getFunction(source, slot);
+ if (func == nullptr) {
+ return nullptr;
+ }
+ auto argIter = func->getArgumentList().begin();
+ return argIter->getType();
+}
+
+llvm::Type* getReturnType(const Source* source, const int slot) {
+ const Function* func = getFunction(source, slot);
+ if (func == nullptr) {
+ return nullptr;
+ }
+ return func->getReturnType();
+}
+
+pair<const Function*, int> getFunction(
+ SlotMap& slotMap, llvm::Linker& linker, const Source* source,
+ const int slot) {
+ auto it1 = slotMap.find(source);
+ if (it1 == slotMap.end()) {
+ llvm::Module* module = (llvm::Module*)&source->getModule();
+ if (linker.linkInModule(module)) {
+ ALOGE("Linking for module in source %s failed.",
+ source->getName().c_str());
+ return std::make_pair(nullptr, 0);
+ }
+ }
+ auto &functions = slotMap[source];
+
+ auto it2 = functions.find(slot);
+ if (it2 == functions.end()) {
+ bcinfo::MetadataExtractor metadata(&source->getModule());
+ metadata.extract();
+ const char* functionName = metadata.getExportForEachNameList()[slot];
+ if (functionName == nullptr) {
+ return std::make_pair(nullptr, 0);
+ }
+
+ if (metadata.getExportForEachInputCountList()[slot] > 1) {
+ // TODO: Handle multiple inputs.
+ ALOGW("Kernel %s has multiple inputs", functionName);
+ return std::make_pair(nullptr, 0);
+ }
+
+ const uint32_t signature = metadata.getExportForEachSignatureList()[slot];
+ int dim = 0;
+ if (metadata.hasForEachSignatureX(signature)) {
+ dim++;
+ }
+ if (metadata.hasForEachSignatureY(signature)) {
+ dim++;
+ }
+
+ const Function* function = linker.getModule()->getFunction(functionName);
+ it2 = functions.emplace(slot, std::make_pair(function, dim)).first;
+ }
+ return it2->second;
+}
+
+} // anonymous namespace
+
+llvm::Module*
+fuseKernels(bcc::BCCContext& Context,
+ const std::vector<const Source *>& sources,
+ const std::vector<int>& slots) {
+ bccAssert(sources.size() > 1 && "Need at least two kernels for kernel merging");
+ bccAssert(sources.size() == slots.size() && "sources and slots differ in size");
+
+ llvm::LLVMContext& context = Context.getLLVMContext();
+ std::unique_ptr<llvm::Module> module(
+ new llvm::Module("Merged ScriptGroup", context));
+ if (module == nullptr) {
+ ALOGE("out of memory while creating module for fused kernels");
+ return nullptr;
+ }
+ llvm::Linker linker(module.get());
+ SlotMap slotMap;
+
+ llvm::Type* inputType = getArgType(sources.front(), slots.front());
+ if (inputType == nullptr) {
+ return nullptr;
+ }
+ llvm::Type* returnType = getReturnType(sources.back(), slots.back());
+ if (returnType == nullptr) {
+ return nullptr;
+ }
+ llvm::Type* I32Ty = llvm::IntegerType::get(context, 32);
+ Function* fusedKernel =
+ (Function*)(module->getOrInsertFunction(
+ "__rs_fused_kernels", returnType, inputType, I32Ty, I32Ty, nullptr));
+
+ llvm::BasicBlock* block = llvm::BasicBlock::Create(context, "entry",
+ fusedKernel);
+ llvm::IRBuilder<> builder(block);
+
+ Function::arg_iterator argIter = fusedKernel->arg_begin();
+ llvm::Value* dataElement = argIter++;
+ dataElement->setName("DataIn");
+ llvm::Value* X = argIter++;
+ X->setName("x");
+ llvm::Value* Y = argIter++;
+ Y->setName("y");
+
+ auto slotIter = slots.begin();
+ for (const Source* source : sources) {
+ int slot = *slotIter++;
+
+ const auto& p = getFunction(slotMap, linker, source, slot);
+ const Function* function = p.first;
+ if (function == nullptr) {
+ return nullptr;
+ }
+ const int dim = p.second;
+
+ std::vector<llvm::Value*> args;
+ args.push_back(dataElement);
+ if (dim > 0) {
+ args.push_back(X);
+ if (dim > 1) {
+ args.push_back(Y);
+ }
+ }
+
+ dataElement = builder.CreateCall((llvm::Value*)function, args);
+ }
+
+ builder.CreateRet(dataElement);
+
+ bcc::RSMetadata metadata(*module);
+ metadata.deleteAll();
+ metadata.markForEachFunction(*fusedKernel, bcc::RSMetadata::FOREACH_KERNEL
+ | bcc::RSMetadata::FOREACH_IN
+ | bcc::RSMetadata::FOREACH_OUT
+ | bcc::RSMetadata::FOREACH_X
+ | bcc::RSMetadata::FOREACH_Y);
+
+ return module.release();
+}
+
+} // namespace bcc