Merge pull request #33025 from wallysslima:master
PiperOrigin-RevId: 287527549
Change-Id: I6bcadc3969f47f900ea142c565dd65def4b99ada
diff --git a/.bazelrc b/.bazelrc
index 94a425a..9f80f1d 100644
--- a/.bazelrc
+++ b/.bazelrc
@@ -238,6 +238,10 @@
build:macos --copt=-w
build:windows --copt=/w
+# Tensorflow uses M_* math constants that only get defined by MSVC headers if
+# _USE_MATH_DEFINES is defined.
+build:windows --copt=/D_USE_MATH_DEFINES
+
# Default paths for TF_SYSTEM_LIBS
build:linux --define=PREFIX=/usr
build:linux --define=LIBDIR=$(PREFIX)/lib
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 756b7f0..b4dc0e7 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -72,7 +72,7 @@
[tensorflow/core](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core)
and
[tensorflow/python](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python).
- TensorFlow has reached version 1 and hence cannot make
+ TensorFlow has passed version 1.0 and hence cannot make
non-backward-compatible API changes without a major release. Reviewers of
your pull request will comment on any API compatibility issues.
* When you contribute a new feature to TensorFlow, the maintenance burden is
diff --git a/configure.py b/configure.py
index a55acb1..b98cc9f 100644
--- a/configure.py
+++ b/configure.py
@@ -175,7 +175,8 @@
library_paths = run_shell([
python_bin_path, '-c',
'import site; print("\\n".join(site.getsitepackages()))'
- ], stderr=stderr).split('\n')
+ ],
+ stderr=stderr).split('\n')
except subprocess.CalledProcessError:
library_paths = [
run_shell([
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 081edb2..d8a681c 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -860,7 +860,7 @@
output_files = TENSORFLOW_API_INIT_FILES_V1,
output_package = "tensorflow._api.v1",
root_file_name = "v1.py",
- root_init_template = "api_template_v1.__init__.py",
+ root_init_template = "$(location api_template_v1.__init__.py)",
)
gen_api_init_files(
@@ -883,7 +883,7 @@
output_files = TENSORFLOW_API_INIT_FILES_V2,
output_package = "tensorflow._api.v2",
root_file_name = "v2.py",
- root_init_template = "api_template.__init__.py",
+ root_init_template = "$(location api_template.__init__.py)",
)
py_library(
diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py
index c515cc7..a8cd6d1 100644
--- a/tensorflow/api_template.__init__.py
+++ b/tensorflow/api_template.__init__.py
@@ -89,6 +89,7 @@
# Enable TF2 behaviors
from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top
_compat.enable_v2_behavior()
+_major_api_version = 2
# Load all plugin libraries from site-packages/tensorflow-plugins if we are
@@ -119,8 +120,14 @@
_current_file_location.startswith(dir_) for dir_ in _site_packages_dirs)
if _running_from_pip_package():
+ # TODO(gunan): Add sanity checks to loaded modules here.
for _s in _site_packages_dirs:
- # TODO(gunan): Add sanity checks to loaded modules here.
+ # Load first party dynamic kernels.
+ _main_dir = _os.path.join(_s, 'tensorflow_core/core/kernels')
+ if _fi.file_exists(_main_dir):
+ _ll.load_library(_main_dir)
+
+ # Load third party dynamic kernels.
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
if _fi.file_exists(_plugin_dir):
_ll.load_library(_plugin_dir)
diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py
index 2b2899c..b6b5e36 100644
--- a/tensorflow/api_template_v1.__init__.py
+++ b/tensorflow/api_template_v1.__init__.py
@@ -104,6 +104,8 @@
_current_module.app.flags = flags # pylint: disable=undefined-variable
setattr(_current_module, "flags", flags)
+_major_api_version = 1
+
# Load all plugin libraries from site-packages/tensorflow-plugins if we are
# running under pip.
# TODO(gunan): Enable setting an environment variable to define arbitrary plugin
@@ -132,8 +134,14 @@
_current_file_location.startswith(dir_) for dir_ in _site_packages_dirs)
if _running_from_pip_package():
+ # TODO(gunan): Add sanity checks to loaded modules here.
for _s in _site_packages_dirs:
- # TODO(gunan): Add sanity checks to loaded modules here.
+ # Load first party dynamic kernels.
+ _main_dir = _os.path.join(_s, 'tensorflow_core/core/kernels')
+ if _fi.file_exists(_main_dir):
+ _ll.load_library(_main_dir)
+
+ # Load third party dynamic kernels.
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
if _fi.file_exists(_plugin_dir):
_ll.load_library(_plugin_dir)
diff --git a/tensorflow/c/eager/c_api_internal.cc b/tensorflow/c/eager/c_api_internal.cc
index f609271..4f3de47 100644
--- a/tensorflow/c/eager/c_api_internal.cc
+++ b/tensorflow/c/eager/c_api_internal.cc
@@ -14,6 +14,7 @@
==============================================================================*/
#include "tensorflow/c/eager/c_api_internal.h"
+#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/host_info.h"
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
@@ -26,29 +27,22 @@
if (!status->status.ok()) {
return nullptr;
}
- auto create_or_reset =
- [&op_to_reset, &ctx, &name, &types, &raw_device_name, &status](
- bool is_function, TFE_OpInferenceContext* inference_ctx) -> TFE_Op* {
- if (op_to_reset) {
- status->status = op_to_reset->Reset(ctx, name, is_function, types,
- raw_device_name, inference_ctx);
- return op_to_reset;
- } else {
- TFE_Op* new_op = new TFE_Op(ctx, name, is_function, types, inference_ctx);
- status->status = new_op->operation.SetDeviceName(raw_device_name);
- return new_op;
- }
- };
+ if (op_to_reset && op_to_reset->ctx != ctx) {
+ status->status = tensorflow::errors::Internal(
+ "Cannot reset a TFE_Op from another TFE_Context");
+ return nullptr;
+ }
+
+ std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
if (!is_function) {
const tensorflow::OpDef* op_def;
status->status = tensorflow::OpDefForOp(op_or_function_name, &op_def);
if (!status->status.ok()) {
return nullptr;
}
- return create_or_reset(false, new TFE_OpInferenceContext(op_def));
- }
- if (!ctx->context->FindFunctionByName(name)) {
+ inference_ctx.reset(new TFE_OpInferenceContext(op_def));
+ } else if (!ctx->context->FindFunctionByName(name)) {
status->status = tensorflow::errors::NotFound(
"'", name,
"' is neither a type of a primitive operation nor a name "
@@ -58,5 +52,15 @@
"registered in the binary running in this process.");
return nullptr;
}
- return create_or_reset(true, nullptr);
+
+ if (op_to_reset) {
+ status->status = op_to_reset->Reset(
+ name, is_function, types, raw_device_name, std::move(inference_ctx));
+ return op_to_reset;
+ }
+
+ TFE_Op* new_op =
+ new TFE_Op(ctx, name, is_function, types, std::move(inference_ctx));
+ status->status = new_op->operation.SetDeviceName(raw_device_name);
+ return new_op;
}
diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h
index 29106e2..df19291 100644
--- a/tensorflow/c/eager/c_api_internal.h
+++ b/tensorflow/c/eager/c_api_internal.h
@@ -125,24 +125,26 @@
struct TFE_Op {
TFE_Op(TFE_Context* ctx, const char* op, bool is_function,
const tensorflow::AttrTypeMap* t,
- TFE_OpInferenceContext* inference_ctx)
- : operation(ctx->context, op, is_function, t),
- inference_ctx(inference_ctx) {}
+ std::unique_ptr<TFE_OpInferenceContext> inference_ctx)
+ : ctx(ctx),
+ operation(ctx->context, op, is_function, t),
+ inference_ctx(std::move(inference_ctx)) {}
void Clear() {
operation.Clear();
inference_ctx.reset();
}
- tensorflow::Status Reset(TFE_Context* ctx, const char* op, bool is_function,
+ tensorflow::Status Reset(const char* op, bool is_function,
const tensorflow::AttrTypeMap* t,
const char* raw_device_name,
- TFE_OpInferenceContext* infer_ctx) {
- inference_ctx.reset(infer_ctx);
+ std::unique_ptr<TFE_OpInferenceContext> infer_ctx) {
+ inference_ctx = std::move(infer_ctx);
return operation.Reset(ctx->context, op, is_function, t, raw_device_name,
nullptr);
}
+ TFE_Context* ctx;
tensorflow::EagerOperation operation;
std::unique_ptr<TFE_OpInferenceContext> inference_ctx;
};
diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc
index b3c1e6a..f67c6f9 100644
--- a/tensorflow/cc/gradients/math_grad.cc
+++ b/tensorflow/cc/gradients/math_grad.cc
@@ -13,7 +13,6 @@
limitations under the License.
==============================================================================*/
-#define _USE_MATH_DEFINES
#include <cmath>
#include "tensorflow/cc/ops/array_ops_internal.h"
diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.h b/tensorflow/compiler/jit/build_xla_ops_pass.h
index 58f7c4b..8487802 100644
--- a/tensorflow/compiler/jit/build_xla_ops_pass.h
+++ b/tensorflow/compiler/jit/build_xla_ops_pass.h
@@ -22,8 +22,9 @@
namespace tensorflow {
-// Adds _XlaCompile and _XlaRun operations to the TF graph that compiles and
-// executes (using XLA) TF function calls marked with "_XlaCompiledKernel".
+// Replaces TF function calls marked with `_XlaCompiledKernel` with _XlaCompile
+// and _XlaRun nodes (which compile and launch, respectively, the corresponding
+// HLO module).
class BuildXlaOpsPass : public GraphOptimizationPass {
public:
// If enable_lazy_compilation is not nullopt then *enable_lazy_compilation
diff --git a/tensorflow/compiler/jit/defs.cc b/tensorflow/compiler/jit/defs.cc
index b23f6ec..4bea71e 100644
--- a/tensorflow/compiler/jit/defs.cc
+++ b/tensorflow/compiler/jit/defs.cc
@@ -17,6 +17,8 @@
namespace tensorflow {
+const char* const kXlaMustCompileAttr = "_XlaMustCompile";
+
const char* const kXlaCompileAttr = "_XlaCompile";
// User-provided through jit_scope APIs. Effective only when auto_jit is OFF.
diff --git a/tensorflow/compiler/jit/defs.h b/tensorflow/compiler/jit/defs.h
index bf80093..9eb4c2c 100644
--- a/tensorflow/compiler/jit/defs.h
+++ b/tensorflow/compiler/jit/defs.h
@@ -22,7 +22,16 @@
namespace tensorflow {
// Name of attribute used to tag operators for compilation with XLA
+
+// Implies must-compile semantics: either it will be compiled
+// with XLA, or an error will be thrown.
+extern const char* const kXlaMustCompileAttr; // "_XlaMustCompile"
+
+// Implies auto-clustering: tagged nodes will be clustered and compiled with XLA
+// on a best-effort basis.
extern const char* const kXlaCompileAttr; // "_XlaCompile"
+
+// Implies auto-clustering within the given scope.
extern const char* const kXlaScopeAttr; // "_XlaScope"
extern const char* const kXlaInternalScopeAttr; // "_XlaInternalScope"
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
index 8b627cd..bf8b2c4 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
@@ -27,6 +27,15 @@
namespace tensorflow {
+// EncapsulateSubgraphs pass takes all the nodes with the same cluster ID
+// (derived from kXlaClusterAttr=ID (kXlaClusterAttr) attribute), puts them into
+// a TF function, and replaces the subgraph in the main graph with a call to
+// that TF function annotated with kXlaCompiledKernelAttr (_XlaCompiledKernel).
+class EncapsulateSubgraphsPass : public GraphOptimizationPass {
+ public:
+ Status Run(const GraphOptimizationPassOptions& options) override;
+};
+
// A rewriting function to apply to each subgraph during encapsulation.
// 'arg_source_tensors' are the tensors corresponding to the arguments in the
// original source graph (*not* 'graph').
@@ -100,11 +109,6 @@
// TODO(hpucha): Move the utilities to a more appropriate place.
void SortControlInputs(GraphDef* gdef);
-class EncapsulateSubgraphsPass : public GraphOptimizationPass {
- public:
- Status Run(const GraphOptimizationPassOptions& options) override;
-};
-
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_SUBGRAPHS_PASS_H_
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h
index 99e9dfd..3057e4c 100644
--- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h
@@ -28,7 +28,7 @@
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/platform/env.h"
- namespace tensorflow {
+namespace tensorflow {
// Encapsulates nodes marked with the _xla_compile_id attribute into
// XlaLaunch operators.
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 9f2bc3f7f..edcec28 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -1187,7 +1187,7 @@
}
if (!whitelist.empty() && !whitelist.contains(node->def().op())) {
- VLOG(1) << "Rejecting " << node->name()
+ VLOG(1) << "Rejecting TF operation " << node->def().op()
<< " as it is not listed in --tf_xla_ops_to_cluster.";
continue;
}
@@ -2036,6 +2036,7 @@
"XlaDynamicSlice",
"XlaDynamicUpdateSlice",
"XlaEinsum",
+ "XlaGather",
"XlaIf",
"XlaKeyValueSort",
"XlaPad",
@@ -2043,6 +2044,7 @@
"XlaReduce",
"XlaReduceWindow",
"XlaReplicaId",
+ "XlaScatter",
"XlaSelectAndScatter",
"XlaSelfAdjointEig",
"XlaSend",
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.h b/tensorflow/compiler/jit/mark_for_compilation_pass.h
index 0c9b407..8b66071 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.h
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.h
@@ -34,8 +34,9 @@
// compilation by the encapsulate subgraphs pass.
extern const char* const kXlaOutsideCompilationAttr;
-// Pass that marks a subset of operators in the graph with attribute
-// _XlaCluster so they are compiled by the EncapsulateSubgraphsPass.
+// Marks a subset of nodes in the graph which are to be clustered
+// with an attribute _XlaCluster=<cluster id> so they are picked up by the
+// EncapsulateSubgraphsPass.
class MarkForCompilationPass : public GraphOptimizationPass {
public:
MarkForCompilationPass() = default;
diff --git a/tensorflow/compiler/jit/xla_kernel_creator.cc b/tensorflow/compiler/jit/xla_kernel_creator.cc
index e3706a0..23bd742 100644
--- a/tensorflow/compiler/jit/xla_kernel_creator.cc
+++ b/tensorflow/compiler/jit/xla_kernel_creator.cc
@@ -21,7 +21,7 @@
bool XlaKernelCreator::CanCreateKernel(const FunctionLibraryRuntime& flr,
const NodeDef& node_def) const {
- return CanCreateXlaKernel(flr, node_def);
+ return CanCreateXlaKernel(node_def);
}
Status XlaKernelCreator::CreateKernel(FunctionLibraryRuntime* flr,
diff --git a/tensorflow/compiler/jit/xla_kernel_creator_test.cc b/tensorflow/compiler/jit/xla_kernel_creator_test.cc
index 28606ab..7ec3733 100644
--- a/tensorflow/compiler/jit/xla_kernel_creator_test.cc
+++ b/tensorflow/compiler/jit/xla_kernel_creator_test.cc
@@ -95,15 +95,17 @@
TEST_F(XlaKernelCreatorTest, OneFloatOneResourceArgument) {
FunctionDef fdef = XTimesY();
- (*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(true);
+ (*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
Init({fdef});
XlaKernelCreator xla_kernel_creator;
-
- Status status = xla_kernel_creator.CreateKernel(
- flr_, ToNodeDef(R"pb(
+ NodeDef callsite =
+ ToNodeDef(R"pb(
name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b'
- )pb"),
- &kernel_);
+ )pb");
+ (*callsite.mutable_attr())["_XlaMustCompile"] = BoolAttr(true);
+
+ // Note: need to set attribute on the created node.
+ Status status = xla_kernel_creator.CreateKernel(flr_, callsite, &kernel_);
ASSERT_TRUE(status.ok()) << status.ToString();
EXPECT_EQ("XTimesY", kernel_->name());
@@ -137,7 +139,7 @@
TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrIsSetToFalse) {
FunctionDef fdef = XTimesY();
- (*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(false);
+ (*fdef.mutable_attr())["_XlaMustCompile"] = BoolAttr(false);
Init({fdef});
XlaKernelCreator xla_kernel_creator;
diff --git a/tensorflow/compiler/jit/xla_kernel_creator_util.cc b/tensorflow/compiler/jit/xla_kernel_creator_util.cc
index 6441dd3..94727fd 100644
--- a/tensorflow/compiler/jit/xla_kernel_creator_util.cc
+++ b/tensorflow/compiler/jit/xla_kernel_creator_util.cc
@@ -23,7 +23,9 @@
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/util/ptr_util.h"
@@ -68,40 +70,10 @@
};
} // namespace
-bool CanCreateXlaKernel(const FunctionLibraryRuntime& flr,
- const NodeDef& node_def) {
- const FunctionDef* function_def =
- flr.GetFunctionLibraryDefinition()->Find(node_def.name());
- if (function_def == nullptr) {
- // The node def is not calling a function. Individual ops can be
- // run directly using on-demand mode, no need to create XlaLaunch
- // kernel for them.
- return false;
- }
-
- // If kXlaCompileAttr is set on the node_def, use its value.
- const auto& it = node_def.attr().find(kXlaCompileAttr);
- if (it != node_def.attr().end()) {
- return it->second.b();
- }
-
- // kXlaCompileAttr is not set on node_def, check if it is set on
- // FunctionDef.
- bool xla_compile = false;
- Status status = flr.GetFunctionLibraryDefinition()->GetAttr(
- node_def, kXlaCompileAttr, &xla_compile);
- if (!status.ok() || !xla_compile) {
- if (VLOG_IS_ON(3)) {
- if (!status.ok()) {
- VLOG(3) << "No " << kXlaCompileAttr << " attr defined for "
- << node_def.op() << ". status=" << status.ToString();
- } else {
- VLOG(3) << node_def.op() << " is explicitly marked not to be compiled";
- }
- }
- return false;
- }
- return true;
+bool CanCreateXlaKernel(const NodeDef& node_def) {
+ // If kXlaMustCompileAttr is set on the node_def, use its value.
+ const auto& it = node_def.attr().find(kXlaMustCompileAttr);
+ return it != node_def.attr().end() && it->second.b();
}
// Given a FunctionLibraryRuntime and a NodeDef calling a function in the
@@ -118,8 +90,11 @@
FunctionLibraryRuntime::Handle handle;
// If node_def is not instantiable, e.g., the function does not exist,
// simply bail out.
+ NameAttrList function;
+ TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
+
TF_RETURN_IF_ERROR(
- flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle));
+ flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle));
*fbody = flr->GetFunctionBody(handle);
CHECK(*fbody); // Can't be nullptr since we just instantiated it.
const DataTypeVector& arg_types = (*fbody)->arg_types;
@@ -149,7 +124,7 @@
Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
std::unique_ptr<OpKernel>* kernel) {
- if (!CanCreateXlaKernel(*flr, node_def)) {
+ if (!CanCreateXlaKernel(node_def)) {
return errors::Internal("Invalid node: ", node_def.ShortDebugString());
}
@@ -241,9 +216,7 @@
// Create the kernel.
NameAttrList function;
- function.set_name(node_def.op());
- *(function.mutable_attr()) = node_def.attr();
-
+ TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
Device* dev = flr->device();
Status s;
OpKernelConstruction construction(
diff --git a/tensorflow/compiler/jit/xla_kernel_creator_util.h b/tensorflow/compiler/jit/xla_kernel_creator_util.h
index 71398c3..5ec8df0 100644
--- a/tensorflow/compiler/jit/xla_kernel_creator_util.h
+++ b/tensorflow/compiler/jit/xla_kernel_creator_util.h
@@ -24,11 +24,9 @@
class FunctionLibraryRuntime;
class OpKernel;
- // Given a NodeDef 'node_def' and the function library runtime 'flr', returns
- // true if 'node_def' is a call to a compilable function defined in 'flr',
- // with the kXlaCompileAttr set.
-bool CanCreateXlaKernel(const FunctionLibraryRuntime& flr,
- const NodeDef& node_def);
+// Given a NodeDef `node_def` returns true iff `node_def` has kXlaCompileAttr
+// set.
+bool CanCreateXlaKernel(const NodeDef& node_def);
// Given a supported NodeDef, returns a XlaLaunchOp that computes the node.
Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD
index eeff920..d301898 100644
--- a/tensorflow/compiler/mlir/lite/BUILD
+++ b/tensorflow/compiler/mlir/lite/BUILD
@@ -8,6 +8,7 @@
default_visibility = [
# TODO(jpienaar): Make the visibility more restrictive.
":friends",
+ "//tensorflow/lite/experimental/tf_runtime:__subpackages__",
],
licenses = ["notice"], # Apache 2.0
)
@@ -590,7 +591,6 @@
":flatbuffer_translate_lib",
"//tensorflow/core:lib",
"//tensorflow/core/platform:logging",
- "//tensorflow/core/platform/default/build_config:base",
"//tensorflow/lite:framework",
"//tensorflow/lite/delegates/flex:delegate",
"//tensorflow/lite/kernels:builtin_ops",
diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
index 11d9712..a561fbc 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc
@@ -212,7 +212,7 @@
// type, thus none stats op is required and nullptr is retruned.
// If the min max information is invalid, nullptr is returned.
mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b,
- Value* res) {
+ Value res) {
// If the `tensor` has scale/zero_point, it must have been quantized, then the
// min/max stats is just for comments, so ignore it.
if (!tensor.quantization || IsQuantized(tensor)) return nullptr;
@@ -497,12 +497,12 @@
// TODO(krzysd) Handle function calls
StatusOr<Operation*> ConvertOp(
- const tflite::OperatorT& op, const std::vector<Value*>& vals_map,
- Value* optional_arg_marker, const std::vector<std::string>& op_names,
+ const tflite::OperatorT& op, const std::vector<Value>& vals_map,
+ Value optional_arg_marker, const std::vector<std::string>& op_names,
const std::vector<std::string>& func_names,
const std::vector<std::unique_ptr<tflite::TensorT>>& tensors, Location loc,
OpBuilder builder) {
- llvm::SmallVector<Value*, 4> operands;
+ llvm::SmallVector<Value, 4> operands;
llvm::SmallVector<mlir::Type, 2> outputTypes;
if (op.outputs.empty()) {
@@ -692,19 +692,19 @@
auto& body = func.getBody();
OpBuilder op_builder{body};
- std::vector<Value*> vals_map(subgraph.tensors.size(), nullptr);
- Value* maybe_optional_arg_marker = nullptr;
+ std::vector<Value> vals_map(subgraph.tensors.size(), nullptr);
+ Value maybe_optional_arg_marker = nullptr;
// Get or construct MLIR values for each input
for (int i = 0, e = subgraph.inputs.size(); i < e; i++) {
auto input_tensor = subgraph.inputs[i];
const auto& tensor = *subgraph.tensors.at(input_tensor);
auto loc = TensorLoc(tensor, builder, base_loc);
- if (nullptr != vals_map[input_tensor]) {
+ if (vals_map[input_tensor]) {
auto err = errors::FailedPrecondition("duplicate input arguments");
return emitError(loc, err.ToString()), err;
}
- Value* input_value = func.getArgument(i);
+ Value input_value = func.getArgument(i);
// If the `tensor` has min/max and doesn't have scale/zero_point
// information, a stats op is created to use the input_value, then the
@@ -745,7 +745,7 @@
builder.getUnitAttr())
.getResult();
}
- } else if (nullptr == vals_map.at(input_num)) {
+ } else if (!vals_map.at(input_num)) {
auto& const_tensor = *subgraph.tensors[input_num];
auto const_loc = TensorLoc(const_tensor, builder, base_loc);
auto op_or_err =
@@ -768,7 +768,7 @@
? base_loc
: TensorLoc(*subgraph.tensors[op->outputs[0]], builder, base_loc);
// If there's an optional argument, maybe_optional_arg_marker has been set
- // to a valid Value*
+ // to a valid Value
TF_ASSIGN_OR_RETURN(
auto* mlir_op,
ConvertOp(*op, vals_map, maybe_optional_arg_marker, op_names,
@@ -791,9 +791,9 @@
}
// Construct return values
- llvm::SmallVector<Value*, 4> return_operands;
+ llvm::SmallVector<Value, 4> return_operands;
for (auto index : func_outputs) {
- if (nullptr == vals_map.at(index)) {
+ if (!vals_map.at(index)) {
auto& const_tensor = *subgraph.tensors[index];
auto const_loc = TensorLoc(const_tensor, builder, base_loc);
auto op_or_err =
diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc
index 3f2b8a7..a873a40 100644
--- a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc
+++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc
@@ -231,7 +231,7 @@
}
template <typename T>
-static bool HasValidTFLiteType(Value* value, T& error_handler) {
+static bool HasValidTFLiteType(Value value, T& error_handler) {
// None type is allowed to represent unspecified operands.
if (value->getType().isa<NoneType>()) return true;
@@ -280,7 +280,7 @@
}
auto& bb = fn.getBlocks().front();
- for (auto* arg : bb.getArguments()) {
+ for (auto arg : bb.getArguments()) {
if (!HasValidTFLiteType(arg, fn))
return fn.emitError("invalid TFLite type: ") << arg->getType(), false;
}
@@ -290,7 +290,7 @@
for (auto& inst : bb) {
if (inst.isKnownTerminator()) break;
- for (auto* result : inst.getResults()) {
+ for (auto result : inst.getResults()) {
if (!HasValidTFLiteType(result, inst))
return fn.emitError("invalid TFLite type: ") << result->getType(),
false;
@@ -362,7 +362,7 @@
// Builds TFLite tensor from the given value. `buffer_idx` is index of the
// corresponding buffer. Emits error and returns llvm::None on failure.
- Optional<BufferOffset<tflite::Tensor>> BuildTensor(Value* value,
+ Optional<BufferOffset<tflite::Tensor>> BuildTensor(Value value,
const std::string& name,
unsigned buffer_idx);
@@ -420,7 +420,7 @@
bool IsStatefulOperand(mlir::Operation* op, int operand_index);
// Returns a unique name for `val`.
- std::string UniqueName(mlir::Value* val);
+ std::string UniqueName(mlir::Value val);
ModuleOp module_;
@@ -450,7 +450,7 @@
std::vector<std::string> failed_custom_ops_;
};
-std::string Translator::UniqueName(mlir::Value* val) {
+std::string Translator::UniqueName(mlir::Value val) {
return name_mapper_.GetUniqueName(val);
}
@@ -503,7 +503,7 @@
}
Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
- Value* value, const std::string& name, unsigned buffer_idx) {
+ Value value, const std::string& name, unsigned buffer_idx) {
auto type = value->getType().cast<TensorType>();
// TFLite requires tensor shape only for the inputs and constants.
@@ -917,11 +917,11 @@
bool has_input_attr = false;
InitializeNamesFromAttribute(fn, &has_input_attr);
std::vector<BufferOffset<tflite::Tensor>> tensors;
- llvm::DenseMap<Value*, int> tensor_index_map;
+ llvm::DenseMap<Value, int> tensor_index_map;
// Builds tensor and buffer for argument or operation result. Returns false
// on failure.
- auto build_tensor_and_buffer = [&](Value* value, const std::string& name) {
+ auto build_tensor_and_buffer = [&](Value value, const std::string& name) {
// NoneType represents optional and may be skipped here.
if (value->getType().isa<NoneType>()) {
return true;
@@ -953,7 +953,7 @@
// have associated tensor and buffer. Build FlatBuffer tensor and buffer for
// other functions.
for (unsigned i = 0, e = bb.getNumArguments(); i < e; ++i) {
- mlir::BlockArgument* arg = bb.getArgument(i);
+ mlir::BlockArgument arg = bb.getArgument(i);
std::string name;
if (has_input_attr) name = name_mapper_.GetUniqueName(arg);
if (name.empty()) name = absl::StrCat("arg", i);
@@ -975,7 +975,7 @@
// Fetch operand and result tensor indices.
std::vector<int32_t> operands;
operands.reserve(inst.getNumOperands());
- for (auto* operand : inst.getOperands()) {
+ for (auto operand : inst.getOperands()) {
if (operand->getType().isa<NoneType>())
operands.push_back(kTfLiteOptionalTensor);
else
@@ -983,7 +983,7 @@
}
std::vector<int32_t> results;
results.reserve(inst.getNumOperands());
- for (auto* result : inst.getResults()) {
+ for (auto result : inst.getResults()) {
results.push_back(tensor_index_map.lookup(result));
}
@@ -997,10 +997,10 @@
// Get input and output tensor indices for the subgraph.
std::vector<int32_t> inputs, outputs;
- for (auto* arg : bb.getArguments()) {
+ for (auto arg : bb.getArguments()) {
inputs.push_back(tensor_index_map[arg]);
}
- for (auto* result : bb.getTerminator()->getOperands()) {
+ for (auto result : bb.getTerminator()->getOperands()) {
outputs.push_back(tensor_index_map[result]);
}
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
index 44fcc15..c0cdc6c 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
@@ -301,8 +301,8 @@
return {};
}
-void buildComparisonBinOp(Builder *builder, OperationState &result, Value *lhs,
- Value *rhs) {
+void buildComparisonBinOp(Builder *builder, OperationState &result, Value lhs,
+ Value rhs) {
auto result_type =
OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType());
if (!result_type)
@@ -321,7 +321,7 @@
}
void buildFusedBroadcastableBinOp(Builder *builder, OperationState &result,
- Value *lhs, Value *rhs,
+ Value lhs, Value rhs,
StringAttr fused_activation_function) {
auto result_type =
OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType());
@@ -462,7 +462,7 @@
return op.emitOpError("concatenation dimension must be in [-rank, rank)");
SmallVector<TensorType, 4> operand_types;
- for (Value *operand : op.values())
+ for (Value operand : op.values())
operand_types.push_back(operand->getType().cast<TensorType>());
return VerifyConcatenationOpTypes(op.getOperation(), output_type,
@@ -528,8 +528,8 @@
}
// Remove all empty values.
- SmallVector<Value *, 4> non_empty_values;
- for (Value *value : this->values()) {
+ SmallVector<Value, 4> non_empty_values;
+ for (Value value : this->values()) {
const auto shaped_type = value->getType().cast<ShapedType>();
if (shaped_type.hasStaticShape() && shaped_type.getNumElements() == 0) {
continue;
@@ -609,7 +609,7 @@
//===----------------------------------------------------------------------===//
static void BuildGatherOp(Builder *builder, OperationState &result,
- Value *params, Value *indices, IntegerAttr axis) {
+ Value params, Value indices, IntegerAttr axis) {
auto params_type = params->getType().cast<TensorType>();
auto indices_type = indices->getType().cast<TensorType>();
@@ -704,7 +704,7 @@
if (op.getOperation()->getNumOperands() != op.values_count())
return op.emitOpError("input count should match 'values_count' attribute");
- Value *operand0 = op.getOperand(0);
+ Value operand0 = op.getOperand(0);
auto input_type = operand0->getType().cast<ShapedType>();
// Check axis bounds.
@@ -717,7 +717,7 @@
// Make sure all inputs have the same shape and element type.
// TODO(rahulsp): Simplify once b/135032064 is fixed.
- for (Value *operand : op.getOperands()) {
+ for (Value operand : op.getOperands()) {
auto other_type = operand->getType().cast<ShapedType>();
if (input_type != other_type)
return op.emitOpError("operands should be of the same type. got ")
@@ -880,8 +880,8 @@
return matchFailure();
for (auto input_output :
llvm::zip(pack_op.getOperands(), input_unpack_op.getResults())) {
- Value *pack_input = std::get<0>(input_output);
- Value *unpack_output = std::get<1>(input_output);
+ Value pack_input = std::get<0>(input_output);
+ Value unpack_output = std::get<1>(input_output);
// Make sure the ordering is the same for the pack op & unpack op.
if (pack_input != unpack_output) return matchFailure();
}
@@ -984,8 +984,8 @@
// TopKOp
//===----------------------------------------------------------------------===//
-static void BuildTopKOp(Builder *builder, OperationState &result, Value *input,
- Value *k) {
+static void BuildTopKOp(Builder *builder, OperationState &result, Value input,
+ Value k) {
// Output size is only known if k is constant value. A negative dimension is
// considered dynamic so use -1 here if k is not a constant value.
int const_k = -1;
@@ -1075,7 +1075,7 @@
// Extracts and returns the signed integer constant in a 0-rank integer tensor
// or 1-element 1-rank integer tensor if 'value' is a constant.
-static llvm::Optional<int64_t> ExtractConstantIntFromTensor(Value *value) {
+static llvm::Optional<int64_t> ExtractConstantIntFromTensor(Value value) {
ElementsAttr attr;
if (!matchPattern(value, m_Constant(&attr))) return {};
if (attr.getNumElements() != 1) return {};
@@ -1101,7 +1101,7 @@
ExpectedOutputTypeGetter get_expected_output_type) {
for (int64_t i = 0; i < num_splits; ++i) {
auto expected_output_type = get_expected_output_type(i);
- Value *output = op->getResult(i);
+ Value output = op->getResult(i);
auto output_type = output->getType().dyn_cast<RankedTensorType>();
if (!output_type || output_type != expected_output_type)
return op->emitOpError()
@@ -1443,7 +1443,7 @@
//===----------------------------------------------------------------------===//
static void BuildSelectV2Op(Builder *builder, OperationState &result,
- Value *cond, Value *x, Value *y) {
+ Value cond, Value x, Value y) {
auto operand_type =
OpTrait::util::getBroadcastedType(x->getType(), y->getType());
diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
index 6cba241..b8b0ef6 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
@@ -224,7 +224,7 @@
//===----------------------------------------------------------------------===//
def TFL_BroadcastableBinaryBuilder : OpBuilder<
- "Builder *builder, OperationState &result, Value *lhs, Value *rhs",
+ "Builder *builder, OperationState &result, Value lhs, Value rhs",
[{
auto resultType =
OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType());
@@ -235,7 +235,7 @@
}]>;
def TFL_FusedBroadcastableBinaryBuilder : OpBuilder<
- "Builder *builder, OperationState &result, Value *lhs, Value *rhs, "
+ "Builder *builder, OperationState &result, Value lhs, Value rhs, "
"StringAttr fusedActivationFunction",
[{
buildFusedBroadcastableBinOp(
@@ -243,7 +243,7 @@
}]>;
def TFL_ComparisonBinaryBuilder : OpBuilder<
- "Builder *builder, OperationState &result, Value *lhs, Value *rhs",
+ "Builder *builder, OperationState &result, Value lhs, Value rhs",
[{
buildComparisonBinOp(builder, result, lhs, rhs);
}]>;
@@ -669,7 +669,7 @@
let builders =
[
OpBuilder<"Builder *builder, OperationState &result, "
- "Value *params, Value *indices, IntegerAttr axis",
+ "Value params, Value indices, IntegerAttr axis",
[{ BuildGatherOp(builder, result, params, indices, axis); }]>
];
@@ -932,7 +932,7 @@
let builders =
[
OpBuilder<
- "Builder *builder, OperationState &result, Value *lhs, Value *rhs",
+ "Builder *builder, OperationState &result, Value lhs, Value rhs",
[{
buildComparisonBinOp(builder, result, lhs, rhs);
}]>
@@ -2081,7 +2081,7 @@
// TODO(jpienaar): autogenerate this.
let builders = [OpBuilder<"Builder *builder, OperationState &result, "
- "Value *condition, Value *x, Value *y",
+ "Value condition, Value x, Value y",
[{
auto resultType = x->getType();
result.addOperands({condition, x, y});
@@ -2109,7 +2109,7 @@
let results = (outs AnyTensor:$output);
let builders = [OpBuilder<"Builder *builder, OperationState &result, "
- "Value *cond, Value *x, Value *y",
+ "Value cond, Value x, Value y",
[{
BuildSelectV2Op(builder, result, cond, x, y);
}]>];
@@ -2303,7 +2303,7 @@
I32Tensor:$indices);
let builders = [OpBuilder<"Builder *builder, OperationState &result, "
- "Value *input, Value *k",
+ "Value input, Value k",
[{ BuildTopKOp(builder, result, input, k); }]>];
let hasOptions = 1;
@@ -2359,14 +2359,14 @@
}];
let arguments = (ins
- TensorOf<[F32, I8, I32, QI8, QUI8]>:$input,
+ TensorOf<[F32, I1, I8, I32, QI8, QUI8]>:$input,
I32Attr:$num,
I32Attr:$axis
);
let results = (outs
- Variadic<TensorOf<[F32, I8, I32, QI8, QUI8]>>:$outputs
+ Variadic<TensorOf<[F32, I1, I8, I32, QI8, QUI8]>>:$outputs
);
let verifier = [{ return Verify(*this); }];
diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc
index fcd2c38..7738f1e 100644
--- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc
+++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc
@@ -277,7 +277,6 @@
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, quant_specs, result, &pm);
-
if (toco_flags.has_dump_graphviz_dir()) {
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
// rename once we enable the new converter feature flag.
diff --git a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc
index 0326d12..c6a94f4 100644
--- a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc
@@ -70,14 +70,14 @@
void ImportAsStatsOps(OpBuilder b, Operation *op, int index,
const QuantParamsEntry &info);
- void InsertStatsOpAtResult(OpBuilder b, Value *res, ElementsAttr layer_stats,
+ void InsertStatsOpAtResult(OpBuilder b, Value res, ElementsAttr layer_stats,
ElementsAttr axis_stats, IntegerAttr axis);
// If the index is out of range, this method returns false. Otherwise it
// returns true if the value is a float tensor.
bool IsQuantizableResult(Operation *op, int index) {
if (index < 0 || index >= op->getNumResults()) return false;
- Value *res = op->getResult(index);
+ Value res = op->getResult(index);
return res->getType().isa<ShapedType>() &&
res->getType().cast<ShapedType>().getElementType().isa<FloatType>();
}
@@ -117,7 +117,7 @@
return false;
}
-void ImportQuantStatsPass::InsertStatsOpAtResult(OpBuilder b, Value *res,
+void ImportQuantStatsPass::InsertStatsOpAtResult(OpBuilder b, Value res,
ElementsAttr layer_stats,
ElementsAttr axis_stats,
IntegerAttr axis) {
diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc
index fbf6f7f..7eb5bdf 100644
--- a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc
@@ -183,20 +183,20 @@
// of the op.
void QuantizeOpResult(Operation *op, int index, QuantParams params);
- void QuantizeArg(BlockArgument *arg, QuantParams params);
+ void QuantizeArg(BlockArgument arg, QuantParams params);
// Inserts the Quantize and Dequantize ops to quantize the value and returns
// the Quantize op.
- void QuantizeValue(Value *value, QuantParams params, Location loc);
+ void QuantizeValue(Value value, QuantParams params, Location loc);
// Inserts the Quantize ops for requantizing the index-th result of the op.
void RequantizeOpResult(Operation *op, int index, RequantizeState *state);
- void RequantizeArg(BlockArgument *arg, RequantizeState *state);
+ void RequantizeArg(BlockArgument arg, RequantizeState *state);
// Inserts the Quantize and Dequantize ops to quantize the value and returns
// the Quantize op.
- void RequantizeValue(Value *value, RequantizeState *state, Location loc);
+ void RequantizeValue(Value value, RequantizeState *state, Location loc);
// A heuristic to get the quantization parameter satisfies the same scale
// constraints for the op. Returns an empty option if this quantization
@@ -213,7 +213,7 @@
return states_[result_states_[{op, index}]];
}
- QuantState &GetArgQuantState(BlockArgument *arg) {
+ QuantState &GetArgQuantState(BlockArgument arg) {
return states_[arg_states_[arg]];
}
@@ -227,7 +227,7 @@
return rescale_states_[result_states_[{op, index}]];
}
- RequantizeState &GetArgRequantizeState(BlockArgument *arg) {
+ RequantizeState &GetArgRequantizeState(BlockArgument arg) {
return rescale_states_[arg_states_[arg]];
}
@@ -235,32 +235,45 @@
// `as_result` is true or index-th operand if `as_result` is false. The state
// is immutable if the type is a quantized type. Returns the index of this
// new state in the state vector.
- int InitializeState(Operation *op, int index, Value *val, bool as_result);
+ int InitializeState(Operation *op, int index, Value val, bool as_result);
+
+ // Sets the state of an argument. If this value is cached, uses the cached
+ // result without creating new entry in the state vector. Otherwise, allocate
+ // a new entry in the state vector.
+ void InitializeArgState(BlockArgument arg, Value in,
+ llvm::DenseMap<Value, int> *cache) {
+ auto cached = cache->insert({in, 0});
+ if (!cached.second) {
+ arg_states_[arg] = cached.first->second;
+ return;
+ }
+ QuantParams params =
+ quant::QuantizedType::getQuantizedElementType(in->getType());
+ bool immutable = !EmptyParams(params);
+ int next_state_index = states_.size();
+ states_.push_back({params, immutable});
+ arg_states_[arg] = next_state_index;
+ cached.first->second = next_state_index;
+ }
// Sets the state of the index-th operand of the op. If this operand is
// cached, uses the cached result without creating new entry in the state
// vector. Otherwise, allocate a new entry in the state vector.
- void InitializeOperandState(Operation *op, int index, Value *in,
- llvm::DenseMap<Value *, int> *cache,
- bool is_argument) {
+ void InitializeOperandState(Operation *op, int index, Value in,
+ llvm::DenseMap<Value, int> *cache) {
auto cached = cache->insert({in, 0});
if (!cached.second) {
operand_states_.insert({{op, index}, cached.first->second});
return;
}
cached.first->second = InitializeState(op, index, in, /*as_result=*/false);
- if (is_argument) {
- auto *arg = llvm::cast<BlockArgument>(in);
- arg_states_[arg] = cached.first->second;
- args_.push_back(arg);
- }
}
// Sets the state of the index-th result of the op. If this result is cached,
// uses the cached result without creating new entry in the state vector.
// Otherwise, allocate a new entry in the state vector.
- void InitializeResultState(Operation *op, int index, Value *res,
- llvm::DenseMap<Value *, int> *cache) {
+ void InitializeResultState(Operation *op, int index, Value res,
+ llvm::DenseMap<Value, int> *cache) {
auto cached = cache->insert({res, 0});
if (!cached.second) {
result_states_.insert({{op, index}, cached.first->second});
@@ -279,7 +292,8 @@
// rest are weights.
llvm::DenseSet<Operation *> weights_;
- // The weights require narrow_range quantization. If the value of this map is
+ // The weights require narrow_range quantization. This map collects all the
+ // weight operands defined by the op quant spec. If the value of the entry is
// positive, per-channel quantization is required.
llvm::DenseMap<Operation *, int> optimized_weights_;
@@ -300,11 +314,11 @@
// results and arguments.
llvm::DenseMap<OpValue, int> operand_states_;
llvm::DenseMap<OpValue, int> result_states_;
- llvm::DenseMap<BlockArgument *, int> arg_states_;
+ llvm::DenseMap<BlockArgument, int> arg_states_;
// This vector is to preserve the arguments order, so the newly inserted
// quantized ops for the arguments are deterministically ordered.
- llvm::SmallVector<BlockArgument *, 4> args_;
+ llvm::SmallVector<BlockArgument, 4> args_;
OpQuantSpecGetter op_quant_spec_getter_;
};
@@ -321,7 +335,7 @@
return true;
}
-int QuantizationDriver::InitializeState(Operation *op, int index, Value *val,
+int QuantizationDriver::InitializeState(Operation *op, int index, Value val,
bool as_result) {
QuantParams params =
quant::QuantizedType::getQuantizedElementType(val->getType());
@@ -338,7 +352,7 @@
bool QuantizationDriver::SetConstantResultParams(Operation *op) {
ElementsAttr attr;
- Value *res = op->getResult(0);
+ Value res = op->getResult(0);
if (!matchPattern(res, m_Constant(&attr))) {
return false;
}
@@ -428,16 +442,16 @@
void QuantizationDriver::QuantizeOpResult(Operation *op, int index,
QuantParams params) {
builder_.setInsertionPoint(op->getBlock(), ++Block::iterator(op));
- Value *original_result = op->getResult(index);
+ Value original_result = op->getResult(index);
QuantizeValue(original_result, params, op->getLoc());
}
-void QuantizationDriver::QuantizeArg(BlockArgument *arg, QuantParams params) {
+void QuantizationDriver::QuantizeArg(BlockArgument arg, QuantParams params) {
builder_.setInsertionPointToStart(arg->getOwner());
QuantizeValue(arg, params, builder_.getUnknownLoc());
}
-void QuantizationDriver::QuantizeValue(Value *value, QuantParams params,
+void QuantizationDriver::QuantizeValue(Value value, QuantParams params,
Location loc) {
Type expressed_type = value->getType();
Type new_type = params.castFromExpressedType(expressed_type);
@@ -459,7 +473,7 @@
RequantizeState *state) {
if (state->pos == RequantizeState::NO_REQUANTIZE) return;
builder_.setInsertionPointAfter(op);
- Value *value = op->getResult(index);
+ Value value = op->getResult(index);
if (state->pos == RequantizeState::ON_OUTPUT) {
Operation *user = value->getUses().begin().getUser();
if (llvm::isa<TFL::QuantizeOp>(user)) {
@@ -471,9 +485,9 @@
RequantizeValue(value, state, op->getLoc());
}
-void QuantizationDriver::RequantizeArg(BlockArgument *arg,
+void QuantizationDriver::RequantizeArg(BlockArgument arg,
RequantizeState *state) {
- Value *value = arg;
+ Value value = arg;
builder_.setInsertionPointToStart(arg->getOwner());
if (value->hasOneUse()) {
auto user = value->use_begin().getUser();
@@ -485,7 +499,7 @@
RequantizeValue(value, state, builder_.getUnknownLoc());
}
-void QuantizationDriver::RequantizeValue(Value *value, RequantizeState *state,
+void QuantizationDriver::RequantizeValue(Value value, RequantizeState *state,
Location loc) {
Type new_type;
if (state->pos == RequantizeState::ON_INPUT) {
@@ -586,7 +600,7 @@
auto type = cst.getType().dyn_cast<ShapedType>();
if (!type || !type.getElementType().isa<FloatType>()) return;
- Value *value = cst.getResult();
+ Value value = cst.getResult();
SmallVector<std::pair<Operation *, int>, 4> bias_users;
bool used_as_weight = false;
for (auto &use : value->getUses()) {
@@ -629,7 +643,20 @@
}
void QuantizationDriver::SetupAllStates() {
- llvm::DenseMap<Value *, int> value_to_state;
+ llvm::DenseMap<Value, int> value_to_state;
+
+ for (auto arg : fn_.getArguments()) {
+ args_.push_back(arg);
+ Value value = arg;
+ // If the argument is quantized, it should only has one user.
+ if (arg->hasOneUse()) {
+ auto user = value->use_begin().getUser();
+ if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
+ value = q.output();
+ }
+ }
+ InitializeArgState(arg, value, &value_to_state);
+ }
fn_.walk([&](Operation *op) {
if (op->isKnownTerminator() ||
@@ -638,21 +665,19 @@
work_list_.push_back(op);
for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
- auto *operand = op->getOperand(i);
- bool is_argument = true;
+ auto operand = op->getOperand(i);
if (auto *inst = operand->getDefiningOp()) {
// If the operand comes from a tfl.dequantize op, we use the quantized
// input of this tfl.dequantize op to set the state.
if (auto dq = llvm::dyn_cast<TFL::DequantizeOp>(inst)) {
operand = dq.input();
}
- is_argument = false;
}
- InitializeOperandState(op, i, operand, &value_to_state, is_argument);
+ InitializeOperandState(op, i, operand, &value_to_state);
}
for (int res = 0, e = op->getNumResults(); res != e; ++res) {
- auto *result = op->getResult(res);
+ auto result = op->getResult(res);
// If the result has been quantized, it should only be used by a
// tfl.quantize op. For this case, we uses the quantized result to
// create the state and mark it immutable.
@@ -746,7 +771,7 @@
}
void QuantizationDriver::Finalize() {
- for (auto *arg : args_) {
+ for (auto arg : args_) {
auto &state = GetArgQuantState(arg);
auto &requantize = GetArgRequantizeState(arg);
if (state.IsEmpty() ||
diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc
index ca10809..398c751 100644
--- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc
@@ -412,7 +412,7 @@
if (user->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() &&
!PreferResultScale(user)) {
- for (Value* res : user->getResults()) {
+ for (Value res : user->getResults()) {
if (res->hasOneUse()) {
if (auto next_stats = llvm::dyn_cast<quant::StatisticsOp>(
*res->getUsers().begin())) {
diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h
index 0f7ec91..e84b77a 100644
--- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h
+++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h
@@ -161,7 +161,7 @@
if (op->getNumResults() != 1) {
return matchFailure();
}
- Value* quantized_value = op->getResult(0);
+ Value quantized_value = op->getResult(0);
for (Operation* quantized_op : quantized_value->getUsers()) {
// If it is requantize op, we shouldn't rewrite this op.
if (llvm::isa<Q>(quantized_op) || llvm::isa<DQ>(quantized_op)) {
@@ -176,7 +176,7 @@
// Collect all the quantized inputs and "clone" the matched op by these
// inputs.
- SmallVector<Value*, 4> inputs;
+ SmallVector<Value, 4> inputs;
inputs.reserve(quantized_op->getNumOperands());
for (auto operand : quantized_op->getOperands()) {
Type operand_type = operand->getType();
@@ -201,12 +201,12 @@
// Collect all the quantized outputs and replace them by the results of
// the new quantized op.
- llvm::SmallDenseMap<Value*, int> outputs_replaced;
+ llvm::SmallDenseMap<Value, int> outputs_replaced;
SmallVector<Type, 4> output_types;
output_types.reserve(quantized_op->getNumResults());
for (auto enumerated_result :
llvm::enumerate(quantized_op->getResults())) {
- Value* result = enumerated_result.value();
+ Value result = enumerated_result.value();
Type result_type = result->getType();
// Add this to the test coverage once we create test ops with none type
// results.
diff --git a/tensorflow/compiler/mlir/lite/tests/debuginfo/BUILD b/tensorflow/compiler/mlir/lite/tests/debuginfo/BUILD
index 2498a10..61faf45 100644
--- a/tensorflow/compiler/mlir/lite/tests/debuginfo/BUILD
+++ b/tensorflow/compiler/mlir/lite/tests/debuginfo/BUILD
@@ -10,7 +10,9 @@
driver = "@local_config_mlir//:run_lit.sh",
test_file_exts = [
"pbtxt",
- "py",
+ # TODO(fengliuai): reenable these tests after the fused loc is
+ # supported in the diagnostic handler.
+ # "py",
],
)
diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/custom_opdef.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/custom_opdef.pbtxt
index 0fcee7d..8045271 100644
--- a/tensorflow/compiler/mlir/lite/tests/end2end/custom_opdef.pbtxt
+++ b/tensorflow/compiler/mlir/lite/tests/end2end/custom_opdef.pbtxt
@@ -38,6 +38,6 @@
# CHECK: func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<*xi32>
# CHECK: attributes {tf.entry_function = {inputs = "input0,input1", outputs = "output"}} {
-# CHECK-NEXT: %0 = "tf.BannaPotatoSaladWithColeslaw"(%arg0, %arg1) {T = i32, device = "", name = "output"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
+# CHECK-NEXT: %0 = "tf.BannaPotatoSaladWithColeslaw"(%arg0, %arg1) {T = i32, device = ""} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
# CHECK-NEXT: return %0 : tensor<*xi32>
# CHECK-NEXT: }
diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir
index bab6433..5a07946 100644
--- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir
@@ -140,22 +140,6 @@
// CHECK-SAME: fused_activation_function = "RELU6"
}
-// CHECK-LABEL: intermOpUsedTwice
-func @intermOpUsedTwice(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> (tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>) {
- %cst = constant dense<1.5> : tensor<16xf32>
- %cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
- %0 = "tfl.conv_2d"(%arg0, %arg1, %cst_0) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
- %1 = "tfl.add"(%0, %cst) {fused_activation_function = "RELU6"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
- return %0, %1 : tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>
-
- // CHECK: %cst = constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00,
- // CHECK: %cst_0 = constant dense<[2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00,
- // CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %cst) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32}
- // CHECK: %1 = "tfl.conv_2d"(%arg0, %arg1, %cst_0) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32}
- // CHECK: return %0, %1
-
-}
-
// CHECK-LABEL: @fuseMulIntoFullyConnected
func @fuseMulIntoFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> {
%cst0 = constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32>
@@ -631,3 +615,18 @@
// CHECK: %[[RES:.*]] = tfl.add %arg0, %arg1 {fused_activation_function = "RELU_N1_TO_1"}
// CHECK: return %[[RES]]
}
+
+// CHECK-LABEL: NotfuseAddIntoConv2d_MultipleUsers
+func @NotfuseAddIntoConv2d_MultipleUsers(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> (tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>) {
+ %cst = constant dense<1.5> : tensor<16xf32>
+ %cst_1 = constant dense<3.5> : tensor<16xf32>
+ %cst_0 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
+ %0 = "tfl.conv_2d"(%arg0, %arg1, %cst_0) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
+ %1 = "tfl.add"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
+ %2 = "tfl.add"(%0, %cst_1) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
+ return %1, %2 : tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>
+
+ // CHECK: %[[tfl_conv2d:[0-9].*]] = "tfl.conv_2d"
+ // CHECK: tfl.add
+ // CHECK-NEXT: tfl.add
+}
diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir
index cd11117..fc9c550 100644
--- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir
@@ -379,26 +379,26 @@
// CHECK: %2 = "tfl.dequantize"(%arg0) : (tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<1x2xf32>
// CHECK: %3 = "tfl.concatenation"(%2, %1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32>
// CHECK: %4 = "tfl.quantize"(%3) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
-// CHeCK: return %4 : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
+// CHECK: return %4 : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
}
// CHECK-LABEL: QuantizeConcatResToAllRequantize
func @QuantizeConcatResToAllRequantize(tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 0.1:128>> {
^bb0(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>):
- %0 = "tfl.quantize"(%arg0) {qtype = tensor<2x!quant.uniform<u8:f32, 2.0:128>>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform<u8:f32, 2.0:128>>
+ %0 = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform<u8:f32, 2.0:128>>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform<u8:f32, 2.0:128>>
%1 = "tfl.dequantize"(%0) : (tensor<1x2x!quant.uniform<u8:f32, 2.0:128>>) -> tensor<1x2xf32>
%2 = "tfl.concatenation"(%1, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32>
%3 = "tfl.quantize"(%2) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
return %3 : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
-// CHECK %0 = "tfl.quantize"(%arg0) {qtype = tensor<2x!quant.uniform<u8:f32, 2.000000e+00:128>>} : (tensor<2xf32>) -> tensor<2x!quant.uniform<u8:f32, 2.000000e+00:128>>
-// CHECK %1 = "tfl.quantize"(%0) {qtype = tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x!quant.uniform<u8:f32, 2.000000e+00:128>>) -> tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>
-// CHECK %2 = "tfl.dequantize"(%1) : (tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<2xf32>
-// CHECK %3 = "tfl.quantize"(%arg1) {qtype = tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2xf32>) -> tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>
-// CHECK %4 = "tfl.dequantize"(%3) : (tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<2xf32>
-// CHECK %5 = "tfl.concatenation"(%2, %4) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2x2xf32>
-// CHECK %6 = "tfl.quantize"(%5) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
-// CHECK return %6 : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
+// CHECK: %[[Q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
+// CHECK: %[[DQ1:.*]] = "tfl.dequantize"(%[[Q1]]) : (tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<1x2xf32>
+// CHECK: %[[Q0:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform<u8:f32, 2.000000e+00:128>>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform<u8:f32, 2.000000e+00:128>>
+// CHECK: %[[RQ0:.*]] = "tfl.quantize"(%[[Q0]]) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<1x2x!quant.uniform<u8:f32, 2.000000e+00:128>>) -> tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
+// CHECK: %[[DQ0:.*]] = "tfl.dequantize"(%[[RQ0]]) : (tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<1x2xf32>
+// CHECK: %[[CONC:.*]] = "tfl.concatenation"(%[[DQ0]], %[[DQ1]]) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32>
+// CHECK: %[[Q:.*]] = "tfl.quantize"(%[[CONC]]) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
+// CHECK: return %[[Q]] : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
}
// CHECK-LABEL: QuantizeConcatResToAllRequantizeArg
@@ -409,13 +409,13 @@
%3 = "tfl.quantize"(%2) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
return %3 : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
-// CHECK %1 = "tfl.quantize"(%arg0) {qtype = tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<1x2x!quant.uniform<u8:f32, 2.000000e+00:128>>) -> tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
-// CHECK %2 = "tfl.dequantize"(%1) : (tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<1x2xf32>
-// CHECK %3 = "tfl.quantize"(%arg1) {qtype = tensor<2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
-// CHECK %4 = "tfl.dequantize"(%3) : (tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<1x2xf32>
-// CHECK %5 = "tfl.concatenation"(%2, %4) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32>
-// CHECK %6 = "tfl.quantize"(%5) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
-// CHECK return %6 : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
+// CHECK: %[[Q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
+// CHECK: %[[DQ1:.*]] = "tfl.dequantize"(%[[Q1]]) : (tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<1x2xf32>
+// CHECK: %[[RQ0:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<1x2x!quant.uniform<u8:f32, 2.000000e+00:128>>) -> tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
+// CHECK: %[[DQ0:.*]] = "tfl.dequantize"(%[[RQ0]]) : (tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>) -> tensor<1x2xf32>
+// CHECK: %[[CONC:.*]] = "tfl.concatenation"(%[[DQ0]], %[[DQ1]]) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32>
+// CHECK: %[[Q:.*]] = "tfl.quantize"(%[[CONC]]) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
+// CHECK: return %[[Q]] : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
}
// CHECK-LABEL: RequantizeAlreadyQuantizedModel
diff --git a/tensorflow/compiler/mlir/lite/tests/quantize.mlir b/tensorflow/compiler/mlir/lite/tests/quantize.mlir
index 225123e..89d1e7c 100644
--- a/tensorflow/compiler/mlir/lite/tests/quantize.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/quantize.mlir
@@ -204,8 +204,9 @@
%3 = "tfl.quantize"(%2) {qtype = tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
return %3 : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
-// CHECK: %[[q:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>}
-// CHECK: %[[cc:.*]] = "tfl.concatenation"(%arg0, %[[q]]) {axis = 0 : i32, fused_activation_function = "NONE"}
+// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>}
+// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>} : (tensor<1x2x!quant.uniform<u8:f32, 2.000000e+00:128>>) -> tensor<1x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
+// CHECK: %[[cc:.*]] = "tfl.concatenation"(%[[q0]], %[[q1]]) {axis = 0 : i32, fused_activation_function = "NONE"}
// CHECK: return %[[cc]] : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
}
diff --git a/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc b/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc
index 52eb621..e9e0b26 100644
--- a/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc
@@ -188,10 +188,10 @@
// This function will process the aggregated inputs based on different
// strategies like "first", "last", "stack".
- std::map<int, Value*> GetAggregatedInputs(OpBuilder* builder) {
- std::map<int, Value*> aggregated_inputs;
+ std::map<int, Value> GetAggregatedInputs(OpBuilder* builder) {
+ std::map<int, Value> aggregated_inputs;
for (const auto& kv : inputs) {
- Value* op_input = nullptr;
+ Value op_input = nullptr;
const AggregatedOperand& operand = kv.second;
// Dealing with "stack" strategy:
// This breaks into two parts:
@@ -203,7 +203,7 @@
if (operand.ops.size() == 1) {
// If ops size is 1, it will be simply expanding dimensions at dim 0.
Operation* current_identity_op = operand.ops.begin()->second;
- Value* input = current_identity_op->getOperand(0);
+ Value input = current_identity_op->getOperand(0);
RankedTensorType input_type =
input->getType().cast<RankedTensorType>();
// The Reshape will be {1, (original_shape)}
@@ -234,8 +234,8 @@
} else {
// Insert a pack op to pack all the inputs together.
- std::vector<Value*> pack_input_operands;
- std::vector<Value*> packed_input_consumers;
+ std::vector<Value> pack_input_operands;
+ std::vector<Value> packed_input_consumers;
for (int i = 0, e = operand.ops.size(); i < e; ++i) {
pack_input_operands.push_back(operand.ops.at(i)->getOperand(0));
packed_input_consumers.push_back(operand.ops.at(i)->getResult(0));
@@ -288,7 +288,7 @@
const AggregatedOperand& operand = kv.second;
if (operand.aggregation == kStrategyStack) {
const int output_numer = operand.ops.size();
- Value* first_output = operand.ops.at(0)->getOperand(0);
+ Value first_output = operand.ops.at(0)->getOperand(0);
RankedTensorType first_output_type =
first_output->getType().cast<RankedTensorType>();
// The aggregated output shape will be {N, original_shape}.
@@ -300,11 +300,11 @@
aggregated_output_types[kv.first] =
RankedTensorType::get(shape, first_output_type.getElementType());
} else if (operand.aggregation == kStrategyLast) {
- Value* last_output =
+ Value last_output =
operand.ops.at(operand.ops.size() - 1)->getOperand(0);
aggregated_output_types[kv.first] = last_output->getType();
} else {
- Value* first_output = operand.ops.at(0)->getOperand(0);
+ Value first_output = operand.ops.at(0)->getOperand(0);
aggregated_output_types[kv.first] = first_output->getType();
}
}
@@ -507,14 +507,14 @@
Operation* BuildFusedFuncOp(StringRef func_name, StringRef fused_func_type,
Operation* insert_before_op,
- const std::map<int, Value*>& inputs,
+ const std::map<int, Value>& inputs,
const std::map<int, Type>& output_types,
OpBuilder* builder, ModuleOp* module_op) {
SmallVector<Type, 4> input_types;
- SmallVector<Value*, 4> input_values;
+ SmallVector<Value, 4> input_values;
SmallVector<int, 4> input_indexes;
for (const auto& kv : inputs) {
- Value* input = kv.second;
+ Value input = kv.second;
input_types.push_back(input->getType());
input_values.push_back(input);
input_indexes.push_back(kv.first);
@@ -588,7 +588,7 @@
llvm::DenseSet<Operation*> reachable_ops;
std::queue<Operation*> ops_queue;
for (auto& input_op : input_ops) {
- for (Value* value : input_op->getOperands()) {
+ for (Value value : input_op->getOperands()) {
Operation* op = value->getDefiningOp();
if (op != nullptr) ops_queue.push(op);
}
@@ -598,7 +598,7 @@
Operation* current_op = ops_queue.front();
ops_queue.pop();
reachable_ops.insert(current_op);
- for (Value* value : current_op->getOperands()) {
+ for (Value value : current_op->getOperands()) {
Operation* upstream_op = value->getDefiningOp();
// Not visited, put it into the queue.
if (upstream_op != nullptr &&
@@ -625,7 +625,7 @@
BfsForReachableOps(ophint_composite_op.GetAllOutputOps());
// Step 3, deal with inputs aggregation strategies.
- const std::map<int, Value*>& aggregated_inputs =
+ const std::map<int, Value>& aggregated_inputs =
ophint_composite_op.GetAggregatedInputs(builder);
// Step 4, get aggregated output types.
diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc
index ed3a9ea..30b93cc 100644
--- a/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc
@@ -92,15 +92,15 @@
if (call_op.getNumResults() != 1) return failure();
// Inputs is indexed at 0.
- Value* input = call_op.getOperand(0);
+ Value input = call_op.getOperand(0);
// Input_weight is indexed at 1.
- Value* weight = call_op.getOperand(1);
+ Value weight = call_op.getOperand(1);
// Recurrent_weight is indexed at 2.
- Value* recurrent_weight = call_op.getOperand(2);
+ Value recurrent_weight = call_op.getOperand(2);
// Bias is indexed at 3.
- Value* bias = call_op.getOperand(3);
+ Value bias = call_op.getOperand(3);
// Hidden_state is indexed at 4.
- Value* hidden_state = call_op.getOperand(4);
+ Value hidden_state = call_op.getOperand(4);
// Build Output.
auto output_type = call_op.getResult(0)->getType();
@@ -127,7 +127,7 @@
auto input_index_attr = composite_func_op.getAttr(kTfLiteFunctionInputIndex)
.cast<ArrayAttr>()
.getValue();
- llvm::DenseMap<int, Value*> fused_ops_index_to_call_op_args;
+ llvm::DenseMap<int, Value> fused_ops_index_to_call_op_args;
for (int i = 0; i < call_op.getNumOperands(); ++i) {
int input_index = input_index_attr[i].cast<IntegerAttr>().getInt();
@@ -139,7 +139,7 @@
// We encounter some optional arguments not filled, so we need to create an
// empty Value.
- Value* none_value;
+ Value none_value;
if (call_op.getNumOperands() <
kUnidirectionalSequenceLSTMOpTotalIArgumentNum) {
builder->setInsertionPoint(call_op.getOperation());
@@ -148,7 +148,7 @@
}
// Prepare all operands for the UnidirectionalSequenceLSTMOp.
- SmallVector<Value*, kUnidirectionalSequenceLSTMOpTotalIArgumentNum> operands;
+ SmallVector<Value, kUnidirectionalSequenceLSTMOpTotalIArgumentNum> operands;
for (int i = 0; i < kUnidirectionalSequenceLSTMOpTotalIArgumentNum; ++i) {
auto operand_it = fused_ops_index_to_call_op_args.find(i);
if (operand_it == fused_ops_index_to_call_op_args.end()) {
@@ -169,7 +169,7 @@
if (call_op.getNumResults() > 1) {
for (int i = 0; i < call_op.getNumResults() - 1; ++i) {
// This one should not be used.
- Value* unused_output = call_op.getResult(i);
+ Value unused_output = call_op.getResult(i);
if (!unused_output->use_empty()) return failure();
}
}
@@ -206,7 +206,7 @@
LogicalResult build_fused_op_result = BuildUnidirectionalSequenceLSTMOp(
composite_func_op, call_op, builder, &fused_op);
if (failed(build_fused_op_result)) return build_fused_op_result;
- Value* call_output = call_op.getResult(call_op.getNumResults() - 1);
+ Value call_output = call_op.getResult(call_op.getNumResults() - 1);
if (call_output->getType() != fused_op->getResult(0)->getType()) {
return failure();
}
diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
index 0b4b0dd..e07dcdb 100644
--- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
@@ -71,7 +71,7 @@
auto values = op->getOperands();
int index = 0;
ArrayRef<int64_t> shape;
- for (Value* value : values) {
+ for (Value value : values) {
auto shaped_type = value->getType().dyn_cast<ShapedType>();
if (!shaped_type && !shaped_type.hasStaticShape()) {
return false;
@@ -183,7 +183,7 @@
Operation* op, PatternRewriter& rewriter) const {
auto tf_pack_op = cast<TF::PackOp>(op);
- SmallVector<Value*, 4> values(tf_pack_op.values());
+ SmallVector<Value, 4> values(tf_pack_op.values());
auto output_type = tf_pack_op.output()->getType();
auto values_count = rewriter.getI32IntegerAttr(tf_pack_op.N());
// Axis can be negative.
@@ -198,8 +198,8 @@
Operation* op, PatternRewriter& rewriter) const {
auto tf_reshape_op = cast<TF::ReshapeOp>(op);
- auto* input = tf_reshape_op.tensor();
- auto* shape = tf_reshape_op.shape();
+ auto input = tf_reshape_op.tensor();
+ auto shape = tf_reshape_op.shape();
ShapedType shape_type = shape->getType().cast<ShapedType>();
// The tfl reshape's #2 operand needs to i32 tensor type, so we have to cast.
@@ -222,7 +222,7 @@
Operation* op, PatternRewriter& rewriter) const {
auto tf_split_op = cast<TF::SplitOp>(op);
- auto output_types = functional::map([](Value* v) { return v->getType(); },
+ auto output_types = functional::map([](Value v) { return v->getType(); },
tf_split_op.output());
// Number of splits cannot be negative.
auto num_split = rewriter.getI32IntegerAttr(tf_split_op.num_split());
@@ -237,7 +237,7 @@
Operation* op, PatternRewriter& rewriter) const {
auto tf_splitv_op = cast<TF::SplitVOp>(op);
- auto output_types = functional::map([](Value* v) { return v->getType(); },
+ auto output_types = functional::map([](Value v) { return v->getType(); },
tf_splitv_op.output());
// Number of splits cannot be negative.
auto num_split = rewriter.getI32IntegerAttr(tf_splitv_op.num_split());
@@ -248,9 +248,9 @@
return matchSuccess();
}
-Value* PadStridedSliceAttributeArray(Operation* op, PatternRewriter& rewriter,
- Value* attribute,
- ArrayRef<int32_t> padding_val, int* mask) {
+Value PadStridedSliceAttributeArray(Operation* op, PatternRewriter& rewriter,
+ Value attribute,
+ ArrayRef<int32_t> padding_val, int* mask) {
DenseIntElementsAttr dense_elem_attr;
SmallVector<int32_t, 8> padded_val;
@@ -305,17 +305,17 @@
// Pad `begin` array with zero values and update the `begin_mask`.
SmallVector<int32_t, 8> begin_pad_val(num_input_dims, 0);
int begin_mask = tf_strided_slice_op.begin_mask().getSExtValue();
- Value* padded_begin = PadStridedSliceAttributeArray(
+ Value padded_begin = PadStridedSliceAttributeArray(
op, rewriter, tf_strided_slice_op.begin(), begin_pad_val, &begin_mask);
// Pad `end` array with `input_shape` and update the `end_mask`.
int end_mask = tf_strided_slice_op.end_mask().getSExtValue();
auto input_shape = ranked_input_type.getShape();
SmallVector<int32_t, 8> end_pad_val(input_shape.begin(), input_shape.end());
- Value* padded_end = PadStridedSliceAttributeArray(
+ Value padded_end = PadStridedSliceAttributeArray(
op, rewriter, tf_strided_slice_op.end(), end_pad_val, &end_mask);
// Pad `strides` array with ones.
SmallVector<int32_t, 8> strides_pad_val(num_input_dims, 1);
- Value* padded_strides = PadStridedSliceAttributeArray(
+ Value padded_strides = PadStridedSliceAttributeArray(
op, rewriter, tf_strided_slice_op.strides(), strides_pad_val, nullptr);
rewriter.replaceOpWithNewOp<TFL::StridedSliceOp>(
op, tf_strided_slice_op.output()->getType(), tf_strided_slice_op.input(),
@@ -335,8 +335,8 @@
Operation* op, PatternRewriter& rewriter) const {
auto tf_unpack_op = cast<TF::UnpackOp>(op);
- auto* input = tf_unpack_op.value();
- auto output_types = functional::map([](Value* v) { return v->getType(); },
+ auto input = tf_unpack_op.value();
+ auto output_types = functional::map([](Value v) { return v->getType(); },
tf_unpack_op.output());
auto num = rewriter.getI32IntegerAttr(tf_unpack_op.num());
// Axis can be negative.
diff --git a/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc b/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc
index f1668b0..c1d567e6 100644
--- a/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc
@@ -50,13 +50,13 @@
// Create LSTM gates with different weights for input, recurrent and
// cell state, and also the layer normalization parameters.
- Operation* CreateGate(Location loc, Value* in, Value* in_w, Value* rec,
- Value* rec_w,
- llvm::Optional<std::pair<Value*, Value*>> cell,
- Value* ln_w, Value* ln_bias, OpBuilder* builder);
+ Operation* CreateGate(Location loc, Value in, Value in_w, Value rec,
+ Value rec_w,
+ llvm::Optional<std::pair<Value, Value>> cell,
+ Value ln_w, Value ln_bias, OpBuilder* builder);
- Operation* CreateLayerNorm(Location loc, Value* in, Value* ln_w,
- Value* ln_bias, OpBuilder* builder);
+ Operation* CreateLayerNorm(Location loc, Value in, Value ln_w, Value ln_bias,
+ OpBuilder* builder);
// Add the internal implementation of the LSTM to its regions.
void LoadForLSTMOp(LSTMOp lstm, OpBuilder* builder);
@@ -92,8 +92,8 @@
int16 = any_int16.castFromExpressedType(lstm.input()->getType());
}
-Operation* LoadQuantizationRecipe::CreateLayerNorm(Location loc, Value* in,
- Value* ln_w, Value* ln_bias,
+Operation* LoadQuantizationRecipe::CreateLayerNorm(Location loc, Value in,
+ Value ln_w, Value ln_bias,
OpBuilder* builder) {
// Note that l2_normalization and add ops here are not the execution kernel
// implementation for layer_normalization and we just want to use them to
@@ -105,8 +105,8 @@
}
Operation* LoadQuantizationRecipe::CreateGate(
- Location loc, Value* in, Value* in_w, Value* rec, Value* rec_w,
- llvm::Optional<std::pair<Value*, Value*>> cell, Value* ln_w, Value* ln_bias,
+ Location loc, Value in, Value in_w, Value rec, Value rec_w,
+ llvm::Optional<std::pair<Value, Value>> cell, Value ln_w, Value ln_bias,
OpBuilder* builder) {
auto s1 = builder->create<FullyConnectedOp>(loc, int16, in, in_w, none_cst,
none_af, fc_format, keep_dims);
@@ -119,13 +119,13 @@
cell.getValue().second, none_af);
s4 = builder->create<AddNOp>(
loc, int16,
- llvm::ArrayRef<Value*>(
+ llvm::ArrayRef<Value>(
{*s1.output().begin(), *s2.output().begin(), s3.output()}));
} else {
s4 = builder->create<AddNOp>(
loc, int16,
- llvm::ArrayRef<Value*>({*s1.output().begin(), *s2.output().begin()}));
+ llvm::ArrayRef<Value>({*s1.output().begin(), *s2.output().begin()}));
}
auto s5 = CreateLayerNorm(loc, s4.sum(), ln_w, ln_bias, builder);
@@ -144,22 +144,20 @@
region.push_back(new Block);
builder->setInsertionPointToEnd(®ion.front());
Location loc = lstm.getLoc();
- Type int32_type = builder->getIntegerType(32);
- Type int32_tensor = UnrankedTensorType::get(int32_type);
none_cst = builder->create<ConstantOp>(loc, builder->getNoneType(),
builder->getUnitAttr());
auto input_gate = CreateGate(
loc, lstm.input(), lstm.input_to_input_weights(),
lstm.input_activation_state(), lstm.recurrent_to_input_weights(),
- llvm::Optional<std::pair<Value*, Value*>>(
+ llvm::Optional<std::pair<Value, Value>>(
{lstm.input_cell_state(), lstm.cell_to_input_weights()}),
lstm.input_layer_norm_coefficients(), lstm.input_gate_bias(), builder);
auto forget_gate = CreateGate(
loc, lstm.input(), lstm.input_to_forget_weights(),
lstm.input_activation_state(), lstm.recurrent_to_forget_weights(),
- llvm::Optional<std::pair<Value*, Value*>>(
+ llvm::Optional<std::pair<Value, Value>>(
{lstm.input_cell_state(), lstm.cell_to_forget_weights()}),
lstm.forget_layer_norm_coefficients(), lstm.forget_gate_bias(), builder);
@@ -179,7 +177,7 @@
auto output_gate = CreateGate(
loc, lstm.input(), lstm.input_to_output_weights(),
lstm.input_activation_state(), lstm.recurrent_to_output_weights(),
- llvm::Optional<std::pair<Value*, Value*>>(
+ llvm::Optional<std::pair<Value, Value>>(
{new_cell, lstm.cell_to_output_weights()}),
lstm.output_layer_norm_coefficients(), lstm.output_gate_bias(), builder);
diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
index 7c02342..c21dad8 100644
--- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc
@@ -84,8 +84,8 @@
TensorListPatternRewriter *rewriter);
};
-Value *CreateI32SplatConst(Location loc, PatternRewriter *rewriter,
- ArrayRef<int64_t> shape, int32_t val) {
+Value CreateI32SplatConst(Location loc, PatternRewriter *rewriter,
+ ArrayRef<int64_t> shape, int32_t val) {
RankedTensorType type =
RankedTensorType::get(shape, rewriter->getIntegerType(32));
DenseElementsAttr attr =
@@ -93,9 +93,9 @@
return rewriter->create<ConstantOp>(loc, type, attr);
}
-Value *CreateI32SplatTensor(Location loc, PatternRewriter *rewriter,
- Value *shape_tensor, int32_t val) {
- Value *scalar_val = CreateI32SplatConst(loc, rewriter, {}, val);
+Value CreateI32SplatTensor(Location loc, PatternRewriter *rewriter,
+ Value shape_tensor, int32_t val) {
+ Value scalar_val = CreateI32SplatConst(loc, rewriter, {}, val);
return rewriter->create<TF::FillOp>(
loc, RankedTensorType::get({-1}, rewriter->getIntegerType(32)),
shape_tensor, scalar_val);
@@ -131,32 +131,32 @@
// Requires that `start_index` and `size` are scalar tensors and
// `item_position_shape` is a 1-D tensor with only one element equal to the rank
// of an item in the tensorlist.
-TF::SliceOp CreateSliceOpForTensorList(Location loc, Value *input_list,
- Value *start_index, Value *size,
- Value *item_rank, Type result_type,
+TF::SliceOp CreateSliceOpForTensorList(Location loc, Value input_list,
+ Value start_index, Value size,
+ Value item_rank, Type result_type,
PatternRewriter *rewriter) {
// Create the start position of slice. This is done by concatenating
// `start_index` and `partial_start_position` together.
IntegerType shape_dtype = rewriter->getIntegerType(32);
RankedTensorType position_type = RankedTensorType::get({-1}, shape_dtype);
- Value *partial_start_position =
+ Value partial_start_position =
CreateI32SplatTensor(loc, rewriter, item_rank, 0);
- Value *scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
+ Value scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
RankedTensorType vector_type = RankedTensorType::get({1}, shape_dtype);
auto expanded_start_index = rewriter->create<TF::ExpandDimsOp>(
loc, vector_type, start_index, scalar_zero);
auto start_position = rewriter->create<TF::ConcatOp>(
loc, position_type, scalar_zero,
- ArrayRef<Value *>({expanded_start_index, partial_start_position}));
+ ArrayRef<Value>({expanded_start_index, partial_start_position}));
// Create the slice size tensor. This is done by concatenating `size` and
// `partial_size`.
auto size_leading_dim =
rewriter->create<TF::ExpandDimsOp>(loc, vector_type, size, scalar_zero);
- Value *partial_size = CreateI32SplatTensor(loc, rewriter, item_rank, -1);
+ Value partial_size = CreateI32SplatTensor(loc, rewriter, item_rank, -1);
auto slice_size = rewriter->create<TF::ConcatOp>(
loc, position_type, scalar_zero,
- ArrayRef<Value *>({size_leading_dim, partial_size}));
+ ArrayRef<Value>({size_leading_dim, partial_size}));
return rewriter->create<TF::SliceOp>(loc, result_type, input_list,
start_position, slice_size);
@@ -180,18 +180,18 @@
// 0), [-1, -1, ...])), (ExpandDims $item, expand_dim = 0), (Slice
// $input, [$index + 1, 0, 0, ...], [-1, -1, ...]))>;
PatternMatchResult matchAndRewrite(
- Operation *operation, ArrayRef<Value *> operands,
+ Operation *operation, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto op = llvm::cast<TF::TensorListSetItemOp>(operation);
Location loc = op.getLoc();
- Value *input = operands[0];
- Value *index = operands[1];
- Value *item = operands[2];
+ Value input = operands[0];
+ Value index = operands[1];
+ Value item = operands[2];
IntegerType shape_dtype = rewriter.getIntegerType(32);
auto item_rank = rewriter.create<TF::RankOp>(
loc, RankedTensorType::get({}, shape_dtype), item);
- Value *scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
+ Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
// Calculate `index` + 1, which is used to generate the start position for
// the second slice op.
@@ -204,7 +204,7 @@
// Create two slice ops.
Type element_type = input->getType().cast<TensorType>().getElementType();
UnrankedTensorType unranked_tensor = UnrankedTensorType::get(element_type);
- Value *scalar_minus_one = CreateI32SplatConst(loc, &rewriter, {}, -1);
+ Value scalar_minus_one = CreateI32SplatConst(loc, &rewriter, {}, -1);
TF::SliceOp slice1 =
CreateSliceOpForTensorList(loc, /*input_list=*/input,
/*start_index=*/scalar_zero,
@@ -226,7 +226,7 @@
// Concatenate three parts together to generate the final result.
rewriter.replaceOpWithNewOp<TF::ConcatOp>(
op, input->getType(), scalar_zero,
- ArrayRef<Value *>({slice1, expanded_item, slice2}));
+ ArrayRef<Value>({slice1, expanded_item, slice2}));
return matchSuccess();
}
};
@@ -241,14 +241,14 @@
// Create and return a 1-d tensor with exactly one element equal to the number
// of list elements to initialize the output tensor list with.
- virtual Value *GetNumElements(OpT op, ArrayRef<Value *> operands,
- PatternRewriter *rewriter) const = 0;
+ virtual Value GetNumElements(OpT op, ArrayRef<Value> operands,
+ PatternRewriter *rewriter) const = 0;
// Rewrites the original op into `tf.fill`. The result tensor shape is
// [num_element, element_shape]. All the values in the result tensor will be
// initialized to 0.
PatternMatchResult matchAndRewrite(
- Operation *operation, ArrayRef<Value *> operands,
+ Operation *operation, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
OpT op = llvm::cast<OpT>(operation);
@@ -263,7 +263,7 @@
return matchFailure();
}
- Value *element_shape = operands[0];
+ Value element_shape = operands[0];
Type shape_dtype = getElementTypeOrSelf(element_shape->getType());
DenseIntElementsAttr dense_elem_attr;
@@ -330,11 +330,11 @@
Location loc = op.getLoc();
// Add number of elements as the prefix to the element shape to get shape of
// the output tensor.
- Value *leading_dim = GetNumElements(op, operands, &rewriter);
- Value *scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
+ Value leading_dim = GetNumElements(op, operands, &rewriter);
+ Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
auto list_shape = rewriter.create<TF::ConcatOp>(
loc, shape_type, scalar_zero,
- ArrayRef<Value *>({leading_dim, element_shape}));
+ ArrayRef<Value>({leading_dim, element_shape}));
// Create a zero-initialized constant tensor that has the same type
// as specified by element_dtype.
@@ -352,11 +352,11 @@
explicit ConvertTensorListReserve(MLIRContext *context)
: ConvertTensorListInitOp(context) {}
- Value *GetNumElements(TF::TensorListReserveOp op, ArrayRef<Value *> operands,
- PatternRewriter *rewriter) const override {
- Value *scalar_zero = CreateI32SplatConst(op.getLoc(), rewriter, {}, 0);
+ Value GetNumElements(TF::TensorListReserveOp op, ArrayRef<Value> operands,
+ PatternRewriter *rewriter) const override {
+ Value scalar_zero = CreateI32SplatConst(op.getLoc(), rewriter, {}, 0);
Type shape_dtype = getElementTypeOrSelf(op.element_shape()->getType());
- Value *num_elements = operands[1];
+ Value num_elements = operands[1];
return rewriter->create<TF::ExpandDimsOp>(
op.getLoc(), RankedTensorType::get({1}, shape_dtype), num_elements,
scalar_zero);
@@ -371,8 +371,8 @@
explicit ConvertEmptyTensorList(MLIRContext *context)
: ConvertTensorListInitOp(context) {}
- Value *GetNumElements(TF::EmptyTensorListOp op, ArrayRef<Value *> operands,
- PatternRewriter *rewriter) const override {
+ Value GetNumElements(TF::EmptyTensorListOp op, ArrayRef<Value> operands,
+ PatternRewriter *rewriter) const override {
return CreateI32SplatConst(op.getLoc(), rewriter, {1}, 0);
}
};
@@ -383,17 +383,17 @@
context) {}
PatternMatchResult matchAndRewrite(
- Operation *op, ArrayRef<Value *> operands,
+ Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
TF::TensorListPushBackOp push_back_op = cast<TF::TensorListPushBackOp>(op);
- Value *input_handle = operands[0];
- Value *item = operands[1];
+ Value input_handle = operands[0];
+ Value item = operands[1];
// Expand the shape of the item so that it will have rank same as the input
// tensor and it is compatible for the Concat Op.
Type expanded_item_type =
PrependLeadingDimIfRanked(1, item->getType(), &rewriter);
- Value *scalar_zero = CreateI32SplatConst(op->getLoc(), &rewriter, {}, 0);
+ Value scalar_zero = CreateI32SplatConst(op->getLoc(), &rewriter, {}, 0);
auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
op->getLoc(), expanded_item_type, item, scalar_zero);
@@ -408,7 +408,7 @@
// get a tensor equivalent to the TensorList generated by this op.
rewriter.replaceOpWithNewOp<TF::ConcatOp>(
push_back_op, result_type, scalar_zero,
- ArrayRef<Value *>({input_handle, expanded_item}));
+ ArrayRef<Value>({input_handle, expanded_item}));
return matchSuccess();
}
};
@@ -429,14 +429,14 @@
context) {}
PatternMatchResult matchAndRewrite(
- Operation *op, ArrayRef<Value *> operands,
+ Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
TF::TensorListResizeOp resize_op = cast<TF::TensorListResizeOp>(op);
- Value *input_handle = operands[0];
- Value *size = operands[1];
+ Value input_handle = operands[0];
+ Value size = operands[1];
Location loc = resize_op.getLoc();
- Value *scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
+ Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
// Compute the input tensorlist's length and store it in `input_size`.
IntegerType shape_dtype = rewriter.getIntegerType(32);
@@ -491,7 +491,7 @@
rewriter.replaceOpWithNewOp<TF::IfOp>(
op, result_type, if_cond,
/*input=*/
- ArrayRef<Value *>({input_handle, input_shape, size_diff, size}),
+ ArrayRef<Value>({input_handle, input_shape, size_diff, size}),
/*then_branch=*/rewriter.getSymbolRefAttr(then_branch_op),
/*else_branch=*/rewriter.getSymbolRefAttr(else_branch_op),
/*output_shapes=*/rewriter.getStrArrayAttr({"{}"}),
@@ -517,9 +517,9 @@
Location loc = resize_op.getLoc();
// Get the element shape by slicing from index 1 in the input shape.
- Value *slice_size = CreateI32SplatConst(loc, rewriter, {1}, -1);
- Value *scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
- Value *slice_start = CreateI32SplatConst(loc, rewriter, {1}, 1);
+ Value slice_size = CreateI32SplatConst(loc, rewriter, {1}, -1);
+ Value scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
+ Value slice_start = CreateI32SplatConst(loc, rewriter, {1}, 1);
auto elem_shape = rewriter->create<TF::SliceOp>(
loc, RankedTensorType::get({-1}, shape_dtype), input_shape, slice_start,
slice_size);
@@ -536,8 +536,8 @@
/*num_elements=*/rewriter->getI32IntegerAttr(-1));
auto concat_op = rewriter->create<TF::ConcatOp>(
loc, result_type, scalar_zero,
- ArrayRef<Value *>({input, stacked_extended_part}));
- rewriter->create<ReturnOp>(loc, ArrayRef<Value *>({concat_op}));
+ ArrayRef<Value>({input, stacked_extended_part}));
+ rewriter->create<ReturnOp>(loc, ArrayRef<Value>({concat_op}));
}
void CreateCondFalseBranch(Location loc, Type shape_dtype, Type result_type,
@@ -550,8 +550,8 @@
Block *block = branch_func.addEntryBlock();
rewriter->setInsertionPointToStart(block);
- Value *scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
- Value *vector_one = CreateI32SplatConst(loc, rewriter, {1}, 1);
+ Value scalar_zero = CreateI32SplatConst(loc, rewriter, {}, 0);
+ Value vector_one = CreateI32SplatConst(loc, rewriter, {1}, 1);
auto input = block->getArgument(0);
auto size = block->getArgument(3);
@@ -566,7 +566,7 @@
/*start_index=*/scalar_zero, /*size=*/size,
/*item_rank=*/partial_position_shape,
/*result_type=*/result_type, rewriter);
- rewriter->create<ReturnOp>(loc, ArrayRef<Value *>({slice_op}));
+ rewriter->create<ReturnOp>(loc, ArrayRef<Value>({slice_op}));
}
};
@@ -576,11 +576,11 @@
context) {}
PatternMatchResult matchAndRewrite(
- Operation *operation, ArrayRef<Value *> operands,
+ Operation *operation, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto op = llvm::cast<TF::TensorListGetItemOp>(operation);
- Value *input = operands[0];
- Value *index = operands[1];
+ Value input = operands[0];
+ Value index = operands[1];
rewriter.replaceOpWithNewOp<TF::GatherOp>(
operation, op.getType(), input, index, rewriter.getBoolAttr(true));
return matchSuccess();
@@ -593,11 +593,11 @@
context) {}
PatternMatchResult matchAndRewrite(
- Operation *operation, ArrayRef<Value *> operands,
+ Operation *operation, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto op = llvm::cast<TF::TensorListLengthOp>(operation);
Location loc = op.getLoc();
- Value *input_handle = operands[0];
+ Value input_handle = operands[0];
BoolAttr true_attr = rewriter.getBoolAttr(true);
auto shape = rewriter.create<TF::ShapeOp>(loc, input_handle,
@@ -615,12 +615,12 @@
context) {}
PatternMatchResult matchAndRewrite(
- Operation *operation, ArrayRef<Value *> operands,
+ Operation *operation, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto op = llvm::cast<TF::TensorListStackOp>(operation);
Location loc = op.getLoc();
- Value *input = operands[0];
- Value *element_shape = operands[1];
+ Value input = operands[0];
+ Value element_shape = operands[1];
// If the `element_shape` is a known constant (which is defined when calling
// `tensor_list_stack`) and also valid (not scalar), we rewrite this op to a
@@ -655,10 +655,10 @@
: ConversionPattern(TF::IdentityOp::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(
- Operation *operation, ArrayRef<Value *> operands,
+ Operation *operation, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto op = llvm::cast<TF::IdentityOp>(operation);
- Value *input = operands[0];
+ Value input = operands[0];
rewriter.replaceOpWithNewOp<TF::IdentityOp>(op, input->getType(), operands,
op.getAttrs());
return matchSuccess();
@@ -728,7 +728,7 @@
: ConversionPattern(TF::WhileOp::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(
- Operation *operation, ArrayRef<Value *> operands,
+ Operation *operation, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto op = llvm::cast<TF::WhileOp>(operation);
diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc
index 1313bae..9241fbe 100644
--- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc
@@ -50,7 +50,7 @@
// The actual Optimize Pass.
namespace {
-bool L2NormalizeReduceAxis(Value *sq_op, DenseElementsAttr axis) {
+bool L2NormalizeReduceAxis(Value sq_op, DenseElementsAttr axis) {
if (sq_op->getType().cast<ShapedType>().getRank() - 1 ==
*axis.getValues<int>().begin() ||
*axis.getValues<int>().begin() == -1) {
@@ -142,7 +142,7 @@
// Returns shape of a ranked tensor.
// Precondition: output_val's is ranked tensor.
-DenseElementsAttr GetShape(Value *output_val) {
+DenseElementsAttr GetShape(Value output_val) {
auto output_type = output_val->getType().cast<RankedTensorType>();
auto shape_vector = output_type.getShape();
std::vector<int32_t> shape(shape_vector.size());
@@ -167,7 +167,7 @@
PatternRewriter &rewriter) const override {
// Add.
DenseElementsAttr added_value;
- Value *constant_val = add_op.rhs();
+ Value constant_val = add_op.rhs();
if (!matchPattern(constant_val, m_Constant(&added_value)))
return matchFailure();
@@ -176,8 +176,8 @@
dyn_cast_or_null<TFL::FullyConnectedOp>(add_op.lhs()->getDefiningOp());
if (!fc_op) return matchFailure();
- Value *filter = fc_op.filter();
- Value *bias = fc_op.bias();
+ Value filter = fc_op.filter();
+ Value bias = fc_op.bias();
ElementsAttr bias_value;
const bool is_none_bias = bias->getType().isa<NoneType>();
if (!is_none_bias && !matchPattern(bias, m_Constant(&bias_value)))
@@ -242,15 +242,15 @@
PatternRewriter &rewriter) const override {
// Mul.
DenseElementsAttr cst;
- Value *constant_val = mul_op.rhs();
+ Value constant_val = mul_op.rhs();
if (!matchPattern(constant_val, m_Constant(&cst))) return matchFailure();
// Fully Connected.
auto fc_op =
dyn_cast_or_null<TFL::FullyConnectedOp>(mul_op.lhs()->getDefiningOp());
if (!fc_op) return matchFailure();
- Value *filter = fc_op.filter();
- Value *bias = fc_op.bias();
+ Value filter = fc_op.filter();
+ Value bias = fc_op.bias();
ElementsAttr cst_tmp;
if (!matchPattern(filter, m_Constant(&cst_tmp))) return matchFailure();
if (!bias->getType().isa<NoneType>() &&
@@ -261,7 +261,7 @@
// Broadcast the constant operand of Mul if it isn't compatible to the
// filter input. We only support broadcasting the operand along the depth
// dimension, when the operand's depth is 1.
- Value *new_const_val = constant_val;
+ Value new_const_val = constant_val;
if (!IsBroadcastableElementsAttrAndType(cst.getType(), filter->getType())) {
auto original_shape = cst.getType().getShape();
llvm::SmallVector<int64_t, 4> normalized_shape(original_shape.begin(),
@@ -325,8 +325,8 @@
APFloat cst_value = *cst.float_value_begin();
// Affine op.
- Value *filter = fc_op.filter();
- Value *bias = fc_op.bias();
+ Value filter = fc_op.filter();
+ Value bias = fc_op.bias();
DenseFPElementsAttr filter_cst, bias_cst;
if (!matchPattern(filter, m_Constant(&filter_cst))) {
// The filter maybe quantized, then we should set it to the real constant.
diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc
index 59dc271..3d362b4 100644
--- a/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc
@@ -98,13 +98,13 @@
for (int i = 0, e = func.getNumArguments(); i != e; ++i)
mapper.map(func.getArgument(i), op.getOperand(i + 1));
- llvm::SmallVector<Value*, 4> updated_results;
+ llvm::SmallVector<Value, 4> updated_results;
for (auto& op_to_inline : func.getBody().front()) {
// If this is a terminator, identify the values to use to replace the
// original If op.
if (op_to_inline.isKnownTerminator()) {
updated_results.reserve(op_to_inline.getNumOperands());
- for (Value* operand : op_to_inline.getOperands())
+ for (Value operand : op_to_inline.getOperands())
updated_results.push_back(mapper.lookup(operand));
break;
}
diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
index 1a22c80..99ad081 100644
--- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
+++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
@@ -54,13 +54,15 @@
[TFL_Relu1Op, TFL_AF_Relu1]] in
defm : FuseActFnIntoConvOpPat<actFnPair[0], actFnPair[1]>;
+// Checks if the value has only one user.
+def HasOneUse : Constraint<CPred<"$0->hasOneUse()">>;
// If we see a binary op (add, sub) op adding a constant value to a convolution
// op with constant bias, we can fuse the binary op into the convolution op by
// constant folding the bias and the binary op's constant operand. The following
// pattern restricts to float constant values for now.
multiclass FuseBinaryOpToPrecedingAffine<dag binaryOp> {
- def : Pat<(binaryOp (TFL_Conv2DOp $input, $filter,
+ def : Pat<(binaryOp (TFL_Conv2DOp:$output $input, $filter,
(ConstantOp F32ElementsAttr:$bias),
$h_factor, $w_factor, TFL_AF_None,
$padding, $stride_h, $stride_w),
@@ -69,8 +71,9 @@
(binaryOp (ConstantOp $bias),
(ConstantOp $value), TFL_AF_None),
$h_factor, $w_factor, $act_fn,
- $padding, $stride_h, $stride_w)>;
- def : Pat<(binaryOp (TFL_DepthwiseConv2DOp $input, $filter,
+ $padding, $stride_h, $stride_w),
+ [(HasOneUse $output)]>;
+ def : Pat<(binaryOp (TFL_DepthwiseConv2DOp:$output $input, $filter,
(ConstantOp F32ElementsAttr:$bias),
$h_factor, $w_factor, TFL_AF_None,
$padding, $stride_h, $stride_w,
@@ -82,7 +85,8 @@
TFL_AF_None),
$h_factor, $w_factor, $act_fn,
$padding, $stride_h, $stride_w,
- $multiplier)>;
+ $multiplier),
+ [(HasOneUse $output)]>;
}
foreach binaryOp = [TFL_AddOp, TFL_SubOp] in
defm : FuseBinaryOpToPrecedingAffine<binaryOp>;
@@ -102,7 +106,7 @@
// The following pattern restricts to float constant values for now.
multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d<dag BinaryOp> {
- def : Pat<(BinaryOp (TFL_DepthwiseConv2DOp $input,
+ def : Pat<(BinaryOp (TFL_DepthwiseConv2DOp:$output $input,
(ConstantOp F32ElementsAttr:$filter),
(ConstantOp F32ElementsAttr:$bias),
$h_factor, $w_factor, TFL_AF_None,
@@ -120,8 +124,9 @@
$h_factor, $w_factor, $act_fn,
$padding, $stride_h, $stride_w,
$multiplier),
- [(CanFuseConvOrDepthwiseConv<"true"> $filter, $value)]>;
- def : Pat<(BinaryOp (TFL_Conv2DOp $input,
+ [(CanFuseConvOrDepthwiseConv<"true"> $filter, $value),
+ (HasOneUse $output)]>;
+ def : Pat<(BinaryOp (TFL_Conv2DOp:$conv_output $input,
(ConstantOp F32ElementsAttr:$filter),
(ConstantOp F32ElementsAttr:$bias),
$h_factor, $w_factor, TFL_AF_None,
@@ -136,7 +141,8 @@
TFL_AF_None),
$h_factor, $w_factor, $act_fn,
$padding, $stride_h, $stride_w),
- [(CanFuseConvOrDepthwiseConv<"false"> $filter, $value)]>;
+ [(CanFuseConvOrDepthwiseConv<"false"> $filter, $value),
+ (HasOneUse $conv_output)]>;
}
foreach BinaryOp = [TFL_DivOp, TFL_MulOp] in
diff --git a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc
index 4f56de2..5394455 100644
--- a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc
@@ -67,13 +67,13 @@
// In each iteration, a new argument is appended to the end of the list
// and the current argument is erased, so here we always process the first
// argument in the list.
- auto* arg = bb.getArgument(0);
+ auto arg = bb.getArgument(0);
auto remove_quantize_op = [&](QuantizeOp quantize_op) {
auto quantize_output = quantize_op.output();
auto quantize_type = quantize_output->getType();
input_types.push_back(quantize_type);
- auto* new_arg = bb.addArgument(quantize_type);
+ auto new_arg = bb.addArgument(quantize_type);
quantize_output->replaceAllUsesWith(new_arg);
quantize_op.erase();
arg->dropAllUses();
@@ -91,7 +91,7 @@
// the pattern isn't found.
Type arg_type = arg->getType();
input_types.push_back(arg_type);
- auto* new_arg = bb.addArgument(arg_type);
+ auto new_arg = bb.addArgument(arg_type);
arg->replaceAllUsesWith(new_arg);
arg->dropAllUses();
bb.eraseArgument(0);
@@ -102,11 +102,11 @@
llvm::SmallVector<Type, 4> output_types;
output_types.reserve(num_return_operands);
for (int i = 0; i != num_return_operands; ++i) {
- auto* returned_value = terminator->getOperand(i);
+ auto returned_value = terminator->getOperand(i);
Operation* returned_op = returned_value->getDefiningOp();
if (returned_op && llvm::isa<DequantizeOp>(returned_op)) {
auto dequantize_op = llvm::cast<DequantizeOp>(returned_op);
- Value* dequantized_result = dequantize_op.input();
+ Value dequantized_result = dequantize_op.input();
output_types.push_back(dequantized_result->getType());
terminator->setOperand(i, dequantized_result);
returned_op->erase();
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc
index c299064..99d8a06 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc
@@ -53,8 +53,8 @@
void RewriteFunc() {
func_.setAttr(kTFImplements,
StringAttr::get("embedding_lookup", func_.getContext()));
- Value* lookup = func_.getArgument(1);
- Value* value = func_.getArgument(0);
+ Value lookup = func_.getArgument(1);
+ Value value = func_.getArgument(0);
auto output_type = func_.getType().getResult(0);
OpBuilder builder(func_.getBody());
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc
index 5d139f8..b058b41 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc
@@ -139,7 +139,7 @@
BoolAttr narrow_range = builder.getBoolAttr(false);
auto add_quantize_op = [&](Location loc, Type input_type, Block* block,
- Block::iterator insertion_point, Value* arg,
+ Block::iterator insertion_point, Value arg,
int i) {
if (auto shaped = input_type.dyn_cast<ShapedType>()) {
if (shaped.getElementType().isa<FloatType>()) {
@@ -160,7 +160,7 @@
};
for (int i = 0, e = func.getNumArguments(); i != e; ++i) {
- BlockArgument* arg = func.getArgument(i);
+ BlockArgument arg = func.getArgument(i);
auto* arg_block = arg->getOwner();
add_quantize_op(arg->getLoc(), arg->getType(), arg_block,
std::next(arg_block->begin(), i), arg, i);
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
index 45248dd..e2d046f 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
@@ -121,7 +121,7 @@
// Extract the min/max constant values from the operands. We also consider
// a special case that there are tf.Identity ops between the min/max
// constants and the tf.FakeQuantWithMinMaxVarsOp.
- Value *min = tf_op.min(), *max = tf_op.max();
+ Value min = tf_op.min(), max = tf_op.max();
DenseFPElementsAttr min_value, max_value;
if (auto id1 = dyn_cast_or_null<TF::IdentityOp>(min->getDefiningOp()))
min = id1.input();
@@ -150,7 +150,7 @@
// Finally, use the quantization parameter to create the quantize and
// dequantize ops, and insert them between the tf.FakeQuantWithMinMaxVarsOp
// and its users.
- Value *value = tf_op.outputs();
+ Value value = tf_op.outputs();
auto quantize = rewriter.create<TFL::QuantizeOp>(
tf_op.getLoc(), qtype.getValue(), value, qtype);
auto dequantize = rewriter.create<TFL::DequantizeOp>(
@@ -177,8 +177,8 @@
//
// TFL::[op] createTFLOp(ConvertTFConvOpMatchState *state,
// PatternRewriter &rewriter, Location loc,
-// Type result_type, Value *input,
-// Value *filter, Value *bias) const;
+// Type result_type, Value input,
+// Value filter, Value bias) const;
//
// And also the following method for getting the dimension for bias tensor:
//
@@ -294,8 +294,8 @@
TFL::Conv2DOp createTFLOp(ConvertTFConvOpMatchState *state,
PatternRewriter &rewriter, Location loc,
- Type result_type, Value *input, Value *filter,
- Value *bias) const {
+ Type result_type, Value input, Value filter,
+ Value bias) const {
filter = legalizeFilter(rewriter, loc, filter);
return rewriter.create<TFL::Conv2DOp>(
loc, result_type, input, filter, bias,
@@ -312,8 +312,8 @@
// format HWIO to TFLite Conv2D op filter data format OHWI and return Value
// for the converted filter. Requires that filter is verified by the match
// method that it is a 4-D RankedTensorType.
- Value *legalizeFilter(PatternRewriter &rewriter, Location loc,
- Value *filter) const {
+ Value legalizeFilter(PatternRewriter &rewriter, Location loc,
+ Value filter) const {
// Create a constant op for HWIO to OHWI transpose permutation.
SmallVector<int, 4> perm = {3, 0, 1, 2};
auto perm_type = RankedTensorType::get({static_cast<int>(perm.size())},
@@ -349,8 +349,8 @@
TFL::DepthwiseConv2DOp createTFLOp(ConvertTFConvOpMatchState *state,
PatternRewriter &rewriter, Location loc,
- Type result_type, Value *input,
- Value *filter, Value *bias) const {
+ Type result_type, Value input,
+ Value filter, Value bias) const {
// Compared to tfl.conv_2d, tfl.depthwise_conv_2d has an additional
// 'depth_multiplier' attribute. However, tf.DepthwiseConv2dNative does not
// have a corresponding 'depth_multiplier' attribute; the multiplier is the
@@ -378,8 +378,8 @@
/// filter data format is [1, filter_height, filter_width, out_channels].
/// Requires that filter is verified by the match method that it is a 4-D
/// RankedTensorType.
- Value *legalizeFilter(PatternRewriter &rewriter, Location loc,
- Value *filter) const {
+ Value legalizeFilter(PatternRewriter &rewriter, Location loc,
+ Value filter) const {
auto filter_type = filter->getType().cast<RankedTensorType>();
auto filterShape = filter_type.getShape();
SmallVector<int64_t, 4> result_shape = {1, filterShape[0], filterShape[1],
@@ -430,7 +430,7 @@
if (new_axis_mask == 0) return matchFailure();
// Insert a new reshape op.
- Value *original_input = strided_slice_op.input();
+ Value original_input = strided_slice_op.input();
RankedTensorType original_input_type =
original_input->getType().cast<RankedTensorType>();
const ArrayRef<int64_t> &original_input_shape =
diff --git a/tensorflow/compiler/mlir/lite/transforms/split_merged_operands.cc b/tensorflow/compiler/mlir/lite/transforms/split_merged_operands.cc
index 123d1f8..8e5eff7 100644
--- a/tensorflow/compiler/mlir/lite/transforms/split_merged_operands.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/split_merged_operands.cc
@@ -71,13 +71,13 @@
};
LogicalResult DuplicateValueIfNeeded(Operation* op,
- llvm::DenseSet<Value*>* values,
+ llvm::DenseSet<Value>* values,
OpBuilder* builder) {
std::vector<int> stateful_operands_index;
if (!IsStatefulOp(op, &stateful_operands_index)) return success();
for (int index : stateful_operands_index) {
- Value* operand = op->getOperand(index);
+ Value operand = op->getOperand(index);
auto inserted_value = values->insert(operand).second;
if (inserted_value) continue;
// We can only clone the constant op at this point.
@@ -102,7 +102,7 @@
}
void SplitMergedOperandsPass::runOnFunction() {
- llvm::DenseSet<Value*> stateful_values;
+ llvm::DenseSet<Value> stateful_values;
auto func = getFunction();
OpBuilder builder(func);
for (auto& bb : func.getBody()) {
diff --git a/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc b/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc
index 61d33a5..20351f7 100644
--- a/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc
@@ -67,7 +67,7 @@
template <typename BatchMatMulOpType>
TF::ReshapeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createReshapeOp(
- Value* value, ArrayRef<int64_t> shape, Type element_type, Location loc,
+ Value value, ArrayRef<int64_t> shape, Type element_type, Location loc,
PatternRewriter& rewriter) {
int64_t shape_rank = shape.size();
auto shape_spec_type =
@@ -81,8 +81,8 @@
}
template <typename BatchMatMulOpType>
-std::vector<Value*> ConvertTFBatchMatMulOp<BatchMatMulOpType>::sliceInput(
- Value* value, int batch_size, Location loc, PatternRewriter& rewriter) {
+std::vector<Value> ConvertTFBatchMatMulOp<BatchMatMulOpType>::sliceInput(
+ Value value, int batch_size, Location loc, PatternRewriter& rewriter) {
RankedTensorType tensorType = value->getType().cast<RankedTensorType>();
Type element_type = tensorType.getElementType();
@@ -96,7 +96,7 @@
SmallVector<int64_t, 3> slice_size = {1, num_rows, num_cols};
- std::vector<Value*> sliced;
+ std::vector<Value> sliced;
Type int64_type = rewriter.getIntegerType(64);
Type slice_result_type = RankedTensorType::get(slice_size, element_type);
@@ -126,7 +126,7 @@
template <typename BatchMatMulOpType>
TF::TransposeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createTransposeOp(
- Value* value, Location loc, PatternRewriter& rewriter) {
+ Value value, Location loc, PatternRewriter& rewriter) {
auto value_type = value->getType().cast<RankedTensorType>();
auto shape = value_type.getShape();
int dims = shape.size();
@@ -158,13 +158,12 @@
template <typename BatchMatMulOpType>
TF::PackOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createMatMulOps(
- const std::vector<Value*>& sliced_lhs,
- const std::vector<Value*>& sliced_rhs, const tensorflow::MatMulBCast& bcast,
- int rows, int cols, Type element_type, Location loc,
- PatternRewriter& rewriter) {
+ const std::vector<Value>& sliced_lhs, const std::vector<Value>& sliced_rhs,
+ const tensorflow::MatMulBCast& bcast, int rows, int cols, Type element_type,
+ Location loc, PatternRewriter& rewriter) {
auto matmul_type = RankedTensorType::get({rows, cols}, element_type);
- std::vector<Value*> matmuls;
+ std::vector<Value> matmuls;
for (int batch_idx = 0; batch_idx < bcast.output_batch_size(); ++batch_idx) {
int lhs_batch_idx, rhs_batch_idx;
if (bcast.IsBroadcastingRequired()) {
@@ -195,8 +194,8 @@
template <typename BatchMatMulOpType>
PatternMatchResult ConvertTFBatchMatMulOp<BatchMatMulOpType>::matchAndRewrite(
BatchMatMulOpType op, PatternRewriter& rewriter) const {
- Value* input_lhs = op.x();
- Value* input_rhs = op.y();
+ Value input_lhs = op.x();
+ Value input_rhs = op.y();
if (!input_lhs->getType().isa<RankedTensorType>()) {
// LHS must be a ranked tensor type
@@ -276,9 +275,9 @@
}
// Compute slices for each batch in the LHS and RHS.
- std::vector<Value*> sliced_lhs =
+ std::vector<Value> sliced_lhs =
sliceInput(input_lhs, bcast.x_batch_size(), loc, rewriter);
- std::vector<Value*> sliced_rhs =
+ std::vector<Value> sliced_rhs =
sliceInput(input_rhs, bcast.y_batch_size(), loc, rewriter);
// Compute (single batch) MatMul for each output batch. The MatMul outputs
diff --git a/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h b/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h
index 19b7596..0e72b3b 100644
--- a/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h
+++ b/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h
@@ -33,19 +33,18 @@
class ConvertTFBatchMatMulOp : public OpRewritePattern<BatchMatMulOpType> {
using OpRewritePattern<BatchMatMulOpType>::OpRewritePattern;
- static TF::ReshapeOp createReshapeOp(Value* value, ArrayRef<int64_t> shape,
+ static TF::ReshapeOp createReshapeOp(Value value, ArrayRef<int64_t> shape,
Type element_type, Location loc,
PatternRewriter& rewriter);
- static std::vector<Value*> sliceInput(Value* value, int batch_size,
- Location loc,
- PatternRewriter& rewriter);
+ static std::vector<Value> sliceInput(Value value, int batch_size,
+ Location loc, PatternRewriter& rewriter);
- static TF::TransposeOp createTransposeOp(Value* value, Location loc,
+ static TF::TransposeOp createTransposeOp(Value value, Location loc,
PatternRewriter& rewriter);
- static TF::PackOp createMatMulOps(const std::vector<Value*>& sliced_lhs,
- const std::vector<Value*>& sliced_rhs,
+ static TF::PackOp createMatMulOps(const std::vector<Value>& sliced_lhs,
+ const std::vector<Value>& sliced_rhs,
const tensorflow::MatMulBCast& bcast,
int rows, int cols, Type element_type,
Location loc, PatternRewriter& rewriter);
diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc
index 92a8ad4..0d5d177 100644
--- a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc
+++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc
@@ -42,35 +42,35 @@
namespace {
-Value* CreateI32SplatConst(OpBuilder* builder, ArrayRef<int64_t> shape,
- int32_t val, mlir::Location location) {
+Value CreateI32SplatConst(OpBuilder* builder, ArrayRef<int64_t> shape,
+ int32_t val, mlir::Location location) {
auto type = RankedTensorType::get(shape, builder->getIntegerType(32));
auto attr = DenseElementsAttr::get(type, val);
return builder->create<ConstantOp>(location, type, attr);
}
-Value* CreateF32SplatConst(OpBuilder* builder, ArrayRef<int64_t> shape,
- float val, mlir::Location location) {
+Value CreateF32SplatConst(OpBuilder* builder, ArrayRef<int64_t> shape,
+ float val, mlir::Location location) {
auto type = RankedTensorType::get(shape, builder->getF32Type());
auto attr = DenseElementsAttr::get(type, val);
return builder->create<ConstantOp>(location, type, attr);
}
-Value* CreateI64DenseConst(OpBuilder* builder, ArrayRef<int64_t> shape,
- ArrayRef<int64_t> values, mlir::Location location) {
+Value CreateI64DenseConst(OpBuilder* builder, ArrayRef<int64_t> shape,
+ ArrayRef<int64_t> values, mlir::Location location) {
auto type = RankedTensorType::get(static_cast<int>(shape.size()),
builder->getIntegerType(64));
auto attr = DenseElementsAttr::get(type, values);
return builder->create<ConstantOp>(location, type, attr);
}
-Value* CreateNoneValue(OpBuilder* builder, mlir::Location location) {
+Value CreateNoneValue(OpBuilder* builder, mlir::Location location) {
return builder->create<mlir::ConstantOp>(location, builder->getNoneType(),
builder->getUnitAttr());
}
-Value* Transpose2D(OpBuilder* builder, Value* value_to_transpose,
- RankedTensorType type, mlir::Location location) {
+Value Transpose2D(OpBuilder* builder, Value value_to_transpose,
+ RankedTensorType type, mlir::Location location) {
// Create a constant op for transpose permutation.
SmallVector<int64_t, 2> perm = {1, 0};
auto perm_op = CreateI64DenseConst(builder, perm, perm, location);
@@ -87,16 +87,16 @@
value_to_transpose, perm_op);
}
-ArrayRef<int64_t> GetRankedTensorShape(Value* value) {
+ArrayRef<int64_t> GetRankedTensorShape(Value value) {
return value->getType().cast<RankedTensorType>().getShape();
}
-Value* SliceRankedTensor(OpBuilder* builder, Value* input,
- ArrayRef<int64_t> begin_shape,
- ArrayRef<int64_t> begin_values,
- ArrayRef<int64_t> size_shape,
- ArrayRef<int64_t> size_values,
- mlir::Location location) {
+Value SliceRankedTensor(OpBuilder* builder, Value input,
+ ArrayRef<int64_t> begin_shape,
+ ArrayRef<int64_t> begin_values,
+ ArrayRef<int64_t> size_shape,
+ ArrayRef<int64_t> size_values,
+ mlir::Location location) {
// If the size of the tensor to be sliced from the input overflows
// the input tensor's dimensions, return 0-valued tensor of the requested
// shape.
diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils.h b/tensorflow/compiler/mlir/lite/utils/lstm_utils.h
index 235d438..ea28aaa 100644
--- a/tensorflow/compiler/mlir/lite/utils/lstm_utils.h
+++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils.h
@@ -102,15 +102,15 @@
// specified state
FuncOp fused_func_op_;
- Value* input_;
- Value* weight_;
- Value* bias_;
- Value* projection_;
+ Value input_;
+ Value weight_;
+ Value bias_;
+ Value projection_;
bool couple_input_forget_gates_;
// internal state
- Value* weight_transposed_;
- Value* projection_transposed_;
+ Value weight_transposed_;
+ Value projection_transposed_;
RankedTensorType weight_type_;
RankedTensorType projection_type_;
int num_gates_;
@@ -121,40 +121,40 @@
int num_cols_projection_transposed_;
// input -> cifg
- Value* input2input_;
- Value* input2forget_;
- Value* input2cell_;
- Value* input2output_;
+ Value input2input_;
+ Value input2forget_;
+ Value input2cell_;
+ Value input2output_;
// recurrent -> cifg
- Value* rec2input_;
- Value* rec2forget_;
- Value* rec2cell_;
- Value* rec2output_;
+ Value rec2input_;
+ Value rec2forget_;
+ Value rec2cell_;
+ Value rec2output_;
// bias -> cifg
- Value* bias2input_;
- Value* bias2forget_;
- Value* bias2cell_;
- Value* bias2output_;
+ Value bias2input_;
+ Value bias2forget_;
+ Value bias2cell_;
+ Value bias2output_;
// projection
- Value* proj_weight_;
- Value* proj_bias_;
+ Value proj_weight_;
+ Value proj_bias_;
// state
- Value* input_activation_state_;
- Value* input_cell_state_;
+ Value input_activation_state_;
+ Value input_cell_state_;
// layer norm coefficients
- Value* input_layer_norm_coefficients_;
- Value* forget_layer_norm_coefficients_;
- Value* cell_layer_norm_coefficients_;
- Value* output_layer_norm_coefficients_;
+ Value input_layer_norm_coefficients_;
+ Value forget_layer_norm_coefficients_;
+ Value cell_layer_norm_coefficients_;
+ Value output_layer_norm_coefficients_;
mlir::TFL::LSTMOp lstm_;
- Value* none_;
+ Value none_;
SmallVector<int64_t, 1> bias_slice_shape_;
SmallVector<int64_t, 1> bias_size_values_;
SmallVector<int64_t, 2> weight_slice_shape_;
@@ -199,7 +199,7 @@
private:
// specified state
- Value* layer_norm_scale_;
+ Value layer_norm_scale_;
// internal state
RankedTensorType layer_norm_scale_type_;
diff --git a/tensorflow/compiler/mlir/lite/utils/validators.h b/tensorflow/compiler/mlir/lite/utils/validators.h
index 0a5d790..2fd8630 100644
--- a/tensorflow/compiler/mlir/lite/utils/validators.h
+++ b/tensorflow/compiler/mlir/lite/utils/validators.h
@@ -51,7 +51,7 @@
// Returns true iff the given value is a float tensor.
// is "DT_FLOAT".
-inline bool TFTypeIsFloatTensor(Value *value) {
+inline bool TFTypeIsFloatTensor(Value value) {
auto tensorType = value->getType().dyn_cast<TensorType>();
if (!tensorType) return false;
return tensorType.getElementType().isa<FloatType>();
diff --git a/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc b/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc
index 6b8dd7b..09d16fa 100644
--- a/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc
+++ b/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc
@@ -148,13 +148,13 @@
// generated using the op type.
return op->getName().getStringRef();
}
- auto* val = op_or_val.dyn_cast<mlir::Value*>();
+ auto val = op_or_val.dyn_cast<mlir::Value>();
auto name_from_loc = GetNameFromLoc(val->getLoc());
if (!name_from_loc.empty()) return name_from_loc;
// If the location is none of the expected types, then simply use name
// generated using the op type. Follow TF convention and append the result
// index unless 0.
- if (auto* result = llvm::dyn_cast<mlir::OpResult>(val)) {
+ if (auto result = val->dyn_cast<mlir::OpResult>()) {
if (result->getResultNumber() > 0)
return llvm::formatv("{0}:{1}",
result->getOwner()->getName().getStringRef(),
diff --git a/tensorflow/compiler/mlir/op_or_arg_name_mapper.h b/tensorflow/compiler/mlir/op_or_arg_name_mapper.h
index 6517349..a51035b 100644
--- a/tensorflow/compiler/mlir/op_or_arg_name_mapper.h
+++ b/tensorflow/compiler/mlir/op_or_arg_name_mapper.h
@@ -30,7 +30,7 @@
// PointerUnion for operation and value.
// TODO(jpienaar): Rename the files.
-using OpOrVal = llvm::PointerUnion<mlir::Operation*, mlir::Value*>;
+using OpOrVal = llvm::PointerUnion<mlir::Operation*, mlir::Value>;
// Mapper from operation or value to name.
class OpOrArgNameMapper {
diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD
index 288a63e..a93d2a7 100644
--- a/tensorflow/compiler/mlir/tensorflow/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/BUILD
@@ -11,6 +11,7 @@
includes = ["@local_config_mlir//:subpackages"],
packages = [
"//tensorflow/compiler/...",
+ "//tensorflow/lite/experimental/tf_runtime/...",
"//tensorflow/python/...",
],
)
diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc
index f9df753..6748cb8 100644
--- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc
+++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc
@@ -86,15 +86,15 @@
func_op.getBody().front().getTerminator()->getOperand(return_index);
assert(mlir::getElementTypeOrSelf(value->getType()).isa<TF::ResourceType>());
int64_t arg_index = -1;
- auto try_parse_arg_index = [&arg_index](Value* v) {
- auto resource_arg = llvm::dyn_cast<BlockArgument>(v);
+ auto try_parse_arg_index = [&arg_index](Value v) {
+ auto resource_arg = v->dyn_cast<BlockArgument>();
if (resource_arg) arg_index = resource_arg->getArgNumber();
return arg_index;
};
while (try_parse_arg_index(value) == -1) {
auto op = value->getDefiningOp();
assert(op);
- int64_t res_num = llvm::dyn_cast<OpResult>(value)->getResultNumber();
+ int64_t res_num = value->cast<OpResult>()->getResultNumber();
if (auto graph = llvm::dyn_cast<tf_executor::GraphOp>(op)) {
value = graph.GetFetch().getOperand(res_num);
} else if (auto island = llvm::dyn_cast<tf_executor::IslandOp>(op)) {
@@ -131,7 +131,7 @@
resource_value_to_ids_[arg].insert(next_unique_id++);
}
llvm::StringMap<int64_t> var_handle_name_id_map;
- auto forward_input_to_output = [&](Value* operand, Value* result) {
+ auto forward_input_to_output = [&](Value operand, Value result) {
if (!mlir::getElementTypeOrSelf(result->getType()).isa<TF::ResourceType>())
return;
auto& result_ids = resource_value_to_ids_[result];
@@ -220,7 +220,7 @@
});
}
-bool ResourceAliasAnalysis::IsUnknownResource(const Value* resource) const {
+bool ResourceAliasAnalysis::IsUnknownResource(const Value resource) const {
auto it = resource_value_to_ids_.find(resource);
assert(it != resource_value_to_ids_.end() && !it->getSecond().empty());
// The set is sorted so we only need to check the first element since
@@ -231,7 +231,7 @@
}
const llvm::SmallSet<int64_t, 8>& ResourceAliasAnalysis::GetResourceUniqueIds(
- const Value* resource) const {
+ const Value resource) const {
auto it = resource_value_to_ids_.find(resource);
assert(it != resource_value_to_ids_.end() && "Unseen resource was queried");
return it->getSecond();
diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h
index 98df094..bd39b3a 100644
--- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h
+++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h
@@ -42,12 +42,12 @@
ResourceAliasAnalysis(ResourceAliasAnalysis&&) = default;
// Returns if the analysis fails to resolve a resource-type value.
- bool IsUnknownResource(const Value* resource) const;
+ bool IsUnknownResource(const Value resource) const;
// Returns the set unique IDs which `resource` could alias. Requires that
// IsUnknownResource(resource) == true.
const llvm::SmallSet<int64_t, 8>& GetResourceUniqueIds(
- const Value* resource) const;
+ const Value resource) const;
private:
ResourceAliasAnalysis() = default;
@@ -56,7 +56,7 @@
void AnalyzeFunction(FuncOp func_op);
// Maps each resource-type value to a set of unique IDs that it could alias.
- llvm::SmallDenseMap<const Value*, llvm::SmallSet<int64_t, 8>, 8>
+ llvm::SmallDenseMap<Value, llvm::SmallSet<int64_t, 8>, 8>
resource_value_to_ids_;
};
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h
index d3cf173..913d57f 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h
+++ b/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h
@@ -90,8 +90,8 @@
static StringRef getOperationName() { return "_tf.Enter"; }
- Value *getData() { return getOperand(0); }
- void setData(Value *value) { setOperand(0, value); }
+ Value getData() { return getOperand(0); }
+ void setData(Value value) { setOperand(0, value); }
LogicalResult verify();
};
@@ -172,8 +172,8 @@
static StringRef getOperationName() { return "_tf.NextIteration.sink"; }
- Value *getData() { return getOperand(0); }
- void setData(Value *value) { setOperand(0, value); }
+ Value getData() { return getOperand(0); }
+ void setData(Value value) { setOperand(0, value); }
LogicalResult verify();
};
@@ -202,8 +202,8 @@
using Op::Op;
static StringRef getOperationName() { return "_tf.LoopCond"; }
- Value *getData() { return getOperand(0); }
- void setData(Value *value) { setOperand(0, value); }
+ Value getData() { return getOperand(0); }
+ void setData(Value value) { setOperand(0, value); }
LogicalResult verify();
};
@@ -233,11 +233,11 @@
static StringRef getOperationName() { return "_tf.Switch"; }
- Value *getData() { return getOperand(0); }
- void setData(Value *value) { setOperand(0, value); }
+ Value getData() { return getOperand(0); }
+ void setData(Value value) { setOperand(0, value); }
- Value *getPredicate() { return getOperand(1); }
- void setPredicate(Value *value) { setOperand(1, value); }
+ Value getPredicate() { return getOperand(1); }
+ void setPredicate(Value value) { setOperand(1, value); }
LogicalResult verify();
};
@@ -266,8 +266,8 @@
using Op::Op;
static StringRef getOperationName() { return "_tf.Exit"; }
- Value *getData() { return getOperand(0); }
- void setData(Value *value) { setOperand(0, value); }
+ Value getData() { return getOperand(0); }
+ void setData(Value value) { setOperand(0, value); }
LogicalResult verify();
};
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc
index ffba86e..78ac91f 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc
@@ -183,7 +183,7 @@
if (op.getNumOperands()) {
*p << '(';
Block& block = op.body().front();
- interleaveComma(block.getArguments(), *p, [&](BlockArgument* arg) {
+ interleaveComma(block.getArguments(), *p, [&](BlockArgument arg) {
const int block_arg_num = arg->getArgNumber();
*p << '[';
p->printOperands(std::next(op.operand_begin(), block_arg_num * n),
@@ -280,7 +280,7 @@
for (auto& replicated_input : replicated_inputs) {
DCHECK_EQ(llvm::size(replicated_input.first), n);
- for (auto* input : replicated_input.first) {
+ for (auto input : replicated_input.first) {
DCHECK(succeeded(
VerifyCompatibleTypes(input->getType(), replicated_input.second)));
state->addOperands(input);
@@ -296,7 +296,7 @@
void ReplicateOp::build(
Builder* builder, OperationState& state, int n,
llvm::ArrayRef<llvm::StringRef> devices,
- llvm::ArrayRef<std::pair<llvm::ArrayRef<Value*>, Type>> replicated_inputs,
+ llvm::ArrayRef<std::pair<llvm::ArrayRef<Value>, Type>> replicated_inputs,
llvm::ArrayRef<Type> replica_output_types) {
BuildReplicateOp(builder, &state, n, devices, replicated_inputs,
replica_output_types);
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td
index 403932e..88cc08a 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td
@@ -185,7 +185,7 @@
let builders = [
OpBuilder<"Builder* builder, OperationState& state, int n, "
"llvm::ArrayRef<llvm::StringRef> devices, "
- "llvm::ArrayRef<std::pair<llvm::ArrayRef<Value*>, Type>>"
+ "llvm::ArrayRef<std::pair<llvm::ArrayRef<Value>, Type>>"
" replicated_inputs, "
"llvm::ArrayRef<Type> replica_output_types">,
OpBuilder<"Builder* builder, OperationState& state, int n, "
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc
index 5a018a3..dd35478 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc
@@ -216,7 +216,7 @@
return fetch.emitOpError() << "does not have enough operands to cover the "
"graph returned values";
for (int i : llvm::seq<int>(0, fetch.getNumOperands())) {
- Value *operand = fetch.getOperand(i);
+ Value operand = fetch.getOperand(i);
// Break out of the loop at the first control operand encountered.
if (operand->getType().isa<ControlType>()) {
if (i != graph.getNumResults())
@@ -536,7 +536,7 @@
<< (switchn.getNumResults() - 1);
auto operand0_type = switchn.getOperand(0)->getType();
- for (Value *result : switchn.outputs())
+ for (Value result : switchn.outputs())
if (operand0_type != result->getType())
return switchn.emitOpError()
<< "type mismatch between data operand and result: "
@@ -824,7 +824,7 @@
namespace {
LogicalResult Verify(NextIterationSourceOp source) {
- Value *token = source.token();
+ Value token = source.token();
if (!token->hasOneUse())
return source.emitOpError() << "expects a single user for produced token";
if (!isa<NextIterationSinkOp>(*token->user_begin()))
@@ -858,7 +858,7 @@
namespace {
LogicalResult Verify(NextIterationSinkOp sink) {
- Value *token = sink.token();
+ Value token = sink.token();
Operation *definingOp = token->getDefiningOp();
if (!definingOp)
return sink.emitOpError() << "expects a token directly produced by a "
@@ -1087,8 +1087,8 @@
YieldOp yield_op = island_op.GetYield();
// Map graph results to inner ops results of single island.
- llvm::SmallVector<Value *, 8> new_rets;
- for (Value *operand : fetch_op.fetches()) {
+ llvm::SmallVector<Value, 8> new_rets;
+ for (Value operand : fetch_op.fetches()) {
// Control results should not be propagated out.
if (operand->getType().isa<ControlType>()) break;
@@ -1097,7 +1097,7 @@
new_rets.push_back(operand);
} else {
// Lookup yield operand in island for inner op result.
- auto result = llvm::cast<OpResult>(operand);
+ auto result = operand->cast<OpResult>();
new_rets.push_back(yield_op.getOperand(result->getResultNumber()));
}
}
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td
index 0f24395..4d5b40a 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td
@@ -514,8 +514,8 @@
);
let builders = [OpBuilder<
- "Builder *builder, OperationState &result, Value *token, "
- "ArrayRef<Value *> operands, ArrayRef<NamedAttribute> attributes = {}",
+ "Builder *builder, OperationState &result, Value token, "
+ "ArrayRef<Value> operands, ArrayRef<NamedAttribute> attributes = {}",
[{
assert(operands.size() >= 1 && "tf_executor.NextIteration.Sink builder "
"expects at least one operand");
@@ -594,7 +594,7 @@
let builders = [OpBuilder<
"Builder *builder, OperationState &result, "
- "ArrayRef<Value *> operands, ArrayRef<NamedAttribute> attributes = {}",
+ "ArrayRef<Value> operands, ArrayRef<NamedAttribute> attributes = {}",
[{
assert(operands.size() >= 1 && "tf_executor.ControlTrigger builder "
"expects at least one operand");
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index a3d0ff1..78724ea 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -1544,8 +1544,8 @@
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
let builders = [
- OpBuilder<"Builder* builder, OperationState& result, Value* x, "
- "Value* y, BoolAttr incompatible_shape_error">
+ OpBuilder<"Builder* builder, OperationState& result, Value x, "
+ "Value y, BoolAttr incompatible_shape_error">
];
let verifier = [{
@@ -1647,8 +1647,8 @@
TF_DerivedOperandTypeAttr Tdim = TF_DerivedOperandTypeAttr<1>;
let builders = [
- OpBuilder<"Builder* builder, OperationState& result, Value* condition, "
- "Value* dim">
+ OpBuilder<"Builder* builder, OperationState& result, Value condition, "
+ "Value dim">
];
}
@@ -1926,6 +1926,102 @@
}];
}
+def TF_FusedBatchNormGradOp : TF_Op<"FusedBatchNormGrad", [NoSideEffect]> {
+ let summary = "Gradient for batch normalization.";
+
+ let description = [{
+Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW".
+The size of 1D Tensors matches the dimension C of the 4D Tensors.
+ }];
+
+ let arguments = (ins
+ F32Tensor:$y_backprop,
+ F32Tensor:$x,
+ F32Tensor:$scale,
+ F32Tensor:$reserve_space_1,
+ F32Tensor:$reserve_space_2,
+
+ DefaultValuedAttr<F32Attr, "0.0001f">:$epsilon,
+ DefaultValuedAttr<TF_ConvnetDataFormatAttr, "NHWC">:$data_format,
+ DefaultValuedAttr<BoolAttr, "true">:$is_training
+ );
+
+ let results = (outs
+ F32Tensor:$x_backprop,
+ F32Tensor:$scale_backprop,
+ F32Tensor:$offset_backprop,
+ F32Tensor:$reserve_space_3,
+ F32Tensor:$reserve_space_4
+ );
+
+ TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+}
+
+def TF_FusedBatchNormGradV2Op : TF_Op<"FusedBatchNormGradV2", [NoSideEffect]> {
+ let summary = "Gradient for batch normalization.";
+
+ let description = [{
+Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW".
+The size of 1D Tensors matches the dimension C of the 4D Tensors.
+ }];
+
+ let arguments = (ins
+ TensorOf<[BF16, F16, F32]>:$y_backprop,
+ TensorOf<[BF16, F16, F32]>:$x,
+ F32Tensor:$scale,
+ F32Tensor:$reserve_space_1,
+ F32Tensor:$reserve_space_2,
+
+ DefaultValuedAttr<F32Attr, "0.0001f">:$epsilon,
+ DefaultValuedAttr<TF_ConvnetDataFormatAttr, "NHWC">:$data_format,
+ DefaultValuedAttr<BoolAttr, "true">:$is_training
+ );
+
+ let results = (outs
+ TensorOf<[BF16, F16, F32]>:$x_backprop,
+ F32Tensor:$scale_backprop,
+ F32Tensor:$offset_backprop,
+ F32Tensor:$reserve_space_3,
+ F32Tensor:$reserve_space_4
+ );
+
+ TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+ TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<3>;
+}
+
+def TF_FusedBatchNormGradV3Op : TF_Op<"FusedBatchNormGradV3", [NoSideEffect]> {
+ let summary = "Gradient for batch normalization.";
+
+ let description = [{
+Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW".
+The size of 1D Tensors matches the dimension C of the 4D Tensors.
+ }];
+
+ let arguments = (ins
+ TensorOf<[BF16, F16, F32]>:$y_backprop,
+ TensorOf<[BF16, F16, F32]>:$x,
+ F32Tensor:$scale,
+ F32Tensor:$reserve_space_1,
+ F32Tensor:$reserve_space_2,
+ F32Tensor:$reserve_space_3,
+
+ DefaultValuedAttr<F32Attr, "0.0001f">:$epsilon,
+ DefaultValuedAttr<TF_ConvnetDataFormatAttr, "NHWC">:$data_format,
+ DefaultValuedAttr<BoolAttr, "true">:$is_training
+ );
+
+ let results = (outs
+ TensorOf<[BF16, F16, F32]>:$x_backprop,
+ F32Tensor:$scale_backprop,
+ F32Tensor:$offset_backprop,
+ F32Tensor:$reserve_space_4,
+ F32Tensor:$reserve_space_5
+ );
+
+ TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+ TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<3>;
+}
+
def TF_FusedBatchNormV3Op : TF_Op<"FusedBatchNormV3", [NoSideEffect]> {
let summary = "Batch normalization.";
@@ -2640,6 +2736,31 @@
let hasCanonicalizer = 1;
}
+def TF_Log1pOp : TF_Op<"Log1p", [NoSideEffect, SameOperandsAndResultType]> {
+ let summary = "Computes natural logarithm of (1 + x) element-wise.";
+
+ let description = [{
+I.e., \\(y = \log_e (1 + x)\\).
+
+Example:
+
+```python
+x = tf.constant([0, 0.5, 1, 5])
+tf.math.log1p(x) ==> [0., 0.4054651, 0.6931472, 1.7917595]
+```
+ }];
+
+ let arguments = (ins
+ TF_FpOrComplexTensor:$x
+ );
+
+ let results = (outs
+ TF_FpOrComplexTensor:$y
+ );
+
+ TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+}
+
def TF_LogSoftmaxOp : TF_Op<"LogSoftmax", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes log softmax activations.";
@@ -3257,8 +3378,8 @@
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
let builders = [OpBuilder<
- "Builder *builder, OperationState &result, Value *input, "
- "Value *reduction_indices, BoolAttr keep_dims"
+ "Builder *builder, OperationState &result, Value input, "
+ "Value reduction_indices, BoolAttr keep_dims"
>];
}
@@ -3669,8 +3790,8 @@
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
let builders = [
- OpBuilder<"Builder* builder, OperationState& result, Value* x, "
- "Value* y, BoolAttr incompatible_shape_error">
+ OpBuilder<"Builder* builder, OperationState& result, Value x, "
+ "Value y, BoolAttr incompatible_shape_error">
];
let verifier = [{
@@ -3788,8 +3909,8 @@
TF_DerivedOperandTypeAttr TI = TF_DerivedOperandTypeAttr<0>;
let builders = [
- OpBuilder<"Builder* builder, OperationState& result, Value* indices, "
- "Value* depth, Value* on_value, Value* off_value, "
+ OpBuilder<"Builder* builder, OperationState& result, Value indices, "
+ "Value depth, Value on_value, Value off_value, "
"IntegerAttr axis">
];
@@ -4223,8 +4344,8 @@
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<0>;
let builders = [
- OpBuilder<"Builder* builder, OperationState& result, Value* start, "
- "Value* limit, Value* delta">
+ OpBuilder<"Builder* builder, OperationState& result, Value start, "
+ "Value limit, Value delta">
];
}
@@ -4258,7 +4379,7 @@
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
let builders = [
- OpBuilder<"Builder* builder, OperationState& result, Value* input">
+ OpBuilder<"Builder* builder, OperationState& result, Value input">
];
}
@@ -4494,7 +4615,7 @@
let builders = [
OpBuilder<
- "Builder* builder, OperationState& result, Value* tensor, Value* shape">
+ "Builder* builder, OperationState& result, Value tensor, Value shape">
];
let verifier = [{
@@ -4951,7 +5072,7 @@
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
let builders = [
- OpBuilder<"Builder* builder, OperationState& result, Value* condition, Value* e, Value* t">
+ OpBuilder<"Builder* builder, OperationState& result, Value condition, Value e, Value t">
];
}
@@ -4985,7 +5106,7 @@
}];
let builders = [
- OpBuilder<"Builder* builder, OperationState& result, Value* input, BoolAttr use32Bit">
+ OpBuilder<"Builder* builder, OperationState& result, Value input, BoolAttr use32Bit">
];
let hasFolder = 1;
@@ -5836,8 +5957,8 @@
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
let builders = [OpBuilder<
- "Builder *builder, OperationState &result, Value *input, "
- "Value *reduction_indices, BoolAttr keep_dims"
+ "Builder *builder, OperationState &result, Value input, "
+ "Value reduction_indices, BoolAttr keep_dims"
>];
}
@@ -6164,6 +6285,103 @@
}];
}
+def TF_TensorScatterUpdateOp : TF_Op<"TensorScatterUpdate", [NoSideEffect]> {
+ let summary = [{
+Scatter `updates` into an existing tensor according to `indices`.
+ }];
+
+ let description = [{
+This operation creates a new tensor by applying sparse `updates` to the passed
+in `tensor`.
+This operation is very similar to `tf.scatter_nd`, except that the updates are
+scattered onto an existing tensor (as opposed to a zero-tensor). If the memory
+for the existing tensor cannot be re-used, a copy is made and updated.
+
+If `indices` contains duplicates, then their updates are accumulated (summed).
+
+**WARNING**: The order in which updates are applied is nondeterministic, so the
+output will be nondeterministic if `indices` contains duplicates -- because
+of some numerical approximation issues, numbers summed in different order
+may yield different results.
+
+`indices` is an integer tensor containing indices into a new tensor of shape
+`shape`. The last dimension of `indices` can be at most the rank of `shape`:
+
+ indices.shape[-1] <= shape.rank
+
+The last dimension of `indices` corresponds to indices into elements
+(if `indices.shape[-1] = shape.rank`) or slices
+(if `indices.shape[-1] < shape.rank`) along dimension `indices.shape[-1]` of
+`shape`. `updates` is a tensor with shape
+
+ indices.shape[:-1] + shape[indices.shape[-1]:]
+
+The simplest form of scatter is to insert individual elements in a tensor by
+index. For example, say we want to insert 4 scattered elements in a rank-1
+tensor with 8 elements.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="https://www.tensorflow.org/images/ScatterNd1.png" alt>
+</div>
+
+In Python, this scatter operation would look like this:
+
+ >>> indices = tf.constant([[4], [3], [1], [7]])
+ >>> updates = tf.constant([9, 10, 11, 12])
+ >>> tensor = tf.ones([8], dtype=tf.int32)
+ >>> print(tf.tensor_scatter_nd_update(tensor, indices, updates))
+ tf.Tensor([ 1 11 1 10 9 1 1 12], shape=(8,), dtype=int32)
+
+We can also, insert entire slices of a higher rank tensor all at once. For
+example, if we wanted to insert two slices in the first dimension of a
+rank-3 tensor with two matrices of new values.
+
+In Python, this scatter operation would look like this:
+
+ >>> indices = tf.constant([[0], [2]])
+ >>> updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],
+ ... [7, 7, 7, 7], [8, 8, 8, 8]],
+ ... [[5, 5, 5, 5], [6, 6, 6, 6],
+ ... [7, 7, 7, 7], [8, 8, 8, 8]]])
+ >>> tensor = tf.ones([4, 4, 4], dtype=tf.int32)
+ >>> print(tf.tensor_scatter_nd_update(tensor, indices, updates).numpy())
+ [[[5 5 5 5]
+ [6 6 6 6]
+ [7 7 7 7]
+ [8 8 8 8]]
+ [[1 1 1 1]
+ [1 1 1 1]
+ [1 1 1 1]
+ [1 1 1 1]]
+ [[5 5 5 5]
+ [6 6 6 6]
+ [7 7 7 7]
+ [8 8 8 8]]
+ [[1 1 1 1]
+ [1 1 1 1]
+ [1 1 1 1]
+ [1 1 1 1]]]
+
+Note that on CPU, if an out of bound index is found, an error is returned.
+On GPU, if an out of bound index is found, the index is ignored.
+ }];
+
+ let arguments = (ins
+ TF_Tensor:$tensor,
+ TF_I32OrI64Tensor:$indices,
+ TF_Tensor:$updates
+ );
+
+ let results = (outs
+ TF_Tensor:$output
+ );
+
+ TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>;
+ TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+
+ let verifier = [{ return Verify(*this); }];
+}
+
def TF_TileOp : TF_Op<"Tile", [NoSideEffect]> {
let summary = "Constructs a tensor by tiling a given tensor.";
@@ -6270,7 +6488,7 @@
let builders = [
OpBuilder<
- "Builder* builder, OperationState& result, Value* x, Value* perm">
+ "Builder* builder, OperationState& result, Value x, Value perm">
];
let verifier = [{
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td
index c3a5161..a63276b 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td
@@ -171,6 +171,8 @@
// Any integer or floating-point tensor types
def TF_IntOrFpTensor : TensorOf<[TF_Int, AnyFloat]>;
+def TF_SintOrFpTensor : TensorOf<[TF_SInt, AnyFloat]>;
+
def TF_FpOrComplexTensor : TensorOf<[AnyFloat, TF_AnyComplex]>;
def TF_AnyNumber : AnyTypeOf<[TF_Int, AnyFloat, TF_AnyQuantized, TF_AnyComplex],
@@ -297,7 +299,7 @@
// behavior. The result type has the same element type as both operands.
class WithBroadcastableBinOpBuilder {
list<OpBuilder> builders = [OpBuilder<
-"Builder *builder, OperationState &result, Value* x, Value* y",
+"Builder *builder, OperationState &result, Value x, Value y",
[{
auto resultType =
OpTrait::util::getBroadcastedType(x->getType(), y->getType());
@@ -312,7 +314,7 @@
// behavior. The result type has bool element type.
class WithBroadcastableCmpOpBuilder {
list<OpBuilder> builders = [OpBuilder<
-"Builder *builder, OperationState &result, Value* x, Value* y",
+"Builder *builder, OperationState &result, Value x, Value y",
[{
Type resultType;
if (x->getType().isa<UnrankedTensorType>() ||
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
index 9dded9c..62d2af2 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
@@ -72,7 +72,7 @@
// may have non-static shape because the shape is not propagated during constant
// folding. If the defining op for the given operand is a constant op, this
// routine uses the constant op's attribute to get the actual shape.
-static RankedTensorType GetRankedTensorTypeForOperand(Value *operand) {
+static RankedTensorType GetRankedTensorTypeForOperand(Value operand) {
DenseElementsAttr attr;
if (matchPattern(operand, m_Constant(&attr))) {
return attr.getType().dyn_cast<RankedTensorType>();
@@ -82,7 +82,7 @@
// Returns true if the given `value` is of ranked float tensor type with the
// given `rank`.
-static inline bool isOfRankedFloatTensorType(Value *value, int rank) {
+static inline bool isOfRankedFloatTensorType(Value value, int rank) {
RankedTensorType type = GetRankedTensorTypeForOperand(value);
return type && type.getRank() == rank &&
type.getElementType().isa<FloatType>();
@@ -90,21 +90,21 @@
// Returns true if the given `value` has the specified rank or has unranked
// type.
-static inline bool IsOfRankOrUnranked(Value *value, int64_t rank) {
+static inline bool IsOfRankOrUnranked(Value value, int64_t rank) {
RankedTensorType type = GetRankedTensorTypeForOperand(value);
return !type || type.getRank() == rank;
}
// Returns true if the given `value` has at least the specified rank or has
// unranked type.
-static inline bool HasRankAtLeast(Value *value, int64_t rank) {
+static inline bool HasRankAtLeast(Value value, int64_t rank) {
RankedTensorType type = GetRankedTensorTypeForOperand(value);
return !type || type.getRank() >= rank;
}
// Returns true if the given `value` has at most the specified rank or has
// unranked type.
-static inline bool HasRankAtMost(Value *value, int64_t rank) {
+static inline bool HasRankAtMost(Value value, int64_t rank) {
RankedTensorType type = GetRankedTensorTypeForOperand(value);
return !type || type.getRank() <= rank;
}
@@ -158,8 +158,8 @@
// Returns the tf.Equal/tf.NotEqual result type given `x` and `y` and inputs. If
// `incompatible_shape_error` is true, reports error if `x` and `y` has
// incompatible shapes. Otherwise, returns a tensor type with unknown rank.
-static Type DeduceEqualCmpOpType(Builder *builder, Location loc, Value *x,
- Value *y, BoolAttr incompatible_shape_error) {
+static Type DeduceEqualCmpOpType(Builder *builder, Location loc, Value x,
+ Value y, BoolAttr incompatible_shape_error) {
auto result_type =
OpTrait::util::getBroadcastedType(x->getType(), y->getType());
if (!result_type) {
@@ -185,7 +185,7 @@
// Infers output type for reduction ops such as SumOp, MaxOp etc.
// TODO(b/e667204a): Move this logic to shape inference once it supports custom
// inference functions.
-static Type InferReductionOpType(Value *input, Value *reduction_indices,
+static Type InferReductionOpType(Value input, Value reduction_indices,
BoolAttr keep_dims, Builder *builder) {
Type input_ty = input->getType();
Type element_ty = getElementTypeOrSelf(input_ty);
@@ -328,7 +328,7 @@
//===----------------------------------------------------------------------===//
// Verifies an reduction op's `input` and reduction `dims`.
-static LogicalResult VerifyReductionInputAndDims(Value *input, Value *dims,
+static LogicalResult VerifyReductionInputAndDims(Value input, Value dims,
Location loc) {
auto dims_type = dims->getType().dyn_cast<RankedTensorType>();
if (!dims_type) return success();
@@ -528,7 +528,7 @@
Operation::operand_range values = op.values();
int axis_idx = std::is_same<OpT, ConcatOp>() ? 0 : 1;
- Value *axis = *op.getODSOperands(axis_idx).begin();
+ Value axis = *op.getODSOperands(axis_idx).begin();
if (!HasRankAtMost(axis, 1)) {
return op.emitOpError(
"requires axis to be of scalar type (or vector type for older "
@@ -561,8 +561,8 @@
int64_t num_dims = -1;
for (auto shape_offset_idx :
llvm::enumerate(llvm::zip(op.shape(), op.offset()))) {
- Value *shape = std::get<0>(shape_offset_idx.value());
- Value *offset = std::get<1>(shape_offset_idx.value());
+ Value shape = std::get<0>(shape_offset_idx.value());
+ Value offset = std::get<1>(shape_offset_idx.value());
const size_t idx = shape_offset_idx.index();
if (failed(verifyCompatibleShape(shape->getType(), offset->getType())))
@@ -860,7 +860,7 @@
int32_t max_index = -1;
llvm::Optional<SmallVector<int64_t, 4>> inferred_item_shape;
for (auto it : llvm::zip(op.indices(), op.data())) {
- Value *index = std::get<0>(it);
+ Value index = std::get<0>(it);
DenseIntElementsAttr index_attr;
if (matchPattern(index, m_Constant(&index_attr))) {
@@ -875,7 +875,7 @@
all_indices_const = false;
}
- Value *data = std::get<1>(it);
+ Value data = std::get<1>(it);
RankedTensorType index_ty = index->getType().dyn_cast<RankedTensorType>();
RankedTensorType data_ty = data->getType().dyn_cast<RankedTensorType>();
if (!index_ty || !data_ty) continue;
@@ -981,8 +981,8 @@
op.getOperation());
}
-void EqualOp::build(Builder *builder, OperationState &result, Value *x,
- Value *y, BoolAttr incompatible_shape_error) {
+void EqualOp::build(Builder *builder, OperationState &result, Value x, Value y,
+ BoolAttr incompatible_shape_error) {
auto result_type = DeduceEqualCmpOpType(builder, result.location, x, y,
incompatible_shape_error);
return build(builder, result, result_type, x, y, incompatible_shape_error);
@@ -992,7 +992,7 @@
// ExpandDimsOp
//===----------------------------------------------------------------------===//
-Type InferExpandDimsOpType(Value *input, Value *dim) {
+Type InferExpandDimsOpType(Value input, Value dim) {
Type element_ty = input->getType().cast<TensorType>().getElementType();
auto unranked_ty = UnrankedTensorType::get(element_ty);
@@ -1014,8 +1014,8 @@
return RankedTensorType::get(shape, element_ty);
}
-void ExpandDimsOp::build(Builder *builder, OperationState &result, Value *input,
- Value *dim) {
+void ExpandDimsOp::build(Builder *builder, OperationState &result, Value input,
+ Value dim) {
return build(builder, result, InferExpandDimsOpType(input, dim), input, dim);
}
@@ -1074,7 +1074,7 @@
if (!isOfRankedFloatTensorType(op.max(), 1))
return op.emitOpError("requires max to be a 1d float tensor");
- Value *inputs = op.inputs();
+ Value inputs = op.inputs();
if (!HasRankAtLeast(inputs, 1) ||
inputs->getType().isa<UnrankedTensorType>()) {
return op.emitError("requires inputs to be at least 1d float tensor");
@@ -1304,8 +1304,8 @@
// MaxOp
//===----------------------------------------------------------------------===//
-void MaxOp::build(Builder *builder, OperationState &result, Value *input,
- Value *reduction_indices, BoolAttr keep_dims) {
+void MaxOp::build(Builder *builder, OperationState &result, Value input,
+ Value reduction_indices, BoolAttr keep_dims) {
Type out_ty =
InferReductionOpType(input, reduction_indices, keep_dims, builder);
build(builder, result, out_ty, input, reduction_indices, keep_dims);
@@ -1350,8 +1350,8 @@
op.getOperation());
}
-void NotEqualOp::build(Builder *builder, OperationState &result, Value *x,
- Value *y, BoolAttr incompatible_shape_error) {
+void NotEqualOp::build(Builder *builder, OperationState &result, Value x,
+ Value y, BoolAttr incompatible_shape_error) {
auto result_type = DeduceEqualCmpOpType(builder, result.location, x, y,
incompatible_shape_error);
return build(builder, result, result_type, x, y, incompatible_shape_error);
@@ -1400,9 +1400,8 @@
return success();
}
-static TensorType InferOneHotOpType(Value *indices, Value *depth,
- Value *on_value, Value *off_value,
- IntegerAttr axis) {
+static TensorType InferOneHotOpType(Value indices, Value depth, Value on_value,
+ Value off_value, IntegerAttr axis) {
int64_t axis_val = axis.getInt();
Type element_ty = on_value->getType().cast<TensorType>().getElementType();
auto unranked_ty = UnrankedTensorType::get(element_ty);
@@ -1423,8 +1422,8 @@
return RankedTensorType::get(shape, element_ty);
}
-void OneHotOp::build(Builder *builder, OperationState &result, Value *indices,
- Value *depth, Value *on_value, Value *off_value,
+void OneHotOp::build(Builder *builder, OperationState &result, Value indices,
+ Value depth, Value on_value, Value off_value,
IntegerAttr axis) {
build(builder, result,
InferOneHotOpType(indices, depth, on_value, off_value, axis), indices,
@@ -1446,7 +1445,7 @@
}
int64_t inputs_rank = -1;
- for (Value *value : values) {
+ for (Value value : values) {
if (auto ty = value->getType().dyn_cast<RankedTensorType>()) {
// Exit early as input types are verified to be compatible so all ranked
// tensors have the same rank.
@@ -1472,6 +1471,59 @@
}
//===----------------------------------------------------------------------===//
+// ParseExampleV2Op
+//===----------------------------------------------------------------------===//
+
+static LogicalResult Verify(ParseExampleV2Op op) {
+ // NOTE(mrry): This validates properties of an op that would previously be
+ // validated by the TensorFlow OpDef type checker. In addition to these
+ // checks, the shape inference function for ParseExampleV2 validates the
+ // consistency of the argument and result types.
+
+ // Validate dense variadic input and output lengths.
+ // NOTE(mrry): The Tdense attr is derived from dense_defaults, so we
+ // do not need to validate dense_defaults.
+ auto dense_types_count =
+ std::distance(op.Tdense().begin(), op.Tdense().end());
+ auto dense_values_count =
+ std::distance(op.dense_values().begin(), op.dense_values().end());
+ if (dense_values_count != dense_types_count) {
+ return op.emitError() << "output 'dense_values' should have same length "
+ << "as attribute 'Tdense'";
+ }
+
+ // Validate sparse variadic output lengths.
+ // NOTE(mrry): The sparse_types attr is derived from sparse_values, so we
+ // do not need to validate sparse_values.
+ auto sparse_types_count =
+ std::distance(op.sparse_types().begin(), op.sparse_types().end());
+ if (op.num_sparse() != sparse_types_count) {
+ return op.emitError() << "attribute 'num_sparse' should be the same as "
+ << "the length of attribute 'sparse_types'";
+ }
+ if (op.sparse_indices().size() != sparse_types_count) {
+ return op.emitError() << "output 'sparse_indices' should have same length "
+ << "as attribute 'sparse_types'";
+ }
+ if (op.sparse_shapes().size() != sparse_types_count) {
+ return op.emitError() << "output 'sparse_shapes' should have same length "
+ << "as attribute 'sparse_types'";
+ }
+
+ // Validate ragged variadic output lengths.
+ auto ragged_value_types_count = std::distance(op.ragged_value_types().begin(),
+ op.ragged_value_types().end());
+ auto ragged_split_types_count = std::distance(op.ragged_split_types().begin(),
+ op.ragged_split_types().end());
+ if (ragged_value_types_count != ragged_split_types_count) {
+ return op.emitError() << "attribute 'ragged_value_types' should have same "
+ << "length as attribute 'ragged_split_types'";
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// ReciprocalOp
//===----------------------------------------------------------------------===//
@@ -1494,8 +1546,8 @@
// RangeOp
//===----------------------------------------------------------------------===//
-void RangeOp::build(Builder *builder, OperationState &result, Value *start,
- Value *limit, Value *delta) {
+void RangeOp::build(Builder *builder, OperationState &result, Value start,
+ Value limit, Value delta) {
assert(start->getType() == limit->getType());
assert(start->getType() == delta->getType());
DenseIntElementsAttr start_val;
@@ -1524,7 +1576,7 @@
// RankOp
//===----------------------------------------------------------------------===//
-void RankOp::build(Builder *builder, OperationState &result, Value *input) {
+void RankOp::build(Builder *builder, OperationState &result, Value input) {
return RankOp::build(builder, result,
RankedTensorType::get({}, builder->getIntegerType(32)),
input);
@@ -1608,8 +1660,8 @@
return success();
}
-void ReshapeOp::build(Builder *builder, OperationState &result, Value *tensor,
- Value *shape) {
+void ReshapeOp::build(Builder *builder, OperationState &result, Value tensor,
+ Value shape) {
auto ttype = tensor->getType().cast<ShapedType>();
auto etype = ttype.getElementType();
@@ -1670,7 +1722,7 @@
// SelectV2Op
//===----------------------------------------------------------------------===//
-static Type InferSelectV2OpType(Value *condition, Value *e, Value *t) {
+static Type InferSelectV2OpType(Value condition, Value e, Value t) {
Type element_ty = e->getType().cast<TensorType>().getElementType();
auto unranked_ty = UnrankedTensorType::get(element_ty);
@@ -1693,7 +1745,7 @@
}
void SelectV2Op::build(Builder *builder, OperationState &result,
- Value *condition, Value *e, Value *t) {
+ Value condition, Value e, Value t) {
build(builder, result, InferSelectV2OpType(condition, e, t), condition, e, t);
}
@@ -1767,7 +1819,7 @@
return ConvertShapeToAttr(getOperand()->getType(), width);
}
-void ShapeOp::build(Builder *builder, OperationState &result, Value *input,
+void ShapeOp::build(Builder *builder, OperationState &result, Value input,
BoolAttr use32Bit) {
auto rankedTensorType = input->getType().dyn_cast<RankedTensorType>();
int64_t rank = rankedTensorType ? rankedTensorType.getRank() : -1;
@@ -1967,7 +2019,7 @@
LogicalResult VerifySplitInputAndSplitDim(Op op, Optional<int64_t> *dim_index) {
*dim_index = llvm::None;
- Value *split_dim = op.split_dim();
+ Value split_dim = op.split_dim();
if (auto split_dim_type = split_dim->getType().dyn_cast<RankedTensorType>())
if (split_dim_type.getRank() != 0)
return op.emitOpError(
@@ -2101,8 +2153,8 @@
// SumOp
//===----------------------------------------------------------------------===//
-void SumOp::build(Builder *builder, OperationState &result, Value *input,
- Value *reduction_indices, BoolAttr keep_dims) {
+void SumOp::build(Builder *builder, OperationState &result, Value input,
+ Value reduction_indices, BoolAttr keep_dims) {
Type out_ty =
InferReductionOpType(input, reduction_indices, keep_dims, builder);
build(builder, result, out_ty, input, reduction_indices, keep_dims);
@@ -2125,7 +2177,7 @@
// Expected size for operands begin, end and strides vector operands.
int64_t expected_size = -1;
- for (Value *val : {op.begin(), op.end(), op.strides()}) {
+ for (Value val : {op.begin(), op.end(), op.strides()}) {
auto operand_ty = val->getType().dyn_cast<ShapedType>();
if (!operand_ty || !operand_ty.hasStaticShape()) {
// TensorFlow constant ops may have non-static shape because the shape is
@@ -2367,6 +2419,35 @@
}
//===----------------------------------------------------------------------===//
+// TensorScatterUpdateOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult Verify(TensorScatterUpdateOp op) {
+ if (!HasRankAtLeast(op.tensor(), 1))
+ return op.emitOpError(
+ "requires tensor operand to have at least 1 dimension");
+ if (!HasRankAtLeast(op.indices(), 1))
+ return op.emitOpError(
+ "requires indices operand to have at least 1 dimension");
+ if (!HasRankAtLeast(op.updates(), 1))
+ return op.emitOpError(
+ "requires updates operand to have at least 1 dimension");
+
+ auto tensor_ty = op.tensor()->getType().dyn_cast<RankedTensorType>();
+ auto indices_ty = op.indices()->getType().dyn_cast<RankedTensorType>();
+ if (!tensor_ty || !indices_ty) return success();
+
+ int64_t num_index_dims = indices_ty.getShape().back();
+ if (ShapedType::isDynamic(num_index_dims)) return success();
+
+ if (num_index_dims > tensor_ty.getRank())
+ return op.emitOpError(
+ "requires tensor operand with rank greater than or equal to the "
+ "indices operand's last dimensions");
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// TopKV2Op
//===----------------------------------------------------------------------===//
@@ -2395,8 +2476,8 @@
}
// TODO(jpienaar): perm could be optional too.
-void TransposeOp::build(Builder *builder, OperationState &result, Value *x,
- Value *perm) {
+void TransposeOp::build(Builder *builder, OperationState &result, Value x,
+ Value perm) {
auto x_type = x->getType().cast<TensorType>();
// If value is unranked, then so is results.
if (!x_type.hasRank())
@@ -2679,7 +2760,7 @@
// operation that takes 'input' as the only operand, and produces a single
// result of 'resultType'. If a conversion can not be generated, nullptr
// should be returned.
- Operation *materializeCallConversion(OpBuilder &builder, Value *input,
+ Operation *materializeCallConversion(OpBuilder &builder, Value input,
Type result_type,
Location conversion_loc) const final {
if (!result_type.isa<TensorType>() || !input->getType().isa<TensorType>())
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
index 9b6196c..620690d 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
@@ -29,6 +29,7 @@
include "tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td"
include "mlir/Analysis/CallInterfaces.td"
+include "mlir/IR/OpBase.td"
class TF_TensorListInitOp<string mnemonic> : TF_Op<mnemonic, [NoSideEffect]> {
let results = (outs
@@ -232,6 +233,50 @@
}];
}
+def TF_ParseExampleV2Op : TF_Op<"ParseExampleV2",
+ [NoSideEffect,
+ AttrSizedResultSegments]> {
+
+ let summary =
+ "Transforms a vector of tf.Example protos (as strings) into typed tensors.";
+
+ let arguments = (ins
+ TF_StrTensor:$serialized,
+ TF_StrTensor:$names,
+ TF_StrTensor:$sparse_keys,
+ TF_StrTensor:$dense_keys,
+ TF_StrTensor:$ragged_keys,
+ Variadic<TensorOf<[F32, I64, TF_Str]>>:$dense_defaults,
+
+ Confined<I64Attr, [IntMinValue<0>]>:$num_sparse,
+ I32ElementsAttr:$result_segment_sizes
+ );
+
+ let results = (outs
+ Variadic<I64Tensor>:$sparse_indices, // len(sparse_types)
+ Variadic<TensorOf<[F32, I64, TF_Str]>>:$sparse_values, // len(sparse_types)
+ Variadic<I64Tensor>:$sparse_shapes, // len(sparse_types)
+ Variadic<TensorOf<[F32, I64, TF_Str]>>:$dense_values, // len(Tdense)
+ Variadic<TensorOf<[F32, I64, TF_Str]>>:$ragged_values, // len(ragged_value_types)
+ // = len(ragged_split_types)
+ Variadic<TensorOf<[I32, I64]>>:$ragged_row_splits // len(ragged_split_types)
+ // = len(ragged_value_types)
+ );
+
+ // The Verify(ParseExampleV2Op) function validates that the lengths and types
+ // of these attrs are compatible.
+ TF_DerivedOperandTypeListAttr Tdense = TF_DerivedOperandTypeListAttr<5>;
+ TF_DerivedResultTypeListAttr sparse_types = TF_DerivedResultTypeListAttr<1>;
+ TF_DerivedResultTypeListAttr ragged_value_types =
+ TF_DerivedResultTypeListAttr<4>;
+ TF_DerivedResultTypeListAttr ragged_split_types =
+ TF_DerivedResultTypeListAttr<5>;
+
+ let verifier = [{
+ return Verify(*this);
+ }];
+}
+
def TF_PartitionedCallOp : TF_Op<"PartitionedCall",
[CallOpInterface, NoSideEffect]> {
let summary =
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-if-ops.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-if-ops.pbtxt
index cbfa973..8eca308 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-if-ops.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-if-ops.pbtxt
@@ -1,11 +1,11 @@
-# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=a,b -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulIf,StatelessIf -o - | FileCheck %s
+# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=a,b -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulIf,StatelessIf -o - -mlir-print-debuginfo | FileCheck %s
# Verify that TensorFlow If and StatelessIf ops are mapped to the
# composite If op in MLIR with is_stateless attribute set accordingly to
# distinguish between them.
-# CHECK-DAG: "tf.If"{{.*}} is_stateless = false, name = "StatefulIf"
-# CHECK-DAG: "tf.If"{{.*}} is_stateless = true, name = "StatelessIf"
+# CHECK-DAG: "tf.If"{{.*}} is_stateless = false{{.*}} loc("StatefulIf")
+# CHECK-DAG: "tf.If"{{.*}} is_stateless = true{{.*}} loc("StatelessIf")
node {
name: "tf.Less"
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-while-ops.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-while-ops.pbtxt
index 953f83a..ede01ebf 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-while-ops.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-while-ops.pbtxt
@@ -1,11 +1,11 @@
-# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=iter,val -tf-input-data-types=DT_INT32,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulWhile:1,StatelessWhile:1 -o - | FileCheck %s
+# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=iter,val -tf-input-data-types=DT_INT32,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulWhile:1,StatelessWhile:1 -o - -mlir-print-debuginfo | FileCheck %s
# Verify that TensorFlow While and StatelessWhile ops are mapped to the
# composite While op in MLIR with is_stateless attribute set accordingly to
# distinguish between them.
-# CHECK-DAG: "tf.While"{{.*}} is_stateless = false, name = "StatefulWhile"
-# CHECK-DAG: "tf.While"{{.*}} is_stateless = true, name = "StatelessWhile"
+# CHECK-DAG: "tf.While"{{.*}} is_stateless = false{{.*}} loc("StatefulWhile")
+# CHECK-DAG: "tf.While"{{.*}} is_stateless = true{{.*}} loc("StatelessWhile")
node {
name: "StatefulWhile"
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/mlir_passthrough_op.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/mlir_passthrough_op.pbtxt
index 1df903d..da79023 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/mlir_passthrough_op.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/mlir_passthrough_op.pbtxt
@@ -1,7 +1,7 @@
# RUN: tf-mlir-translate -graphdef-to-mlir %s | FileCheck %s
# CHECK:"tf.MlirPassthroughOp"
-# CHECK: mlir_module = "\0Afunc @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32> {\0A %add = \22tf.Add\22(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>\0A %ret = \22magic.op\22(%add, %add) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32>\0A return %ret : tensor<10x10xf32>\0A}\0A", name = "MlirPassthroughOp"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32>
+# CHECK: mlir_module = "\0Afunc @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32> {\0A %add = \22tf.Add\22(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>\0A %ret = \22magic.op\22(%add, %add) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32>\0A return %ret : tensor<10x10xf32>\0A}\0A"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32>
node {
name: "x"
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/node-locations.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/node-locations.pbtxt
index a8f58c4..fdf279f 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/node-locations.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/node-locations.pbtxt
@@ -90,6 +90,6 @@
}
# TODO(b/142400497): What is the semantic contract for locations?
-# CHECK: "tf.Const"{{.*}}value = dense<2>{{.*}}loc(fused["n1@f1", "n2@f2"])
+# CHECK: "tf.Const"{{.*}}value = dense<2>{{.*}}loc(fused["n1@f1", "n2@f2", "fused_node_outside_function"])
# CHECK: "tf.Const"{{.*}}value = dense<0>{{.*}}loc("node_outside_function")
# CHECK: "tf.Const"{{.*}}value = dense<1>{{.*}}loc("node_inside_function@foo")
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/parse_example.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/parse_example.pbtxt
new file mode 100644
index 0000000..7411a5e
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/parse_example.pbtxt
@@ -0,0 +1,225 @@
+# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=input0 -tf-input-data-types=DT_STRING -tf-input-shapes=32 -tf-output-arrays=ParseExample/ParseExampleV2:0,ParseExample/ParseExampleV2:7 -o - | FileCheck %s
+
+# CHECK: %[[parse_example:.*]]:8, %[[parse_example_control:.*]] = tf_executor.island wraps "tf.ParseExampleV2"(%arg0,
+# CHECK: result_segment_sizes = dense<[2, 2, 2, 2, 0, 0]> : vector<6xi32>
+# CHECK: tf_executor.fetch %[[parse_example]]#0, %[[parse_example]]#7 : tensor<?x2xi64>, tensor<32xf32>
+
+node {
+ name: "input0"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ unknown_rank: true
+ }
+ }
+ }
+}
+node {
+ name: "ParseExample/Const"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ }
+ }
+ }
+ }
+ }
+}
+node {
+ name: "ParseExample/Const_1"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ }
+ }
+ }
+ }
+ }
+}
+node {
+ name: "ParseExample/ParseExampleV2/names"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_STRING
+ tensor_shape {
+ dim {
+ }
+ }
+ }
+ }
+ }
+}
+node {
+ name: "ParseExample/ParseExampleV2/sparse_keys"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_STRING
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ }
+ string_val: "feature_key3"
+ string_val: "feature_key4"
+ }
+ }
+ }
+}
+node {
+ name: "ParseExample/ParseExampleV2/dense_keys"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_STRING
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ }
+ string_val: "feature_key1"
+ string_val: "feature_key2"
+ }
+ }
+ }
+}
+node {
+ name: "ParseExample/ParseExampleV2/ragged_keys"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_STRING
+ tensor_shape {
+ dim {
+ }
+ }
+ }
+ }
+ }
+}
+node {
+ name: "ParseExample/ParseExampleV2"
+ op: "ParseExampleV2"
+ input: "input0"
+ input: "ParseExample/ParseExampleV2/names"
+ input: "ParseExample/ParseExampleV2/sparse_keys"
+ input: "ParseExample/ParseExampleV2/dense_keys"
+ input: "ParseExample/ParseExampleV2/ragged_keys"
+ input: "ParseExample/Const"
+ input: "ParseExample/Const_1"
+ attr {
+ key: "Tdense"
+ value {
+ list {
+ type: DT_FLOAT
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
+ key: "dense_shapes"
+ value {
+ list {
+ shape {
+ }
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "num_sparse"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "ragged_split_types"
+ value {
+ list {
+ }
+ }
+ }
+ attr {
+ key: "ragged_value_types"
+ value {
+ list {
+ }
+ }
+ }
+ attr {
+ key: "sparse_types"
+ value {
+ list {
+ type: DT_STRING
+ type: DT_INT64
+ }
+ }
+ }
+}
+versions {
+ producer: 175
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/quint8-const.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/quint8-const.pbtxt
index 748bc99..cf8051f 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/quint8-const.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/quint8-const.pbtxt
@@ -1,4 +1,4 @@
-# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s
+# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - -mlir-print-debuginfo | FileCheck %s
node {
name: "Quantized_Constant"
@@ -28,5 +28,5 @@
}
# CHECK: tf.Const
-# CHECK-SAME: name = "Quantized_Constant"
# CHECK-SAME: value = opaque<"tf", "{{0[xX][0-9a-fA-F]*}}"> : tensor<!tf.quint8>
+# CHECK-SAME: loc("Quantized_Constant")
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/switch_n.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/switch_n.pbtxt
index 3dd5ce5..e819efc 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/switch_n.pbtxt
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/switch_n.pbtxt
@@ -1,13 +1,13 @@
-# RUN: tf-mlir-translate -graphdef-to-splatted-mlir %s -o - | FileCheck %s --dump-input-on-failure
+# RUN: tf-mlir-translate -graphdef-to-splatted-mlir %s -o - -mlir-print-debuginfo | FileCheck %s --dump-input-on-failure
# CHECK: tf_executor.SwitchN
# CHECK-SAME: of 3 : tensor<i32>
# CHECK-SAME: T = i32
-# CHECK-SAME: name = "Case/branch_index/_3"
+# CHECK-SAME: loc("Case/branch_index/_3")
# CHECK: tf_executor.SwitchN
# CHECK-SAME: of 2 : tensor<f32>
# CHECK-SAME: T = f32
-# CHECK-SAME: name = "Case/Case/input_0/_7"
+# CHECK-SAME: loc("Case/Case/input_0/_7")
node {
name: "Case/branch_index"
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir
index 7f3a4c2..c1c5f41 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir
@@ -424,3 +424,10 @@
%0 = "tf.DynamicStitch"(%indices, %arg0) : (tensor<2xi32>, tensor<2x2xf32>) -> tensor<1x2xf32>
return %0 : tensor<1x2xf32>
}
+
+func @Reciprocal(%arg0: tensor<*xf32>) -> tensor<*xf32> {
+ // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
+ // CHECK: "tf.Div"(%[[ONE]], %arg0) : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
+ %0 = "tf.Reciprocal"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/convert_tensor.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/convert_tensor.mlir
index 52e4c52..e6e2272 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/convert_tensor.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/convert_tensor.mlir
@@ -1,8 +1,8 @@
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
func @main() -> (tensor<1x2xf16>, tensor<2xf16>) {
- %0:2 = "_tf.Const"() {device = "", name = "foo", dtype = "tfdtype$DT_HALF", value = dense<1.0> : tensor<1x2xf16>} : () -> (tensor<1x2xf16>, !_tf.control)
- %1:2 = "_tf.Const"() {device = "", name = "bar", dtype = "tfdtype$DT_HALF", value = dense<[1.0, 2.0]> : tensor<2xf16>} : () -> (tensor<2xf16>, !_tf.control)
+ %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_HALF", value = dense<1.0> : tensor<1x2xf16>} : () -> (tensor<1x2xf16>, !_tf.control) loc("foo")
+ %1:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_HALF", value = dense<[1.0, 2.0]> : tensor<2xf16>} : () -> (tensor<2xf16>, !_tf.control) loc("bar")
return %0#0, %1#0 : tensor<1x2xf16>, tensor<2xf16>
// CHECK: node {
@@ -13,4 +13,4 @@
// CHECK-NEXT: op: "Const"
// CHECK: half_val: 15360
// CHECK: half_val: 16384
-}
\ No newline at end of file
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/function-resource-args.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/function-resource-args.mlir
index 24cb7b7..515e03a 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/function-resource-args.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/function-resource-args.mlir
@@ -2,7 +2,7 @@
func @main() -> tensor<*x!tf.resource> attributes {tf.entry_function = {inputs = "", outputs = "func_call"}} {
%0 = tf_executor.graph {
- %outputs, %control = tf_executor.island wraps "tf.VarHandleOp"() {container = "a", device = "/CPU:0", dtype = i64, name = "x", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource<tensor<i64>>>
+ %outputs, %control = tf_executor.island wraps "tf.VarHandleOp"() {container = "a", device = "/CPU:0", dtype = i64, shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource<tensor<i64>>> loc("x")
%outputs_0, %control_1 = tf_executor.island wraps "tf.LegacyCall"(%outputs, %outputs) {_disable_call_shape_inference = true, f = @test_func_name0} : (tensor<!tf.resource<tensor<i64>>>, tensor<!tf.resource<tensor<i64>>>) -> tensor<*x!tf.resource>
tf_executor.fetch %outputs_0 : tensor<*x!tf.resource>
}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/graph-as-function.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/graph-as-function.mlir
index 40ddad9..cb9c5c3 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/graph-as-function.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/graph-as-function.mlir
@@ -2,15 +2,15 @@
func @main(%arg0: tensor<*x!tf.resource>, %arg1: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, %arg2: tensor<*xf32>, %arg3: tensor<2x4x6x8xi32>) -> (tensor<f32>, tensor<f32>)
attributes {tf.entry_function = {inputs = "args_0,args_1,args_2,args_3", outputs = "rets_0_RetVal,rets_1_RetVal"}} {
- %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "const", value = dense<0.000000e+00> : tensor<f32>} : () -> (tensor<f32>, !_tf.control)
- %1:2 = "_tf.Identity"(%0#0) {T = "tfdtype$DT_FLOAT", device = "", name = "identity"} : (tensor<f32>) -> (tensor<f32>, !_tf.control)
- %2:2 = "_tf.StatefulPartitionedCall"(%0#0, %arg1) {Tin = ["tfdtype$DT_FLOAT", "tfdtype$DT_RESOURCE"], Tout = ["tfdtype$DT_FLOAT"], _gradient_op_type = "PartitionedCall-1205", config = "", config_proto = "\0A\07\0A\03GPU\10\00\0A\07\0A\03CPU\10\012\02J\008\01", device = "", executor_type = "", f = @function0, name = "statefulpartitionedcall"} : (tensor<f32>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>) -> (tensor<f32>, !_tf.control)
- return %1#0, %2#0 : tensor<f32>, tensor<f32>
+ %0 = "tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32> loc("const")
+ %1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", device = ""} : (tensor<f32>) -> tensor<f32> loc("identity")
+ %2 = "tf.StatefulPartitionedCall"(%0, %arg1) {Tin = ["tfdtype$DT_FLOAT", "tfdtype$DT_RESOURCE"], Tout = ["tfdtype$DT_FLOAT"], _gradient_op_type = "PartitionedCall-1205", config = "", config_proto = "\0A\07\0A\03GPU\10\00\0A\07\0A\03CPU\10\012\02J\008\01", device = "", executor_type = "", f = @function0} : (tensor<f32>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>) -> tensor<f32> loc("statefulpartitionedcall")
+ return %1, %2 : tensor<f32>, tensor<f32>
}
func @function0(%arg0: tensor<*xf32>, %arg1: tensor<*x!tf.resource>) -> tensor<*xf32>
attributes {tf.signature.is_stateful} {
- %0:2 = "_tf.Identity"(%arg0) {T = "tfdtype$DT_FLOAT", device = "", name = "Identity"} : (tensor<*xf32>) -> (tensor<*xf32>, !_tf.control)
+ %0 = "tf.Identity"(%arg0) {T = "tfdtype$DT_FLOAT", device = ""} : (tensor<*xf32>) -> tensor<*xf32> loc("Identity@function0")
return %0#0 : tensor<*xf32>
}
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/legalized_name.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/legalized_name.mlir
index 67ccf52..60b239a 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/legalized_name.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/legalized_name.mlir
@@ -6,7 +6,7 @@
%0 = "tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<0> : tensor<i32>} : () -> (tensor<i32>) loc("^foo")
// CHECK: name: "fo.o"
%1 = "tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<1> : tensor<i32>} : () -> (tensor<i32>) loc("fo{o")
- // CHECK: name: "foo.1"
+ // CHECK: name: "foo"
%2 = "tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<2> : tensor<i32>} : () -> (tensor<i32>) loc("foo@1")
// CHECK: name: "ba.r"
%3 = "tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<2> : tensor<i32>} : () -> (tensor<i32>) loc("ba r")
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/parse_example.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/parse_example.mlir
new file mode 100644
index 0000000..ec51fdc
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/parse_example.mlir
@@ -0,0 +1,86 @@
+// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
+
+module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 175 : i32}} {
+ func @main(%arg0: tensor<32x!tf.string>) -> (tensor<?x2xi64>) attributes {tf.entry_function = {inputs = "input0", outputs = "ParseExample/ParseExampleV2"}} {
+
+ %0 = tf_executor.graph {
+ // NOTE(mrry): This dummy input was manually added because the exporter expects it and fails otherwise.
+ %dummy_input, %control_dummy = tf_executor.island wraps "tf.Placeholder.input"(%arg0) {device = "", dtype = "tfdtype$DT_STRING", shape = "tfshape$dim { size: 32 }"} : (tensor<32x!tf.string>) -> tensor<32x!tf.string>
+
+ %outputs, %control = tf_executor.island wraps "tf.Const"() {device = "", dtype = f32, value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32>
+ %outputs_0, %control_1 = tf_executor.island wraps "tf.Const"() {device = "", dtype = f32, value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32>
+ %outputs_2, %control_3 = tf_executor.island wraps "tf.Const"() {device = "", dtype = !tf.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B2073697A653A2032207D207D2074656E736F725F636F6E74656E743A20225C3031345C303134666561747572655F6B657931666561747572655F6B65793222"> : tensor<2x!tf.string>} : () -> tensor<2x!tf.string>
+ %outputs_4, %control_5 = tf_executor.island wraps "tf.Const"() {device = "", dtype = !tf.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf.string>} : () -> tensor<0x!tf.string>
+ %outputs_6, %control_7 = tf_executor.island wraps "tf.Const"() {device = "", dtype = !tf.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf.string>} : () -> tensor<0x!tf.string>
+ %outputs_8, %control_9 = tf_executor.island wraps "tf.Const"() {device = "", dtype = !tf.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B2073697A653A2032207D207D2074656E736F725F636F6E74656E743A20225C3031345C303134666561747572655F6B657933666561747572655F6B65793422"> : tensor<2x!tf.string>} : () -> tensor<2x!tf.string>
+
+ %outputs_10:8, %control_11 = tf_executor.island wraps "tf.ParseExampleV2"(%dummy_input, %outputs_4, %outputs_8, %outputs_2, %outputs_6, %outputs, %outputs_0) {Tdense = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], dense_shapes = ["tfshape$", "tfshape$"], device = "", name = "ParseExample/ParseExampleV2", num_sparse = 2 : i64, ragged_split_types = [], ragged_value_types = [], result_segment_sizes = dense<[2, 2, 2, 2, 0, 0]> : vector<6xi32>, sparse_types = ["tfdtype$DT_STRING", "tfdtype$DT_INT64"]} : (tensor<32x!tf.string>, tensor<0x!tf.string>, tensor<2x!tf.string>, tensor<2x!tf.string>, tensor<0x!tf.string>, tensor<0xf32>, tensor<0xf32>) -> (tensor<?x2xi64>, tensor<?x2xi64>, tensor<?x!tf.string>, tensor<?xi64>, tensor<2xi64>, tensor<2xi64>, tensor<32xf32>, tensor<32xf32>)
+ // CHECK: name: "ParseExample/ParseExampleV2"
+ // CHECK-NEXT: op: "ParseExampleV2"
+ // CHECK-NEXT: input: "input0"
+ // CHECK-NEXT: input: "_tf.Const3"
+ // CHECK-NEXT: input: "_tf.Const5"
+ // CHECK-NEXT: input: "_tf.Const2"
+ // CHECK-NEXT: input: "_tf.Const4"
+ // CHECK-NEXT: input: "_tf.Const"
+ // CHECK-NEXT: input: "_tf.Const1"
+ // CHECK-NEXT: attr {
+ // CHECK-NEXT: key: "Tdense"
+ // CHECK-NEXT: value {
+ // CHECK-NEXT: list {
+ // CHECK-NEXT: type: DT_FLOAT
+ // CHECK-NEXT: type: DT_FLOAT
+ // CHECK-NEXT: }
+ // CHECK-NEXT: }
+ // CHECK-NEXT: }
+ // CHECK-NEXT: attr {
+ // CHECK-NEXT: key: "dense_shapes"
+ // CHECK-NEXT: value {
+ // CHECK-NEXT: list {
+ // CHECK-NEXT: shape {
+ // CHECK-NEXT: }
+ // CHECK-NEXT: shape {
+ // CHECK-NEXT: }
+ // CHECK-NEXT: }
+ // CHECK-NEXT: }
+ // CHECK-NEXT: }
+ // CHECK-NEXT: attr {
+ // CHECK-NEXT: key: "num_sparse"
+ // CHECK-NEXT: value {
+ // CHECK-NEXT: i: 2
+ // CHECK-NEXT: }
+ // CHECK-NEXT: }
+ // CHECK-NEXT: attr {
+ // CHECK-NEXT: key: "ragged_split_types"
+ // CHECK-NEXT: value {
+ // CHECK-NEXT: list {
+ // CHECK-NEXT: }
+ // CHECK-NEXT: }
+ // CHECK-NEXT: }
+ // CHECK-NEXT: attr {
+ // CHECK-NEXT: key: "ragged_value_types"
+ // CHECK-NEXT: value {
+ // CHECK-NEXT: list {
+ // CHECK-NEXT: }
+ // CHECK-NEXT: }
+ // CHECK-NEXT: }
+ // CHECK-NEXT: attr {
+ // CHECK-NEXT: key: "sparse_types"
+ // CHECK-NEXT: value {
+ // CHECK-NEXT: list {
+ // CHECK-NEXT: type: DT_STRING
+ // CHECK-NEXT: type: DT_INT64
+ // CHECK-NEXT: }
+ // CHECK-NEXT: }
+ // CHECK-NEXT: }
+
+ tf_executor.fetch %outputs_10#0 : tensor<?x2xi64>
+ }
+ return %0#0 : tensor<?x2xi64>
+ // CHECK: name: "main"
+ // CHECK-NEXT: op: "_Retval"
+ // CHECK-NEXT: input: "ParseExample/ParseExampleV2"
+
+ }
+}
+
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
index 09fdb5a..d58a0b8 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir
@@ -2181,3 +2181,107 @@
%0:2 = "tf.ConcatOffset"(%concat_dim, %shape0, %shape1) : (tensor<i32>, tensor<3xi32>, tensor<8xi32>) -> (tensor<3xi32>, tensor<8xi32>)
return
}
+
+// -----
+
+func @tensor_scatter_update(%tensor: tensor<f32>, %indices: tensor<4x2xi32>, %updates: tensor<4x4xf32>) -> tensor<f32> {
+ // expected-error @+1 {{op requires tensor operand to have at least 1 dimension}}
+ %0 = "tf.TensorScatterUpdate"(%tensor, %indices, %updates) : (tensor<f32>, tensor<4x2xi32>, tensor<4x4xf32>) -> tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// -----
+
+func @tensor_scatter_update(%tensor: tensor<4x4x4xf32>, %indices: tensor<i32>, %updates: tensor<4x4xf32>) -> tensor<4x4x4xf32> {
+ // expected-error @+1 {{op requires indices operand to have at least 1 dimension}}
+ %0 = "tf.TensorScatterUpdate"(%tensor, %indices, %updates) : (tensor<4x4x4xf32>, tensor<i32>, tensor<4x4xf32>) -> tensor<4x4x4xf32>
+ return %0 : tensor<4x4x4xf32>
+}
+
+// -----
+
+func @tensor_scatter_update(%tensor: tensor<4x4x4xf32>, %indices: tensor<4x2xi32>, %updates: tensor<f32>) -> tensor<4x4x4xf32> {
+ // expected-error @+1 {{op requires updates operand to have at least 1 dimension}}
+ %0 = "tf.TensorScatterUpdate"(%tensor, %indices, %updates) : (tensor<4x4x4xf32>, tensor<4x2xi32>, tensor<f32>) -> tensor<4x4x4xf32>
+ return %0 : tensor<4x4x4xf32>
+}
+
+// -----
+
+func @tensor_scatter_update(%tensor: tensor<4xf32>, %indices: tensor<4x2xi32>, %updates: tensor<4x4xf32>) -> tensor<4x4x4xf32> {
+ // expected-error @+1 {{op requires tensor operand with rank greater than or equal to the indices operand's last dimensions}}
+ %0 = "tf.TensorScatterUpdate"(%tensor, %indices, %updates) : (tensor<4xf32>, tensor<4x2xi32>, tensor<4x4xf32>) -> tensor<4x4x4xf32>
+ return %0 : tensor<4x4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @testParseExampleV2DenseOnlyValid
+func @testParseExampleV2DenseOnlyValid(%serialized: tensor<32x!tf.string>, %names : tensor<32x!tf.string>, %dense_keys : tensor<2x!tf.string>, %dense_default_0 : tensor<?xf32>, %dense_default_1 : tensor<?xf32>) -> (tensor<32xf32>) {
+ %empty_str_vector = "tf.Const"() {dtype = !tf.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf.string>} : () -> tensor<0x!tf.string>
+ %result:2 = "tf.ParseExampleV2"(%serialized, %names, %empty_str_vector, %dense_keys, %empty_str_vector, %dense_default_0, %dense_default_1) {dense_shapes = ["tfshape$", "tfshape$"], num_sparse = 0 : i64, result_segment_sizes = dense<[0, 0, 0, 2, 0, 0]> : vector<6xi32>} : (tensor<32x!tf.string>, tensor<32x!tf.string>, tensor<0x!tf.string>, tensor<2x!tf.string>, tensor<0x!tf.string>, tensor<?xf32>, tensor<?xf32>) -> (tensor<32xf32>, tensor<32xf32>)
+ return %result#0 : tensor<32xf32>
+}
+
+// -----
+
+func @testParseExampleV2DenseMismatchedInputOutput(%serialized: tensor<32x!tf.string>, %names : tensor<32x!tf.string>, %dense_keys : tensor<2x!tf.string>, %dense_default_0 : tensor<?xf32>, %dense_default_1 : tensor<?xf32>) -> (tensor<32xf32>) {
+ %empty_str_vector = "tf.Const"() {dtype = !tf.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf.string>} : () -> tensor<0x!tf.string>
+ // expected-error @+1 {{output 'dense_values' should have same length as attribute 'Tdense'}}
+ %result:3 = "tf.ParseExampleV2"(%serialized, %names, %empty_str_vector, %dense_keys, %empty_str_vector, %dense_default_0, %dense_default_1) {dense_shapes = ["tfshape$", "tfshape$"], num_sparse = 0 : i64, result_segment_sizes = dense<[0, 0, 0, 3, 0, 0]> : vector<6xi32>} : (tensor<32x!tf.string>, tensor<32x!tf.string>, tensor<0x!tf.string>, tensor<2x!tf.string>, tensor<0x!tf.string>, tensor<?xf32>, tensor<?xf32>) -> (tensor<32xf32>, tensor<32xf32>, tensor<32xi64>)
+ return %result#0 : tensor<32xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @testParseExampleV2SparseOnlyValid
+func @testParseExampleV2SparseOnlyValid(%serialized: tensor<32x!tf.string>, %names : tensor<32x!tf.string>, %sparse_keys : tensor<2x!tf.string>) -> (tensor<?x2xi64>) {
+ %empty_str_vector = "tf.Const"() {dtype = !tf.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf.string>} : () -> tensor<0x!tf.string>
+ %result:6 = "tf.ParseExampleV2"(%serialized, %names, %sparse_keys, %empty_str_vector, %empty_str_vector) {dense_shapes = [], num_sparse = 2 : i64, result_segment_sizes = dense<[2, 2, 2, 0, 0, 0]> : vector<6xi32>} : (tensor<32x!tf.string>, tensor<32x!tf.string>, tensor<2x!tf.string>, tensor<0x!tf.string>, tensor<0x!tf.string>) -> (tensor<?x2xi64>, tensor<?x2xi64>, tensor<?x!tf.string>, tensor<?xi64>, tensor<2xi64>, tensor<2xi64>)
+ return %result#0 : tensor<?x2xi64>
+}
+
+// -----
+
+func @testParseExampleV2SparseInvalidNumSparse(%serialized: tensor<32x!tf.string>, %names : tensor<32x!tf.string>, %sparse_keys : tensor<2x!tf.string>) -> (tensor<?x2xi64>) {
+ %empty_str_vector = "tf.Const"() {dtype = !tf.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf.string>} : () -> tensor<0x!tf.string>
+ // expected-error @+1 {{attribute 'num_sparse' should be the same as the length of attribute 'sparse_types'}}
+ %result:6 = "tf.ParseExampleV2"(%serialized, %names, %sparse_keys, %empty_str_vector, %empty_str_vector) {dense_shapes = [], num_sparse = 3 : i64, result_segment_sizes = dense<[2, 2, 2, 0, 0, 0]> : vector<6xi32>} : (tensor<32x!tf.string>, tensor<32x!tf.string>, tensor<2x!tf.string>, tensor<0x!tf.string>, tensor<0x!tf.string>) -> (tensor<?x2xi64>, tensor<?x2xi64>, tensor<?x!tf.string>, tensor<?xi64>, tensor<2xi64>, tensor<2xi64>)
+ return %result#0 : tensor<?x2xi64>
+}
+
+// -----
+
+func @testParseExampleV2SparseInvalidSparseIndicesOutput(%serialized: tensor<32x!tf.string>, %names : tensor<32x!tf.string>, %sparse_keys : tensor<2x!tf.string>) -> (tensor<?x2xi64>) {
+ %empty_str_vector = "tf.Const"() {dtype = !tf.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf.string>} : () -> tensor<0x!tf.string>
+ // expected-error @+1 {{output 'sparse_indices' should have same length as attribute 'sparse_types'}}
+ %result:5 = "tf.ParseExampleV2"(%serialized, %names, %sparse_keys, %empty_str_vector, %empty_str_vector) {dense_shapes = [], num_sparse = 2 : i64, result_segment_sizes = dense<[1, 2, 2, 0, 0, 0]> : vector<6xi32>} : (tensor<32x!tf.string>, tensor<32x!tf.string>, tensor<2x!tf.string>, tensor<0x!tf.string>, tensor<0x!tf.string>) -> (tensor<?x2xi64>, tensor<?x!tf.string>, tensor<?xi64>, tensor<2xi64>, tensor<2xi64>)
+ return %result#0 : tensor<?x2xi64>
+}
+
+// -----
+
+func @testParseExampleV2SparseOnlyValid(%serialized: tensor<32x!tf.string>, %names : tensor<32x!tf.string>, %sparse_keys : tensor<2x!tf.string>) -> (tensor<?x2xi64>) {
+ %empty_str_vector = "tf.Const"() {dtype = !tf.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf.string>} : () -> tensor<0x!tf.string>
+ // expected-error @+1 {{output 'sparse_shapes' should have same length as attribute 'sparse_types'}}
+ %result:5 = "tf.ParseExampleV2"(%serialized, %names, %sparse_keys, %empty_str_vector, %empty_str_vector) {dense_shapes = [], num_sparse = 2 : i64, result_segment_sizes = dense<[2, 2, 1, 0, 0, 0]> : vector<6xi32>} : (tensor<32x!tf.string>, tensor<32x!tf.string>, tensor<2x!tf.string>, tensor<0x!tf.string>, tensor<0x!tf.string>) -> (tensor<?x2xi64>, tensor<?x2xi64>, tensor<?x!tf.string>, tensor<?xi64>, tensor<2xi64>)
+ return %result#0 : tensor<?x2xi64>
+}
+
+// -----
+
+// CHECK-LABEL: func @testParseExampleV2RaggedOnlyValid
+func @testParseExampleV2RaggedOnlyValid(%serialized: tensor<32x!tf.string>, %names : tensor<32x!tf.string>, %ragged_keys : tensor<2x!tf.string>) -> (tensor<?xf32>) {
+ %empty_str_vector = "tf.Const"() {dtype = !tf.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf.string>} : () -> tensor<0x!tf.string>
+ %result:4 = "tf.ParseExampleV2"(%serialized, %names, %empty_str_vector, %empty_str_vector, %ragged_keys) {dense_shapes = [], num_sparse = 0 : i64, result_segment_sizes = dense<[0, 0, 0, 0, 2, 2]> : vector<6xi32>} : (tensor<32x!tf.string>, tensor<32x!tf.string>, tensor<0x!tf.string>, tensor<0x!tf.string>, tensor<2x!tf.string>) -> (tensor<?xf32>, tensor<?x!tf.string>, tensor<?xi32>, tensor<?xi64>)
+ return %result#0 : tensor<?xf32>
+}
+
+// -----
+
+func @testParseExampleV2RaggedMismatchedOutputLengths(%serialized: tensor<32x!tf.string>, %names : tensor<32x!tf.string>, %ragged_keys : tensor<2x!tf.string>) -> (tensor<?xf32>) {
+ %empty_str_vector = "tf.Const"() {dtype = !tf.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf.string>} : () -> tensor<0x!tf.string>
+ // expected-error @+1 {{attribute 'ragged_value_types' should have same length as attribute 'ragged_split_types'}}
+ %result:3 = "tf.ParseExampleV2"(%serialized, %names, %empty_str_vector, %empty_str_vector, %ragged_keys) {dense_shapes = [], num_sparse = 0 : i64, result_segment_sizes = dense<[0, 0, 0, 0, 2, 1]> : vector<6xi32>} : (tensor<32x!tf.string>, tensor<32x!tf.string>, tensor<0x!tf.string>, tensor<0x!tf.string>, tensor<2x!tf.string>) -> (tensor<?xf32>, tensor<?x!tf.string>, tensor<?xi32>)
+ return %result#0 : tensor<?xf32>
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc
index 165d1b2..7686767 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc
@@ -68,9 +68,9 @@
// re-ordered but forming clusters of non-continuous ops is effectively
// re-ordering them..
bool CanMergeIntoCluster(const Cluster& c, Operation* to_merge) {
- return llvm::all_of(to_merge->getOperands(), [&](Value* operand) {
+ return llvm::all_of(to_merge->getOperands(), [&](Value operand) {
// Block arguments.
- if (isa<BlockArgument>(operand)) return true;
+ if (operand->isa<BlockArgument>()) return true;
Operation* defining_op = operand->getDefiningOp();
@@ -95,11 +95,11 @@
});
}
-void ReplaceLiveOutExternalUses(llvm::ArrayRef<Value*> live_outs,
+void ReplaceLiveOutExternalUses(llvm::ArrayRef<Value> live_outs,
tf_device::LaunchOp launch_op) {
Region* launch_op_region = &launch_op.body();
for (const auto& p : llvm::zip(live_outs, launch_op.getResults())) {
- Value* from = std::get<0>(p);
+ Value from = std::get<0>(p);
for (auto& use : from->getUses()) {
if (launch_op_region->isAncestor(use.getOwner()->getParentRegion()))
continue;
@@ -109,11 +109,11 @@
}
// Get all escaped live-out values of a region.
-void GetLiveOuts(Region* region, llvm::SmallVectorImpl<Value*>* live_outs) {
+void GetLiveOuts(Region* region, llvm::SmallVectorImpl<Value>* live_outs) {
live_outs->clear();
for (Operation& op : region->front()) {
- for (Value* v : op.getResults()) {
+ for (Value v : op.getResults()) {
// A value is live-out if any of its users are not inside value producer's
// region.
bool is_live_out = llvm::any_of(v->getUsers(), [&](Operation* user) {
@@ -145,7 +145,7 @@
// Get all escaped live-out values of region, they are used later to determine
// return values and types of launch op.
- llvm::SmallVector<Value*, 4> live_outs;
+ llvm::SmallVector<Value, 4> live_outs;
GetLiveOuts(®ion, &live_outs);
// Build a `tf_device.return` op at end of region, with all live-out values
@@ -157,7 +157,7 @@
llvm::SmallVector<Type, 4> live_out_types;
live_out_types.reserve(live_outs.size());
- for (Value* v : live_outs) {
+ for (Value v : live_outs) {
live_out_types.emplace_back(v->getType());
}
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc
index 10337df..b38a603 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc
@@ -51,12 +51,12 @@
// Builds a function that outlines region attached to launch_op and inserts
// built function into given module.
-FuncOp BuildFunction(StringRef device, llvm::ArrayRef<Value*> live_ins,
+FuncOp BuildFunction(StringRef device, llvm::ArrayRef<Value> live_ins,
tf_device::LaunchOp launch_op, SymbolTable* symbol_table,
OpBuilder* builder) {
llvm::SmallVector<Type, 4> operand_types;
operand_types.reserve(live_ins.size());
- for (Value* v : live_ins) operand_types.emplace_back(v->getType());
+ for (Value v : live_ins) operand_types.emplace_back(v->getType());
llvm::SmallVector<Type, 4> result_types(launch_op.getResultTypes());
@@ -101,7 +101,7 @@
// removed afterwards.`
void OutlineLaunch(tf_device::LaunchOp launch_op, SymbolTable* symbol_table,
OpBuilder* builder) {
- llvm::SetVector<Value*> live_ins;
+ llvm::SetVector<Value> live_ins;
getUsedValuesDefinedAbove(launch_op.body(), launch_op.body(), live_ins);
StringRef device =
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc
index 918e6ac..116b9fc 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc
@@ -49,11 +49,11 @@
// IslandResult is a helper struct holding an islands result and associated
// inner op result.
struct IslandResult {
- IslandResult(Value* inner_op_result, Value* island_result)
+ IslandResult(Value inner_op_result, Value island_result)
: inner_op_result(inner_op_result), island_result(island_result) {}
- Value* inner_op_result;
- Value* island_result;
+ Value inner_op_result;
+ Value island_result;
};
struct ExecutorIslandCoarsening
@@ -70,7 +70,7 @@
Operation* candidate = nullptr;
// Check island control operands.
- for (Value* input : island.controlInputs()) {
+ for (Value input : island.controlInputs()) {
Operation* def = input->getDefiningOp();
DCHECK_EQ(def->getParentOp(), graph_op);
if (!candidate || candidate->isBeforeInBlock(def)) candidate = def;
@@ -78,7 +78,7 @@
// Check island data operands.
island.walk([graph_op, &candidate](Operation* op) {
- for (Value* input : op->getOperands()) {
+ for (Value input : op->getOperands()) {
Operation* def = input->getDefiningOp();
if (!def || def->getParentOp() != graph_op) continue;
if (!candidate || candidate->isBeforeInBlock(def)) candidate = def;
@@ -106,7 +106,7 @@
// Check island data results.
Block& graph_body = llvm::cast<GraphOp>(graph_op).GetBody();
- for (Value* result : island.outputs()) {
+ for (Value result : island.outputs()) {
for (Operation* user : result->getUsers()) {
Operation* def = graph_body.findAncestorOpInBlock(*user);
DCHECK_NE(def, nullptr);
@@ -121,9 +121,9 @@
// Collects the operands for the new island by collecting all control inputs of
// the islands being merged.
-llvm::SmallSetVector<Value*, 8> GetNewIslandOperands(IslandOp parent,
- IslandOp child) {
- llvm::SmallSetVector<Value*, 8> operands;
+llvm::SmallSetVector<Value, 8> GetNewIslandOperands(IslandOp parent,
+ IslandOp child) {
+ llvm::SmallSetVector<Value, 8> operands;
operands.insert(parent.getOperands().begin(), parent.getOperands().end());
operands.insert(child.getOperands().begin(), child.getOperands().end());
operands.remove(parent.control());
@@ -145,8 +145,8 @@
for (auto ret_vals :
llvm::zip(parent.GetYield().getOperands(), parent.outputs())) {
bool result_captured = false;
- Value* inner_op_result = std::get<0>(ret_vals);
- Value* island_result = std::get<1>(ret_vals);
+ Value inner_op_result = std::get<0>(ret_vals);
+ Value island_result = std::get<1>(ret_vals);
for (auto& use : llvm::make_early_inc_range(island_result->getUses())) {
if (child_body.findAncestorOpInBlock(*use.getOwner())) {
// Forward result from inner op.
@@ -160,8 +160,8 @@
for (auto ret_vals :
llvm::zip(child.GetYield().getOperands(), child.outputs())) {
- Value* inner_op_result = std::get<0>(ret_vals);
- Value* island_result = std::get<1>(ret_vals);
+ Value inner_op_result = std::get<0>(ret_vals);
+ Value island_result = std::get<1>(ret_vals);
if (!island_result->use_empty()) {
results.emplace_back(inner_op_result, island_result);
}
@@ -173,7 +173,7 @@
// Creates the new merged island.
IslandOp CreateNewIsland(IslandOp parent, IslandOp child,
IslandType insert_position,
- llvm::ArrayRef<Value*> operands,
+ llvm::ArrayRef<Value> operands,
llvm::ArrayRef<IslandResult> results) {
// Collect types from results.
llvm::SmallVector<Type, 8> result_types;
@@ -194,7 +194,7 @@
// Creates respective YieldOp for the new merged island.
YieldOp CreateNewIslandYieldOp(IslandOp new_island,
llvm::ArrayRef<IslandResult> results) {
- llvm::SmallVector<Value*, 8> yield_operands;
+ llvm::SmallVector<Value, 8> yield_operands;
yield_operands.reserve(results.size());
for (auto ret_vals : llvm::zip(results, new_island.outputs())) {
@@ -232,8 +232,7 @@
// Merges two islands and places new merged island before parent or child.
void MergeIslands(IslandOp parent, IslandOp child, IslandType insert_position) {
// Collect operands for the new merged island.
- llvm::SmallSetVector<Value*, 8> operands =
- GetNewIslandOperands(parent, child);
+ llvm::SmallSetVector<Value, 8> operands = GetNewIslandOperands(parent, child);
// Collect results for the new merged island.
llvm::SmallVector<IslandResult, 8> results =
@@ -288,9 +287,9 @@
// This allows our def-use based island coarsening algorithm to merge
// islands that independently feed into a fetch.
void InsertDummyIslandForFetch(FetchOp fetch) {
- llvm::SmallVector<Value*, 4> data_fetches;
+ llvm::SmallVector<Value, 4> data_fetches;
llvm::SmallVector<Type, 4> data_types;
- llvm::SmallVector<Value*, 4> control_fetches;
+ llvm::SmallVector<Value, 4> control_fetches;
for (auto value : fetch.fetches()) {
if (value->getType().isa<ControlType>()) {
control_fetches.push_back(value);
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc
index 52b425c..cc668b3 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc
@@ -65,12 +65,12 @@
} // namespace
// Returns the defining op for a value looking through islands.
-static Operation* GetDefiningOp(Value* val) {
+static Operation* GetDefiningOp(Value val) {
Operation* op = val->getDefiningOp();
auto island_op = dyn_cast<tf_executor::IslandOp>(op);
if (!island_op) return op;
auto yield_op = island_op.GetYield();
- auto index = cast<mlir::OpResult>(val)->getResultNumber();
+ auto index = val->cast<mlir::OpResult>()->getResultNumber();
return yield_op.getOperand(index)->getDefiningOp();
}
@@ -81,7 +81,7 @@
// identity nodes are common so handle them specially when considering
// predicate in a minimally invasive way until identity's are handled more
// generally.
-static Value* LookThroughIdentityOp(Value* pred_val) {
+static Value LookThroughIdentityOp(Value pred_val) {
if (!pred_val) return pred_val;
auto op = GetDefiningOp(pred_val);
if (auto id_op = dyn_cast<TF::IdentityOp>(op)) pred_val = id_op.input();
@@ -124,7 +124,7 @@
}
// Enqueue users of a value.
- void EnqueueUsers(Value* val) {
+ void EnqueueUsers(Value val) {
for (auto user : val->getUsers()) {
Enqueue(user, val->getType().isa<tf_executor::ControlType>());
}
@@ -175,7 +175,7 @@
// Enqueues values of foldable switch ops.
static void MatchSwitchFoldOps(tf_executor::SwitchOp switch_op,
DeadQueue* queue) {
- Value* pred_val = LookThroughIdentityOp(switch_op.predicate());
+ Value pred_val = LookThroughIdentityOp(switch_op.predicate());
// If predicate or input is null then enqueue entire op for deletion.
if (pred_val == nullptr || switch_op.data() == nullptr) {
@@ -187,8 +187,8 @@
if (!matchPattern(pred_val, m_Constant(&pred))) return;
bool taken = pred.getSplatValue<bool>();
- Value* dead = taken ? switch_op.falseOutput() : switch_op.trueOutput();
- Value* live = !taken ? switch_op.falseOutput() : switch_op.trueOutput();
+ Value dead = taken ? switch_op.falseOutput() : switch_op.trueOutput();
+ Value live = !taken ? switch_op.falseOutput() : switch_op.trueOutput();
live->replaceAllUsesWith(switch_op.data());
queue->EnqueueUsers(dead);
@@ -210,12 +210,12 @@
for (auto it : queue.merge_nodes()) {
// Find the valid input to merge node.
- Value* val = nullptr;
+ Value val = nullptr;
int index = -1;
auto* merge = it.first;
auto merge_op = cast<tf_executor::MergeOp>(merge);
for (auto e : llvm::enumerate(merge->getOperands())) {
- Value* operand = e.value();
+ Value operand = e.value();
if (!operand) continue;
// Skip control operands.
if (operand->getType().isa<tf_executor::ControlType>()) break;
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc
index e9b3879..94fa222 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc
@@ -48,7 +48,7 @@
// non-empty means True and empty means False. If the tensor is not a scalar,
// being empty means False and being non-empty means True.
//
-static Value* LowerCondition(Location loc, Value* value, OpBuilder* builder) {
+static Value LowerCondition(Location loc, Value value, OpBuilder* builder) {
// TODO: Right now we just handle zero-D tensors of boolean values.
// FIXME: This is almost all wrong, but is a placeholder to unblock the one
// testcases, later patches will build on this once I build the right infra to
@@ -70,15 +70,14 @@
// Requires the function to provide arguments for each of the `fn` operands
// that is compatible for tensor cast.
//
-static Operation* CallFn(Location loc,
- const std::function<Value*(int)>& get_arg, FuncOp fn,
- OpBuilder* builder) {
+static Operation* CallFn(Location loc, const std::function<Value(int)>& get_arg,
+ FuncOp fn, OpBuilder* builder) {
FunctionType fn_type = fn.getType();
- llvm::SmallVector<Value*, 4> operands;
+ llvm::SmallVector<Value, 4> operands;
int num_operands = fn_type.getNumInputs();
operands.reserve(num_operands);
for (int i = 0; i < num_operands; ++i) {
- Value* val = get_arg(i);
+ Value val = get_arg(i);
Type expected = fn_type.getInput(i);
if (val->getType() != expected) {
val =
@@ -95,14 +94,14 @@
//
// Requires the function to provide values for each of the block arguments and
// they should be pair-wise compatible for tensor cast.
-static llvm::SmallVector<Value*, 4> PrepareValsForJump(
- Location loc, const std::function<Value*(int)>& get_val, Block* block,
+static llvm::SmallVector<Value, 4> PrepareValsForJump(
+ Location loc, const std::function<Value(int)>& get_val, Block* block,
OpBuilder* builder) {
- llvm::SmallVector<Value*, 4> result;
+ llvm::SmallVector<Value, 4> result;
int num_vals = block->getNumArguments();
result.reserve(num_vals);
for (int i = 0; i < num_vals; ++i) {
- Value* val = get_val(i);
+ Value val = get_val(i);
Type expected = block->getArgument(i)->getType();
if (val->getType() != expected) {
val =
@@ -119,7 +118,7 @@
//
// Requires the function to provide values for each of the block arguments and
// they should be pair-wise compatible for tensor cast.
-static void JumpToBlock(Location loc, const std::function<Value*(int)>& get_arg,
+static void JumpToBlock(Location loc, const std::function<Value(int)>& get_arg,
Block* block, OpBuilder* builder) {
auto operands = PrepareValsForJump(loc, get_arg, block, builder);
builder->create<BranchOp>(loc, block, operands);
@@ -136,8 +135,8 @@
Block* block, OpBuilder* builder) {
assert(op->getNumResults() == block->getNumArguments());
for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
- Value* arg = block->getArgument(i);
- Value* result = op->getResult(i);
+ Value arg = block->getArgument(i);
+ Value result = op->getResult(i);
if (arg->getType() != result->getType()) {
arg =
builder->create<TF::CastOp>(loc, result->getType(), arg,
@@ -160,7 +159,7 @@
OpBuilder builder(op_inst);
// Lower the condition to a boolean value (i1).
- Value* cond_i1 = LowerCondition(loc, op.cond(), &builder);
+ Value cond_i1 = LowerCondition(loc, op.cond(), &builder);
if (!cond_i1) return failure();
auto module = op_inst->getParentOfType<ModuleOp>();
@@ -174,7 +173,7 @@
// Add the block arguments to the merge point, and replace all uses of the
// original operation results with them.
- for (Value* value : op_inst->getResults())
+ for (Value value : op_inst->getResults())
merge_block->addArgument(value->getType());
ReplaceOpResultWithBlockArgs(loc, op_inst, merge_block, &builder);
@@ -200,8 +199,8 @@
// orig_block with a conditional branch.
builder.setInsertionPointToEnd(orig_block);
builder.create<CondBranchOp>(loc, cond_i1, then_block,
- llvm::ArrayRef<Value*>(), else_block,
- llvm::ArrayRef<Value*>());
+ llvm::ArrayRef<Value>(), else_block,
+ llvm::ArrayRef<Value>());
// Finally, delete the op in question.
op_inst->erase();
@@ -277,7 +276,7 @@
Operation* cond_call_op = CallFn(loc, get_cond_arg, cond_fn, &builder);
assert(cond_call_op->getNumResults() == 1);
- Value* condition = LowerCondition(loc, cond_call_op->getResult(0), &builder);
+ Value condition = LowerCondition(loc, cond_call_op->getResult(0), &builder);
auto br_operands =
PrepareValsForJump(loc, get_cond_arg, body_block, &builder);
builder.create<CondBranchOp>(loc, condition, body_block, br_operands,
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc b/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc
index 882e769..f71584c 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc
@@ -38,7 +38,7 @@
// Visit an op's operands if it is output of an Operation in same graph.
auto visit_op = [&](Operation* op) {
- for (Value* operand : op->getOperands()) {
+ for (Value operand : op->getOperands()) {
Operation* def = operand->getDefiningOp();
if (def && def->getParentOp() == graph &&
reachable_ops.insert(def).second) {
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc
index d0bf397..7f17807 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc
@@ -134,8 +134,8 @@
// TODO(hinsu): Improve parallelism by splitting operands in two halves and
// accumulating them first.
- Value *result = *op.inputs().begin();
- for (Value *operand : llvm::drop_begin(op.inputs(), 1)) {
+ Value result = *op.inputs().begin();
+ for (Value operand : llvm::drop_begin(op.inputs(), 1)) {
result = rewriter.create<TF::AddV2Op>(op.getLoc(), result, operand);
}
@@ -189,8 +189,8 @@
SmallVector<DenseIntElementsAttr, 4> indices;
indices.reserve(op.N());
for (auto it : llvm::zip(op.indices(), op.data())) {
- Value *index = std::get<0>(it);
- Value *data = std::get<1>(it);
+ Value index = std::get<0>(it);
+ Value data = std::get<1>(it);
DenseIntElementsAttr index_attr;
if (!matchPattern(index, m_Constant(&index_attr))) return matchFailure();
@@ -214,10 +214,10 @@
// Prepare each of the output item by unpacking data and then putting it to
// the specified index.
- SmallVector<Value *, 8> values(out_ty.getDimSize(0));
+ SmallVector<Value, 8> values(out_ty.getDimSize(0));
for (auto it : llvm::zip(indices, op.data())) {
DenseIntElementsAttr index_attr = std::get<0>(it);
- Value *data = std::get<1>(it);
+ Value data = std::get<1>(it);
auto reshaped_data =
rewriter.create<ReshapeOp>(loc, data, packed_shape_val);
@@ -228,7 +228,7 @@
/*axis=*/APInt(64, 0));
for (auto index_item : llvm::zip(index_attr, items.getResults())) {
int64_t output_index = std::get<0>(index_item).getSExtValue();
- Value *item = std::get<1>(index_item);
+ Value item = std::get<1>(index_item);
values[output_index] = item;
}
}
@@ -264,9 +264,9 @@
int64_t axis = op.axis().getSExtValue();
Type prev_input_ty, inferred_ty;
- SmallVector<Value *, 4> expanded_inputs;
+ SmallVector<Value, 4> expanded_inputs;
expanded_inputs.reserve(op.N());
- for (Value *input : op.values()) {
+ for (Value input : op.values()) {
// If input type is different than the previous input type, infer the
// output type. Otherwise, use the already inferred output type from the
// previous iteration.
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td
index 5dc173a..07792d5 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td
@@ -184,6 +184,14 @@
(TF_ConstOp (GetScalarOfType<0> $input)))>;
//===----------------------------------------------------------------------===//
+// Reciprocal op patterns.
+//===----------------------------------------------------------------------===//
+
+// TODO(hinsu): Support complex and unsigned input types.
+def LowerReciprocal : Pat<(TF_ReciprocalOp TF_SintOrFpTensor:$x),
+ (TF_DivOp (TF_ConstOp (GetScalarOfType<1> $x)), $x)>;
+
+//===----------------------------------------------------------------------===//
// Rsqrt op patterns.
//===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc
index e7acbb3..7658c01 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc
@@ -52,7 +52,7 @@
// be keep in sync.
bool IsReadOnlyVariableOp(Operation* op) { return isa<TF::ReadVariableOp>(op); }
-void RewriteReadOnlyVariableOpToTensorOp(Operation* op, Value* tensor_value) {
+void RewriteReadOnlyVariableOpToTensorOp(Operation* op, Value tensor_value) {
auto read_variable = cast<TF::ReadVariableOp>(op);
read_variable.value()->replaceAllUsesWith(tensor_value);
}
@@ -73,7 +73,7 @@
// func for tf.ReadVariableOp. If the resource is passed into other functions
// or control flow, we fail to prove it is freezable even though we could.
for (auto& global_tensor_use : global_tensor_uses) {
- auto* arg = global_tensor_use.func.getArgument(global_tensor_use.arg_index);
+ auto arg = global_tensor_use.func.getArgument(global_tensor_use.arg_index);
for (auto user : arg->getUsers()) {
if (!IsReadOnlyVariableOp(user)) {
return false;
@@ -129,7 +129,7 @@
for (auto global_tensor_use : global_tensor_uses) {
auto func = global_tensor_use.func;
auto arg_index = global_tensor_use.arg_index;
- Value* arg = func.getArgument(arg_index);
+ Value arg = func.getArgument(arg_index);
for (Operation* user : llvm::make_early_inc_range(arg->getUsers())) {
RewriteReadOnlyVariableOpToTensorOp(user, arg);
user->erase();
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
index f870ca2..32d98a2 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
@@ -46,7 +46,8 @@
// Optimizes Tensorflow graph.
std::unique_ptr<OpPassBase<FuncOp>> CreateTFOptimizePass();
-struct StandardPipelineOptions : public PassOptions<StandardPipelineOptions> {
+struct StandardPipelineOptions
+ : public PassPipelineOptions<StandardPipelineOptions> {
Option<bool> enable_inliner{*this, "enable-inliner",
llvm::cl::desc("Enable inliner."),
llvm::cl::init(false)};
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/raise_control_flow.cc b/tensorflow/compiler/mlir/tensorflow/transforms/raise_control_flow.cc
index d6acb74..4bbf071 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/raise_control_flow.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/raise_control_flow.cc
@@ -110,7 +110,7 @@
// Add a result type for each non-control result we find.
bool sawControlResult = false;
- for (auto *opResult : op.getResults()) {
+ for (auto opResult : op.getResults()) {
if (opResult->getType().isa<TFControlType>()) {
sawControlResult = true;
} else {
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc
index 36f6f3a..a2b9f1c 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc
@@ -71,10 +71,10 @@
// }
void MakeShapeOpInvariant(tf_device::ReplicateOp replicate_op, int num_replicas,
Block* replicate_block, TF::ShapeOp shape_op) {
- Value* input = shape_op.input();
+ Value input = shape_op.input();
// If ShapeOp operand is replicate tensor block argument, replace with the
// associated first replica operand.
- if (auto block_arg = llvm::dyn_cast<BlockArgument>(input)) {
+ if (auto block_arg = input->dyn_cast<BlockArgument>()) {
if (block_arg->getOwner() != replicate_block) return;
shape_op.setOperand(
@@ -96,7 +96,7 @@
// shape has not changed in replicate prior to read. Currently after both
// ResourceOpLiftingPass and TPURewritePass, there should not be any updates
// to resources prior to their respective ReadVariableOp.
- if (auto block_arg = llvm::dyn_cast<BlockArgument>(read_var_op.resource())) {
+ if (auto block_arg = read_var_op.resource()->dyn_cast<BlockArgument>()) {
if (block_arg->getOwner() != replicate_block) return;
OpBuilder builder(shape_op);
@@ -111,7 +111,7 @@
// Checks if op and inner op operands are all replicate invariant.
bool IsOpReplicateInvariant(Region* replicate_region, Operation* op) {
auto result = op->walk([&](Operation* inner_op) {
- for (Value* operand : inner_op->getOperands()) {
+ for (Value operand : inner_op->getOperands()) {
Region* parent_region = operand->getParentRegion();
if (!parent_region || !parent_region->isProperAncestor(replicate_region))
return WalkResult::interrupt();
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc
index 9787ac0..ca594ac 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc
@@ -60,7 +60,7 @@
Operation& terminator = replicate_op.GetBody().back();
llvm::SmallVector<Type, 8> output_types(terminator.getOperandTypes());
auto control_type = tf_executor::ControlType::get(island_op.getContext());
- llvm::SmallVector<Value*, 8> replica_inputs(island_op.controlInputs());
+ llvm::SmallVector<Value, 8> replica_inputs(island_op.controlInputs());
// Replace replicate terminator with YieldOp.
builder->setInsertionPoint(&terminator);
@@ -149,8 +149,8 @@
num_replicas);
// Collect all replica results.
- llvm::SmallVector<Value*, 8> replicas_outputs(replicate_op.getNumResults(),
- nullptr);
+ llvm::SmallVector<Value, 8> replicas_outputs(replicate_op.getNumResults(),
+ nullptr);
for (auto replica_and_idx : llvm::enumerate(replicas))
for (auto replica_result_and_idx :
llvm::enumerate(replica_and_idx.value().outputs()))
@@ -163,7 +163,7 @@
// Collect per replica control dependency and add to island operand if replica
// island has no uses.
- llvm::SmallVector<Value*, 8> island_operands;
+ llvm::SmallVector<Value, 8> island_operands;
for (auto& replica : replicas)
if (replica.use_empty()) island_operands.push_back(replica.control());
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc
index 6dc3e87..db1bbaa 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc
@@ -64,7 +64,7 @@
// Returns the recorded device assignment for a resource, if any.
llvm::Optional<llvm::StringRef> DeviceForResource(
- const Value* resource) const {
+ const Value resource) const {
llvm::Optional<llvm::StringRef> result;
if (alias_analysis_.IsUnknownResource(resource)) return result;
for (int64_t id : alias_analysis_.GetResourceUniqueIds(resource)) {
@@ -87,7 +87,7 @@
// conflicts with an existing one, returns an error.
//
// If `changed` is provided, assign *changed to true if anything is modified.
- LogicalResult AddResourceDevice(const Value* resource, llvm::StringRef device,
+ LogicalResult AddResourceDevice(const Value resource, llvm::StringRef device,
bool* changed = nullptr) {
if (alias_analysis_.IsUnknownResource(resource)) return success();
for (int64_t id : alias_analysis_.GetResourceUniqueIds(resource)) {
@@ -108,7 +108,7 @@
};
// Tries to record device assignment for a resource.
-LogicalResult AddResourceDeviceAndEmitError(const Value* resource,
+LogicalResult AddResourceDeviceAndEmitError(const Value resource,
llvm::StringRef device,
Operation* error_reporting_op,
PerFunctionResult* result,
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc
index 2f32a3a..a3d3af0 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc
@@ -87,14 +87,14 @@
// resource_handle_to_last_store_op keeps track of the most recent (last)
// store to each resource. Non-existent entry indicates that a resource has
// not been stored to yet.
- llvm::SmallDenseMap<Value*, TF::AssignVariableOp>
+ llvm::SmallDenseMap<Value, TF::AssignVariableOp>
resource_handle_to_last_store_op;
// Only iterate through ops directly in launch_op's body as we can't handle
// ops nested deeper in regions.
for (Operation& op : llvm::make_early_inc_range(launch_op.GetBody())) {
if (auto read_variable_op = dyn_cast<TF::ReadVariableOp>(&op)) {
- Value* resource = read_variable_op.resource();
+ Value resource = read_variable_op.resource();
auto last_store = resource_handle_to_last_store_op[resource];
if (!last_store) continue;
@@ -106,7 +106,7 @@
}
if (auto assign_variable_op = dyn_cast<TF::AssignVariableOp>(&op)) {
- Value* resource = assign_variable_op.resource();
+ Value resource = assign_variable_op.resource();
auto last_store = resource_handle_to_last_store_op[resource];
// Previous store ops to same resource can be erased.
if (last_store) last_store.erase();
@@ -120,14 +120,14 @@
// forwarding has been performed on this launch_op such that all loads of same
// resource are on its initial values.
void HoistResourceLoads(tf_device::LaunchOp launch_op) {
- llvm::SmallDenseMap<Value*, TF::ReadVariableOp> resource_to_read_ops;
+ llvm::SmallDenseMap<Value, TF::ReadVariableOp> resource_to_read_ops;
// Only iterate through ops directly in launch_op's body as we can't handle
// ops nested deeper in regions.
for (Operation& op : llvm::make_early_inc_range(launch_op.GetBody())) {
auto read_variable_op = dyn_cast<TF::ReadVariableOp>(&op);
if (!read_variable_op) continue;
- Value* resource = read_variable_op.resource();
+ Value resource = read_variable_op.resource();
// Skip resources created inside of launch_op.
if (resource->getParentRegion() == &launch_op.body()) continue;
@@ -156,14 +156,14 @@
Block* body = &launch_op.GetBody();
auto old_return = body->getTerminator();
- llvm::SmallVector<Value*, 4> new_return_operands(old_return->getOperands());
+ llvm::SmallVector<Value, 4> new_return_operands(old_return->getOperands());
// Only iterate through ops directly in launch_op's body as we can't handle
// ops nested deeper in regions.
for (Operation& op : launch_op.GetBody()) {
auto assign_variable_op = dyn_cast<TF::AssignVariableOp>(&op);
if (!assign_variable_op) continue;
- Value* resource = assign_variable_op.resource();
+ Value resource = assign_variable_op.resource();
if (!resource) continue;
// Skip resources created inside of launch_op.
@@ -202,7 +202,7 @@
builder->setInsertionPoint(launch_op);
auto new_launch_op = builder->create<tf_device::LaunchOp>(
launch_op.getLoc(), new_launch_return_types,
- /*operands=*/llvm::SmallVector<Value*, 4>(), launch_op.getAttrs());
+ /*operands=*/llvm::SmallVector<Value, 4>(), launch_op.getAttrs());
new_launch_op.body().takeBody(launch_op.body());
// Replace uses of old launch_op results with those of new_launch_op.
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
index 844ae2f..fb06d0e 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
@@ -74,12 +74,12 @@
if (cast_op.SrcT() != cast_op.DstT()) continue;
// We only refine the result shape if the result a dynamic shape, the
// input has static shape, and the two shapes are compatible.
- auto has_static_shape = [](const Value* value) {
+ auto has_static_shape = [](const Value value) {
auto shaped_type = value->getType().dyn_cast<ShapedType>();
return shaped_type && shaped_type.hasStaticShape();
};
- Value* input = cast_op.x();
- Value* result = cast_op.y();
+ Value input = cast_op.x();
+ Value result = cast_op.y();
if (!has_static_shape(input) || has_static_shape(result) ||
failed(verifyCompatibleShape(input->getType(), result->getType())))
continue;
@@ -161,7 +161,7 @@
op->getNumOperands());
std::vector<tensorflow::Tensor> tensors(op->getNumOperands());
for (auto it : llvm::enumerate(op->getOperands())) {
- Value* operand = it.value();
+ Value operand = it.value();
size_t index = it.index();
// If the operand is constant, then convert it to Tensor.
@@ -214,7 +214,7 @@
builder.setInsertionPointAfter(op);
for (int output : llvm::seq<int>(0, c.num_outputs())) {
// Skip already statically shaped results.
- Value* result = op->getResult(output);
+ Value result = op->getResult(output);
auto shaped_type = result->getType().dyn_cast<ShapedType>();
if (!shaped_type || shaped_type.hasStaticShape()) continue;
@@ -306,7 +306,7 @@
int64_t max_iteration) {
llvm::SmallVector<Type, 4> input_types;
input_types.reserve(std::distance(op.input().begin(), op.input().end()));
- for (Value* v : op.input()) {
+ for (Value v : op.input()) {
input_types.push_back(v->getType());
}
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc b/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc
index e4358e7..f7d5bbe 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc
@@ -48,10 +48,10 @@
// The sunk_constant map keeps a mapping from a ConstOp defined above to
// a sunk clone of it. This allows for reusing a sunk constant with
// multiple uses in the region.
- llvm::DenseMap<Value *, TF::ConstOp> sunk_constant;
+ llvm::DenseMap<Value, TF::ConstOp> sunk_constant;
Region &body = launch.body();
visitUsedValuesDefinedAbove(body, [&](OpOperand *use) {
- Value *constant = use->get();
+ Value constant = use->get();
auto const_op =
dyn_cast_or_null<TF::ConstOp>(constant->getDefiningOp());
if (!const_op) return;
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc
index 7a840aa..058d62f 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc
@@ -140,7 +140,7 @@
const llvm::SmallSetVector<Operation*, 8>& cluster_ops,
const llvm::SmallSetVector<Operation*, 8>& preceding_users) {
auto result = op->walk([&](Operation* op) {
- for (Value* operand : op->getOperands()) {
+ for (Value operand : op->getOperands()) {
Operation* def = operand->getDefiningOp();
// Operands may not have a defining op (BlockArgument) or is from a
// different block.
@@ -179,12 +179,12 @@
// `tf_device::LaunchOp` and associated terminator. Results that have no uses
// outside of the cluster (i.e. results of ops in the cluster are only consumed
// by other ops in the cluster) are pruned.
-llvm::SmallVector<Value*, 8> CollectClusterResults(
+llvm::SmallVector<Value, 8> CollectClusterResults(
Block* block, const llvm::SmallSetVector<Operation*, 8>& cluster_ops) {
- llvm::SmallVector<Value*, 8> results;
+ llvm::SmallVector<Value, 8> results;
for (Operation* op : cluster_ops) {
- for (Value* result : op->getResults()) {
+ for (Value result : op->getResults()) {
for (Operation* user : result->getUsers()) {
// Check if user is not an op in the cluster.
if (cluster_ops.count(block->findAncestorOpInBlock(*user)) == 0) {
@@ -200,13 +200,13 @@
// Creates a `tf_device::LaunchOp` to wrap cluster ops.
tf_device::LaunchOp CreateLaunchOpForCluster(Operation* last_cluster_op,
- llvm::ArrayRef<Value*> results) {
+ llvm::ArrayRef<Value> results) {
// `tf_device::LaunchOp` will be placed at where the last op of the cluster
// is.
OpBuilder builder(last_cluster_op);
llvm::SmallVector<Type, 8> result_types;
- for (Value* result : results) result_types.push_back(result->getType());
+ for (Value result : results) result_types.push_back(result->getType());
// An empty string placeholder is used for the device as that will be later
// populated with the device of the associated TPUReplicateMetadata op.
@@ -241,11 +241,11 @@
// Replaces uses of cluster ops results outside of cluster with the associated
// `tf_device::LaunchOp` results.
void UpdateLaunchOpResultExternalUses(tf_device::LaunchOp launch_op,
- llvm::ArrayRef<Value*> results) {
+ llvm::ArrayRef<Value> results) {
Block& launch_op_block = launch_op.GetBody();
for (auto ret_vals : llvm::zip(results, launch_op.getResults())) {
- Value* old_ret = std::get<0>(ret_vals);
- Value* new_ret = std::get<1>(ret_vals);
+ Value old_ret = std::get<0>(ret_vals);
+ Value new_ret = std::get<1>(ret_vals);
for (auto& use : old_ret->getUses())
if (!launch_op_block.findAncestorOpInBlock(*use.getOwner()))
use.set(new_ret);
@@ -337,7 +337,7 @@
// Replace replicated cluster results with replicate op results.
for (auto result_and_idx : llvm::enumerate(launch_op.getResults())) {
- Value* result = result_and_idx.value();
+ Value result = result_and_idx.value();
int idx = result_and_idx.index();
for (auto& use : result->getUses()) {
Operation* def = use.getOwner();
@@ -360,7 +360,7 @@
for (auto input_and_block_arg :
llvm::zip(replicated_input_ops, replicate_op.GetBody().getArguments())) {
Operation* input = std::get<0>(input_and_block_arg);
- Value* block_arg = std::get<1>(input_and_block_arg);
+ Value block_arg = std::get<1>(input_and_block_arg);
mlir::replaceAllUsesInRegionWith(input->getResult(0), block_arg,
launch_op.body());
}
@@ -412,7 +412,7 @@
llvm::SmallSetVector<Operation*, 8> preceding_users =
CollectClusterPrecedingUsers(block, cluster_ops);
- llvm::SmallVector<Value*, 8> results =
+ llvm::SmallVector<Value, 8> results =
CollectClusterResults(block, cluster_ops);
tf_device::LaunchOp launch_op =
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc
index f2f885d..b45ea48 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc
@@ -60,7 +60,7 @@
llvm::SmallDenseMap<int32_t, int32_t> remapped_indices;
for (auto operand_and_idx : llvm::enumerate(launch_func.getOperands()))
- if (auto block_arg = llvm::dyn_cast<BlockArgument>(operand_and_idx.value()))
+ if (auto block_arg = operand_and_idx.value()->dyn_cast<BlockArgument>())
if (block_arg->getOwner() == replicate_block)
remapped_indices[block_arg->getArgNumber()] = operand_and_idx.index();
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc
index 2833250..ce54b6a 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc
@@ -92,15 +92,15 @@
// Information about all resource accesses to be fused into a TPUExecute op.
struct VariableAccessesForTPUExecute {
// Maps each resource detected to VariableAccessInfo.
- llvm::SmallDenseMap<Value*, VariableAccessInfo, 8> per_resource_info;
+ llvm::SmallDenseMap<Value, VariableAccessInfo, 8> per_resource_info;
// The corresponding new output index in TPUExecuteAndUpdateVariables for
// each old output index in TPUExecute.
llvm::SmallVector<int, 8> old_to_new_output_mapping;
// The resources read by ReadVariableOps that are inputs to TPUExecute.
// Ordered by the input indices to TPUExecute
- llvm::SmallVector<Value*, 8> resources_read;
+ llvm::SmallVector<Value, 8> resources_read;
// Operands for the new TPUExecuteAndUpdateVariables.
- llvm::SmallVector<Value*, 8> new_operand_values;
+ llvm::SmallVector<Value, 8> new_operand_values;
};
// Returns if an op accesses a resource.
@@ -147,7 +147,7 @@
// Check device matching for the node defining the resource.
if (!resource_attr || resource_attr != device_attr) continue;
} else {
- auto resource_arg = llvm::dyn_cast<BlockArgument>(resource);
+ auto resource_arg = resource->dyn_cast<BlockArgument>();
assert(resource_arg);
// Check device matching for the argument defining the resource.
auto resource_attr = func.getArgAttrOfType<mlir::StringAttr>(
@@ -206,7 +206,7 @@
}
infos.resources_read.erase(
llvm::remove_if(infos.resources_read,
- [&](const Value* resource) {
+ [&](const Value resource) {
return infos.per_resource_info.count(resource) == 0;
}),
infos.resources_read.end());
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc
index 1033670..bfd7af8 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc
@@ -277,7 +277,7 @@
// TODO(b/139377366): When shape inference is ready, we can use compile time
// shape inference to get inputs that have static shapes and only use shape
// ops for the rest.
- llvm::SmallVector<Value*, 4> compile_op_operands;
+ llvm::SmallVector<Value, 4> compile_op_operands;
compile_op_operands.reserve(launch_func.getNumOperands());
for (auto operand_and_idx : llvm::enumerate(launch_func.getOperands())) {
@@ -332,7 +332,7 @@
OpBuilder* builder) {
// TPUExecute inherits all launch_func inputs, and takes an additional input
// for compilation cache key.
- llvm::SmallVector<Value*, 4> tensor_inputs(launch_func.getOperands());
+ llvm::SmallVector<Value, 4> tensor_inputs(launch_func.getOperands());
tensor_inputs.push_back(compile_op->getResult(1));
// TODO(b/139377366): Need to snapshot all resource variable inputs in
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc
index 764c791..334baec 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc
@@ -44,7 +44,7 @@
void BreakUpIsland(tf_executor::IslandOp op,
const TF::SideEffectAnalysis& side_effect_analysis,
- llvm::DenseMap<Operation*, llvm::SmallVector<Value*, 4>>*
+ llvm::DenseMap<Operation*, llvm::SmallVector<Value, 4>>*
new_control_edges);
};
@@ -64,7 +64,7 @@
// Map from the users of the existing islands to the list of control
// edges that need to be added.
- llvm::DenseMap<Operation*, llvm::SmallVector<Value*, 4>> new_control_edges;
+ llvm::DenseMap<Operation*, llvm::SmallVector<Value, 4>> new_control_edges;
auto& side_effect_analysis = getAnalysis<TF::SideEffectAnalysis>();
// Iterate in reverse order to avoid invalidating Operation* stored in
// new_control_edges.
@@ -78,7 +78,7 @@
// Apply edge additions in reverse order so that the ops don't get
// invalidated.
- llvm::SmallVector<Value*, 8> edges;
+ llvm::SmallVector<Value, 8> edges;
llvm::SmallPtrSet<Operation*, 4> dups;
llvm::SmallVector<Type, 4> types;
for (auto& item :
@@ -96,11 +96,11 @@
edges.assign(item.operand_begin(), item.operand_end());
dups.clear();
- for (Value* input : edges) {
+ for (Value input : edges) {
dups.insert(input->getDefiningOp());
}
// Insert new control edges removing duplicates.
- for (Value* value : llvm::reverse(edge.second)) {
+ for (Value value : llvm::reverse(edge.second)) {
if (dups.insert(value->getDefiningOp()).second) edges.push_back(value);
}
state.addOperands(edges);
@@ -114,7 +114,7 @@
// Helper that creates an island. If `sub_op` is not nullptr, it will be moved
// to the island.
tf_executor::IslandOp CreateIsland(ArrayRef<Type> result_types,
- ArrayRef<Value*> control_inputs,
+ ArrayRef<Value> control_inputs,
const tf_executor::ControlType& control_type,
const Location& loc, Operation* sub_op,
tf_executor::IslandOp original_island) {
@@ -132,7 +132,7 @@
if (sub_op) {
island_builder.create<tf_executor::YieldOp>(loc, sub_op->getResults());
} else {
- island_builder.create<tf_executor::YieldOp>(loc, ArrayRef<Value*>{});
+ island_builder.create<tf_executor::YieldOp>(loc, ArrayRef<Value>{});
}
return island;
}
@@ -178,7 +178,7 @@
void BreakUpIslands::BreakUpIsland(
tf_executor::IslandOp op,
const TF::SideEffectAnalysis& side_effect_analysis,
- llvm::DenseMap<Operation*, llvm::SmallVector<Value*, 4>>*
+ llvm::DenseMap<Operation*, llvm::SmallVector<Value, 4>>*
new_control_edges) {
auto island_body = op.GetBody().without_terminator();
// Skip islands that are already only a single op.
@@ -188,7 +188,7 @@
auto island_control_inputs = llvm::to_vector<4>(op.controlInputs());
// Add control dependencies for yields of values defined by other islands to
// the island that defines that fetched value.
- for (auto* fetch : op.GetYield().fetches()) {
+ for (auto fetch : op.GetYield().fetches()) {
// Ok, because there is no op to add control to (eg: function args).
if (!fetch->getDefiningOp()) continue;
if (fetch->getDefiningOp()->getParentOp() == op) {
@@ -214,9 +214,9 @@
auto sources_and_sinks =
FindSourcesAndSinksInIsland(op, side_effect_analysis);
// The corresponding control output of the new island created for each sub-op.
- llvm::SmallDenseMap<Operation*, Value*, 8> new_control_for_sub_ops;
+ llvm::SmallDenseMap<Operation*, Value, 8> new_control_for_sub_ops;
// Control outputs of newly created islands that are sinks.
- llvm::SmallVector<Value*, 8> sink_island_controls;
+ llvm::SmallVector<Value, 8> sink_island_controls;
// For each operation in the island, construct a new island to wrap the op,
// yield all the results, and replace all the usages with the results of the
// new island.
@@ -224,7 +224,7 @@
const auto predecessors =
side_effect_analysis.DirectControlPredecessors(&sub_op);
// Get the controls from the predecessors.
- llvm::SmallVector<Value*, 4> predecessors_control;
+ llvm::SmallVector<Value, 4> predecessors_control;
predecessors_control.reserve(predecessors.size());
for (auto predecessor : predecessors) {
predecessors_control.push_back(new_control_for_sub_ops[predecessor]);
@@ -233,9 +233,9 @@
// by inter-islands dependencies; otherwise, we do not need to include
// island_control_inputs, since they must have been tracked by the (direct
// or indirect) control predecessors or operands.
- ArrayRef<Value*> control = sources_and_sinks.sources.count(&sub_op) > 0
- ? island_control_inputs
- : predecessors_control;
+ ArrayRef<Value> control = sources_and_sinks.sources.count(&sub_op) > 0
+ ? island_control_inputs
+ : predecessors_control;
auto island =
CreateIsland(llvm::to_vector<4>(sub_op.getResultTypes()), control,
control_type, sub_op.getLoc(), &sub_op, op);
@@ -258,7 +258,7 @@
op.control()->replaceAllUsesWith(sink_island_controls[0]);
// All existing outputs need to add a control flow edge from
// sink_island_controls[0].
- for (Value* out : op.outputs()) {
+ for (Value out : op.outputs()) {
for (auto& use : out->getUses()) {
Operation* owner = use.getOwner();
if (auto island_op =
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc b/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc
index 29979c0..54e3e45 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc
@@ -68,8 +68,8 @@
tf_executor::IslandOp ControlToExecutorDialectConversion::CreateIslandForOp(
Operation *op, OpBuilder *builder) {
// Create a new region for the tf_executor.island body
- SmallVector<Value *, 8> operands;
- for (Value *operand : op->getOperands())
+ SmallVector<Value, 8> operands;
+ for (Value operand : op->getOperands())
if (operand->getType().isa<tf_executor::ControlType>())
operands.push_back(operand);
SmallVector<Type, 8> types;
@@ -118,8 +118,8 @@
// This is the return of the function, we will create a fetch in the graph
// matching the operands of the returns. The return is then updated to
// take as operands the results of the tf_executor.graph operation.
- SmallVector<Value *, 8> ret_vals;
- for (Value *operand : op.getOperands()) ret_vals.push_back(operand);
+ SmallVector<Value, 8> ret_vals;
+ for (Value operand : op.getOperands()) ret_vals.push_back(operand);
for (auto &graph_result : llvm::enumerate(graph_op.getResults()))
op.setOperand(graph_result.index(), graph_result.value());
builder.create<tf_executor::FetchOp>(getFunction().getLoc(), ret_vals);
@@ -128,7 +128,7 @@
assert(IsUnderscoredTFOp(&op) && "Expected only _tf operations");
// The operands and types arrays are used to create the tf_executor ops.
- SmallVector<Value *, 8> operands;
+ SmallVector<Value, 8> operands;
operands.append(op.getOperands().begin(), op.getOperands().end());
SmallVector<Type, 8> types;
for (Type result_type : op.getResultTypes()) {
@@ -201,7 +201,7 @@
// Only the non-control operands are carried over, the island is handling
// the control input.
- for (Value *operand : op.getOperands())
+ for (Value operand : op.getOperands())
if (!operand->getType().isa<tf_executor::ControlType>())
result.operands.push_back(operand);
@@ -223,7 +223,7 @@
inner_op->setAttrs(op.getAttrList());
// Add the terminator for the island
- SmallVector<Value *, 8> ret_vals(inner_op->getResults());
+ SmallVector<Value, 8> ret_vals(inner_op->getResults());
island_builder.create<tf_executor::YieldOp>(loc, ret_vals);
}
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc b/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc
index 8a4f8aa..827f0d6 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc
@@ -46,7 +46,7 @@
// Replace all uses of value `v` with a list of new values. Because number of
// new values might be greater than 1, users of `v` might be replaced with their
// clones in case of non-resizable operands list.
-void ReplaceAllUsesOfValueWithValues(Value *v,
+void ReplaceAllUsesOfValueWithValues(Value v,
Operation::operand_range new_values) {
int new_values_size = std::distance(new_values.begin(), new_values.end());
if (new_values_size == 1) {
@@ -58,9 +58,9 @@
for (Operation *user : llvm::make_early_inc_range(v->getUsers())) {
builder.setInsertionPoint(user);
- llvm::SmallVector<Value *, 4> new_operands;
+ llvm::SmallVector<Value, 4> new_operands;
new_operands.reserve(user->getNumOperands() - 1 + new_values_size);
- for (Value *operand : user->getOperands()) {
+ for (Value operand : user->getOperands()) {
if (operand == v) {
new_operands.append(new_values.begin(), new_values.end());
} else {
@@ -135,7 +135,7 @@
builder.setInsertionPoint(&op);
if (auto island = dyn_cast<tf_executor::IslandOp>(op)) {
- Value *ctl_sequence = nullptr;
+ Value ctl_sequence = nullptr;
for (Operation &wrapped_op : island.GetBody()) {
LLVM_DEBUG(llvm::dbgs()
<< " In island: " << wrapped_op.getName() << "\n");
@@ -162,7 +162,7 @@
if (ctl_sequence) {
state.operands.push_back(ctl_sequence);
} else {
- for (Value *ctl_operand : island.getOperands())
+ for (Value ctl_operand : island.getOperands())
state.operands.push_back(ctl_operand);
}
@@ -228,7 +228,7 @@
// dialect.
auto non_null_operands = llvm::make_filter_range(
op.getOperands(),
- [](Value *v) { return !v->getType().isa<tf_executor::TokenType>(); });
+ [](Value v) { return !v->getType().isa<tf_executor::TokenType>(); });
state.operands.append(non_null_operands.begin(), non_null_operands.end());
for (Type result_type : op.getResultTypes()) {
// Filter out TokenType, they don't exist in the control dialect.
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc
index 9d57220..c5829f2 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc
@@ -32,6 +32,7 @@
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/Identifier.h" // TF:local_config_mlir
+#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/Types.h" // TF:local_config_mlir
@@ -110,25 +111,28 @@
return legalized_name;
}
-// TODO(jpienaar): unify and move from here to be able to reuse with tflite
-std::string GetName(Operation* inst) {
- // TODO(prakalps): b/137006652 prevents us from using location info (derived
- // from experimental_debug_info) to generate node names. Until it is fixed,
- // first check for "name" attribute to get node name.
-
- // Default name is Operation type.
- auto name = inst->getName().getStringRef();
- if (auto attr = inst->getAttrOfType<mlir::StringAttr>("name")) {
- name = attr.getValue();
- } else if (auto name_loc = inst->getLoc().dyn_cast<mlir::NameLoc>()) {
- name = name_loc.getName().strref();
- } else if (auto call_loc = inst->getLoc().dyn_cast<mlir::CallSiteLoc>()) {
+llvm::StringRef GetNameFromLoc(mlir::Location loc,
+ llvm::StringRef default_name) {
+ if (auto name_loc = loc.dyn_cast<mlir::NameLoc>()) {
+ return name_loc.getName().strref().split('@').first;
+ } else if (auto call_loc = loc.dyn_cast<mlir::CallSiteLoc>()) {
// Return name if CallSiteLoc's callee has a NameLoc (as should be the case
// if imported with DebugInfo), else use the fallback naming scheme below.
if (auto name_loc = call_loc.getCallee().dyn_cast<mlir::NameLoc>())
- name = name_loc.getName().strref();
+ return name_loc.getName().strref().split('@').first;
+ } else if (auto fused_loc = loc.dyn_cast<mlir::FusedLoc>()) {
+ // According to the importer, the last location of a fused location is
+ // the name from the node_def and the rests are from the experimental debug
+ // info.
+ return GetNameFromLoc(fused_loc.getLocations().back(), default_name);
}
+ return default_name;
+}
+// TODO(jpienaar): unify and move from here to be able to reuse with tflite
+std::string GetName(Operation* inst) {
+ // Default name is Operation type.
+ auto name = GetNameFromLoc(inst->getLoc(), inst->getName().getStringRef());
return LegalizeNodeName(name);
}
@@ -161,7 +165,7 @@
explicit Exporter(Graph* graph, const Dialect* tf_dialect)
: graph_(graph), tf_dialect_(tf_dialect) {}
- Status AddArgumentNode(BlockArgument* arg, unsigned index,
+ Status AddArgumentNode(BlockArgument arg, unsigned index,
llvm::StringRef name);
Status AddReturnNode(mlir::ReturnOp op,
llvm::ArrayRef<llvm::StringRef> names);
@@ -169,7 +173,7 @@
Status AddNextIterationNode(Operation* inst);
Status AddEdge(Operation* inst);
- StatusOr<std::unique_ptr<NodeDef>> GetArgumentNode(BlockArgument* arg,
+ StatusOr<std::unique_ptr<NodeDef>> GetArgumentNode(BlockArgument arg,
unsigned index,
llvm::StringRef name);
StatusOr<std::unique_ptr<NodeDef>> GetReturnNode(Operation* inst,
@@ -177,7 +181,7 @@
llvm::StringRef name);
// Adds one edge between src_node and dst_node. If it is not a control edge,
// an index is used to find out the right operand of the dst_node.
- Status AddEdgeBetweenNodes(Value* src, Node* dst_node, unsigned dst_index);
+ Status AddEdgeBetweenNodes(Value src, Node* dst_node, unsigned dst_index);
// Returns a unique name for `op`.
std::string UniqueName(Operation* op);
@@ -189,7 +193,7 @@
absl::flat_hash_map<Operation*, string> op_to_name_;
absl::flat_hash_map<string, int64> name_to_count_;
absl::flat_hash_map<Operation*, Node*> nodes_;
- absl::flat_hash_map<const BlockArgument*, Node*> args_;
+ llvm::DenseMap<BlockArgument, Node*> args_;
// One single return operation can return multiple results, and each of them
// will be converted to one node in the graph.
typedef absl::InlinedVector<Node*, 4> NodeVector;
@@ -231,7 +235,7 @@
}
StatusOr<std::unique_ptr<NodeDef>> Exporter::GetArgumentNode(
- BlockArgument* arg, unsigned index, llvm::StringRef name) {
+ BlockArgument arg, unsigned index, llvm::StringRef name) {
auto func = arg->getParentRegion()->getParentOfType<mlir::FuncOp>();
auto node_def = absl::make_unique<NodeDef>();
@@ -279,7 +283,7 @@
UniqueName(inst->getParentOfType<mlir::FuncOp>().getName().str()));
node_def->set_op(FunctionLibraryDefinition::kRetOp);
- auto* inst_op = inst->getOperand(index);
+ auto inst_op = inst->getOperand(index);
DataType dtype;
TF_RETURN_IF_ERROR(ConvertToDataType(
inst_op->getType().cast<mlir::TensorType>().getElementType(), &dtype));
@@ -292,9 +296,9 @@
return node_def;
}
-Status Exporter::AddEdgeBetweenNodes(Value* src, Node* dst_node,
+Status Exporter::AddEdgeBetweenNodes(Value src, Node* dst_node,
unsigned dst_index) {
- if (auto* input_result = dyn_cast<mlir::OpResult>(src)) {
+ if (auto input_result = src->dyn_cast<mlir::OpResult>()) {
auto* input_inst = input_result->getOwner();
// replaces the input node by the sink one if it is an NextIteration source:
auto it = source_to_sink_.find(input_inst);
@@ -313,7 +317,7 @@
return Status::OK();
}
- auto* input_arg = cast<BlockArgument>(src);
+ auto input_arg = src->cast<BlockArgument>();
auto input_node_it = args_.find(input_arg);
TF_RET_CHECK(input_node_it != args_.end())
<< "Use of BlockArgument encounted before def!";
@@ -326,7 +330,7 @@
auto* dst_node = nodes_[inst];
bool is_return_op = isa<mlir::ReturnOp>(inst);
for (int index = 0, e = inst->getNumOperands(); index < e; index++) {
- auto* src = inst->getOperand(index);
+ auto src = inst->getOperand(index);
// For return operation, the edge is from the operand owner to one of the
// faked return nodes. The input index is always 0 for the return node.
if (is_return_op) {
@@ -361,14 +365,14 @@
return Status::OK();
}
-bool IsEntryFunctionArg(BlockArgument* arg) {
+bool IsEntryFunctionArg(BlockArgument arg) {
return arg->getParentRegion()->getParentOfType<mlir::FuncOp>().getName() ==
"main";
}
// Creates argument nodes from Block argument. If a name is supplied, that
// name will be used instead of generating a unique name.
-Status Exporter::AddArgumentNode(BlockArgument* arg, unsigned index,
+Status Exporter::AddArgumentNode(BlockArgument arg, unsigned index,
llvm::StringRef name) {
if (!IsEntryFunctionArg(arg) || !name.empty()) {
TF_ASSIGN_OR_RETURN(auto node_def, GetArgumentNode(arg, index, name));
@@ -395,9 +399,9 @@
builder.getContext());
OperationState state(loc, input_name.str());
state.attributes.append(input->getAttrs().begin(), input->getAttrs().end());
- for (auto* op : input->getOperands()) {
+ for (auto op : input->getOperands()) {
// Skip the argument in the new operation.
- if (llvm::isa<BlockArgument>(op)) continue;
+ if (op->isa<BlockArgument>()) continue;
state.operands.push_back(op);
}
state.types.append(input->getResultTypes().begin(),
@@ -405,7 +409,15 @@
auto* inst = builder.createOperation(state);
// If it is one of the specified input names, then the new
// instruction should have the same name.
- op_to_name_[inst].assign(op_to_name_[input]);
+ auto& mapped_name = op_to_name_[inst];
+ const auto& input_mapped_name = op_to_name_[input];
+ DCHECK(mapped_name.empty())
+ << "AddArgumentNode() attempted to change the op_to_name_ mapping for "
+ << inst << " from " << mapped_name << " to " << input_mapped_name << ".";
+ DCHECK(!input_mapped_name.empty())
+ << "AddArgumentNode() attempted to set the op_to_name_ mapping for "
+ << inst << " to an empty string.";
+ mapped_name.assign(input_mapped_name);
for (int index : llvm::seq<int>(0, input->getNumResults())) {
input->getResult(index)->replaceAllUsesWith(inst->getResult(index));
}
@@ -511,9 +523,15 @@
// Only assign defining op of operands of the return the output names if
// the main graph did not have its _Retval nodes lifted into the functions
// returns.
- if (!graph_as_function)
- exporter.op_to_name_[it.value()->getDefiningOp()] =
- output_names[it.index()];
+ if (!graph_as_function) {
+ auto defining_op = it.value()->getDefiningOp();
+ auto& mapped_name = exporter.op_to_name_[defining_op];
+ DCHECK(mapped_name.empty())
+ << "Convert() attempted to change the op_to_name_ mapping for "
+ << defining_op << " from " << mapped_name << " to output "
+ << it.index() << " name " << output_names[it.index()].str() << ".";
+ mapped_name = output_names[it.index()];
+ }
}
}
if (!input_names.empty()) {
@@ -522,16 +540,22 @@
exporter.name_to_count_[input_names[it.index()].str()] = 1;
// Only assign user of argument the input name if the main graph did not
// have its _Arg nodes lifted into the functions arguments.
- if (!graph_as_function)
- exporter.op_to_name_[*it.value()->user_begin()] =
- input_names[it.index()];
+ if (!graph_as_function) {
+ auto first_user = *it.value()->user_begin();
+ auto& mapped_name = exporter.op_to_name_[first_user];
+ DCHECK(mapped_name.empty())
+ << "Convert() attempted to change the op_to_name_ mapping for "
+ << first_user << " from " << mapped_name << " to input "
+ << it.index() << " name " << input_names[it.index()].str() << ".";
+ mapped_name = input_names[it.index()];
+ }
}
}
// Adds nodes for basic block (function) arguments.
for (auto it : llvm::enumerate(block.getArguments())) {
int index = it.index();
- auto* arg = it.value();
+ auto arg = it.value();
mlir::Type type = arg->getType();
if (!type.isa<mlir::TensorType>()) {
return errors::InvalidArgument(
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc
index 0a1192c..3ff526d 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc
@@ -131,7 +131,7 @@
if (inst->getDialect() && inst->getDialect()->getNamespace() == "_tf") {
mlir::OperationState result(inst->getLoc(),
inst->getName().getStringRef().drop_front());
- for (mlir::Value* operand : inst->getOperands())
+ for (mlir::Value operand : inst->getOperands())
if (!operand->getType().isa<mlir::TFControlFlow::TFControlType>())
result.operands.push_back(operand);
@@ -160,6 +160,13 @@
TF_RETURN_IF_ERROR(GetUnregisteredAttrs(inst, &attrs_to_ignore));
}
+ if (inst->hasTrait<mlir::OpTrait::AttrSizedResultSegments>()) {
+ // TODO(b/146937733): Don't use <void> here.
+ llvm::StringRef attr_name = mlir::OpTrait::AttrSizedResultSegments<
+ void>::getResultSegmentSizeAttr();
+ attrs_to_ignore.insert(attr_name.data());
+ }
+
TF_ASSIGN_OR_RETURN(auto node_def,
GetOperationNodeDef(attrs_to_ignore, inst, name));
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
index 4b0bebf..70547b8 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
@@ -48,6 +48,7 @@
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
+#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir
#include "mlir/IR/Types.h" // TF:local_config_mlir
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
@@ -264,7 +265,7 @@
mlir::Operation* createOperation(
const Node& node, llvm::StringRef node_type_name,
const mlir::OperationState& result,
- const llvm::SmallVectorImpl<mlir::Value*>& control_operands,
+ const llvm::SmallVectorImpl<mlir::Value>& control_operands,
bool convert_to_legacy_call = false);
// Converts one NodeDef from the input GraphDef into an Operation and
@@ -1174,7 +1175,7 @@
const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
const absl::InlinedVector<Node*, 4>& control_ret_nodes) {
auto* bb = &func.front();
- llvm::SmallDenseMap<std::pair<Node*, int>, mlir::Value*, 4>
+ llvm::SmallDenseMap<std::pair<Node*, int>, mlir::Value, 4>
arg_nodes_to_values;
for (int i = 0, e = arg_types.size(); i < e; ++i) {
auto& arg_node = arg_nodes[i];
@@ -1182,8 +1183,8 @@
// be converted to mlir operations and don't have a mapping.
mlir::Operation* island = node_values_.find(arg_node.node->id())->second;
- auto* bb_arg = bb->getArgument(i);
- mlir::Value* arg_def = bb_arg;
+ auto bb_arg = bb->getArgument(i);
+ mlir::Value arg_def = bb_arg;
if (island->getNumResults() != 2)
return errors::InvalidArgument(
@@ -1206,7 +1207,7 @@
island->erase();
}
- llvm::SmallVector<mlir::Value*, 8> inst_to_return;
+ llvm::SmallVector<mlir::Value, 8> inst_to_return;
for (const auto& ret : ret_nodes) {
auto* inst = node_values_[ret.node->id()];
auto op = absl::string_view(ret.node->type_string());
@@ -1318,15 +1319,21 @@
return create_location(node_def.name(), function_name_for_debug_info_);
} else {
// If the original nodes are defined, then we use them to get a list of
- // call sites, and then fuse them to a single fused location.
- llvm::SmallVector<mlir::Location, 4> node_call_sites;
- node_call_sites.reserve(original_nodes.size());
+ // call sites, and then fuse them to a single fused location, with the name
+ // of the node_def.
+ llvm::SmallVector<mlir::Location, 4> node_locations;
+ node_locations.reserve(original_nodes.size() + 1);
+
+ // store the names in the experimental_debug_info
for (int i = 0, e = original_nodes.size(); i != e; ++i) {
auto node_name = original_nodes[i];
auto func_name = (i < original_funcs.size()) ? original_funcs[i] : "";
- node_call_sites.push_back(create_location(node_name, func_name));
+ node_locations.push_back(create_location(node_name, func_name));
}
- return mlir::FusedLoc::get(node_call_sites, context_);
+ // store the name of the node_def
+ node_locations.push_back(
+ create_location(node_def.name(), function_name_for_debug_info_));
+ return mlir::FusedLoc::get(node_locations, context_);
}
}
@@ -1347,14 +1354,14 @@
mlir::Operation* ImporterBase::createOperation(
const Node& node, llvm::StringRef node_type_name,
const mlir::OperationState& result,
- const llvm::SmallVectorImpl<mlir::Value*>& control_operands,
+ const llvm::SmallVectorImpl<mlir::Value>& control_operands,
bool convert_to_legacy_call) {
// For the tf.executor specific operations (not wrapped in an island), we
// have an extra returned value for the control result, and we concatenate
// control and non-control operands.
mlir::SmallVector<mlir::Type, 4> types(result.types);
types.push_back(mlir::tf_executor::ControlType::get(builder_.getContext()));
- mlir::SmallVector<mlir::Value*, 4> operands(result.operands);
+ mlir::SmallVector<mlir::Value, 4> operands(result.operands);
operands.append(control_operands.begin(), control_operands.end());
auto loc = result.location;
@@ -1432,6 +1439,32 @@
inner_op = island_builder.createOperation(result);
}
+ if (inner_op->hasTrait<mlir::OpTrait::AttrSizedResultSegments>()) {
+ // The op has multiple variadic outputs.
+ // Calculate result segment sizes using the OpDef.
+ NameRangeMap output_ranges;
+ // This will fail only if the OpDef is syntactically invalid.
+ // TODO(jpienaar): Convert this CHECK into a properly propagated error.
+ TF_CHECK_OK(
+ NameRangesForNode(node, node.op_def(), nullptr, &output_ranges));
+ std::vector<mlir::Attribute> values;
+ values.reserve(node.op_def().output_arg_size());
+ for (const auto& output_arg : node.op_def().output_arg()) {
+ auto range = output_ranges[output_arg.name()];
+ values.push_back(
+ island_builder.getI32IntegerAttr(range.second - range.first));
+ }
+
+ // Add derived "result_segment_sizes" attr to the created operation.
+ // TODO(b/146937733): Don't use <void> here.
+ llvm::StringRef attr_name = mlir::OpTrait::AttrSizedResultSegments<
+ void>::getResultSegmentSizeAttr();
+ auto attr_type = mlir::VectorType::get(node.op_def().output_arg_size(),
+ builder_.getIntegerType(32));
+ auto attr_value = mlir::DenseElementsAttr::get(attr_type, values);
+ inner_op->setAttr(attr_name, attr_value);
+ }
+
// Add the terminator for the island
island_builder.create<mlir::tf_executor::YieldOp>(result.location,
inner_op->getResults());
@@ -1497,7 +1530,7 @@
result.operands.reserve(in_edges.size());
// Collect the control operands separately, they will be held by the island.
- mlir::SmallVector<mlir::Value*, 8> control_operands;
+ mlir::SmallVector<mlir::Value, 8> control_operands;
for (const auto* input_edge : in_edges) {
const Node& input_node = *input_edge->src();
@@ -1567,8 +1600,6 @@
}
result.attributes.push_back(builder_.getNamedAttr(
- "name", builder_.getStringAttr(std::string(node.name()))));
- result.attributes.push_back(builder_.getNamedAttr(
"device", builder_.getStringAttr(std::string(node_def.device()))));
// Map If and StatelessIf op in TensorFlow to the common If op in MLIR and add
@@ -1648,7 +1679,7 @@
// Replaces the output uses of the old operation by the corresponding
// result of the new operation, and deletes the old operation.
for (unsigned i = 0, e = dst->getNumResults(); i != e; ++i) {
- auto* new_output = new_dst->getResult(i);
+ auto new_output = new_dst->getResult(i);
dst->getResult(i)->replaceAllUsesWith(new_output);
}
dst->dropAllReferences();
@@ -2533,7 +2564,7 @@
module.insert(module.getBody()->begin(), func);
func.addEntryBlock();
func.setName("__sm_exported_" + orig_func.getName().str());
- llvm::SmallVector<mlir::Value*, 4> args_as_values;
+ llvm::SmallVector<mlir::Value, 4> args_as_values;
for (auto block_argument : func.getArguments()) {
args_as_values.push_back(block_argument);
}
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc
index 86fbff9..ee769cf 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc
@@ -75,7 +75,7 @@
builder.setInsertionPointToEnd(&graph_op.GetBody());
auto island = builder.create<tf_executor::IslandOp>(
loc, getFunction().getType().getResults(),
- tf_executor::ControlType::get(&getContext()), ArrayRef<Value*>());
+ tf_executor::ControlType::get(&getContext()), ArrayRef<Value>());
// Create Fetch.
ValueRange to_fetch = island.getResults();
if (to_fetch.size() != 1) {
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc
index e35b713..97f2486 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc
@@ -65,8 +65,12 @@
debug_info->add_original_node_names(name_loc.getName().c_str());
}
} else if (auto fused = inst_loc.dyn_cast<mlir::FusedLoc>()) {
- for (auto loc : fused.getLocations()) {
- TF_RETURN_IF_ERROR(ConvertLocation(loc, debug_info));
+ auto locations = fused.getLocations();
+ if (locations.size() <= 1)
+ return errors::InvalidArgument("expected experimental debuf info.");
+ // skip the first one, which is the name of the node_def.
+ for (int i = 0; i < locations.size() - 1; ++i) {
+ TF_RETURN_IF_ERROR(ConvertLocation(locations[i], debug_info));
}
}
return Status::OK();
diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
index 5063267..e2caf2d 100644
--- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
+++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc
@@ -264,6 +264,12 @@
attributes.push_back(ConvertComparisonDirection(instruction));
MakeAndReturn(CompareOp);
}
+ case HloOpcode::kCholesky: {
+ attributes.push_back(builder_->getNamedAttr(
+ "lower",
+ builder_->getBoolAttr(instruction->cholesky_options().lower())));
+ MakeAndReturn(CholeskyOp);
+ }
case HloOpcode::kGather: {
auto gather_instruction = static_cast<HloGatherInstruction*>(instruction);
attributes.push_back(ConvertGatherDimensionNumbers(
@@ -284,7 +290,7 @@
return func_builder
->create<mlir::xla_hlo::DynamicUpdateSliceOp>(
loc, result_type, operands[0], operands[1],
- llvm::ArrayRef<Value*>(operands.begin() + 2, operands.end()))
+ llvm::ArrayRef<Value>(operands.begin() + 2, operands.end()))
.getOperation();
}
case HloOpcode::kInfeed: {
@@ -371,6 +377,28 @@
ConvertDimensions(instruction->dimensions()))
.getOperation();
}
+ case HloOpcode::kRng: {
+ auto shape = func_builder->create<mlir::ConstantOp>(
+ loc, Convert(result_type.cast<RankedTensorType>().getShape()));
+ switch (instruction->random_distribution()) {
+ case xla::RNG_UNIFORM:
+ return func_builder
+ ->create<mlir::xla_hlo::RngUniformOp>(
+ loc, result_type, operands[0], operands[1], shape)
+ .getOperation();
+
+ case xla::RNG_NORMAL:
+ return func_builder
+ ->create<mlir::xla_hlo::RngNormalOp>(
+ loc, result_type, operands[0], operands[1], shape)
+ .getOperation();
+
+ default:
+ return tensorflow::errors::InvalidArgument(absl::StrCat(
+ "Unsupported distribution: ",
+ RandomDistributionToString(instruction->random_distribution())));
+ }
+ }
case HloOpcode::kWhile: {
auto op = func_builder->create<mlir::xla_hlo::WhileOp>(
loc, operands[0]->getType(), operands[0]);
@@ -473,10 +501,12 @@
NoAttributeCase(kPower, PowOp);
NoAttributeCase(kReal, RealOp);
NoAttributeCase(kRemainder, RemOp);
+ NoAttributeCase(kReplicaId, ReplicaIdOp);
// The dimensions attribute is not present on the HLO Reshape instruction.
// If dimensions are non-default, the XLA builder implements it as a
// separate transpose.
NoAttributeCase(kReshape, ReshapeOp);
+ NoAttributeCase(kRoundNearestAfz, RoundOp);
NoAttributeCase(kRsqrt, RsqrtOp);
NoAttributeCase(kSelect, SelectOp);
NoAttributeCase(kShiftLeft, ShiftLeftOp);
@@ -512,9 +542,9 @@
}
}
-StatusOr<llvm::SmallVector<mlir::Value*, 4>> HloFunctionImporter::GetOperands(
+StatusOr<llvm::SmallVector<mlir::Value, 4>> HloFunctionImporter::GetOperands(
HloInstruction* instruction) {
- llvm::SmallVector<mlir::Value*, 4> operands;
+ llvm::SmallVector<mlir::Value, 4> operands;
for (const auto& operand : instruction->operands()) {
auto input_it = instruction_value_map_.find(operand);
if (input_it == instruction_value_map_.end()) {
@@ -602,8 +632,7 @@
return tensorflow::Status::OK();
}
-StatusOr<Value*> HloFunctionImporter::GetMlirValue(
- HloInstruction* instruction) {
+StatusOr<Value> HloFunctionImporter::GetMlirValue(HloInstruction* instruction) {
auto lookup = instruction_value_map_.find(instruction);
if (lookup != instruction_value_map_.end()) {
return lookup->second;
diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.h b/tensorflow/compiler/mlir/xla/hlo_function_importer.h
index bd36c9b..ba62224 100644
--- a/tensorflow/compiler/mlir/xla/hlo_function_importer.h
+++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.h
@@ -71,7 +71,7 @@
mlir::OpBuilder* func_builder);
// Gets the MLIR operand values from an HLO Instruction.
- StatusOr<llvm::SmallVector<mlir::Value*, 4>> GetOperands(
+ StatusOr<llvm::SmallVector<mlir::Value, 4>> GetOperands(
xla::HloInstruction* instruction);
// Converts xla Tensor type to the corresponding MLIR type.
@@ -89,7 +89,7 @@
llvm::SmallVectorImpl<mlir::Type>* types);
// Returns the Mlir Value for the corresponding HloInstruction.
- StatusOr<mlir::Value*> GetMlirValue(xla::HloInstruction* instruction);
+ StatusOr<mlir::Value> GetMlirValue(xla::HloInstruction* instruction);
// Converts an XLA PrecisionConfig to the corresponding MLIR attribute.
mlir::NamedAttribute ConvertPrecisionConfig(xla::HloInstruction* instruction);
@@ -129,7 +129,7 @@
std::unordered_map<xla::HloComputation*, mlir::FuncOp>* function_map_;
// Mapping from HloInstructions to the associative MLIR values.
- std::unordered_map<xla::HloInstruction*, mlir::Value*> instruction_value_map_;
+ std::unordered_map<xla::HloInstruction*, mlir::Value> instruction_value_map_;
};
} // namespace xla
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
index 41e561f..e819370 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
@@ -203,7 +203,7 @@
// AbsOp
//===----------------------------------------------------------------------===//
-void AbsOp::build(Builder* builder, OperationState& result, Value* operand) {
+void AbsOp::build(Builder* builder, OperationState& result, Value operand) {
auto shaped_type = operand->getType().cast<ShapedType>();
Type new_type;
if (!shaped_type.getElementType().isa<ComplexType>()) {
@@ -222,7 +222,7 @@
// ConvertOp
//===----------------------------------------------------------------------===//
-void ConvertOp::build(Builder* builder, OperationState& result, Value* operand,
+void ConvertOp::build(Builder* builder, OperationState& result, Value operand,
Type result_element_ty) {
Type result_ty;
Type operand_ty = operand->getType();
@@ -431,8 +431,8 @@
// ComplexOp
//===----------------------------------------------------------------------===//
-void ComplexOp::build(Builder* builder, OperationState& state, Value* lhs,
- Value* rhs) {
+void ComplexOp::build(Builder* builder, OperationState& state, Value lhs,
+ Value rhs) {
auto type = lhs->getType();
auto element_ty = ComplexType::get(getElementTypeOrSelf(type));
Type result_ty;
@@ -476,7 +476,7 @@
}
} // namespace
-void ImagOp::build(Builder* builder, OperationState& state, Value* val) {
+void ImagOp::build(Builder* builder, OperationState& state, Value val) {
build(builder, state, CreateRealType(val->getType()), val);
}
@@ -489,7 +489,7 @@
return {};
}
-void RealOp::build(Builder* builder, OperationState& state, Value* val) {
+void RealOp::build(Builder* builder, OperationState& state, Value val) {
build(builder, state, CreateRealType(val->getType()), val);
}
@@ -611,7 +611,7 @@
SmallVector<Type, 1> result_ty;
result_ty.reserve(operands.size());
- for (Value* operand : operands) {
+ for (Value operand : operands) {
result_ty.push_back(
GetReduceResultType(operand->getType(), dimensions, builder));
}
@@ -622,7 +622,7 @@
SmallVectorImpl<OpFoldResult>& results) {
// No dimensions to reduce.
if (dimensions().getNumElements() == 0) {
- for (Value* input : this->operands()) {
+ for (Value input : this->operands()) {
results.push_back(input);
}
return success();
@@ -758,8 +758,8 @@
} // namespace
#define BINARY_BUILDER(Op) \
- void Op::build(Builder* builder, OperationState& result, Value* left, \
- Value* right, DenseIntElementsAttr broadcast_dimensions) { \
+ void Op::build(Builder* builder, OperationState& result, Value left, \
+ Value right, DenseIntElementsAttr broadcast_dimensions) { \
auto type = GetBroadcastType(builder, left->getType().cast<ShapedType>(), \
right->getType().cast<ShapedType>(), \
getElementTypeOrSelf(right->getType()), \
@@ -790,7 +790,7 @@
// SliceOp
//===----------------------------------------------------------------------===//
-void SliceOp::build(Builder* builder, OperationState& result, Value* operand,
+void SliceOp::build(Builder* builder, OperationState& result, Value operand,
DenseIntElementsAttr start_indices,
DenseIntElementsAttr limit_indices,
DenseIntElementsAttr strides) {
@@ -811,7 +811,7 @@
return llvm::divideCeil(end - start, stride);
}
-Type SliceOp::InferOutputTypes(Builder* builder, Value* operand,
+Type SliceOp::InferOutputTypes(Builder* builder, Value operand,
DenseIntElementsAttr start_indices,
DenseIntElementsAttr limit_indices,
DenseIntElementsAttr strides) {
@@ -852,7 +852,7 @@
SmallVector<Type, 2> element_types;
element_types.reserve(operands.size());
- for (Value* operand : operands) element_types.push_back(operand->getType());
+ for (Value operand : operands) element_types.push_back(operand->getType());
state.addTypes(builder->getTupleType(element_types));
state.addRegion();
@@ -863,13 +863,13 @@
if (operands.empty()) return op.emitOpError("requires at least one input");
// TODO(antiagainst): verify partionally dynamic shapes
- if (llvm::all_of(operands, [](Value* operand) {
+ if (llvm::all_of(operands, [](Value operand) {
return operand->getType().cast<ShapedType>().hasRank();
})) {
ArrayRef<int64_t> input_shape =
(*operands.begin())->getType().cast<ShapedType>().getShape();
- if (llvm::any_of(llvm::drop_begin(operands, 1), [&](Value* operand) {
+ if (llvm::any_of(llvm::drop_begin(operands, 1), [&](Value operand) {
return operand->getType().cast<ShapedType>().getShape() !=
input_shape;
}))
@@ -971,7 +971,7 @@
//===----------------------------------------------------------------------===//
void GetTupleElementOp::build(Builder* builder, OperationState& result,
- Value* tuple, int32_t index) {
+ Value tuple, int32_t index) {
if (auto tuple_type = tuple->getType().dyn_cast<TupleType>()) {
auto element_type = tuple_type.getType(index);
build(builder, result, element_type, tuple,
@@ -1011,8 +1011,8 @@
// CompareOp
//===----------------------------------------------------------------------===//
-void CompareOp::build(Builder* builder, OperationState& result, Value* lhs,
- Value* rhs, DenseIntElementsAttr broadcast_dimensions,
+void CompareOp::build(Builder* builder, OperationState& result, Value lhs,
+ Value rhs, DenseIntElementsAttr broadcast_dimensions,
StringAttr comparison_direction) {
auto new_type = GetBroadcastType(builder, lhs->getType(), rhs->getType(),
builder->getI1Type(), broadcast_dimensions);
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
index 90c3189..6d8c11e 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
@@ -76,6 +76,9 @@
// Any int, floating-point or complex tensor types
def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, AnyFloat, AnyComplex]>;
+// Any pred, int or floating-point tensor types
+def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, AnyFloat]>;
+
//===----------------------------------------------------------------------===//
// XLA nullary op definitions.
//===----------------------------------------------------------------------===//
@@ -128,7 +131,7 @@
def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs",
[NoSideEffect, SameOperandsAndResultShape], HLO_Tensor>, BASE_HLO_AbsOp {
let builders = [OpBuilder<
- "Builder *builder, OperationState &result, Value *operand"
+ "Builder *builder, OperationState &result, Value operand"
>];
}
@@ -140,7 +143,7 @@
BASE_HLO_ConvertOp {
let builders = [OpBuilder<
- "Builder *, OperationState &tblgen_state, Value *operand, "
+ "Builder *, OperationState &tblgen_state, Value operand, "
"Type result_element_ty"
>];
@@ -149,6 +152,10 @@
let hasCustomHLOConverter = 1;
}
+def HLO_ClzOp: HLO_UnaryElementwiseOp<"count_leading_zeros",
+ [NoSideEffect, SameOperandsAndResultType], HLO_IntTensor>,
+ BASE_HLO_ClzOp;
+
def HLO_CosOp: HLO_UnaryElementwiseOp<"cos",
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>,
BASE_HLO_CosOp;
@@ -191,6 +198,9 @@
[NoSideEffect, SameOperandsAndResultType], HLO_IntTensor>,
BASE_HLO_PopulationCountOp;
+def HLO_RoundOp: HLO_UnaryElementwiseOp<"round",
+ [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_RoundOp;
+
def HLO_RsqrtOp: HLO_UnaryElementwiseOp<"rsqrt",
[NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>,
BASE_HLO_RsqrtOp;
@@ -220,7 +230,7 @@
[NoSideEffect, SameOperandsElementType, SameOperandsAndResultShape]>,
BASE_HLO_ComplexOp {
let builders = [OpBuilder<
- "Builder *, OperationState &tblgen_state, Value *lhs, Value *rhs">];
+ "Builder *, OperationState &tblgen_state, Value lhs, Value rhs">];
let arguments = (ins HLO_FpTensor:$lhs, HLO_FpTensor:$rhs);
let results = (outs HLO_ComplexTensor);
@@ -230,7 +240,7 @@
def HLO_ImagOp: HLO_Op<
"imag", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_ImagOp {
let builders = [OpBuilder<
- "Builder *, OperationState &tblgen_state, Value *val">];
+ "Builder *, OperationState &tblgen_state, Value val">];
let arguments = (ins HLO_ComplexTensor);
let results = (outs HLO_FpTensor);
@@ -240,7 +250,7 @@
def HLO_RealOp: HLO_Op<
"real", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_RealOp {
let builders = [OpBuilder<
- "Builder *, OperationState &tblgen_state, Value *val">];
+ "Builder *, OperationState &tblgen_state, Value val">];
let arguments = (ins HLO_ComplexTensor);
let results = (outs HLO_FpTensor);
@@ -261,7 +271,7 @@
);
let builders = [OpBuilder<
- "Builder *builder, OperationState &result, Value *left, Value* right, "
+ "Builder *builder, OperationState &result, Value left, Value right, "
"DenseIntElementsAttr broadcast_dimensions"
>];
@@ -328,6 +338,15 @@
// XLA communication op definitions.
//===----------------------------------------------------------------------===//
+// Represents a unique identifier for each Send/Recv instruction pair or
+// optionally for collective instructions (AllReduce, CollectivePermute,
+// AllToAll). Non-positive channel_id handle is equivalent to no channel id.
+def ChannelHandle : StructAttr<"ChannelHandle", HLO_Dialect, [
+ StructFieldAttr<"handle", I64Attr>,
+ StructFieldAttr<"type", I64Attr>]> {
+ let description = "two 64-bit integers 'handle' and 'type'";
+}
+
// InfeedOp corresponds to 'InfeedWithToken' xla client API and not 'Infeed'.
// InfeedWithToken allows ordering of infeed HLO instructions using tokens.
def HLO_InfeedOp : HLO_Op<"infeed", []> {
@@ -374,6 +393,42 @@
let hasCustomHLOConverter = 1;
}
+def HLO_SendOp : HLO_Op<"send", []> {
+
+ string summary = "Send operator";
+
+ string description = [{
+ Sends the given operand data to a Recv instruction in another computation
+ that shares the same channel handle. Does not return any data. Similar to
+ the Recv operation, Send operation represents synchronous communication,
+ and is internally decomposed into 2 HLO instructions (Send and SendDone) to
+ enable asynchronous data transfers.
+
+ See https://www.tensorflow.org/xla/operation_semantics#send.
+ }];
+
+ let arguments = (ins
+ HLO_TensorOrTuple:$operand,
+ HLO_Token:$token,
+ ChannelHandle:$channel_id,
+ DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer
+ );
+
+ let results = (outs HLO_Token);
+ let hasCustomHLOConverter = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// XLA parallelism related op definitions.
+//===----------------------------------------------------------------------===//
+
+def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect]>,
+ BASE_HLO_ReplicaIdOp {
+ // TODO(prakalps): The output should unsigned 32-bit integer but mlir does
+ // not differentiate between signed and unsigned int.
+ let results = (outs I32Tensor);
+}
+
//===----------------------------------------------------------------------===//
// XLA control flow op definitions.
//===----------------------------------------------------------------------===//
@@ -393,7 +448,6 @@
let arguments = (ins Variadic<HLO_Token>:$operands);
let results = (outs HLO_Token);
- let hasCustomHLOConverter = 1;
}
def HLO_ConditionalOp: HLO_Op<"conditional", [NoSideEffect]> {
@@ -440,15 +494,6 @@
let hasCustomHLOConverter = 1;
}
-// Represents a unique identifier for each Send/Recv instruction pair or
-// optionally for collective instructions (AllReduce, CollectivePermute,
-// AllToAll). Non-positive channel_id handle is equivalent to no channel id.
-def ChannelHandle : StructAttr<"ChannelHandle", HLO_Dialect, [
- StructFieldAttr<"handle", I64Attr>,
- StructFieldAttr<"type", I64Attr>]> {
- let description = "two 64-bit integers 'handle' and 'type'";
-}
-
def HLO_AllReduceOp : HLO_Op<"all_reduce",
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_AllReduceOp {
@@ -508,7 +553,7 @@
let builders = [OpBuilder<
"Builder *builder, OperationState &results, "
- "Value* value, int32_t index">];
+ "Value value, int32_t index">];
}
def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp {
@@ -519,8 +564,6 @@
"Builder *builder, OperationState &results, "
"ValueRange values">];
- // TupleOp has special conversion logic to HLO.
- let hasCustomHLOConverter = 1;
}
def HLO_CompareOp: HLO_Op<"compare",
@@ -532,14 +575,14 @@
HLO_ComparisonDirectionAttr:$comparison_direction
);
let builders = [OpBuilder<
- "Builder *builder, OperationState &result, Value *left, Value* right, "
+ "Builder *builder, OperationState &result, Value left, Value right, "
"DenseIntElementsAttr broadcast_dimensions, "
"StringAttr comparison_direction"
>];
let results = (outs HLO_PredTensor);
let builders = [OpBuilder<
- "Builder *builder, OperationState &result, Value *lhs, Value *rhs, "
+ "Builder *builder, OperationState &result, Value lhs, Value rhs, "
"DenseIntElementsAttr broadcast_dimensions, StringAttr comparison_direction"
>];
}
@@ -562,7 +605,7 @@
let results = (outs HLO_Tensor);
let builders = [OpBuilder<
- "Builder *builder, OperationState &result, Value *operand, "
+ "Builder *builder, OperationState &result, Value operand, "
"DenseIntElementsAttr start_indices, DenseIntElementsAttr limit_indices, "
"DenseIntElementsAttr strides"
>];
@@ -570,7 +613,7 @@
let extraClassDeclaration = [{
// Infers output type for given operand and attributes. Result type is
// unranked if any of the attributes is illegal.
- static Type InferOutputTypes(Builder *builder, Value *operand,
+ static Type InferOutputTypes(Builder *builder, Value operand,
DenseIntElementsAttr start_indices,
DenseIntElementsAttr limit_indices,
DenseIntElementsAttr strides);
@@ -684,6 +727,16 @@
let hasCustomHLOConverter = 1;
}
+def HLO_CholeskyOp : HLO_Op<"cholesky",
+ [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_CholeskyOp {
+ let arguments = (ins
+ HLO_FpOrComplexTensor:$a,
+ DefaultValuedAttr<BoolAttr, "false">:$lower
+ );
+
+ let results = (outs HLO_FpOrComplexTensor);
+}
+
def HLO_ClampOp : HLO_Op<"clamp",
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ClampOp {
let arguments = (ins
@@ -707,8 +760,6 @@
let hasFolder = 1;
- // TODO(b/129422361) ConcatOp has special conversion logic to HLO.
- let hasCustomHLOConverter = 1;
}
def HLO_CrossReplicaSumOp : HLO_Op<"cross-replica-sum",
@@ -758,8 +809,6 @@
let results = (outs HLO_Tensor);
- // TODO(b/129422361): Conv Op has special conversion logic to HLO.
- let hasCustomHLOConverter = 1;
}
def HLO_CopyOp: HLO_Op<"copy", [NoSideEffect, SameOperandsAndResultType]> {
@@ -801,7 +850,9 @@
let results = (outs HLO_Tensor);
}
-def BASE_EinsumOp {
+// Define Base Einsum op within the HLO dialect as these are client ops and
+// therefore this class is not common between HLO and LHLO ops.
+class BASE_EinsumOp {
string summary = "Einsum operator";
string description = [{
@@ -810,7 +861,7 @@
}];
}
-def HLO_EinsumOp: HLO_Op<"einsum", [NoSideEffect]> {
+def HLO_EinsumOp: HLO_Op<"einsum", [NoSideEffect]>, BASE_EinsumOp {
let arguments = (ins
HLO_Tensor:$lhs,
HLO_Tensor:$rhs,
@@ -823,7 +874,7 @@
// side HLO ops.
}
-def HLO_UnaryEinsumOp: HLO_Op<"unary_einsum", [NoSideEffect]> {
+def HLO_UnaryEinsumOp: HLO_Op<"unary_einsum", [NoSideEffect]>, BASE_EinsumOp {
let arguments = (ins
HLO_Tensor:$operand,
StrAttr:$einsum_config
@@ -846,9 +897,6 @@
);
let results = (outs HLO_Tensor);
-
- // TODO(b/129422361) Attributes are not supported by the codegen.
- let hasCustomHLOConverter = 1;
}
def GatherDimensionNumbers : StructAttr<"GatherDimensionNumbers", HLO_Dialect,
@@ -869,8 +917,6 @@
);
let results = (outs HLO_Tensor);
-
- let hasCustomHLOConverter = 1;
}
def HLO_GetDimensionSizeOp: HLO_Op<"get_dimension_size", [NoSideEffect]>,
@@ -976,9 +1022,6 @@
let results = (outs HLO_Tensor);
let hasFolder = 1;
-
- // TODO(b/129422361): ReverseOp has a custom constructor for HLO.
- let hasCustomHLOConverter = 1;
}
def HLO_PadOp: HLO_Op<"pad",
@@ -1079,12 +1122,24 @@
//===----------------------------------------------------------------------===//
def HLO_RngUniformOp : HLO_Op<"rng_uniform", []>, BASE_HLO_RngUniformOp {
let arguments = (ins
- HLO_Tensor:$a,
- HLO_Tensor:$b,
+ HLO_PredIntOrFpTensor:$a,
+ HLO_PredIntOrFpTensor:$b,
I64Tensor:$shape
);
- let results = (outs HLO_Tensor);
+ let results = (outs HLO_PredIntOrFpTensor);
+
+ let hasCustomHLOConverter = 1;
+}
+
+def HLO_RngNormalOp : HLO_Op<"rng_normal", []>, BASE_HLO_RngNormalOp {
+ let arguments = (ins
+ HLO_FpTensor:$mu,
+ HLO_FpTensor:$sigma,
+ I64Tensor:$shape
+ );
+
+ let results = (outs HLO_FpTensor);
let hasCustomHLOConverter = 1;
}
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td
index 3be2c26..ac7351e 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td
@@ -68,6 +68,17 @@
}];
}
+class BASE_HLO_ClzOp {
+ string summary = "Count-leading-zeros (Clz) operator";
+
+ string description = [{
+ Returns the number of leading zeros in each operand element-wise.
+
+ See
+ https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
+ }];
+}
+
class BASE_HLO_ComplexOp {
string summary = "Complex operator";
@@ -228,6 +239,18 @@
}];
}
+class BASE_HLO_RoundOp {
+ string summary = "Round operator";
+
+ string description = [{
+ Returns `Round(operand)` element-wise, rounding to nearest integer with
+ half-way cases rounding away from zero.
+
+ See
+ https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
+ }];
+}
+
class BASE_HLO_RsqrtOp {
string summary = "Reciprocal Square-root operator";
@@ -465,6 +488,26 @@
}];
}
+//===----------------------------------------------------------------------===//
+// XLA parallelism related op definitions.
+//===----------------------------------------------------------------------===//
+
+class BASE_HLO_ReplicaIdOp {
+ string summary = "ReplicaId operator";
+
+ string description = [{
+ Returns the unique ID (int32 scalar) of the replica.
+
+ The unique ID of each replica is an unsigned integer in the interval [0, N),
+ where N is the number of replicas. Since all the replicas are running the
+ same program, a ReplicaId() call in the program will return a different
+ value on each replica.
+
+ See https://www.tensorflow.org/xla/operation_semantics#replicaid.
+ }];
+}
+
+
class BASE_HLO_AllReduceOp {
string summary = "AllReduce operator";
@@ -707,6 +750,32 @@
}];
}
+class BASE_HLO_CholeskyOp {
+ string summary = "Cholesky operator";
+
+ string description = [{
+ Computes the Cholesky decomposition of a batch of symmetric (Hermitian)
+ positive definite matrices.
+
+ If lower is true, computes lower-triangular matrices l such that
+ `a=l.Transpose(l)`. If lower is false, computes upper-triangular matrices u such
+ that `a=Transpose(u).u`.
+
+ Input data is read only from the lower/upper triangle of a, depending on the
+ value of lower. Values from the other triangle are ignored. Output data is
+ returned in the same triangle; the values in the other triangle are
+ implementation-defined and may be anything.
+
+ If the rank of a is greater than 2, a is treated as a batch of matrices, where
+ all except the minor 2 dimensions are batch dimensions.
+
+ If a is not symmetric (Hermitian) positive definite, the result is
+ implementation-defined.
+
+ See https://www.tensorflow.org/xla/operation_semantics#cholesky.
+ }];
+}
+
class BASE_HLO_ClampOp {
string summary = "Clamp operator";
@@ -895,11 +964,26 @@
string summary = "RNG with uniform distribution.";
string description = [{
- Constructs an output of a given shape with random numbers generated following
- the uniform distribution over the interval `[a,b)`.
+ Constructs an output of a given shape with random numbers generated
+ following the uniform distribution over the interval `[a,b)`. The parameters
+ and output element type have to be a boolean type, an integral type or a
+ floating point types, and the types have to be consistent.
See https://www.tensorflow.org/xla/operation_semantics#rnguniform.
}];
}
+class BASE_HLO_RngNormalOp {
+ string summary = "RNG with normal distribution.";
+
+ string description = [{
+ Constructs an output of a given shape with random numbers generated
+ following the normal distribution with parameters `mu` and `sigma`. The
+ parameters and output shape have to have a floating point elemental type.
+ The parameters furthermore have to be scalar valued.
+
+ See https://www.tensorflow.org/xla/operation_semantics#rngnormal.
+ }];
+}
+
#endif // HLO_OPS_BASE
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_utils.cc b/tensorflow/compiler/mlir/xla/ir/hlo_utils.cc
index 7d3e2ca..794b8aa 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_utils.cc
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_utils.cc
@@ -22,8 +22,7 @@
namespace mlir {
namespace xla {
-DenseIntElementsAttr getBroadcastDimensionsAttr(Builder *b, Value *x,
- Value *y) {
+DenseIntElementsAttr getBroadcastDimensionsAttr(Builder *b, Value x, Value y) {
TensorType xType = x->getType().dyn_cast<RankedTensorType>();
TensorType yType = y->getType().dyn_cast<RankedTensorType>();
if (xType == yType || !xType || !yType) return {};
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_utils.h b/tensorflow/compiler/mlir/xla/ir/hlo_utils.h
index 86c90b4..ce03231 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_utils.h
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_utils.h
@@ -29,12 +29,12 @@
// Computes the broadcast dimensions attr for an elementwise binary operator
// between two ranked tensors.
mlir::DenseIntElementsAttr getBroadcastDimensionsAttr(mlir::Builder* b,
- mlir::Value* x,
- mlir::Value* y);
+ mlir::Value x,
+ mlir::Value y);
/// Get a constant splat for the given value type.
template <typename T>
-static ElementsAttr getSplat(Builder* b, Value* val, T constant) {
+static ElementsAttr getSplat(Builder* b, Value val, T constant) {
auto valType = val->getType().cast<TensorType>();
auto valElementType = getElementTypeOrSelf(val->getType());
diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
index dc248d6..f5bec22 100644
--- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
+++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
@@ -91,6 +91,8 @@
return value.convertToDouble();
}
+static inline bool Convertbool(bool value) { return value; }
+
static absl::string_view ConvertStringRef(mlir::StringRef value) {
return {value.data(), value.size()};
}
@@ -115,6 +117,15 @@
return ConvertDenseIntAttr(*broadcast_dimensions);
}
+// Converts StringRef to xla FftType enum
+static xla::FftType Convert_fft_type(llvm::StringRef fft_type_str) {
+ xla::FftType fft_type_enum;
+ // Illegal fft_type string would be caught by the verifier, so 'FftType_Parse'
+ // call below should never return false.
+ if (!FftType_Parse(fft_type_str, &fft_type_enum)) return xla::FftType::FFT;
+ return fft_type_enum;
+}
+
// Convert a nx2 dense attribute to a list of tuples. This is the way padding
// is defined in hlo.
static std::vector<std::pair<int64, int64>> Convert_padding(
@@ -151,10 +162,10 @@
return result;
}
-#define I64_ELEMENTS_ATTR_TO_VECTOR(attribute) \
- static std::vector<int64> Convert_##attribute( \
- mlir::DenseIntElementsAttr attribute) { \
- return ConvertDenseIntAttr(attribute); \
+#define I64_ELEMENTS_ATTR_TO_VECTOR(attribute) \
+ static std::vector<int64> Convert_##attribute( \
+ llvm::Optional<mlir::DenseIntElementsAttr> attribute) { \
+ return ConvertDenseIntAttr(attribute); \
}
I64_ELEMENTS_ATTR_TO_VECTOR(broadcast_sizes);
@@ -163,6 +174,11 @@
I64_ELEMENTS_ATTR_TO_VECTOR(limit_indices);
I64_ELEMENTS_ATTR_TO_VECTOR(strides);
I64_ELEMENTS_ATTR_TO_VECTOR(slice_sizes);
+I64_ELEMENTS_ATTR_TO_VECTOR(fft_length);
+I64_ELEMENTS_ATTR_TO_VECTOR(dimensions);
+I64_ELEMENTS_ATTR_TO_VECTOR(window_strides);
+I64_ELEMENTS_ATTR_TO_VECTOR(lhs_dilation);
+I64_ELEMENTS_ATTR_TO_VECTOR(rhs_dilation);
#undef I64_ELEMENTS_ATTR_TO_VECTOR
@@ -230,7 +246,7 @@
return dot_dimension_numbers;
}
-static xla::ConvolutionDimensionNumbers Convert_convolution_dimension_numbers(
+static xla::ConvolutionDimensionNumbers Convert_dimension_numbers(
mlir::xla_hlo::ConvDimensionNumbers input) {
xla::ConvolutionDimensionNumbers output;
@@ -281,7 +297,7 @@
.ValueOrDie();
}
-static xla::GatherDimensionNumbers Convert_gather_dimension_numbers(
+static xla::GatherDimensionNumbers Convert_dimension_numbers(
mlir::xla_hlo::GatherDimensionNumbers input) {
xla::GatherDimensionNumbers output;
@@ -335,7 +351,7 @@
namespace {
class ConvertToHloModule {
public:
- using ValueLoweringMap = llvm::DenseMap<Value*, xla::XlaOp>;
+ using ValueLoweringMap = llvm::DenseMap<Value, xla::XlaOp>;
using FunctionLoweringMap = llvm::DenseMap<mlir::FuncOp, xla::XlaComputation>;
// If use_tuple_args is true, then the entry function's arguments are
@@ -417,7 +433,7 @@
namespace {
struct OpLoweringContext {
- llvm::DenseMap<mlir::Value*, xla::XlaOp>* values;
+ llvm::DenseMap<mlir::Value, xla::XlaOp>* values;
mlir::ConvertToHloModule* converter;
xla::XlaBuilder* builder;
};
@@ -425,7 +441,7 @@
llvm::SmallVector<xla::XlaOp, 4> GetTuple(mlir::Operation::operand_range values,
OpLoweringContext ctx) {
llvm::SmallVector<xla::XlaOp, 4> ops;
- for (mlir::Value* value : values) {
+ for (mlir::Value value : values) {
ops.push_back((*ctx.values)[value]);
}
return ops;
@@ -437,16 +453,6 @@
namespace xla_hlo {
namespace {
-LogicalResult ExportXlaOp(AfterAllOp op, OpLoweringContext ctx) {
- auto& value_map = *ctx.values;
- std::vector<xla::XlaOp> tokens(op.operands().size());
- for (auto index_and_value : llvm::enumerate(op.operands())) {
- tokens[index_and_value.index()] = value_map[index_and_value.value()];
- }
- value_map[op] = xla::AfterAll(ctx.builder, tokens);
- return mlir::success();
-}
-
LogicalResult ExportXlaOp(AllReduceOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values;
xla::XlaComputation computation;
@@ -485,13 +491,6 @@
return success();
}
-LogicalResult ExportXlaOp(ConcatenateOp op, OpLoweringContext ctx) {
- auto& value_map = *ctx.values;
- value_map[op] = xla::ConcatInDim(ctx.builder, GetTuple(op.val(), ctx),
- op.dimension().getSExtValue());
- return success();
-}
-
LogicalResult ExportXlaOp(ConditionalOp op, OpLoweringContext ctx) {
xla::XlaComputation true_branch;
xla::XlaComputation false_branch;
@@ -514,21 +513,6 @@
return failure();
}
-LogicalResult ExportXlaOp(ConvOp op, OpLoweringContext ctx) {
- auto& value_map = *ctx.values;
- value_map[op] = xla::ConvGeneralDilated(
- value_map[op.lhs()], value_map[op.rhs()],
- Convert_broadcast_dimensions(op.window_strides()),
- Convert_padding(op.padding()),
- Convert_broadcast_dimensions(op.lhs_dilation()),
- Convert_broadcast_dimensions(op.rhs_dilation()),
- Convert_convolution_dimension_numbers(op.dimension_numbers()),
- op.feature_group_count().getSExtValue(),
- op.batch_group_count().getSExtValue(),
- Convert_precision_config(op.precision_config()).get());
- return success();
-}
-
LogicalResult ExportXlaOp(ConvertOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values;
value_map[op] = xla::ConvertElementType(
@@ -537,22 +521,6 @@
return success();
}
-LogicalResult ExportXlaOp(CopyOp op, OpLoweringContext ctx) {
- return failure();
-}
-
-LogicalResult ExportXlaOp(FftOp op, OpLoweringContext ctx) { return failure(); }
-
-LogicalResult ExportXlaOp(GatherOp op, OpLoweringContext ctx) {
- auto& value_map = *ctx.values;
- xla::GatherDimensionNumbers dimension_numbers =
- Convert_gather_dimension_numbers(op.dimension_numbers());
- value_map[op] = xla::Gather(
- value_map[op.operand()], value_map[op.start_indices()], dimension_numbers,
- Convert_slice_sizes(op.slice_sizes()), op.indices_are_sorted());
- return success();
-}
-
LogicalResult ExportXlaOp(InfeedOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values;
// The shape argument expected by the xla client API is the type of the first
@@ -645,10 +613,10 @@
return failure();
}
-LogicalResult ExportXlaOp(ReverseOp op, OpLoweringContext ctx) {
+LogicalResult ExportXlaOp(RngNormalOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values;
- value_map[op] = xla::Rev(value_map[op.operand()],
- Convert_broadcast_dimensions(op.dimensions()));
+ value_map[op] = xla::RngNormal(value_map[op.mu()], value_map[op.sigma()],
+ xla::TypeToShape(op.getType()));
return success();
}
@@ -692,6 +660,21 @@
return success();
}
+LogicalResult ExportXlaOp(SendOp op, OpLoweringContext ctx) {
+ auto& value_map = *ctx.values;
+ if (op.is_host_transfer()) {
+ value_map[op] =
+ xla::SendToHost(value_map[op.operand()], value_map[op.token()],
+ xla::TypeToShape(op.operand().getType()),
+ Convert_channel_handle(op.channel_id()));
+ return success();
+ }
+ value_map[op] =
+ xla::SendWithToken(value_map[op.operand()], value_map[op.token()],
+ Convert_channel_handle(op.channel_id()));
+ return success();
+}
+
LogicalResult ExportXlaOp(SliceOp op, OpLoweringContext ctx) {
return failure();
}
@@ -708,12 +691,6 @@
return success();
}
-LogicalResult ExportXlaOp(TupleOp op, OpLoweringContext ctx) {
- auto& value_map = *ctx.values;
- value_map[op] = xla::Tuple(ctx.builder, GetTuple(op.val(), ctx));
- return success();
-}
-
LogicalResult ExportXlaOp(UnaryEinsumOp op, OpLoweringContext ctx) {
// Intentional as UnaryEinsumOp is always lowered to the EinsumOp with two
// operands.
@@ -914,7 +891,7 @@
}
} else {
for (auto& it : llvm::enumerate(bb.getArguments())) {
- auto* arg = it.value();
+ auto arg = it.value();
auto num = it.index();
xla::Shape shape = xla::TypeToShape(arg->getType());
lowering[arg] =
diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h
index 3dffe2b..af32c50 100644
--- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h
+++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h
@@ -37,7 +37,7 @@
// from `value_lowering` map.
llvm::Optional<xla::XlaOp> CreateXlaOperator(
mlir::Operation* op,
- llvm::DenseMap<mlir::Value*, xla::XlaOp>* value_lowering);
+ llvm::DenseMap<mlir::Value, xla::XlaOp>* value_lowering);
} // namespace mlir
diff --git a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc
index acc3c17..7c5694e 100644
--- a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc
+++ b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc
@@ -17,6 +17,7 @@
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringMap.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/InitLLVM.h"
@@ -42,14 +43,31 @@
Attribute attr = named_attr.attr;
StringRef storage_type = attr.getStorageType();
// For some attribute types we have a general conversion, so use that.
- if (!attr.isEnumAttr() && (storage_type.endswith("IntegerAttr") ||
+ if (!attr.isEnumAttr() && (storage_type.endswith("BoolAttr") ||
storage_type.endswith("FloatAttr") ||
+ storage_type.endswith("IntegerAttr") ||
storage_type.endswith("StringAttr"))) {
return "Convert" + attr.getReturnType().str();
}
return "Convert_" + named_attr.name.str();
}
+static std::string GetClientBuilder(const Operator& op) {
+ static const auto* kOpToXLABuilderMap =
+ new llvm::StringMap<StringRef>{{"ReverseOp", "Rev"},
+ {"ConcatenateOp", "ConcatInDim"},
+ {"ConvOp", "ConvGeneralDilated"}};
+
+ StringRef op_name = op.getCppClassName();
+
+ // Default case where the client builder method names closely follow the op
+ // names in the dialect. For e.g., AddOp -> xla::Add method.
+ if (!kOpToXLABuilderMap->count(op_name)) return op_name.drop_back(2);
+
+ // Otherwise, if the op to client builder method mapping is provided.
+ return kOpToXLABuilderMap->lookup(op_name);
+}
+
static void BuildOperator(const Operator& op, raw_ostream* output) {
auto& os = *output;
os << " auto& value_map = *lowering_context.values;\n"
@@ -71,7 +89,7 @@
}
// Otherwise, this is a varidiac operand list.
- os << " std::vector<xla::XlaOp> xla_arg_" << index << ";"
+ os << " std::vector<xla::XlaOp> xla_arg_" << index << ";\n"
<< " for (auto operand : xla_op.getODSOperands(" << operand_number++
<< "))\n xla_arg_" << index
<< ".push_back(value_map[operand]);\n";
@@ -85,10 +103,15 @@
<< op.getArgName(index) << "());\n";
}
- // Assumes that the client builder method names closely follow the op names
- // in the dialect. For e.g., AddOp -> xla::Add method.
- StringRef op_name = op.getCppClassName();
- os << " auto xla_result = xla::" << op_name.drop_back(2) << "(";
+ // Emit call to client API
+ os << " auto xla_result = xla::" << GetClientBuilder(op) << "(";
+
+ // If all operands are variadic, then pass the builder explicitly to xla
+ // client API call
+ if (op.getNumOperands() == op.getNumVariadicOperands()) {
+ os << "lowering_context.builder";
+ if (op.getNumArgs() != 0) os << ", ";
+ }
// Emit each of the arguments.
interleaveComma(llvm::seq<int>(0, op.getNumArgs()), os,
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
index ce9a0d4..7e743ca 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
@@ -6,7 +6,7 @@
// CHECK-LABEL: fusedBatchNorm_notraining
func @fusedBatchNorm_notraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
- // CHECK-NEXT: "xla_hlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK: "xla_hlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
%0:5 = "tf.FusedBatchNorm"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
return %0#0 : tensor<8x8x8x8xf32>
}
@@ -14,22 +14,22 @@
// CHECK-LABEL: fusedBatchNorm_training
func @fusedBatchNorm_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
// TODO(riverriddle) Support training.
- // CHECK-NEXT: "tf.FusedBatchNorm"
+ // CHECK: "tf.FusedBatchNorm"
%0:5 = "tf.FusedBatchNorm"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
return %0#0 : tensor<8x8x8x8xf32>
}
// CHECK-LABEL: fusedBatchNormV3_noTraining
func @fusedBatchNormV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
- // CHECK-NEXT: "xla_hlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK: "xla_hlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
%0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
return %0#0 : tensor<8x8x8x8xf32>
}
//CHECK-LABEL: fusedBatchNormV3_noTraining_mixedPrecision
func @fusedBatchNormV3_noTraining_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) {
- // CHECK-NEXT: %[[RESULT0:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32>
- // CHECK-NEXT: %[[RESULT1:.*]] = "xla_hlo.batch_norm_inference"(%[[RESULT0]], %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK: %[[RESULT0:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32>
+ // CHECK: %[[RESULT1:.*]] = "xla_hlo.batch_norm_inference"(%[[RESULT0]], %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
%0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
// CHECK-NEXT: "xla_hlo.convert"(%[[RESULT1]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16>
return %0#0 : tensor<8x8x8x8xbf16>
@@ -37,19 +37,19 @@
//CHECK-LABEL: fusedBatchNormV3_training
func @fusedBatchNormV3_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
- // CHECK-NEXT: %[[RESULT0:.*]] = "xla_hlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>
+ // CHECK: %[[RESULT0:.*]] = "xla_hlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>
%0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
- // CHECK-NEXT: "xla_hlo.get_tuple_element"(%[[RESULT0]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32>
- // CHECK-NEXT: "xla_hlo.get_tuple_element"(%[[RESULT0]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
- // CHECK-NEXT: %[[VAR:.*]] = "xla_hlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
- // CHECK-NEXT: xla_hlo.constant
- // CHECK-NEXT: "xla_hlo.mul"(%[[VAR]], {{.*}}) : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
+ // CHECK: "xla_hlo.get_tuple_element"(%[[RESULT0]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32>
+ // CHECK: "xla_hlo.get_tuple_element"(%[[RESULT0]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
+ // CHECK: %[[VAR:.*]] = "xla_hlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
+ // CHECK: xla_hlo.constant
+ // CHECK: "xla_hlo.mul"(%[[VAR]], {{.*}}) : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
return %0#0 : tensor<8x8x8x8xf32>
}
//CHECK-LABEL: fusedBatchNormV3_training_mixedPrecision
func @fusedBatchNormV3_training_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) {
- // CHECK-NEXT: "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32>
+ // CHECK: "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32>
%0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
// CHECK: "xla_hlo.convert"({{.*}}) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16>
return %0#0 : tensor<8x8x8x8xbf16>
@@ -57,11 +57,289 @@
//CHECK-LABEL: fusedBatchNormV3_NCHW
func @fusedBatchNormV3_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
- // CHECK-NEXT: "xla_hlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 1 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>
+ // CHECK: "xla_hlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 1 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>
%0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
return %0#0 : tensor<8x8x8x8xf32>
}
+// CHECK-LABEL: fusedBatchNormGrad_noTraining
+func @fusedBatchNormGrad_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
+ // CHECK-NEXT: %[[grad:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[eps:.*]] = xla_hlo.constant dense<1.000000e-03> : tensor<f32>
+
+ // CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg4, %[[eps]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
+ // CHECK-NEXT: %[[scr1:.*]] = "xla_hlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32>
+
+ // CHECK-NEXT: %[[sub:.*]] = "xla_hlo.sub"(%[[act]], %arg3) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %[[grad]], %[[sub]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8x8x8x8xf32>
+ // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64>
+ // CHECK-NEXT: %[[cmul:.*]] = "xla_hlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[init:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
+ // CHECK-NEXT: %[[red1:.*]] = "xla_hlo.reduce"(%[[cmul]], %[[init]]) ( {
+ // CHECK-NEXT: ^bb0(%arg5: tensor<f32>, %arg6: tensor<f32>): // no predecessors
+ // CHECK-NEXT: %[[reduced:.*]] = xla_hlo.add %arg5, %arg6 : tensor<f32>
+ // CHECK-NEXT: "xla_hlo.return"(%[[reduced]]) : (tensor<f32>) -> ()
+ // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
+ // CHECK-NEXT: %[[scr2:.*]] = "xla_hlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32>
+
+ // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.mul %arg2, %[[scr1]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32>
+ // CHECK-NEXT: %[[mul3:.*]] = "xla_hlo.mul"(%[[grad]], %[[mul2]]) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
+
+ // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.mul %[[scr1]], %[[scr2]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32>
+
+ // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64>
+ // CHECK-NEXT: %[[cgrad:.*]] = "xla_hlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[init2:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
+ // CHECK-NEXT: %[[red2:.*]] = "xla_hlo.reduce"(%[[cgrad]], %[[init2]]) ( {
+ // CHECK-NEXT: ^bb0(%arg5: tensor<f32>, %arg6: tensor<f32>): // no predecessors
+ // CHECK-NEXT: %[[reduced1:.*]] = xla_hlo.add %arg5, %arg6 : tensor<f32>
+ // CHECK-NEXT: "xla_hlo.return"(%[[reduced1]]) : (tensor<f32>) -> ()
+ // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
+ // CHECK-NEXT: %[[offset_backprop:.*]] = "xla_hlo.convert"(%[[red2]]) : (tensor<8xf32>) -> tensor<8xf32>
+
+ // CHECK-NEXT: %[[x_backprop:.*]] = "xla_hlo.convert"(%[[mul3]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32>
+
+ %0:5 = "tf.FusedBatchNormGrad"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
+ return %0#0 : tensor<8x8x8x8xf32>
+}
+
+// CHECK-LABEL: fusedBatchNormGrad_Training
+func @fusedBatchNormGrad_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
+ // CHECK-NEXT: %[[grad:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[training:.*]] = "xla_hlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>
+ // CHECK-NEXT: %[[tact:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[scale_backprop:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
+ // CHECK-NEXT: %[[offset_backprop:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
+ // CHECK-NEXT: %[[x_backprop:.*]] = "xla_hlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32>
+
+ %0:5 = "tf.FusedBatchNormGrad"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
+ return %0#0 : tensor<8x8x8x8xf32>
+}
+
+// CHECK-LABEL: fusedBatchNormGradV2_noTraining
+func @fusedBatchNormGradV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
+ // CHECK-NEXT: %[[grad:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[eps:.*]] = xla_hlo.constant dense<1.000000e-03> : tensor<f32>
+
+ // CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg4, %[[eps]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
+ // CHECK-NEXT: %[[scr1:.*]] = "xla_hlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32>
+
+ // CHECK-NEXT: %[[sub:.*]] = "xla_hlo.sub"(%[[act]], %arg3) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %[[grad]], %[[sub]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8x8x8x8xf32>
+ // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64>
+ // CHECK-NEXT: %[[cmul:.*]] = "xla_hlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[init:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
+ // CHECK-NEXT: %[[red1:.*]] = "xla_hlo.reduce"(%[[cmul]], %[[init]]) ( {
+ // CHECK-NEXT: ^bb0(%arg5: tensor<f32>, %arg6: tensor<f32>): // no predecessors
+ // CHECK-NEXT: %[[reduced:.*]] = xla_hlo.add %arg5, %arg6 : tensor<f32>
+ // CHECK-NEXT: "xla_hlo.return"(%[[reduced]]) : (tensor<f32>) -> ()
+ // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
+ // CHECK-NEXT: %[[scr2:.*]] = "xla_hlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32>
+
+ // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.mul %arg2, %[[scr1]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32>
+ // CHECK-NEXT: %[[mul3:.*]] = "xla_hlo.mul"(%[[grad]], %[[mul2]]) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
+
+ // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.mul %[[scr1]], %[[scr2]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32>
+
+ // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64>
+ // CHECK-NEXT: %[[cgrad:.*]] = "xla_hlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[init2:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
+ // CHECK-NEXT: %[[red2:.*]] = "xla_hlo.reduce"(%[[cgrad]], %[[init2]]) ( {
+ // CHECK-NEXT: ^bb0(%arg5: tensor<f32>, %arg6: tensor<f32>): // no predecessors
+ // CHECK-NEXT: %[[reduced1:.*]] = xla_hlo.add %arg5, %arg6 : tensor<f32>
+ // CHECK-NEXT: "xla_hlo.return"(%[[reduced1]]) : (tensor<f32>) -> ()
+ // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
+ // CHECK-NEXT: %[[offset_backprop:.*]] = "xla_hlo.convert"(%[[red2]]) : (tensor<8xf32>) -> tensor<8xf32>
+
+ // CHECK-NEXT: %[[x_backprop:.*]] = "xla_hlo.convert"(%[[mul3]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32>
+
+ %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
+ return %0#0 : tensor<8x8x8x8xf32>
+}
+
+// CHECK-LABEL: fusedBatchNormGradV2_Training
+func @fusedBatchNormGradV2_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
+ // CHECK-NEXT: %[[grad:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[training:.*]] = "xla_hlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>
+ // CHECK-NEXT: %[[tact:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[scale_backprop:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
+ // CHECK-NEXT: %[[offset_backprop:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
+ // CHECK-NEXT: %[[x_backprop:.*]] = "xla_hlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32>
+
+ %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
+ return %0#0 : tensor<8x8x8x8xf32>
+}
+
+// CHECK-LABEL: fusedBatchNormGradV2_noTraining_mixed_precision
+func @fusedBatchNormGradV2_noTraining_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) {
+ // CHECK-NEXT: %[[grad:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32>
+
+ // CHECK: %[[x_backprop:.*]] = "xla_hlo.convert"({{.*}}) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16>
+ // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16>
+
+ %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
+ return %0#0 : tensor<8x8x8x8xbf16>
+}
+
+// CHECK-LABEL: fusedBatchNormGradV2_Training_mixed_precision
+func @fusedBatchNormGradV2_Training_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) {
+ // CHECK-NEXT: %[[grad:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[training:.*]] = "xla_hlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>
+ // CHECK-NEXT: %[[tact:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[scale_backprop:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
+ // CHECK-NEXT: %[[offset_backprop:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
+ // CHECK-NEXT: %[[x_backprop:.*]] = "xla_hlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16>
+ // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16>
+
+ %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
+ return %0#0 : tensor<8x8x8x8xbf16>
+}
+
+// CHECK-LABEL: fusedBatchNormGradV3_noTraining
+func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
+ // CHECK-NEXT: %[[grad:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[eps:.*]] = xla_hlo.constant dense<1.000000e-03> : tensor<f32>
+
+ // CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg4, %[[eps]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
+ // CHECK-NEXT: %[[scr1:.*]] = "xla_hlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32>
+
+ // CHECK-NEXT: %[[sub:.*]] = "xla_hlo.sub"(%[[act]], %arg3) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %[[grad]], %[[sub]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8x8x8x8xf32>
+ // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64>
+ // CHECK-NEXT: %[[cmul:.*]] = "xla_hlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[init:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
+ // CHECK-NEXT: %[[red1:.*]] = "xla_hlo.reduce"(%[[cmul]], %[[init]]) ( {
+ // CHECK-NEXT: ^bb0(%arg6: tensor<f32>, %arg7: tensor<f32>): // no predecessors
+ // CHECK-NEXT: %[[reduced:.*]] = xla_hlo.add %arg6, %arg7 : tensor<f32>
+ // CHECK-NEXT: "xla_hlo.return"(%[[reduced]]) : (tensor<f32>) -> ()
+ // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
+ // CHECK-NEXT: %[[scr2:.*]] = "xla_hlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32>
+
+ // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.mul %arg2, %[[scr1]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32>
+ // CHECK-NEXT: %[[mul3:.*]] = "xla_hlo.mul"(%[[grad]], %[[mul2]]) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
+
+ // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.mul %[[scr1]], %[[scr2]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32>
+
+ // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64>
+ // CHECK-NEXT: %[[cgrad:.*]] = "xla_hlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[init2:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
+ // CHECK-NEXT: %[[red2:.*]] = "xla_hlo.reduce"(%[[cgrad]], %[[init2]]) ( {
+ // CHECK-NEXT: ^bb0(%arg6: tensor<f32>, %arg7: tensor<f32>): // no predecessors
+ // CHECK-NEXT: %[[reduced1:.*]] = xla_hlo.add %arg6, %arg7 : tensor<f32>
+ // CHECK-NEXT: "xla_hlo.return"(%[[reduced1]]) : (tensor<f32>) -> ()
+ // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
+ // CHECK-NEXT: %[[offset_backprop:.*]] = "xla_hlo.convert"(%[[red2]]) : (tensor<8xf32>) -> tensor<8xf32>
+
+ // CHECK-NEXT: %[[x_backprop:.*]] = "xla_hlo.convert"(%[[mul3]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32>
+
+ %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
+ return %0#0 : tensor<8x8x8x8xf32>
+}
+
+// CHECK-LABEL: fusedBatchNormGradV3_Training
+func @fusedBatchNormGradV3_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
+ // CHECK-NEXT: %[[grad:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[training:.*]] = "xla_hlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>
+ // CHECK-NEXT: %[[tact:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[scale_backprop:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
+ // CHECK-NEXT: %[[offset_backprop:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
+ // CHECK-NEXT: %[[x_backprop:.*]] = "xla_hlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32>
+
+ %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
+ return %0#0 : tensor<8x8x8x8xf32>
+}
+
+// CHECK-LABEL: fusedBatchNormGradV3_noTraining_mixed_precision
+func @fusedBatchNormGradV3_noTraining_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) {
+ // CHECK-NEXT: %[[grad:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32>
+
+ // CHECK: %[[x_backprop:.*]] = "xla_hlo.convert"({{.*}}) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16>
+ // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16>
+
+ %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
+ return %0#0 : tensor<8x8x8x8xbf16>
+}
+
+// CHECK-LABEL: fusedBatchNormGradV3_Training_mixed_precision
+func @fusedBatchNormGradV3_Training_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) {
+ // CHECK-NEXT: %[[grad:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[training:.*]] = "xla_hlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>
+ // CHECK-NEXT: %[[tact:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[scale_backprop:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
+ // CHECK-NEXT: %[[offset_backprop:.*]] = "xla_hlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
+ // CHECK-NEXT: %[[x_backprop:.*]] = "xla_hlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16>
+ // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16>
+
+ %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
+ return %0#0 : tensor<8x8x8x8xbf16>
+}
+
+// CHECK-LABEL: fusedBatchNormGradV3_noTraining_NCHW
+func @fusedBatchNormGradV3_noTraining_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
+ // CHECK-NEXT: %[[grad:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[eps:.*]] = xla_hlo.constant dense<1.000000e-03> : tensor<f32>
+
+ // CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg4, %[[eps]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
+ // CHECK-NEXT: %[[scr1:.*]] = "xla_hlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32>
+
+ // CHECK-NEXT: %[[sub:.*]] = "xla_hlo.sub"(%[[act]], %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %[[grad]], %[[sub]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8x8x8x8xf32>
+ // CHECK-NEXT: xla_hlo.constant dense<[0, 2, 3]> : tensor<3xi64>
+ // CHECK-NEXT: %[[cmul:.*]] = "xla_hlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[init:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
+ // CHECK-NEXT: %[[red1:.*]] = "xla_hlo.reduce"(%[[cmul]], %[[init]]) ( {
+ // CHECK-NEXT: ^bb0(%arg6: tensor<f32>, %arg7: tensor<f32>): // no predecessors
+ // CHECK-NEXT: %[[reduced:.*]] = xla_hlo.add %arg6, %arg7 : tensor<f32>
+ // CHECK-NEXT: "xla_hlo.return"(%[[reduced]]) : (tensor<f32>) -> ()
+ // CHECK-NEXT: }) {dimensions = dense<[0, 2, 3]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
+ // CHECK-NEXT: %[[scr2:.*]] = "xla_hlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32>
+
+ // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.mul %arg2, %[[scr1]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32>
+ // CHECK-NEXT: %[[mul3:.*]] = "xla_hlo.mul"(%[[grad]], %[[mul2]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
+
+ // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.mul %[[scr1]], %[[scr2]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32>
+
+ // CHECK-NEXT: xla_hlo.constant dense<[0, 2, 3]> : tensor<3xi64>
+ // CHECK-NEXT: %[[cgrad:.*]] = "xla_hlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: %[[init2:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
+ // CHECK-NEXT: %[[red2:.*]] = "xla_hlo.reduce"(%[[cgrad]], %[[init2]]) ( {
+ // CHECK-NEXT: ^bb0(%arg6: tensor<f32>, %arg7: tensor<f32>): // no predecessors
+ // CHECK-NEXT: %[[reduced1:.*]] = xla_hlo.add %arg6, %arg7 : tensor<f32>
+ // CHECK-NEXT: "xla_hlo.return"(%[[reduced1]]) : (tensor<f32>) -> ()
+ // CHECK-NEXT: }) {dimensions = dense<[0, 2, 3]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
+ // CHECK-NEXT: %[[offset_backprop:.*]] = "xla_hlo.convert"(%[[red2]]) : (tensor<8xf32>) -> tensor<8xf32>
+
+ // CHECK-NEXT: %[[x_backprop:.*]] = "xla_hlo.convert"(%[[mul3]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
+ // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32>
+
+ %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
+ return %0#0 : tensor<8x8x8x8xf32>
+}
+
+// CHECK-LABEL: fusedBatchNormGradV3_Training_NCHW
+func @fusedBatchNormGradV3_Training_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
+ // CHECK: %{{.*}} = "xla_hlo.batch_norm_grad"(%{{.*}}, %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 1 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>
+ %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
+ return %0#0 : tensor<8x8x8x8xf32>
+}
+
//===----------------------------------------------------------------------===//
// Bias op legalizations.
//===----------------------------------------------------------------------===//
@@ -1412,6 +1690,27 @@
return %0 : tensor<*xf32>
}
+// CHECK-LABEL: @log1p
+func @log1p(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+ // CHECK: "xla_hlo.log_plus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ %0 = "tf.Log1p"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ return %0 : tensor<2xf32>
+}
+
+// CHECK-LABEL: func @log1p_dynamic
+func @log1p_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+ // CHECK: "xla_hlo.log_plus_one"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
+ %0 = "tf.Log1p"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// CHECK-LABEL: func @log1p_unranked
+func @log1p_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
+ // CHECK: "xla_hlo.log_plus_one"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
+ %0 = "tf.Log1p"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
// CHECK-LABEL: func @not_op_unranked
func @not_op_unranked(%arg0: tensor<*xi1>) -> tensor<*xi1> {
// CHECK: "xla_hlo.not"(%arg0) : (tensor<*xi1>) -> tensor<*xi1>
@@ -1580,6 +1879,18 @@
return %0 : tensor<1x2xf32>
}
+// CHECK-LABEL: func @sign
+// CHECK-SAME: [[ARG:%arg.*]]: tensor<1x2x3x4xf32>
+func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> {
+ // CHECK: [[PRED:%.*]] = "xla_hlo.compare"([[ARG]], [[ARG]])
+ // CHECK: [[ZEROS:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32>
+ // CHECK: [[SIGN:%.*]] = "xla_hlo.sign"([[ARG]])
+ // CHECK: [[SELECT:%.*]] = "xla_hlo.select"([[PRED]], [[ZEROS]], [[SIGN]])
+ // CHECK: return [[SELECT]] : tensor<1x2x3x4xf32>
+ %0 = "tf.Sign"(%arg0) : (tensor<1x2x3x4xf32>) -> (tensor<1x2x3x4xf32>)
+ return %0 : tensor<1x2x3x4xf32>
+}
+
// CHECK-LABEL: slice_constant_start
func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> {
// CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi64>
@@ -2500,3 +2811,20 @@
// CHECK: return [[PAD]]
return %0: tensor<4x128x1024xf32>
}
+
+// CHECK-LABEL: @tensor_scatter_update
+func @tensor_scatter_update(%tensor: tensor<?x?x?xf32>, %indices: tensor<?x2xi32>, %updates: tensor<?x?xf32>) -> tensor<?x?x?xf32> {
+ // CHECK: "xla_hlo.scatter"(%arg0, %arg1, %arg2) ( {
+ // CHECK: ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ // CHECK: "xla_hlo.return"(%arg4) : (tensor<f32>) -> ()
+ // CHECK: })
+ // CHECK-SAME: indices_are_sorted = false
+ // CHECK-SAME: scatter_dimension_numbers
+ // CHECK-SAME: index_vector_dim = 1 : i64
+ // CHECK-SAME: inserted_window_dims = dense<[0, 1]> : tensor<2xi64>
+ // CHECK-SAME: scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>
+ // CHECK-SAME: update_window_dims = dense<1> : tensor<1xi64>
+ // CHECK-SAME: unique_indices = false
+ %0 = "tf.TensorScatterUpdate"(%tensor, %indices, %updates) : (tensor<?x?x?xf32>, tensor<?x2xi32>, tensor<?x?xf32>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir
index dae20d0..1d2cf76 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir
@@ -32,10 +32,10 @@
// CHECK-NEXT: %2 = subi %1, %arg1 : tensor<4xi32>
%2 = "xla_hlo.sub"(%1, %arg1) {name = "sub.5"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
- // CHECK-NEXT: %3 = divis %2, %arg1 : tensor<4xi32>
+ // CHECK-NEXT: %3 = divi_signed %2, %arg1 : tensor<4xi32>
%3 = "xla_hlo.div"(%2, %arg1) {name = "div.6"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
- // CHECK-NEXT: %4 = remis %3, %arg1 : tensor<4xi32>
+ // CHECK-NEXT: %4 = remi_signed %3, %arg1 : tensor<4xi32>
%4 = "xla_hlo.remainder"(%3, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK-NEXT: return %4 : tensor<4xi32>
diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-affine.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-affine.mlir
index d4ee0fd..74fea0c 100644
--- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-affine.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-affine.mlir
@@ -59,7 +59,7 @@
// CHECK-LABEL: func @int_div_op
func @int_div_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
%result: memref<7xi32>) -> () {
- // CHECK: divis %{{.*}}, %{{.*}} : i32
+ // CHECK: divi_signed %{{.*}}, %{{.*}} : i32
"xla_lhlo.div"(%lhs, %rhs, %result) {name = "div.1"}
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
return
diff --git a/tensorflow/compiler/mlir/xla/tests/ops.mlir b/tensorflow/compiler/mlir/xla/tests/ops.mlir
index a315a23..c33ab80 100644
--- a/tensorflow/compiler/mlir/xla/tests/ops.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/ops.mlir
@@ -189,6 +189,15 @@
// -----
+func @rng_uniform_invalid_type(%mu: tensor<complex<f32>>, %sigma: tensor<f32>) -> tensor<2x3x5xf32> {
+ %shape = xla_hlo.constant dense<[2, 3, 5]> : tensor<3xi64>
+ // expected-error@+1 {{must be tensor of pred (AKA boolean or 1-bit integer) or 8/16/32/64-bit integer or floating-point values, but got 'tensor<complex<f32>>'}}
+ %0 = "xla_hlo.rng_uniform"(%mu, %sigma, %shape) : (tensor<complex<f32>>, tensor<f32>, tensor<3xi64>) -> tensor<2x3x5xf32>
+ return %0 : tensor<2x3x5xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @select
func @select(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
%0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir
index 090d082..8a4d0b8 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir
@@ -356,6 +356,18 @@
// -----
// CHECK: HloModule
+func @main(%arg0: tensor<3x9xf32>) -> tensor<3x5xcomplex<f32>> {
+ %0 = "xla_hlo.fft"(%arg0) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"} : (tensor<3x9xf32>) -> tensor<3x5xcomplex<f32>>
+ return %0 : tensor<3x5xcomplex<f32>>
+}
+
+// CHECK: ENTRY
+// CHECK: [[ARG:%.*]] = f32[3,9] parameter(0)
+// CHECK: c64[3,5] fft(f32[3,9] [[ARG]]), fft_type=RFFT, fft_length={9}
+
+// -----
+
+// CHECK: HloModule
func @main(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<10x300xf32> {
// CHECK: [[ARG0:%.*]] = f32[200,100,300] parameter(0)
// CHECK: [[ARG1:%.*]] = s32[10,2] parameter(1)
@@ -530,6 +542,20 @@
// -----
// CHECK: HloModule
+func @main(%mu: tensor<f32>, %sigma: tensor<f32>) -> tensor<2x3x5xf32> {
+ %shape = xla_hlo.constant dense<[2, 3, 5]> : tensor<3xi64>
+ %0 = "xla_hlo.rng_normal"(%mu, %sigma, %shape) : (tensor<f32>, tensor<f32>, tensor<3xi64>) -> tensor<2x3x5xf32>
+ return %0 : tensor<2x3x5xf32>
+}
+
+// CHECK: ENTRY
+// CHECK: %[[MU:.*]] = f32[] parameter(0)
+// CHECK: %[[SIGMA:.*]] = f32[] parameter(1)
+// CHECK: ROOT %[[RESULT:.*]] = f32[2,3,5] rng(f32[] %[[MU]], f32[] %[[SIGMA]]), distribution=rng_normal
+
+// -----
+
+// CHECK: HloModule
func @main() -> tensor<2x3x5xf32> {
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
%1 = xla_hlo.constant dense<1.000000e+00> : tensor<f32>
@@ -625,6 +651,48 @@
// -----
// CHECK: HloModule
+func @main(%arg: tensor<3x4xi32>, %token: !xla_hlo.token) -> !xla_hlo.token {
+ %0 = "xla_hlo.send"(%arg, %token) {
+ channel_id = {
+ handle = 5 : i64,
+ type = 2 : i64 // Device to host channel
+ },
+ is_host_transfer = true
+ } : (tensor<3x4xi32>, !xla_hlo.token) -> !xla_hlo.token
+ return %0 : !xla_hlo.token
+}
+
+// CHECK: ENTRY
+// CHECK: [[ARG:%.*]] = s32[3,4] parameter(0)
+// CHECK: [[TOKEN:%.*]] = token[] parameter(1)
+// CHECK: [[SEND:%.*]] = (s32[3,4], u32[], token[]) send(s32[3,4] [[ARG]], token[] [[TOKEN]]), channel_id=5, is_host_transfer=true
+// CHECK: ROOT
+// CHECK-SAME: token[] send-done((s32[3,4], u32[], token[]) [[SEND]]), channel_id=5, is_host_transfer=true
+
+// -----
+
+// CHECK: HloModule
+func @main(%arg: tensor<3x4xi32>, %token: !xla_hlo.token) -> !xla_hlo.token {
+ %0 = "xla_hlo.send"(%arg, %token) {
+ channel_id = {
+ handle = 5 : i64,
+ type = 1 : i64 // Device to device channel
+ },
+ is_host_transfer = false
+ } : (tensor<3x4xi32>, !xla_hlo.token) -> !xla_hlo.token
+ return %0 : !xla_hlo.token
+}
+
+// CHECK: ENTRY
+// CHECK: [[ARG:%.*]] = s32[3,4] parameter(0)
+// CHECK: [[TOKEN:%.*]] = token[] parameter(1)
+// CHECK: [[SEND:%.*]] = (s32[3,4], u32[], token[]) send(s32[3,4] [[ARG]], token[] [[TOKEN]]), channel_id=5
+// CHECK: ROOT
+// CHECK-SAME: token[] send-done((s32[3,4], u32[], token[]) [[SEND]]), channel_id=5
+
+// -----
+
+// CHECK: HloModule
func @main(%arg: tensor<3x4xi32>) -> tensor<1x2xi32> {
%0 = "xla_hlo.slice"(%arg) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x2xi32>
return %0 : tensor<1x2xi32>
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt
index 41f009e..b6900bc 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt
@@ -95,6 +95,15 @@
ROOT %call.2 = s64[] call(%arg0.1), to_apply=%call
}
+// CHECK-LABEL: func @test_cholesky
+// CHECK-SAME: ([[ARG:%.*]]: tensor<1x291x291xf32>) -> tensor<1x291x291xf32>
+%test_cholesky (a: f32[1,291,291]) -> f32[1,291,291] {
+ %a = f32[1,291,291] parameter(0)
+ // CHECK-NEXT: "xla_hlo.cholesky"([[ARG]]) {lower = true, name = {{.*}}} : (tensor<1x291x291xf32>) -> tensor<1x291x291xf32>
+ ROOT %out = f32[1,291,291] cholesky(f32[1,291,291] %a), lower=true
+}
+
+
// CHECK-LABEL: func @test_clamp(
%test_clamp (Arg_0.1: f32[], Arg_1.2: f32[4], Arg_1.3: f32[]) -> f32[4] {
%Arg_0.1 = f32[] parameter(0)
@@ -508,6 +517,26 @@
ROOT %power.3 = f32[4] power(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
}
+// CHECK-LABEL: func @test_rng_normal
+// CHECK-SAME: ([[ARG0:%.*]]: tensor<f32>, [[ARG1:%.*]]: tensor<f32>) -> tensor<2x3x5xf32>
+%test_rng_normal (Arg_0.1: f32[], Arg_1.2: f32[]) -> f32[2,3,5] {
+ %Arg_0.1 = f32[] parameter(0)
+ %Arg_1.2 = f32[] parameter(1)
+ // CHECK: [[CST:%.*]] = constant dense<[2, 3, 5]> : tensor<3xi64>
+ // CHECK: "xla_hlo.rng_normal"([[ARG0]], [[ARG1]], [[CST]])
+ ROOT %rng.4 = f32[2,3,5] rng(f32[] %Arg_0.1, f32[] %Arg_1.2), distribution=rng_normal
+}
+
+// CHECK-LABEL: func @test_rng_uniform
+// CHECK-SAME: ([[ARG0:%.*]]: tensor<f32>, [[ARG1:%.*]]: tensor<f32>) -> tensor<2x3x5xf32>
+%test_rng_uniform (Arg_0.1: f32[], Arg_1.2: f32[]) -> f32[2,3,5] {
+ %Arg_0.1 = f32[] parameter(0)
+ %Arg_1.2 = f32[] parameter(1)
+ // CHECK: [[CST:%.*]] = constant dense<[2, 3, 5]> : tensor<3xi64>
+ // CHECK: "xla_hlo.rng_uniform"([[ARG0]], [[ARG1]], [[CST]])
+ ROOT %rng.4 = f32[2,3,5] rng(f32[] %Arg_0.1, f32[] %Arg_1.2), distribution=rng_uniform
+}
+
// CHECK-LABEL: func @test_real
%test_real (Arg_0.1: c64[4]) -> f32[4] {
%Arg_0.1 = c64[4] parameter(0)
diff --git a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc
index 4a74fe4..d14e9d0 100644
--- a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc
@@ -39,7 +39,7 @@
constexpr StringRef kTempBufferAttr = "temp";
-Value* GetTensorStoreOrReturnMemRef(Value* value) {
+Value GetTensorStoreOrReturnMemRef(Value value) {
for (const auto& user : value->getUsers()) {
if (auto tensor_store = dyn_cast<TensorStoreOp>(user)) {
if (tensor_store.getOperand(0) == value) {
@@ -56,7 +56,7 @@
return nullptr;
}
-Operation* GetLastUse(Value* value) {
+Operation* GetLastUse(Value value) {
Operation* last = value->getDefiningOp();
for (auto& user : value->getUses()) {
Operation* user_op = user.getOwner();
@@ -67,8 +67,8 @@
return last;
}
-Value* InsertAllocAndDealloc(Location loc, Value* result,
- ConversionPatternRewriter* rewriter) {
+Value InsertAllocAndDealloc(Location loc, Value result,
+ ConversionPatternRewriter* rewriter) {
auto result_type = result->getType().dyn_cast<ShapedType>();
if (!result_type || !result_type.hasStaticShape()) {
emitError(loc,
@@ -93,8 +93,8 @@
/// For every tensor-type value that is produced in the original function,
/// this function returns the buffer that can be used in the converted
/// function to store that values held in the tensor.
-Value* GetBufferForResultValue(Location loc, Value* result,
- ConversionPatternRewriter* rewriter) {
+Value GetBufferForResultValue(Location loc, Value result,
+ ConversionPatternRewriter* rewriter) {
if (auto existing_memref = GetTensorStoreOrReturnMemRef(result)) {
return existing_memref;
}
@@ -108,7 +108,7 @@
: ConversionPattern(HloOpTy::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(
- Operation* op, ArrayRef<Value*> operands,
+ Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
if (op->getParentRegion()->getBlocks().size() != 1) {
emitError(op->getLoc(),
@@ -116,14 +116,14 @@
"region containing the operation");
}
const auto& original_results = op->getResults();
- SmallVector<Value*, 4> buffer_args(operands.begin(), operands.end());
+ SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
for (auto result : original_results) {
buffer_args.push_back(
GetBufferForResultValue(op->getLoc(), result, &rewriter));
}
rewriter.create<LhloOpTy>(op->getLoc(), llvm::None, buffer_args,
op->getAttrs());
- rewriter.replaceOp(op, ArrayRef<Value*>(buffer_args).slice(operands.size()),
+ rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()),
original_results);
return matchSuccess();
}
@@ -135,7 +135,7 @@
using OpConversionPattern::OpConversionPattern;
PatternMatchResult matchAndRewrite(
- xla_hlo::ReduceOp op, ArrayRef<Value*> operands,
+ xla_hlo::ReduceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
auto loc = op.getLoc();
// TODO(b/137624192) Implement variadic reduce.
@@ -146,7 +146,7 @@
"region containing the operation");
}
const auto& original_results = op.getResults();
- SmallVector<Value*, 4> buffer_args(operands.begin(), operands.end());
+ SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
for (auto result : original_results) {
buffer_args.push_back(GetBufferForResultValue(loc, result, &rewriter));
}
@@ -178,7 +178,7 @@
rewriter.setInsertionPointToEnd(&entry_block);
rewriter.create<xla_lhlo::TerminatorOp>(loc);
- rewriter.replaceOp(op, ArrayRef<Value*>(buffer_args).slice(operands.size()),
+ rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()),
original_results);
return matchSuccess();
@@ -191,7 +191,7 @@
: ConversionPattern(TensorLoadOp::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(
- Operation* op, ArrayRef<Value*> operands,
+ Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
rewriter.replaceOp(op, operands, op->getResults());
return matchSuccess();
@@ -205,7 +205,7 @@
: ConversionPattern(TensorStoreOp::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(
- Operation* op, ArrayRef<Value*> operands,
+ Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
rewriter.eraseOp(op);
return matchSuccess();
@@ -218,7 +218,7 @@
using OpConversionPattern::OpConversionPattern;
PatternMatchResult matchAndRewrite(
- xla_hlo::ReturnOp op, ArrayRef<Value*> operands,
+ xla_hlo::ReturnOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
rewriter.eraseOp(op);
return matchSuccess();
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc
index 8a8afc0..4c7de6c 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc
@@ -171,8 +171,8 @@
auto cond_value = builder.create<mlir::ExtractElementOp>(loc, return_value);
// Get the body block arguments.
- llvm::SmallVector<Value*, 4> successor_args(cond_block->args_begin(),
- cond_block->args_end());
+ llvm::SmallVector<Value, 4> successor_args(cond_block->args_begin(),
+ cond_block->args_end());
builder.create<mlir::CondBranchOp>(loc, cond_value, body_block,
successor_args, tail_block,
successor_args);
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
index 7551b9f..44baf8c 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
@@ -55,31 +55,20 @@
class LegalizeTF : public FunctionPass<LegalizeTF> {
public:
- struct Options : public PassOptions<Options> {
- Option<bool> allow_partial_conversion{
- *this, "allow-partial-conversion",
- llvm::cl::desc("Allow operations that can't be legalized."),
- llvm::cl::init(false)};
- };
-
- explicit LegalizeTF(bool allow_partial_conversion)
- : FunctionPass<LegalizeTF>(),
- allow_partial_conversion_(allow_partial_conversion) {}
-
- explicit LegalizeTF(const Options &option)
- : LegalizeTF(option.allow_partial_conversion) {}
+ LegalizeTF() = default;
+ LegalizeTF(const LegalizeTF &) {}
+ explicit LegalizeTF(bool allow_partial_conversion) {
+ allow_partial_conversion_ = allow_partial_conversion;
+ }
/// Performs the lowering to XLA dialect.
void runOnFunction() override;
- /// Print this pass for a textual pipeline. It must round-trip.
- void printAsTextualPipeline(raw_ostream &os) override {
- os << "xla-legalize-tf{allow-partial-conversion="
- << (allow_partial_conversion_ ? "true" : "false") << "}";
- }
-
private:
- bool allow_partial_conversion_;
+ Option<bool> allow_partial_conversion_{
+ *this, "allow-partial-conversion",
+ llvm::cl::desc("Allow operations that can't be legalized."),
+ llvm::cl::init(false)};
};
/// Returns if the given TF data format string is the default format.
@@ -133,7 +122,7 @@
// corresponding to the tensorflow axis. In particular, the tensorflow axis can
// be negative, in which case, the corresponding HLO axis is
// (axis + rank-of-the-tensor).
-static llvm::Optional<int64_t> GetIntegerHLOAxisFromTFAxis(Value *value,
+static llvm::Optional<int64_t> GetIntegerHLOAxisFromTFAxis(Value value,
int64_t rank) {
DenseIntElementsAttr attrs;
if (!matchPattern(value, m_Constant(&attrs)) ||
@@ -146,7 +135,7 @@
/// Returns a `ConvertOp` that casts the elements to a i64 type while retaining
/// the shape of the input value.
-static ConvertOp CastValueToI64(Location loc, Value *value,
+static ConvertOp CastValueToI64(Location loc, Value value,
PatternRewriter *rewriter) {
return rewriter->create<ConvertOp>(loc, value, rewriter->getIntegerType(64));
}
@@ -230,12 +219,28 @@
builder->create<ReturnOp>(loc, reducer.getResult());
}
+// Builds region taking two arguments and returning second argument as the
+// result. Corresponds to the function f(x, y) = y.
+// Used in Scatter op's computation to update specific elements.
+static void BuildBinaryAssignmentRegion(Type element_type, Region *region,
+ OpBuilder *builder) {}
+
+// Builds a set of operations for applying reduction on the input value. A
+// tf.sum op is created and will be legalized to tfl ops automatically.
+static Value ApplyReduction(Location loc, Value input,
+ DenseIntElementsAttr reduce_dims,
+ OpBuilder *builder) {
+ auto reduce_dims_op = builder->create<ConstOp>(loc, reduce_dims);
+ return builder->create<TF::SumOp>(loc, input, reduce_dims_op,
+ builder->getBoolAttr(false));
+}
+
//===----------------------------------------------------------------------===//
// BatchNorm op utilities.
//===----------------------------------------------------------------------===//
static IntegerAttr getFeatureDimensionAttr(Builder &b, StringAttr format,
- Value *input) {
+ Value input) {
return b.getI64IntegerAttr(
getFeatureDimension(format, input->getType().cast<RankedTensorType>()));
}
@@ -248,7 +253,7 @@
// Requires input to have ranked tensor.
static DenseIntElementsAttr getBiasFeatureDimension(Builder &b,
StringAttr format,
- Value *input) {
+ Value input) {
auto inputType = input->getType().cast<RankedTensorType>();
size_t featureDim = getFeatureDimension(format, inputType);
RankedTensorType type = RankedTensorType::get(1, b.getIntegerType(64));
@@ -313,7 +318,7 @@
// same shape, this broadcasts size 1 tensors up to any rank. Dynamic dimensions
// must be broadcasted with a size 1 tensor or another dynamic dimension.
// Returns false on rankless.
-static bool AreBroadcastCompatible(Value *x, Value *y) {
+static bool AreBroadcastCompatible(Value x, Value y) {
auto x_rankless = x->getType().dyn_cast<RankedTensorType>();
auto y_rankless = y->getType().dyn_cast<RankedTensorType>();
if (!x_rankless || !y_rankless) {
@@ -394,16 +399,16 @@
Location loc = body->getLoc();
StringAttr compare_direction =
StringAttr::get(direction, builder->getContext());
- Value *compare = builder->create<CompareOp>(
+ Value compare = builder->create<CompareOp>(
loc, block->getArgument(0), block->getArgument(2),
/*broadcast_dimensions=*/nullptr, compare_direction);
- Value *selected_input = builder->create<SelectOp>(
+ Value selected_input = builder->create<SelectOp>(
loc, input_type, compare, block->getArgument(0), block->getArgument(2));
- Value *selected_index = builder->create<SelectOp>(
+ Value selected_index = builder->create<SelectOp>(
loc, index_type, compare, block->getArgument(1), block->getArgument(3));
- Value *return_values[] = {selected_input, selected_index};
+ Value return_values[] = {selected_input, selected_index};
builder->create<ReturnOp>(loc, return_values);
}
@@ -411,7 +416,7 @@
// Slice op utilities.
//===----------------------------------------------------------------------===//
-static bool CanBeTranslatedToDynamicSlice(Value *input, Value *start_indices,
+static bool CanBeTranslatedToDynamicSlice(Value input, Value start_indices,
DenseIntElementsAttr slice_sizes) {
auto input_ty = input->getType().dyn_cast<RankedTensorType>();
int64_t input_rank = input_ty.getRank();
@@ -452,7 +457,7 @@
// the end. HLO slice size can't be -1. As such, we need to translate TF slice
// size -1 to HLO slice size.
static DenseIntElementsAttr TFSliceSizes2HLOSliceSizes(
- Value *input, Value *start_indices, DenseIntElementsAttr slice_sizes,
+ Value input, Value start_indices, DenseIntElementsAttr slice_sizes,
Builder *builder) {
DenseIntElementsAttr constant_start_indices;
if (!matchPattern(start_indices, m_Constant(&constant_start_indices))) {
@@ -502,7 +507,7 @@
Location loc = body->getLoc();
StringAttr compare_direction =
StringAttr::get(direction, builder->getContext());
- Value *compare = builder->create<xla_hlo::CompareOp>(
+ Value compare = builder->create<xla_hlo::CompareOp>(
loc, block->getArgument(0), block->getArgument(1),
/*broadcast_dimensions=*/nullptr, compare_direction);
@@ -661,7 +666,7 @@
auto paddings_attr = rewriter.getNamedAttr(
"padding", DenseElementsAttr::get<int64_t>(paddings_ty, paddings));
- SmallVector<Value *, 2> operands(op.getOperands());
+ SmallVector<Value, 2> operands(op.getOperands());
NamedAttribute attrs[] = {rhs_dilations_attr, window_strides_attr,
dimension_numbers_attr, feature_group_count_attr,
batch_group_count_attr, paddings_attr};
@@ -738,6 +743,120 @@
}
};
+// The base class to convert TensorFlow FusedBatchNormGrad*Op to HLO
+// BatchNormGradOp for training and a sequence of binary ops for inference.
+// TODO(b/145536565): move to legalize_tf_patterns.td if it applies.
+template <typename FusedBatchNormGradOpT>
+class ConvertFusedBatchNormGradBase
+ : public OpRewritePattern<FusedBatchNormGradOpT> {
+ public:
+ using OpRewritePattern<FusedBatchNormGradOpT>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(FusedBatchNormGradOpT op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Value grad = op.y_backprop();
+ Value act = op.x();
+ Value scale = op.scale();
+ Value mean = op.reserve_space_1();
+ Value var = op.reserve_space_2();
+
+ // TODO(b/141785544): Update this to not require static shapes.
+ // activation shape needs to be static to convert negative indices in
+ // TensorFlow to absolute indices required by HLO.
+ RankedTensorType act_type =
+ act->getType().template dyn_cast<RankedTensorType>();
+ if (!act_type) return Pattern::matchFailure();
+ Type act_ele_type = act_type.getElementType();
+ // To support mixed precision, the statistics type, which maybe more
+ // precise than the input types, are used for this op.
+ Type kernel_type =
+ scale->getType().template cast<TensorType>().getElementType();
+ grad = rewriter.create<ConvertOp>(loc, grad, kernel_type);
+ act = rewriter.create<ConvertOp>(loc, act, kernel_type);
+
+ auto feature_dim_attr =
+ getFeatureDimensionAttr(rewriter, op.data_formatAttr(), act);
+ auto feature_dim = feature_dim_attr.getValue().getSExtValue();
+
+ // Gets the result values.
+ Value x_backprop, scale_backprop, offset_backprop;
+ if (op.is_training()) { // training
+ // TODO(b/145536565): handle GPU logic seperately.
+ // Infers the output type with the converted `act`.
+ Type feature_type = RankedTensorType::get(
+ {GetDimSize(act_type, feature_dim)}, kernel_type);
+ Type result_type = TupleType::get(
+ {act->getType(), feature_type, feature_type}, rewriter.getContext());
+
+ auto training_op = rewriter.create<BatchNormGradOp>(
+ loc, result_type, act, scale, mean, var, grad, op.epsilon(),
+ feature_dim_attr.getValue());
+
+ x_backprop =
+ rewriter.create<GetTupleElementOp>(loc, training_op.getResult(), 0);
+
+ scale_backprop =
+ rewriter.create<GetTupleElementOp>(loc, training_op.getResult(), 1);
+
+ offset_backprop =
+ rewriter.create<GetTupleElementOp>(loc, training_op.getResult(), 2);
+ } else { // inference
+ SmallVector<int64_t, 4> non_feature_dims;
+ for (int64_t i = 0; i < act_type.getRank(); ++i) {
+ if (i == feature_dim) continue;
+ non_feature_dims.push_back(i);
+ }
+ auto reduce_dims = GetI64ElementsAttr(non_feature_dims, &rewriter);
+ auto broadcast_dims = GetI64ElementsAttr({feature_dim}, &rewriter);
+ auto no_broadcast_dims = GetI64ElementsAttr({}, &rewriter);
+
+ // scratch1 = rsqrt(var + epsilon)
+ RankedTensorType scalar_float = RankedTensorType::get({}, kernel_type);
+ auto epsilon = rewriter.create<ConstOp>(
+ loc, DenseFPElementsAttr::get(scalar_float, {op.epsilon()}));
+ auto add_op = rewriter.create<AddOp>(loc, var, epsilon.getResult(),
+ no_broadcast_dims);
+ Value scratch1 = rewriter.create<RsqrtOp>(loc, add_op);
+
+ // scratch2 = sum(y_backprop * (x - mean))
+ auto sub_op = rewriter.create<SubOp>(loc, act, mean, broadcast_dims);
+ auto weighted_grad =
+ rewriter.create<MulOp>(loc, grad, sub_op, no_broadcast_dims);
+ Value scratch2 =
+ ApplyReduction(loc, weighted_grad, reduce_dims, &rewriter);
+
+ // x_backprop = y_backprop * (scale * scratch1)
+ auto scaled_grad =
+ rewriter.create<MulOp>(loc, op.scale(), scratch1, no_broadcast_dims);
+ x_backprop =
+ rewriter.create<MulOp>(loc, grad, scaled_grad, broadcast_dims);
+
+ // scale_backprop = scratch2 * scratch1
+ scale_backprop =
+ rewriter.create<MulOp>(loc, scratch1, scratch2, no_broadcast_dims);
+
+ // offset_backprop = sum(y_backprop)
+ offset_backprop = ApplyReduction(loc, grad, reduce_dims, &rewriter);
+ }
+
+ x_backprop = rewriter.create<ConvertOp>(loc, x_backprop, act_ele_type);
+ // It doesn't matter what values we provide for the last 2 results.
+ rewriter.replaceOp(op,
+ {/*x_backprop=*/x_backprop,
+ /*scale_backprop=*/scale_backprop,
+ /*offset_backprop=*/offset_backprop, op.x(), op.x()});
+ return Pattern::matchSuccess();
+ }
+};
+
+using ConvertFusedBatchNormGradOp =
+ ConvertFusedBatchNormGradBase<TF::FusedBatchNormGradOp>;
+using ConvertFusedBatchNormGradV2Op =
+ ConvertFusedBatchNormGradBase<TF::FusedBatchNormGradV2Op>;
+using ConvertFusedBatchNormGradV3Op =
+ ConvertFusedBatchNormGradBase<TF::FusedBatchNormGradV3Op>;
+
// Converts TensorFlow FusedBatchNormV3Op to either HLO BatchNormTrainingOp or
// HLO BatchNormInferenceOp, depending on the value of the 'is_training'
// parameter.
@@ -757,30 +876,13 @@
auto scale_type_tensor = op.scale()->getType().dyn_cast<TensorType>();
auto scale_element_type = scale_type_tensor.getElementType();
- // The TF FusedBatchNormV3 op supports mixed precision. If the input type
- // differs, convert it to have the precision of the other types for the
- // HLO op.
- bool is_mixed_precision = false;
- Value *bn_train_input;
- TensorType bn_train_input_type_tensor;
- Type bn_train_input_element_type;
- if (input_element_type != scale_element_type) {
- // TODO(b/69928690): Support mixed precision in the XLA batch
- // normalization operators. As a workaround, create a new x with the same
- // element type as scale (which may be more precise than the input type).
- is_mixed_precision = true;
- bn_train_input = rewriter.create<xla_hlo::ConvertOp>(op.getLoc(), op.x(),
- scale_element_type);
- bn_train_input_type_tensor =
- ChangeTensorElementType(&rewriter, input_type_tensor,
- scale_element_type)
- .dyn_cast<TensorType>();
- bn_train_input_element_type = scale_element_type;
- } else {
- bn_train_input = op.x();
- bn_train_input_type_tensor = input_type_tensor;
- bn_train_input_element_type = input_element_type;
- }
+ // TODO(b/69928690): Support mixed precision in the XLA batch
+ // normalization operators. As a workaround, create a new x with the same
+ // element type as scale (which may be more precise than the input type).
+ Value bn_train_input = rewriter.create<xla_hlo::ConvertOp>(
+ op.getLoc(), op.x(), scale_element_type);
+ TensorType bn_train_input_type_tensor =
+ bn_train_input.getType().cast<TensorType>();
if (op.is_training()) {
// Training case.
@@ -790,7 +892,7 @@
// This shape must be constructed manually because the mean and variance
// inputs are empty in the training case.
Type mean_var_type = RankedTensorType::get(
- {operand_shape[feature_dim.getInt()]}, bn_train_input_element_type);
+ {operand_shape[feature_dim.getInt()]}, scale_element_type);
// Op result type is a tuple of 3 values: output with same shape as input;
// batch_mean, and batch_var.
SmallVector<Type, 3> operand_types = {bn_train_input_type_tensor,
@@ -802,11 +904,11 @@
op.epsilon(), feature_dim.getValue());
// HLO op outputs a tuple of tensors. Extract those results.
auto bn_train_op_result = bn_train_op.getResult();
- Value *y_out = rewriter.create<xla_hlo::GetTupleElementOp>(
+ Value y_out = rewriter.create<xla_hlo::GetTupleElementOp>(
op.getLoc(), bn_train_op_result, 0);
- Value *batch_mean = rewriter.create<xla_hlo::GetTupleElementOp>(
+ Value batch_mean = rewriter.create<xla_hlo::GetTupleElementOp>(
op.getLoc(), bn_train_op_result, 1);
- Value *batch_variance = rewriter.create<xla_hlo::GetTupleElementOp>(
+ Value batch_variance = rewriter.create<xla_hlo::GetTupleElementOp>(
op.getLoc(), bn_train_op_result, 2);
// Apply Bessel's correction on the variance.
@@ -823,12 +925,10 @@
op.getLoc(), batch_variance->getType(), batch_variance,
factor_const_op, /*DenseIntElementsAttr=*/DenseIntElementsAttr());
- if (is_mixed_precision) {
- // Convert back to input type to stay aligned with expected output type
- // for TF op.
- y_out = rewriter.create<xla_hlo::ConvertOp>(op.getLoc(), y_out,
- input_element_type);
- }
+ // Convert back to input type to stay aligned with expected output type
+ // for TF op.
+ y_out = rewriter.create<xla_hlo::ConvertOp>(op.getLoc(), y_out,
+ input_element_type);
// TF FusedBatchNormV3 op expects 5 outputs. Outputs 3 and 4 are
// currently marked as "reserved spaces 1 and 2". They are used to
@@ -848,15 +948,10 @@
op.scale(), op.offset(), op.mean(), op.variance(), op.epsilon(),
feature_dim.getValue());
- Value *y_out;
- if (is_mixed_precision) {
- // Convert back to input type to stay aligned with expected output type
- // for TF op.
- y_out = rewriter.create<xla_hlo::ConvertOp>(op.getLoc(), bn_train_op,
- input_element_type);
- } else {
- y_out = bn_train_op;
- }
+ // Convert back to input type to stay aligned with expected output type
+ // for TF op.
+ auto y_out = rewriter.create<xla_hlo::ConvertOp>(op.getLoc(), bn_train_op,
+ input_element_type);
// The mean, variance, and reserved space outputs of the batch norm op are
// not used for inference. It doesn't matter what values we provide for
@@ -1030,7 +1125,7 @@
PatternMatchResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
- Value *logits = op.logits();
+ Value logits = op.logits();
// Softmax converter requires ranked type because the XLA reduce ops used
// while lowering requires dimensions attribute to reduce along.
@@ -1060,16 +1155,16 @@
rewriter.create<SubOp>(loc, type, logits, max_logits, batch_dims);
// Exponentiate the inputs.
- Value *exp = rewriter.create<ExpOp>(loc, type, shifted_logits);
+ Value exp = rewriter.create<ExpOp>(loc, type, shifted_logits);
// Compute summation of the exponentials.
auto exp_sum =
rewriter.create<TF::SumOp>(loc, exp, reduce_dim,
/*keep_dims=*/rewriter.getBoolAttr(false));
- Value *sum = exp_sum.getResult();
+ Value sum = exp_sum.getResult();
if (use_log) {
- Value *log = rewriter.create<LogOp>(loc, sum);
+ Value log = rewriter.create<LogOp>(loc, sum);
rewriter.replaceOpWithNewOp<SubOp>(op, shifted_logits, log, batch_dims);
} else {
rewriter.replaceOpWithNewOp<DivOp>(op, exp, sum, batch_dims);
@@ -1106,7 +1201,7 @@
PatternMatchResult matchAndRewrite(TF::SizeOp op,
PatternRewriter &rewriter) const override {
- Value *input = op.input();
+ Value input = op.input();
auto input_ty = input->getType().dyn_cast<RankedTensorType>();
if (!input_ty) return Pattern::matchFailure();
@@ -1203,7 +1298,7 @@
SmallVector<int64_t, 4> strides(input_rank, 1);
// All HLO slice results used to replace the original tf.Split op.
- SmallVector<Value *, 4> slices;
+ SmallVector<Value, 4> slices;
slices.reserve(num_splits);
for (int i = 0; i < num_splits; ++i) {
@@ -1315,7 +1410,7 @@
SmallVector<int64_t, 4> strides(input_rank, 1);
// All HLO slice results used to replace the original tf.Split op.
- SmallVector<Value *, 4> slices;
+ SmallVector<Value, 4> slices;
slices.reserve(op.getNumResults());
for (int i = 0; i < op.getNumResults(); ++i) {
@@ -1457,7 +1552,7 @@
&strides))
return matchFailure();
- Value *grad = op.dy();
+ Value grad = op.dy();
Type element_type = grad->getType().cast<ShapedType>().getElementType();
// Perform reshape to undo any new/shrink axies done by strided slice.
@@ -1599,14 +1694,14 @@
rewriter.create<ConvertOp>(loc, op.input(), reduce_element_type);
// Each reduction op can have a different initial value.
- Value *init = Derived::GetInitialValue(reduce_element_type, loc, rewriter);
+ Value init = Derived::GetInitialValue(reduce_element_type, loc, rewriter);
auto reduction = rewriter.create<ReduceOp>(
loc, casted_input.getResult(), init,
GetI64ElementsAttr(xla_dimensions, &rewriter));
BuildReduceBody<ReductionOp>(reduce_element_type, &reduction.body(),
&rewriter);
- Value *result = reduction.getResult(0);
+ Value result = reduction.getResult(0);
// The mean op needs to divide by the product of the reduced dimensions.
if (std::is_same<OpTy, TF::MeanOp>::value) {
@@ -1650,8 +1745,8 @@
: public GenericConvertReductionOp<ConvertMeanOp, TF::MeanOp, AddOp> {
public:
using GenericConvertReductionOp::GenericConvertReductionOp;
- static Value *GetInitialValue(Type reduce_element_type, Location loc,
- PatternRewriter &rewriter) {
+ static Value GetInitialValue(Type reduce_element_type, Location loc,
+ PatternRewriter &rewriter) {
return GetScalarConstOfType(reduce_element_type, loc, 0, &rewriter);
}
};
@@ -1666,8 +1761,8 @@
public:
using GenericConvertReductionOp::GenericConvertReductionOp;
- static Value *GetInitialValue(Type reduce_element_type, Location loc,
- PatternRewriter &rewriter) {
+ static Value GetInitialValue(Type reduce_element_type, Location loc,
+ PatternRewriter &rewriter) {
return GetScalarConstOfType(reduce_element_type, loc, 0, &rewriter);
}
};
@@ -1683,8 +1778,8 @@
public:
using GenericConvertReductionOp::GenericConvertReductionOp;
- static Value *GetInitialValue(Type reduce_element_type, Location loc,
- PatternRewriter &rewriter) {
+ static Value GetInitialValue(Type reduce_element_type, Location loc,
+ PatternRewriter &rewriter) {
return GetMinValueForType(reduce_element_type, loc, &rewriter);
}
};
@@ -1698,8 +1793,8 @@
: public GenericConvertReductionOp<ConvertAllOp, TF::AllOp, AndOp> {
public:
using GenericConvertReductionOp::GenericConvertReductionOp;
- static Value *GetInitialValue(Type reduce_element_type, Location loc,
- PatternRewriter &rewriter) {
+ static Value GetInitialValue(Type reduce_element_type, Location loc,
+ PatternRewriter &rewriter) {
return GetScalarConstOfType(reduce_element_type, loc, 1, &rewriter);
}
};
@@ -1713,8 +1808,8 @@
: public GenericConvertReductionOp<ConvertAnyOp, TF::AnyOp, OrOp> {
public:
using GenericConvertReductionOp::GenericConvertReductionOp;
- static Value *GetInitialValue(Type reduce_element_type, Location loc,
- PatternRewriter &rewriter) {
+ static Value GetInitialValue(Type reduce_element_type, Location loc,
+ PatternRewriter &rewriter) {
return GetScalarConstOfType(reduce_element_type, loc, 0, &rewriter);
}
};
@@ -1742,7 +1837,7 @@
if (!input_element_type.isIntOrFloat()) return this->matchFailure();
Location loc = op.getLoc();
- Value *init_value =
+ Value init_value =
Derived::GetInitialValue(input_element_type, loc, rewriter);
RankedTensorType output_type =
@@ -1752,7 +1847,7 @@
}
Type index_element_type = output_type.getElementType();
- Value *index_init_value =
+ Value index_init_value =
GetScalarConstOfType(index_element_type, loc, 0, &rewriter);
RankedTensorType index_type =
@@ -1767,21 +1862,21 @@
IntegerAttr iota_dimension =
IntegerAttr::get(rewriter.getIntegerType(64), axis);
- Value *index_values =
+ Value index_values =
rewriter.create<IotaOp>(loc, index_type, iota_dimension);
std::vector<int64_t> dimensions = input_type.getShape();
dimensions.erase(dimensions.begin() + axis);
ArrayRef<int64_t> reduction_result_shape(dimensions);
- Value *operands[] = {op.input(), index_values};
- Value *init_values[] = {init_value, index_init_value};
+ Value operands[] = {op.input(), index_values};
+ Value init_values[] = {init_value, index_init_value};
DenseIntElementsAttr reduction_dimensions =
GetI64ElementsAttr({axis}, &rewriter);
auto reduction = rewriter.create<ReduceOp>(
- loc, llvm::ArrayRef<Value *>(operands),
- llvm::ArrayRef<Value *>(init_values), reduction_dimensions);
+ loc, llvm::ArrayRef<Value>(operands),
+ llvm::ArrayRef<Value>(init_values), reduction_dimensions);
StringRef direction = Derived::GetDirection();
BuildArgMinMaxReductionBody(input_element_type, index_element_type,
direction, &reduction.body(), &rewriter);
@@ -1803,14 +1898,70 @@
public:
using ConvertArgMinMaxOp::ConvertArgMinMaxOp;
- static Value *GetInitialValue(Type reduce_element_type, Location loc,
- PatternRewriter &rewriter) {
+ static Value GetInitialValue(Type reduce_element_type, Location loc,
+ PatternRewriter &rewriter) {
return GetMinValueForType(reduce_element_type, loc, &rewriter);
}
static StringRef GetDirection() { return "GT"; }
};
+// Converts TF TensorScatterUpdate op into Scatter Op with assignment:
+//
+// %result = "xla_hlo.scatter"(%tensor, %indices, %updates)
+// { dimensions = ... }
+//
+class ConvertTensorScatterUpdateOp
+ : public OpRewritePattern<TF::TensorScatterUpdateOp> {
+ public:
+ using OpRewritePattern::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(TF::TensorScatterUpdateOp op,
+ PatternRewriter &rewriter) const override {
+ auto tensor_ty = op.tensor()->getType().dyn_cast<RankedTensorType>();
+ auto indices_ty = op.indices()->getType().dyn_cast<RankedTensorType>();
+ auto updates_ty = op.updates()->getType().dyn_cast<RankedTensorType>();
+
+ if (!tensor_ty || !indices_ty || !updates_ty) return matchFailure();
+ // Last dimension of the indices needs to known at compile time for
+ // computation of the 'update_window_dims' attribute in the dimensions
+ // struct.
+ int64_t num_index_dims = indices_ty.getShape().back();
+ if (ShapedType::isDynamic(num_index_dims)) return matchFailure();
+
+ int64_t tensor_rank = tensor_ty.getRank();
+ int64_t indices_rank = indices_ty.getRank();
+ int64_t updates_rank = updates_ty.getRank();
+
+ int64_t window_dims = tensor_rank - num_index_dims;
+ auto dims_attr = ScatterDimensionNumbers::get(
+ GetI64ElementsAttrForSeq(updates_rank - window_dims, updates_rank,
+ &rewriter),
+ GetI64ElementsAttrForSeq(0, num_index_dims, &rewriter),
+ GetI64ElementsAttrForSeq(0, num_index_dims, &rewriter),
+ rewriter.getI64IntegerAttr(indices_rank - 1), rewriter.getContext());
+
+ Location loc = op.getLoc();
+ auto scatter = rewriter.create<ScatterOp>(
+ loc, op.getType(), op.tensor(), op.indices(), op.updates(), dims_attr);
+
+ // Build region to assign the new value.
+ [&](Region *region) {
+ OpBuilder::InsertionGuard guard(rewriter);
+ Block *block = rewriter.createBlock(region);
+
+ // Block arguments are scalars of the given element type.
+ Type type =
+ RankedTensorType::get(/*shape=*/{}, tensor_ty.getElementType());
+ block->addArguments({type, type});
+ rewriter.create<ReturnOp>(loc, block->getArgument(1));
+ }(&scatter.update_computation());
+
+ rewriter.replaceOp(op, scatter.getResult());
+ return matchSuccess();
+ }
+};
+
// Converts Tile op to HLO BroadcastInDim and Reshape ops.
// For shape [S1, S2] and multiples [M1, M2],
// MS1 = M1 * S1; MS2 = M2 * S2
@@ -1867,7 +2018,7 @@
RankedTensorType::get(broadcasted_shape, element_type);
Type output_type = op.getType();
- Value *result = rewriter.create<BroadcastInDimOp>(
+ Value result = rewriter.create<BroadcastInDimOp>(
loc, broadcasted_type, op.input(),
GetI64ElementsAttr(broadcast_dimensions, &rewriter));
@@ -2024,7 +2175,7 @@
auto paddings_attr = DenseIntElementsAttr::get(paddings_ty, conv_paddings);
auto spatial_dims_attr = GetI64ElementsAttr(spatial_dims, &rewriter);
- Value *filter = op.filter();
+ Value filter = op.filter();
if (feature_group_count != 1) {
/*
@@ -2041,7 +2192,7 @@
// activation gradients
// = gradients (with padding and dilation) <conv> mirrored_weights
- Value *result = rewriter.create<ConvOp>(
+ Value result = rewriter.create<ConvOp>(
loc, op.getType(), op.out_backprop(), filter,
/*window_strides=*/GetI64ElementsAttr(ones, &rewriter),
/*padding=*/paddings_attr, GetI64ElementsAttr(lhs_dilation, &rewriter),
@@ -2242,7 +2393,7 @@
auto feature_dim_attr = rewriter.getI64IntegerAttr(feature_dim);
Location loc = op.getLoc();
- Value *result = rewriter.create<ConvOp>(
+ Value result = rewriter.create<ConvOp>(
loc, op.getType(), op.input(), op.out_backprop(),
/*window_strides=*/GetI64ElementsAttr(window_strides, &rewriter),
/*padding=*/paddings_attr, GetI64ElementsAttr(lhs_dilation, &rewriter),
@@ -2305,21 +2456,21 @@
Location loc = op.getLoc();
auto index_type = RankedTensorType::get(output_dims, element_type);
- Value *compare = rewriter.create<CompareOp>(
+ Value compare = rewriter.create<CompareOp>(
loc, op.indices(),
rewriter.create<IotaOp>(
loc, index_type,
IntegerAttr::get(rewriter.getIntegerType(64), axis)),
GetI64ElementsAttr(broadcast_dims, &rewriter),
StringAttr::get("EQ", rewriter.getContext()));
- Value *on_value = rewriter.create<BroadcastOp>(
+ Value on_value = rewriter.create<BroadcastOp>(
loc, op.getType(), op.on_value(),
GetI64ElementsAttr(output_dims, &rewriter));
- Value *off_value = rewriter.create<BroadcastOp>(
+ Value off_value = rewriter.create<BroadcastOp>(
loc, op.getType(), op.off_value(),
GetI64ElementsAttr(output_dims, &rewriter));
- Value *result = rewriter.create<SelectOp>(loc, op.getType(), compare,
- on_value, off_value);
+ Value result = rewriter.create<SelectOp>(loc, op.getType(), compare,
+ on_value, off_value);
rewriter.replaceOp(
op, {result},
@@ -2381,14 +2532,14 @@
// Create an Itoa op for indices.
auto i32_type = rewriter.getIntegerType(32);
Type iota_type = RankedTensorType::get(input_type.getShape(), i32_type);
- Value *iota_op = rewriter.create<xla_hlo::IotaOp>(
+ Value iota_op = rewriter.create<xla_hlo::IotaOp>(
op.getLoc(), iota_type, rewriter.getI64IntegerAttr(last_dim_index));
// Create the sort op. It takes two inputs, one for the original input, the
// other for the indices.
auto sort_op = rewriter.create<xla_hlo::SortOp>(
- op.getLoc(), llvm::ArrayRef<Value *>{op.input(), iota_op},
- last_dim_index, /*is_stable=*/true);
+ op.getLoc(), llvm::ArrayRef<Value>{op.input(), iota_op}, last_dim_index,
+ /*is_stable=*/true);
BuildSortComparisonBody({input_type.getElementType(), i32_type},
/*direction=*/"GT", &sort_op.comparator(),
&rewriter);
@@ -2407,13 +2558,13 @@
// Get the slice for the top K elements.
- Value *values = rewriter.create<xla_hlo::SliceOp>(
+ Value values = rewriter.create<xla_hlo::SliceOp>(
op.getLoc(), tuple_first_element,
GetI64ElementsAttr(begin_indices, &rewriter),
GetI64ElementsAttr(end_indices, &rewriter),
GetI64ElementsAttr(strides, &rewriter));
- Value *indices = rewriter.create<xla_hlo::SliceOp>(
+ Value indices = rewriter.create<xla_hlo::SliceOp>(
op.getLoc(), tuple_second_element,
GetI64ElementsAttr(begin_indices, &rewriter),
GetI64ElementsAttr(end_indices, &rewriter),
@@ -2449,7 +2600,7 @@
SmallVector<int64_t, 4> strides(value_rank, 1);
// All HLO slice+reshape results used to replace the original tf.Unpack op.
- SmallVector<Value *, 4> results;
+ SmallVector<Value, 4> results;
results.reserve(op.getNumResults());
for (int i = 0; i < op.getNumResults(); ++i) {
@@ -2518,22 +2669,20 @@
// Broadccast the initial value for reduction. This will become the
// 'operand' parameter to scatter to for the final scatter op.
- Value *init = ConcreteClass::GetInitialValue(data_type.getElementType(),
- op.getLoc(), rewriter);
+ Value init = ConcreteClass::GetInitialValue(data_type.getElementType(),
+ op.getLoc(), rewriter);
auto broadcasted_init = rewriter.create<xla_hlo::BroadcastOp>(
op.getLoc(), output_type, init,
GetI64ElementsAttr(output_shape, &rewriter));
// Parameters for the generated scatter op.
- auto range = llvm::seq<int64_t>(segment_ids_rank, data_rank);
- SmallVector<int64_t, 4> update_window_dims(range.begin(), range.end());
SmallVector<int64_t, 1> inserted_window_dims(1, 0);
SmallVector<int64_t, 1> scatter_dims_to_operand_dims(1, 0);
int64_t index_vector_dim = segment_ids_rank;
// Put all parameters in a StructAttr.
auto dims_attr = ScatterDimensionNumbers::get(
- GetI64ElementsAttr(update_window_dims, &rewriter),
+ GetI64ElementsAttrForSeq(segment_ids_rank, data_rank, &rewriter),
GetI64ElementsAttr(inserted_window_dims, &rewriter),
GetI64ElementsAttr(scatter_dims_to_operand_dims, &rewriter),
rewriter.getI64IntegerAttr(index_vector_dim), rewriter.getContext());
@@ -2556,8 +2705,8 @@
using GenericConvertUnsortedSegmentReductionOp::
GenericConvertUnsortedSegmentReductionOp;
- static Value *GetInitialValue(Type reduce_element_type, Location loc,
- PatternRewriter &rewriter) {
+ static Value GetInitialValue(Type reduce_element_type, Location loc,
+ PatternRewriter &rewriter) {
return GetMinValueForType(reduce_element_type, loc, &rewriter);
}
};
@@ -2569,8 +2718,8 @@
using GenericConvertUnsortedSegmentReductionOp::
GenericConvertUnsortedSegmentReductionOp;
- static Value *GetInitialValue(Type reduce_element_type, Location loc,
- PatternRewriter &rewriter) {
+ static Value GetInitialValue(Type reduce_element_type, Location loc,
+ PatternRewriter &rewriter) {
return GetMaxValueForType(reduce_element_type, loc, &rewriter);
}
};
@@ -2582,8 +2731,8 @@
using GenericConvertUnsortedSegmentReductionOp::
GenericConvertUnsortedSegmentReductionOp;
- static Value *GetInitialValue(Type reduce_element_type, Location loc,
- PatternRewriter &rewriter) {
+ static Value GetInitialValue(Type reduce_element_type, Location loc,
+ PatternRewriter &rewriter) {
return GetScalarConstOfType(reduce_element_type, loc, 1, &rewriter);
}
};
@@ -2595,8 +2744,8 @@
using GenericConvertUnsortedSegmentReductionOp::
GenericConvertUnsortedSegmentReductionOp;
- static Value *GetInitialValue(Type reduce_element_type, Location loc,
- PatternRewriter &rewriter) {
+ static Value GetInitialValue(Type reduce_element_type, Location loc,
+ PatternRewriter &rewriter) {
return GetScalarConstOfType(reduce_element_type, loc, 0, &rewriter);
}
};
@@ -2617,14 +2766,16 @@
patterns.insert<
ConvertAllOp, ConvertAnyOp, ConvertArgMaxOp, ConvertBF16FloorDivOp,
ConvertConv2D, ConvertConv2DBackpropFilterOp,
- ConvertConv2DBackpropInputOp, ConvertEinsumOp, ConvertFusedBatchNormV3Op,
- ConvertMaxOp, ConvertMaxPoolOp, ConvertMaxPoolGradOp, ConvertMeanOp,
- ConvertOneHotOp, ConvertRangeOp, ConvertSigmoidOp, ConvertSizeOp,
+ ConvertConv2DBackpropInputOp, ConvertEinsumOp,
+ ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op,
+ ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV3Op, ConvertMaxOp,
+ ConvertMaxPoolOp, ConvertMaxPoolGradOp, ConvertMeanOp, ConvertOneHotOp,
+ ConvertRangeOp, ConvertSigmoidOp, ConvertSizeOp,
ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,
- ConvertTileOp, ConvertTopKV2Op, ConvertUnpackOp,
- ConvertUnsortedSegmentMaxOp, ConvertUnsortedSegmentMinOp,
+ ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,
+ ConvertUnpackOp, ConvertUnsortedSegmentMaxOp, ConvertUnsortedSegmentMinOp,
ConvertUnsortedSegmentProdOp, ConvertUnsortedSegmentSumOp>(
op->getContext());
@@ -2647,7 +2798,7 @@
signalPassFailure();
}
-static PassRegistration<LegalizeTF, LegalizeTF::Options> pass(
+static PassRegistration<LegalizeTF> pass(
"xla-legalize-tf", "Legalize from TensorFlow to the XLA dialect");
} // end namespace
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc
index ac14bca..e78b9b6 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc
@@ -64,8 +64,7 @@
namespace {
-void Detuple(Value* tuple, Operation::result_range replace,
- OpBuilder* builder) {
+void Detuple(Value tuple, Operation::result_range replace, OpBuilder* builder) {
// De-tuple the results of the xla hlo conditional result.
for (auto result_it : llvm::enumerate(replace)) {
auto get_tuple_value = builder->create<xla_hlo::GetTupleElementOp>(
@@ -87,7 +86,7 @@
auto entry_block = builder.createBlock(dest_region);
auto tuple_arg = entry_block->addArgument(
builder.getTupleType(func.getType().getInputs()));
- llvm::SmallVector<Value*, 4> detupled_args;
+ llvm::SmallVector<Value, 4> detupled_args;
detupled_args.reserve(func.getNumArguments());
for (int64_t i = 0, s = func.getNumArguments(); i < s; i++) {
@@ -110,12 +109,12 @@
// XLA prefers tuple arguments for control flow due to XLA not supporting
// multiple return values.
- SmallVector<Value*, 3> inputs(op.input());
+ SmallVector<Value, 3> inputs(op.input());
builder.setInsertionPoint(op);
auto tuple_input = builder.create<xla_hlo::TupleOp>(loc, inputs);
// Create the new conditional op with tuple inputs.
- SmallVector<Value*, 3> operands(op.getOperands());
+ SmallVector<Value, 3> operands(op.getOperands());
SmallVector<Type, 4> types(op.getResultTypes());
auto result_type = builder.getTupleType(types);
auto conditional = builder.create<xla_hlo::ConditionalOp>(
@@ -142,12 +141,12 @@
// XLA prefers tuple arguments for control flow due to XLA not supporting
// multiple return values.
- SmallVector<Value*, 3> inputs(op.input());
+ SmallVector<Value, 3> inputs(op.input());
builder.setInsertionPoint(op);
- Value* tuple_input = builder.create<xla_hlo::TupleOp>(loc, inputs);
+ Value tuple_input = builder.create<xla_hlo::TupleOp>(loc, inputs);
// Create the new while op with tuple inputs.
- SmallVector<Value*, 3> operands(op.getOperands());
+ SmallVector<Value, 3> operands(op.getOperands());
SmallVector<Type, 4> types(op.getResultTypes());
auto while_op = builder.create<xla_hlo::WhileOp>(
loc, builder.getTupleType(types), tuple_input);
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
index ad91cf0..ed5e10d 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
@@ -428,6 +428,7 @@
[TF_ImagOp, HLO_ImagOp],
[TF_IsFiniteOp, HLO_IsFiniteOp],
[TF_LogOp, HLO_LogOp],
+ [TF_Log1pOp, HLO_Log1pOp],
[TF_LogicalNotOp, HLO_NotOp],
[TF_NegOp, HLO_NegOp],
[TF_RealOp, HLO_RealOp],
@@ -457,6 +458,19 @@
(HLO_ReshapeOp $arg), [(AnyStaticShapeTensor $res)]>;
}
+// Returns 0 if x is NaN, 0 if x is 0, -1 if x < 0 and 1 if x > 0.
+def : Pat<(TF_SignOp $x),
+ (HLO_SelectOp
+ (HLO_CompareOp
+ $x,
+ $x,
+ (NullDenseIntElementsAttr),
+ HLO_COMPARISON_DIRECTION_NE
+ ),
+ (HLO_ConstOp (ConstantSplat<"0"> $x)),
+ (HLO_SignOp $x)
+ )>;
+
//===----------------------------------------------------------------------===//
// RngUniform.
//===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td
index 43c57b9..1d009a3 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td
@@ -74,9 +74,9 @@
[(IsSameSizeConstraint $l, $r)]>;
def : Pat<(HLO_DivOp HLO_IntTensor:$l, HLO_IntTensor:$r,
IsNullAttr:$broadcast_dimensions),
- (DivISOp $l, $r),
+ (SignedDivIOp $l, $r),
[(IsSameSizeConstraint $l, $r)]>;
def : Pat<(HLO_RemOp HLO_IntTensor:$l, HLO_IntTensor:$r,
IsNullAttr:$broadcast_dimensions),
- (RemISOp $l, $r),
+ (SignedRemIOp $l, $r),
[(IsSameSizeConstraint $l, $r)]>;
diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc
index a8a2eb7..5ed88db 100644
--- a/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc
@@ -42,7 +42,7 @@
// tiled. In order to greedily fuse the ops, we have to start from the tiled
// root linalg ops, i.e. linalg ops that write to output buffers of the
// function.
- llvm::SmallDenseSet<Value*> func_args;
+ llvm::SmallDenseSet<Value> func_args;
for (auto func_arg : func.getArguments()) {
func_args.insert(func_arg);
}
@@ -52,7 +52,7 @@
const SmallVector<int64_t, 2> tile_sizes(
generic_op.getNumInputsAndOutputs(), 1);
auto op = cast<LinalgOp>(generic_op.getOperation());
- for (const Value* result : op.getOutputs()) {
+ for (const Value result : op.getOutputs()) {
if (!func_args.count(result)) continue;
if (linalg::tileLinalgOp(b, op, tile_sizes, /*permutation=*/{},
&folder)) {
diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc
index f3b8ab9..42b340d 100644
--- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc
@@ -47,7 +47,7 @@
return this->matchFailure();
}
const auto& shape = lhs_type.getShape();
- SmallVector<Value*, 4> induction_vars;
+ SmallVector<Value, 4> induction_vars;
const auto loc = op.getLoc();
for (int i = 0; i < shape.size(); ++i) {
auto forOp = rewriter.create<AffineForOp>(loc, 0, shape[i]);
diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc
index 9f1f90c..5a94707 100644
--- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc
@@ -49,7 +49,7 @@
using OpConversionPattern::OpConversionPattern;
PatternMatchResult matchAndRewrite(
- ReduceOp reduce_op, ArrayRef<Value*> args,
+ ReduceOp reduce_op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
auto loc = reduce_op.getLoc();
// Only support 1d reductions for now.
@@ -105,7 +105,7 @@
loc, mapping.lookup(std::get<0>(pair)));
rewriter.create<mlir::StoreOp>(loc, init_value,
mapping.lookup(std::get<1>(pair)),
- ArrayRef<Value*>{index});
+ ArrayRef<Value>{index});
}
// Insert a loop into the body to compute the reduction. The loop ranges
@@ -133,8 +133,8 @@
MemRefType::getDynamicStrideOrOffset(),
rewriter.getContext()));
auto accumulator = rewriter.create<mlir::linalg::SliceOp>(
- loc, resType, output, ArrayRef<Value*>{launch_op.getThreadIds().x});
- llvm::SmallVector<Value*, 4> indexings;
+ loc, resType, output, ArrayRef<Value>{launch_op.getThreadIds().x});
+ llvm::SmallVector<Value, 4> indexings;
auto input_buffer = *reduce_op.operands().begin();
auto input_type = input_buffer->getType().cast<MemRefType>();
for (int64_t dim = 0; dim < input_type.getRank(); ++dim) {
diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_linalg.cc
index af7383c..1e3da7d 100644
--- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_linalg.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_linalg.cc
@@ -53,7 +53,7 @@
using OpConversionPattern<LhloOp>::OpConversionPattern;
PatternMatchResult matchAndRewrite(
- LhloOp lhlo_op, ArrayRef<Value*> args,
+ LhloOp lhlo_op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
auto loc = lhlo_op.getLoc();
auto argType =
@@ -101,7 +101,7 @@
block->addArguments(bodyArgTypes);
block->addArguments(bodyResultTypes);
- SmallVector<Value*, 4> bodyArgs;
+ SmallVector<Value, 4> bodyArgs;
for (int i = 0, e = bodyArgTypes.size(); i < e; ++i) {
bodyArgs.push_back(block->getArgument(i));
}
@@ -121,7 +121,7 @@
using OpConversionPattern<LhloOp>::OpConversionPattern;
PatternMatchResult matchAndRewrite(
- LhloOp lhlo_op, ArrayRef<Value*> args,
+ LhloOp lhlo_op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
auto loc = lhlo_op.getLoc();
auto argType =
@@ -136,7 +136,7 @@
auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs());
Operation* op = MapLhloOpToStdScalarOp<LhloOp>(
llvm::cast<LhloOp>(lhlo_op), argType.getElementType(),
- llvm::ArrayRef<Value*>{lhs, rhs}, rewriter);
+ llvm::ArrayRef<Value>{lhs, rhs}, rewriter);
rewriter.create<StoreOp>(loc, op->getResult(0), lhlo_op.out());
rewriter.eraseOp(lhlo_op);
return ConversionPattern::matchSuccess();
@@ -148,7 +148,7 @@
using OpConversionPattern<BroadcastInDimOp>::OpConversionPattern;
PatternMatchResult matchAndRewrite(
- BroadcastInDimOp broadcastOp, ArrayRef<Value*> args,
+ BroadcastInDimOp broadcastOp, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
auto operandMemrefType =
broadcastOp.operand()->getType().dyn_cast<MemRefType>();
@@ -167,7 +167,7 @@
private:
PatternMatchResult emitScalarBroadcast(
- BroadcastInDimOp broadcastOp, ArrayRef<Value*> args,
+ BroadcastInDimOp broadcastOp, ArrayRef<Value> args,
MemRefType resultMemrefType, ConversionPatternRewriter* rewriter) const {
unsigned nloops = resultMemrefType.getRank();
SmallVector<Attribute, 1> indexingMaps{
@@ -195,7 +195,7 @@
}
PatternMatchResult emitNonScalarBroadcast(
- BroadcastInDimOp broadcastOp, ArrayRef<Value*> args,
+ BroadcastInDimOp broadcastOp, ArrayRef<Value> args,
MemRefType operandMemrefType, MemRefType resultMemrefType,
ConversionPatternRewriter* rewriter) const {
SmallVector<Type, 4> bodyArgTypes{operandMemrefType.getElementType()};
@@ -250,7 +250,7 @@
using OpConversionPattern<IotaOp>::OpConversionPattern;
PatternMatchResult matchAndRewrite(
- IotaOp iotaOp, ArrayRef<Value*> args,
+ IotaOp iotaOp, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
auto resultMemrefType =
iotaOp.getOperand()->getType().dyn_cast<MemRefType>();
@@ -301,7 +301,7 @@
using OpConversionPattern<ConstOp>::OpConversionPattern;
PatternMatchResult matchAndRewrite(
- ConstOp constOp, ArrayRef<Value*> args,
+ ConstOp constOp, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
auto loc = constOp.getLoc();
auto valueAttr = constOp.value().cast<DenseElementsAttr>();
diff --git a/tensorflow/compiler/mlir/xla/transforms/lower_general_dot.cc b/tensorflow/compiler/mlir/xla/transforms/lower_general_dot.cc
index 515f818..7b72b70 100644
--- a/tensorflow/compiler/mlir/xla/transforms/lower_general_dot.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/lower_general_dot.cc
@@ -44,11 +44,11 @@
namespace {
-Value *TransposeReshape(Value *arg, mlir::Location loc,
- llvm::ArrayRef<int64_t> left_dims,
- llvm::ArrayRef<int64_t> right_dims,
- llvm::ArrayRef<int64_t> arg_shape,
- PatternRewriter *rewriter) {
+Value TransposeReshape(Value arg, mlir::Location loc,
+ llvm::ArrayRef<int64_t> left_dims,
+ llvm::ArrayRef<int64_t> right_dims,
+ llvm::ArrayRef<int64_t> arg_shape,
+ PatternRewriter *rewriter) {
auto element_type = mlir::getElementTypeOrSelf(arg->getType());
int64_t left_size = 1;
@@ -91,9 +91,9 @@
transpose_result);
}
-Value *ProcessDotArg(Value *arg, mlir::Location loc,
- ElementsAttr contract_dims_attr, bool outer_dims_first,
- PatternRewriter *rewriter) {
+Value ProcessDotArg(Value arg, mlir::Location loc,
+ ElementsAttr contract_dims_attr, bool outer_dims_first,
+ PatternRewriter *rewriter) {
auto shape = arg->getType().cast<mlir::ShapedType>().getShape();
llvm::SmallVector<bool, 5> is_outer_dim;
diff --git a/tensorflow/compiler/mlir/xla/transforms/map_lhlo_to_scalar_op.h b/tensorflow/compiler/mlir/xla/transforms/map_lhlo_to_scalar_op.h
index 11e3af7..883424f 100644
--- a/tensorflow/compiler/mlir/xla/transforms/map_lhlo_to_scalar_op.h
+++ b/tensorflow/compiler/mlir/xla/transforms/map_lhlo_to_scalar_op.h
@@ -40,7 +40,7 @@
template <>
struct ScalarOp<xla_lhlo::DivOp> {
using FOp = ::mlir::DivFOp;
- using IOp = ::mlir::DivISOp;
+ using IOp = ::mlir::SignedDivIOp;
};
template <>
struct ScalarOp<xla_lhlo::MulOp> {
@@ -60,7 +60,7 @@
template <typename LhloOp>
Operation* MapLhloOpToStdScalarOp(LhloOp lhlo_op, ArrayRef<Type> result_types,
- ArrayRef<Value*> block_args, OpBuilder b) {
+ ArrayRef<Value> block_args, OpBuilder b) {
Type element_type = block_args.front()->getType();
if (element_type.isa<IntegerType>()) {
return b.template create<ScalarIOp<LhloOp>>(lhlo_op.getLoc(), result_types,
@@ -76,7 +76,7 @@
template <>
inline Operation* MapLhloOpToStdScalarOp<xla_lhlo::MaxOp>(
xla_lhlo::MaxOp lhlo_op, ArrayRef<Type> result_types,
- ArrayRef<Value*> block_args, OpBuilder b) {
+ ArrayRef<Value> block_args, OpBuilder b) {
const auto& lhs = block_args[0];
const auto& rhs = block_args[1];
Type element_type = lhs->getType();
@@ -96,7 +96,7 @@
template <>
inline Operation* MapLhloOpToStdScalarOp<xla_lhlo::MinOp>(
xla_lhlo::MinOp lhlo_op, ArrayRef<Type> result_types,
- ArrayRef<Value*> block_args, OpBuilder b) {
+ ArrayRef<Value> block_args, OpBuilder b) {
const auto& lhs = block_args[0];
const auto& rhs = block_args[1];
Type element_type = lhs->getType();
@@ -116,7 +116,7 @@
template <>
inline Operation* MapLhloOpToStdScalarOp<xla_lhlo::AndOp>(
xla_lhlo::AndOp lhlo_op, ArrayRef<Type> result_types,
- ArrayRef<Value*> block_args, OpBuilder b) {
+ ArrayRef<Value> block_args, OpBuilder b) {
Type element_type = block_args.front()->getType();
return element_type.isa<IntegerType>()
? b.create<::mlir::AndOp>(lhlo_op.getLoc(), result_types,
@@ -150,7 +150,7 @@
template <>
inline Operation* MapLhloOpToStdScalarOp<xla_lhlo::CompareOp>(
xla_lhlo::CompareOp lhlo_op, ArrayRef<Type> result_types,
- ArrayRef<Value*> block_args, OpBuilder b) {
+ ArrayRef<Value> block_args, OpBuilder b) {
const auto& lhs = block_args[0];
const auto& rhs = block_args[1];
Type element_type = lhs->getType();
@@ -172,7 +172,7 @@
template <>
inline Operation* MapLhloOpToStdScalarOp<xla_lhlo::SelectOp>(
xla_lhlo::SelectOp lhlo_op, ArrayRef<Type> result_types,
- ArrayRef<Value*> block_args, OpBuilder b) {
+ ArrayRef<Value> block_args, OpBuilder b) {
return b.create<::mlir::SelectOp>(lhlo_op.getLoc(), result_types, block_args,
mlir::None);
}
@@ -180,7 +180,7 @@
template <>
inline Operation* MapLhloOpToStdScalarOp<xla_lhlo::ExpOp>(
xla_lhlo::ExpOp lhlo_op, ArrayRef<Type> result_types,
- ArrayRef<Value*> block_args, OpBuilder b) {
+ ArrayRef<Value> block_args, OpBuilder b) {
Type element_type = block_args.front()->getType();
return element_type.isa<FloatType>()
? b.create<::mlir::ExpOp>(lhlo_op.getLoc(), result_types,
diff --git a/tensorflow/compiler/tests/depthwise_conv_op_test.py b/tensorflow/compiler/tests/depthwise_conv_op_test.py
index a49985f..0f0ea50 100644
--- a/tensorflow/compiler/tests/depthwise_conv_op_test.py
+++ b/tensorflow/compiler/tests/depthwise_conv_op_test.py
@@ -68,21 +68,21 @@
Tuple (input_size, filter_size, out_size, stride, padding), the depthwise
convolution parameters.
"""
- input_sizes = [[4, 5, 5, 48], [4, 8, 8, 84], [4, 17, 17, 48], [4, 9, 27, 8],
- [4, 31, 31, 7], [4, 35, 35, 2], [4, 147, 147, 2],
- [3, 299, 299, 3], [5, 183, 183, 1]]
- filter_sizes = [[1, 1, 48, 2], [1, 3, 84, 1], [3, 1, 48, 4], [3, 3, 8, 1],
- [3, 3, 7, 1], [5, 5, 2, 1], [3, 3, 2, 8], [2, 2, 3,
- 8], [5, 5, 1, 2]]
- out_sizes = [[4, 5, 5, 96], [4, 8, 8, 84], [4, 17, 17, 192], [4, 9, 27, 8],
- [4, 31, 31, 7], [4, 35, 35, 2], [4, 49, 49, 16],
+ input_sizes = [[4, 5, 5, 48], [2, 5, 5, 48], [4, 8, 8, 84], [4, 17, 17, 48],
+ [4, 9, 27, 8], [4, 31, 31, 7], [4, 35, 35, 2],
+ [4, 147, 147, 2], [3, 299, 299, 3], [5, 183, 183, 1]]
+ filter_sizes = [[1, 1, 48, 2], [2, 2, 48, 8], [1, 3, 84, 1], [3, 1, 48, 4],
+ [3, 3, 8, 1], [3, 3, 7, 1], [5, 5, 2, 1], [3, 3, 2, 8],
+ [2, 2, 3, 8], [5, 5, 1, 2]]
+ out_sizes = [[4, 5, 5, 96], [2, 5, 5, 384], [4, 8, 8, 84], [4, 17, 17, 192],
+ [4, 9, 27, 8], [4, 31, 31, 7], [4, 35, 35, 2], [4, 49, 49, 16],
[3, 150, 150, 24], [5, 92, 92, 2]]
- strides = [1, 1, 1, 1, 1, 1, 3, 2, 2]
+ strides = [1, 1, 1, 1, 1, 1, 1, 3, 2, 2]
# pylint: disable=invalid-name
VALID = "VALID"
SAME = "SAME"
# pylint: enable=invalid-name
- paddings = [SAME, SAME, SAME, SAME, SAME, SAME, VALID, SAME, SAME, SAME]
+ paddings = [SAME, SAME, SAME, SAME, SAME, SAME, SAME, VALID, SAME, SAME, SAME]
for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes, strides,
paddings):
yield i, f, o, s, p
diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD
index f6e9780..65679bd 100644
--- a/tensorflow/compiler/tf2tensorrt/BUILD
+++ b/tensorflow/compiler/tf2tensorrt/BUILD
@@ -500,7 +500,8 @@
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib_proto_parsing",
- ],
+ "//tensorflow/core:lib",
+ ] + if_tensorrt([":tensorrt_lib"]),
)
tf_proto_library(
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
index 0735994..4e76287 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
@@ -51,6 +51,7 @@
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/tensor_coding.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/strided_slice_op.h"
#if GOOGLE_CUDA
@@ -200,18 +201,6 @@
return this->at(key)->i();
}
-template <typename TensorShapeType>
-inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape,
- bool ignore_first_dim) {
- nvinfer1::Dims trt_dims;
- const int offset = (ignore_first_dim ? 1 : 0);
- for (int i = offset; i < shape.dims(); i++) {
- trt_dims.d[i - offset] = shape.dim_size(i);
- }
- trt_dims.nbDims = shape.dims() - offset;
- return trt_dims;
-}
-
template <typename Container>
Status TensorShapeArrayToTrtDims(const Container& shape, nvinfer1::Dims* out,
bool ignore_first_dim = false) {
@@ -314,66 +303,6 @@
return Status::OK();
}
-string DebugString(const nvinfer1::DimensionType type) {
- switch (type) {
- case nvinfer1::DimensionType::kSPATIAL:
- return "kSPATIAL";
- case nvinfer1::DimensionType::kCHANNEL:
- return "kCHANNEL";
- case nvinfer1::DimensionType::kINDEX:
- return "kINDEX";
- case nvinfer1::DimensionType::kSEQUENCE:
- return "kSEQUENCE";
- default:
- return StrCat(static_cast<int>(type), "=unknown");
- }
-}
-
-string DebugString(const nvinfer1::DataType trt_dtype) {
- switch (trt_dtype) {
- case nvinfer1::DataType::kFLOAT:
- return "kFLOAT";
- case nvinfer1::DataType::kHALF:
- return "kHALF";
- case nvinfer1::DataType::kINT8:
- return "kINT8";
- case nvinfer1::DataType::kINT32:
- return "kINT32";
- default:
- return "Invalid TRT data type";
- }
-}
-
-string DebugString(const nvinfer1::Dims& dims) {
- string out = StrCat("nvinfer1::Dims(nbDims=", dims.nbDims, ", d=");
- for (int i = 0; i < dims.nbDims; ++i) {
- StrAppend(&out, dims.d[i]);
- if (VLOG_IS_ON(2)) {
- StrAppend(&out, "[", DebugString(dims.type[i]), "],");
- } else {
- StrAppend(&out, ",");
- }
- }
- StrAppend(&out, ")");
- return out;
-}
-
-string DebugString(const nvinfer1::Permutation& permutation, int len) {
- string out = "nvinfer1::Permutation(";
- for (int i = 0; i < len; ++i) {
- StrAppend(&out, permutation.order[i], ",");
- }
- StrAppend(&out, ")");
- return out;
-}
-
-string DebugString(const nvinfer1::ITensor& tensor) {
- return StrCat("nvinfer1::ITensor(@", reinterpret_cast<uintptr_t>(&tensor),
- ", name=", tensor.getName(),
- ", dtype=", DebugString(tensor.getType()),
- ", dims=", DebugString(tensor.getDimensions()), ")");
-}
-
Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l,
const TRT_TensorOrWeights& operand_r,
const bool check_feasibility,
@@ -581,14 +510,6 @@
return dims;
}
-inline bool HasStaticShape(const nvinfer1::Dims& dims) {
- if (dims.nbDims < 0) return false;
- for (int d = 0; d < dims.nbDims; ++d) {
- if (dims.d[d] < 0) return false;
- }
- return true;
-}
-
int64_t Prod(const nvinfer1::Dims& dims) {
int64_t count = 1;
for (int d = 0; d < dims.nbDims; ++d) {
@@ -732,9 +653,10 @@
}
string TRT_ShapedWeights::DebugString() const {
- return StrCat("TRT_ShapedWeights(shape=", convert::DebugString(shape_),
- ", type=", convert::DebugString(type_),
- ", values=", reinterpret_cast<uintptr_t>(GetValues()), ")");
+ return StrCat(
+ "TRT_ShapedWeights(shape=", tensorflow::tensorrt::DebugString(shape_),
+ ", type=", tensorflow::tensorrt::DebugString(type_),
+ ", values=", reinterpret_cast<uintptr_t>(GetValues()), ")");
}
// A fake ITensor implementation used to check whether the TF-TRT converter can
@@ -858,7 +780,7 @@
string TRT_TensorOrWeights::DebugString() const {
string output = "TRT_TensorOrWeights(type=";
if (is_tensor()) {
- StrAppend(&output, "tensor=", convert::DebugString(*tensor()),
+ StrAppend(&output, "tensor=", tensorflow::tensorrt::DebugString(*tensor()),
", batch_size=", batch_size_);
} else {
StrAppend(&output, "weights=", weights_.DebugString());
@@ -1210,11 +1132,8 @@
mutex_lock lock(plugin_mutex);
if (plugin_initialized) return;
- LOG(INFO) << "Linked TensorRT version: " << NV_TENSORRT_MAJOR << "."
- << NV_TENSORRT_MINOR << "." << NV_TENSORRT_PATCH;
- const int loaded_version = getInferLibVersion();
- LOG(INFO) << "Loaded TensorRT version: " << loaded_version / 1000 << "."
- << (loaded_version / 100) % 10 << "." << loaded_version % 100;
+ LOG(INFO) << "Linked TensorRT version: " << GetLinkedTensorRTVersion();
+ LOG(INFO) << "Loaded TensorRT version: " << GetLoadedTensorRTVersion();
plugin_initialized = initLibNvInferPlugins(trt_logger, "");
if (!plugin_initialized) {
@@ -1451,6 +1370,19 @@
}
}
+#if IS_TRT_VERSION_GE(6, 0, 0, 0)
+ string precision_mode_str;
+ TF_RETURN_IF_ERROR(
+ TrtPrecisionModeToName(precision_mode_, &precision_mode_str));
+ string trt_network_name = StrCat(
+ "TF:", TF_VERSION_STRING, ", ", "TRT:", GetLoadedTensorRTVersion(), "-",
+ "Precision:", precision_mode_str, ", ", "Calibration:", use_calibration_,
+ ", ", "Max-Batch-Size:", max_batch_size, ", ",
+ "Max-Workspace-Size:", max_workspace_size_bytes);
+ VLOG(1) << "Setting TensorRT network name to " << trt_network_name;
+ network()->setName(trt_network_name.c_str());
+#endif // #if IS_TRT_VERSION_GE(6, 0, 0, 0)
+
VLOG(1) << "Building TensorRT engine";
engine->reset(trt_builder_->buildCudaEngine(*network()));
#endif
@@ -2234,23 +2166,22 @@
// argument output_shape and thus the TRT output shape could be wrong
// in case of strides>1.
if (is_conv2d_backprop_input) {
- auto tf_output_shape = backprop_output_size.GetTrtDims();
+ auto tf_output_shape =
+ static_cast<int*>(backprop_output_size.weights().GetValues());
nvinfer1::Dims trt_output_shape = output_tensor->getDimensions();
// What determines the padding size is the difference between the given
// input_sizes (tf_output_shape) and TRT computed size.
- const int height_diff =
- tf_output_shape.d[h_index - 1] - trt_output_shape.d[1];
- const int width_diff =
- tf_output_shape.d[w_index - 1] - trt_output_shape.d[2];
+ const int height_diff = tf_output_shape[h_index] - trt_output_shape.d[1];
+ const int width_diff = tf_output_shape[w_index] - trt_output_shape.d[2];
if ((height_diff < 0) || (width_diff < 0)) {
return errors::InvalidArgument(
"input_sizes argument of Conv2DBackprop (i.e. output_shape argument "
- "of conv2d_transpose)",
+ "of conv2d_transpose) ",
"is too small for the given out_backprop argument of Conv2DBackprop "
- "(i.e. input argument of conv2d_transpose).",
- "(", tf_output_shape.d[h_index - 1], ", ",
- tf_output_shape.d[w_index - 1], ") >= ", "(", trt_output_shape.d[1],
- ", ", trt_output_shape.d[2], ")", node_def.name());
+ "(i.e. input argument of conv2d_transpose). Expect: ",
+ "(", tf_output_shape[h_index], ", ", tf_output_shape[w_index],
+ ") >= ", "(", trt_output_shape.d[1], ", ", trt_output_shape.d[2],
+ ") for op ", node_def.name());
}
// Only add a padding layer if padding sizes are larger than 0
if ((height_diff > 0) || (width_diff > 0)) {
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h
index 6090296..a9f579c 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h
@@ -42,14 +42,6 @@
namespace convert {
using ::stream_executor::port::StatusOr;
-#define IS_TRT_VERSION_GE(major, minor, patch, build) \
- ((NV_TENSORRT_MAJOR > major) || \
- (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR > minor) || \
- (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \
- NV_TENSORRT_PATCH > patch) || \
- (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \
- NV_TENSORRT_PATCH == patch && NV_TENSORRT_BUILD >= build))
-
struct EngineConnection {
// Constructs a non-control edge.
EngineConnection(const string& outside, int out_id, int out_port,
@@ -164,11 +156,6 @@
bool operator()(const Edge* out_edge) const;
};
-string DebugString(const nvinfer1::DimensionType type);
-string DebugString(const nvinfer1::DataType trt_dtype);
-string DebugString(const nvinfer1::Dims& dims);
-string DebugString(const nvinfer1::Permutation& permutation, int len);
-string DebugString(const nvinfer1::ITensor& tensor);
int64_t TrtWeightDimsNumElements(const nvinfer1::Dims& dims);
int64_t TrtTensorDimsNumElements(const nvinfer1::Dims& dims);
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
index 23f4852..fa361c2 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
@@ -1714,15 +1714,14 @@
};
// Reshape at batch dimension, should fail.
- const int kReshapeBatchDimsCases = 5;
- TestParams params[kReshapeBatchDimsCases] = {
+ std::vector<TestParams> params = {
TestParams{1, {1, 2, 3}, {3, 1, 1, 2}},
TestParams{1, {1, 2, -1}, {-1, 1, 1, 2}},
TestParams{1, {1, 2, 3}, {-1, 1, 1, 2}},
TestParams{-1, {1, 2, 3}, {1, 1, 1, 2}},
TestParams{-1, {-1, 2, 3}, {1, 1, 1, 6}}, // TODO(laigd): it should pass.
};
- for (int i = 0; i < kReshapeBatchDimsCases; ++i) {
+ for (int i = 0; i < params.size(); ++i) {
Reset();
const std::vector<int>& dims = params[i].tensor_dims;
AddTestTensor("input", dims, params[i].batch_size);
@@ -1734,8 +1733,7 @@
}
// Reshape on non batch dimensions, ok.
- const int kReshapeOKCases = 8;
- TestParams ok_params[kReshapeOKCases] = {
+ std::vector<TestParams> ok_params = {
TestParams{-1, {1, 2, 3}, {-1, 1, 3, 2}},
TestParams{1, {1, 2, 3}, {-1, 1, 3, 2}},
TestParams{1, {1, 2, 3}, {1, 1, 3, 2}},
@@ -1745,7 +1743,7 @@
TestParams{2, {1, 1}, {2}},
TestParams{2, {}, {2, 1}},
};
- for (int i = 0; i < kReshapeOKCases; ++i) {
+ for (int i = 0; i < ok_params.size(); ++i) {
const int batch_size = std::max(1, ok_params[i].batch_size);
const auto& shape = ok_params[i].shape;
Reset();
@@ -2549,14 +2547,13 @@
};
// Ok.
- const int kCombinedNMSOKCases = 1;
- TestParams ok_params[kCombinedNMSOKCases] = {
+ std::vector<TestParams> ok_params = {
// TODO(aaroey): there is a bug in TRT's CombinedNonMaxSuppression
// implementation that, the extra output classes that are outside of the
// range specified by valid_detections[i] are not zeros but -1s.
TestParams{{1, 1, 4}, {1, 3}, 3, 2, .5f, 0, {2, 4}, {2}, {2}}};
- for (int i = 0; i < kCombinedNMSOKCases; ++i) {
+ for (int i = 0; i < ok_params.size(); ++i) {
Reset();
AddTestTensor("boxes", ok_params[i].boxes_tensor_dims);
@@ -2814,14 +2811,13 @@
};
// Ok.
- const int kExpandDimsOKCases = 8;
- TestParams ok_params[kExpandDimsOKCases] = {
+ std::vector<TestParams> ok_params = {
TestParams{{2, 3}, 1, {1, 2, 3}}, TestParams{{2, 3}, -3, {1, 2, 3}},
TestParams{{2, 3}, 3, {2, 3, 1}}, TestParams{{2, 3}, -1, {2, 3, 1}},
TestParams{{2, 3}, 2, {2, 1, 3}}, TestParams{{2, 3}, -2, {2, 1, 3}},
TestParams{{6}, 1, {1, 6}}, TestParams{{6}, -1, {6, 1}},
};
- for (int i = 0; i < kExpandDimsOKCases; ++i) {
+ for (int i = 0; i < ok_params.size(); ++i) {
Reset();
AddTestTensor("input", ok_params[i].input_dims);
AddTestWeights<int32>("weights", {1}, {ok_params[i].axis});
@@ -2931,8 +2927,7 @@
};
// Ok.
- const int kSqueezeOKCases = 10;
- TestParams ok_params[kSqueezeOKCases] = {
+ std::vector<TestParams> ok_params = {
TestParams{{1, 2, 3}, {1}, {2, 3}},
TestParams{{1, 2, 3}, {-3}, {2, 3}},
TestParams{{2, 3, 1}, {3}, {2, 3}},
@@ -2944,7 +2939,7 @@
TestParams{{1, 6}, {1}, {6}},
TestParams{{6, 1}, {2}, {6}},
};
- for (int i = 0; i < kSqueezeOKCases; ++i) {
+ for (int i = 0; i < ok_params.size(); ++i) {
Reset();
NodeDef node_def = get_squeeze_nodedef(ok_params[i].axis);
AddTestTensor("input", ok_params[i].input_dims);
@@ -3114,13 +3109,8 @@
// Same input is used for all tests.
const std::vector<float> ok_input = {1, 2, 3, 4, 5, 6};
-#if IS_TRT_VERSION_GE(5, 1, 3, 1)
- const int kStridedSliceOKCases = 31;
-#else
- const int kStridedSliceOKCases = 27;
-#endif
// Ok.
- TestParams ok_params[kStridedSliceOKCases] = {
+ std::vector<TestParams> ok_params = {
// 2D Crop.
TestParams{
/*input_dims=*/{1, 2, 3},
@@ -3484,6 +3474,7 @@
/*expected_output_dims=*/{1, 2, 1},
/*expected_output=*/{2, 5},
},
+#if IS_TRT_VERSION_GE(5, 1, 3, 1)
TestParams{
/*input_dims=*/{1, 2, 3},
/*begin=*/{0, 0, 0, 0, 1},
@@ -3537,9 +3528,10 @@
/*expected_output_dims=*/{},
/*expected_output=*/{1},
},
+#endif // IS_TRT_VERSION_GE(5, 1, 3, 1)
};
- for (int i = 0; i < kStridedSliceOKCases; i++) {
+ for (int i = 0; i < ok_params.size(); i++) {
Reset();
NodeDef node_def = get_strided_slice_nodedef(
ok_params[i].begin_mask, ok_params[i].end_mask,
@@ -3672,8 +3664,7 @@
};
// Ok.
- const int kSliceOKCases = 5;
- TestParams ok_params[kSliceOKCases] = {
+ std::vector<TestParams> ok_params = {
TestParams{{1, 2, 3},
{0, 0, 0, 0},
{-1, -1, -1, -1},
@@ -3687,7 +3678,7 @@
TestParams{{6}, {0, 1}, {-1, 3}, {3}, {2, 3, 4}},
};
- for (int i = 0; i < kSliceOKCases; i++) {
+ for (int i = 0; i < ok_params.size(); i++) {
Reset();
NodeDef node_def = get_slice_nodedef();
AddTestTensor("input", ok_params[i].input_dims);
@@ -3856,8 +3847,7 @@
};
// Ok.
- const int kConv2DOKCases = 7;
- TestParams ok_params[kConv2DOKCases] = {
+ std::vector<TestParams> ok_params = {
// Basic
TestParams{/*input_dims=*/{1, 2, 3},
/*input=*/{0, 1, 2, 3, 3, 4},
@@ -3969,7 +3959,7 @@
};
- for (int i = 0; i < kConv2DOKCases; i++) {
+ for (int i = 0; i < ok_params.size(); i++) {
Reset();
NodeDef node_def = get_conv2d_nodedef(
ok_params[i].strides, ok_params[i].padding, ok_params[i].data_format,
@@ -3978,8 +3968,10 @@
AddTestWeights<float>("weights", ok_params[i].filter_dims,
ok_params[i].filter);
if (ok_params[i].is_conv2d_backprop_input) {
- AddTestWeights<float>("input_sizes", ok_params[i].expected_output_dims,
- ok_params[i].expected_output);
+ std::vector<int> tf_input_sizes = ok_params[i].expected_output_dims;
+ tf_input_sizes.insert(tf_input_sizes.begin(), 1); // Add batch dimension.
+ QCHECK_EQ(4, tf_input_sizes.size());
+ AddTestWeights<int>("input_sizes", {4}, tf_input_sizes);
}
RunValidationAndConversion(node_def);
TRT_TensorOrWeights output;
@@ -4164,8 +4156,7 @@
};
// Start here
- const int kConv3DOKCases = 8;
- TestParams ok_params[kConv3DOKCases] = {
+ std::vector<TestParams> ok_params = {
// Basic - just 1x1 conv - input = output
TestParams{
/*input_dims=*/{1, 3, 3, 3}, // CDHW
@@ -4300,7 +4291,7 @@
};
- for (int i = 0; i < kConv3DOKCases; i++) {
+ for (int i = 0; i < ok_params.size(); i++) {
Reset();
NodeDef node_def = get_conv3d_nodedef(
ok_params[i].strides, ok_params[i].padding, ok_params[i].data_format,
@@ -4384,8 +4375,7 @@
const std::vector<float> common_array{-4, 2, 15, 3, 6, -3, 22, 1, 88,
56, 36, 1, 1, 105, 1, 16, -28, 1,
42, 9, 3, 1, 7, 1, 11, 61, 5};
- const int kPool3DOKCases = 10;
- TestParams ok_params[kPool3DOKCases] = {
+ std::vector<TestParams> ok_params = {
// Basic - just 1x1 max pooling - input = output
TestParams{/*input_dims=*/{1, 3, 3, 3},
/*input=*/common_array,
@@ -4495,7 +4485,7 @@
// the corners
}};
- for (int i = 0; i < kPool3DOKCases; i++) {
+ for (int i = 0; i < ok_params.size(); i++) {
Reset();
NodeDef node_def = get_pool3d_nodedef(
ok_params[i].ksize, ok_params[i].strides, ok_params[i].padding,
@@ -4595,10 +4585,9 @@
};
// Input is the same {1, 2, 3, 4, 5, 6} for all cases.
- const int kGatherOKCases = 11;
const std::vector<CType> params_input = {CType(1), CType(2), CType(3),
CType(4), CType(5), CType(6)};
- TestParams ok_params[kGatherOKCases] = {
+ std::vector<TestParams> ok_params = {
// Vector indices, and output rank is rank(params).
TestParams{
/*params_shape=*/{1, 1, 2, 3},
@@ -4703,7 +4692,7 @@
};
// Ok.
- for (int i = 0; i < kGatherOKCases; i++) {
+ for (int i = 0; i < ok_params.size(); i++) {
test->Reset();
const auto& params_shape = ok_params[i].params_shape;
if (ok_params[i].params_is_tensor) {
@@ -5016,8 +5005,7 @@
InitTestVector<CType>(6, /*start_value=*/CType(6))};
// TODO(hinsu): Use std::vector instead of an array to avoid use of explicit
// size.
- const int kConcatOKCases = 4;
- TestParams ok_params[kConcatOKCases] = {
+ std::vector<TestParams> ok_params = {
{
/*input_shapes=*/{{1, 2, 3}, {1, 2, 3}},
/*input_values=*/common_input,
@@ -5057,7 +5045,7 @@
},
};
- for (int i = 0; i < kConcatOKCases; ++i) {
+ for (int i = 0; i < ok_params.size(); ++i) {
test->Reset();
const int num_inputs = ok_params[i].input_shapes.size();
EXPECT_EQ(num_inputs, ok_params[i].input_values.size());
@@ -5190,8 +5178,7 @@
};
const std::vector<CType> common_input = InitTestVector<CType>(6);
- const int kSplitOKCases = 4;
- TestParams ok_params[kSplitOKCases] = {
+ std::vector<TestParams> ok_params = {
// Identity (num_split = 1)
{/*input_shape=*/{1, 2, 3}, /*value=*/common_input, /*axis=*/1,
/*num_split=*/1, /*expected_output_dims=*/{1, 2, 3},
@@ -5224,7 +5211,7 @@
{InitTestVector<CType>(3), InitTestVector<CType>(3, CType(3))}},
};
- for (int i = 0; i < kSplitOKCases; ++i) {
+ for (int i = 0; i < ok_params.size(); ++i) {
test->Reset();
NodeDef node_def = get_split_nodedef(dtype, ok_params[i].num_split);
// Create inputs.
@@ -5366,8 +5353,7 @@
};
const std::vector<CType> common_input = InitTestVector<CType>(6);
- const int kUnpackOKCases = 4;
- TestParams ok_params[kUnpackOKCases] = {
+ std::vector<TestParams> ok_params = {
{/*input_shape=*/{1, 2, 3}, /*value=*/common_input, /*axis=*/1,
/*num=*/1, /*expected_output_dims=*/{2, 3},
/*expected_outputs=*/{InitTestVector<CType>(6)}},
@@ -5404,7 +5390,7 @@
{CType(5)}}},
};
- for (int i = 0; i < kUnpackOKCases; ++i) {
+ for (int i = 0; i < ok_params.size(); ++i) {
test->Reset();
NodeDef node_def =
get_unpack_nodedef(dtype, ok_params[i].num, ok_params[i].axis);
diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.cc b/tensorflow/compiler/tf2tensorrt/convert/utils.cc
index ca21c19..d142bc5 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/utils.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/utils.cc
@@ -17,6 +17,8 @@
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
namespace tensorrt {
@@ -51,5 +53,101 @@
return Status::OK();
}
+#if GOOGLE_CUDA && GOOGLE_TENSORRT
+using absl::StrAppend;
+using absl::StrCat;
+
+string DebugString(const nvinfer1::DimensionType type) {
+ switch (type) {
+ case nvinfer1::DimensionType::kSPATIAL:
+ return "kSPATIAL";
+ case nvinfer1::DimensionType::kCHANNEL:
+ return "kCHANNEL";
+ case nvinfer1::DimensionType::kINDEX:
+ return "kINDEX";
+ case nvinfer1::DimensionType::kSEQUENCE:
+ return "kSEQUENCE";
+ default:
+ return StrCat(static_cast<int>(type), "=unknown");
+ }
+}
+
+string DebugString(const nvinfer1::Dims& dims) {
+ string out = StrCat("nvinfer1::Dims(nbDims=", dims.nbDims, ", d=");
+ for (int i = 0; i < dims.nbDims; ++i) {
+ StrAppend(&out, dims.d[i]);
+ if (VLOG_IS_ON(2)) {
+ StrAppend(&out, "[", DebugString(dims.type[i]), "],");
+ } else {
+ StrAppend(&out, ",");
+ }
+ }
+ StrAppend(&out, ")");
+ return out;
+}
+
+string DebugString(const nvinfer1::DataType trt_dtype) {
+ switch (trt_dtype) {
+ case nvinfer1::DataType::kFLOAT:
+ return "kFLOAT";
+ case nvinfer1::DataType::kHALF:
+ return "kHALF";
+ case nvinfer1::DataType::kINT8:
+ return "kINT8";
+ case nvinfer1::DataType::kINT32:
+ return "kINT32";
+ default:
+ return "Invalid TRT data type";
+ }
+}
+
+string DebugString(const nvinfer1::Permutation& permutation, int len) {
+ string out = "nvinfer1::Permutation(";
+ for (int i = 0; i < len; ++i) {
+ StrAppend(&out, permutation.order[i], ",");
+ }
+ StrAppend(&out, ")");
+ return out;
+}
+
+string DebugString(const nvinfer1::ITensor& tensor) {
+ return StrCat("nvinfer1::ITensor(@", reinterpret_cast<uintptr_t>(&tensor),
+ ", name=", tensor.getName(),
+ ", dtype=", DebugString(tensor.getType()),
+ ", dims=", DebugString(tensor.getDimensions()), ")");
+}
+
+#endif
+
+string GetLinkedTensorRTVersion() {
+ int major, minor, patch;
+#if GOOGLE_CUDA && GOOGLE_TENSORRT
+ major = NV_TENSORRT_MAJOR;
+ minor = NV_TENSORRT_MINOR;
+ patch = NV_TENSORRT_PATCH;
+#else
+ major = 0;
+ minor = 0;
+ patch = 0;
+#endif
+ return absl::StrCat(major, ".", minor, ".", patch);
+}
+
+string GetLoadedTensorRTVersion() {
+ int major, minor, patch;
+#if GOOGLE_CUDA && GOOGLE_TENSORRT
+ int ver = getInferLibVersion();
+ major = ver / 1000;
+ ver = ver - major * 1000;
+ minor = ver / 100;
+ patch = ver - minor * 100;
+#else
+ major = 0;
+ minor = 0;
+ patch = 0;
+#endif
+ return absl::StrCat(major, ".", minor, ".", patch);
+}
+
} // namespace tensorrt
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.h b/tensorflow/compiler/tf2tensorrt/convert/utils.h
index eb60829..9015c24 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/utils.h
+++ b/tensorflow/compiler/tf2tensorrt/convert/utils.h
@@ -17,9 +17,15 @@
#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_
#include <memory>
+#include <vector>
+#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/status.h"
+#if GOOGLE_CUDA && GOOGLE_TENSORRT
+#include "third_party/tensorrt/NvInfer.h"
+#endif // GOOGLE_CUDA && GOOGLE_TENSORRT
+
namespace tensorflow {
namespace tensorrt {
@@ -45,6 +51,60 @@
Status TrtPrecisionModeFromName(const string& name, TrtPrecisionMode* mode);
+// Define a hash function for vector<TensorShape> because it is used as the key
+// for the engine cache.
+struct VectorTensorShapeHasher {
+ std::size_t operator()(const std::vector<TensorShape>& key) const {
+ return std::hash<std::string>()(TensorShapeUtils::ShapeListString(key));
+ }
+};
+
+#if GOOGLE_CUDA && GOOGLE_TENSORRT
+
+#define IS_TRT_VERSION_GE(major, minor, patch, build) \
+ ((NV_TENSORRT_MAJOR > major) || \
+ (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR > minor) || \
+ (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \
+ NV_TENSORRT_PATCH > patch) || \
+ (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \
+ NV_TENSORRT_PATCH == patch && NV_TENSORRT_BUILD >= build))
+
+string DebugString(const nvinfer1::DimensionType type);
+string DebugString(const nvinfer1::Dims& dims);
+string DebugString(const nvinfer1::DataType trt_dtype);
+string DebugString(const nvinfer1::Permutation& permutation, int len);
+string DebugString(const nvinfer1::ITensor& tensor);
+
+inline bool HasStaticShape(const nvinfer1::Dims& dims) {
+ if (dims.nbDims < 0) return false;
+ for (int d = 0; d < dims.nbDims; ++d) {
+ if (dims.d[d] < 0) return false;
+ }
+ return true;
+}
+
+template <typename TensorShapeType>
+inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape,
+ bool ignore_first_dim) {
+ nvinfer1::Dims trt_dims;
+ const int offset = (ignore_first_dim ? 1 : 0);
+ for (int i = offset; i < shape.dims(); i++) {
+ trt_dims.d[i - offset] = shape.dim_size(i);
+ }
+ trt_dims.nbDims = shape.dims() - offset;
+ return trt_dims;
+}
+
+// Return a string that includes compile time
+// TensorRT library version information {Maj, Min, Patch}.
+string GetLinkedTensorRTVersion();
+
+// Return a string that includes runtime time
+// TensorRT library version information {Maj, Min, Patch}.
+string GetLoadedTensorRTVersion();
+
+#endif // GOOGLE_CUDA && GOOGLE_TENSORRT
+
} // namespace tensorrt
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc
index 9adeed7..c14de3a 100644
--- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc
@@ -529,6 +529,25 @@
EngineContext* engine_context) {
VLOG(1) << "Executing TRT engine: " << name();
auto& cuda_engine = engine_context->cuda_engine;
+
+ if (VLOG_IS_ON(2)) {
+#if IS_TRT_VERSION_GE(6, 0, 0, 0)
+ VLOG(2) << " Network name: " << cuda_engine->getName();
+#endif // #if IS_TRT_VERSION_GE(6, 0, 0, 0)
+ VLOG(2) << " Activation size: " << cuda_engine->getDeviceMemorySize()
+ << " bytes";
+ VLOG(2) << " Workspace size: " << cuda_engine->getWorkspaceSize()
+ << " bytes";
+ VLOG(2) << " Datatype of " << cuda_engine->getNbBindings()
+ << " inputs/outputs";
+ string binding_types = "";
+ for (int i = 0; i < cuda_engine->getNbBindings(); i++) {
+ binding_types += " " + string(cuda_engine->getBindingName(i)) + ": " +
+ DebugString(cuda_engine->getBindingDataType(i)) + "\n";
+ }
+ VLOG(2) << binding_types;
+ }
+
const bool kRetry = true;
// All inputs must have the same batch size, so just get it from the first
// input.
diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h
index 8d603ac..808b689 100644
--- a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h
+++ b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h
@@ -114,14 +114,6 @@
}
};
-// Define a hash function for vector<TensorShape> because it is used as the key
-// for the engine cache.
-struct VectorTensorShapeHasher {
- std::size_t operator()(const std::vector<TensorShape>& key) const {
- return std::hash<std::string>()(TensorShapeUtils::ShapeListString(key));
- }
-};
-
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 242448e..dbc8397 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -48,6 +48,7 @@
"function_ops.cc",
"gather_op.cc",
"gather_op_helpers.h",
+ "gather_scatter_ops.cc",
"identity_op.cc",
"image_ops.cc",
"image_resize_ops.cc",
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
index 4f79ce1..dda0d79 100644
--- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
+++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
@@ -512,22 +512,26 @@
filter_in_depth = filter_shape.dimensions(attrs.num_spatial_dims),
feature_group_count = in_depth / filter_in_depth;
+ // In the case of depthwise convolutions, the computation can be done by the
+ // batch_group_count parameter.
+ bool use_batch_group_count = in_depth > 1 && in_depth == filter_in_depth &&
+ (feature_group_count != 1 || attrs.depthwise);
+
+ if (use_batch_group_count) {
+ feature_group_count = 1;
+ }
+
// The activations (inputs) form the LHS of the convolution.
// Activations have shape: [batch, in_rows, in_cols, ..., in_depth]
// For the gradient computation, we need to:
// 1. In the case of group convolution, move the num_groups dimension before
// the batch dimension
// 2. Swap the roles of the batch and feature dimensions.
- if (feature_group_count != 1 && !attrs.depthwise) {
+ if (!use_batch_group_count && feature_group_count != 1 && !attrs.depthwise) {
activations = TransposeInputForGroupConvolutionBackpropFilter(
activations, input_shape, feature_group_count, n_dim, c_dim);
}
- // In the case of depthwise convolution with no multiplier,
- // the computation can be done by the batch_group_count parameter.
- bool use_batch_group_count =
- filter_tensor_shape.dim_size(num_dims - 1) == 1 && attrs.depthwise;
-
std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
std::vector<int64> window_strides(attrs.num_spatial_dims);
diff --git a/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc b/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc
new file mode 100644
index 0000000..19aa85f
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc
@@ -0,0 +1,102 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class GatherOp : public XlaOpKernel {
+ public:
+ explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {
+ string dnums_attr;
+ OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr));
+ OP_REQUIRES(
+ context, dnums_.ParsePartialFromString(dnums_attr),
+ errors::InvalidArgument("Error parsing gather dimension numbers"));
+ OP_REQUIRES_OK(
+ context, context->GetAttr("indices_are_sorted", &indices_are_sorted_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ std::vector<int64> slice_sizes;
+ OP_REQUIRES_OK(ctx,
+ ctx->ConstantInputAsIntVector("slice_sizes", &slice_sizes));
+ xla::XlaOp result =
+ xla::Gather(ctx->Input("operand"), ctx->Input("start_indices"), dnums_,
+ slice_sizes, indices_are_sorted_);
+ ctx->SetOutput(0, result);
+ }
+
+ private:
+ xla::GatherDimensionNumbers dnums_;
+ bool indices_are_sorted_;
+};
+
+REGISTER_XLA_OP(Name("XlaGather"), GatherOp);
+
+class ScatterOp : public XlaOpKernel {
+ public:
+ explicit ScatterOp(OpKernelConstruction* context) : XlaOpKernel(context) {
+ OP_REQUIRES_OK(
+ context, context->GetAttr("update_computation", &update_computation_));
+ string dnums_attr;
+ OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr));
+ OP_REQUIRES(
+ context, dnums_.ParsePartialFromString(dnums_attr),
+ errors::InvalidArgument("Error parsing scatter dimension numbers"));
+ OP_REQUIRES_OK(
+ context, context->GetAttr("indices_are_sorted", &indices_are_sorted_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const DataType dtype = ctx->input_type(0);
+
+ XlaCompiler::Argument update_computation_arg;
+ update_computation_arg.kind = XlaCompiler::Argument::kParameter;
+ update_computation_arg.type = dtype;
+ update_computation_arg.shape = TensorShape();
+
+ XlaCompiler::CompileOptions compile_options;
+ compile_options.use_tuple_arg = false;
+ compile_options.always_return_tuple = false;
+ compile_options.is_entry_computation = false;
+ XlaCompiler::CompilationResult update_computation;
+ OP_REQUIRES_OK(ctx, ctx->compiler()->CompileFunction(
+ compile_options, *update_computation_,
+ {update_computation_arg, update_computation_arg},
+ &update_computation));
+
+ xla::XlaOp result =
+ xla::Scatter(ctx->Input("operand"), ctx->Input("scatter_indices"),
+ ctx->Input("updates"), *update_computation.computation,
+ dnums_, indices_are_sorted_);
+ ctx->SetOutput(0, result);
+ }
+
+ private:
+ const NameAttrList* update_computation_;
+ xla::ScatterDimensionNumbers dnums_;
+ bool indices_are_sorted_;
+};
+
+REGISTER_XLA_OP(Name("XlaScatter"), ScatterOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
index 33b740a..6b71cca 100644
--- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc
+++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
@@ -665,5 +665,50 @@
})
.Doc("Replica ID.");
+REGISTER_OP("XlaGather")
+ .Input("operand: T")
+ .Input("start_indices: Tindices")
+ .Input("slice_sizes: Tindices")
+ .Attr("dimension_numbers: string")
+ .Attr("indices_are_sorted: bool")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .Output("output: T")
+ .SetShapeFn(UnchangedRank)
+ .Doc(R"doc(
+Wraps the XLA Gather operator documented at
+ https://www.tensorflow.org/xla/operation_semantics#gather
+operand: The array we're gathering from.
+start_indices: Array containing the starting indices of the slices we gather.
+dimension_numbers: A serialized xla::GatherDimensionNumbers proto.
+slice_sizes: slice_sizes[i] is the bounds for the slice on dimension i.
+indices_are_sorted: Boolean indicating if the indices are sorted.
+)doc");
+
+REGISTER_OP("XlaScatter")
+ .Input("operand: T")
+ .Input("scatter_indices: Tindices")
+ .Input("updates: T")
+ .Attr("update_computation: func")
+ .Attr("dimension_numbers: string")
+ .Attr("indices_are_sorted: bool")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .Output("output: T")
+ .SetShapeFn(UnchangedRank)
+ .Doc(R"doc(
+Wraps the XLA Scatter operator documented at
+ https://www.tensorflow.org/xla/operation_semantics#scatter.
+
+operand: Array to be scattered into.
+scatter_indices: Array containing the starting indices of the slices that must
+ be scattered to.
+updates: Array containing the values that must be used for scattering.
+update_computation: Computation to be used for combining the existing values in
+ the input array and the updates during scatter.
+dimension_numbers: A serialized xla::ScatterDimensionNumbers proto.
+indices_are_sorted: Boolean indicating if the indices are sorted.
+)doc");
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py
index eff6f82..bf25848 100644
--- a/tensorflow/compiler/tf2xla/python/xla.py
+++ b/tensorflow/compiler/tf2xla/python/xla.py
@@ -416,3 +416,27 @@
key_value_sort = gen_xla_ops.xla_key_value_sort
while_loop = gen_xla_ops.xla_while
dequantize = gen_xla_ops.xla_dequantize
+
+
+def gather(operand, start_indices, dimension_numbers, slice_sizes,
+ indices_are_sorted=False, name=None):
+ return gen_xla_ops.xla_gather(
+ operand,
+ start_indices,
+ slice_sizes=slice_sizes,
+ dimension_numbers=dimension_numbers.SerializeToString(),
+ indices_are_sorted=indices_are_sorted,
+ name=name)
+
+
+def scatter(operand, scatter_indices, updates, update_computation,
+ dimension_numbers, indices_are_sorted=False, name=None):
+ return gen_xla_ops.xla_scatter(
+ operand,
+ scatter_indices,
+ updates,
+ update_computation=update_computation,
+ dimension_numbers=dimension_numbers.SerializeToString(),
+ indices_are_sorted=indices_are_sorted,
+ name=name)
+
diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD
index fd31fb1..4581d85 100644
--- a/tensorflow/compiler/xla/client/BUILD
+++ b/tensorflow/compiler/xla/client/BUILD
@@ -253,6 +253,7 @@
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_matchers",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
)
diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc
index 8c85482..9153ac9 100644
--- a/tensorflow/compiler/xla/client/lib/math.cc
+++ b/tensorflow/compiler/xla/client/lib/math.cc
@@ -15,9 +15,7 @@
#include "tensorflow/compiler/xla/client/lib/math.h"
-// This macro is required to make MSVC defines math constants in math.h
-#define _USE_MATH_DEFINES
-#include <math.h>
+#include <cmath>
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
diff --git a/tensorflow/compiler/xla/client/lib/tridiagonal.cc b/tensorflow/compiler/xla/client/lib/tridiagonal.cc
index d2ea6d5..13cc363 100644
--- a/tensorflow/compiler/xla/client/lib/tridiagonal.cc
+++ b/tensorflow/compiler/xla/client/lib/tridiagonal.cc
@@ -36,6 +36,8 @@
struct TridiagonalSystemShape {
const int64 rank;
const int64 num_equations;
+ TridiagonalSystemShape(int64 rk, int64 num_eqs)
+ : rank(rk), num_equations(num_eqs) {}
};
Status CheckSecondToLastDimension(const Shape& op_shape, int64 rank,
@@ -109,9 +111,7 @@
TF_RETURN_IF_ERROR(CheckSecondToLastDimension(upper_diagonal_shape, rank, 1,
"upper diagonal"));
- TridiagonalSystemShape result = {.rank = rank,
- .num_equations = num_equations};
- return result;
+ return TridiagonalSystemShape(rank, num_equations);
}
XlaOp Coefficient(XlaOp operand, int64 i) {
diff --git a/tensorflow/compiler/xla/g3doc/index.md b/tensorflow/compiler/xla/g3doc/index.md
index 39715fb..38c6672 100644
--- a/tensorflow/compiler/xla/g3doc/index.md
+++ b/tensorflow/compiler/xla/g3doc/index.md
@@ -81,32 +81,19 @@
### Explicit compilation
Explicit compilation API offers a more fine-grained control for choosing which
-functions should be compiled with XLA. However, it requires restructuring source
-code, as not all TensorFlow operations can be represented in XLA. That is, using
-explicit compilation on API on functions which can not be represented in XLA
-results in an exception.
+functions should be compiled with XLA. However, it might require restructuring
+of the source code, as not all TensorFlow operations can be represented in XLA.
-#### TF2: Use `@tf.function(experimental_compile=True)`
+Note: Using the explicit compilation on API on functions which can not be
+represented in XLA results in an exception.
Optimizing sections of the program using
[`tf.function`](https://www.tensorflow.org/api_docs/python/tf/function) is a
-standard approach for
-[improving performance](https://www.tensorflow.org/tutorials/customization/performance)
-of TF2 programs. You can enable compilation with XLA by setting the
-`experimental_compile` argument of `tf.function` to `True`.
-
-Note: `experimental_compile` only works in
-[eager](https://www.tensorflow.org/guide/eager) mode.
-
-#### TF1: Use `xla.compile`
-
-If you are using TF1, you can use the `xla.compile` API for explicit compilation
-using XLA. See the [tutorial colab](./tutorials/xla_compile.ipynb) for usage
-examples.
-
-Note: Gradient computation of graph in `xla.compile()` is prohibited because it
-can cause performance degradation. To avoid this issue, move gradient
-computation inside `xla.compile()`.
+standard approach for [improving
+performance](https://www.tensorflow.org/tutorials/customization/performance) of
+TF2 programs. You can enable compilation with XLA by setting the
+`experimental_compile` argument of `tf.function` to `True`. See the [tutorial
+colab](./tutorials/experimental_compile.ipynb) for usage examples.
### AOT (Ahead-of-time) compilation for CPU with `tfcompile`
diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md
index ee7b2b2..0185bb4 100644
--- a/tensorflow/compiler/xla/g3doc/operation_semantics.md
+++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md
@@ -2053,8 +2053,8 @@
: : : as to have the same output shape :
: : : as input if the stride is 1, or :
: : : Padding\:\:kValid, which uses no :
-: : : no padding and "stops" the :
-: : : window once it no longer fits) :
+: : : padding and "stops" the window :
+: : : once it no longer fits) :
Below code and figure shows an example of using `ReduceWindow`. Input is a
matrix of size [4x6] and both window_dimensions and window_stride_dimensions are
diff --git a/tensorflow/compiler/xla/g3doc/tutorials/experimental_compile.ipynb b/tensorflow/compiler/xla/g3doc/tutorials/experimental_compile.ipynb
new file mode 100644
index 0000000..c8c08fc
--- /dev/null
+++ b/tensorflow/compiler/xla/g3doc/tutorials/experimental_compile.ipynb
@@ -0,0 +1,268 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "Using XLA with tf.function",
+ "provenance": [],
+ "collapsed_sections": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ }
+ },
+ "cells": [
+ {
+ "metadata": {
+ "colab_type": "text",
+ "id": "f4TSNCvpENrW"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "##### Copyright 2019 The TensorFlow Authors."
+ ]
+ },
+ {
+ "metadata": {
+ "cellView": "form",
+ "colab_type": "code",
+ "id": "vamNSA0vEP-m",
+ "colab": {}
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
+ "# you may not use this file except in compliance with the License.\n",
+ "# You may obtain a copy of the License at\n",
+ "#\n",
+ "# https://www.apache.org/licenses/LICENSE-2.0\n",
+ "#\n",
+ "# Unless required by applicable law or agreed to in writing, software\n",
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+ "# See the License for the specific language governing permissions and\n",
+ "# limitations under the License."
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "e1oSi4lHFt3z"
+ },
+ "source": [
+ "# Using XLA via `tf.function` and `experimental_compile`"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "sDy5lSBd4BDE",
+ "colab_type": "text"
+ },
+ "source": [
+ "In this colab, we train a TensorFlow model to classify the MNIST dataset, where the training function is compiled using XLA.\n",
+ "\n",
+ "We start by loading TensorFlow, with eager execution enabled."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "b7noD9NjFRL-"
+ },
+ "source": [
+ "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
+ " <td>\n",
+ " <a target=\"_blank\" href=\"https://www.tensorflow.org/xla/tutorials/xla_compile\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
+ " </td>\n",
+ " <td>\n",
+ " <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
+ " </td>\n",
+ " <td>\n",
+ " <a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
+ " </td>\n",
+ "</table>"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "45kUPj5ZFrRa"
+ },
+ "source": [
+ "import tensorflow as tf\n",
+ "\n",
+ "tf.enable_eager_execution()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "GZVNiRmTDV-5"
+ },
+ "source": [
+ "Then, we define some necessary constants and prepare the MNIST dataset."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "f37TSEGvGX4_",
+ "colab": {}
+ },
+ "source": [
+ "# Size of each input image, 28 x 28 pixels\n",
+ "IMAGE_SIZE = 28 * 28\n",
+ "# Number of distinct number labels, [0..9]\n",
+ "NUM_CLASSES = 10\n",
+ "# Number of examples in each training batch (step)\n",
+ "TRAIN_BATCH_SIZE = 100\n",
+ "# Number of training steps to run\n",
+ "TRAIN_STEPS = 1000\n",
+ "\n",
+ "# Loads MNIST dataset.\n",
+ "train, test = tf.keras.datasets.mnist.load_data()\n",
+ "train_ds = tf.data.Dataset.from_tensor_slices(train).batch(TRAIN_BATCH_SIZE).repeat()\n",
+ "\n",
+ "# Casting from raw data to the required datatypes.\n",
+ "def cast(images, labels):\n",
+ " images = tf.cast(\n",
+ " tf.reshape(images, [-1, IMAGE_SIZE]), tf.float32)\n",
+ " labels = tf.cast(labels, tf.int64)\n",
+ " return (images, labels)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "lv7I-u_82v1S",
+ "colab_type": "text"
+ },
+ "source": [
+ "Finally, we define the model and the optimizer. For the model, we shall use a single dense layer."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "7O2NcEfG206Q",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "layer = tf.keras.layers.Dense(NUM_CLASSES)\n",
+ "optimizer = tf.keras.optimizers.Adam()\n"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "x_ZehpZP-SfS"
+ },
+ "source": [
+ "# Define the training function\n",
+ "\n",
+ "In the training function, we get predicted labels using the layer defined above, and then we minimize the gradient of the loss using the optimizer. In order to compile the computation using XLA, we place it inside `tf.function` with `experimental_compile=True`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "ZbhJl_WvGa3g",
+ "colab": {}
+ },
+ "source": [
+ "@tf.function(experimental_compile=True)\n",
+ "def train_mnist(images, labels):\n",
+ " images, labels = cast(images, labels)\n",
+ "\n",
+ " with tf.GradientTape() as tape:\n",
+ " predicted_labels = layer(images)\n",
+ " loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
+ " logits=predicted_labels, labels=labels\n",
+ " ))\n",
+ " layer_variables = layer.trainable_variables\n",
+ " grads = tape.gradient(loss, layer_variables)\n",
+ " optimizer.apply_gradients(zip(grads, layer_variables))\n"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "EZD1m_n1DxAF"
+ },
+ "source": [
+ "# Train and test the model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "gukC2Hol3sFZ",
+ "colab_type": "text"
+ },
+ "source": [
+ "Once we have defined the training function, we can define the model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "qe28bAHNHUG2",
+ "colab": {}
+ },
+ "source": [
+ "for images, labels in train_ds:\n",
+ " if optimizer.iterations > TRAIN_STEPS:\n",
+ " break\n",
+ " train_mnist(images, labels)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "qgsKmz3n2UiW"
+ },
+ "source": [
+ "And, finally, check the accuracy:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "_GxF6jTRHVuA"
+ },
+ "source": [
+ "images, labels = cast(test[0], test[1])\n",
+ "predicted_labels = layer(images)\n",
+ "correct_prediction = tf.equal(tf.argmax(predicted_labels, 1), labels)\n",
+ "accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n",
+ "print(\"Prediction accuracy after training: %s\" % accuracy)"
+ ],
+ "execution_count": 0
+ }
+ ]
+}
diff --git a/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb b/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb
deleted file mode 100644
index 715585d..0000000
--- a/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb
+++ /dev/null
@@ -1,373 +0,0 @@
-{
- "nbformat": 4,
- "nbformat_minor": 0,
- "metadata": {
- "colab": {
- "name": "The XLA compile API",
- "version": "0.3.2",
- "provenance": [],
- "collapsed_sections": [],
- "toc_visible": true
- },
- "kernelspec": {
- "name": "python3",
- "display_name": "Python 3"
- }
- },
- "cells": [
- {
- "metadata": {
- "colab_type": "text",
- "id": "f4TSNCvpENrW"
- },
- "cell_type": "markdown",
- "source": [
- "##### Copyright 2018 The TensorFlow Authors."
- ]
- },
- {
- "metadata": {
- "cellView": "form",
- "colab_type": "code",
- "id": "vamNSA0vEP-m",
- "colab": {}
- },
- "cell_type": "code",
- "source": [
- "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
- "# you may not use this file except in compliance with the License.\n",
- "# You may obtain a copy of the License at\n",
- "#\n",
- "# https://www.apache.org/licenses/LICENSE-2.0\n",
- "#\n",
- "# Unless required by applicable law or agreed to in writing, software\n",
- "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
- "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
- "# See the License for the specific language governing permissions and\n",
- "# limitations under the License."
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "colab_type": "text",
- "id": "e1oSi4lHFt3z"
- },
- "cell_type": "markdown",
- "source": [
- "# The XLA compile API"
- ]
- },
- {
- "metadata": {
- "colab_type": "text",
- "id": "b7noD9NjFRL-"
- },
- "cell_type": "markdown",
- "source": [
- "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
- " <td>\n",
- " <a target=\"_blank\" href=\"https://www.tensorflow.org/xla/tutorials/xla_compile\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
- " </td>\n",
- " <td>\n",
- " <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
- " </td>\n",
- " <td>\n",
- " <a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
- " </td>\n",
- "</table>"
- ]
- },
- {
- "metadata": {
- "colab_type": "text",
- "id": "v9YbsuLZaBXy"
- },
- "cell_type": "markdown",
- "source": [
- "\n",
- "\n",
- "Import TensorFlow and the XLA library. XLA contains `xla.compile()`, an API that compiles part or all of a model with [XLA](https://www.tensorflow.org/extend/xla/)."
- ]
- },
- {
- "metadata": {
- "colab_type": "code",
- "id": "45kUPj5ZFrRa",
- "colab": {}
- },
- "cell_type": "code",
- "source": [
- "import tensorflow as tf\n",
- "\n",
- "from tensorflow.contrib.compiler import xla"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "colab_type": "text",
- "id": "GZVNiRmTDV-5"
- },
- "cell_type": "markdown",
- "source": [
- "Define some necessary constants and prepare the MNIST dataset."
- ]
- },
- {
- "metadata": {
- "colab_type": "code",
- "id": "f37TSEGvGX4_",
- "colab": {}
- },
- "cell_type": "code",
- "source": [
- "# Size of each input image, 28 x 28 pixels\n",
- "IMAGE_SIZE = 28 * 28\n",
- "# Number of distinct number labels, [0..9]\n",
- "NUM_CLASSES = 10\n",
- "# Number of examples in each training batch (step)\n",
- "TRAIN_BATCH_SIZE = 100\n",
- "# Number of training steps to run\n",
- "TRAIN_STEPS = 1000"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "colab_type": "code",
- "id": "TiVXchblG5hK",
- "colab": {}
- },
- "cell_type": "code",
- "source": [
- "# Loads MNIST dataset.\n",
- "train, test = tf.keras.datasets.mnist.load_data()\n",
- "train_ds = tf.data.Dataset.from_tensor_slices(train).batch(TRAIN_BATCH_SIZE).repeat()\n",
- "test_ds = tf.data.Dataset.from_tensor_slices(test).batch(TRAIN_BATCH_SIZE)\n",
- "\n",
- "iterator = tf.data.Iterator.from_structure(train_ds.output_types, train_ds.output_shapes)\n",
- "images, labels = iterator.get_next()\n",
- "images = tf.reshape(images, [-1, IMAGE_SIZE])\n",
- "images, labels = tf.cast(images, tf.float32), tf.cast(labels, tf.int64)"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "colab_type": "text",
- "id": "x_ZehpZP-SfS"
- },
- "cell_type": "markdown",
- "source": [
- "# Define the model constructing function\n",
- "\n",
- "Following code block contains a function that constructs a simple model with one dense layer, including both forward and backward propagation.\n",
- "\n",
- "When called, it returns two values. `y` is a `tf.Tensor` representing predicted probability of each target class, `train_step` is a `tf.Operation` that increments `global_step` and applies variable update."
- ]
- },
- {
- "metadata": {
- "colab_type": "code",
- "id": "ZbhJl_WvGa3g",
- "colab": {}
- },
- "cell_type": "code",
- "source": [
- "def build_mnist_model(x, y_):\n",
- " y = tf.keras.layers.Dense(NUM_CLASSES).apply(x)\n",
- "\n",
- " cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=y_, logits=y)\n",
- " train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)\n",
- "\n",
- " return y, train_step"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "colab_type": "text",
- "id": "7Jh3lyQHDfM9"
- },
- "cell_type": "markdown",
- "source": [
- "# Enable XLA\n",
- "\n",
- "Use `xla.compile` with the `build_mnist_model` function to enable XLA. Following code block wraps the model with `xla.compile()`, which allows the target function with provided inputs to be executed by XLA."
- ]
- },
- {
- "metadata": {
- "colab_type": "code",
- "id": "kYpCXCdRHNuN",
- "colab": {}
- },
- "cell_type": "code",
- "source": [
- "[y] = xla.compile(build_mnist_model, inputs=[images, labels])"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "colab_type": "text",
- "id": "4giQh62IrZGF"
- },
- "cell_type": "markdown",
- "source": [
- "When compiling the graph, XLA replaces all the graph nodes constructed in the target function with a few XLA ops.\n",
- "\n",
- "xla.compile does not return any\n",
- "`tf.Operation` nodes that can be executed independently from the generated XLA ops. Instead, returned `tf.Operation` nodes from the target function are added as control dependencies of all returned `tf.Tensor` values. This triggers execution of the `tf.Operation` nodes when the returned tensors are evaluated.\n",
- "\n",
- "In pseudo-code, xla.compile's implementation looks as follows:\n",
- "\n",
- "---\n",
- "```\n",
- "# Ask Tensorflow to execute code in XLA-friendly manner\n",
- "\n",
- "y, train_step = build_mnist_model(images, labels)\n",
- "with tf.control_dependencies([train_step]):\n",
- " y = tf.identity(y)\n",
- "\n",
- "# Ask Tensorflow to STOP executing code in XLA-friendly manner\n",
- "```\n",
- "---\n",
- "\n",
- "xla.compile() always returns a list of `tf.Tensor`'s (even if there is only one-element)."
- ]
- },
- {
- "metadata": {
- "colab_type": "text",
- "id": "TPGas4jjFLZl"
- },
- "cell_type": "markdown",
- "source": [
- "If you were to print the constructed graph now, you will see that it is not much different from a normal Tensorflow graph and you won't be able to find XLA ops mentioned before. This is because the actual compilation happens later when you try to execute the graph with `sess.run()`. At that time, Tensorflow triggers a series of graph rewrite passes that actually generate XLA ops, which compiles and executes computation when all inputs are ready."
- ]
- },
- {
- "metadata": {
- "colab_type": "text",
- "id": "EZD1m_n1DxAF"
- },
- "cell_type": "markdown",
- "source": [
- "# Train and test the model"
- ]
- },
- {
- "metadata": {
- "colab_type": "code",
- "id": "qe28bAHNHUG2",
- "colab": {}
- },
- "cell_type": "code",
- "source": [
- "# Creates session and initialize all variables.\n",
- "# xla.compile() doesn't work with Keras model.fit() API or TF eager mode yet.\n",
- "sess = tf.Session()\n",
- "sess.run(tf.global_variables_initializer())"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "colab_type": "text",
- "id": "qgsKmz3n2UiW"
- },
- "cell_type": "markdown",
- "source": [
- "Following code block trains model. Evaluating `y` also triggers its control dependency node `train_step`, which updates model variables."
- ]
- },
- {
- "metadata": {
- "colab_type": "code",
- "id": "_GxF6jTRHVuA",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 34
- },
- "outputId": "fbf299ca-02d5-4e95-f9fe-8f3c0432d132"
- },
- "cell_type": "code",
- "source": [
- "# Feeds training dataset\n",
- "sess.run(iterator.make_initializer(train_ds))\n",
- "\n",
- "# Runs TRAIN_STEPS steps\n",
- "for i in range(TRAIN_STEPS):\n",
- " sess.run(y)\n",
- "\n",
- "print(\"Model trained for %s steps.\" % TRAIN_STEPS)"
- ],
- "execution_count": 21,
- "outputs": [
- {
- "output_type": "stream",
- "text": [
- "Model trained for 1000 steps.\n"
- ],
- "name": "stdout"
- }
- ]
- },
- {
- "metadata": {
- "colab_type": "code",
- "id": "dHlQlRSRHXD1",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 34
- },
- "outputId": "9c3677a2-ec84-406f-9d2c-d722844f3093"
- },
- "cell_type": "code",
- "source": [
- "# Tests trained model\n",
- "\n",
- "# Feeds testing dataset\n",
- "sess.run(iterator.make_initializer(test_ds))\n",
- "\n",
- "# Calculates accuracy\n",
- "correct_prediction = tf.equal(tf.argmax(y, 1), labels)\n",
- "accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n",
- "print(\"Prediction accuracy after training: %s\" % sess.run(accuracy))"
- ],
- "execution_count": 22,
- "outputs": [
- {
- "output_type": "stream",
- "text": [
- "Prediction accuracy after training: 0.91\n"
- ],
- "name": "stdout"
- }
- ]
- },
- {
- "metadata": {
- "colab_type": "code",
- "id": "ynJQIuzjHYOb",
- "colab": {}
- },
- "cell_type": "code",
- "source": [
- "# Cleans up session\n",
- "sess.close()"
- ],
- "execution_count": 0,
- "outputs": []
- }
- ]
-}
diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc
index 3d6310c..da172c7 100644
--- a/tensorflow/compiler/xla/literal.cc
+++ b/tensorflow/compiler/xla/literal.cc
@@ -738,14 +738,14 @@
const Shape& result_shape, absl::Span<const int64> start_indices) const {
Literal result_literal(result_shape);
DimensionVector new_indices(result_shape.rank());
- result_literal.EachCell<NativeT>(
- [&](absl::Span<const int64> indices, NativeT /*value*/) {
- for (int64 i = 0; i < result_shape.rank(); ++i) {
- new_indices[i] = indices[i] + start_indices[i];
- }
- NativeT value = Get<NativeT>(new_indices);
- result_literal.Set<NativeT>(indices, value);
- });
+ CHECK(result_literal
+ .Populate<NativeT>([&](absl::Span<const int64> indices) {
+ for (int64 i = 0; i < result_shape.rank(); ++i) {
+ new_indices[i] = indices[i] + start_indices[i];
+ }
+ return Get<NativeT>(new_indices);
+ })
+ .ok());
return result_literal;
}
diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc
index 9b17cb7..f2784c7 100644
--- a/tensorflow/compiler/xla/literal_test.cc
+++ b/tensorflow/compiler/xla/literal_test.cc
@@ -2061,6 +2061,11 @@
EXPECT_FALSE(c6.GetAsComplex128({}).has_value());
}
+TEST_F(LiteralUtilTest, SliceOnBool) {
+ Literal c1 = LiteralUtil::CreateR1<bool>({true, true, false});
+ EXPECT_EQ(c1, c1.Slice({0}, {3}));
+}
+
TEST_F(LiteralUtilTest, IsEqualAt) {
double val_double = 10.0;
int val_integral = 10;
diff --git a/tensorflow/compiler/xla/python/tpu_driver/BUILD b/tensorflow/compiler/xla/python/tpu_driver/BUILD
index 96c6636..b796fe8 100644
--- a/tensorflow/compiler/xla/python/tpu_driver/BUILD
+++ b/tensorflow/compiler/xla/python/tpu_driver/BUILD
@@ -61,6 +61,7 @@
hdrs = ["grpc_tpu_driver.h"],
deps = [
":tpu_driver",
+ "//tensorflow:grpc++",
"//tensorflow/core/platform:logging",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:util",
@@ -73,6 +74,25 @@
)
cc_library(
+ name = "external_tpu_driver",
+ srcs = ["external_tpu_driver.cc"],
+ deps = [
+ ":tpu_driver",
+ "@com_google_absl//absl/strings:str_format",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/core/platform:logging",
+ "//tensorflow/compiler/xla:status",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto_cc",
+ "//tensorflow/compiler/xla/service:hlo_proto_cc",
+ ":tpu_service_proto_cc",
+ ":tpu_driver_proto_cc",
+ "//tensorflow/compiler/xla/python/tpu_driver/client:c_api",
+ ] + external_deps(),
+ alwayslink = 1,
+)
+
+cc_library(
name = "recording_tpu_driver",
srcs = [
"recording_tpu_driver.cc",
diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/c_api.h b/tensorflow/compiler/xla/python/tpu_driver/client/c_api.h
index 7e301de..228128c 100644
--- a/tensorflow/compiler/xla/python/tpu_driver/client/c_api.h
+++ b/tensorflow/compiler/xla/python/tpu_driver/client/c_api.h
@@ -32,17 +32,81 @@
typedef struct TpuBufferHandleInternal TpuBufferHandleInternal;
+typedef struct TpuCompiledProgramHandleInternal
+ TpuCompiledProgramHandleInternal;
+
+typedef struct TpuLoadedProgramHandleInternal TpuLoadedProgramHandleInternal;
+
typedef struct TpuBufferHandle {
TpuBufferHandleInternal* internal_handle;
TpuEvent* event;
+ int64_t size_in_bytes;
} TpuBufferHandle;
+typedef struct TpuCompiledProgramHandle {
+ TpuCompiledProgramHandleInternal* internal_handle;
+ TpuEvent* event;
+} TpuCompiledProgramHandle;
+
+typedef struct TpuLoadedProgramHandle {
+ TpuLoadedProgramHandleInternal* internal_handle;
+ TpuEvent* event;
+} TpuLoadedProgramHandle;
+
+typedef struct HloProto {
+ void* bytes;
+ int32_t size;
+} HloProto;
+
+typedef struct DeviceAssignmentProto {
+ void* bytes;
+ int32_t size;
+} DeviceAssignmentProto;
+
+typedef struct TpuStatus {
+ int32_t code;
+ char* msg;
+} TpuStatus;
+
+typedef struct CompiledProgramShape {
+ struct TpuStatus* status;
+ void* bytes;
+ int32_t size;
+} CompiledProgramShape;
+
typedef void(PrototypeTpuDriver_Initialize)(struct TpuDriverFn* driver_fn);
typedef struct TpuDriver*(PrototypeTpuDriver_Open)(const char* worker);
typedef void(PrototypeTpuDriver_Close)(struct TpuDriver* driver);
+// TODO(frankchn): Make this not a hard-coded constant.
const int32_t MemoryRegion_HBM = 1;
+typedef struct TpuCompiledProgramHandle*(PrototypeTpuDriver_CompileProgram)(
+ struct TpuDriver* driver, const struct HloProto& source,
+ int32_t num_replicas, int32_t eventc, struct TpuEvent** eventv);
+
+typedef struct TpuLoadedProgramHandle*(PrototypeTpuDriver_LoadProgram)(
+ struct TpuDriver* driver, int32_t core_id,
+ const struct TpuCompiledProgramHandle* compiled_program_handle,
+ int32_t eventc, struct TpuEvent** eventv);
+
+typedef struct TpuEvent*(PrototypeTpuDriver_UnloadProgram)(
+ struct TpuDriver* driver,
+ struct TpuLoadedProgramHandle* loaded_program_handle, int32_t eventc,
+ struct TpuEvent** eventv);
+
+typedef struct TpuEvent*(PrototypeTpuDriver_ExecuteProgram)(
+ struct TpuDriver* driver, struct TpuLoadedProgramHandle* handle,
+ int32_t inputc, struct TpuBufferHandle** input_buffer_handle,
+ int32_t outputc, struct TpuBufferHandle** output_buffer_handle,
+ const struct DeviceAssignmentProto& device_assignment, int32_t eventc,
+ struct TpuEvent** eventv);
+
+typedef struct TpuBufferHandle*(PrototypeTpuDriver_AllocateTuple)(
+ struct TpuDriver* driver, int32_t core_id, int32_t memory_region,
+ int64_t num_bytes, int32_t bufferc, struct TpuBufferHandle** buffer_handle,
+ int32_t eventc, struct TpuEvent** eventv);
+
typedef struct TpuBufferHandle*(PrototypeTpuDriver_Allocate)(
struct TpuDriver* driver, int32_t core_id, int32_t memory_region,
int64_t num_bytes, int32_t eventc, struct TpuEvent** eventv);
@@ -51,16 +115,69 @@
struct TpuDriver* driver, struct TpuBufferHandle* buffer_handle,
int32_t eventc, struct TpuEvent** eventv);
+typedef struct TpuEvent*(PrototypeTpuDriver_TransferToDevice)(
+ struct TpuDriver* driver, const void* src, struct TpuBufferHandle* dst,
+ int32_t eventc, struct TpuEvent** eventv);
+
+typedef struct TpuEvent*(PrototypeTpuDriver_TransferFromDevice)(
+ struct TpuDriver* driver, struct TpuBufferHandle* src, void* dst,
+ int32_t eventc, struct TpuEvent** eventv);
+
+typedef struct TpuEvent*(PrototypeTpuDriver_TransferFromDeviceToDevice)(
+ struct TpuDriver* driver, struct TpuBufferHandle* src,
+ struct TpuBufferHandle* dst, int32_t eventc, struct TpuEvent** eventv);
+
+typedef struct CompiledProgramShape*(
+ PrototypeTpuDriver_GetCompiledProgramShape)(
+ struct TpuCompiledProgramHandle* handle);
+
+typedef void(PrototypeTpuDriver_FreeCompiledProgramShape)(
+ struct CompiledProgramShape* shape);
+
+typedef void(PrototypeTpuDriver_EventAddCallback)(
+ struct TpuEvent* event,
+ void (*callback_fn)(struct TpuStatus*, void* additional_info),
+ void* additional_info);
+
+typedef struct TpuStatus*(PrototypeTpuDriver_EventAwait)(struct TpuEvent* event,
+ int64_t timeout_in_us);
+
typedef void(PrototypeTpuDriver_FreeEvent)(struct TpuEvent* event);
+typedef void(PrototypeTpuDriver_FreeStatus)(struct TpuStatus* status);
+
typedef const char*(PrototypeTpuDriver_Version)();
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Initialize TpuDriver_Initialize;
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Open TpuDriver_Open;
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Close TpuDriver_Close;
+TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_CompileProgram
+ TpuDriver_CompileProgram;
+TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_LoadProgram
+ TpuDriver_LoadProgram;
+TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_UnloadProgram
+ TpuDriver_UnloadProgram;
+TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_ExecuteProgram
+ TpuDriver_ExecuteProgram;
+TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_AllocateTuple
+ TpuDriver_AllocateTuple;
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Allocate TpuDriver_Allocate;
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Deallocate TpuDriver_Deallocate;
+TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_TransferToDevice
+ TpuDriver_TransferToDevice;
+TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_TransferFromDevice
+ TpuDriver_TransferFromDevice;
+TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_TransferFromDeviceToDevice
+ TpuDriver_TransferFromDeviceToDevice;
+TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_GetCompiledProgramShape
+ TpuDriver_GetCompiledProgramShape;
+TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_FreeCompiledProgramShape
+ TpuDriver_FreeCompiledProgramShape;
+TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_EventAddCallback
+ TpuDriver_EventAddCallback;
+TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_EventAwait TpuDriver_EventAwait;
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_FreeEvent TpuDriver_FreeEvent;
+TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_FreeStatus TpuDriver_FreeStatus;
TPUDRIVER_CAPI_EXPORT extern PrototypeTpuDriver_Version TpuDriver_Version;
#ifdef __cplusplus
@@ -68,12 +185,29 @@
#endif
struct TpuDriverFn {
- PrototypeTpuDriver_Open* TpuDriver_Open; // NOLINT
- PrototypeTpuDriver_Close* TpuDriver_Close; // NOLINT
- PrototypeTpuDriver_Allocate* TpuDriver_Allocate; // NOLINT
- PrototypeTpuDriver_Deallocate* TpuDriver_Deallocate; // NOLINT
- PrototypeTpuDriver_FreeEvent* TpuDriver_FreeEvent; // NOLINT
- PrototypeTpuDriver_Version* TpuDriver_Version; // NOLINT
+ PrototypeTpuDriver_Open* TpuDriver_Open; // NOLINT
+ PrototypeTpuDriver_Close* TpuDriver_Close; // NOLINT
+ PrototypeTpuDriver_CompileProgram* TpuDriver_CompileProgram; // NOLINT
+ PrototypeTpuDriver_LoadProgram* TpuDriver_LoadProgram; // NOLINT
+ PrototypeTpuDriver_UnloadProgram* TpuDriver_UnloadProgram; // NOLINT
+ PrototypeTpuDriver_ExecuteProgram* TpuDriver_ExecuteProgram; // NOLINT
+ PrototypeTpuDriver_AllocateTuple* TpuDriver_AllocateTuple; // NOLINT
+ PrototypeTpuDriver_Allocate* TpuDriver_Allocate; // NOLINT
+ PrototypeTpuDriver_Deallocate* TpuDriver_Deallocate; // NOLINT
+ PrototypeTpuDriver_TransferToDevice* TpuDriver_TransferToDevice; // NOLINT
+ PrototypeTpuDriver_TransferFromDevice*
+ TpuDriver_TransferFromDevice; // NOLINT
+ PrototypeTpuDriver_TransferFromDeviceToDevice*
+ TpuDriver_TransferFromDeviceToDevice; // NOLINT
+ PrototypeTpuDriver_GetCompiledProgramShape*
+ TpuDriver_GetCompiledProgramShape; // NOLINT
+ PrototypeTpuDriver_FreeCompiledProgramShape*
+ TpuDriver_FreeCompiledProgramShape; // NOLINT
+ PrototypeTpuDriver_EventAddCallback* TpuDriver_EventAddCallback; // NOLINT
+ PrototypeTpuDriver_EventAwait* TpuDriver_EventAwait; // NOLINT
+ PrototypeTpuDriver_FreeEvent* TpuDriver_FreeEvent; // NOLINT
+ PrototypeTpuDriver_FreeStatus* TpuDriver_FreeStatus; // NOLINT
+ PrototypeTpuDriver_Version* TpuDriver_Version; // NOLINT
};
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_CLIENT_C_API_H_
diff --git a/tensorflow/compiler/xla/python/tpu_driver/external_tpu_driver.cc b/tensorflow/compiler/xla/python/tpu_driver/external_tpu_driver.cc
new file mode 100644
index 0000000..8a8e868
--- /dev/null
+++ b/tensorflow/compiler/xla/python/tpu_driver/external_tpu_driver.cc
@@ -0,0 +1,387 @@
+// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// ==============================================================================
+
+#include <dlfcn.h>
+
+#include "absl/strings/str_format.h"
+#include "absl/time/time.h"
+#include "tensorflow/compiler/xla/python/tpu_driver/client/c_api.h"
+#include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
+#include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace tpu_driver {
+namespace {
+
+class ExternalTpuDriver;
+
+class ExternalEvent : public Event {
+ public:
+ explicit ExternalEvent(::TpuDriverFn* driver_fn, ::TpuEvent* event)
+ : driver_fn_(driver_fn), event_(event) {}
+
+ ~ExternalEvent() override { driver_fn_->TpuDriver_FreeEvent(event_); }
+
+ xla::Status Await() override {
+ auto tpu_status = driver_fn_->TpuDriver_EventAwait(event_, -1);
+ auto ret = xla::Status(tensorflow::error::Code(tpu_status->code),
+ absl::StrFormat("%s", tpu_status->msg));
+ driver_fn_->TpuDriver_FreeStatus(tpu_status);
+ return ret;
+ }
+
+ absl::optional<xla::Status> AwaitWithTimeout(
+ absl::Duration duration) override {
+ auto tpu_status_or = driver_fn_->TpuDriver_EventAwait(
+ event_, absl::ToInt64Microseconds(duration));
+ if (tpu_status_or == nullptr) {
+ return absl::nullopt;
+ } else {
+ auto ret = xla::Status(tensorflow::error::Code(tpu_status_or->code),
+ absl::StrFormat("%s", tpu_status_or->msg));
+ driver_fn_->TpuDriver_FreeStatus(tpu_status_or);
+ return ret;
+ }
+ }
+
+ void AddCallback(std::function<void(xla::Status)> callback) override {
+ // We have to create a new copy of the fn on the heap to make it persist.
+ std::function<void(xla::Status)>* callback_addr =
+ new std::function<void(xla::Status)>(callback);
+
+ // Using the callback_addr instead of capturing because C++11 lambdas with
+ // variable captures cannot be converted to C function pointers.
+ driver_fn_->TpuDriver_EventAddCallback(
+ event_,
+ [](struct TpuStatus* status, void* additional_info) {
+ auto callback_addr =
+ static_cast<std::function<void(xla::Status)>*>(additional_info);
+ auto xla_status = xla::Status(tensorflow::error::Code(status->code),
+ absl::StrFormat("%s", status->msg));
+ (*callback_addr)(xla_status);
+ delete callback_addr;
+ },
+ callback_addr);
+ }
+
+ private:
+ ::TpuDriverFn* driver_fn_;
+ ::TpuEvent* event_;
+
+ friend ExternalTpuDriver;
+};
+
+class ExternalBufferHandle : public BufferHandle {
+ public:
+ explicit ExternalBufferHandle(::TpuDriverFn* driver_fn,
+ ::TpuBufferHandle* handle)
+ : handle_(handle), event_(new ExternalEvent(driver_fn, handle->event)) {}
+
+ std::shared_ptr<Event> OnReady() override { return event_; }
+
+ int64_t size_in_bytes() override { return handle_->size_in_bytes; }
+
+ absl::optional<xla::ShapeProto> shape() override {
+ LOG(FATAL) << "Unimplemented.";
+ return absl::nullopt;
+ }
+
+ private:
+ ::TpuBufferHandle* handle_;
+ std::shared_ptr<ExternalEvent> event_;
+
+ friend ExternalTpuDriver;
+};
+
+class ExternalCompiledProgramHandle : public CompiledProgramHandle {
+ public:
+ explicit ExternalCompiledProgramHandle(::TpuDriverFn* driver_fn,
+ ::TpuCompiledProgramHandle* handle)
+ : handle_(handle),
+ driver_fn_(driver_fn),
+ event_(new ExternalEvent(driver_fn, handle->event)) {}
+
+ std::shared_ptr<Event> OnReady() override { return event_; }
+
+ int64_t size_in_bytes() override {
+ LOG(FATAL) << "Unimplemented.";
+ return 0;
+ }
+
+ xla::Status program_shape(xla::ProgramShapeProto* program_shape) override {
+ struct CompiledProgramShape* shape =
+ driver_fn_->TpuDriver_GetCompiledProgramShape(handle_);
+ program_shape->ParseFromArray(shape->bytes, shape->size);
+
+ auto status = xla::Status(tensorflow::error::Code(shape->status->code),
+ absl::StrFormat("%s", shape->status->msg));
+ driver_fn_->TpuDriver_FreeCompiledProgramShape(shape);
+
+ return status;
+ }
+
+ private:
+ ::TpuCompiledProgramHandle* handle_;
+ ::TpuDriverFn* driver_fn_;
+ std::shared_ptr<ExternalEvent> event_;
+
+ friend ExternalTpuDriver;
+};
+
+class ExternalLoadedProgramHandle : public LoadedProgramHandle {
+ public:
+ explicit ExternalLoadedProgramHandle(::TpuDriverFn* driver_fn,
+ ::TpuLoadedProgramHandle* handle)
+ : handle_(handle), event_(new ExternalEvent(driver_fn, handle->event)) {}
+ std::shared_ptr<Event> OnReady() override { return event_; }
+
+ int64_t size_in_bytes() override {
+ LOG(FATAL) << "Unimplemented.";
+ return 0;
+ }
+
+ private:
+ ::TpuLoadedProgramHandle* handle_;
+ std::shared_ptr<ExternalEvent> event_;
+
+ friend ExternalTpuDriver;
+};
+
+class ExternalTpuDriver : public TpuDriver {
+ public:
+ explicit ExternalTpuDriver(const std::string& so_path) {
+ void* handle;
+ handle = dlopen(so_path.c_str(), RTLD_NOW);
+ if (!handle) {
+ LOG(FATAL) << "Unable to load shared library: " << dlerror();
+ }
+
+ PrototypeTpuDriver_Initialize* initialize_fn;
+ *reinterpret_cast<void**>(&initialize_fn) =
+ dlsym(handle, "TpuDriver_Initialize");
+ initialize_fn(&driver_fn_);
+
+ driver_ = driver_fn_.TpuDriver_Open("local://");
+ }
+
+ ~ExternalTpuDriver() override {}
+
+ void QuerySystemInfo(SystemInfo* system_info) override {
+ LOG(FATAL) << "Unimplemented.";
+ }
+
+ xla::Status Reset() override { LOG(FATAL) << "Unimplemented."; }
+
+ std::unique_ptr<BufferHandle> Allocate(
+ int32_t core_id, MemoryRegion region, int64_t num_bytes,
+ absl::Span<Event* const> wait_for) override {
+ auto tpu_events = MakeEventArray(wait_for);
+ auto bh = absl::make_unique<ExternalBufferHandle>(
+ &driver_fn_,
+ driver_fn_.TpuDriver_Allocate(driver_, core_id, region, num_bytes,
+ wait_for.size(), tpu_events));
+ delete tpu_events;
+ return bh;
+ }
+
+ std::unique_ptr<BufferHandle> Allocate(
+ int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape,
+ absl::Span<Event* const> wait_for) override {
+ LOG(FATAL) << "Unimplemented.";
+ return nullptr;
+ }
+
+ std::unique_ptr<BufferHandle> AllocateTuple(
+ int32_t core_id, MemoryRegion region,
+ absl::Span<BufferHandle* const> children,
+ absl::Span<Event* const> wait_for) override {
+ LOG(FATAL) << "Unimplemented.";
+ return nullptr;
+ }
+
+ std::shared_ptr<Event> Deallocate(
+ std::unique_ptr<BufferHandle> handle,
+ absl::Span<Event* const> wait_for) override {
+ auto tpu_events = MakeEventArray(wait_for);
+ auto event = std::make_shared<ExternalEvent>(
+ &driver_fn_,
+ driver_fn_.TpuDriver_Deallocate(
+ driver_, static_cast<ExternalBufferHandle*>(handle.get())->handle_,
+ wait_for.size(), tpu_events));
+ delete tpu_events;
+ return event;
+ }
+
+ std::shared_ptr<Event> TransferToDevice(
+ const void* src, BufferHandle* dst,
+ absl::Span<Event* const> wait_for) override {
+ auto tpu_events = MakeEventArray(wait_for);
+ auto event = std::make_shared<ExternalEvent>(
+ &driver_fn_,
+ driver_fn_.TpuDriver_TransferToDevice(
+ driver_, src, static_cast<ExternalBufferHandle*>(dst)->handle_,
+ wait_for.size(), tpu_events));
+ delete tpu_events;
+ return event;
+ }
+
+ std::shared_ptr<Event> TransferFromDevice(
+ const BufferHandle* src, void* dst,
+ absl::Span<Event* const> wait_for) override {
+ auto tpu_events = MakeEventArray(wait_for);
+ auto event = std::make_shared<ExternalEvent>(
+ &driver_fn_,
+ driver_fn_.TpuDriver_TransferFromDevice(
+ driver_, static_cast<const ExternalBufferHandle*>(src)->handle_,
+ dst, wait_for.size(), tpu_events));
+ delete tpu_events;
+ return event;
+ }
+
+ std::shared_ptr<Event> TransferFromDeviceToDevice(
+ const BufferHandle* src, BufferHandle* dst,
+ absl::Span<Event* const> wait_for) override {
+ auto tpu_events = MakeEventArray(wait_for);
+ auto event = std::make_shared<ExternalEvent>(
+ &driver_fn_,
+ driver_fn_.TpuDriver_TransferFromDeviceToDevice(
+ driver_, static_cast<const ExternalBufferHandle*>(src)->handle_,
+ static_cast<ExternalBufferHandle*>(dst)->handle_, wait_for.size(),
+ tpu_events));
+ delete tpu_events;
+ return event;
+ }
+
+ std::unique_ptr<CompiledProgramHandle> CompileProgram(
+ const xla::HloProto& source, int32_t num_replicas,
+ absl::Span<Event* const> wait_for) override {
+ auto tpu_events = MakeEventArray(wait_for);
+
+ struct HloProto hlo;
+ hlo.size = source.ByteSizeLong();
+ hlo.bytes = malloc(hlo.size);
+ if (!source.SerializeToArray(hlo.bytes, hlo.size)) {
+ LOG(ERROR) << "Unable to serialize HLO to array.";
+ return nullptr;
+ }
+
+ auto handle = absl::make_unique<ExternalCompiledProgramHandle>(
+ &driver_fn_,
+ driver_fn_.TpuDriver_CompileProgram(driver_, hlo, num_replicas,
+ wait_for.size(), tpu_events));
+
+ free(hlo.bytes);
+ delete tpu_events;
+ return handle;
+ }
+ std::unique_ptr<LoadedProgramHandle> LoadProgram(
+ int32_t core_id, const CompiledProgramHandle* handle,
+ absl::Span<Event* const> wait_for) override {
+ auto tpu_events = MakeEventArray(wait_for);
+
+ auto loaded_handle = absl::make_unique<ExternalLoadedProgramHandle>(
+ &driver_fn_,
+ driver_fn_.TpuDriver_LoadProgram(
+ driver_, core_id,
+ static_cast<const ExternalCompiledProgramHandle*>(handle)->handle_,
+ wait_for.size(), tpu_events));
+
+ delete tpu_events;
+ return loaded_handle;
+ }
+
+ std::shared_ptr<Event> UnloadProgram(
+ std::unique_ptr<LoadedProgramHandle> handle,
+ absl::Span<Event* const> wait_for) override {
+ auto tpu_events = MakeEventArray(wait_for);
+ auto event = std::make_shared<ExternalEvent>(
+ &driver_fn_,
+ driver_fn_.TpuDriver_UnloadProgram(
+ driver_,
+ static_cast<ExternalLoadedProgramHandle*>(handle.get())->handle_,
+ wait_for.size(), tpu_events));
+ delete tpu_events;
+ return event;
+ }
+
+ std::shared_ptr<Event> ExecuteProgram(
+ LoadedProgramHandle* program, absl::Span<BufferHandle* const> inputs,
+ absl::Span<BufferHandle* const> outputs,
+ const xla::DeviceAssignmentProto& device_assignment,
+ absl::Span<Event* const> wait_for) override {
+ auto tpu_events = MakeEventArray(wait_for);
+
+ struct DeviceAssignmentProto da_proto;
+ da_proto.size = device_assignment.ByteSizeLong();
+ da_proto.bytes = malloc(da_proto.size);
+ if (!device_assignment.SerializeToArray(da_proto.bytes, da_proto.size)) {
+ LOG(ERROR) << "Unable to serialize device assignment to array.";
+ return nullptr;
+ }
+
+ std::vector<::TpuBufferHandle*> inputv;
+ inputv.reserve(inputs.size());
+ for (int i = 0; i < inputs.size(); i++) {
+ inputv.push_back(
+ static_cast<ExternalBufferHandle* const>(inputs[i])->handle_);
+ }
+ std::vector<::TpuBufferHandle*> outputv;
+ outputv.reserve(outputs.size());
+ for (int i = 0; i < outputs.size(); i++) {
+ outputv.push_back(
+ static_cast<ExternalBufferHandle* const>(outputs[i])->handle_);
+ }
+
+ auto event = std::make_shared<ExternalEvent>(
+ &driver_fn_,
+ driver_fn_.TpuDriver_ExecuteProgram(
+ driver_,
+ static_cast<ExternalLoadedProgramHandle*>(program)->handle_,
+ inputs.size(), inputv.data(), outputs.size(), outputv.data(),
+ da_proto, wait_for.size(), tpu_events));
+
+ free(da_proto.bytes);
+ return event;
+ }
+
+ std::unique_ptr<TpuLinearizer> GetLinearizer() override { return nullptr; }
+
+ private:
+ ::TpuDriverFn driver_fn_;
+ ::TpuDriver* driver_;
+
+ ::TpuEvent** MakeEventArray(absl::Span<Event* const> wait_for) {
+ if (wait_for.empty()) return nullptr;
+ ::TpuEvent** ret = new ::TpuEvent*[wait_for.size()];
+ for (int i = 0; i < wait_for.size(); i++) {
+ ret[i] = static_cast<ExternalEvent* const>(wait_for[i])->event_;
+ }
+ return ret;
+ }
+};
+
+xla::StatusOr<std::unique_ptr<TpuDriver>> RegisterExternalTpuDriver(
+ const TpuDriverConfig& config) {
+ std::string shared_lib = config.worker().substr(strlen("external://"));
+ return xla::StatusOr<std::unique_ptr<TpuDriver>>(
+ absl::make_unique<ExternalTpuDriver>(shared_lib));
+}
+
+REGISTER_TPU_DRIVER("external://", RegisterExternalTpuDriver);
+
+} // namespace
+} // namespace tpu_driver
diff --git a/tensorflow/compiler/xla/python/tpu_driver/platform/external/tools.bzl b/tensorflow/compiler/xla/python/tpu_driver/platform/external/tools.bzl
index d2823ae..99b07b6 100644
--- a/tensorflow/compiler/xla/python/tpu_driver/platform/external/tools.bzl
+++ b/tensorflow/compiler/xla/python/tpu_driver/platform/external/tools.bzl
@@ -33,5 +33,4 @@
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
- "//tensorflow:grpc++",
]
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index b4ea4d9..1fcadc9 100755
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -608,7 +608,6 @@
":hlo",
":hlo_parser",
"//tensorflow/compiler/xla:test",
- "//tensorflow/compiler/xla/tests:xla_internal_test_main",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
@@ -1844,6 +1843,7 @@
":hlo_creation_utils",
":hlo_parser",
":hlo_pass",
+ ":hlo_pass_pipeline",
":pattern_matcher",
":pattern_matcher_gmock",
":shape_inference",
@@ -1982,6 +1982,7 @@
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
@@ -2018,6 +2019,7 @@
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
],
)
@@ -2053,6 +2055,7 @@
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
@@ -2118,6 +2121,7 @@
":while_loop_simplifier",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"@com_google_absl//absl/strings",
@@ -2179,6 +2183,7 @@
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
@@ -2207,6 +2212,7 @@
":hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
],
)
@@ -2236,6 +2242,7 @@
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
)
@@ -2319,6 +2326,7 @@
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
)
@@ -2339,6 +2347,7 @@
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
)
@@ -2951,6 +2960,7 @@
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
)
@@ -3309,6 +3319,7 @@
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
],
@@ -3450,6 +3461,7 @@
":hlo_element_type_converter",
":hlo_matchers",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
@@ -3837,6 +3849,7 @@
":sort_simplifier",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
)
@@ -3868,6 +3881,7 @@
":stable_sort_expander",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
)
@@ -3959,6 +3973,7 @@
":while_loop_invariant_code_motion",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
)
@@ -3986,6 +4001,7 @@
":while_loop_constant_sinking",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
)
@@ -4047,6 +4063,7 @@
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
"@com_google_absl//absl/strings",
],
@@ -4095,9 +4112,9 @@
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/tests:verified_hlo_module",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
- "//tensorflow/core:test_main", # fixdeps: keep
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
old mode 100755
new mode 100644
index f145b44..0225d2d
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -80,6 +80,68 @@
}
}
+bool IsAnyOperandComplex(const HloInstruction* hlo) {
+ for (auto operand : hlo->operands()) {
+ if (ShapeUtil::ElementIsComplex(operand->shape())) {
+ return true;
+ }
+ }
+ return false;
+}
+
+bool IsPositive(const HloInstruction* hlo,
+ const AlgebraicSimplifierOptions& options) {
+ // Utility only handles real types.
+ if (IsAnyOperandComplex(hlo)) {
+ return false;
+ }
+ switch (hlo->opcode()) {
+ case HloOpcode::kGetTupleElement: {
+ const HloInstruction* gte_operand = hlo->operand(0);
+ switch (gte_operand->opcode()) {
+ case HloOpcode::kCustomCall: {
+ const auto& target = gte_operand->custom_call_target();
+ return target ==
+ options.get_cudnn_batchnorm_forward_training_metadata() &&
+ hlo->tuple_index() == 2;
+ }
+ default:
+ return false;
+ }
+ }
+ case HloOpcode::kPower:
+ case HloOpcode::kAbs:
+ case HloOpcode::kRsqrt:
+ case HloOpcode::kSqrt:
+ return IsPositive(hlo->operand(0), options);
+
+ case HloOpcode::kMultiply: {
+ return hlo->operand(0) == hlo->operand(1) &&
+ IsPositive(hlo->operand(0), options);
+ }
+ default:
+ return false;
+ }
+}
+
+bool IsNonNegative(const HloInstruction* hlo,
+ const AlgebraicSimplifierOptions& options) {
+ // Utility only handles real types.
+ if (IsAnyOperandComplex(hlo)) {
+ return false;
+ }
+ switch (hlo->opcode()) {
+ case HloOpcode::kMultiply: {
+ return hlo->operand(0) == hlo->operand(1);
+ }
+ case HloOpcode::kAbs: {
+ return true;
+ }
+ default:
+ return IsPositive(hlo, options);
+ }
+}
+
// Checks whether `op` is a floating-point constant or broadcast of a constant
// of the form +/- 2^k for some integer k positive, negative, or zero. Such
// values are interesting because multiplying by a power of 2 just moves the
@@ -212,6 +274,8 @@
AlgebraicSimplifier* simplifier)
: options_(options), simplifier_(simplifier) {}
+ Status HandleAbs(HloInstruction* abs) override;
+
Status HandleAdd(HloInstruction* add) override;
Status HandleAnd(HloInstruction* logical_and) override;
@@ -279,8 +343,15 @@
Status HandleReduceWindow(HloInstruction* reduce_window) override;
Status HandleReverse(HloInstruction* reverse) override;
+
+ Status HandleRsqrt(HloInstruction* rsqrt) override;
+
Status HandleSlice(HloInstruction* slice) override;
+
+ Status HandleSqrt(HloInstruction* sqrt) override;
+
Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
+
Status HandleDynamicUpdateSlice(
HloInstruction* dynamic_update_slice) override;
Status HandleScatter(HloInstruction* scatter) override;
@@ -501,6 +572,16 @@
return true;
}
+Status AlgebraicSimplifierVisitor::HandleAbs(HloInstruction* abs) {
+ HloInstruction* abs_operand = abs->mutable_operand(0);
+ VLOG(10) << "trying transform [Abs(A) => A] " << abs->ToString()
+ << " Abs operand is: " << abs_operand->ToString();
+ if (IsNonNegative(abs->operand(0), options_)) {
+ return ReplaceInstruction(abs, abs_operand);
+ }
+ return Status::OK();
+}
+
Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
HloInstruction *lhs, *rhs;
CHECK(Match(add, m::Add(m::Op(&lhs), m::Op(&rhs))));
@@ -2127,24 +2208,24 @@
Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) {
HloInstruction *lhs, *rhs;
CHECK(Match(multiply, m::Multiply(m::Op(&lhs), m::Op(&rhs))));
- // A*1 => A
- VLOG(10) << "trying transform [A*1 => A]: " << multiply->ToString();
+ // LHS*1 => LHS
+ VLOG(10) << "trying transform [LHS*1 => LHS]: " << multiply->ToString();
if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(multiply, lhs)) {
return Status::OK();
}
- // 1*A => A
- VLOG(10) << "trying transform [1*A => A]: " << multiply->ToString();
+ // 1*RHS => RHS
+ VLOG(10) << "trying transform [1*RHS => RHS]: " << multiply->ToString();
if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(multiply, rhs)) {
return Status::OK();
}
- // 0*A => 0. Only applies for integral types for correct NaN-handling.
+ // 0*RHS => 0. Only applies for integral types for correct NaN-handling.
if (IsAll(lhs, 0) &&
primitive_util::IsIntegralType(multiply->shape().element_type()) &&
ReplaceInstructionIfSameShape(multiply, lhs)) {
return Status::OK();
}
- // A*0 => 0
+ // LHS*0 => 0
if (IsAll(rhs, 0) &&
primitive_util::IsIntegralType(multiply->shape().element_type()) &&
ReplaceInstructionIfSameShape(multiply, rhs)) {
@@ -2174,7 +2255,8 @@
product_of_constants));
}
- // exp(A) * exp(B) => exp(A+B)
+ VLOG(10) << "trying to transform exp(LHS) * exp(RHS) => exp(LHS+RHS) "
+ << multiply->ToString();
if (Match(multiply, m::Multiply(m::Exp(m::Op(&lhs)), m::Exp(m::Op(&rhs))))) {
auto add = computation_->AddInstruction(HloInstruction::CreateBinary(
multiply->shape(), HloOpcode::kAdd, lhs, rhs));
@@ -2182,6 +2264,18 @@
multiply,
HloInstruction::CreateUnary(multiply->shape(), HloOpcode::kExp, add));
}
+
+ VLOG(10) << "trying transform [rsqrt(B) * rsqrt(B) => 1/B] "
+ << multiply->ToString();
+ HloInstruction* b;
+ if (Match(multiply, m::Multiply(m::Rsqrt(m::Op(&b)), m::Rsqrt(m::Op(&b)))) &&
+ IsPositive(b, options_)) {
+ return ReplaceWithNewInstruction(
+ multiply,
+ HloInstruction::CreateBinary(multiply->shape(), HloOpcode::kDivide,
+ MakeScalarLike(b, 1), b));
+ }
+
return Status::OK();
}
@@ -3329,6 +3423,31 @@
return Status::OK();
}
+Status AlgebraicSimplifierVisitor::HandleRsqrt(HloInstruction* rsqrt) {
+ VLOG(10) << "trying transform [rsqrt(Pow(A, -2)) => |A|] "
+ << rsqrt->ToString();
+ HloInstruction* rsqrt_operand = rsqrt->mutable_operand(0);
+ if (rsqrt_operand->opcode() == HloOpcode::kPower &&
+ IsAll(rsqrt_operand->operand(1), -2) &&
+ IsPositive(rsqrt_operand, options_)) {
+ return ReplaceWithNewInstruction(
+ rsqrt, HloInstruction::CreateUnary(rsqrt->shape(), HloOpcode::kAbs,
+ rsqrt_operand->mutable_operand(0)));
+ }
+
+ VLOG(10) << "trying transform [rsqrt(Divide(1, A)) => sqrt(A)] "
+ << rsqrt->ToString();
+ if (rsqrt_operand->opcode() == HloOpcode::kDivide &&
+ IsAll(rsqrt_operand->operand(0), 1) &&
+ IsPositive(rsqrt_operand->operand(1), options_)) {
+ return ReplaceWithNewInstruction(
+ rsqrt, HloInstruction::CreateUnary(rsqrt->shape(), HloOpcode::kSqrt,
+ rsqrt_operand->mutable_operand(1)));
+ }
+
+ return Status::OK();
+}
+
Status AlgebraicSimplifierVisitor::HandleDynamicSlice(
HloInstruction* dynamic_slice) {
auto operand = dynamic_slice->mutable_operand(0);
@@ -3813,6 +3932,19 @@
return Status::OK();
}
+Status AlgebraicSimplifierVisitor::HandleSqrt(HloInstruction* sqrt) {
+ VLOG(10) << "trying transform [sqrt(A*A) => |A|] " << sqrt->ToString();
+ HloInstruction* sqrt_operand = sqrt->mutable_operand(0);
+ if (sqrt_operand->opcode() == HloOpcode::kMultiply &&
+ sqrt_operand->operand(0) == sqrt_operand->operand(1)) {
+ return ReplaceWithNewInstruction(
+ sqrt, HloInstruction::CreateUnary(
+ sqrt_operand->mutable_operand(0)->shape(), HloOpcode::kAbs,
+ sqrt_operand->mutable_operand(0)));
+ }
+ return Status::OK();
+}
+
namespace {
bool OnlyPermutesDegenerateDims(const Shape& shape,
absl::Span<const int64> perm) {
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h
index 74d8b1d..ce364a1 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.h
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h
@@ -99,7 +99,27 @@
int64 very_small_gather_size() const { return very_small_gather_size_; }
+ void set_cudnn_batchnorm_forward_training_metadata(const string& c) {
+ metadata_.cudnn_batchnorm_forward_training_metadata = c;
+ }
+
+ const string& get_cudnn_batchnorm_forward_training_metadata() const {
+ return metadata_.cudnn_batchnorm_forward_training_metadata;
+ }
+
private:
+ // Metadata struct can be used to store any metadata information encapsulated
+ // with the AlgebraicSimplierOptions that can be later used in an
+ // AlgebraicSimplifier pass. For example,
+ // cudnn_batchnorm_forward_training_metadata can be used to store the name of
+ // a custom call. If the custom call is
+ // __cudnn$batchNormalizationForwardTraining, the output with index 2 is
+ // guaranteed to be postive. This property has been used to recursively
+ // determine if the operand of an instruction is always positive.
+ struct Metadata {
+ string cudnn_batchnorm_forward_training_metadata{""};
+ Metadata() {}
+ };
ReshapeIsBitcastCallback reshape_is_bitcast_callback_;
bool is_layout_sensitive_{false};
bool enable_dot_strength_reduction_{true};
@@ -107,6 +127,7 @@
bool enable_conv_simplification_{true};
bool enable_window_reduce_to_reduce_replacement_{true};
int64 very_small_gather_size_{4};
+ Metadata metadata_;
};
// A pass which performs algebraic simplifications.
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index f37ff53..b4e66eb 100755
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -31,6 +31,7 @@
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
@@ -5847,5 +5848,243 @@
GmockMatch(m::Parameter(1)));
}
+TEST_F(AlgebraicSimplifierTest, SqrtOfSelfMultiply) {
+ const char* kModuleStr = R"(
+ HloModule m
+ test {
+ p0 = f32[32]{0} parameter(0)
+ m0 = f32[32]{0} multiply(f32[32]{0} p0, f32[32]{0} p0)
+ ROOT s0 = f32[32]{0} sqrt(f32[32]{0} m0)
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
+ ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
+ EXPECT_THAT(m->entry_computation()->root_instruction(),
+ GmockMatch(m::Abs(m::Parameter(0))));
+}
+
+TEST_F(AlgebraicSimplifierTest, RsqrtOfRPower) {
+ const char* kModuleStr = R"(
+ HloModule m
+ test {
+ p0 = f32[128,32,2,112]{3,2,1,0} parameter(0)
+ p1 = f32[32]{0} parameter(1)
+ p2 = f32[32]{0} parameter(2)
+ c0 = f32[] constant(0.001)
+ c1 = s64[] constant(1)
+ custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, c0, c1), custom_call_target="__cudnn$batchNormalizationForwardTraining"
+ get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0
+ get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1
+ get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2
+ c2 = f32[] constant(-2)
+ broadcast = f32[32]{0} broadcast(f32[] c2), dimensions={}
+ power = f32[32]{0} power(get-tuple-element, broadcast)
+ rsqrt = f32[32]{0} rsqrt(f32[32]{0} power)
+ ROOT tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, rsqrt)
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
+ default_options_.set_cudnn_batchnorm_forward_training_metadata(
+ "__cudnn$batchNormalizationForwardTraining");
+ ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
+ // Expect transformation: rsqrt(power(gte.2,-2)) -> abs(gte.2)
+ EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kPower), nullptr);
+ EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kRsqrt), nullptr);
+ auto computation = m->entry_computation();
+ auto root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
+ EXPECT_EQ(root->operand(2)->opcode(), HloOpcode::kAbs);
+ EXPECT_EQ(root->operand(2)->operand(0)->opcode(),
+ HloOpcode::kGetTupleElement);
+}
+
+TEST_F(AlgebraicSimplifierTest, RsqrtDivide) {
+ const char* kModuleStr = R"(
+ HloModule m
+ test {
+ p0 = f32[128,32,2,112]{3,2,1,0} parameter(0)
+ p1 = f32[32]{0} parameter(1)
+ p2 = f32[32]{0} parameter(2)
+ constant = f32[] constant(0.001)
+ constant.1 = s64[] constant(1)
+ custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, constant, constant.1), custom_call_target="__cudnn$batchNormalizationForwardTraining"
+ get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0
+ get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1
+ get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2
+ constant.2 = f32[] constant(1)
+ broadcast.1 = f32[32]{0} broadcast(constant.2), dimensions={}
+ divide = f32[32]{0} divide(broadcast.1, get-tuple-element)
+ rsqrt = f32[32]{0} rsqrt(divide)
+ ROOT tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, rsqrt)
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
+ default_options_.set_cudnn_batchnorm_forward_training_metadata(
+ "__cudnn$batchNormalizationForwardTraining");
+ ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
+ // Expect transformation: rsqrt(divide(1,gte.2)) -> sqrt(gte.2)
+ EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kDivide), nullptr);
+ EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kRsqrt), nullptr);
+ auto computation = m->entry_computation();
+ auto root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
+ EXPECT_EQ(root->operand(2)->opcode(), HloOpcode::kSqrt);
+ EXPECT_EQ(root->operand(2)->operand(0)->opcode(),
+ HloOpcode::kGetTupleElement);
+}
+
+TEST_F(AlgebraicSimplifierTest, MultiplySelfRsqrt) {
+ const char* kModuleStr = R"(
+ HloModule m
+ test {
+ p0 = f32[128,32,2,112]{3,2,1,0} parameter(0)
+ p1 = f32[32]{0} parameter(1)
+ p2 = f32[32]{0} parameter(2)
+ constant = f32[] constant(0.001)
+ constant.1 = s64[] constant(1)
+ custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, constant, constant.1), custom_call_target="__cudnn$batchNormalizationForwardTraining"
+ get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0
+ get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1
+ get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2
+ rsqrt = f32[32]{0} rsqrt(get-tuple-element)
+ multiply = f32[32]{0} multiply(rsqrt, rsqrt)
+ ROOT tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, multiply)
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
+ default_options_.set_cudnn_batchnorm_forward_training_metadata(
+ "__cudnn$batchNormalizationForwardTraining");
+ ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
+
+ // Expect transformation: multiply(rsqrt(gte.2), rsqrt(gte.2)) -> divide(1,
+ // gte.2)
+ EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kMultiply), nullptr);
+ EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kRsqrt), nullptr);
+
+ auto computation = m->entry_computation();
+ auto root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
+ EXPECT_EQ(root->operand(2)->opcode(), HloOpcode::kDivide);
+ EXPECT_EQ(root->operand(2)->operand(0)->opcode(), HloOpcode::kBroadcast);
+ EXPECT_EQ(root->operand(2)->operand(1)->opcode(),
+ HloOpcode::kGetTupleElement);
+}
+
+TEST_F(AlgebraicSimplifierTest, MultiplySelfRsqrt_NegativeTestCase) {
+ const char* kModuleStr = R"(
+ HloModule m
+ test {
+ p0 = f32[128,32,2,112]{3,2,1,0} parameter(0)
+ p1 = f32[32]{0} parameter(1)
+ p2 = f32[32]{0} parameter(2)
+ constant = f32[] constant(0.001)
+ constant.1 = s64[] constant(1)
+ custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, constant, constant.1), custom_call_target="__cudnn$batchNormalizationForwardTraining"
+ get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0
+ get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1
+ get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2
+ rsqrt = f32[32]{0} rsqrt(get-tuple-element)
+ multiply = f32[32]{0} multiply(rsqrt, rsqrt)
+ ROOT tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, multiply)
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
+ default_options_.set_cudnn_batchnorm_forward_training_metadata(
+ "__cudnn$batchNormalizationForward");
+ ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
+ EXPECT_NE(FindInstruction(m.get(), HloOpcode::kMultiply), nullptr);
+ EXPECT_NE(FindInstruction(m.get(), HloOpcode::kRsqrt), nullptr);
+ EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kDivide), nullptr);
+ EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kBroadcast), nullptr);
+ EXPECT_EQ(m->entry_computation()->root_instruction()->operand(2)->opcode(),
+ HloOpcode::kMultiply);
+}
+
+TEST_F(AlgebraicSimplifierTest, AbsEliminationBatchnormTraining) {
+ const char* kModuleStr = R"(
+ HloModule m
+ test {
+ p0 = f32[128,32,2,112]{3,2,1,0} parameter(0)
+ p1 = f32[32]{0} parameter(1)
+ p2 = f32[32]{0} parameter(2)
+ constant = f32[] constant(0.001)
+ constant.1 = s64[] constant(1)
+ custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, constant, constant.1), custom_call_target="__cudnn$batchNormalizationForwardTraining"
+ get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0
+ get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1
+ get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2
+ abs = f32[32]{0} abs(get-tuple-element)
+ ROOT %tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, abs)
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
+ default_options_.set_cudnn_batchnorm_forward_training_metadata(
+ "__cudnn$batchNormalizationForwardTraining");
+ ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
+ // Verify that the module doesn't have any abs node.
+ EXPECT_EQ(FindInstruction(m.get(), HloOpcode::kAbs), nullptr);
+ EXPECT_EQ(m->entry_computation()->root_instruction()->operand(2)->opcode(),
+ HloOpcode::kGetTupleElement);
+}
+
+TEST_F(AlgebraicSimplifierTest,
+ AbsEliminationBatchnormTraining_NegativeTestCase) {
+ const char* kModuleStr = R"(
+ HloModule m
+ test {
+ p0 = f32[128,32,2,112]{3,2,1,0} parameter(0)
+ p1 = f32[32]{0} parameter(1)
+ p2 = f32[32]{0} parameter(2)
+ constant = f32[] constant(0.001)
+ constant.1 = s64[] constant(1)
+ custom-call.1 = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) custom-call(p0, p1, p2, constant, constant.1), custom_call_target="__cudnn$batchNormalizationForwardTraining"
+ get-tuple-element.1 = f32[128,32,2,112]{3,2,1,0} get-tuple-element(custom-call.1), index=0
+ get-tuple-element.2 = f32[32]{0} get-tuple-element(custom-call.1), index=1
+ get-tuple-element = f32[32]{0} get-tuple-element(custom-call.1), index=2
+ abs = f32[32]{0} abs(get-tuple-element)
+ ROOT %tuple = (f32[128,32,2,112]{3,2,1,0}, f32[32]{0}, f32[32]{0}) tuple(get-tuple-element.1, get-tuple-element.2, abs)
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
+ default_options_.set_cudnn_batchnorm_forward_training_metadata(
+ "__cudnn$batchNormalizationForwardInference");
+ ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
+ EXPECT_NE(FindInstruction(m.get(), HloOpcode::kAbs), nullptr);
+}
+
+TEST_F(AlgebraicSimplifierTest, AbsEliminationMultiply) {
+ const char* kModuleStr = R"(
+ HloModule m
+ test {
+ p = f32[32]{0} parameter(0)
+ m = f32[32]{0} multiply(p, p)
+ ROOT a = f32[32]{0} abs(m)
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
+ ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
+ EXPECT_THAT(m->entry_computation()->root_instruction(),
+ GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0))));
+}
+
+TEST_F(AlgebraicSimplifierTest, AbsEliminationPower2) {
+ const char* kModuleStr = R"(
+ HloModule m
+ test {
+ p0 = f32[32]{0} parameter(0)
+ c0 = f32[] constant(2)
+ b0 = f32[32]{0} broadcast(c0), dimensions={}
+ pow = f32[32]{0} power(p0, b0)
+ ROOT a = f32[32]{0} abs(pow)
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
+ ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
+ // Pow(A, 2) is transformed to AA. As a result, Abs(Power(A, 2)) is
+ // transformed to AA.
+ EXPECT_THAT(m->entry_computation()->root_instruction(),
+ GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0))));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/convolution_group_converter.cc b/tensorflow/compiler/xla/service/convolution_group_converter.cc
index f942d67..06bcd77 100644
--- a/tensorflow/compiler/xla/service/convolution_group_converter.cc
+++ b/tensorflow/compiler/xla/service/convolution_group_converter.cc
@@ -218,14 +218,127 @@
int64 input_batch_dimension = dim_numbers.input_batch_dimension();
int64 output_batch_dimension = dim_numbers.output_batch_dimension();
+ const int64 kernel_output_feature_dimension =
+ dim_numbers.kernel_output_feature_dimension();
int64 output_feature_dimension = dim_numbers.output_feature_dimension();
int64 input_batch = activation->shape().dimensions(input_batch_dimension);
+ const int64 output_feature =
+ filter->shape().dimensions(kernel_output_feature_dimension);
+
+ VLOG(2) << "is_cost_viable_ " << is_cost_viable_(convolution);
+ const bool cost_too_high = !is_cost_viable_(convolution);
+
+ if (output_feature != batch_group_count) {
+ const int64 group_size = output_feature / batch_group_count;
+
+ VLOG(2) << "Need to insert a spatial dimension in activations and in the "
+ "kernel to deal with backprop of grouped convolutions "
+ << " group size " << group_size;
+
+ // Add spatial dimension to the activation, and reshape.
+ Shape reshaped_activation_shape = activation->shape();
+ ShapeUtil::AppendMajorDimension(1, &reshaped_activation_shape);
+ const int64 new_spatial_dim =
+ reshaped_activation_shape.dimensions().size() - 1;
+
+ activation = add(
+ HloInstruction::CreateReshape(reshaped_activation_shape, activation));
+
+ // Insert new spatial dimension after the output feature dimension on the
+ // kernel.
+ auto dims = filter->shape().dimensions();
+ std::vector<int64> new_dims;
+ for (int i = 0; i < dims.size(); i++) {
+ if (i == kernel_output_feature_dimension) {
+ new_dims.push_back(batch_group_count);
+ new_dims.push_back(group_size);
+ } else {
+ new_dims.push_back(dims[i]);
+ }
+ }
+
+ Shape reshaped_filter_shape = ShapeUtil::MakeShapeWithDescendingLayout(
+ filter->shape().element_type(), new_dims);
+
+ filter = add(HloInstruction::CreateReshape(reshaped_filter_shape, filter));
+
+ Shape new_output_shape = convolution->shape();
+ ShapeUtil::AppendMajorDimension(1, &new_output_shape);
+
+ // Edit convolution dimension numbers. Note that kernel_input_feature_dim
+ // now becomes a spatial dimension, and the newly added dimension of size
+ // 1 is the new kernel_input_feature_dim.
+ dim_numbers.add_input_spatial_dimensions(new_spatial_dim);
+
+ // Update spatial dimension numbers if they show up after the newly added
+ // spatial dimension.
+ for (auto& d : *dim_numbers.mutable_kernel_spatial_dimensions()) {
+ if (d > kernel_output_feature_dimension) {
+ ++d;
+ }
+ }
+
+ // Same for input feature dimension.
+ if (dim_numbers.kernel_input_feature_dimension() >
+ kernel_output_feature_dimension) {
+ dim_numbers.set_kernel_input_feature_dimension(
+ dim_numbers.kernel_input_feature_dimension() + 1);
+ }
+
+ dim_numbers.add_kernel_spatial_dimensions(kernel_output_feature_dimension +
+ 1);
+
+ dim_numbers.add_output_spatial_dimensions(output_batch_dimension);
+
+ dim_numbers.set_output_batch_dimension(new_spatial_dim);
+
+ // Add window for the new spatial dimension.
+ Window new_window = convolution->window();
+ auto* dim = new_window.add_dimensions();
+ dim->set_window_dilation(1);
+ dim->set_base_dilation(1);
+ dim->set_stride(1);
+ dim->set_size(group_size);
+ dim->set_padding_high(group_size - 1);
+ dim->set_padding_low(group_size - 1);
+ dim->set_window_reversal(false);
+
+ auto new_convolution = add(HloInstruction::CreateConvolve(
+ new_output_shape, activation, filter, /*feature_group_count=*/1,
+ batch_group_count, new_window, dim_numbers,
+ convolution->precision_config()));
+
+ VLOG(2) << "New convolution " << new_convolution->ToString();
+
+ // This reversal is not done via set_window_reversal because GPUs don't
+ // support it.
+ auto rev = add(HloInstruction::CreateReverse(
+ new_output_shape, new_convolution, {output_batch_dimension}));
+
+ // Delete the extra spatial dimension, and reshape.
+ Shape reshaped_convolution_shape =
+ ShapeUtil::DeleteDimension(new_spatial_dim, rev->shape());
+ auto reshaped_convolution =
+ HloInstruction::CreateReshape(reshaped_convolution_shape, rev);
+
+ VLOG(2) << "Reshaped convolution " << reshaped_convolution->ToString();
+
+ TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
+ convolution, std::move(reshaped_convolution)));
+
+ changed_ = true;
+
+ convolution = new_convolution;
+ dim_numbers = convolution->convolution_dimension_numbers();
+ output_batch_dimension = new_spatial_dim;
+ }
+
// We are not yet supporting batch_group of sizes greater than 1.
TF_RET_CHECK(input_batch == batch_group_count);
- if (!is_cost_viable_(convolution) || filter_expansion_) {
+ if (cost_too_high || filter_expansion_) {
// We first obtain the expanded the filter (which is the convolution
// output). The batch dimension is the expanded one (which originally
// represents kernel input feature dimension). We mask the filter to zero
@@ -238,11 +351,17 @@
auto expanded_filter_shape = ExpandedFilterShape(
convolution->shape(), batch_group_count, output_batch_dimension);
+ VLOG(2) << "output_batch_dimension " << output_batch_dimension;
+ VLOG(2) << "New output shape of convolution "
+ << expanded_filter_shape.ToString();
+
auto new_convolution = add(HloInstruction::CreateConvolve(
expanded_filter_shape, activation, filter,
/*feature_group_count=*/1, /*batch_group_count=*/1,
convolution->window(), dim_numbers, convolution->precision_config()));
+ VLOG(2) << "Expanded convolution " << new_convolution->ToString();
+
auto zero = add(HloInstruction::CreateConstant(
LiteralUtil::Zero(expanded_filter_shape.element_type())));
auto zero_filter =
@@ -354,6 +473,7 @@
changed_ = false;
return Status::OK();
}
+ VLOG(2) << "is_cost_viable_ " << is_cost_viable_(convolution);
// We want to repeat 'filter' in the 'input_feature_dim' dimension
// 'group_count' times.
if (!is_cost_viable_(convolution) || filter_expansion_) {
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index bec66ae..5f0e687 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -818,6 +818,7 @@
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"@com_google_absl//absl/types:span",
],
@@ -914,6 +915,7 @@
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 13e8a3f..1d7f9fa 100755
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -1116,6 +1116,7 @@
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:call_inliner",
"//tensorflow/compiler/xla/service:conditional_simplifier",
+ "//tensorflow/compiler/xla/service:convolution_group_converter",
"//tensorflow/compiler/xla/service:depthwise_convolution_converter",
"//tensorflow/compiler/xla/service:dot_decomposer",
"//tensorflow/compiler/xla/service:dump",
@@ -1196,6 +1197,7 @@
":gpu_conv_padding_legalization",
":gpu_conv_rewriter",
":gpu_layout_assignment",
+ ":ir_emission_utils",
":reduction_degenerate_dim_remover",
":reduction_dimension_grouper",
":reduction_layout_normalizer",
@@ -1604,6 +1606,7 @@
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:pattern_matcher",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc
index 37095ad..4ecf6ed 100644
--- a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc
+++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc
@@ -577,10 +577,24 @@
se::DeviceMemory<ElementT> rhs_typed(rhs);
uint64 buffer_size = lhs_typed.ElementCount();
- TF_ASSIGN_OR_RETURN(absl::Span<const uint8> compiled_ptx,
- se::CompileGpuAsmOrGetCached(executor->device_ordinal(),
- buffer_compare_ptx,
- PtxOptsFromConfig(config)));
+ absl::Span<const uint8> compiled_ptx = {};
+ StatusOr<absl::Span<const uint8>> compiled_ptx_or =
+ se::CompileGpuAsmOrGetCached(executor->device_ordinal(),
+ buffer_compare_ptx,
+ PtxOptsFromConfig(config));
+ if (compiled_ptx_or.ok()) {
+ compiled_ptx = compiled_ptx_or.ConsumeValueOrDie();
+ } else {
+ static std::once_flag ptxas_not_found_logged;
+ std::call_once(ptxas_not_found_logged, [&]() {
+ LOG(WARNING)
+ << compiled_ptx_or.status().ToString()
+ << "\nRelying on driver to perform ptx compilation. "
+ << "\nSetting XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda "
+ << " or modifying $PATH can be used to set the location of ptxas"
+ << "\nThis message will only be logged once.";
+ });
+ }
TF_ASSIGN_OR_RETURN(
std::unique_ptr<ComparisonKernelT<ElementT>> comparison_kernel,
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index 30b204e..6709a51 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -36,6 +36,7 @@
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
+#include "tensorflow/compiler/xla/service/convolution_group_converter.h"
#include "tensorflow/compiler/xla/service/depthwise_convolution_converter.h"
#include "tensorflow/compiler/xla/service/dot_decomposer.h"
#include "tensorflow/compiler/xla/service/dump.h"
@@ -138,11 +139,28 @@
// TODO(b/64094172): make Call work on GPU instead of inlining.
pipeline.AddPass<CallInliner>();
+
+ pipeline.AddPass<DotDecomposer>();
+
+ // We use the ConvolutionGroupConverter to convert backprops of filter
+ // grouped convolutions into non-grouped equivalents.
+ auto batch_group_cost_model = [](HloInstruction* conv) {
+ auto dim_numbers = conv->convolution_dimension_numbers();
+ const int64 input_batch_size = conv->operand(0)->shape().dimensions(
+ dim_numbers.input_batch_dimension());
+ return conv->batch_group_count() != input_batch_size;
+ };
+
+ pipeline.AddPass<ConvolutionGroupConverter>(
+ batch_group_cost_model,
+ /*convert_batch_groups_only=*/true,
+ /*canonicalize_depthwise_filter=*/false);
+
auto cost_model = [](HloInstruction* conv) {
// We need a cost model for GPUs. Currently, do nothing.
return false;
};
- pipeline.AddPass<DotDecomposer>();
+
pipeline.AddPass<DepthwiseConvolutionConverter>(cost_model);
// Expand the sort op to support stable sorting if required.
pipeline.AddPass<StableSortExpander>();
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
old mode 100755
new mode 100644
index fa01d75..d48c36b
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -31,6 +31,7 @@
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
#include "tensorflow/compiler/xla/service/gpu/reduction_degenerate_dim_remover.h"
#include "tensorflow/compiler/xla/service/gpu/reduction_dimension_grouper.h"
@@ -134,6 +135,8 @@
/*allow_mixed_precision=*/false);
AlgebraicSimplifierOptions options;
+ options.set_cudnn_batchnorm_forward_training_metadata(
+ kCudnnBatchNormForwardTrainingCallTarget);
pass.AddPass<AlgebraicSimplifier>(options);
}
@@ -432,7 +435,7 @@
"Can't find ptxas binary in ${CUDA_DIR}/bin. Will back to the "
"GPU driver for PTX -> sass compilation. This is OK so long "
"as you don't see a warning below about an out-of-date driver "
- "version.",
+ "version. Custom ptxas location can be specified using $PATH.",
hlo_module_config);
}
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 4c1ef5b..075d244 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -2615,18 +2615,37 @@
static double min() { return -max(); }
};
+// MSVC's standard C++ library does not define isnan/isfinite for integer types.
+// To work around that we will need to provide our own.
+template <typename T>
+std::enable_if_t<std::is_floating_point<T>::value, bool> IsFinite(T val) {
+ return std::isfinite(val);
+}
+template <typename T>
+std::enable_if_t<std::is_floating_point<T>::value, bool> IsNaN(T val) {
+ return std::isnan(val);
+}
+template <typename T>
+std::enable_if_t<std::is_integral<T>::value, bool> IsFinite(T val) {
+ return std::isfinite(static_cast<double>(val));
+}
+template <typename T>
+std::enable_if_t<std::is_integral<T>::value, bool> IsNaN(T val) {
+ return std::isnan(static_cast<double>(val));
+}
+
template <typename LiteralNativeT, typename ParsedElemT>
bool HloParserImpl::CheckParsedValueIsInRange(LocTy loc, ParsedElemT value) {
if (std::is_floating_point<ParsedElemT>::value) {
auto value_as_native_t = static_cast<LiteralNativeT>(value);
auto value_double_converted = static_cast<ParsedElemT>(value_as_native_t);
- if (!std::isfinite(value) || std::isfinite(value_double_converted)) {
+ if (!IsFinite(value) || IsFinite(value_double_converted)) {
value = value_double_converted;
}
}
PrimitiveType literal_ty =
primitive_util::NativeToPrimitiveType<LiteralNativeT>();
- if (std::isnan(value) ||
+ if (IsNaN(value) ||
(std::numeric_limits<ParsedElemT>::has_infinity &&
(std::numeric_limits<ParsedElemT>::infinity() == value ||
-std::numeric_limits<ParsedElemT>::infinity() == value))) {
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index 445a3ea..5d38bbe 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -370,7 +370,8 @@
const HloRematerialization::ShapeSizeFunction& size_function,
const HloRematerialization::CompactShapeFunction& compact_shape_function,
const TuplePointsToAnalysis& points_to_analysis,
- const InstructionList& instruction_list);
+ const InstructionList& instruction_list,
+ HloRematerialization::RematerializationMode mode);
// Starts the placement of the given instruction. This adds the sizes of the
// LogicalBuffers defined by the instruction to the current memory
@@ -607,6 +608,7 @@
// between the calling of BeginInstruction and EndInstruction.
Item* in_progress_item_ = nullptr;
+ HloRematerialization::RematerializationMode mode_;
// All buffers in the computation.
std::vector<Buffer> buffers_;
};
@@ -616,11 +618,13 @@
const HloRematerialization::ShapeSizeFunction& size_function,
const HloRematerialization::CompactShapeFunction& compact_shape_function,
const TuplePointsToAnalysis& points_to_analysis,
- const InstructionList& instruction_list)
+ const InstructionList& instruction_list,
+ HloRematerialization::RematerializationMode mode)
: computation_(computation),
instruction_list_(instruction_list),
size_function_(size_function),
- compact_shape_function_(compact_shape_function) {
+ compact_shape_function_(compact_shape_function),
+ mode_(mode) {
PointsToSet::BufferSet live_out_set =
points_to_analysis.GetPointsToSet(computation_->root_instruction())
.CreateFlattenedSet();
@@ -1155,7 +1159,10 @@
continue;
}
- if (item->buffers_output.size() == 1) {
+ if (item->buffers_output.size() == 1 &&
+ (mode_ == HloRematerialization::RematerializationMode::kCompressOnly ||
+ mode_ == HloRematerialization::RematerializationMode::
+ kRecomputeAndCompress)) {
// Only consider compressing single output instruction.
const Buffer& output_buffer = buffers_.at(item->buffers_output[0]);
@@ -1196,6 +1203,11 @@
continue;
}
+ // Do not consider recomputation in compress-only mode.
+ if (mode_ == HloRematerialization::RematerializationMode::kCompressOnly) {
+ continue;
+ }
+
const int64 memory_reduced = MemoryReducedIfRematerialized(item);
if (memory_reduced > 0) {
@@ -1370,7 +1382,7 @@
InstructionList instruction_list(order);
MemoryUsageTracker tracker(computation, size_function_,
compact_shape_function_, *points_to_analysis_,
- instruction_list);
+ instruction_list, mode_);
int64 peak_memory = tracker.memory_usage();
for (auto* item = instruction_list.first(); item != nullptr;
item = instruction_list.next(item)) {
@@ -1412,9 +1424,9 @@
CHECK(!ContainsKey(rematerialized_computations_, computation));
InstructionList instruction_list(schedule->sequence(computation));
- MemoryUsageTracker memory_tracker(computation, size_function_,
- compact_shape_function_,
- *points_to_analysis_, instruction_list);
+ MemoryUsageTracker memory_tracker(
+ computation, size_function_, compact_shape_function_,
+ *points_to_analysis_, instruction_list, mode_);
bool changed = false;
// If the rematerialization makes the source instruction dead, then the
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h
index 9ab34b4..69cdc84 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.h
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h
@@ -49,6 +49,13 @@
int64 after_bytes;
};
+ // Mode in which the rematerialization algorithm should be run.
+ enum class RematerializationMode {
+ kRecomputeOnly, // Only consider the kCompress RematStrategy.
+ kCompressOnly, // Only consider the kRecompute RematStrategy.
+ kRecomputeAndCompress // Consider both kRecompute and kRemat.
+ };
+
static Shape DefaultCompactShapeFunction(const Shape& shape) { return shape; }
// Constructor parameters:
@@ -69,13 +76,15 @@
explicit HloRematerialization(
const ShapeSizeFunction& size_function, int64 memory_limit_bytes,
RematerializationSizes* sizes,
- CompactShapeFunction compact_shape_function = nullptr)
+ CompactShapeFunction compact_shape_function = nullptr,
+ RematerializationMode mode = RematerializationMode::kRecomputeAndCompress)
: size_function_(size_function),
memory_limit_bytes_(memory_limit_bytes),
sizes_(sizes),
compact_shape_function_(compact_shape_function == nullptr
? DefaultCompactShapeFunction
- : std::move(compact_shape_function)) {}
+ : std::move(compact_shape_function)),
+ mode_(mode) {}
~HloRematerialization() override = default;
absl::string_view name() const override { return "rematerialization"; }
@@ -152,6 +161,8 @@
// uses of the original instruction and the original instruction is
// dead. Hence, no net instructions were added.
int64 net_instructions_added_ = 0;
+
+ RematerializationMode mode_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc
index caf8fce..30682aa 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc
@@ -91,6 +91,11 @@
return end_time - start_time <= max_overlap_count_;
}
+int64 InstructionCountPrefetchIntervalPicker::PreferredEvictionEndTime(
+ const Shape& shape, int64 start_time, int64 latest_end_time) const {
+ return std::min(start_time + min_overlap_count_, latest_end_time);
+}
+
void InstructionCountPrefetchIntervalPicker::Begin(const HloUse& use,
int64 start_time,
int64 end_time) {
@@ -153,6 +158,21 @@
logical_interval_elapsed;
}
+int64 CostAnalysisPrefetchIntervalPicker::PreferredEvictionEndTime(
+ const Shape& shape, int64 start_time, int64 latest_end_time) const {
+ float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
+ int64 end_time;
+ for (end_time = start_time + 1; end_time <= latest_end_time; ++end_time) {
+ float logical_interval_elapsed =
+ GetLogicalIntervalElapsed(start_time, end_time);
+ if (logical_interval_elapsed >=
+ min_async_copy_to_overlap_ratio_ * async_copy_elapsed) {
+ break;
+ }
+ }
+ return end_time;
+}
+
void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use,
int64 start_time,
int64 end_time) {
@@ -337,8 +357,7 @@
absl::make_unique<MemorySpaceAssignment::Allocation>(
value->defining_instruction(), value->defining_position(),
aliased_allocation->memory_space(), aliased_allocation->chunk(),
- aliased_allocation->start_time(),
- aliased_allocation->end_time()));
+ definition_time, definition_time));
}
// Iterate over the uses.
@@ -418,6 +437,28 @@
return result_;
}
+bool operator<(const AsynchronousCopy& a, const AsynchronousCopy& b) {
+ return (a.start_time < b.start_time && a.end_time <= b.end_time) ||
+ (a.start_time <= b.start_time && a.end_time < b.end_time);
+}
+
+void AsynchronousCopyOrdering::AddCopy(const AsynchronousCopy& copy) {
+ auto it_and_inserted = ranges_.insert(copy);
+ CHECK(it_and_inserted.second ||
+ it_and_inserted.first->start_time == copy.start_time);
+}
+
+bool AsynchronousCopyOrdering::ViolatesOrdering(int64 start_time,
+ int64 end_time) const {
+ // We allow identical start and end times. It is enough to check for just the
+ // start time in case we find a match in ranges_ because the found value will
+ // either be identical to {start_time, end_time} (and this doesn't violate) or
+ // its start_time will be smaller and end_time will be larger (this violates).
+ auto copy_it = ranges_.find(
+ {start_time, end_time, MemorySpaceAssignment::MemorySpace::kAlternate});
+ return copy_it != ranges_.end() && copy_it->start_time != start_time;
+}
+
void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() {
// Go through the parameters and outputs and pin them to the corresponding
// memory by adding a required assignment.
@@ -520,14 +561,7 @@
kDummyChunk);
}
if (interval.destination == MemorySpace::kAlternate) {
- // If there is already an asynchronous copy ending the same time, pick
- // the earliest copy start time.
- auto range_it = async_copy_range_map_.find(interval.end_time);
- if (range_it != async_copy_range_map_.end()) {
- range_it->second = std::min(range_it->second, interval.start_time);
- } else {
- async_copy_range_map_[interval.end_time] = interval.start_time;
- }
+ async_copy_ordering_.AddCopy(interval);
}
}
pending_async_copies_.clear();
@@ -648,27 +682,63 @@
prev_allocation->defining_position() == defining_position) {
// If there was an allocation for this HloValue that was in the alternate
// memory space, we also need to perform an eviction.
- // TODO(berkin): For now evictions happen relative to the most recent
- // allocation in the alternate memory. We can potentially start evictions
- // earlier and end later.
+ int64 eviction_start_time = prev_allocation->start_time();
+ int64 eviction_end_time = prev_allocation->end_time();
+ CHECK(eviction_start_time <= eviction_end_time);
+
+ int64 preferred_eviction_end_time = std::max(
+ options_.prefetch_interval_picker->PreferredEvictionEndTime(
+ non_bitcast_operand->shape(), eviction_start_time, end_time),
+ eviction_end_time);
+
+ BufferInterval eviction_mem_interval;
+ eviction_mem_interval.buffer = buffer;
+ eviction_mem_interval.size = size;
+ // Try to reserve a buffer from the end of the previous allocation to the
+ // preferred eviction end time.
+ eviction_mem_interval.start = prev_allocation->end_time() + 1;
+ eviction_mem_interval.end = preferred_eviction_end_time;
+ int64 preferred_offset = prev_allocation->chunk().offset;
+ VLOG(4) << "Eviction (" << eviction_start_time << ", " << eviction_end_time
+ << ") preferred end time = " << preferred_eviction_end_time;
+
+ while (preferred_eviction_end_time > eviction_end_time) {
+ ChunkCandidate chunk_candidate =
+ FindChunkCandidate(eviction_mem_interval, preferred_offset);
+ if (chunk_candidate.chunk.offset == preferred_offset) {
+ eviction_end_time = preferred_eviction_end_time;
+ AddToPendingChunks(eviction_mem_interval, chunk_candidate);
+ break;
+ }
+ eviction_mem_interval.end = --preferred_eviction_end_time;
+ }
+
VLOG(3) << "Evicting buffer at " << prev_allocation->chunk().offset << " ("
- << prev_allocation->start_time() << ", "
- << prev_allocation->end_time() << ")";
+ << eviction_start_time << ", " << eviction_end_time << ")";
+
+ bool eviction_interval_too_short =
+ (eviction_start_time == eviction_end_time);
+ bool eviction_violates_outstanding_copies =
+ ViolatesMaximumOutstandingAsyncCopies(eviction_start_time,
+ eviction_end_time);
// See if this interval would violate the asynchronous copy limit.
- if (!ViolatesMaximumOutstandingAsyncCopies(prev_allocation->start_time(),
- prev_allocation->end_time())) {
+ if (!eviction_interval_too_short && !eviction_violates_outstanding_copies) {
+ prev_allocation->Extend(eviction_end_time);
AddAsyncCopy(*prev_allocation, MemorySpace::kDefault, kDummyChunk,
- prev_allocation->start_time(), prev_allocation->end_time(),
- prev_allocation->end_time(), allocations);
-
+ eviction_start_time, prev_allocation->end_time(),
+ eviction_end_time, allocations);
} else {
- VLOG(3) << "This violates the maximum async copies.";
+ if (eviction_violates_outstanding_copies) {
+ VLOG(3) << "This violates the maximum async copies.";
+ } else {
+ VLOG(3) << "Eviction interval is too short (" << eviction_start_time
+ << ", " << eviction_end_time << ").";
+ }
// If the original interval violated the limit, try sub-intervals within
// this interval.
bool eviction_scheduled = false;
- for (int64 time = prev_allocation->start_time();
- time <= prev_allocation->end_time(); ++time) {
+ for (int64 time = eviction_start_time; time < eviction_end_time; ++time) {
VLOG(3) << "Try evicting (" << time << ", " << time << ")";
if (!ViolatesMaximumOutstandingAsyncCopies(time, time)) {
VLOG(3) << "Eviction successful.";
@@ -686,10 +756,10 @@
<< " because we hit the limit of maximum asynchronous copies "
<< "between "
<< hlo_live_range_.flattened_instruction_sequence()
- .instructions()[prev_allocation->start_time()]
+ .instructions()[eviction_start_time]
<< " and "
<< hlo_live_range_.flattened_instruction_sequence()
- .instructions()[prev_allocation->end_time()];
+ .instructions()[eviction_end_time];
return false;
}
}
@@ -736,8 +806,8 @@
VLOG(4) << "This would violate the outstanding async copy limit.";
continue;
}
- if (ViolatesAsynchronousCopyOrdering(alternate_mem_interval.start,
- alternate_mem_interval.end)) {
+ if (async_copy_ordering_.ViolatesOrdering(alternate_mem_interval.start,
+ alternate_mem_interval.end)) {
VLOG(4) << "This would violate asynchronous copy ordering.";
continue;
}
@@ -812,13 +882,6 @@
return num_async_copies + 1 > options_.max_outstanding_async_copies;
}
-bool AlternateMemoryBestFitHeap::ViolatesAsynchronousCopyOrdering(
- int64 start_time, int64 end_time) const {
- auto async_copy_range_it = async_copy_range_map_.lower_bound(end_time);
- return async_copy_range_it != async_copy_range_map_.end() &&
- async_copy_range_it->second < start_time;
-}
-
bool AlternateMemoryBestFitHeap::TryAllocatingInAlternateMemoryNoCopy(
int64 start_time, int64 end_time, int64 last_use_time,
HloPosition defining_position, HloUse use,
@@ -1313,6 +1376,13 @@
return;
}
for (HloInstruction* operand : new_instruction->operands()) {
+ // CopyStart/CopyDone dependencies should always be already inserted; it is
+ // a red flag when they haven't already been inserted.
+ CHECK((operand->opcode() != HloOpcode::kCopyStart &&
+ operand->opcode() != HloOpcode::kCopyDone) ||
+ inserted_instructions->contains(operand))
+ << "Inserted instruction " << new_instruction->ToString()
+ << " has un-inserted dependency: " << operand->ToString();
EnsureInstructionAndOperandsInserted(operand, new_sequence,
inserted_instructions);
}
@@ -1404,10 +1474,14 @@
}
HloInstruction* instruction = flattened_instructions_[instruction_index];
// Insert only if it is not deleted (SimplifyGraph sets it to nullptr if
- // it was deleted) and not previously inserted.
+ // it was deleted) and not previously inserted. Also bitcasts and tuples
+ // are treated specially and only inserted as a result of operand
+ // dependencies.
if (instruction != nullptr &&
!inserted_instructions.contains(instruction) &&
- instruction->parent() == computation) {
+ instruction->parent() == computation &&
+ instruction->opcode() != HloOpcode::kBitcast &&
+ instruction->opcode() != HloOpcode::kTuple) {
EnsureInstructionAndOperandsInserted(instruction, &new_sequence,
&inserted_instructions);
}
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h
index 67ced4c..0242ded 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment.h
+++ b/tensorflow/compiler/xla/service/memory_space_assignment.h
@@ -123,6 +123,11 @@
int64 start_time,
int64 end_time) const = 0;
+ // Returns the preferred end time for an eviction that starts at a given time
+ // and must end by the given end time.
+ virtual int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time,
+ int64 latest_end_time) const = 0;
+
// Begins the iterator for the first start time of the prefetch.
virtual void Begin(const HloUse& use, int64 start_time, int64 end_time) = 0;
@@ -166,6 +171,9 @@
bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape, int64 start_time,
int64 end_time) const override;
+ int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time,
+ int64 latest_end_time) const override;
+
void Begin(const HloUse& use, int64 start_time, int64 end_time) override;
int64 Next() override;
@@ -206,6 +214,9 @@
bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape, int64 start_time,
int64 end_time) const override;
+ int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time,
+ int64 latest_end_time) const override;
+
void Begin(const HloUse& use, int64 start_time, int64 end_time) override;
int64 Next() override;
@@ -526,6 +537,48 @@
int64 time;
};
+// A struct representing an asynchronous copy with its logical start and end
+// time and its destination memory space.
+struct AsynchronousCopy {
+ int64 start_time;
+ int64 end_time;
+ MemorySpaceAssignment::MemorySpace destination;
+};
+
+// Compare asynchronous copies such that an earlier start time has the same or
+// earlier end time and an earlier end time has the same or earlier start time.
+bool operator<(const AsynchronousCopy& a, const AsynchronousCopy& b);
+
+// Helper class to enforce asynchronous copy ordering. We only allow
+// asynchronous copies that are pipelined: if an asynchronous copy ends earlier
+// than another asynchronous copy, it must start the same time or earlier than
+// the other asynchronous copy; and if an asynchronous copy starts earlier than
+// another asynchronous copy, it must end the same time or earlier than the
+// other asynchronous copy.
+class AsynchronousCopyOrdering {
+ public:
+ AsynchronousCopyOrdering() = default;
+
+ // Adds an asynchronous copy.
+ void AddCopy(const AsynchronousCopy& copy);
+
+ // Returns true if the addition of an asynchronous copy in the the given time
+ // interval would violate the asynchronous copy ordering. E.g., consider the
+ // following scenario:
+ // CS CD
+ // already committed async copy: +-----------+
+ // new async copy: +--------+
+ //
+ // The new asynchronous copy would violate the ordering guarantee because the
+ // copy start is after an already committed asynchronous copy while its copy
+ // done is before the committed copy.
+ bool ViolatesOrdering(int64 start_time, int64 end_time) const;
+
+ private:
+ // Stores asynchronous copies in a tree set respecting the pipelining order.
+ std::set<AsynchronousCopy> ranges_;
+};
+
// This class inherits from GlobalDecreasingSizeBestFitHeap with a notion of
// maximum size.
class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
@@ -551,14 +604,6 @@
HeapSimulator::Result Finish() override;
private:
- // A struct representing an asynchronous copy with its logical start and end
- // time and its destination memory space.
- struct AsynchronousCopy {
- int64 start_time;
- int64 end_time;
- MemorySpace destination;
- };
-
// Finds an allocation for the given interval. Internally, it will attempt to
// find a suitable chunk candidate within the heap size and prefetch interval
// limits, and append the new allocation(s) to allocations. The new
@@ -603,18 +648,6 @@
bool ViolatesMaximumOutstandingAsyncCopies(int64 start_time,
int64 end_time) const;
- // Returns true if the addition of an asynchronous copy in the the given time
- // interval would violate the asynchronous copy ordering. E.g., consider the
- // following scenario:
- // CS CD
- // already committed async copy: +-----------+
- // new async copy: +--------+
- //
- // The new asynchronous copy would violate the ordering guarantee because the
- // copy start is after an already committed asynchronous copy while its copy
- // done is before the committed copy.
- bool ViolatesAsynchronousCopyOrdering(int64 start_time, int64 end_time) const;
-
// Adds an asynchronous copy to the allocations.
void AddAsyncCopy(const MemorySpaceAssignment::Allocation& prev_allocation,
MemorySpace memory_space, Chunk chunk, int64 start_time,
@@ -639,9 +672,7 @@
// We use a interval tree to keep track of the number of outstanding
// asynchronous copies.
BufferIntervalTree async_copy_interval_tree_;
- // Given the logical time for CopyDone in key, stores the earliest time for
- // the corresponding CopyStart.
- std::map<int64, int64> async_copy_range_map_;
+ AsynchronousCopyOrdering async_copy_ordering_;
std::vector<std::pair<BufferInterval, ChunkCandidate>> pending_chunks_;
std::vector<AsynchronousCopy> pending_async_copies_;
// This map contains required memory assignments for HloValues (e.g., input
diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
index 238bbed..49acdeb 100644
--- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc
@@ -67,9 +67,9 @@
std::unique_ptr<PresetAssignments> AssignMemorySpace(
HloModule* module, int64 max_outstanding_async_copies = -1,
- int64 max_prefetch_interval = 10) {
+ int64 max_prefetch_interval = 10, int64 min_prefetch_interval = 2) {
InstructionCountPrefetchIntervalPicker prefetch_interval_picker(
- /*min_overlap_count=*/2, max_prefetch_interval);
+ min_prefetch_interval, max_prefetch_interval);
return AssignMemorySpace(module, max_outstanding_async_copies,
/*buffer_interval_compare=*/{},
&prefetch_interval_picker);
@@ -759,6 +759,77 @@
AssignMemorySpace(module.get());
}
+TEST_P(MemorySpaceAssignmentTest, BitcastScheduleBug) {
+ // Bitcasts can force asynchronous copies to be scheduled too early, possibly
+ // leading to memory corruption.
+ // Bug:
+ // p0------------------>neg-->neg-->neg ... -->neg-->neg-->neg->add
+ // /
+ // p1->cs->cd->bitcast-----------------------------------------+
+ //
+ // Expected:
+ // p0-->neg-->neg-->neg ... -->neg-->neg-->neg------------->add
+ // /
+ // p1--------------------->cs----------------->cd->bitcast-+
+ HloComputation::Builder builder(TestName());
+ Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
+ Shape param_shape = ShapeUtil::MakeShape(F32, {6});
+ HloInstruction* p0 =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
+ HloInstruction* p1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, param_shape, "p1"));
+ HloInstruction* bitcast =
+ builder.AddInstruction(HloInstruction::CreateBitcast(shape, p1));
+ HloInstruction* negate0 = builder.AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
+ HloInstruction* negate1 = builder.AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
+ HloInstruction* negate2 = builder.AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
+ HloInstruction* negate3 = builder.AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
+ HloInstruction* negate4 = builder.AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
+ HloInstruction* negate5 = builder.AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
+ HloInstruction* negate6 = builder.AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
+ HloInstruction* negate7 = builder.AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
+ HloInstruction* negate8 = builder.AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate7));
+ HloInstruction* negate9 = builder.AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate8));
+ HloInstruction* add = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, bitcast, negate9));
+
+ auto module = CreateNewVerifiedModule();
+ HloComputation* computation = module->AddEntryComputation(builder.Build());
+
+ HloSchedule schedule(module.get());
+ schedule.set_sequence(
+ computation, {p0, p1, bitcast, negate0, negate1, negate2, negate3,
+ negate4, negate5, negate6, negate7, negate8, negate9, add});
+ TF_CHECK_OK(module->set_schedule(schedule));
+
+ AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
+ /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/4);
+
+ EXPECT_EQ(bitcast->shape().layout().memory_space(), kAlternateMemorySpace);
+ const auto& instructions =
+ module->schedule().sequence(module->entry_computation()).instructions();
+ for (int i = 0; i < instructions.size(); ++i) {
+ // Expect that there is a negate before and after the CopyStart and there is
+ // a negate before CopyDone.
+ if (instructions.at(i)->opcode() == HloOpcode::kCopyStart) {
+ EXPECT_EQ(instructions.at(i - 1)->opcode(), HloOpcode::kNegate);
+ EXPECT_EQ(instructions.at(i + 1)->opcode(), HloOpcode::kNegate);
+ } else if (instructions.at(i)->opcode() == HloOpcode::kCopyDone) {
+ EXPECT_EQ(instructions.at(i - 1)->opcode(), HloOpcode::kNegate);
+ }
+ }
+}
+
TEST_P(MemorySpaceAssignmentTest, LastUseOpt) {
// Test that checks the last use optimization. It uses two buffers that should
// be placed in alternate memory.
@@ -2266,5 +2337,38 @@
MemorySpaceAssignmentTest,
::testing::Values(false, true));
+using AsynchronousCopyOrderingTest = ::testing::Test;
+
+TEST_F(AsynchronousCopyOrderingTest, Simple) {
+ // Given asynchronous copies like the following, ensure the pipelining order
+ // is maintained (earlier start time must have earlier end time).
+ // 3,11 +-------+ OK
+ // 1,8 +------+ OK
+ // 5,14 +--------+ OK
+ // 7,14 +------+ OK
+ // 2,16 +-------------+ Violate
+ // 9,12 +--+ Violate
+ // 6,17 +----------+ Violate
+ // 5,13 +-------+ OK (same start as 5,14)
+ // 5,14 +--------+ OK (same as 5,14)
+ auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate;
+ AsynchronousCopyOrdering ordering;
+ EXPECT_FALSE(ordering.ViolatesOrdering(3, 11));
+ ordering.AddCopy({3, 11, alternate_mem_space});
+ EXPECT_FALSE(ordering.ViolatesOrdering(1, 8));
+ ordering.AddCopy({1, 8, alternate_mem_space});
+ EXPECT_FALSE(ordering.ViolatesOrdering(5, 14));
+ ordering.AddCopy({5, 14, alternate_mem_space});
+ EXPECT_FALSE(ordering.ViolatesOrdering(7, 14));
+ ordering.AddCopy({7, 14, alternate_mem_space});
+ EXPECT_TRUE(ordering.ViolatesOrdering(2, 16));
+ EXPECT_TRUE(ordering.ViolatesOrdering(9, 12));
+ EXPECT_TRUE(ordering.ViolatesOrdering(6, 17));
+ EXPECT_FALSE(ordering.ViolatesOrdering(5, 13));
+ ordering.AddCopy({5, 13, alternate_mem_space});
+ EXPECT_FALSE(ordering.ViolatesOrdering(5, 14));
+ ordering.AddCopy({5, 14, alternate_mem_space});
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc
index 84e239a..eec3e4d 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc
@@ -117,18 +117,18 @@
struct BoundAffineMap {
mlir::AffineMap affine_map;
- std::vector<mlir::Value*> operands;
+ std::vector<mlir::Value> operands;
};
BoundAffineMap GetBoundAffineMapFrom(mlir::Operation* op) {
if (auto load = mlir::dyn_cast<mlir::AffineLoadOp>(op)) {
return {load.getAffineMap(),
- std::vector<mlir::Value*>(load.getMapOperands().begin(),
- load.getMapOperands().end())};
+ std::vector<mlir::Value>(load.getMapOperands().begin(),
+ load.getMapOperands().end())};
} else if (auto store = mlir::dyn_cast<mlir::AffineStoreOp>(op)) {
return {store.getAffineMap(),
- std::vector<mlir::Value*>(store.getMapOperands().begin(),
- store.getMapOperands().end())};
+ std::vector<mlir::Value>(store.getMapOperands().begin(),
+ store.getMapOperands().end())};
} else {
CHECK(false);
}
@@ -150,7 +150,7 @@
}
}
-void SetMemRef(mlir::Operation* op, mlir::Value* memref) {
+void SetMemRef(mlir::Operation* op, mlir::Value memref) {
if (auto load = mlir::dyn_cast<mlir::AffineLoadOp>(op)) {
load.setMemRef(memref);
} else if (auto store = mlir::dyn_cast<mlir::AffineStoreOp>(op)) {
@@ -325,7 +325,7 @@
auto new_alloc =
builder.create<mlir::AllocOp>(builder.getUnknownLoc(), new_type);
- std::vector<mlir::Value*> indvars;
+ std::vector<mlir::Value> indvars;
for (auto ancestor : ancestors) {
indvars.push_back(ancestor.getInductionVar());
}
@@ -418,7 +418,7 @@
// output[...] = output_acc[]
// }
StatusOr<InitialMlirConvAnchors> CreateNaiveMlirConv(
- mlir::Value* input, mlir::Value* filter, mlir::Value* output,
+ mlir::Value input, mlir::Value filter, mlir::Value output,
const ShapeInfo& input_shape_info, const ShapeInfo& filter_shape_info,
const ShapeInfo& output_shape_info, const Window& window,
mlir::OpBuilder builder) {
@@ -440,7 +440,7 @@
location,
builder.create<mlir::ConstantOp>(
location, mlir::FloatAttr::get(builder.getF32Type(), 0)),
- output_acc, llvm::ArrayRef<mlir::Value*>());
+ output_acc, llvm::ArrayRef<mlir::Value>());
std::vector<mlir::AffineForOp> reduction_loops;
reduction_loops = CreateNestedSimpleLoops(
@@ -450,11 +450,11 @@
mlir::AffineForOp loop_o = cartesian_product_loops[1];
mlir::AffineForOp loop_c = reduction_loops[0];
- std::vector<mlir::Value*> output_spatial_indvars;
+ std::vector<mlir::Value> output_spatial_indvars;
for (auto loop : absl::MakeSpan(cartesian_product_loops).subspan(2)) {
output_spatial_indvars.push_back(loop.getInductionVar());
}
- std::vector<mlir::Value*> filter_spatial_indvars;
+ std::vector<mlir::Value> filter_spatial_indvars;
for (auto loop : absl::MakeSpan(reduction_loops).subspan(1)) {
filter_spatial_indvars.push_back(loop.getInductionVar());
}
@@ -463,7 +463,7 @@
builder = reduction_loops.back().getBodyBuilder();
- mlir::Value* loaded_input = [&] {
+ mlir::Value loaded_input = [&] {
std::vector<mlir::AffineExpr> input_indices;
input_indices.push_back(builder.getAffineDimExpr(0));
input_indices.push_back(builder.getAffineDimExpr(1));
@@ -479,7 +479,7 @@
builder.getAffineDimExpr(2 + num_spatial_dims + i) -
window_dim.padding_low());
}
- std::vector<mlir::Value*> input_vars;
+ std::vector<mlir::Value> input_vars;
input_vars.push_back(loop_n.getInductionVar());
input_vars.push_back(loop_c.getInductionVar());
input_vars.insert(input_vars.end(), output_spatial_indvars.begin(),
@@ -499,8 +499,8 @@
builder.getF32Type());
}();
- mlir::Value* loaded_filter = [&] {
- std::vector<mlir::Value*> filter_vars;
+ mlir::Value loaded_filter = [&] {
+ std::vector<mlir::Value> filter_vars;
filter_vars.push_back(loop_o.getInductionVar());
filter_vars.push_back(loop_c.getInductionVar());
filter_vars.insert(filter_vars.end(), filter_spatial_indvars.begin(),
@@ -519,11 +519,11 @@
location,
builder.createOrFold<mlir::AffineLoadOp>(location, output_acc),
builder.create<mlir::MulFOp>(location, loaded_input, loaded_filter)),
- output_acc, llvm::ArrayRef<mlir::Value*>());
+ output_acc, llvm::ArrayRef<mlir::Value>());
builder.setInsertionPointAfter(reduction_loops[0]);
{
- std::vector<mlir::Value*> output_vars;
+ std::vector<mlir::Value> output_vars;
output_vars.push_back(loop_n.getInductionVar());
output_vars.push_back(loop_o.getInductionVar());
output_vars.insert(output_vars.end(), output_spatial_indvars.begin(),
@@ -735,9 +735,9 @@
builder.create<mlir::ReturnOp>(builder.getUnknownLoc());
builder.setInsertionPointToStart(entry_block);
- mlir::Value* input = entry_block->getArgument(1);
- mlir::Value* filter = entry_block->getArgument(2);
- mlir::Value* output = entry_block->getArgument(0);
+ mlir::Value input = entry_block->getArgument(1);
+ mlir::Value filter = entry_block->getArgument(2);
+ mlir::Value output = entry_block->getArgument(0);
TF_RETURN_IF_ERROR(ConvIsImplemented(conv));
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc
index 60b5d08..ec5ae03 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc
@@ -43,14 +43,21 @@
namespace hlo = ::mlir::xla_hlo;
// TODO(b/137624192) Use tablegen for this.
-StatusOr<Value*> InsertMlirOp(
- HloOpcode opcode, OpBuilder func_builder, Location loc, ArrayRef<Type> rets,
- ArrayRef<Value*> args, ArrayRef<std::pair<Identifier, Attribute>> attrs) {
+StatusOr<Value> InsertMlirOp(HloOpcode opcode, OpBuilder func_builder,
+ Location loc, ArrayRef<Type> rets,
+ ArrayRef<Value> args,
+ ArrayRef<std::pair<Identifier, Attribute>> attrs) {
switch (opcode) {
+ case HloOpcode::kAbs:
+ return {func_builder.create<hlo::AbsOp>(loc, rets, args, attrs)};
case HloOpcode::kAdd:
return {func_builder.create<hlo::AddOp>(loc, rets, args, attrs)};
case HloOpcode::kAnd:
return {func_builder.create<hlo::AndOp>(loc, rets, args, attrs)};
+ case HloOpcode::kCeil:
+ return {func_builder.create<hlo::CeilOp>(loc, rets, args, attrs)};
+ case HloOpcode::kCos:
+ return {func_builder.create<hlo::CosOp>(loc, rets, args, attrs)};
case HloOpcode::kDivide:
return {func_builder.create<hlo::DivOp>(loc, rets, args, attrs)};
case HloOpcode::kExp:
@@ -61,10 +68,18 @@
return {func_builder.create<hlo::MinOp>(loc, rets, args, attrs)};
case HloOpcode::kMultiply:
return {func_builder.create<hlo::MulOp>(loc, rets, args, attrs)};
+ case HloOpcode::kNegate:
+ return {func_builder.create<hlo::NegOp>(loc, rets, args, attrs)};
+ case HloOpcode::kRemainder:
+ return {func_builder.create<hlo::RemOp>(loc, rets, args, attrs)};
case HloOpcode::kSelect:
return {func_builder.create<hlo::SelectOp>(loc, rets, args, attrs)};
+ case HloOpcode::kSign:
+ return {func_builder.create<hlo::SignOp>(loc, rets, args, attrs)};
case HloOpcode::kSubtract:
return {func_builder.create<hlo::SubOp>(loc, rets, args, attrs)};
+ case HloOpcode::kTanh:
+ return {func_builder.create<hlo::TanhOp>(loc, rets, args, attrs)};
default:
return tensorflow::errors::Internal(absl::StrCat(
"HLO Opcode ", HloOpcodeString(opcode), " is not supported."));
@@ -78,7 +93,7 @@
return emission_context_->getLocation(instr);
}
-StatusOr<Value*> HloDialectEmitter::EmitComputation(
+StatusOr<Value> HloDialectEmitter::EmitComputation(
const HloComputation& computation) {
const auto root = computation.root_instruction();
TF_RETURN_IF_ERROR(root->Accept(this));
@@ -88,7 +103,7 @@
Status HloDialectEmitter::DefaultAction(HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(auto res_type, ConvertTensorShapeToType<RankedTensorType>(
instr->shape(), builder_));
- llvm::SmallVector<Value*, 4> arguments;
+ llvm::SmallVector<Value, 4> arguments;
for (auto operand : instr->operands()) {
arguments.push_back(instruction_to_values_[operand]);
}
@@ -135,7 +150,7 @@
}
Status HloDialectEmitter::HandleReduce(HloInstruction* reduce) {
- llvm::SmallVector<Value*, 4> operands;
+ llvm::SmallVector<Value, 4> operands;
for (auto operand : reduce->operands()) {
operands.push_back(instruction_to_values_.at(operand));
}
@@ -152,7 +167,7 @@
{
auto computation = reduce->to_apply();
auto block = new mlir::Block();
- llvm::SmallVector<Value*, 4> arguments;
+ llvm::SmallVector<Value, 4> arguments;
arguments.reserve(computation->num_parameters());
for (auto parameter : computation->parameter_instructions()) {
TF_ASSIGN_OR_RETURN(auto param_type,
@@ -166,7 +181,7 @@
OpBuilder body_builder(block);
body_builder.setInsertionPointToEnd(block);
body_builder.create<hlo::ReturnOp>(getLocation(reduce),
- ArrayRef<Value*>{result});
+ ArrayRef<Value>{result});
}
// TODO(b/137624192) Add support for multiple results.
instruction_to_values_[reduce] = reduceOp.getResult(0);
@@ -180,7 +195,7 @@
"comparison_direction",
builder_.getStringAttr(
ComparisonDirectionToString(compare->comparison_direction())));
- llvm::SmallVector<Value*, 4> arguments;
+ llvm::SmallVector<Value, 4> arguments;
for (auto operand : compare->operands()) {
arguments.push_back(instruction_to_values_[operand]);
}
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h
index 86ed97b..eeff31b 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h
+++ b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.h
@@ -37,19 +37,19 @@
public:
HloDialectEmitter(xla::mlir_gpu::EmissionContext* emission_context,
::mlir::Region* region,
- llvm::ArrayRef<::mlir::Value*> arguments)
+ llvm::ArrayRef<::mlir::Value> arguments)
: emission_context_(emission_context),
builder_(region),
arguments_(arguments) {}
HloDialectEmitter(xla::mlir_gpu::EmissionContext* emission_context,
::mlir::OpBuilder builder,
- llvm::ArrayRef<::mlir::Value*> arguments)
+ llvm::ArrayRef<::mlir::Value> arguments)
: emission_context_(emission_context),
builder_(builder),
arguments_(arguments) {}
- StatusOr<mlir::Value*> EmitComputation(const HloComputation& computation);
+ StatusOr<mlir::Value> EmitComputation(const HloComputation& computation);
Status DefaultAction(HloInstruction* instr) override;
Status HandleBroadcast(HloInstruction* broadcast) override;
@@ -64,8 +64,8 @@
xla::mlir_gpu::EmissionContext* emission_context_;
::mlir::OpBuilder builder_;
- llvm::ArrayRef<::mlir::Value*> arguments_;
- absl::flat_hash_map<const xla::HloInstruction*, ::mlir::Value*>
+ llvm::ArrayRef<::mlir::Value> arguments_;
+ absl::flat_hash_map<const xla::HloInstruction*, ::mlir::Value>
instruction_to_values_;
};
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc
index 186dacc..78d83db 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc
@@ -17,7 +17,6 @@
#include <memory>
-#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // TF:local_config_mlir
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // TF:local_config_mlir
@@ -108,7 +107,7 @@
struct SingleTripLoopRemoval
: public mlir::FunctionPass<SingleTripLoopRemoval> {
void runOnFunction() override {
- auto getConstantValue = [](mlir::Value* value) -> llvm::Optional<int64_t> {
+ auto getConstantValue = [](mlir::Value value) -> llvm::Optional<int64_t> {
auto definingOp = value->getDefiningOp();
if (!definingOp) return llvm::None;
auto constantOp = llvm::dyn_cast<mlir::ConstantOp>(definingOp);
@@ -145,7 +144,7 @@
// same address with the stored value. This needs generalization.
struct StoreForwardingPass : mlir::FunctionPass<StoreForwardingPass> {
void runOnFunction() override {
- absl::flat_hash_map<mlir::Value*, mlir::Operation*> memrefToAllocOp;
+ llvm::DenseMap<mlir::Value, mlir::Operation*> memrefToAllocOp;
getFunction().walk([&](mlir::LoadOp loadOp) {
auto* block = loadOp.getOperation()->getBlock();
@@ -180,7 +179,7 @@
// Recursively checks defining ops until finds AllocOp. Return either AllocOp
// if it is found or nullptr.
- mlir::Operation* SearchAllocOp(mlir::Value* memref) {
+ mlir::Operation* SearchAllocOp(mlir::Value memref) {
mlir::Operation* defOp = memref->getDefiningOp();
while (auto subviewOp = mlir::dyn_cast_or_null<mlir::SubViewOp>(defOp)) {
defOp = subviewOp.source()->getDefiningOp();
@@ -193,8 +192,8 @@
// Retrieves AllocOp from the cache or actually looks for it.
mlir::Operation* GetAllocOp(
- mlir::Value* memref,
- absl::flat_hash_map<mlir::Value*, mlir::Operation*>* memrefToAllocOp) {
+ mlir::Value memref,
+ llvm::DenseMap<mlir::Value, mlir::Operation*>* memrefToAllocOp) {
auto allocOpIt = memrefToAllocOp->find(memref);
if (allocOpIt != memrefToAllocOp->end()) {
return allocOpIt->second;
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc
index fd38cd3..8e8af22 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc
@@ -59,15 +59,24 @@
// TODO(b/137624192) Use tablegen for this.
Status InsertMlirOp(HloOpcode opcode, OpBuilder func_builder, Location loc,
- ArrayRef<Type> rets, ArrayRef<Value*> args,
+ ArrayRef<Type> rets, ArrayRef<Value> args,
ArrayRef<std::pair<Identifier, Attribute>> attrs) {
switch (opcode) {
+ case HloOpcode::kAbs:
+ func_builder.create<lhlo::AbsOp>(loc, rets, args, attrs);
+ break;
case HloOpcode::kAdd:
func_builder.create<lhlo::AddOp>(loc, rets, args, attrs);
break;
case HloOpcode::kAnd:
func_builder.create<lhlo::AndOp>(loc, rets, args, attrs);
break;
+ case HloOpcode::kCeil:
+ func_builder.create<lhlo::CeilOp>(loc, rets, args, attrs);
+ break;
+ case HloOpcode::kCos:
+ func_builder.create<lhlo::CosOp>(loc, rets, args, attrs);
+ break;
case HloOpcode::kDivide:
func_builder.create<lhlo::DivOp>(loc, rets, args, attrs);
break;
@@ -83,12 +92,24 @@
case HloOpcode::kMultiply:
func_builder.create<lhlo::MulOp>(loc, rets, args, attrs);
break;
+ case HloOpcode::kNegate:
+ func_builder.create<lhlo::NegOp>(loc, rets, args, attrs);
+ break;
+ case HloOpcode::kRemainder:
+ func_builder.create<lhlo::RemOp>(loc, rets, args, attrs);
+ break;
case HloOpcode::kSelect:
func_builder.create<lhlo::SelectOp>(loc, rets, args, attrs);
break;
+ case HloOpcode::kSign:
+ func_builder.create<lhlo::SignOp>(loc, rets, args, attrs);
+ break;
case HloOpcode::kSubtract:
func_builder.create<lhlo::SubOp>(loc, rets, args, attrs);
break;
+ case HloOpcode::kTanh:
+ func_builder.create<lhlo::TanhOp>(loc, rets, args, attrs);
+ break;
default:
return tensorflow::errors::Internal(absl::StrCat(
"LHLO opcode ", HloOpcodeString(opcode), " is not supported."));
@@ -168,8 +189,8 @@
Status LhloDialectEmitter::DefaultAction(HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*instr));
OpBuilder func_builder(function.getBody());
- llvm::SmallVector<Value*, 4> arg_values{function.args_begin(),
- function.args_end()};
+ llvm::SmallVector<Value, 4> arg_values{function.args_begin(),
+ function.args_end()};
TF_RETURN_IF_ERROR(InsertMlirOp(instr->opcode(), func_builder,
getLocation(instr), ArrayRef<Type>{},
arg_values, llvm::None));
@@ -197,7 +218,7 @@
// Load the HLO argument tensors from the corresponding buffers. The last
// argument is for the result, so no need to load it.
OpBuilder body_builder(fusion_op.region());
- llvm::SmallVector<Value*, 4> arg_values;
+ llvm::SmallVector<Value, 4> arg_values;
for (int i = 0, e = function.getNumArguments() - 1; i < e; ++i) {
arg_values.push_back(body_builder.create<::mlir::TensorLoadOp>(
getLocation(fusion), function.getArgument(i)));
@@ -211,7 +232,7 @@
// Insert the write-back from the HLO computation to the result argument
// buffer.
body_builder.setInsertionPoint(fusion_op.region().back().getTerminator());
- Value* result_memref = function.getArgument(function.getNumArguments() - 1);
+ Value result_memref = function.getArgument(function.getNumArguments() - 1);
body_builder.create<::mlir::TensorStoreOp>(getLocation(fusion), result,
result_memref);
@@ -220,8 +241,8 @@
Status LhloDialectEmitter::HandleReduce(HloInstruction* reduce) {
TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*reduce));
- llvm::SmallVector<Value*, 4> arg_values{function.args_begin(),
- function.args_end()};
+ llvm::SmallVector<Value, 4> arg_values{function.args_begin(),
+ function.args_end()};
OpBuilder builder(function.getBody());
auto loc = getLocation(reduce);
int input_count = reduce->operand_count() / 3;
@@ -239,7 +260,7 @@
OpBuilder body_builder(reduce_op.body());
auto block = body_builder.getInsertionBlock();
auto to_apply = reduce->to_apply();
- llvm::SmallVector<Value*, 4> reduce_arg_values;
+ llvm::SmallVector<Value, 4> reduce_arg_values;
// First map parameters to memrefs on the operation.
for (auto param : to_apply->parameter_instructions()) {
TF_ASSIGN_OR_RETURN(auto arg_type, ConvertShapeToType<MemRefType>(
@@ -280,8 +301,8 @@
TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*compare));
OpBuilder func_builder(function.getBody());
- llvm::SmallVector<Value*, 4> arg_values{function.args_begin(),
- function.args_end()};
+ llvm::SmallVector<Value, 4> arg_values{function.args_begin(),
+ function.args_end()};
func_builder.create<lhlo::CompareOp>(getLocation(compare), llvm::None,
arg_values, comparison_direction_attr);
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc
index d332392..cde08fc 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc
@@ -213,7 +213,7 @@
}
using OperandToValueMap =
- absl::flat_hash_map<const HloInstruction*, std::vector<BlockArgument*>>;
+ absl::flat_hash_map<const HloInstruction*, std::vector<BlockArgument>>;
static StatusOr<std::vector<const HloInstruction*>> ComputeOperandToValueMap(
OperandToValueMap* operand_to_value_map, const HloInstruction* instr,
@@ -224,7 +224,7 @@
for (int kernel_index = 0; kernel_index < launchOp.getNumKernelOperands();
++kernel_index) {
auto launchop_operand =
- dyn_cast<BlockArgument>(launchOp.getKernelOperand(kernel_index));
+ launchOp.getKernelOperand(kernel_index)->dyn_cast<BlockArgument>();
if (!launchop_operand) {
launchOp.emitError("argument to kernel is not a function input");
has_failed = true;
@@ -272,7 +272,7 @@
std::vector<mlir::Type> as_mlir_types(new_arg_types.begin(),
new_arg_types.end());
auto new_args = kernel.front().addArguments(as_mlir_types);
- std::vector<Value*> buffer_args(new_args.begin(), new_args.end());
+ std::vector<Value> buffer_args(new_args.begin(), new_args.end());
auto zero = builder.create<mlir::LLVM::ConstantOp>(
loc, offset_type, builder.getI64IntegerAttr(0));
@@ -310,23 +310,21 @@
builder.create<mlir::LLVM::AllocaOp>(loc, target_type, one, 0);
// Fill the base and aligned pointers.
auto casted = builder.create<mlir::LLVM::BitcastOp>(
- loc, struct_type.getStructElementType(0),
- llvm::ArrayRef<Value*>{ptr});
+ loc, struct_type.getStructElementType(0), llvm::ArrayRef<Value>{ptr});
auto structPtrAddr = builder.create<mlir::LLVM::GEPOp>(
loc, struct_type.getStructElementType(0), descPtr,
- llvm::ArrayRef<Value*>{zero, baseIndex});
+ llvm::ArrayRef<Value>{zero, baseIndex});
builder.create<mlir::LLVM::StoreOp>(loc, casted, structPtrAddr);
casted = builder.create<mlir::LLVM::BitcastOp>(
- loc, struct_type.getStructElementType(1),
- llvm::ArrayRef<Value*>{ptr});
+ loc, struct_type.getStructElementType(1), llvm::ArrayRef<Value>{ptr});
structPtrAddr = builder.create<mlir::LLVM::GEPOp>(
loc, struct_type.getStructElementType(1), descPtr,
- llvm::ArrayRef<Value*>{zero, dataIndex});
+ llvm::ArrayRef<Value>{zero, dataIndex});
builder.create<mlir::LLVM::StoreOp>(loc, casted, structPtrAddr);
// Fill the offset value.
auto structOffsetAddr = builder.create<mlir::LLVM::GEPOp>(
loc, struct_type.getStructElementType(1), descPtr,
- llvm::ArrayRef<Value*>{zero, offsetIndex});
+ llvm::ArrayRef<Value>{zero, offsetIndex});
builder.create<mlir::LLVM::StoreOp>(loc, offset, structOffsetAddr);
// Fill the shape.
auto shape = operand->shape();
@@ -341,7 +339,7 @@
loc, offset_type, builder.getI64IntegerAttr(extent.index()));
auto shapeEntryPtr = builder.create<mlir::LLVM::GEPOp>(
loc, entry_type, descPtr,
- llvm::ArrayRef<Value*>{zero, shapeIndex, index});
+ llvm::ArrayRef<Value>{zero, shapeIndex, index});
auto extentValue = builder.create<mlir::LLVM::ConstantOp>(
loc, entry_type, builder.getI64IntegerAttr(extent.value()));
builder.create<mlir::LLVM::StoreOp>(loc, extentValue, shapeEntryPtr);
@@ -349,13 +347,13 @@
// Finally, fill the strides.
// TODO(b/137624192): Take assigned layout into account.
entry_type = struct_type.getStructElementType(4).getArrayElementType();
- Value* accumulator = nullptr;
+ Value accumulator = nullptr;
for (int64 idx = shape.rank() - 1; idx >= 0; --idx) {
auto indexValue = builder.create<mlir::LLVM::ConstantOp>(
loc, offset_type, builder.getI64IntegerAttr(idx));
auto strideEntryPtr = builder.create<mlir::LLVM::GEPOp>(
loc, entry_type, descPtr,
- llvm::ArrayRef<Value*>{zero, strideIndex, indexValue});
+ llvm::ArrayRef<Value>{zero, strideIndex, indexValue});
if (accumulator) {
auto strideValue = builder.create<mlir::LLVM::ConstantOp>(
loc, entry_type,
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc b/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc
index 505d16d..afcac65 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc
@@ -393,5 +393,104 @@
)");
}
+TEST_F(LhloGenTest, Abs) {
+ CompileAndVerifyIr(R"(
+HloModule Abs
+ENTRY %Abs (val: f32[2,2]) -> f32[2,2] {
+ %val = f32[2,2]{1,0} parameter(0)
+ ROOT %abs = f32[2,2]{1,0} abs(f32[2,2]{1,0} %val)
+})",
+ R"(
+;CHECK: func @abs(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
+;CHECK: "xla_lhlo.abs"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
+;CHECK: }
+ )");
+}
+
+TEST_F(LhloGenTest, Ceil) {
+ CompileAndVerifyIr(R"(
+HloModule Ceil
+ENTRY %Ceil (val: f32[2,2]) -> f32[2,2] {
+ %val = f32[2,2]{1,0} parameter(0)
+ ROOT %ceil = f32[2,2]{1,0} ceil(f32[2,2]{1,0} %val)
+})",
+ R"(
+;CHECK: func @ceil(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
+;CHECK: "xla_lhlo.ceil"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
+;CHECK: }
+ )");
+}
+
+TEST_F(LhloGenTest, Cos) {
+ CompileAndVerifyIr(R"(
+HloModule Cos
+ENTRY %Cos (val: f32[2,2]) -> f32[2,2] {
+ %val = f32[2,2]{1,0} parameter(0)
+ ROOT %cos = f32[2,2]{1,0} cosine(f32[2,2]{1,0} %val)
+})",
+ R"(
+;CHECK: func @cosine(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
+;CHECK: "xla_lhlo.cos"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
+;CHECK: }
+ )");
+}
+
+TEST_F(LhloGenTest, Neg) {
+ CompileAndVerifyIr(R"(
+HloModule Neg
+ENTRY %Neg (val: f32[2,2]) -> f32[2,2] {
+ %val = f32[2,2]{1,0} parameter(0)
+ ROOT %neg = f32[2,2]{1,0} negate(f32[2,2]{1,0} %val)
+})",
+ R"(
+;CHECK: func @negate(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
+;CHECK: "xla_lhlo.neg"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
+;CHECK: }
+ )");
+}
+
+TEST_F(LhloGenTest, Rem) {
+ CompileAndVerifyIr(R"(
+HloModule Rem
+ENTRY %Rem(x: f32[2,2], y: f32[2,2]) -> f32[2,2] {
+ %x = f32[2,2]{1,0} parameter(0)
+ %y = f32[2,2]{1,0} parameter(1)
+ ROOT %rem = f32[2,2]{1,0} remainder(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y)
+})",
+ R"(
+;CHECK: func @remainder(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]]) {
+;CHECK: "xla_lhlo.remainder"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
+;CHECK: }
+ )");
+}
+
+TEST_F(LhloGenTest, Sign) {
+ CompileAndVerifyIr(R"(
+HloModule Sign
+ENTRY %Sign (val: f32[2,2]) -> f32[2,2] {
+ %val = f32[2,2]{1,0} parameter(0)
+ ROOT %sign = f32[2,2]{1,0} sign(f32[2,2]{1,0} %val)
+})",
+ R"(
+;CHECK: func @sign(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
+;CHECK: "xla_lhlo.sign"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
+;CHECK: }
+ )");
+}
+
+TEST_F(LhloGenTest, Tanh) {
+ CompileAndVerifyIr(R"(
+HloModule Tanh
+ENTRY %Tanh (val: f32[2,2]) -> f32[2,2] {
+ %val = f32[2,2]{1,0} parameter(0)
+ ROOT %tanh = f32[2,2]{1,0} tanh(f32[2,2]{1,0} %val)
+})",
+ R"(
+;CHECK: func @tanh(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) {
+;CHECK: "xla_lhlo.tanh"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> ()
+;CHECK: }
+ )");
+}
+
} // namespace mlir_gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc
index 41e2b0e..16e3433 100644
--- a/tensorflow/compiler/xla/service/multi_output_fusion.cc
+++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc
@@ -151,6 +151,37 @@
return remaining;
}
+HloInstruction* MultiOutputFusion::CreateFusion(HloInstruction* base,
+ HloInstruction* to_fuse) {
+ HloInstruction* input_fusion =
+ computation()->AddInstruction(HloInstruction::CreateFusion(
+ base->shape(), HloInstruction::FusionKind::kLoop, base));
+
+ // Update candidate_ and all_fusion_candidates_.
+ std::vector<std::pair<HloInstruction*, int64>> new_fusibles =
+ GetNewFusibles(base, to_fuse);
+ int64 index;
+ if (candidates_index_.contains(input_fusion)) {
+ index = candidates_index_[input_fusion];
+ } else {
+ index = candidates_.size();
+ InsertOrDie(&candidates_index_, input_fusion, index);
+ candidates_.emplace_back(input_fusion);
+ all_fusion_candidates_.push_back(input_fusion);
+ }
+
+ // Update the worklist_.
+ FusionCandidate& candidate_node = candidates_[index];
+ for (auto it : new_fusibles) {
+ candidate_node.fusibles.emplace_back(it.first, it.second);
+ worklist_.emplace(input_fusion, it.first, it.second);
+ }
+
+ reachability_->Replace(base, input_fusion);
+ TF_CHECK_OK(computation()->ReplaceInstruction(base, input_fusion));
+ return input_fusion;
+}
+
bool MultiOutputFusion::IsProfitableOperand(HloInstruction* instr) {
// kConstant instruction will not have memory reads, so it won't be a profit
// source. Skip them.
@@ -167,29 +198,12 @@
return true;
}
-void MultiOutputFusion::Update(HloInstruction* instr1, HloInstruction* instr2) {
- HloInstruction* fusion = instr1;
- HloInstruction* fused = instr2;
- if (is_fused(instr1)) {
- fusion = instr2;
- fused = instr1;
- }
-
- // Insert the newly created instruction (if any), to candidates_.
- for (auto use : fusion->users()) {
- if (candidates_index_.find(use) == candidates_index_.end()) {
- int64 index = candidates_.size();
- candidates_.emplace_back(use);
- InsertOrDie(&candidates_index_, use, index++);
- }
- }
+std::vector<std::pair<HloInstruction*, int64>>
+MultiOutputFusion::GetNewFusibles(HloInstruction* fusion,
+ HloInstruction* fused) {
FusionCandidate& fusion_node = candidates_[get_candidate_id(fusion)];
FusionCandidate& fused_node = candidates_[get_candidate_id(fused)];
- // Update the reachability graph.
- UpdateReachability(fusion, fused, all_fusion_candidates_,
- [this](HloInstruction* instr) { return is_fused(instr); });
-
// Update the fusible list for fusion. Variable new_fusibles keeps
// track of the new or changed entries.
std::vector<std::pair<HloInstruction*, int64>> new_fusibles;
@@ -227,6 +241,33 @@
}
fused_node.fusibles.clear();
+ return new_fusibles;
+}
+
+void MultiOutputFusion::Update(HloInstruction* instr1, HloInstruction* instr2) {
+ HloInstruction* fusion = instr1;
+ HloInstruction* fused = instr2;
+ if (is_fused(instr1)) {
+ fusion = instr2;
+ fused = instr1;
+ }
+
+ // Insert the newly created instruction (if any), to candidates_.
+ for (auto use : fusion->users()) {
+ if (candidates_index_.find(use) == candidates_index_.end()) {
+ int64 index = candidates_.size();
+ candidates_.emplace_back(use);
+ InsertOrDie(&candidates_index_, use, index++);
+ }
+ }
+
+ // Update the reachability graph.
+ UpdateReachability(fusion, fused, all_fusion_candidates_,
+ [this](HloInstruction* instr) { return is_fused(instr); });
+
+ std::vector<std::pair<HloInstruction*, int64>> new_fusibles =
+ GetNewFusibles(fusion, fused);
+
// Update the worklist_.
for (auto it : new_fusibles) {
worklist_.emplace(fusion, it.first, it.second);
@@ -235,10 +276,15 @@
bool MultiOutputFusion::LegalToFuse(HloInstruction* instr1,
HloInstruction* instr2) {
- if (instr1 == instr2) {
+ if (instr1->opcode() != HloOpcode::kFusion) {
return false;
}
- if (instr1->opcode() != HloOpcode::kFusion) {
+ return LegalToFuseMainConstraints(instr1, instr2);
+}
+
+bool MultiOutputFusion::LegalToFuseMainConstraints(HloInstruction* instr1,
+ HloInstruction* instr2) {
+ if (instr1 == instr2) {
return false;
}
@@ -342,7 +388,12 @@
}
Update(instr1, instr2);
HloInstruction* ret = Fuse(instr1, instr2);
- set_is_fused(ret == instr1 ? instr2 : instr1);
+ if (ret != instr1) {
+ set_is_fused(instr1);
+ }
+ if (ret != instr2) {
+ set_is_fused(instr2);
+ }
changed = true;
VLOG(2) << "After fusion, \t this: " << ret->name() << "\n"
<< ret->fused_instructions_computation()->ToString(
diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h
index 9be69f8..55cb15e 100644
--- a/tensorflow/compiler/xla/service/multi_output_fusion.h
+++ b/tensorflow/compiler/xla/service/multi_output_fusion.h
@@ -79,6 +79,11 @@
// Test if it's legal to fuse instr1 and instr2 into one fusion instruction.
virtual bool LegalToFuse(HloInstruction* instr1, HloInstruction* instr2);
+ // Test if it's legal to fuse instr1 and instr2 into one fusion instruction
+ // using main constraints.
+ bool LegalToFuseMainConstraints(HloInstruction* instr1,
+ HloInstruction* instr2);
+
// Fuse HloInstruction instr1 and instr2 and return the fused instruction.
// The other instruction is removed from its parent computation.
virtual HloInstruction* Fuse(HloInstruction* instr1, HloInstruction* instr2);
@@ -105,6 +110,17 @@
// InstructionFusion instead.
virtual bool DoProducerConsumerMultiOutputFusion();
+ // Return a list of new fusible instructions that can be fused into `fusion'
+ // fused with `fused'. The second entry in the vector is a profit value from
+ // fusing the corresponding instruction.
+ std::vector<std::pair<HloInstruction*, int64>> GetNewFusibles(
+ HloInstruction* fusion, HloInstruction* fused);
+
+ // Create a new fusion instruction and add `base' into it.
+ // Prepare for fusing `to_fuse' into the created fusion by updating
+ // reachability, worklist, and fusion candidates.
+ HloInstruction* CreateFusion(HloInstruction* base, HloInstruction* to_fuse);
+
private:
// An internal data structure for each instruction in current computation.
// When an instruction is removed, member 'hlo' is set to nullptr.
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index ec6a97e..4ce34ea 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -1720,7 +1720,8 @@
const int64 kernel_output_features =
rhs.dimensions(dnums.kernel_output_feature_dimension());
- if (batch_group_count > 1 && kernel_output_features != batch_group_count) {
+ if (batch_group_count > 1 &&
+ kernel_output_features % batch_group_count != 0) {
return InvalidArgument(
"Expected output feature dimension size (value %d) to be equal to "
"batch group count %d; got <conv>(%s, %s)\n"
@@ -1759,7 +1760,7 @@
dnums.DebugString());
}
- if (input_batch % batch_group_count > 0) {
+ if (input_batch % batch_group_count != 0) {
return InvalidArgument(
"Expected input batch dimension (value %d) to be divisible by "
"batch_group_count (value %d); "
@@ -1793,6 +1794,13 @@
std::vector<int64> dimensions(num_dims);
dimensions[dnums.output_batch_dimension()] = input_batch / batch_group_count;
dimensions[dnums.output_feature_dimension()] = kernel_output_features;
+
+ if (batch_group_count > 1) {
+ dimensions[dnums.output_batch_dimension()] =
+ kernel_output_features / batch_group_count;
+ dimensions[dnums.output_feature_dimension()] = batch_group_count;
+ }
+
for (int i = 0; i < num_spatial_dims; ++i) {
dimensions[dnums.output_spatial_dimensions(i)] =
window_output_shape.dimensions(i);
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc
old mode 100644
new mode 100755
index 17e3760..0746588
--- a/tensorflow/compiler/xla/tests/hlo_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc
@@ -364,7 +364,6 @@
instruction->set_raw_backend_config_string(backend_config);
}
- // return ::testing::AssertionSuccess();
auto output = test_runner_.Execute(std::move(module), fake_argument_ptrs,
/*run_hlo_passes=*/run_hlo_passes,
/*profile=*/profile);
@@ -501,6 +500,19 @@
return nullptr;
}
+HloInstruction* HloTestBase::FindInstruction(HloModule* module,
+ HloOpcode opcode) {
+ for (const HloComputation* c : module->computations()) {
+ auto instructions = c->instructions();
+ auto it = absl::c_find_if(
+ instructions, [&](HloInstruction* i) { return i->opcode() == opcode; });
+ if (it != instructions.end()) {
+ return *it;
+ }
+ }
+ return nullptr;
+}
+
Backend& HloTestBase::backend() { return test_runner_.backend(); }
/* static */
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h
old mode 100644
new mode 100755
index 848b334..45917f3
--- a/tensorflow/compiler/xla/tests/hlo_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.h
@@ -274,6 +274,8 @@
// inspect a particular computation or instruction.
HloComputation* FindComputation(HloModule* module, absl::string_view name);
HloInstruction* FindInstruction(HloModule* module, absl::string_view name);
+ // Gets the instruction from the given module with the given opcode.
+ HloInstruction* FindInstruction(HloModule* module, HloOpcode opcode);
// Return an HLO verifier constructed for the test backend.
HloVerifier& verifier() const { return *hlo_verifier_; }
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index 4563d7e..c160d6c 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -218,6 +218,23 @@
}
}
+// uniform_int_distribution is not defined for 8-bit integers.
+// Use 'short' for those types.
+template <typename IntT>
+struct RngT {
+ using type = IntT;
+};
+
+template <>
+struct RngT<int8> {
+ using type = int16;
+};
+
+template <>
+struct RngT<uint8> {
+ using type = uint16;
+};
+
template <typename IntT>
void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine,
bool no_duplicates) {
@@ -230,7 +247,7 @@
std::shuffle(literal->data<IntT>().begin(), literal->data<IntT>().end(),
*engine);
} else {
- std::uniform_int_distribution<IntT> generator(
+ std::uniform_int_distribution<typename RngT<IntT>::type> generator(
std::numeric_limits<IntT>::lowest(), std::numeric_limits<IntT>::max());
for (IntT& value : literal->data<IntT>()) {
value = generator(*engine);
@@ -341,7 +358,7 @@
CHECK(engine != nullptr);
CHECK_EQ(literal->shape().element_type(),
primitive_util::NativeToPrimitiveType<IntT>());
- std::uniform_int_distribution<IntT> generator(min, max);
+ std::uniform_int_distribution<typename RngT<IntT>::type> generator(min, max);
for (IntT& value : literal->data<IntT>()) {
value = generator(*engine);
}
diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD
index 603e94c..db819c3 100644
--- a/tensorflow/compiler/xla/tools/BUILD
+++ b/tensorflow/compiler/xla/tools/BUILD
@@ -206,6 +206,7 @@
":hlo_extractor",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
)
diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc
index 0956550..639f91b 100644
--- a/tensorflow/compiler/xla/tools/replay_computation.cc
+++ b/tensorflow/compiler/xla/tools/replay_computation.cc
@@ -349,7 +349,7 @@
tensorflow::tstring record;
while (reader.ReadRecord(&offset, &record).ok()) {
HloSnapshot snapshot;
- if (snapshot.mutable_hlo()->ParseFromStringPiece(record)) {
+ if (snapshot.mutable_hlo()->ParseFromString(record)) {
snapshots.push_back(std::move(snapshot));
} else {
LOG(ERROR) << "Encountered bad proto";
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 9a1b0a5..fbdcb4d 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -88,15 +88,29 @@
"tf_opts_nortti_if_emscripten",
"transitive_hdrs",
)
+
+# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "if_nccl")
+
+# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "tensorflow_opensource_extra_deps")
+# buildifier: disable=same-origin-load
# load("//tensorflow:tensorflow.bzl", "tf_android_full_lite_protos")
+
+# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu")
+
+# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "tf_cc_tests_gpu")
+
+# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
+# buildifier: disable=same-origin-load
# Placeholder: load("//tensorflow:tensorflow.bzl", "tf_portable_proto_lib")
+
+# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "tf_portable_proto_library")
# For platform specific build config
@@ -310,7 +324,6 @@
"//tensorflow/core/platform:threadpool_interface",
"//tensorflow/core/platform:threadpool_options",
"//tensorflow/core/platform:types",
- "//tensorflow/core/platform/default/build_config:base",
"@com_google_absl//absl/base",
"@com_google_absl//absl/strings",
],
@@ -1432,7 +1445,7 @@
"//tensorflow/core/lib/random:legacy_lib_random_all_srcs",
"//tensorflow/core/lib/strings:legacy_lib_strings_all_headers",
"//tensorflow/core/lib/strings:legacy_lib_strings_all_srcs",
- "//tensorflow/core/platform/default/build_config:android_srcs",
+ "//tensorflow/core/platform:legacy_mobile_srcs",
"//tensorflow/core/profiler:mobile_srcs",
"//tensorflow/core/public:mobile_srcs_no_runtime",
"//tensorflow/core/util/ctc:android_srcs",
@@ -1743,7 +1756,7 @@
visibility = ["//visibility:public"],
deps = [
":android_tensorflow_lib",
- ":protos_cc",
+ ":protos_all_cc",
"//tensorflow/core/platform/default/build_config:gtest",
"//third_party/eigen3",
],
@@ -2236,7 +2249,7 @@
visibility = ["//visibility:public"],
deps = [
":platform_base",
- "//tensorflow/core/platform/default/build_config:logging",
+ "//tensorflow/core/platform:logging",
],
)
@@ -2269,8 +2282,8 @@
":core_stringpiece",
"//tensorflow/core/platform:dynamic_annotations",
"//tensorflow/core/platform:jpeg",
+ "//tensorflow/core/platform:logging",
"//tensorflow/core/platform:stringpiece",
- "//tensorflow/core/platform/default/build_config:logging",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings",
],
@@ -2304,10 +2317,10 @@
"//tensorflow/core/lib/strings:strcat",
"//tensorflow/core/platform:dynamic_annotations",
"//tensorflow/core/platform:gif",
+ "//tensorflow/core/platform:logging",
"//tensorflow/core/platform:numbers",
"//tensorflow/core/platform:strcat",
"//tensorflow/core/platform:stringpiece",
- "//tensorflow/core/platform/default/build_config:logging",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings",
],
@@ -2519,6 +2532,7 @@
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/core/util:port",
"//tensorflow/core/util:stats_calculator_portable",
+ "//tensorflow/compiler/jit:common",
] + if_static(
extra_deps = ["@com_google_protobuf//:protobuf"],
otherwise = ["@com_google_protobuf//:protobuf_headers"],
@@ -2588,13 +2602,6 @@
visibility = ["//visibility:public"],
)
-# TODO(josh11b): Is this needed, or can we just use ":protos_all_cc"?
-cc_library(
- name = "protos_cc",
- visibility = ["//visibility:public"],
- deps = ["//tensorflow/core/platform/default/build_config:protos_cc"],
-)
-
# Library containing all of the graph construction code that is
# independent of the runtime.
#
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesCalculateBestFeatureSplitV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCalculateBestFeatureSplitV2.pbtxt
new file mode 100644
index 0000000..2bbaba2
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCalculateBestFeatureSplitV2.pbtxt
@@ -0,0 +1,124 @@
+op {
+ graph_op_name: "BoostedTreesCalculateBestFeatureSplitV2"
+ visibility: HIDDEN
+ in_arg {
+ name: "node_id_range"
+ description: <<END
+A Rank 1 tensor (shape=[2]) to specify the range [first, last) of node ids to process within `stats_summary_list`. The nodes are iterated between the two nodes specified by the tensor, as like `for node_id in range(node_id_range[0], node_id_range[1])` (Note that the last index node_id_range[1] is exclusive).
+END
+ }
+ in_arg {
+ name: "stats_summaries_list"
+ description: <<END
+A list of Rank 4 tensor (#shape=[max_splits, feature_dims, bucket, stats_dims]) for accumulated stats summary (gradient/hessian) per node, per dimension, per buckets for each feature.
+The first dimension of the tensor is the maximum number of splits, and thus not all elements of it will be used, but only the indexes specified by node_ids will be used.
+END
+ }
+ in_arg {
+ name: "split_types"
+ description: <<END
+A Rank 1 tensor indicating if this Op should perform inequality split or equality split per feature.
+END
+ }
+ in_arg {
+ name: "candidate_feature_ids"
+ description: <<END
+Rank 1 tensor with ids for each feature. This is the real id of the feature.
+END
+ }
+ in_arg {
+ name: "l1"
+ description: <<END
+l1 regularization factor on leaf weights, per instance based.
+END
+ }
+ in_arg {
+ name: "l2"
+ description: <<END
+l2 regularization factor on leaf weights, per instance based.
+END
+ }
+ in_arg {
+ name: "tree_complexity"
+ description: <<END
+adjustment to the gain, per leaf based.
+END
+ }
+ in_arg {
+ name: "min_node_weight"
+ description: <<END
+mininum avg of hessians in a node before required for the node to be considered for splitting.
+END
+ }
+ out_arg {
+ name: "node_ids"
+ description: <<END
+A Rank 1 tensors indicating possible split node ids for each feature. The length of the list is num_features, but each tensor has different size as each feature provides different possible nodes. See above for details like shapes and sizes.
+END
+ }
+ out_arg {
+ name: "gains"
+ description: <<END
+A Rank 1 tensor indicating the best gains for each feature to split for certain nodes. See above for details like shapes and sizes.
+END
+ }
+ out_arg {
+ name: "feature_ids"
+ description: <<END
+A Rank 1 tensors indicating the best feature id for each node. See above for details like shapes and sizes.
+END
+ }
+ out_arg {
+ name: "feature_dimensions"
+ description: <<END
+A Rank 1 tensors indicating the best feature dimension for each feature to split for certain nodes if the feature is multi-dimension. See above for details like shapes and sizes.
+END
+ }
+ out_arg {
+ name: "thresholds"
+ description: <<END
+A Rank 1 tensors indicating the bucket id to compare with (as a threshold) for split in each node. See above for details like shapes and sizes.
+END
+ }
+ out_arg {
+ name: "left_node_contribs"
+ description: <<END
+A Rank 2 tensors indicating the contribution of the left nodes when branching from parent nodes (given by the tensor element in the output node_ids_list) to the left direction by the given threshold for each feature. This value will be used to make the left node value by adding to the parent node value. Second dimension size is 1 for 1-dimensional logits, but would be larger for multi-class problems. See above for details like shapes and sizes.
+END
+ }
+ out_arg {
+ name: "right_node_contribs"
+ description: <<END
+A Rank 2 tensors, with the same shape/conditions as left_node_contribs_list, but just that the value is for the right node.
+END
+ }
+ out_arg {
+ name: "split_with_default_directions"
+ description: <<END
+A Rank 1 tensors indicating the which direction to go if data is missing. See above for details like shapes and sizes.
+Inequality with default left returns 0, inequality with default right returns 1, equality with default right returns 2.
+END
+ }
+ attr {
+ name: "num_features"
+ description: <<END
+inferred from the size of `stats_summary_list`; the number of total features.
+END
+}
+ attr {
+ name: "logits_dimension"
+ description: <<END
+The dimension of logit, i.e., number of classes.
+END
+ }
+ summary: "Calculates gains for each feature and returns the best possible split information for each node. However, if no split is found, then no split information is returned for that node."
+ description: <<END
+The split information is the best threshold (bucket id), gains and left/right node contributions per node for each feature.
+
+It is possible that not all nodes can be split on each feature. Hence, the list of possible nodes can differ between the features. Therefore, we return `node_ids_list` for each feature, containing the list of nodes that this feature can be used to split.
+
+In this manner, the output is the best split per features and per node, so that it needs to be combined later to produce the best split for each node (among all possible features).
+
+The output shapes are compatible in a way that the first dimension of all tensors are the same and equal to the number of possible split nodes for each feature.
+END
+}
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 0d9f897..9731d74 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -1313,10 +1313,12 @@
options_.config.experimental().has_session_metadata()
? &options_.config.experimental().session_metadata()
: nullptr;
+ const CustomKernelCreator* custom_kernel_creator =
+ GetDefaultCustomKernelCreator();
func_info->proc_flr.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), options_.env, &options_.config, graph_def_version,
func_info->flib_def.get(), optimizer_opts, thread_pools_[0].first,
- nullptr, nullptr, session_metadata));
+ nullptr, custom_kernel_creator, session_metadata));
GraphOptimizer optimizer(optimizer_opts);
for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
diff --git a/tensorflow/core/common_runtime/eager/eager_executor.cc b/tensorflow/core/common_runtime/eager/eager_executor.cc
index a0071ce..930f70b 100644
--- a/tensorflow/core/common_runtime/eager/eager_executor.cc
+++ b/tensorflow/core/common_runtime/eager/eager_executor.cc
@@ -91,14 +91,11 @@
if (node->AsAsync() != nullptr) {
return errors::Internal("Executor does not support executing async nodes");
}
- Status s = status();
- if (!s.ok()) {
- return s;
- }
+ // NOTE: SyncExecute runs every node regardless of error status in executor.
uint64 id = next_node_id_++;
- s = node->Prepare();
+ Status s = node->Prepare();
if (!s.ok()) {
return s;
}
@@ -129,11 +126,8 @@
// Inline execution in sync mode.
if (!Async()) {
- status = this->status();
- if (status.ok()) {
- status = RunItem(std::move(item), false);
- }
- return status;
+ // In sync mode, run the node item regardless of executor status.
+ return RunItem(std::move(item), false);
} else {
tensorflow::mutex_lock l(node_queue_mutex_);
DVLOG(3) << "Add node [id " << item->id << "]" << item->node->DebugString()
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index 9584056..1d80f59 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -352,8 +352,8 @@
}
}
-Status ShouldCompileWithXLA(const EagerOperation* op, const EagerContext* ctx,
- bool* compile_with_xla) {
+Status MustCompileWithXLA(const EagerOperation* op, const EagerContext* ctx,
+ bool* compile_with_xla) {
if (!op->is_function()) {
*compile_with_xla = false;
return Status::OK();
@@ -368,7 +368,7 @@
}
// Does node have an explicit request to compile or not?
- Status status = op->Attrs().Get(kXlaCompileAttr, compile_with_xla);
+ Status status = op->Attrs().Get(kXlaMustCompileAttr, compile_with_xla);
if (status.ok()) {
DVLOG(2) << "Caller explicitly requested "
<< (*compile_with_xla ? "" : "not ")
@@ -383,7 +383,7 @@
return errors::NotFound("Failed to find function '", op->Name(), "'");
}
- status = GetNodeAttr(AttrSlice(&function_def->attr()), kXlaCompileAttr,
+ status = GetNodeAttr(AttrSlice(&function_def->attr()), kXlaMustCompileAttr,
compile_with_xla);
if (status.ok()) {
DVLOG(2) << "Function definition explicitly specifies "
@@ -511,12 +511,12 @@
bool run_function_with_flr = false;
if (op->is_function()) {
bool compile_with_xla;
- TF_RETURN_IF_ERROR(ShouldCompileWithXLA(op, ctx, &compile_with_xla));
+ TF_RETURN_IF_ERROR(MustCompileWithXLA(op, ctx, &compile_with_xla));
if (compile_with_xla) {
// Note that it is not ideal, but currently correct, to set this
// attribute after computing the kernel cache key above.
// Note: If the attribute is already set to true, this is a noop.
- op->MutableAttrs()->Set(kXlaCompileAttr, true);
+ op->MutableAttrs()->Set(kXlaMustCompileAttr, true);
} else {
run_function_with_flr = true;
}
diff --git a/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc b/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc
index 44d72fc..0a912b1 100644
--- a/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc
+++ b/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc
@@ -127,7 +127,6 @@
}
// Copy all attributes to the new op.
- string name;
const NodeDef& orig_ndef = orig_op->MutableAttrs()->BuildNodeDef();
AttrSlice attr_list(orig_ndef);
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 1c04adf..8141838 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -2532,7 +2532,7 @@
}
}
delete this;
- runner([=]() {
+ runner([step_id, status, done_cb = std::move(done_cb)]() {
profiler::TraceMe traceme(
[&] {
return absl::StrCat("ExecutorDoneCallback#id=", step_id, "#");
@@ -2548,10 +2548,10 @@
// devices like GPUs that continue to execute Ops after their Compute
// methods have completed, this ensures that control is not returned to
// the user until the step (and its side-effects) has actually completed.
- device->Sync([=](Status new_status) mutable {
- status.Update(new_status);
+ device->Sync([this, step_id, runner = std::move(runner),
+ done_cb = std::move(done_cb)](const Status& status) mutable {
delete this;
- runner([=]() {
+ runner([step_id, status, done_cb = std::move(done_cb)]() {
profiler::TraceMe traceme(
[&] {
return absl::StrCat("ExecutorDoneCallback#id=", step_id, "#");
@@ -2562,7 +2562,7 @@
});
} else {
delete this;
- runner([=]() {
+ runner([step_id, status, done_cb = std::move(done_cb)]() {
profiler::TraceMe traceme(
[&] {
return absl::StrCat("ExecutorDoneCallback#id=", step_id, "#");
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index 89e4daa..c124719 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -275,7 +275,6 @@
opts.runner = nullptr;
}
Notification done;
- std::vector<Tensor> out;
Status status;
flr->Run(opts, handle, frame, [&status, &done](const Status& s) {
status = s;
diff --git a/tensorflow/core/common_runtime/function_threadpool_test.cc b/tensorflow/core/common_runtime/function_threadpool_test.cc
index 8f31cda..7198341 100644
--- a/tensorflow/core/common_runtime/function_threadpool_test.cc
+++ b/tensorflow/core/common_runtime/function_threadpool_test.cc
@@ -171,7 +171,6 @@
opts.runner = nullptr;
}
Notification done;
- std::vector<Tensor> out;
Status status;
flr->Run(opts, handle, frame, [&status, &done](const Status& s) {
status = s;
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index 2287bf8..eaf16d2 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -833,6 +833,9 @@
// RAM and Video RAM
min_system_memory = 1 << 30;
#endif
+
+ VLOG(5) << "available_memory = " << available_memory;
+ VLOG(5) << "min_system_memory = " << min_system_memory;
return min_system_memory;
}
@@ -1186,7 +1189,7 @@
", name: ", desc.name(),
", pci bus id: ", desc.pci_bus_id(),
", compute capability: ", cc_major, ".", cc_minor);
- // LINT.ThenChange(//tensorflow/python/platform/test.py)
+ // LINT.ThenChange(//tensorflow/python/framework/gpu_util.py)
#elif TENSORFLOW_USE_ROCM
return strings::StrCat("device: ", platform_gpu_id.value(),
", name: ", desc.name(),
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc
index 7bd5d09..671f067 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc
@@ -843,7 +843,7 @@
auto attrs = AttrSlice(&shard.attr());
VLOG(1) << "Start instantiating component function " << unique_name
<< " on device " << target;
- VLOG(2) << DebugString(shard);
+ VLOG(4) << DebugString(shard);
auto* component_handle = new FunctionLibraryRuntime::Handle;
auto done = [this, status, unique_name, comp_data, component_handle,
@@ -851,7 +851,7 @@
status->Update(s);
VLOG(1) << "Finished instantiating component function " << unique_name
- << "with handle " << *component_handle << " status: " << s;
+ << " with handle " << *component_handle << " status: " << s;
if (status->ok()) {
{
mutex_lock l(mu_);
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index 9c95c29..897efc0 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -31,7 +31,6 @@
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/cost_graph.pb.h"
-#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor.h"
@@ -473,8 +472,8 @@
c->req.set_session_handle(session_handle_);
c->req.set_create_worker_session_called(!should_deregister_);
c->req.mutable_graph_def()->Swap(&graph_partitions[part.name]);
- StripDefaultAttributes(*OpRegistry::Global(),
- c->req.mutable_graph_def()->mutable_node());
+ // TODO(b/146354085): Default attributes should be stripped here from
+ // c->req.graph_def(), but this causes some TFX pipelines to fail.
*c->req.mutable_config_proto() = session_opts_.config;
*c->req.mutable_graph_options() = session_opts_.config.graph_options();
*c->req.mutable_debug_options() =
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_testlib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_testlib.cc
index 9b118ce..7ffff94 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_testlib.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_testlib.cc
@@ -80,7 +80,6 @@
std::unique_ptr<GrpcSession> session;
TF_RETURN_IF_ERROR(GrpcSession::Create(options_copy, &session));
- std::vector<DeviceAttributes> device_attributes;
TF_RETURN_IF_ERROR(session->ListDevices(&ret->devices_));
*out_cluster = std::move(ret);
diff --git a/tensorflow/core/distributed_runtime/session_mgr.cc b/tensorflow/core/distributed_runtime/session_mgr.cc
index d6fb07f..e2151e0 100644
--- a/tensorflow/core/distributed_runtime/session_mgr.cc
+++ b/tensorflow/core/distributed_runtime/session_mgr.cc
@@ -199,7 +199,6 @@
}
protobuf::RepeatedPtrField<DeviceAttributes> added_cluster_device_attrs_pb(
added_cluster_device_attrs.begin(), added_cluster_device_attrs.end());
- std::unique_ptr<DeviceMgr> remote_devices;
AsRemoteDevices(worker_env_->env, added_cluster_device_attrs_pb, nullptr,
&added_remote_devices);
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index 67918fe..4a45691 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -71,24 +71,47 @@
// Interface for reading values from a key-value store.
// Used for restoring iterator state.
+// Please see comment on IteratorStateWriter for guidance around using the
+// Read*(key, val) vs Read*(name, key, val).
class IteratorStateReader {
public:
virtual Status ReadScalar(StringPiece key, int64* val) = 0;
virtual Status ReadScalar(StringPiece key, tstring* val) = 0;
virtual Status ReadTensor(StringPiece key, Tensor* val) = 0;
+
+ virtual Status ReadScalar(StringPiece name, StringPiece key, int64* val) = 0;
+ virtual Status ReadScalar(StringPiece name, StringPiece key,
+ tstring* val) = 0;
+ virtual Status ReadTensor(StringPiece name, StringPiece key, Tensor* val) = 0;
+
virtual bool Contains(StringPiece key) = 0;
+ virtual bool Contains(StringPiece name, StringPiece key) = 0;
virtual ~IteratorStateReader() {}
};
// Interface for writing values to a key-value store.
// Used for saving iterator state.
+// The IteratorStateWriter creates a tensor for each unique iterator name it
+// sees. For the Write*(key, val) API's the key is expected to encode this
+// name as keys are required to be produced using the full_name() method.
+// Each tensor has an upper limit of 2 GB and so if the state for an iterator
+// might exceed the 2 GB limit, you can pass an explicit name in via the
+// Write*(name, key, val) APIs allowing you to further split up the state
+// into more manageable chunks.
class IteratorStateWriter {
public:
virtual Status WriteScalar(StringPiece key, const int64 val) = 0;
virtual Status WriteScalar(StringPiece key, const tstring& val) = 0;
virtual Status WriteTensor(StringPiece key, const Tensor& val) = 0;
+ virtual Status WriteScalar(StringPiece name, StringPiece key,
+ const int64 val) = 0;
+ virtual Status WriteScalar(StringPiece name, StringPiece key,
+ const tstring& val) = 0;
+ virtual Status WriteTensor(StringPiece name, StringPiece key,
+ const Tensor& val) = 0;
+
virtual ~IteratorStateWriter() {}
};
@@ -475,6 +498,14 @@
// latter makes sense to do when performing data agnostic graph rewrites to
// reduce the memory usage.
bool serialize_data_tensors = true;
+
+ // Indicates whether datasets that use random seeds should have the values
+ // of random seeds serialized or not. If the values of random seeds are
+ // serialized, the deserialized dataset will have the same seeds as the
+ // original dataset. Otherwise, the deserialized dataset will use different
+ // seeds. This param does not affect datasets that use fixed seeds; fixed
+ // seeds will always be preserved.
+ bool preserve_random_seeds = true;
};
explicit SerializationContext(Params params) : params_(params) {}
@@ -491,6 +522,8 @@
bool serialize_data_tensors() const { return params_.serialize_data_tensors; }
+ bool preserve_random_seeds() const { return params_.preserve_random_seeds; }
+
private:
Params params_;
diff --git a/tensorflow/core/framework/memory_types.cc b/tensorflow/core/framework/memory_types.cc
index 246f50a..5393b16 100644
--- a/tensorflow/core/framework/memory_types.cc
+++ b/tensorflow/core/framework/memory_types.cc
@@ -17,6 +17,8 @@
#include <utility>
+#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
@@ -97,6 +99,11 @@
inp_mtypes->clear();
out_mtypes->clear();
+ bool has_xla_compile = [&] {
+ const auto& it = ndef.attr().find(kXlaMustCompileAttr);
+ return it != ndef.attr().end() && it->second.b();
+ }();
+
// For functions (which have no KernelDef) and their gradients, we can only
// best-effort derive the memory type from the data type. For now, we assume
// int32 is always on host memory and other types are always on device memory.
@@ -104,7 +111,7 @@
// to derive the correct input/output memory types. We should also split
// host-memory and non host-memory arguments into separate type lists.
if (!status.ok() || IsFunctionCallOp(ndef.op())) {
- if (device_type.type_string() == "TPU") {
+ if (device_type.type_string() == "TPU" || has_xla_compile) {
// Here we assume that if tf.function() is called within
// "with tf.device('/device:TPU:0')", the whole function will be compiled
// and executed on TPU. This is true today, but when we implement auto
diff --git a/tensorflow/core/framework/numeric_types.h b/tensorflow/core/framework/numeric_types.h
index 3236d18..c8ac08b 100644
--- a/tensorflow/core/framework/numeric_types.h
+++ b/tensorflow/core/framework/numeric_types.h
@@ -84,6 +84,27 @@
}
};
+#ifdef USE_TSTRING
+template <>
+struct NumTraits<tensorflow::tstring> : GenericNumTraits<tensorflow::tstring> {
+ enum {
+ RequireInitialization = 1,
+ ReadCost = HugeCost,
+ AddCost = HugeCost,
+ MulCost = HugeCost
+ };
+
+ static inline int digits10() { return 0; }
+
+ private:
+ static inline tensorflow::tstring epsilon();
+ static inline tensorflow::tstring dummy_precision();
+ static inline tensorflow::tstring lowest();
+ static inline tensorflow::tstring highest();
+ static inline tensorflow::tstring infinity();
+ static inline tensorflow::tstring quiet_NaN();
+};
+#endif // USE_TSTRING
using ::tensorflow::operator==;
using ::tensorflow::operator!=;
diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc
index 34ef6e6..67fc398 100644
--- a/tensorflow/core/framework/resource_mgr.cc
+++ b/tensorflow/core/framework/resource_mgr.cc
@@ -38,7 +38,6 @@
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes) {
ResourceHandle result;
result.set_device(device.name());
- string actual_container;
result.set_container(container);
if (name == ResourceHandle::ANONYMOUS_NAME) {
result.set_name(strings::StrCat("_AnonymousVar", current_id_.fetch_add(1)));
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 6f2a90d..5511932 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -1486,7 +1486,6 @@
// false otherwise.
static bool FusedMatMulRewrite(const Node* n) {
bool trans_a;
- std::vector<string> fused_ops;
// Do not rewrite with transpose attribute because reorder has performance
// impact.
diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc
index 61fb5d1..329f770 100644
--- a/tensorflow/core/graph/mkl_layout_pass_test.cc
+++ b/tensorflow/core/graph/mkl_layout_pass_test.cc
@@ -866,898 +866,387 @@
}
REGISTER_TEST_ALL_TYPES(NodeMerge_PadWithConv2D_Negative);
#undef REGISTER_TEST
+
+#define REGISTER_TEST(NAME, T, INPUT) \
+ TEST_F(MklLayoutPassTest, NAME##_##T) { \
+ InitGraph( \
+ "node { name: 'Input0' op: '" #INPUT "'}" \
+ "node { name: 'Input1' op: '" #INPUT "'}" \
+ "node { name: 'Const0' op: 'Const'" \
+ " attr {key: 'dtype' value { type: DT_INT32 } }" \
+ " attr {key: 'value' value { " \
+ " tensor {" \
+ " dtype: DT_INT32" \
+ " tensor_shape { dim { size: 4 } }" \
+ " tensor_content: " \
+ "'\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\001\\" \
+ "000\\000\\000'" \
+ " }" \
+ " }" \
+ " }" \
+ "}" \
+ "node { name: 'Const1' op: 'Const'" \
+ " attr {key: 'dtype' value { type: DT_INT32 } }" \
+ " attr {key: 'value'" \
+ " value {" \
+ " tensor {" \
+ " dtype: DT_INT32" \
+ " tensor_shape {dim {size: 4 }}" \
+ " tensor_content: " \
+ "'\\000\\000\\000\\000\\003\\000\\000\\000\\001" \
+ "\\000\\000\\000\\002\\000\\000\\000'" \
+ " }" \
+ " }" \
+ " }" \
+ "}" \
+ "node { name: 'Transpose0' op: 'Transpose'" \
+ " input: ['Input0', 'Const0']" \
+ " attr { key: 'T' value { type: " #T "} }" \
+ " attr { key: 'Tperm' value { type: DT_INT32 } } }" \
+ "node { name: 'Conv2D' op: 'Conv2D'" \
+ " input: ['Transpose0', 'Input1']" \
+ " attr { key: 'T' value { type: " #T "} }" \
+ " attr { key: 'data_format' value { s: 'NHWC' }}" \
+ " attr { key: 'dilations' value {list: {i:1,i:1,i:1,i:1}}}" \
+ " attr { key: 'padding' value {s: 'SAME'}}" \
+ " attr { key: 'strides' value {list: {i:1,i:1,i:1,i:1}}}" \
+ " attr { key: 'use_cudnn_on_gpu' value {b: true}}}" \
+ "node { name: 'Transpose1' op: 'Transpose'" \
+ " input: ['Conv2D', 'Const1' ]" \
+ " attr { key: 'T' value { type: " #T "}}" \
+ " attr { key: 'Tperm' value { type: DT_INT32 }}}" \
+ "node { name: 'Relu' op: 'Relu'" \
+ " attr { key: 'T' value { type: " #T "} }" \
+ " input: ['Transpose1'] }"); \
+ EXPECT_EQ(DoMklLayoutOptimizationPass(), \
+ "Const0(Const);Const1(Const);Conv2D(_MklConv2D);DMT/_0(Const);" \
+ "DMT/_1(Const);Input0(" #INPUT ");Input1(" #INPUT ");" \
+ "Relu(_MklRelu)|Conv2D->Relu;Conv2D:2->Relu:1;DMT/_0->Conv2D:2;" \
+ "DMT/_1->Conv2D:3;Input0->Conv2D;Input0:control->DMT/_0:control;"\
+ "Input0:control->DMT/_1:control;Input1->Conv2D:1"); \
+}
+REGISTER_TEST_ALL_TYPES(NodeMerge_TransposeConv2DTranspose_Positive);
+#undef REGISTER_TEST
+
+#define REGISTER_TEST(NAME, T, INPUT) \
+ TEST_F(MklLayoutPassTest, NAME##_##T) { \
+ InitGraph( \
+ "node { name: 'Input0' op: '" #INPUT "'}" \
+ "node { name: 'Input1' op: '" #INPUT "'}" \
+ "node { name: 'Const0' op: 'Const'" \
+ " attr { key: 'dtype' value { type: DT_INT32 } }" \
+ " attr { key: 'value'" \
+ " value {" \
+ " tensor {" \
+ " dtype: DT_INT32" \
+ " tensor_shape {dim {size: 4}}" \
+ " tensor_content: " \
+ "'\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\001\\" \
+ "000\\000\\000'" \
+ " }" \
+ " }" \
+ " }" \
+ "}" \
+ "node { name: 'Const1' op: 'Const'" \
+ " attr { key: 'dtype' value { type: DT_INT32 }}" \
+ " attr {" \
+ " key: 'value'" \
+ " value {" \
+ " tensor {" \
+ " dtype: DT_INT32" \
+ " tensor_shape { dim { size: 4 }}" \
+ " tensor_content: " \
+ "'\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\001\\" \
+ "000\\000\\000'" \
+ " }" \
+ " }" \
+ " }" \
+ "}" \
+ "node {name: 'Transpose0' op: 'Transpose'" \
+ " input: ['Input0', 'Const0']" \
+ " attr { key: 'T' value { type: " #T " }}" \
+ " attr { key: 'Tperm' value { type: DT_INT32 }}}" \
+ "node { name: 'Conv2D' op: 'Conv2D'" \
+ " input: ['Transpose0', 'Input1']" \
+ " attr { key: 'T' value { type: " #T "}}" \
+ " attr { key: 'data_format' value { s: 'NHWC' }}" \
+ " attr { key: 'dilations' value { list: {i:1,i:1,i:1,i:1}}}" \
+ " attr { key: 'padding' value { s: 'SAME' }}" \
+ " attr { key: 'strides' value { list: {i:1,i:1,i:1,i:1}}}" \
+ " attr { key: 'use_cudnn_on_gpu' value { b: true }}}" \
+ "node {name: 'Transpose1' op: 'Transpose'" \
+ " input: ['Conv2D', 'Const1']" \
+ " attr { key: 'T' value { type: " #T "}}" \
+ " attr { key: 'Tperm' value { type: DT_INT32 }}}" \
+ "node { name: 'Relu' op: 'Relu'" \
+ " attr { key: 'T' value { type: " #T "}}" \
+ " input: ['Transpose1'] }"); \
+ EXPECT_EQ(DoMklLayoutOptimizationPass(), \
+ "Const0(Const);Const1(Const);Conv2D(_MklConv2D);DMT/_0(Const);" \
+ "DMT/_1(Const);DMT/_2(Const);Input0(" #INPUT ");Input1(" #INPUT \
+ ");Relu(_MklRelu);Transpose0(_MklTranspose);" \
+ "Transpose1(_MklTranspose)|Const0->Transpose0:1;" \
+ "Const1->Transpose1:1;Conv2D->Transpose1;DMT/_0->Conv2D:2;" \
+ "DMT/_1->Conv2D:3;DMT/_2->Relu:1;Input0->Transpose0;" \
+ "Input1->Conv2D:1;Transpose0->Conv2D;Transpose0:control->DMT/" \
+ "_0:control;Transpose0:control->DMT/_1:control;Transpose1->Relu;"\
+ "Transpose1:control->DMT/_2:control"); \
+}
+REGISTER_TEST_ALL_TYPES(NodeMerge_TransposeConv2DTranspose_Negative);
+#undef REGISTER_TEST
+
+
+#define REGISTER_TEST(NAME, T, INPUT) \
+ TEST_F(MklLayoutPassTest, NAME##_##T) { \
+ InitGraph( \
+ "node { name: 'Input0' op: '" #INPUT "'}" \
+ "node { name: 'Input1' op: '" #INPUT "'}" \
+ "node { name: 'Const0' op: 'Const'" \
+ " attr { key: 'dtype' value { type: DT_INT32 } }" \
+ " attr { key: 'value'" \
+ " value {" \
+ " tensor {" \
+ " dtype: DT_INT32" \
+ " tensor_shape { dim {size: 5}}" \
+ " tensor_content:" \
+ "'\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\004" \
+ "\\000\\000\\000\\001\\000\\000\\000'" \
+ " }" \
+ " }" \
+ " }" \
+ "}" \
+ "node { name: 'Const1' op: 'Const'" \
+ " attr { key: 'dtype' value { type: DT_INT32 } }" \
+ " attr { key: 'value'" \
+ " value {" \
+ " tensor {" \
+ " dtype: DT_INT32" \
+ " tensor_shape { dim { size: 5 }}" \
+ " tensor_content: " \
+ "'\\000\\000\\000\\000\\004\\000\\000\\000\\001\\000\\000\\000\\002" \
+ "\\000\\000\\000\\003\\000\\000\\000'" \
+ " }" \
+ " }" \
+ " }" \
+ "}" \
+ "node { name: 'Transpose0' op: 'Transpose'" \
+ " input: ['Input0', 'Const0']" \
+ " attr { key: 'T' value { type: " #T " }}" \
+ " attr { key: 'Tperm' value { type: DT_INT32 }}}" \
+ "node { name: 'Conv3D' op: 'Conv3D'" \
+ "input: ['Transpose0', 'Input1']" \
+ " attr { key: 'T' value { type: " #T " }}" \
+ " attr { key: 'data_format' value { s: 'NDHWC' }}" \
+ " attr { key: 'dilations' value { list: {i:1,i:1,i:1,i:1,i:1}}}" \
+ " attr { key: 'padding' value { s: 'SAME' }}" \
+ " attr { key: 'strides' value { list: {i:1,i:1,i:1,i:1,i:1}}}" \
+ " attr { key: 'use_cudnn_on_gpu' value { b: true }}}" \
+ "node { name: 'Transpose1' op: 'Transpose'" \
+ " input: ['Conv3D', 'Const1']" \
+ " attr { key: 'T' value { type: " #T " }}" \
+ " attr { key: 'Tperm' value { type: DT_INT32 }}}" \
+ "node { name: 'Relu' op: 'Relu'" \
+ " attr { key: 'T' value { type: " #T " } }" \
+ " input: ['Transpose1'] }"); \
+ EXPECT_EQ(DoMklLayoutOptimizationPass(), \
+ "Const0(Const);Const1(Const);Conv3D(_MklConv3D);DMT/_0(Const);" \
+ "DMT/_1(Const);Input0(" #INPUT ");Input1(" #INPUT ");" \
+ "Relu(_MklRelu)|Conv3D->Relu;Conv3D:2->Relu:1;" \
+ "DMT/_0->Conv3D:2;DMT/_1->Conv3D:3;Input0->Conv3D;" \
+ "Input0:control->DMT/_0:control;" \
+ "Input0:control->DMT/_1:control;Input1->Conv3D:1"); \
+}
+REGISTER_TEST_ALL_TYPES(NodeMerge_TransposeConv3DTranspose_Positive);
+#undef REGISTER_TEST
+
+#define REGISTER_TEST(NAME, T, INPUT) \
+ TEST_F(MklLayoutPassTest, NAME##_##T) { \
+ InitGraph( \
+ "node { name: 'Input0' op: '" #INPUT "'}" \
+ "node { name: 'Input1' op: '" #INPUT "'}" \
+ "node { name: 'Const0' op: 'Const'" \
+ " attr { key: 'dtype' value { type: DT_INT32 } }" \
+ " attr { key: 'value'" \
+ " value {" \
+ " tensor {" \
+ " dtype: DT_INT32" \
+ " tensor_shape { dim {size: 5}}" \
+ " tensor_content:" \
+ "'\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\004" \
+ "\\000\\000\\000\\001\\000\\000\\000'" \
+ " }" \
+ " }" \
+ " }" \
+ "}" \
+ "node { name: 'Const1' op: 'Const'" \
+ " attr { key: 'dtype' value { type: DT_INT32 } }" \
+ " attr { key: 'value'" \
+ " value {" \
+ " tensor {" \
+ " dtype: DT_INT32" \
+ " tensor_shape { dim { size: 5 }}" \
+ " tensor_content: " \
+ "'\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\004" \
+ "\\000\\000\\000\\001\\000\\000\\000'" \
+ " }" \
+ " }" \
+ " }" \
+ "}" \
+ "node { name: 'Transpose0' op: 'Transpose'" \
+ " input: ['Input0', 'Const0']" \
+ " attr { key: 'T' value { type: " #T " }}" \
+ " attr { key: 'Tperm' value { type: DT_INT32 }}}" \
+ "node { name: 'Conv3D' op: 'Conv3D'" \
+ "input: ['Transpose0', 'Input1']" \
+ " attr { key: 'T' value { type: " #T " }}" \
+ " attr { key: 'data_format' value { s: 'NDHWC' }}" \
+ " attr { key: 'dilations' value { list: {i:1,i:1,i:1,i:1,i:1}}}" \
+ " attr { key: 'padding' value { s: 'SAME' }}" \
+ " attr { key: 'strides' value { list: {i:1,i:1,i:1,i:1,i:1}}}" \
+ " attr { key: 'use_cudnn_on_gpu' value { b: true }}}" \
+ "node { name: 'Transpose1' op: 'Transpose'" \
+ " input: ['Conv3D', 'Const1']" \
+ " attr { key: 'T' value { type: " #T " }}" \
+ " attr { key: 'Tperm' value { type: DT_INT32 }}}" \
+ "node { name: 'Relu' op: 'Relu'" \
+ " attr { key: 'T' value { type: " #T " } }" \
+ " input: ['Transpose1'] }"); \
+ EXPECT_EQ(DoMklLayoutOptimizationPass(), \
+ "Const0(Const);Const1(Const);Conv3D(_MklConv3D);DMT/_0(Const);" \
+ "DMT/_1(Const);DMT/_2(Const);Input0(" #INPUT ");" \
+ "Input1(" #INPUT ");Relu(_MklRelu);Transpose0(_MklTranspose);" \
+ "Transpose1(_MklTranspose)|Const0->Transpose0:1;Const1->" \
+ "Transpose1:1;Conv3D->Transpose1;DMT/_0->Conv3D:2;" \
+ "DMT/_1->Conv3D:3;DMT/_2->Relu:1;Input0->Transpose0;Input1->" \
+ "Conv3D:1;Transpose0->Conv3D;Transpose0:control->" \
+ "DMT/_0:control;Transpose0:control->DMT/_1:control;" \
+ "Transpose1->Relu;Transpose1:control->DMT/_2:control"); \
+}
+REGISTER_TEST_ALL_TYPES(NodeMerge_TransposeConv3DTranspose_Negative);
+#undef REGISTER_TEST
+
+#define REGISTER_TEST(NAME, T, INPUT) \
+ TEST_F(MklLayoutPassTest, NAME##_##T) { \
+ InitGraph( \
+ "node { name: 'Input0' op: '" #INPUT "'}" \
+ "node { name: 'Const0' op: 'Const'" \
+ " attr { key: 'dtype' value { type: DT_INT32 } }" \
+ " attr { key: 'value'" \
+ " value {" \
+ " tensor {" \
+ " dtype: DT_INT32" \
+ " tensor_shape { dim {size: 5}}" \
+ " tensor_content:" \
+ "'\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\004" \
+ "\\000\\000\\000\\001\\000\\000\\000'" \
+ " }" \
+ " }" \
+ " }" \
+ "}" \
+ "node { name: 'Const1' op: 'Const'" \
+ " attr { key: 'dtype' value { type: DT_INT32 } }" \
+ " attr { key: 'value'" \
+ " value {" \
+ " tensor {" \
+ " dtype: DT_INT32" \
+ " tensor_shape { dim { size: 5 }}" \
+ " tensor_content: " \
+ "'\\000\\000\\000\\000\\004\\000\\000\\000\\001\\000\\000\\000\\002" \
+ "\\000\\000\\000\\003\\000\\000\\000'" \
+ " }" \
+ " }" \
+ " }" \
+ "}" \
+ "node { name: 'Transpose0' op: 'Transpose'" \
+ " input: ['Input0', 'Const0']" \
+ " attr { key: 'T' value { type: " #T " }}" \
+ " attr { key: 'Tperm' value { type: DT_INT32 }}}" \
+ "node { name: 'MaxPool3D' op: 'MaxPool3D'" \
+ "input: ['Transpose0']" \
+ " attr { key: 'T' value { type: " #T " }}" \
+ " attr { key: 'data_format' value { s: 'NDHWC' }}" \
+ " attr { key: 'padding' value { s: 'SAME' }}" \
+ " attr { key: 'strides' value { list: {i:1,i:2,i:2,i:2,i:1}}}" \
+ " attr { key: 'ksize' value { list: {i:1,i:1,i:1,i:1,i:1}}}}"\
+ "node { name: 'Transpose1' op: 'Transpose'" \
+ " input: ['MaxPool3D', 'Const1']" \
+ " attr { key: 'T' value { type: " #T " }}" \
+ " attr { key: 'Tperm' value { type: DT_INT32 }}}" \
+ "node { name: 'Relu' op: 'Relu'" \
+ " attr { key: 'T' value { type: " #T " }}" \
+ " input: ['Transpose1'] }"); \
+ EXPECT_EQ(DoMklLayoutOptimizationPass(), \
+ "Const0(Const);Const1(Const);DMT/_0(Const);Input0(" #INPUT ");" \
+ "MaxPool3D(_MklMaxPool3D);Relu(_MklRelu)|DMT/_0->MaxPool3D:1;" \
+ "Input0->MaxPool3D;Input0:control->DMT/_0:control;" \
+ "MaxPool3D->Relu;MaxPool3D:2->Relu:1"); \
+}
+REGISTER_TEST_ALL_TYPES(NodeMerge_TransposeMaxPool3DTranspose_Positive);
+#undef REGISTER_TEST
+
+#define REGISTER_TEST(NAME, T, INPUT) \
+ TEST_F(MklLayoutPassTest, NAME##_##T) { \
+ InitGraph( \
+ "node { name: 'Input0' op: '" #INPUT "'}" \
+ "node { name: 'Const0' op: 'Const'" \
+ " attr { key: 'dtype' value { type: DT_INT32 } }" \
+ " attr { key: 'value'" \
+ " value {" \
+ " tensor {" \
+ " dtype: DT_INT32" \
+ " tensor_shape { dim {size: 5}}" \
+ " tensor_content:" \
+ "'\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\004" \
+ "\\000\\000\\000\\001\\000\\000\\000'" \
+ " }" \
+ " }" \
+ " }" \
+ "}" \
+ "node { name: 'Const1' op: 'Const'" \
+ " attr { key: 'dtype' value { type: DT_INT32 } }" \
+ " attr { key: 'value'" \
+ " value {" \
+ " tensor {" \
+ " dtype: DT_INT32" \
+ " tensor_shape { dim { size: 5 }}" \
+ " tensor_content: " \
+ "'\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\004" \
+ "\\000\\000\\000\\001\\000\\000\\000'" \
+ " }" \
+ " }" \
+ " }" \
+ "}" \
+ "node { name: 'Transpose0' op: 'Transpose'" \
+ " input: ['Input0', 'Const0']" \
+ " attr { key: 'T' value { type: " #T " }}" \
+ " attr { key: 'Tperm' value { type: DT_INT32 }}}" \
+ "node { name: 'MaxPool3D' op: 'MaxPool3D'" \
+ "input: ['Transpose0']" \
+ " attr { key: 'T' value { type: " #T " }}" \
+ " attr { key: 'data_format' value { s: 'NDHWC' }}" \
+ " attr { key: 'padding' value { s: 'SAME' }}" \
+ " attr { key: 'strides' value { list: {i:1,i:2,i:2,i:2,i:1}}}" \
+ " attr { key: 'ksize' value { list: {i:1,i:1,i:1,i:1,i:1}}}}"\
+ "node { name: 'Transpose1' op: 'Transpose'" \
+ " input: ['MaxPool3D', 'Const1']" \
+ " attr { key: 'T' value { type: " #T " }}" \
+ " attr { key: 'Tperm' value { type: DT_INT32 }}}" \
+ "node { name: 'Relu' op: 'Relu'" \
+ " attr { key: 'T' value { type: " #T " }}" \
+ " input: ['Transpose1'] }"); \
+ EXPECT_EQ(DoMklLayoutOptimizationPass(), \
+ "Const0(Const);Const1(Const);DMT/_0(Const);DMT/_1(Const);Input0(" #INPUT\
+ ");MaxPool3D(_MklMaxPool3D);Relu(_MklRelu);Transpose0(_MklTranspose);" \
+ "Transpose1(_MklTranspose)|Const0->Transpose0:1;" \
+ "Const1->Transpose1:1;DMT/_0->MaxPool3D:1;DMT/_1->Relu:1;Input0->" \
+ "Transpose0;MaxPool3D->Transpose1;Transpose0->MaxPool3D;Transpose0:" \
+ "control->DMT/_0:control;Transpose1->Relu;Transpose1:control->" \
+ "DMT/_1:control"); \
+}
+REGISTER_TEST_ALL_TYPES(NodeMerge_TransposeMaxPool3DTranspose_Negative);
+#undef REGISTER_TEST
// clang-format on
-TEST_F(MklLayoutPassTest, NodeMerge_TransposeConv2DTranspose_Positive) {
- InitGraph(
- "node { name: 'Input0' op: 'Input'}"
- "node { name: 'Input1' op: 'Input'}"
- "node { name: 'Const0' op: 'Const'"
- " attr {"
- " key: 'dtype'"
- " value {"
- " type: DT_INT32"
- " }"
- " }"
- " attr {"
- " key: 'value'"
- " value {"
- " tensor {"
- " dtype: DT_INT32"
- " tensor_shape {"
- " dim {"
- " size: 4"
- " }"
- " }"
- " tensor_content: "
- "'\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\001\\000"
- "\\000\\000'"
- " }"
- " }"
- " }"
- "}"
- "node { name: 'Const1' op: 'Const'"
- " attr {"
- " key: 'dtype'"
- " value {"
- " type: DT_INT32"
- " }"
- " }"
- " attr {"
- " key: 'value'"
- " value {"
- " tensor {"
- " dtype: DT_INT32"
- " tensor_shape {"
- " dim {"
- " size: 4"
- " }"
- " }"
- " tensor_content: "
- "'\\000\\000\\000\\000\\003\\000\\000\\000\\001\\000\\000\\000\\002\\000"
- "\\000\\000'"
- " }"
- " }"
- " }"
- "}"
- "node { \
- name: 'Transpose0' \
- op: 'Transpose' \
- input: 'Input0' \
- input: 'Const0' \
- attr { \
- key: 'T' \
- value { \
- type: DT_FLOAT \
- } \
- } \
- attr { \
- key: 'Tperm' \
- value { \
- type: DT_INT32 \
- } \
- } \
- }"
- "node { \
- name: 'Conv2D' \
- op: 'Conv2D' \
- input: 'Transpose0' \
- input: 'Input1' \
- attr { \
- key: 'T' \
- value { \
- type: DT_FLOAT \
- } \
- } \
- attr { \
- key: 'data_format' \
- value { \
- s: 'NHWC' \
- } \
- } \
- attr { \
- key: 'dilations' \
- value { \
- list { \
- i: 1 \
- i: 1 \
- i: 1 \
- i: 1 \
- } \
- } \
- } \
- attr { \
- key: 'padding' \
- value { \
- s: 'SAME' \
- } \
- } \
- attr { \
- key: 'strides' \
- value { \
- list { \
- i: 1 \
- i: 1 \
- i: 1 \
- i: 1 \
- } \
- } \
- } \
- attr { \
- key: 'use_cudnn_on_gpu' \
- value { \
- b: true \
- } \
- } \
- }"
- "node { \
- name: 'Transpose1' \
- op: 'Transpose' \
- input: 'Conv2D' \
- input: 'Const1' \
- attr { \
- key: 'T' \
- value { \
- type: DT_FLOAT \
- } \
- } \
- attr { \
- key: 'Tperm' \
- value { \
- type: DT_INT32 \
- } \
- } \
- }"
- "node { name: 'Relu' op: 'Relu'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['Transpose1'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "Const0(Const);Const1(Const);"
- "Conv2D(_MklConv2D);DMT/_0(Const);DMT/_1(Const);Input0(Input);"
- "Input1(Input);Relu(_MklRelu)|Conv2D->Relu;Conv2D:2->Relu:1;DMT/"
- "_0->Conv2D:2;DMT/_1->Conv2D:3;Input0->Conv2D;"
- "Input0:control->DMT/_0:control;Input0:control->DMT/"
- "_1:control;Input1->Conv2D:1");
-}
-
-TEST_F(MklLayoutPassTest, NodeMerge_TransposeConv2DTranspose_Negative) {
- InitGraph(
- "node { name: 'Input0' op: 'Input'}"
- "node { name: 'Input1' op: 'Input'}"
- "node { name: 'Const0' op: 'Const'"
- " attr {"
- " key: 'dtype'"
- " value {"
- " type: DT_INT32"
- " }"
- " }"
- " attr {"
- " key: 'value'"
- " value {"
- " tensor {"
- " dtype: DT_INT32"
- " tensor_shape {"
- " dim {"
- " size: 4"
- " }"
- " }"
- " tensor_content: "
- "'\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\001\\000"
- "\\000\\000'"
- " }"
- " }"
- " }"
- "}"
- "node { name: 'Const1' op: 'Const'"
- " attr {"
- " key: 'dtype'"
- " value {"
- " type: DT_INT32"
- " }"
- " }"
- " attr {"
- " key: 'value'"
- " value {"
- " tensor {"
- " dtype: DT_INT32"
- " tensor_shape {"
- " dim {"
- " size: 4"
- " }"
- " }"
- " tensor_content: "
- "'\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\001\\000"
- "\\000\\000'"
- " }"
- " }"
- " }"
- "}"
- "node { \
- name: 'Transpose0' \
- op: 'Transpose' \
- input: 'Input0' \
- input: 'Const0' \
- attr { \
- key: 'T' \
- value { \
- type: DT_FLOAT \
- } \
- } \
- attr { \
- key: 'Tperm' \
- value { \
- type: DT_INT32 \
- } \
- } \
- }"
- "node { \
- name: 'Conv2D' \
- op: 'Conv2D' \
- input: 'Transpose0' \
- input: 'Input1' \
- attr { \
- key: 'T' \
- value { \
- type: DT_FLOAT \
- } \
- } \
- attr { \
- key: 'data_format' \
- value { \
- s: 'NHWC' \
- } \
- } \
- attr { \
- key: 'dilations' \
- value { \
- list { \
- i: 1 \
- i: 1 \
- i: 1 \
- i: 1 \
- } \
- } \
- } \
- attr { \
- key: 'padding' \
- value { \
- s: 'SAME' \
- } \
- } \
- attr { \
- key: 'strides' \
- value { \
- list { \
- i: 1 \
- i: 1 \
- i: 1 \
- i: 1 \
- } \
- } \
- } \
- attr { \
- key: 'use_cudnn_on_gpu' \
- value { \
- b: true \
- } \
- } \
- }"
- "node { \
- name: 'Transpose1' \
- op: 'Transpose' \
- input: 'Conv2D' \
- input: 'Const1' \
- attr { \
- key: 'T' \
- value { \
- type: DT_FLOAT \
- } \
- } \
- attr { \
- key: 'Tperm' \
- value { \
- type: DT_INT32 \
- } \
- } \
- }"
- "node { name: 'Relu' op: 'Relu'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['Transpose1'] }");
- EXPECT_EQ(
- DoMklLayoutOptimizationPass(),
- "Const0(Const);Const1(Const);Conv2D(_MklConv2D);"
- "DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);Input0(Input);"
- "Input1(Input);Relu(_MklRelu);Transpose0(_MklTranspose);"
- "Transpose1(_MklTranspose)|Const0->Transpose0:1;"
- "Const1->Transpose1:1;Conv2D->Transpose1;DMT/_0->Conv2D:2;"
- "DMT/_1->Conv2D:3;DMT/_2->Relu:1;Input0->Transpose0;"
- "Input1->Conv2D:1;Transpose0->Conv2D;Transpose0:control->DMT/_0:control;"
- "Transpose0:control->DMT/_1:control;Transpose1->Relu;"
- "Transpose1:control->DMT/_2:control");
-}
-
-TEST_F(MklLayoutPassTest, NodeMerge_TransposeConv3DTranspose_Positive) {
- InitGraph(
- "node { name: 'Input0' op: 'Input'} \
- node { name: 'Input1' op: 'Input'} \
- node { name: 'Const0' op: 'Const' \
- attr { key: 'dtype' value { type: DT_INT32 } } \
- attr { \
- key: 'value' \
- value { \
- tensor { \
- dtype: DT_INT32 \
- tensor_shape { \
- dim { \
- size: 5 \
- } \
- } \
- tensor_content: \
- '\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\004' \
- '\\000\\000\\000\\001\\000\\000\\000' \
- } \
- } \
- } \
- } \
- node { name: 'Const1' op: 'Const' \
- attr { key: 'dtype' value { type: DT_INT32 } } \
- attr { \
- key: 'value' \
- value { \
- tensor { \
- dtype: DT_INT32 \
- tensor_shape { \
- dim { \
- size: 5 \
- } \
- } \
- tensor_content: \
- '\\000\\000\\000\\000\\004\\000\\000\\000\\001\\000\\000\\000\\002' \
- '\\000\\000\\000\\003\\000\\000\\000' \
- } \
- } \
- } \
- }"
- "node { \
- name: 'Transpose0' \
- op: 'Transpose' \
- input: 'Input0' \
- input: 'Const0' \
- attr { \
- key: 'T' \
- value { \
- type: DT_FLOAT \
- } \
- } \
- attr { \
- key: 'Tperm' \
- value { \
- type: DT_INT32 \
- } \
- } \
- }"
- "node { \
- name: 'Conv3D' \
- op: 'Conv3D' \
- input: 'Transpose0' \
- input: 'Input1' \
- attr { \
- key: 'T' \
- value { \
- type: DT_FLOAT \
- } \
- } \
- attr { \
- key: 'data_format' \
- value { \
- s: 'NDHWC' \
- } \
- } \
- attr { \
- key: 'dilations' \
- value { \
- list { \
- i: 1 \
- i: 1 \
- i: 1 \
- i: 1 \
- i: 1 \
- } \
- } \
- } \
- attr { \
- key: 'padding' \
- value { \
- s: 'SAME' \
- } \
- } \
- attr { \
- key: 'strides' \
- value { \
- list { \
- i: 1 \
- i: 1 \
- i: 1 \
- i: 1 \
- i: 1 \
- } \
- } \
- } \
- attr { \
- key: 'use_cudnn_on_gpu' \
- value { \
- b: true \
- } \
- } \
- }"
- "node { \
- name: 'Transpose1' \
- op: 'Transpose' \
- input: 'Conv3D' \
- input: 'Const1' \
- attr { \
- key: 'T' \
- value { \
- type: DT_FLOAT \
- } \
- } \
- attr { \
- key: 'Tperm' \
- value { \
- type: DT_INT32 \
- } \
- } \
- }"
- "node { name: 'Relu' op: 'Relu'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['Transpose1'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "Const0(Const);Const1(Const);Conv3D(_MklConv3D);DMT/_0(Const);"
- "DMT/_1(Const);Input0(Input);Input1(Input);"
- "Relu(_MklRelu)|Conv3D->Relu;Conv3D:2->Relu:1;"
- "DMT/_0->Conv3D:2;DMT/_1->Conv3D:3;Input0->Conv3D;"
- "Input0:control->DMT/_0:control;"
- "Input0:control->DMT/_1:control;Input1->Conv3D:1");
-}
-
-TEST_F(MklLayoutPassTest, NodeMerge_TransposeConv3DTranspose_Negative) {
- InitGraph(
- "node { name: 'Input0' op: 'Input'} \
- node { name: 'Input1' op: 'Input'} \
- node { name: 'Const0' op: 'Const' \
- attr { \
- key: 'dtype' \
- value { \
- type: DT_INT32 \
- } \
- } \
- attr { \
- key: 'value' \
- value { \
- tensor { \
- dtype: DT_INT32 \
- tensor_shape { \
- dim { \
- size: 5 \
- } \
- } \
- tensor_content: \
- '\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\004' \
- '\\000\\000\\000\\001\\000\\000\\000' \
- } \
- } \
- } \
- } \
- node { name: 'Const1' op: 'Const' \
- attr { \
- key: 'dtype' \
- value { \
- type: DT_INT32 \
- } \
- } \
- attr { \
- key: 'value' \
- value { \
- tensor { \
- dtype: DT_INT32 \
- tensor_shape { \
- dim { \
- size: 5 \
- } \
- } \
- tensor_content: \
- '\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\004' \
- '\\000\\000\\000\\001\\000\\000\\000' \
- } \
- } \
- } \
- }"
- "node { \
- name: 'Transpose0' \
- op: 'Transpose' \
- input: 'Input0' \
- input: 'Const0' \
- attr { \
- key: 'T' \
- value { \
- type: DT_FLOAT \
- } \
- } \
- attr { \
- key: 'Tperm' \
- value { \
- type: DT_INT32 \
- } \
- } \
- }"
- "node { \
- name: 'Conv3D' \
- op: 'Conv3D' \
- input: 'Transpose0' \
- input: 'Input1' \
- attr { \
- key: 'T' \
- value { \
- type: DT_FLOAT \
- } \
- } \
- attr { \
- key: 'data_format' \
- value { \
- s: 'NDHWC' \
- } \
- } \
- attr { \
- key: 'dilations' \
- value { \
- list { \
- i: 1 \
- i: 1 \
- i: 1 \
- i: 1 \
- i: 1 \
- } \
- } \
- } \
- attr { \
- key: 'padding' \
- value { \
- s: 'SAME' \
- } \
- } \
- attr { \
- key: 'strides' \
- value { \
- list { \
- i: 1 \
- i: 1 \
- i: 1 \
- i: 1 \
- i: 1 \
- } \
- } \
- } \
- attr { \
- key: 'use_cudnn_on_gpu' \
- value { \
- b: true \
- } \
- } \
- }"
- "node { \
- name: 'Transpose1' \
- op: 'Transpose' \
- input: 'Conv3D' \
- input: 'Const1' \
- attr { \
- key: 'T' \
- value { \
- type: DT_FLOAT \
- } \
- } \
- attr { \
- key: 'Tperm' \
- value { \
- type: DT_INT32 \
- } \
- } \
- }"
- "node { name: 'Relu' op: 'Relu'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['Transpose1'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "Const0(Const);Const1(Const);Conv3D(_MklConv3D);"
- "DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);"
- "Input0(Input);Input1(Input);Relu(_MklRelu);"
- "Transpose0(_MklTranspose);Transpose1(_MklTranspose)"
- "|Const0->Transpose0:1;Const1->Transpose1:1;"
- "Conv3D->Transpose1;DMT/_0->Conv3D:2;DMT/_1->Conv3D:3;"
- "DMT/_2->Relu:1;Input0->Transpose0;Input1->Conv3D:1;"
- "Transpose0->Conv3D;Transpose0:control->DMT/_0:control;"
- "Transpose0:control->DMT/_1:control;Transpose1->Relu;"
- "Transpose1:control->DMT/_2:control");
-}
-
-TEST_F(MklLayoutPassTest, NodeMerge_TransposeMaxPool3DTranspose_Positive) {
- InitGraph(
- "node { name: 'Input0' op: 'Input'} \
- node { name: 'Const0' op: 'Const' \
- attr { key: 'dtype' value { type: DT_INT32 } } \
- attr { \
- key: 'value' \
- value { \
- tensor { \
- dtype: DT_INT32 \
- tensor_shape { \
- dim { \
- size: 5 \
- } \
- } \
- tensor_content: \
- '\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\004' \
- '\\000\\000\\000\\001\\000\\000\\000' \
- } \
- } \
- } \
- } \
- node { name: 'Const1' op: 'Const' \
- attr { key: 'dtype' value { type: DT_INT32 } } \
- attr { \
- key: 'value' \
- value { \
- tensor { \
- dtype: DT_INT32 \
- tensor_shape { \
- dim { \
- size: 5 \
- } \
- } \
- tensor_content: \
- '\\000\\000\\000\\000\\004\\000\\000\\000\\001\\000\\000\\000\\002' \
- '\\000\\000\\000\\003\\000\\000\\000' \
- } \
- } \
- } \
- }"
- "node { \
- name: 'Transpose0' \
- op: 'Transpose' \
- input: 'Input0' \
- input: 'Const0' \
- attr { \
- key: 'T' \
- value { \
- type: DT_FLOAT \
- } \
- } \
- attr { \
- key: 'Tperm' \
- value { \
- type: DT_INT32 \
- } \
- } \
- }"
- "node { \
- name: 'MaxPool3D' \
- op: 'MaxPool3D' \
- input: 'Transpose0' \
- attr { \
- key: 'T' \
- value { \
- type: DT_FLOAT \
- } \
- } \
- attr { \
- key: 'data_format' \
- value { \
- s: 'NDHWC' \
- } \
- } \
- attr { \
- key: 'padding' \
- value { \
- s: 'SAME' \
- } \
- } \
- attr { \
- key: 'strides' \
- value { \
- list { \
- i: 1 \
- i: 2 \
- i: 2 \
- i: 2 \
- i: 1 \
- } \
- } \
- } \
- attr { \
- key: 'ksize' \
- value { \
- list { \
- i: 1 \
- i: 1 \
- i: 1 \
- i: 1 \
- i: 1 \
- } \
- } \
- } \
- attr { \
- key: 'use_cudnn_on_gpu' \
- value { \
- b: true \
- } \
- } \
- }"
- "node { \
- name: 'Transpose1' \
- op: 'Transpose' \
- input: 'MaxPool3D' \
- input: 'Const1' \
- attr { \
- key: 'T' \
- value { \
- type: DT_FLOAT \
- } \
- } \
- attr { \
- key: 'Tperm' \
- value { \
- type: DT_INT32 \
- } \
- } \
- }"
- "node { name: 'Relu' op: 'Relu'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['Transpose1'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "Const0(Const);Const1(Const);DMT/_0(Const);Input0(Input);"
- "MaxPool3D(_MklMaxPool3D);Relu(_MklRelu)"
- "|DMT/_0->MaxPool3D:1;Input0->MaxPool3D;"
- "Input0:control->DMT/_0:control;MaxPool3D->Relu;"
- "MaxPool3D:2->Relu:1");
-}
-
-TEST_F(MklLayoutPassTest, NodeMerge_TransposeMaxPool3DTranspose_Negative) {
- InitGraph(
- "node { name: 'Input0' op: 'Input'} \
- node { name: 'Const0' op: 'Const' \
- attr { key: 'dtype' value { type: DT_INT32 } } \
- attr { \
- key: 'value' \
- value { \
- tensor { \
- dtype: DT_INT32 \
- tensor_shape { \
- dim { \
- size: 5 \
- } \
- } \
- tensor_content: \
- '\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\004' \
- '\\000\\000\\000\\001\\000\\000\\000' \
- } \
- } \
- } \
- } \
- node { name: 'Const1' op: 'Const' \
- attr { key: 'dtype' value { type: DT_INT32 } } \
- attr { \
- key: 'value' \
- value { \
- tensor { \
- dtype: DT_INT32 \
- tensor_shape { \
- dim { \
- size: 5 \
- } \
- } \
- tensor_content: \
- '\\000\\000\\000\\000\\004\\000\\000\\000\\001\\000\\000\\000\\004' \
- '\\000\\000\\000\\003\\000\\000\\000' \
- } \
- } \
- } \
- }"
- "node { \
- name: 'Transpose0' \
- op: 'Transpose' \
- input: 'Input0' \
- input: 'Const0' \
- attr { \
- key: 'T' \
- value { \
- type: DT_FLOAT \
- } \
- } \
- attr { \
- key: 'Tperm' \
- value { \
- type: DT_INT32 \
- } \
- } \
- }"
- "node { \
- name: 'MaxPool3D' \
- op: 'MaxPool3D' \
- input: 'Transpose0' \
- attr { \
- key: 'T' \
- value { \
- type: DT_FLOAT \
- } \
- } \
- attr { \
- key: 'data_format' \
- value { \
- s: 'NDHWC' \
- } \
- } \
- attr { \
- key: 'padding' \
- value { \
- s: 'SAME' \
- } \
- } \
- attr { \
- key: 'strides' \
- value { \
- list { \
- i: 1 \
- i: 2 \
- i: 2 \
- i: 2 \
- i: 1 \
- } \
- } \
- } \
- attr { \
- key: 'ksize' \
- value { \
- list { \
- i: 1 \
- i: 1 \
- i: 1 \
- i: 1 \
- i: 1 \
- } \
- } \
- } \
- }"
- "node { \
- name: 'Transpose1' \
- op: 'Transpose' \
- input: 'MaxPool3D' \
- input: 'Const1' \
- attr { \
- key: 'T' \
- value { \
- type: DT_FLOAT \
- } \
- } \
- attr { \
- key: 'Tperm' \
- value { \
- type: DT_INT32 \
- } \
- } \
- }"
- "node { name: 'Relu' op: 'Relu'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['Transpose1'] }");
- EXPECT_EQ(
- DoMklLayoutOptimizationPass(),
- "Const0(Const);Const1(Const);DMT/_0(Const);DMT/_1(Const);Input0(Input);"
- "MaxPool3D(_MklMaxPool3D);Relu(_MklRelu);"
- "Transpose0(_MklTranspose);Transpose1(_MklTranspose)|Const0->Transpose0:"
- "1;"
- "Const1->Transpose1:1;DMT/_0->MaxPool3D:1;"
- "DMT/_1->Relu:1;Input0->Transpose0;MaxPool3D->Transpose1;"
- "Transpose0->MaxPool3D;Transpose0:control->DMT/_0:control;"
- "Transpose1->Relu;Transpose1:control->DMT/_1:control");
-}
-
/////////////////////////////////////////////////////////////////////
// Unit tests related to rewriting node to Mkl node
/////////////////////////////////////////////////////////////////////
diff --git a/tensorflow/core/grappler/BUILD b/tensorflow/core/grappler/BUILD
index fd2ea4f..3f79c02 100644
--- a/tensorflow/core/grappler/BUILD
+++ b/tensorflow/core/grappler/BUILD
@@ -23,20 +23,14 @@
hdrs = ["utils.h"],
visibility = ["//visibility:public"],
deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:graph",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
- ] + select({
- "//tensorflow:android": [
- "//tensorflow/core:android_tensorflow_lib",
- ],
- "//conditions:default": [
- "//tensorflow/core:framework",
- "//tensorflow/core:graph",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- "//tensorflow/core:protos_all_cc",
- ],
- }),
+ ],
)
tf_cc_test(
diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD
index 3fb249c..409e68c 100644
--- a/tensorflow/core/grappler/costs/BUILD
+++ b/tensorflow/core/grappler/costs/BUILD
@@ -136,6 +136,7 @@
hdrs = ["utils.h"],
visibility = ["//visibility:public"],
deps = [
+ ":cost_estimator",
"//third_party/eigen3",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
@@ -289,6 +290,7 @@
deps = [
":cost_estimator",
":op_context",
+ ":utils",
"//third_party/eigen3",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index 4a17975..f018e88 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -23,6 +23,7 @@
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/clusters/utils.h"
+#include "tensorflow/core/grappler/costs/utils.h"
namespace tensorflow {
namespace grappler {
@@ -659,7 +660,7 @@
Costs::NanoSeconds(intermediate_read_time);
costs.intermediate_memory_write_time =
Costs::NanoSeconds(intermediate_write_time);
- CombineCostsAndUpdateExecutionTime(&costs);
+ CombineCostsAndUpdateExecutionTime(compute_memory_overlap_, &costs);
return costs;
}
@@ -1715,7 +1716,7 @@
fused_cost.intermediate_memory_time += op_cost.intermediate_memory_time;
}
- CombineCostsAndUpdateExecutionTime(&fused_cost);
+ CombineCostsAndUpdateExecutionTime(compute_memory_overlap_, &fused_cost);
return fused_cost;
}
@@ -2050,17 +2051,5 @@
costs.max_memory = total_output_size;
return costs;
}
-
-void OpLevelCostEstimator::CombineCostsAndUpdateExecutionTime(
- Costs* costs) const {
- if (compute_memory_overlap_) {
- costs->execution_time =
- std::max(costs->intermediate_memory_time,
- std::max(costs->compute_time, costs->memory_time));
- } else {
- costs->execution_time = costs->compute_time + costs->memory_time +
- costs->intermediate_memory_time;
- }
-}
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
index 9183c54..b76884e 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
@@ -194,11 +194,6 @@
static OpInfo::TensorProperties DescribeTensor(
DataType type, const std::vector<int64>& dims);
- // This method calculates the execution time depending on whether IO can
- // overlap with computation. It assumes the memory and the compute times have
- // already been calculated.
- void CombineCostsAndUpdateExecutionTime(Costs* costs) const;
-
protected:
std::map<string, int> elementwise_ops_;
typedef std::function<Costs(const OpContext& op_context)> CostImpl;
diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc
index 2f3d171..f3bcf53 100644
--- a/tensorflow/core/grappler/costs/utils.cc
+++ b/tensorflow/core/grappler/costs/utils.cc
@@ -504,5 +504,16 @@
return output.str();
}
+void CombineCostsAndUpdateExecutionTime(bool compute_memory_overlap,
+ Costs* costs) {
+ if (compute_memory_overlap) {
+ costs->execution_time =
+ std::max(costs->intermediate_memory_time,
+ std::max(costs->compute_time, costs->memory_time));
+ } else {
+ costs->execution_time = costs->compute_time + costs->memory_time +
+ costs->intermediate_memory_time;
+ }
+}
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/utils.h b/tensorflow/core/grappler/costs/utils.h
index ea64e5a..3dfbd67 100644
--- a/tensorflow/core/grappler/costs/utils.h
+++ b/tensorflow/core/grappler/costs/utils.h
@@ -25,6 +25,7 @@
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/graph/types.h"
+#include "tensorflow/core/grappler/costs/cost_estimator.h"
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/config.pb.h"
@@ -119,6 +120,12 @@
string GetStatsStringFromRunMetadata(const RunMetadata& run_metadata,
bool verbosity);
+// This method calculates the execution time depending on whether IO can
+// overlap with computation. It assumes the memory and the compute times have
+// already been calculated.
+void CombineCostsAndUpdateExecutionTime(bool compute_memory_overlap,
+ Costs* costs);
+
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index b18c7c5..8324f40 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -3594,7 +3594,6 @@
protobuf::RepeatedPtrField<string> parent_inputs;
parent_inputs.Swap(parent->mutable_input());
- std::vector<string> ctrl_output;
// TODO(rmlarsen): IF the child occurs more than once, is it beneficial to
// collapse it into the parent multiple times? Probably not.
for (const auto& input : parent_inputs) {
diff --git a/tensorflow/core/grappler/optimizers/data/auto_shard.cc b/tensorflow/core/grappler/optimizers/data/auto_shard.cc
index bcc8feb..7ed80a1 100644
--- a/tensorflow/core/grappler/optimizers/data/auto_shard.cc
+++ b/tensorflow/core/grappler/optimizers/data/auto_shard.cc
@@ -396,7 +396,6 @@
MutableGraphView graph(output);
FunctionLibraryDefinition flib(OpRegistry::Global(), item.graph.library());
- NodeDef target_node;
absl::flat_hash_set<string> nodes_to_delete;
NodeDef* sink_node;
diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD
index 7572141..8941d55 100644
--- a/tensorflow/core/grappler/utils/BUILD
+++ b/tensorflow/core/grappler/utils/BUILD
@@ -386,17 +386,11 @@
hdrs = ["transitive_fanin.h"],
visibility = ["//visibility:public"],
deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:utils",
- ] + select({
- "//tensorflow:android": [
- "//tensorflow/core:android_tensorflow_lib",
- ],
- "//conditions:default": [
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:protos_all_cc",
- ],
- }),
+ ],
)
tf_cc_test(
diff --git a/tensorflow/core/kernels/boosted_trees/stats_ops.cc b/tensorflow/core/kernels/boosted_trees/stats_ops.cc
index 45dc248..851e5b7 100644
--- a/tensorflow/core/kernels/boosted_trees/stats_ops.cc
+++ b/tensorflow/core/kernels/boosted_trees/stats_ops.cc
@@ -34,7 +34,10 @@
using ConstVectorMap = Eigen::Map<const Eigen::VectorXf>;
using VectorMap = Eigen::Map<Eigen::VectorXf>;
-// V1 Op. Deprecated. BoostedTreesCalculateBestFeatureSplitOp is V2.
+constexpr char kInequalitySplit[] = "inequality";
+constexpr char kEqualitySplit[] = "equality";
+
+// V1 Op. Deprecated. BoostedTreesCalculateBestFeatureSplitOpV2 is V2.
class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel {
public:
explicit BoostedTreesCalculateBestGainsPerFeatureOp(
@@ -227,7 +230,7 @@
Name("BoostedTreesCalculateBestGainsPerFeature").Device(DEVICE_CPU),
BoostedTreesCalculateBestGainsPerFeatureOp);
-// V2 Op.
+// Deprecated op. Use BoostedTreesCalculateBestFeatureSplitOpV2.
class BoostedTreesCalculateBestFeatureSplitOp : public OpKernel {
public:
explicit BoostedTreesCalculateBestFeatureSplitOp(
@@ -545,11 +548,394 @@
string split_type_;
};
-// v2 op that supports multi-class.
+// Deprecated op. Use BoostedTreesCalculateBestFeatureSplitOpV2.
REGISTER_KERNEL_BUILDER(
Name("BoostedTreesCalculateBestFeatureSplit").Device(DEVICE_CPU),
BoostedTreesCalculateBestFeatureSplitOp);
+// V2 Op.
+class BoostedTreesCalculateBestFeatureSplitV2 : public OpKernel {
+ public:
+ explicit BoostedTreesCalculateBestFeatureSplitV2(
+ OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("logits_dimension", &logits_dim_));
+ OP_REQUIRES_OK(context, context->GetAttr("num_features", &num_features_));
+ }
+
+ void Compute(OpKernelContext* const context) override {
+ // node_id_range
+ const Tensor* node_id_range_t;
+ OP_REQUIRES_OK(context, context->input("node_id_range", &node_id_range_t));
+ const auto node_id_range = node_id_range_t->vec<int32>();
+ const int32 node_id_first = node_id_range(0); // Inclusive.
+ const int32 node_id_last = node_id_range(1); // Exclusive.
+
+ // Get stats_summaries_list.
+ OpInputList stats_summaries_list;
+ OP_REQUIRES_OK(context, context->input_list("stats_summaries_list",
+ &stats_summaries_list));
+
+ // Infer dimensions of a stats_summary.
+ DCHECK_GT(stats_summaries_list.size(), 0);
+ const int32 feature_dims = stats_summaries_list[0].dim_size(1);
+ // The last bucket is for default/missing value.
+ const int32 num_buckets = stats_summaries_list[0].dim_size(2) - 1;
+ const int32 logits_dim = logits_dim_;
+ const int32 hessian_dim = stats_summaries_list[0].dim_size(3) - logits_dim;
+ DCHECK_GT(hessian_dim, 0);
+ DCHECK_LE(hessian_dim, logits_dim * logits_dim);
+
+ // Vector of stats_summaries; each element is stats for feature of shape
+ // [max_splits, feature_dim, num_buckets, logits_dim + hessian_dim].
+ std::vector<TTypes<float, 4>::ConstTensor> stats_summaries;
+ DCHECK_EQ(stats_summaries_list.size(), num_features_);
+ stats_summaries.reserve(num_features_);
+ for (const auto& tensor : stats_summaries_list) {
+ stats_summaries.emplace_back(tensor.tensor<float, 4>());
+ }
+
+ // Split types.
+ const Tensor* split_types_t;
+ OP_REQUIRES_OK(context, context->input("split_types", &split_types_t));
+ const auto split_types = split_types_t->vec<tstring>();
+ DCHECK_EQ(split_types.size(), num_features_);
+ // Validate.
+ for (int i = 0; i < num_features_; ++i) {
+ if (!(split_types(i) == kInequalitySplit ||
+ split_types(i) == kEqualitySplit)) {
+ OP_REQUIRES_OK(
+ context,
+ errors::Aborted(
+ "Operation received an exception: Incorrect split type"));
+ }
+ }
+ // Feature ids.
+ const Tensor* candidate_feature_ids_t;
+ OP_REQUIRES_OK(context, context->input("candidate_feature_ids",
+ &candidate_feature_ids_t));
+ const auto candidate_feature_ids = candidate_feature_ids_t->vec<int32>();
+ DCHECK_EQ(candidate_feature_ids.size(), num_features_);
+
+ // L1, L2, tree_complexity, min_node_weight.
+ const Tensor* l1_t;
+ OP_REQUIRES_OK(context, context->input("l1", &l1_t));
+ const auto l1 = l1_t->scalar<float>()();
+ DCHECK_GE(l1, 0);
+ if (logits_dim_ > 1) {
+ // Multi-class L1 regularization not supported yet.
+ DCHECK_EQ(l1, 0);
+ }
+ const Tensor* l2_t;
+ OP_REQUIRES_OK(context, context->input("l2", &l2_t));
+ const auto l2 = l2_t->scalar<float>()();
+ DCHECK_GE(l2, 0);
+ const Tensor* tree_complexity_t;
+ OP_REQUIRES_OK(context,
+ context->input("tree_complexity", &tree_complexity_t));
+ const auto tree_complexity = tree_complexity_t->scalar<float>()();
+ const Tensor* min_node_weight_t;
+ OP_REQUIRES_OK(context,
+ context->input("min_node_weight", &min_node_weight_t));
+ const auto min_node_weight = min_node_weight_t->scalar<float>()();
+
+ std::vector<int32> output_node_ids;
+ std::vector<float> output_gains;
+ std::vector<int32> output_feature_ids;
+ std::vector<int32> output_feature_dimensions;
+ std::vector<int32> output_thresholds;
+ std::vector<Eigen::VectorXf> output_left_node_contribs;
+ std::vector<Eigen::VectorXf> output_right_node_contribs;
+ std::vector<string> output_split_types;
+
+ // TODO(tanzheny) parallelize the computation.
+ // Iterate each node and find the best gain per node.
+ float parent_gain;
+ for (int32 node_id = node_id_first; node_id < node_id_last; ++node_id) {
+ float best_gain = std::numeric_limits<float>::lowest();
+ int32 best_bucket;
+ int32 best_f_id;
+ int32 best_f_dim;
+ string best_split_type;
+ Eigen::VectorXf best_contrib_for_left(logits_dim);
+ Eigen::VectorXf best_contrib_for_right(logits_dim);
+
+ // Sum of gradient and hessian. Compute parent gain using first feature.
+ ConstMatrixMap stats_mat(&stats_summaries[0](node_id, 0, 0, 0),
+ num_buckets + 1, // Including default bucket.
+ logits_dim + hessian_dim);
+ const Eigen::VectorXf total_grad =
+ stats_mat.leftCols(logits_dim).colwise().sum();
+ const Eigen::VectorXf total_hess =
+ stats_mat.rightCols(hessian_dim).colwise().sum();
+ if (total_hess.norm() < min_node_weight) {
+ continue;
+ }
+ Eigen::VectorXf unused(logits_dim);
+ CalculateWeightsAndGains(total_grad, total_hess, l1, l2, &unused,
+ &parent_gain);
+ for (int f_idx = 0; f_idx < num_features_; ++f_idx) {
+ const string split_type = split_types(f_idx);
+ TTypes<float, 4>::ConstTensor stats_summary = stats_summaries[f_idx];
+ float f_best_gain = std::numeric_limits<float>::lowest();
+ int32 f_best_bucket;
+ int32 f_best_f_dim;
+ string f_best_split_type;
+ Eigen::VectorXf f_best_contrib_for_left(logits_dim);
+ Eigen::VectorXf f_best_contrib_for_right(logits_dim);
+
+ if (split_type == kInequalitySplit) {
+ CalculateBestInequalitySplit(
+ stats_summary, node_id, feature_dims, logits_dim, hessian_dim,
+ num_buckets, min_node_weight, l1, l2, &f_best_gain,
+ &f_best_bucket, &f_best_f_dim, &f_best_split_type,
+ &f_best_contrib_for_left, &f_best_contrib_for_right);
+ } else {
+ CalculateBestEqualitySplit(
+ stats_summary, total_grad, total_hess, node_id, feature_dims,
+ logits_dim, hessian_dim, num_buckets, l1, l2, &f_best_gain,
+ &f_best_bucket, &f_best_f_dim, &f_best_split_type,
+ &f_best_contrib_for_left, &f_best_contrib_for_right);
+ }
+ if (f_best_gain > best_gain) {
+ best_gain = f_best_gain;
+ best_f_id = candidate_feature_ids(f_idx);
+ best_f_dim = f_best_f_dim;
+ best_split_type = f_best_split_type;
+ best_bucket = f_best_bucket;
+ best_contrib_for_left = f_best_contrib_for_left;
+ best_contrib_for_right = f_best_contrib_for_right;
+ }
+ } // For feature id.
+ if (best_gain == std::numeric_limits<float>::lowest()) {
+ // Do not add the node if no split is found.
+ continue;
+ }
+ output_node_ids.push_back(node_id);
+ // Remove the parent gain for the parent node.
+ output_gains.push_back(best_gain - parent_gain);
+ output_feature_ids.push_back(best_f_id);
+ output_feature_dimensions.push_back(best_f_dim);
+ // Default direction is fixed for dense splits.
+ // TODO(tanzheny) account for default values.
+ output_split_types.push_back(best_split_type);
+ output_thresholds.push_back(best_bucket);
+ output_left_node_contribs.push_back(best_contrib_for_left);
+ output_right_node_contribs.push_back(best_contrib_for_right);
+ } // for node id.
+ const int num_nodes = output_node_ids.size();
+ // output_node_ids
+ Tensor* output_node_ids_t = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output("node_ids", {num_nodes},
+ &output_node_ids_t));
+ auto output_node_ids_vec = output_node_ids_t->vec<int32>();
+
+ // output_gains
+ Tensor* output_gains_t;
+ OP_REQUIRES_OK(context, context->allocate_output("gains", {num_nodes},
+ &output_gains_t));
+ auto output_gains_vec = output_gains_t->vec<float>();
+
+ // output_feature_ids
+ Tensor* output_features_ids_t;
+ OP_REQUIRES_OK(context, context->allocate_output("feature_ids", {num_nodes},
+ &output_features_ids_t));
+ auto output_features_vec = output_features_ids_t->vec<int32>();
+
+ // output_feature_dimensions
+ Tensor* output_feature_dimension_t;
+ OP_REQUIRES_OK(context,
+ context->allocate_output("feature_dimensions", {num_nodes},
+ &output_feature_dimension_t));
+ auto output_feature_dimensions_vec =
+ output_feature_dimension_t->vec<int32>();
+
+ // output_thresholds
+ Tensor* output_thresholds_t;
+ OP_REQUIRES_OK(context, context->allocate_output("thresholds", {num_nodes},
+ &output_thresholds_t));
+ auto output_thresholds_vec = output_thresholds_t->vec<int32>();
+
+ // output_left_node_contribs
+ Tensor* output_left_node_contribs_t;
+ OP_REQUIRES_OK(context, context->allocate_output(
+ "left_node_contribs", {num_nodes, logits_dim},
+ &output_left_node_contribs_t));
+ auto output_left_node_contribs_matrix =
+ output_left_node_contribs_t->matrix<float>();
+
+ // output_right_node_contribs
+ Tensor* output_right_node_contribs_t;
+ OP_REQUIRES_OK(context, context->allocate_output(
+ "right_node_contribs", {num_nodes, logits_dim},
+ &output_right_node_contribs_t));
+ auto output_right_node_contribs_matrix =
+ output_right_node_contribs_t->matrix<float>();
+
+ // split type
+ Tensor* output_split_types_t;
+ OP_REQUIRES_OK(
+ context, context->allocate_output("split_with_default_directions",
+ {num_nodes}, &output_split_types_t));
+ auto output_split_types_vec = output_split_types_t->vec<tstring>();
+
+ // Sets output tensors from vectors.
+ for (int i = 0; i < num_nodes; ++i) {
+ output_node_ids_vec(i) = output_node_ids[i];
+ output_features_vec(i) = output_feature_ids[i];
+ // Adjust the gains to penalize by tree complexity.
+ output_gains_vec(i) = output_gains[i] - tree_complexity;
+ output_feature_dimensions_vec(i) = output_feature_dimensions[i];
+ output_thresholds_vec(i) = output_thresholds[i];
+ for (int j = 0; j < logits_dim; ++j) {
+ output_left_node_contribs_matrix(i, j) =
+ output_left_node_contribs[i][j];
+ output_right_node_contribs_matrix(i, j) =
+ output_right_node_contribs[i][j];
+ }
+ output_split_types_vec(i) = output_split_types[i];
+ }
+ }
+
+ private:
+ // TODO(crawles): Simplify inequality path just like equality b/138329196
+ // Currently this is not simplify-able due to numerical instability in math
+ // i.e. gain = -g.transpose() * hessian_and_reg.colPivHouseholderQr().solve(g)
+ // It caused gain to be Inf when g is approaching 0 but not exactly 0 while
+ // there is no regularization.
+ // Calculate the best inequality split per node.
+ void CalculateBestInequalitySplit(
+ TTypes<float, 4>::ConstTensor stats_summary, const int32 node_id,
+ const int32 feature_dims, const int32 logits_dim, const int32 hessian_dim,
+ const int32 num_buckets, const float min_node_weight, const float l1,
+ const float l2, float* best_gain, int32* best_bucket, int32* best_f_dim,
+ string* best_split_type, Eigen::VectorXf* best_contrib_for_left,
+ Eigen::VectorXf* best_contrib_for_right) {
+ std::vector<Eigen::VectorXf> cum_grad;
+ std::vector<Eigen::VectorXf> cum_hess;
+ // get all cumulative gradients including default bucket.
+ cum_grad.reserve(num_buckets);
+ cum_hess.reserve(num_buckets);
+
+ for (int f_dim = 0; f_dim < feature_dims; ++f_dim) {
+ ConstVectorMap default_stats_vec(
+ &stats_summary(node_id, f_dim, num_buckets, 0),
+ logits_dim + hessian_dim);
+ Eigen::VectorXf missing_bucket_grad = default_stats_vec.head(logits_dim);
+ Eigen::VectorXf missing_bucket_hess = default_stats_vec.tail(hessian_dim);
+ cum_grad.clear();
+ cum_hess.clear();
+ Eigen::VectorXf total_grad = Eigen::VectorXf::Zero(logits_dim);
+ Eigen::VectorXf total_hess = Eigen::VectorXf::Zero(hessian_dim);
+ // sum all the gradients including default bucket.
+ for (int bucket = 0; bucket <= num_buckets; ++bucket) {
+ for (int i = 0; i < logits_dim; ++i) {
+ total_grad[i] += stats_summary(node_id, f_dim, bucket, i);
+ }
+ for (int i = 0; i < hessian_dim; ++i) {
+ // Full hessian.
+ total_hess[i] +=
+ stats_summary(node_id, f_dim, bucket, logits_dim + i);
+ }
+ if (bucket < num_buckets) {
+ cum_grad.push_back(total_grad);
+ cum_hess.push_back(total_hess);
+ }
+ }
+ const string kInequalityDefaultLeft =
+ boosted_trees::SplitTypeWithDefault_Name(
+ boosted_trees::INEQUALITY_DEFAULT_LEFT);
+ const string kInequalityDefaultRight =
+ boosted_trees::SplitTypeWithDefault_Name(
+ boosted_trees::INEQUALITY_DEFAULT_RIGHT);
+
+ // Iterate from left to right, excluding default bucket.
+ for (int bucket = 0; bucket < num_buckets; ++bucket) {
+ // default value goes to left node.
+ const Eigen::VectorXf total_left_grad =
+ cum_grad[bucket] + missing_bucket_grad;
+ const Eigen::VectorXf total_left_hess =
+ cum_hess[bucket] + missing_bucket_hess;
+ MaybeUpdateBestSplit(
+ total_left_grad, total_grad - total_left_grad, total_left_hess,
+ total_hess - total_left_hess, logits_dim, bucket, f_dim, l1, l2,
+ kInequalityDefaultLeft, best_gain, best_bucket, best_f_dim,
+ best_split_type, best_contrib_for_left, best_contrib_for_right);
+ // default value goes to right node.
+ MaybeUpdateBestSplit(
+ cum_grad[bucket], total_grad - cum_grad[bucket], cum_hess[bucket],
+ total_hess - cum_hess[bucket], logits_dim, bucket, f_dim, l1, l2,
+ kInequalityDefaultRight, best_gain, best_bucket, best_f_dim,
+ best_split_type, best_contrib_for_left, best_contrib_for_right);
+ } // for bucket
+ }
+ }
+
+ // Calculate the best equality split per node.
+ void CalculateBestEqualitySplit(
+ TTypes<float, 4>::ConstTensor stats_summary,
+ const Eigen::VectorXf& total_grad, const Eigen::VectorXf& total_hess,
+ const int32 node_id, const int32 feature_dims, const int32 logits_dim,
+ const int32 hessian_dim, const int32 num_buckets, const float l1,
+ const float l2, float* best_gain, int32* best_bucket, int32* best_f_dim,
+ string* best_split_type, Eigen::VectorXf* best_contrib_for_left,
+ Eigen::VectorXf* best_contrib_for_right) {
+ const string kEqualityDefaultRight =
+ boosted_trees::SplitTypeWithDefault_Name(
+ boosted_trees::EQUALITY_DEFAULT_RIGHT);
+ for (int f_dim = 0; f_dim < feature_dims; ++f_dim) {
+ for (int bucket = 0; bucket < num_buckets; ++bucket) {
+ ConstVectorMap stats_vec(&stats_summary(node_id, f_dim, bucket, 0),
+ logits_dim + hessian_dim);
+ Eigen::VectorXf curr_grad = stats_vec.head(logits_dim);
+ Eigen::VectorXf curr_hess = stats_vec.tail(hessian_dim);
+ MaybeUpdateBestSplit(curr_grad, total_grad - curr_grad, curr_hess,
+ total_hess - curr_hess, logits_dim, bucket, f_dim,
+ l1, l2, kEqualityDefaultRight, best_gain,
+ best_bucket, best_f_dim, best_split_type,
+ best_contrib_for_left, best_contrib_for_right);
+ }
+ }
+ }
+
+ void MaybeUpdateBestSplit(const Eigen::VectorXf& grad_for_left,
+ const Eigen::VectorXf& grad_for_right,
+ const Eigen::VectorXf& hess_for_left,
+ const Eigen::VectorXf& hess_for_right,
+ const int32 logits_dim, const int32 bucket,
+ const int32 f_dim, const float l1, const float l2,
+ const string split_type, float* best_gain,
+ int32* best_bucket, int32* best_f_dim,
+ string* best_split_type,
+ Eigen::VectorXf* best_contrib_for_left,
+ Eigen::VectorXf* best_contrib_for_right) {
+ // Left child.
+ Eigen::VectorXf contrib_for_left(logits_dim);
+ float gain_for_left;
+ CalculateWeightsAndGains(grad_for_left, hess_for_left, l1, l2,
+ &contrib_for_left, &gain_for_left);
+ Eigen::VectorXf contrib_for_right(logits_dim);
+ float gain_for_right;
+ CalculateWeightsAndGains(grad_for_right, hess_for_right, l1, l2,
+ &contrib_for_right, &gain_for_right);
+ if (GainIsLarger(gain_for_left + gain_for_right, *best_gain)) {
+ *best_gain = gain_for_left + gain_for_right;
+ *best_bucket = bucket;
+ *best_f_dim = f_dim;
+ *best_contrib_for_left = contrib_for_left;
+ *best_contrib_for_right = contrib_for_right;
+ *best_split_type = split_type;
+ }
+ }
+ int num_features_;
+ int logits_dim_;
+};
+
+// v2 op that supports multi-class.
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesCalculateBestFeatureSplitV2").Device(DEVICE_CPU),
+ BoostedTreesCalculateBestFeatureSplitV2);
+
// Map from bucket id to vector of statistics.
typedef std::map<int32, std::vector<float>> BucketMap;
typedef BucketMap::iterator BucketMapIterator;
diff --git a/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h b/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h
index cb042fb..ecc58da 100644
--- a/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h
+++ b/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h
@@ -20,10 +20,11 @@
#ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_GPU_COMMON_CU_H_
#define TENSORFLOW_CORE_KERNELS_CWISE_OPS_GPU_COMMON_CU_H_
-#define EIGEN_USE_GPU
-
+#define _USE_MATH_DEFINES
+#include <cmath>
#include <complex>
+#define EIGEN_USE_GPU
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/cwise_ops.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc
index 407819d..cd6682d 100644
--- a/tensorflow/core/kernels/data/captured_function.cc
+++ b/tensorflow/core/kernels/data/captured_function.cc
@@ -480,24 +480,23 @@
/* static */
Status CapturedFunction::Create(
- OpKernelContext* ctx,
- const std::shared_ptr<const FunctionMetadata> metadata,
+ OpKernelContext* ctx, std::shared_ptr<const FunctionMetadata> metadata,
const string& argument_name,
std::unique_ptr<CapturedFunction>* out_function) {
OpInputList inputs;
TF_RETURN_IF_ERROR(ctx->input_list(argument_name, &inputs));
std::vector<Tensor> captured_inputs(inputs.begin(), inputs.end());
- return Create(ctx, metadata, std::move(captured_inputs), out_function);
+ return Create(ctx, std::move(metadata), std::move(captured_inputs),
+ out_function);
}
/* static */
Status CapturedFunction::Create(
- OpKernelContext* ctx,
- const std::shared_ptr<const FunctionMetadata> metadata,
+ OpKernelContext* ctx, std::shared_ptr<const FunctionMetadata> metadata,
std::vector<Tensor>&& captured_inputs,
std::unique_ptr<CapturedFunction>* out_function) {
*out_function = absl::WrapUnique(
- new CapturedFunction(metadata, std::move(captured_inputs)));
+ new CapturedFunction(std::move(metadata), std::move(captured_inputs)));
return Status::OK();
}
@@ -602,8 +601,7 @@
*instantiated_captured_function =
absl::WrapUnique<InstantiatedCapturedFunction>(
new InstantiatedCapturedFunction(lib, f_handle, std::move(ret_types),
- *ctx->runner(),
- ctx->cancellation_manager(), this));
+ *ctx->runner(), this));
return Status::OK();
}
@@ -620,12 +618,11 @@
InstantiatedCapturedFunction::InstantiatedCapturedFunction(
FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle,
DataTypeVector ret_types, std::function<void(std::function<void()>)> runner,
- CancellationManager* cancellation_manager, CapturedFunction* captured_func)
+ CapturedFunction* captured_func)
: lib_(lib),
f_handle_(f_handle),
ret_types_(std::move(ret_types)),
captured_runner_(std::move(runner)),
- captured_cancellation_manager_(cancellation_manager),
captured_func_(captured_func) {}
// NOTE: We don't release f_handle_ here and instead delegate the function
@@ -664,7 +661,7 @@
"InstantiatedCapturedFunction::Run#id=", f_opts.step_id, "#");
},
profiler::TraceMeLevel::kInfo);
- lib_->Run(f_opts, f_handle_, &frame, [&n, &s](Status func_status) {
+ lib_->Run(f_opts, f_handle_, &frame, [&n, &s](const Status& func_status) {
s.Update(func_status);
n.Notify();
});
@@ -704,7 +701,7 @@
f_opts.step_id, "#");
},
profiler::TraceMeLevel::kInfo);
- lib_->Run(f_opts, f_handle_, &frame, [&n, &s](Status func_status) {
+ lib_->Run(f_opts, f_handle_, &frame, [&n, &s](const Status& func_status) {
s.Update(func_status);
n.Notify();
});
@@ -728,7 +725,7 @@
f_opts.step_container = &step_container;
f_opts.runner = &captured_runner_;
f_opts.create_rendezvous = ShouldCreateRendezvous();
- CancellationManager cancellation_manager(captured_cancellation_manager_);
+ CancellationManager cancellation_manager;
f_opts.cancellation_manager = &cancellation_manager;
BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
@@ -742,7 +739,7 @@
f_opts.step_id, "#");
},
profiler::TraceMeLevel::kInfo);
- lib_->Run(f_opts, f_handle_, &frame, [&n, &s](Status func_status) {
+ lib_->Run(f_opts, f_handle_, &frame, [&n, &s](const Status& func_status) {
s.Update(func_status);
n.Notify();
});
@@ -849,9 +846,10 @@
}
CapturedFunction::CapturedFunction(
- const std::shared_ptr<const FunctionMetadata> metadata,
+ std::shared_ptr<const FunctionMetadata> metadata,
std::vector<Tensor> captured_inputs)
- : metadata_(metadata), captured_inputs_(std::move(captured_inputs)) {}
+ : metadata_(std::move(metadata)),
+ captured_inputs_(std::move(captured_inputs)) {}
Status CapturedFunction::IsMultiDevice(IteratorContext* ctx,
bool* is_multi_device) {
diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h
index 2ac5f80..8747e73 100644
--- a/tensorflow/core/kernels/data/captured_function.h
+++ b/tensorflow/core/kernels/data/captured_function.h
@@ -98,7 +98,6 @@
FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle,
DataTypeVector ret_types,
std::function<void(std::function<void()>)> runner,
- CancellationManager* cancellation_manager,
CapturedFunction* captured_func);
// Determines whether a rendezvous object should be created when running the
@@ -110,10 +109,9 @@
FunctionLibraryRuntime* const lib_;
const FunctionLibraryRuntime::Handle f_handle_;
const DataTypeVector ret_types_;
- // Note: Since we have no IteratorContext in `RunInstantiated`, we have to
- // capture these at function instantiation time.
+ // Note: We capture the runner at function instantiation time to be able to
+ // run the function without `IteratorContext` via `RunInstantiated`.
std::function<void(std::function<void()>)> captured_runner_;
- CancellationManager* captured_cancellation_manager_;
CapturedFunction* const captured_func_;
TF_DISALLOW_COPY_AND_ASSIGN(InstantiatedCapturedFunction);
@@ -192,14 +190,14 @@
// Creates a new instance using a list of named attributes, fetching captured
// inputs from a context argument.
static Status Create(OpKernelContext* ctx,
- const std::shared_ptr<const FunctionMetadata> metadata,
+ std::shared_ptr<const FunctionMetadata> metadata,
const string& argument_name,
std::unique_ptr<CapturedFunction>* out_function);
// Creates a new instance using a list of named attributes, using provided
// captured inputs.
static Status Create(OpKernelContext* ctx,
- const std::shared_ptr<const FunctionMetadata> metadata,
+ std::shared_ptr<const FunctionMetadata> metadata,
std::vector<Tensor>&& captured_inputs,
std::unique_ptr<CapturedFunction>* out_function);
@@ -258,7 +256,7 @@
}
private:
- CapturedFunction(const std::shared_ptr<const FunctionMetadata> metadata,
+ CapturedFunction(std::shared_ptr<const FunctionMetadata> metadata,
std::vector<Tensor> captured_inputs);
// Determines whether the captured function requires the use of the
diff --git a/tensorflow/core/kernels/data/dataset_test_base.cc b/tensorflow/core/kernels/data/dataset_test_base.cc
index 0984629..ce194a8 100644
--- a/tensorflow/core/kernels/data/dataset_test_base.cc
+++ b/tensorflow/core/kernels/data/dataset_test_base.cc
@@ -426,7 +426,6 @@
params->op_kernel = kernel;
params->resource_manager = resource_mgr_.get();
params->runner = &runner_;
- checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper;
slice_reader_cache_ =
absl::make_unique<checkpoint::TensorSliceReaderCacheWrapper>();
params->slice_reader_cache = slice_reader_cache_.get();
diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc
index 3a260f3..dea569c 100644
--- a/tensorflow/core/kernels/data/dataset_utils.cc
+++ b/tensorflow/core/kernels/data/dataset_utils.cc
@@ -21,6 +21,7 @@
#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/framework/op_def_util.h"
#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -448,11 +449,31 @@
return ReadTensorInternal(key, val);
}
+Status VariantTensorDataReader::ReadScalar(StringPiece name, StringPiece key,
+ int64* val) {
+ return ReadScalarInternal(name, key, val);
+}
+
+Status VariantTensorDataReader::ReadScalar(StringPiece name, StringPiece key,
+ tstring* val) {
+ return ReadScalarInternal(name, key, val);
+}
+
+Status VariantTensorDataReader::ReadTensor(StringPiece name, StringPiece key,
+ Tensor* val) {
+ return ReadTensorInternal(name, key, val);
+}
+
bool VariantTensorDataReader::Contains(StringPiece key) {
string name;
if (!GetIteratorName(key, &name).ok()) {
return false;
}
+ return Contains(name, key);
+}
+
+bool VariantTensorDataReader::Contains(StringPiece n, StringPiece key) {
+ string name(n);
return map_[name].find(string(key)) != map_[name].end();
}
@@ -460,6 +481,20 @@
Status VariantTensorDataReader::ReadScalarInternal(StringPiece key, T* val) {
string name;
TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
+ return ReadScalarInternal(name, key, val);
+}
+
+Status VariantTensorDataReader::ReadTensorInternal(StringPiece key,
+ Tensor* val) {
+ string name;
+ TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
+ return ReadTensorInternal(name, key, val);
+}
+
+template <typename T>
+Status VariantTensorDataReader::ReadScalarInternal(StringPiece n,
+ StringPiece key, T* val) {
+ string name(n);
if (map_[name].find(string(key)) == map_[name].end()) {
return errors::NotFound(key);
}
@@ -467,10 +502,10 @@
return Status::OK();
}
-Status VariantTensorDataReader::ReadTensorInternal(StringPiece key,
+Status VariantTensorDataReader::ReadTensorInternal(StringPiece n,
+ StringPiece key,
Tensor* val) {
- string name;
- TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
+ string name(n);
if (map_[name].find(string(key)) == map_[name].end()) {
return errors::NotFound(key);
}
@@ -492,6 +527,21 @@
return WriteTensorInternal(key, val);
}
+Status VariantTensorDataWriter::WriteScalar(StringPiece name, StringPiece key,
+ const int64 val) {
+ return WriteScalarInternal(name, key, val);
+}
+
+Status VariantTensorDataWriter::WriteScalar(StringPiece name, StringPiece key,
+ const tstring& val) {
+ return WriteScalarInternal(name, key, val);
+}
+
+Status VariantTensorDataWriter::WriteTensor(StringPiece name, StringPiece key,
+ const Tensor& val) {
+ return WriteTensorInternal(name, key, val);
+}
+
void VariantTensorDataWriter::MaybeFlush() {
if (is_flushed_) return;
for (auto& keys : keys_) {
@@ -535,9 +585,9 @@
return errors::FailedPrecondition(
"Cannot call WriteScalar after GetData or ReleaseData is called");
}
- Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
- val_t.scalar<T>()() = val;
- return WriteTensorInternal(key, val_t);
+ string name;
+ TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
+ return WriteScalarInternal(name, key, val);
}
Status VariantTensorDataWriter::WriteTensorInternal(StringPiece key,
@@ -548,7 +598,31 @@
}
string name;
TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
+ return WriteTensorInternal(name, key, val);
+}
+
+template <typename T>
+Status VariantTensorDataWriter::WriteScalarInternal(StringPiece name,
+ StringPiece key,
+ const T& val) {
+ if (is_flushed_) {
+ return errors::FailedPrecondition(
+ "Cannot call WriteScalar after GetData or ReleaseData is called");
+ }
+ Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
+ val_t.scalar<T>()() = val;
+ return WriteTensorInternal(name, key, val_t);
+}
+
+Status VariantTensorDataWriter::WriteTensorInternal(StringPiece n,
+ StringPiece key,
+ const Tensor& val) {
+ if (is_flushed_) {
+ return errors::FailedPrecondition(
+ "Cannot call WriteTensor after GetData or ReleaseData is called");
+ }
DCHECK_EQ(key.find(kDelimiter), string::npos);
+ string name(n);
if (keys_.count(name) == 0) {
keys_[name] = std::vector<string>();
}
diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h
index 82f05e9..0401e3d 100644
--- a/tensorflow/core/kernels/data/dataset_utils.h
+++ b/tensorflow/core/kernels/data/dataset_utils.h
@@ -173,11 +173,20 @@
Status ReadTensor(StringPiece key, Tensor* val) override;
bool Contains(StringPiece key) override;
+ Status ReadScalar(StringPiece name, StringPiece key, int64* val) override;
+ Status ReadScalar(StringPiece name, StringPiece key, tstring* val) override;
+ Status ReadTensor(StringPiece name, StringPiece key, Tensor* val) override;
+ bool Contains(StringPiece name, StringPiece key) override;
+
private:
template <typename T>
Status ReadScalarInternal(StringPiece key, T* val);
Status ReadTensorInternal(StringPiece key, Tensor* val);
+ template <typename T>
+ Status ReadScalarInternal(StringPiece name, StringPiece key, T* val);
+ Status ReadTensorInternal(StringPiece name, StringPiece key, Tensor* val);
+
std::map<string, std::map<string, size_t>> map_;
std::map<string, const VariantTensorData*> data_; // Not owned.
};
@@ -198,6 +207,13 @@
Status WriteScalar(StringPiece key, const tstring& val) override;
Status WriteTensor(StringPiece key, const Tensor& val) override;
+ Status WriteScalar(StringPiece name, StringPiece key,
+ const int64 val) override;
+ Status WriteScalar(StringPiece name, StringPiece key,
+ const tstring& val) override;
+ Status WriteTensor(StringPiece name, StringPiece key,
+ const Tensor& val) override;
+
// Releases the built VariantTensorData's to `variants`. Clears out all
// class state.
void ReleaseData(std::vector<std::unique_ptr<VariantTensorData>>* variants);
@@ -213,6 +229,11 @@
Status WriteScalarInternal(StringPiece key, const T& val);
Status WriteTensorInternal(StringPiece key, const Tensor& val);
+ template <typename T>
+ Status WriteScalarInternal(StringPiece name, StringPiece key, const T& val);
+ Status WriteTensorInternal(StringPiece name, StringPiece key,
+ const Tensor& val);
+
bool is_flushed_ = false;
std::map<string, std::unique_ptr<VariantTensorData>> data_;
std::map<string, std::vector<string>> keys_;
diff --git a/tensorflow/core/kernels/data/dataset_utils_test.cc b/tensorflow/core/kernels/data/dataset_utils_test.cc
index b8de855..5ad0d0b 100644
--- a/tensorflow/core/kernels/data/dataset_utils_test.cc
+++ b/tensorflow/core/kernels/data/dataset_utils_test.cc
@@ -91,6 +91,45 @@
reader.ReadTensor(full_name("NonExistentKey"), &val_tensor).code());
}
+TEST(DatasetUtilsTest, VariantTensorDataRoundtripIteratorName) {
+ VariantTensorDataWriter writer;
+ TF_ASSERT_OK(writer.WriteScalar("Iterator", "Int64", 24));
+ Tensor input_tensor(DT_FLOAT, {1});
+ input_tensor.flat<float>()(0) = 2.0f;
+ TF_ASSERT_OK(writer.WriteTensor("Iterator", "Tensor", input_tensor));
+ std::vector<const VariantTensorData*> data;
+ writer.GetData(&data);
+
+ VariantTensorDataReader reader(data);
+ int64 val_int64;
+ TF_ASSERT_OK(reader.ReadScalar("Iterator", "Int64", &val_int64));
+ EXPECT_EQ(val_int64, 24);
+ Tensor val_tensor;
+ TF_ASSERT_OK(reader.ReadTensor("Iterator", "Tensor", &val_tensor));
+ EXPECT_EQ(input_tensor.NumElements(), val_tensor.NumElements());
+ EXPECT_EQ(input_tensor.flat<float>()(0), val_tensor.flat<float>()(0));
+}
+
+TEST(DatasetUtilsTest, VariantTensorDataNonExistentKeyIteratorName) {
+ VariantTensorData data;
+ strings::StrAppend(&data.metadata_, "key1", "@@");
+ data.tensors_.push_back(Tensor(DT_INT64, {1}));
+ std::vector<const VariantTensorData*> reader_data;
+ reader_data.push_back(&data);
+ VariantTensorDataReader reader(reader_data);
+ int64 val_int64;
+ tstring val_string;
+ Tensor val_tensor;
+ EXPECT_EQ(error::NOT_FOUND,
+ reader.ReadScalar("Iterator", "NonExistentKey", &val_int64).code());
+ EXPECT_EQ(
+ error::NOT_FOUND,
+ reader.ReadScalar("Iterator", "NonExistentKey", &val_string).code());
+ EXPECT_EQ(
+ error::NOT_FOUND,
+ reader.ReadTensor("Iterator", "NonExistentKey", &val_tensor).code());
+}
+
TEST(DatasetUtilsTest, VariantTensorDataWriteAfterFlushing) {
VariantTensorDataWriter writer;
TF_ASSERT_OK(writer.WriteScalar(full_name("Int64"), 24));
diff --git a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc
index b2df8aa..c3d120d 100644
--- a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc
@@ -298,7 +298,6 @@
Status DumpDatasetGraph(const std::string& path, uint64 hash,
const GraphDef& graph) {
- std::unique_ptr<WritableFile> file;
std::string hash_hex =
strings::StrCat(strings::Hex(hash, strings::kZeroPad16));
std::string graph_file =
diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc
index e57a185..7a5d7e4 100644
--- a/tensorflow/core/kernels/data/generator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/generator_dataset_op.cc
@@ -143,11 +143,10 @@
s = Status::OK();
*end_of_sequence = true;
- // NOTE(mrry): We ignore any tensors returned by the
- // finalize function.
+ // NOTE(mrry): We ignore any tensors returned by the finalize function.
std::vector<Tensor> ignored;
- TF_RETURN_IF_ERROR(
- instantiated_finalize_func_->RunInstantiated(state_, &ignored));
+ TF_RETURN_IF_ERROR(instantiated_finalize_func_->RunWithBorrowedArgs(
+ ctx, state_, &ignored));
finalized_ = true;
}
return s;
diff --git a/tensorflow/core/kernels/data/iterator_ops.h b/tensorflow/core/kernels/data/iterator_ops.h
index 4f43d78..7871cd7 100644
--- a/tensorflow/core/kernels/data/iterator_ops.h
+++ b/tensorflow/core/kernels/data/iterator_ops.h
@@ -74,9 +74,9 @@
std::shared_ptr<ProcessFunctionLibraryRuntime> pflr,
FunctionLibraryRuntime* flr,
std::unique_ptr<DatasetBaseIterator> iterator)
- : flib_def(flib_def),
+ : flib_def(std::move(flib_def)),
flr(flr),
- pflr(pflr),
+ pflr(std::move(pflr)),
function_handle_cache(absl::make_unique<FunctionHandleCache>(flr)),
iterator(std::move(iterator)) {}
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
index 3f5f98a..9c5b3ca 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
@@ -228,18 +228,18 @@
mutex_lock l(*mu_);
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
TF_RETURN_IF_ERROR(
- writer->WriteScalar(full_name(kBufferSize), buffer_.size()));
+ writer->WriteScalar(prefix(), kBufferSize, buffer_.size()));
for (size_t i = 0; i < buffer_.size(); i++) {
auto& buffer_element = buffer_[i];
TF_RETURN_IF_ERROR(WriteStatus(writer, i, buffer_element.status));
if (buffer_element.status.ok()) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
- full_name(strings::StrCat(kBuffer, "[", i, "]", kSizeSuffix)),
- buffer_element.value.size()));
+ absl::StrCat(prefix(), "::", i),
+ absl::StrCat(kBuffer, kSizeSuffix), buffer_element.value.size()));
for (size_t j = 0; j < buffer_element.value.size(); j++) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
- full_name(strings::StrCat(kBuffer, "[", i, "][", j, "]")),
- buffer_element.value[j]));
+ absl::StrCat(prefix(), "::", i),
+ absl::StrCat(kBuffer, "[", j, "]"), buffer_element.value[j]));
}
}
}
@@ -255,7 +255,7 @@
size_t buffer_size;
{
int64 temp;
- TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kBufferSize), &temp));
+ TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kBufferSize, &temp));
buffer_size = static_cast<size_t>(temp);
}
for (size_t i = 0; i < buffer_size; i++) {
@@ -266,17 +266,18 @@
size_t value_size;
{
int64 temp;
- TF_RETURN_IF_ERROR(reader->ReadScalar(
- full_name(strings::StrCat(kBuffer, "[", i, "]", kSizeSuffix)),
- &temp));
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(absl::StrCat(prefix(), "::", i),
+ absl::StrCat(kBuffer, kSizeSuffix), &temp));
value_size = static_cast<size_t>(temp);
}
buffer_element.value.reserve(value_size);
for (size_t j = 0; j < value_size; j++) {
buffer_element.value.emplace_back();
- TF_RETURN_IF_ERROR(reader->ReadTensor(
- full_name(strings::StrCat(kBuffer, "[", i, "][", j, "]")),
- &buffer_element.value.back()));
+ TF_RETURN_IF_ERROR(
+ reader->ReadTensor(absl::StrCat(prefix(), "::", i),
+ absl::StrCat(kBuffer, "[", j, "]"),
+ &buffer_element.value.back()));
}
}
}
@@ -435,11 +436,13 @@
Status WriteStatus(IteratorStateWriter* writer, size_t index,
const Status& status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
- TF_RETURN_IF_ERROR(writer->WriteScalar(
- CodeKey(index), static_cast<int64>(status.code())));
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(absl::StrCat(prefix(), "::", index), CodeKey(),
+ static_cast<int64>(status.code())));
if (!status.ok()) {
- TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index),
- status.error_message()));
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(absl::StrCat(prefix(), "::", index),
+ ErrorMessageKey(), status.error_message()));
}
return Status::OK();
}
@@ -447,13 +450,15 @@
Status ReadStatus(IteratorStateReader* reader, size_t index, Status* status)
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
int64 code_int;
- TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
+ TF_RETURN_IF_ERROR(reader->ReadScalar(absl::StrCat(prefix(), "::", index),
+ CodeKey(), &code_int));
error::Code code = static_cast<error::Code>(code_int);
if (code != error::Code::OK) {
tstring error_message;
TF_RETURN_IF_ERROR(
- reader->ReadScalar(ErrorMessageKey(index), &error_message));
+ reader->ReadScalar(absl::StrCat(prefix(), "::", index),
+ ErrorMessageKey(), &error_message));
*status = Status(code, error_message);
} else {
*status = Status::OK();
@@ -461,13 +466,10 @@
return Status::OK();
}
- string CodeKey(size_t index) {
- return full_name(strings::StrCat(kStatus, "[", index, "]", kCodeSuffix));
- }
+ string CodeKey() { return absl::StrCat(kStatus, kCodeSuffix); }
- string ErrorMessageKey(size_t index) {
- return full_name(
- strings::StrCat(kStatus, "[", index, "]", kErrorMessageSuffix));
+ string ErrorMessageKey() {
+ return absl::StrCat(kStatus, kErrorMessageSuffix);
}
// This mutex is used to ensure exclusivity between multiple threads
diff --git a/tensorflow/core/kernels/data/rewrite_utils.cc b/tensorflow/core/kernels/data/rewrite_utils.cc
index a284aa8..8c43b95 100644
--- a/tensorflow/core/kernels/data/rewrite_utils.cc
+++ b/tensorflow/core/kernels/data/rewrite_utils.cc
@@ -151,6 +151,7 @@
SerializationContext::ExternalStatePolicy::kIgnore;
params.fail_if_unimplemented = false;
params.serialize_data_tensors = false;
+ params.preserve_random_seeds = false;
SerializationContext serialization_ctx(params);
GraphDef graph_def;
TF_RETURN_IF_ERROR(
diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc
index 684ab0b..327fe3a 100644
--- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc
+++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc
@@ -70,6 +70,29 @@
constexpr char kReshufflingDatasetPrefix[] = "Reshuffling";
constexpr char kShuffleDataset[] = "ShuffleDataset";
+namespace {
+class Seeds {
+ public:
+ Seeds(int64 seed, int64 seed2) {
+ input_seed_ = seed;
+ input_seed2_ = seed2;
+ seed_ = seed;
+ seed2_ = seed2;
+ // By TensorFlow convention, if both seeds are 0, then shuffling should be
+ // seeded non-deterministically.
+ if (seed == 0 && seed2 == 0) {
+ seed_ = random::New64();
+ seed2_ = random::New64();
+ }
+ }
+
+ int64 input_seed_;
+ int64 input_seed2_;
+ int64 seed_;
+ int64 seed2_;
+};
+} // namespace
+
ShuffleDatasetOpBase::ShuffleDatasetOpBase(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx) {}
@@ -110,6 +133,18 @@
}
protected:
+ // Adds the seeds to the given graphdef builder. `preserve_random_seeds`
+ // controls whether to add the input seeds or the resolved seeds.
+ Status AddSeeds(Seeds seeds, bool preserve_random_seeds,
+ DatasetGraphDefBuilder* b, Node** seed, Node** seed2) const {
+ int64 seed_to_add = preserve_random_seeds ? seeds.input_seed_ : seeds.seed_;
+ int64 seed2_to_add =
+ preserve_random_seeds ? seeds.input_seed2_ : seeds.seed2_;
+ TF_RETURN_IF_ERROR(b->AddScalar(seed_to_add, seed));
+ TF_RETURN_IF_ERROR(b->AddScalar(seed2_to_add, seed2));
+ return Status::OK();
+ }
+
template <class T>
class Iterator : public DatasetIterator<T> {
public:
@@ -408,29 +443,6 @@
const int64 count_;
};
-namespace {
-class Seeds {
- public:
- Seeds(int64 seed, int64 seed2) {
- input_seed_ = seed;
- input_seed2_ = seed2;
- seed_ = seed;
- seed2_ = seed2;
- // By TensorFlow convention, if both seeds are 0, then shuffling should be
- // seeded non-deterministically.
- if (seed == 0 && seed2 == 0) {
- seed_ = random::New64();
- seed2_ = random::New64();
- }
- }
-
- int64 input_seed_;
- int64 input_seed2_;
- int64 seed_;
- int64 seed2_;
-};
-} // namespace
-
// A dataset that uses a pseudorandom sequence of seeds for the iterators
// created from it. Used when `reshuffle_each_iteration` is true.
class ShuffleDatasetOp::ReshufflingDataset : public ShuffleDatasetBase {
@@ -543,8 +555,8 @@
AttrValue reshuffle_each_iteration;
TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size));
- TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed_, &seed));
- TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed2_, &seed2));
+ TF_RETURN_IF_ERROR(
+ AddSeeds(seeds_, ctx->preserve_random_seeds(), b, &seed, &seed2));
b->BuildAttrValue(true, &reshuffle_each_iteration);
TF_RETURN_IF_ERROR(b->AddDataset(
this, {input_graph_node, buffer_size, seed, seed2}, // Inputs
@@ -700,8 +712,8 @@
AttrValue reshuffle_each_iteration;
TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size));
- TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed_, &seed));
- TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed2_, &seed2));
+ TF_RETURN_IF_ERROR(
+ AddSeeds(seeds_, ctx->preserve_random_seeds(), b, &seed, &seed2));
b->BuildAttrValue(false, &reshuffle_each_iteration);
TF_RETURN_IF_ERROR(b->AddDataset(
this, {input_graph_node, buffer_size, seed, seed2}, // Inputs
@@ -799,8 +811,8 @@
Node* count = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size));
- TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed_, &seed));
- TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed2_, &seed2));
+ TF_RETURN_IF_ERROR(
+ AddSeeds(seeds_, ctx->preserve_random_seeds(), b, &seed, &seed2));
TF_RETURN_IF_ERROR(b->AddScalar(count_, &count));
TF_RETURN_IF_ERROR(b->AddDataset(
this, {input_graph_node, buffer_size, seed, seed2, count}, // Inputs
diff --git a/tensorflow/core/kernels/debug_ops.h b/tensorflow/core/kernels/debug_ops.h
index 5d1c78e..8ac9684 100644
--- a/tensorflow/core/kernels/debug_ops.h
+++ b/tensorflow/core/kernels/debug_ops.h
@@ -686,7 +686,8 @@
static constexpr int kNegInfBit = 0x01;
static constexpr int kPosInfBit = 0x02;
static constexpr int kNaNBit = 0x04;
- static constexpr int64 kMaxTensorId = 1L << std::numeric_limits<Tout>::digits;
+ static constexpr int64 kMaxTensorId = 1LL
+ << std::numeric_limits<Tout>::digits;
};
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
diff --git a/tensorflow/core/kernels/debug_ops_gpu.cu.cc b/tensorflow/core/kernels/debug_ops_gpu.cu.cc
index a388b06..e1df486 100644
--- a/tensorflow/core/kernels/debug_ops_gpu.cu.cc
+++ b/tensorflow/core/kernels/debug_ops_gpu.cu.cc
@@ -42,9 +42,9 @@
const int32 total_thread_count = gridDim.x * blockDim.x;
int32 offset = thread_id;
-
while (offset < size) {
- if (isinf(data[offset]) || isnan(data[offset])) {
+ if (Eigen::numext::isinf(data[offset]) ||
+ Eigen::numext::isnan(data[offset])) {
output[0] = 1.0;
}
offset += total_thread_count;
@@ -63,14 +63,14 @@
Tout accum[3] = {0.0, 0.0, 0.0};
while (offset < size) {
- if (isinf(data[offset])) {
+ if (Eigen::numext::isinf(data[offset])) {
if (data[offset] < static_cast<Tin>(0.f)) {
++accum[0];
} else {
++accum[1];
}
}
- if (isnan(data[offset])) {
+ if (Eigen::numext::isnan(data[offset])) {
++accum[2];
}
offset += total_thread_count;
@@ -94,13 +94,13 @@
Tout accum[6] = {0.0, 0.0, 0.0, 0.0, 0.0, 0.0};
while (offset < size) {
- if (isinf(data[offset])) {
+ if (Eigen::numext::isinf(data[offset])) {
if (data[offset] < static_cast<Tin>(0.f)) {
++accum[0];
} else {
++accum[1];
}
- } else if (isnan(data[offset])) {
+ } else if (Eigen::numext::isnan(data[offset])) {
++accum[2];
} else {
if (data[offset] < static_cast<Tin>(0.f)) {
@@ -136,14 +136,14 @@
int32 offset = thread_id;
while (offset < size) {
- if (isinf(data[offset])) {
+ if (Eigen::numext::isinf(data[offset])) {
if (data[offset] < static_cast<Tin>(0.f)) {
output[0] = -std::numeric_limits<Tout>::infinity();
} else {
output[1] = std::numeric_limits<Tout>::infinity();
}
}
- if (isnan(data[offset])) {
+ if (Eigen::numext::isnan(data[offset])) {
output[2] = std::numeric_limits<Tout>::quiet_NaN();
}
offset += total_thread_count;
diff --git a/tensorflow/core/kernels/gpu_utils.cc b/tensorflow/core/kernels/gpu_utils.cc
index 4681c62..52676f6 100644
--- a/tensorflow/core/kernels/gpu_utils.cc
+++ b/tensorflow/core/kernels/gpu_utils.cc
@@ -40,7 +40,6 @@
if (RedzoneCheckDisabled()) {
return buffer;
}
- se::DeviceMemoryBase output_tensor;
auto output_rz_or = rz_allocator->AllocateBytes(buffer.size());
if (!output_rz_or.ok()) {
static std::once_flag rz_allocation_failure_logged;
diff --git a/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc b/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc
index 690d13c..7a6924e 100644
--- a/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc
+++ b/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc
@@ -548,7 +548,6 @@
inputs.emplace_back("Mul", Tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH}));
std::vector<string> output_node_names = {"softmax"};
- RemoteFusedGraphExecuteUtils::TensorShapeMap output_tensor_info0;
GraphTransferer gt0;
gt0.EnableStrictCheckMode(false);
ClockCycleProfiler prof0;
@@ -568,7 +567,6 @@
LOG(INFO) << "(0) node count: " << gfi0.node_info_size() << ", "
<< gfi0.const_node_info_size();
- RemoteFusedGraphExecuteUtils::TensorShapeMap output_tensor_info1;
GraphTransferer gt1;
gt1.EnableStrictCheckMode(true);
ClockCycleProfiler prof1;
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc
index e754341..a0a3d9f 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_ops.cc
@@ -792,7 +792,8 @@
// Tensorflow format to MKL format by caching the filter when it is
// converted for the first time. This cached filter can then be reused
// in subsequent iterations.
- if (is_filter_const_) {
+ bool do_cache_filter = src_dims[MklDnnDims::Dim_N] > kSmallBatchSize;
+ if (is_filter_const_ && do_cache_filter) {
if (IsFilterCacheEmpty(context)) {
// Cache filter if it is not already cached.
CacheFilter(context, conv_fwd_pd, filter_data, filter_tensor,
@@ -805,6 +806,13 @@
filter_data = GetCachedFilter(
context, GET_WEIGHTS_FORMAT_FROM_OP_PD(conv_fwd_pd, conv_fwd));
is_filter_cached = (filter_data != nullptr);
+ if (filter_out_tensor != nullptr) {
+ Tfilter* filter_out_tensor_buf =
+ static_cast<Tfilter*>(const_cast<Tfilter*>(
+ filter_out_tensor->flat<Tfilter>().data()));
+ memcpy(filter_out_tensor_buf, filter_data,
+ filter_out_tensor->AllocatedBytes());
+ }
}
if (!is_filter_cached) {
filter.SetUsrMem(filter_md, &filter_tensor);
@@ -1591,8 +1599,6 @@
const float* min_filter = min_filter_vector.flat<float>().data();
const float* max_filter = max_filter_vector.flat<float>().data();
- std::vector<mkldnn::primitive> net;
-
const float int_const_scale_limit =
(std::is_same<Tinput, quint8>::value) ? 255.0 * 127.0 : 127.0 * 127.0;
// Re-scale bias if either of following 2 conditions are met:
diff --git a/tensorflow/core/kernels/mkl_quantize_op.cc b/tensorflow/core/kernels/mkl_quantize_op.cc
index e161f8e..985f1cd 100644
--- a/tensorflow/core/kernels/mkl_quantize_op.cc
+++ b/tensorflow/core/kernels/mkl_quantize_op.cc
@@ -293,26 +293,6 @@
ctx, ctx->GetAttr("ensure_minimum_range", &ensure_minimum_range_));
}
- ~MklQuantizeV2Op() {
- if (minfirst_input_ != nullptr) {
- delete minfirst_input_;
- minfirst_input_ = nullptr;
- }
- }
-
- float* GetMinfirstInputBuf(int size) {
- if (!minfirst_input_) {
- minfirst_input_ = new float[size];
- minfirst_input_size_ = size;
- } else if (size > minfirst_input_size_) {
- delete minfirst_input_;
- minfirst_input_ = new float[size];
- minfirst_input_size_ = size;
- }
-
- return minfirst_input_;
- }
-
void ComputeScalar(OpKernelContext* ctx, float min_range, float max_range) {
// TODO(intel-tf): Scalar support has to be added for SCALE mode
OP_REQUIRES(ctx, (mode_ == QUANTIZE_MODE_MIN_FIRST),
@@ -434,8 +414,11 @@
// If the mode is min_first, input data has to be subtracted from
// min_range, before being scaled
auto flat_input = input.flat<float>().data();
+ Tensor minfirst_tmpinput;
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_temp(DT_FLOAT, input.shape(), &minfirst_tmpinput));
if (mode_ == QUANTIZE_MODE_MIN_FIRST) {
- float* minfirst_input = GetMinfirstInputBuf(input.NumElements());
+ auto minfirst_input = minfirst_tmpinput.flat<float>().data();
const Eigen::TensorOpCost cost(
sizeof(float), /*load bytes*/
sizeof(float), /*saved bytes*/
@@ -557,8 +540,6 @@
int round_mode_;
int axis_;
bool narrow_range_;
- float* minfirst_input_ = nullptr;
- int minfirst_input_size_;
};
REGISTER_KERNEL_BUILDER(Name("_MklQuantizeV2")
diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc
index ff48bd5..f9dd7c6 100644
--- a/tensorflow/core/kernels/non_max_suppression_op.cc
+++ b/tensorflow/core/kernels/non_max_suppression_op.cc
@@ -368,7 +368,6 @@
}
std::vector<int> selected;
- std::vector<float> selected_boxes;
Candidate next_candidate;
std::sort(candidate_vector.begin(), candidate_vector.end(), cmp);
diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils_test.cc b/tensorflow/core/kernels/remote_fused_graph_execute_utils_test.cc
index 44251e6..a55ea39 100644
--- a/tensorflow/core/kernels/remote_fused_graph_execute_utils_test.cc
+++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils_test.cc
@@ -87,7 +87,6 @@
Status FuseByInOut() {
// Feed output shapes and types
- RemoteFusedGraphExecuteUtils::TensorShapeMap tensor_shape_map;
GraphDef graph_def_with_shapetype = graph_def_;
TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::BuildAndAddTensorShapes(
input_tensors_, /*dry_run_inference*/ true, &graph_def_with_shapetype));
diff --git a/tensorflow/core/kernels/sparse/sparse_mat_mul_op.cc b/tensorflow/core/kernels/sparse/sparse_mat_mul_op.cc
index 53f9fbf..a03d60e 100644
--- a/tensorflow/core/kernels/sparse/sparse_mat_mul_op.cc
+++ b/tensorflow/core/kernels/sparse/sparse_mat_mul_op.cc
@@ -369,8 +369,6 @@
CSRSparseMatrix c;
Tensor c_row_ptrs;
- Tensor c_col_inds;
- Tensor c_values;
// TODO(ebrevdo): Re-enable transposing within the GEMM kernel when cuSparse
// stops spitting out CUSPARSE_STATUS_INTERNAL_ERROR values for transposes.
diff --git a/tensorflow/core/lib/random/random_distributions.h b/tensorflow/core/lib/random/random_distributions.h
index 6fb6bab..6f40816 100644
--- a/tensorflow/core/lib/random/random_distributions.h
+++ b/tensorflow/core/lib/random/random_distributions.h
@@ -18,10 +18,7 @@
#include <string.h>
-#define _USE_MATH_DEFINES
-#include <math.h>
#include <cmath>
-#undef _USE_MATH_DEFINES
#include <algorithm>
#include <type_traits>
diff --git a/tensorflow/core/lib/random/random_distributions_test.cc b/tensorflow/core/lib/random/random_distributions_test.cc
index 8868672..a497316 100644
--- a/tensorflow/core/lib/random/random_distributions_test.cc
+++ b/tensorflow/core/lib/random/random_distributions_test.cc
@@ -15,8 +15,8 @@
#include "tensorflow/core/lib/random/random_distributions.h"
-#include <math.h>
#include <algorithm>
+#include <cmath>
#include <functional>
#include <numeric>
#include <unordered_map>
diff --git a/tensorflow/core/ops/boosted_trees_ops.cc b/tensorflow/core/ops/boosted_trees_ops.cc
index d028ceb..639a753 100644
--- a/tensorflow/core/ops/boosted_trees_ops.cc
+++ b/tensorflow/core/ops/boosted_trees_ops.cc
@@ -141,6 +141,74 @@
return Status::OK();
});
+REGISTER_OP("BoostedTreesCalculateBestFeatureSplitV2")
+ .Input("node_id_range: int32")
+ .Input("stats_summaries_list: num_features * float32")
+ .Input("split_types: string")
+ .Input("candidate_feature_ids: int32")
+ .Input("l1: float")
+ .Input("l2: float")
+ .Input("tree_complexity: float")
+ .Input("min_node_weight: float")
+ .Attr("num_features: int >= 1") // not passed but populated automatically.
+ .Attr("logits_dimension: int >= 1")
+ .Output("node_ids: int32")
+ .Output("gains: float32")
+ .Output("feature_ids: int32")
+ .Output("feature_dimensions: int32")
+ .Output("thresholds: int32")
+ .Output("left_node_contribs: float32")
+ .Output("right_node_contribs: float32")
+ .Output("split_with_default_directions: string")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ // Attributes.
+ int num_features;
+ TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
+ int logits_dimension;
+ TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension));
+ // Inputs.
+ shape_inference::ShapeHandle unused_shape;
+ // node id range is rank 1 with 2 values.
+ shape_inference::ShapeHandle node_id_range_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &node_id_range_shape));
+ TF_RETURN_IF_ERROR(
+ c->Merge(node_id_range_shape, c->MakeShape({2}), &unused_shape));
+ // Stats summary validation.
+ shape_inference::ShapeHandle summary_shape_base;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &summary_shape_base));
+ // All stats summary entries are of the same shape.
+ for (int i = 1; i < num_features; ++i) {
+ shape_inference::ShapeHandle summary_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1 + i), 4, &summary_shape));
+ TF_RETURN_IF_ERROR(
+ c->Merge(summary_shape_base, summary_shape, &unused_shape));
+ }
+ // Validate rank 1 split_types.
+ TF_RETURN_IF_ERROR(
+ c->WithRank(c->input(1 + num_features), 1, &unused_shape));
+ // Validate rank 1 feature_ids.
+ TF_RETURN_IF_ERROR(
+ c->WithRank(c->input(2 + num_features), 1, &unused_shape));
+ // Validate rank 0: l1, l2, tree_complexity, min_node_weight.
+ for (int i = 0; i < 4; ++i) {
+ TF_RETURN_IF_ERROR(
+ c->WithRank(c->input(3 + num_features + i), 0, &unused_shape));
+ }
+ // Output shapes.
+ ShapeHandle rank_1_output_shape = c->MakeShape({c->UnknownDim()});
+ c->set_output(0, rank_1_output_shape);
+ c->set_output(1, rank_1_output_shape);
+ c->set_output(2, rank_1_output_shape);
+ c->set_output(3, rank_1_output_shape);
+ c->set_output(4, rank_1_output_shape);
+ ShapeHandle contribs_output_shape =
+ c->MakeShape({c->UnknownDim(), logits_dimension});
+ c->set_output(5, contribs_output_shape);
+ c->set_output(6, contribs_output_shape);
+ c->set_output(7, rank_1_output_shape);
+ return Status::OK();
+ });
+
REGISTER_OP("BoostedTreesSparseCalculateBestFeatureSplit")
.Input("node_id_range: int32")
.Input("stats_summary_indices: int32")
@@ -395,7 +463,6 @@
int num_bucketized_features;
TF_RETURN_IF_ERROR(
c->GetAttr("num_bucketized_features", &num_bucketized_features));
- shape_inference::ShapeHandle unused_input;
shape_inference::DimensionHandle batch_size = c->Dim(c->input(1), 0);
for (int i = 0; i < num_bucketized_features; ++i) {
TF_RETURN_IF_ERROR(
@@ -425,7 +492,6 @@
int num_bucketized_features;
TF_RETURN_IF_ERROR(
c->GetAttr("num_bucketized_features", &num_bucketized_features));
- shape_inference::ShapeHandle unused_input;
shape_inference::DimensionHandle batch_dim = c->Dim(c->input(1), 0);
for (int i = 0; i < num_bucketized_features; ++i) {
TF_RETURN_IF_ERROR(
diff --git a/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesCalculateBestFeatureSplitV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesCalculateBestFeatureSplitV2.pbtxt
new file mode 100644
index 0000000..e900ed9
--- /dev/null
+++ b/tensorflow/core/ops/compat/ops_history_v1/BoostedTreesCalculateBestFeatureSplitV2.pbtxt
@@ -0,0 +1,80 @@
+op {
+ name: "BoostedTreesCalculateBestFeatureSplitV2"
+ input_arg {
+ name: "node_id_range"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "stats_summaries_list"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ input_arg {
+ name: "split_types"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "candidate_feature_ids"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "l1"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "l2"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "tree_complexity"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "min_node_weight"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "node_ids"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "gains"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "feature_ids"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "feature_dimensions"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "thresholds"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "left_node_contribs"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "right_node_contribs"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "split_with_default_directions"
+ type: DT_STRING
+ }
+ attr {
+ name: "num_features"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "logits_dimension"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+}
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 62ce446..b24089c 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -5202,6 +5202,86 @@
}
}
op {
+ name: "BoostedTreesCalculateBestFeatureSplitV2"
+ input_arg {
+ name: "node_id_range"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "stats_summaries_list"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ input_arg {
+ name: "split_types"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "candidate_feature_ids"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "l1"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "l2"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "tree_complexity"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "min_node_weight"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "node_ids"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "gains"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "feature_ids"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "feature_dimensions"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "thresholds"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "left_node_contribs"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "right_node_contribs"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "split_with_default_directions"
+ type: DT_STRING
+ }
+ attr {
+ name: "num_features"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "logits_dimension"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "BoostedTreesCalculateBestGainsPerFeature"
input_arg {
name: "node_id_range"
diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD
index 865f505..cefb86c 100644
--- a/tensorflow/core/platform/BUILD
+++ b/tensorflow/core/platform/BUILD
@@ -17,6 +17,7 @@
"tf_legacy_srcs_no_runtime_google",
"tf_logging_deps",
"tf_monitoring_deps",
+ "tf_platform_alias",
"tf_platform_deps",
"tf_protobuf_compiler_deps",
"tf_protobuf_deps",
@@ -932,6 +933,16 @@
visibility = ["//tensorflow/core:__pkg__"],
)
+# These are the sources needed to build the target tensorflow/core:mobile_srcs_no_runtime.
+# We want to get rid of all such android targets, as described in
+# https://github.com/tensorflow/community/pull/179.
+# This temporary filegroup is allows us to remove the legacy "build_config" directories.
+filegroup(
+ name = "legacy_mobile_srcs",
+ srcs = tf_platform_alias("legacy_mobile_srcs"),
+ visibility = ["//tensorflow/core:__pkg__"],
+)
+
bzl_library(
name = "build_config_root_bzl",
srcs = [
diff --git a/tensorflow/core/platform/build_config.bzl b/tensorflow/core/platform/build_config.bzl
index 4ca9652..e30789d 100644
--- a/tensorflow/core/platform/build_config.bzl
+++ b/tensorflow/core/platform/build_config.bzl
@@ -23,6 +23,7 @@
_tf_lib_proto_parsing_deps = "tf_lib_proto_parsing_deps",
_tf_logging_deps = "tf_logging_deps",
_tf_monitoring_deps = "tf_monitoring_deps",
+ _tf_platform_alias = "tf_platform_alias",
_tf_platform_deps = "tf_platform_deps",
_tf_proto_library = "tf_proto_library",
_tf_proto_library_cc = "tf_proto_library_cc",
@@ -60,6 +61,7 @@
tf_lib_proto_parsing_deps = _tf_lib_proto_parsing_deps
tf_logging_deps = _tf_logging_deps
tf_monitoring_deps = _tf_monitoring_deps
+tf_platform_alias = _tf_platform_alias
tf_platform_deps = _tf_platform_deps
tf_proto_library = _tf_proto_library
tf_proto_library_cc = _tf_proto_library_cc
diff --git a/tensorflow/core/platform/cloud/retrying_file_system_test.cc b/tensorflow/core/platform/cloud/retrying_file_system_test.cc
index 1df371a..b48831a 100644
--- a/tensorflow/core/platform/cloud/retrying_file_system_test.cc
+++ b/tensorflow/core/platform/cloud/retrying_file_system_test.cc
@@ -525,7 +525,6 @@
RetryingFileSystem<MockFileSystem> fs(
std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
- std::vector<string> result;
TF_EXPECT_OK(fs.DeleteFile("gs://path/file.txt"));
}
@@ -536,7 +535,6 @@
RetryingFileSystem<MockFileSystem> fs(
std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
- std::vector<string> result;
const auto& status = fs.DeleteFile("gs://path/file.txt");
EXPECT_TRUE(absl::StrContains(status.error_message(), "Retriable error #10"))
<< status;
@@ -551,7 +549,6 @@
RetryingFileSystem<MockFileSystem> fs(
std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
- std::vector<string> result;
TF_EXPECT_OK(fs.CreateDir("gs://path/newdir"));
}
@@ -562,7 +559,6 @@
RetryingFileSystem<MockFileSystem> fs(
std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
- std::vector<string> result;
const auto& status = fs.CreateDir("gs://path/newdir");
EXPECT_TRUE(absl::StrContains(status.error_message(), "Retriable error #10"))
<< status;
@@ -577,7 +573,6 @@
RetryingFileSystem<MockFileSystem> fs(
std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
- std::vector<string> result;
TF_EXPECT_OK(fs.DeleteDir("gs://path/dir"));
}
@@ -588,7 +583,6 @@
RetryingFileSystem<MockFileSystem> fs(
std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
- std::vector<string> result;
const auto& status = fs.DeleteDir("gs://path/dir");
EXPECT_TRUE(absl::StrContains(status.error_message(), "Retriable error #10"))
<< status;
diff --git a/tensorflow/core/platform/default/BUILD b/tensorflow/core/platform/default/BUILD
index 04893ec..491f845 100644
--- a/tensorflow/core/platform/default/BUILD
+++ b/tensorflow/core/platform/default/BUILD
@@ -463,6 +463,11 @@
visibility = ["//tensorflow:__subpackages__"],
)
+filegroup(
+ name = "legacy_mobile_srcs",
+ visibility = ["//tensorflow/core/platform:__pkg__"],
+)
+
package_group(
name = "core_and_platform_packages",
packages = [
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index 3c50725..2876330 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -480,9 +480,6 @@
def tf_jspb_proto_library(**kwargs):
pass
-def tf_nano_proto_library(**kwargs):
- pass
-
def tf_proto_library(
name,
srcs = [],
@@ -535,23 +532,6 @@
visibility = visibility,
)
-# A list of all files under platform matching the pattern in 'files'. In
-# contrast with 'tf_platform_srcs' below, which seletive collects files that
-# must be compiled in the 'default' platform, this is a list of all headers
-# mentioned in the platform/* files.
-def tf_platform_hdrs(files):
- return native.glob(["*/" + f for f in files])
-
-def tf_platform_srcs(files):
- base_set = ["default/" + f for f in files]
- windows_set = base_set + ["windows/" + f for f in files]
- posix_set = base_set + ["posix/" + f for f in files]
-
- return select({
- clean_dep("//tensorflow:windows"): native.glob(windows_set),
- "//conditions:default": native.glob(posix_set),
- })
-
def tf_additional_lib_hdrs():
return [
"//tensorflow/core/platform/default:context.h",
@@ -753,7 +733,10 @@
],
})
-def tf_platform_deps(name):
+def tf_platform_deps(name, platform_dir = "//tensorflow/core/platform/"):
+ return [platform_dir + "default:" + name]
+
+def tf_platform_alias(name):
return ["//tensorflow/core/platform/default:" + name]
def tf_logging_deps():
diff --git a/tensorflow/core/platform/default/build_config/BUILD b/tensorflow/core/platform/default/build_config/BUILD
index 2354015..7545bc5 100644
--- a/tensorflow/core/platform/default/build_config/BUILD
+++ b/tensorflow/core/platform/default/build_config/BUILD
@@ -1,21 +1,17 @@
# Description:
# Platform-specific build configurations.
+load("//tensorflow:tensorflow.bzl", "tf_copts", "tf_cuda_library")
+load("//tensorflow/core/platform:build_config_root.bzl", "if_static")
+load("@local_config_sycl//sycl:platform.bzl", "sycl_library_path")
+load("@local_config_sycl//sycl:build_defs.bzl", "if_ccpp")
+
package(default_visibility = ["//tensorflow:internal"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
-load("//tensorflow:tensorflow.bzl", "check_deps")
-load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
-load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
-load("//tensorflow:tensorflow.bzl", "tf_copts")
-load("//tensorflow:tensorflow.bzl", "tf_cuda_library")
-load("//tensorflow/core/platform:build_config_root.bzl", "if_static")
-load("@local_config_sycl//sycl:platform.bzl", "sycl_library_path")
-load("@local_config_sycl//sycl:build_defs.bzl", "if_ccpp")
-
cc_library(
name = "gtest",
testonly = 1,
@@ -133,17 +129,10 @@
name = "proto_parsing",
copts = tf_copts(),
deps = [
- "//tensorflow/core:protos_cc",
+ "//tensorflow/core:protos_all_cc",
],
)
-# Minimal lib so that tools used for mobile compilation
-# don't have to depend on platformlib.
-cc_library(
- name = "logging",
- copts = tf_copts(),
-)
-
# Minimal lib to be used by tensorflow/core:framework_lite.
# This provides minimal support for writing operator implementations (kernels),
# and excludes anything that can bloat binary size if used.
@@ -154,48 +143,12 @@
)
cc_library(
- name = "base",
- srcs = [],
- copts = tf_copts(),
-)
-
-cc_library(
- name = "port",
- srcs = [],
- copts = tf_copts(),
-)
-
-cc_library(
- name = "protobuf",
- srcs = [],
- copts = tf_copts(),
-)
-
-cc_library(
- name = "env",
- srcs = [],
- copts = tf_copts(),
-)
-
-cc_library(
- name = "other",
- srcs = [],
- copts = tf_copts(),
- deps = [
- "@com_googlesource_code_re2//:re2",
- "@farmhash_archive//:farmhash",
- "@fft2d",
- "@highwayhash//:sip_hash",
- ],
-)
-
-cc_library(
name = "platformlib",
copts = tf_copts(),
deps = [
":gif",
":jpeg",
- "//tensorflow/core:protos_cc",
+ "//tensorflow/core:protos_all_cc",
"@com_googlesource_code_re2//:re2",
"@farmhash_archive//:farmhash",
"@fft2d",
@@ -205,11 +158,6 @@
)
cc_library(
- name = "stacktrace",
- srcs = [],
-)
-
-cc_library(
name = "gif",
copts = tf_copts(),
deps = [
@@ -235,41 +183,12 @@
)
cc_library(
- name = "protos_cc_impl",
- copts = tf_copts(),
- deps = [
- "//tensorflow/core:protos_all_cc_impl",
- ],
-)
-
-cc_library(
- name = "protos_cc",
- copts = tf_copts(),
- deps = [
- "//tensorflow/core:protos_all_cc",
- ],
-)
-
-cc_library(
- name = "test_lite_main",
- testonly = 1,
- linkstatic = 1,
- deps = [],
-)
-
-cc_library(
name = "test_main",
testonly = 1,
linkstatic = 1,
deps = [],
)
-filegroup(
- name = "android_proto_lib_portable_proto",
- srcs = [],
- visibility = ["//visibility:public"],
-)
-
cc_library(
name = "cuda",
data = [
@@ -314,15 +233,3 @@
["@local_config_sycl//sycl:sycl_headers"],
),
)
-
-filegroup(
- name = "mobile_srcs",
- srcs = glob(["*.h"]),
- visibility = ["//visibility:public"],
-)
-
-alias(
- name = "android_srcs",
- actual = ":mobile_srcs",
- visibility = ["//visibility:public"],
-)
diff --git a/tensorflow/core/platform/default/strong_hash.h b/tensorflow/core/platform/default/strong_hash.h
index d20ef70..8c25bf6 100644
--- a/tensorflow/core/platform/default/strong_hash.h
+++ b/tensorflow/core/platform/default/strong_hash.h
@@ -16,8 +16,8 @@
#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_STRONG_HASH_H_
#define TENSORFLOW_CORE_PLATFORM_DEFAULT_STRONG_HASH_H_
-#include "highwayhash/sip_hash.h"
-#include "highwayhash/state_helpers.h"
+#include "highwayhash/sip_hash.h" // TF:highwayhash
+#include "highwayhash/state_helpers.h" // TF:highwayhash
namespace tensorflow {
diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system.cc b/tensorflow/core/platform/hadoop/hadoop_file_system.cc
index 34dc1cf..091cb60 100644
--- a/tensorflow/core/platform/hadoop/hadoop_file_system.cc
+++ b/tensorflow/core/platform/hadoop/hadoop_file_system.cc
@@ -135,6 +135,25 @@
return libhdfs;
}
+Status SplitArchiveNameAndPath(StringPiece& path, string& nn) {
+ size_t index_end_archive_name = path.find(".har");
+ if (index_end_archive_name == path.npos) {
+ return errors::InvalidArgument(
+ "Hadoop archive path does not contain a .har extension");
+ }
+ // Case of hadoop archive. Namenode is the path to the archive.
+ std::ostringstream namenodestream;
+ namenodestream << "har://" << nn
+ << path.substr(0, index_end_archive_name + 4);
+ nn = namenodestream.str();
+ path.remove_prefix(index_end_archive_name + 4);
+ if (path.empty()) {
+ // Root of the archive
+ path = "/";
+ }
+ return Status::OK();
+}
+
// We rely on HDFS connection caching here. The HDFS client calls
// org.apache.hadoop.fs.FileSystem.get(), which caches the connection
// internally.
@@ -143,7 +162,7 @@
StringPiece scheme, namenode, path;
io::ParseURI(fname, &scheme, &namenode, &path);
- const string nn(namenode);
+ string nn(namenode);
hdfsBuilder* builder = libhdfs()->hdfsNewBuilder();
if (scheme == "file") {
@@ -163,6 +182,9 @@
// configuration files). See:
// https://github.com/tensorflow/tensorflow/blob/v1.0.0/third_party/hadoop/hdfs.h#L259
libhdfs()->hdfsBuilderSetNameNode(builder, "default");
+ } else if (scheme == "har") {
+ TF_RETURN_IF_ERROR(SplitArchiveNameAndPath(path, nn));
+ libhdfs()->hdfsBuilderSetNameNode(builder, nn.c_str());
} else {
libhdfs()->hdfsBuilderSetNameNode(builder,
nn.empty() ? "default" : nn.c_str());
@@ -517,5 +539,6 @@
REGISTER_FILE_SYSTEM("hdfs", HadoopFileSystem);
REGISTER_FILE_SYSTEM("viewfs", HadoopFileSystem);
+REGISTER_FILE_SYSTEM("har", HadoopFileSystem);
} // namespace tensorflow
diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system.h b/tensorflow/core/platform/hadoop/hadoop_file_system.h
index 11812c2..f9f2c25 100644
--- a/tensorflow/core/platform/hadoop/hadoop_file_system.h
+++ b/tensorflow/core/platform/hadoop/hadoop_file_system.h
@@ -70,6 +70,8 @@
Status Connect(StringPiece fname, hdfsFS* fs);
};
+Status SplitArchiveNameAndPath(StringPiece& path, string& nn);
+
} // namespace tensorflow
#endif // TENSORFLOW_CORE_PLATFORM_HADOOP_HADOOP_FILE_SYSTEM_H_
diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system_test.cc b/tensorflow/core/platform/hadoop/hadoop_file_system_test.cc
index 3104add..71cf054 100644
--- a/tensorflow/core/platform/hadoop/hadoop_file_system_test.cc
+++ b/tensorflow/core/platform/hadoop/hadoop_file_system_test.cc
@@ -235,6 +235,44 @@
TF_EXPECT_OK(writer->Close());
}
+TEST_F(HadoopFileSystemTest, HarSplit) {
+ string har_path =
+ "har://hdfs-root/user/j.doe/my_archive.har/dir0/dir1/file.txt";
+ StringPiece scheme, namenode, path;
+ io::ParseURI(har_path, &scheme, &namenode, &path);
+ EXPECT_EQ("har", scheme);
+ EXPECT_EQ("hdfs-root", namenode);
+ EXPECT_EQ("/user/j.doe/my_archive.har/dir0/dir1/file.txt", path);
+ string nn(namenode);
+ TF_EXPECT_OK(SplitArchiveNameAndPath(path, nn));
+ EXPECT_EQ("har://hdfs-root/user/j.doe/my_archive.har", nn);
+ EXPECT_EQ("/dir0/dir1/file.txt", path);
+}
+
+TEST_F(HadoopFileSystemTest, NoHarExtension) {
+ string har_path = "har://hdfs-root/user/j.doe/my_archive/dir0/dir1/file.txt";
+ StringPiece scheme, namenode, path;
+ io::ParseURI(har_path, &scheme, &namenode, &path);
+ EXPECT_EQ("har", scheme);
+ EXPECT_EQ("hdfs-root", namenode);
+ EXPECT_EQ("/user/j.doe/my_archive/dir0/dir1/file.txt", path);
+ string nn(namenode);
+ EXPECT_EQ(errors::InvalidArgument("").code(),
+ SplitArchiveNameAndPath(path, nn).code());
+}
+
+TEST_F(HadoopFileSystemTest, HarRootPath) {
+ string har_path = "har://hdfs-root/user/j.doe/my_archive.har";
+ StringPiece scheme, namenode, path;
+ io::ParseURI(har_path, &scheme, &namenode, &path);
+ EXPECT_EQ("har", scheme);
+ EXPECT_EQ("hdfs-root", namenode);
+ EXPECT_EQ("/user/j.doe/my_archive.har", path);
+ string nn(namenode);
+ TF_EXPECT_OK(SplitArchiveNameAndPath(path, nn));
+ EXPECT_EQ("har://hdfs-root/user/j.doe/my_archive.har", nn);
+ EXPECT_EQ("/", path);
+}
// NewAppendableFile() is not testable. Local filesystem maps to
// ChecksumFileSystem in Hadoop, where appending is an unsupported operation.
diff --git a/tensorflow/core/platform/logging.h b/tensorflow/core/platform/logging.h
index 1ebc93f..c3a998d 100644
--- a/tensorflow/core/platform/logging.h
+++ b/tensorflow/core/platform/logging.h
@@ -22,7 +22,7 @@
#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID) || \
defined(PLATFORM_GOOGLE_IOS) || defined(GOOGLE_LOGGING) || \
defined(__EMSCRIPTEN__)
-#include "tensorflow/core/platform/google/build_config/logging.h"
+#include "tensorflow/core/platform/google/logging.h"
#else
#include "tensorflow/core/platform/default/logging.h"
#endif
diff --git a/tensorflow/core/profiler/convert/BUILD b/tensorflow/core/profiler/convert/BUILD
index 75be7a3..044df45 100644
--- a/tensorflow/core/profiler/convert/BUILD
+++ b/tensorflow/core/profiler/convert/BUILD
@@ -28,7 +28,6 @@
srcs = ["run_metadata_to_trace_events.cc"],
hdrs = ["run_metadata_to_trace_events.h"],
deps = [
- "//tensorflow/core:core_cpu_lib",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
@@ -80,6 +79,31 @@
)
cc_library(
+ name = "op_stats_to_input_pipeline_analysis",
+ srcs = ["op_stats_to_input_pipeline_analysis.cc"],
+ hdrs = ["op_stats_to_input_pipeline_analysis.h"],
+ deps = [
+ ":op_metrics_to_record",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core/platform:logging",
+ "//tensorflow/core/profiler/protobuf:hardware_types_proto_cc",
+ "//tensorflow/core/profiler/protobuf:input_pipeline_proto_cc",
+ "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc",
+ "//tensorflow/core/profiler/protobuf:op_stats_proto_cc",
+ "//tensorflow/core/profiler/protobuf:steps_db_proto_cc",
+ "//tensorflow/core/profiler/utils:event_span",
+ "//tensorflow/core/profiler/utils:math_utils",
+ "//tensorflow/core/profiler/utils:tf_op_utils",
+ "//tensorflow/core/profiler/utils:time_utils",
+ "//tensorflow/core/util:stats_calculator_portable",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
name = "op_stats_to_tf_stats",
srcs = ["op_stats_to_tf_stats.cc"],
hdrs = ["op_stats_to_tf_stats.h"],
diff --git a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc
new file mode 100644
index 0000000..965cab1
--- /dev/null
+++ b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc
@@ -0,0 +1,402 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h"
+
+#include <algorithm>
+#include <utility>
+
+#include "google/protobuf/any.pb.h"
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/strings/match.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/convert/op_metrics_to_record.h"
+#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h"
+#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
+#include "tensorflow/core/profiler/protobuf/op_stats.pb.h"
+#include "tensorflow/core/profiler/protobuf/steps_db.pb.h"
+#include "tensorflow/core/profiler/utils/event_span.h"
+#include "tensorflow/core/profiler/utils/math_utils.h"
+#include "tensorflow/core/profiler/utils/tf_op_utils.h"
+#include "tensorflow/core/profiler/utils/time_utils.h"
+#include "tensorflow/core/util/stats_calculator.h"
+
+namespace tensorflow {
+namespace profiler {
+
+namespace {
+
+const double kNumPsPerMs = 1000000000.0;
+
+template <class Collection>
+double GetTimeInMs(const Collection& type_ps,
+ EventType event_type) {
+ return PicosToMillis(gtl::FindWithDefault(type_ps, event_type, /*value=*/0));
+}
+
+StepSummary GetStepSummaryForSampleStats(const Stat<double>& sample_stats) {
+ StepSummary step_time_summary;
+ step_time_summary.set_average(sample_stats.avg());
+ step_time_summary.set_standard_deviation(
+ std::sqrt(sample_stats.sample_variance()));
+ step_time_summary.set_minimum(sample_stats.min());
+ step_time_summary.set_maximum(sample_stats.max());
+ return step_time_summary;
+}
+
+GenericStepTimeBreakdown ComputeGenericStepTimeBreakdownInMs(
+ const InputPipelineAnalysisResult& analysis) {
+ Stat<double> unknown_time_ms;
+ Stat<double> infeed_ms;
+ Stat<double> outfeed_ms;
+ Stat<double> device_compute_ms;
+ Stat<double> device_to_device_ms;
+ Stat<double> host_compute_ms;
+ Stat<double> host_prepare_ms;
+ Stat<double> host_compile_ms;
+ GenericStepTimeBreakdown result;
+
+ for (const google::protobuf::Any& step_details : analysis.step_details()) {
+ PerGenericStepDetails details;
+ bool success = step_details.UnpackTo(&details);
+ if (!success) {
+ LOG(ERROR) << "Unable to unpack step_breakdown. Expected: generic"
+ << std::endl;
+ return {};
+ }
+ unknown_time_ms.UpdateStat(details.unknown_time_ms());
+ infeed_ms.UpdateStat(details.infeed_ms());
+ outfeed_ms.UpdateStat(details.outfeed_ms());
+ device_compute_ms.UpdateStat(details.device_compute_ms());
+ device_to_device_ms.UpdateStat(details.device_to_device_ms());
+ host_compute_ms.UpdateStat(details.host_compute_ms());
+ host_prepare_ms.UpdateStat(details.host_prepare_ms());
+ host_compile_ms.UpdateStat(details.host_compile_ms());
+ }
+ *result.mutable_unknown_time_ms_summary() =
+ GetStepSummaryForSampleStats(unknown_time_ms);
+ *result.mutable_infeed_ms_summary() = GetStepSummaryForSampleStats(infeed_ms);
+ *result.mutable_outfeed_ms_summary() =
+ GetStepSummaryForSampleStats(outfeed_ms);
+ *result.mutable_device_compute_ms_summary() =
+ GetStepSummaryForSampleStats(device_compute_ms);
+ *result.mutable_device_to_device_ms_summary() =
+ GetStepSummaryForSampleStats(device_to_device_ms);
+ *result.mutable_host_compute_ms_summary() =
+ GetStepSummaryForSampleStats(host_compute_ms);
+ *result.mutable_host_prepare_ms_summary() =
+ GetStepSummaryForSampleStats(host_prepare_ms);
+ *result.mutable_host_compile_ms_summary() =
+ GetStepSummaryForSampleStats(host_compile_ms);
+ return result;
+}
+
+InputPipelineAnalysisResult ComputeGenericInputPipelineAnalysisResult(
+ const protobuf::RepeatedPtrField<PerCoreStepInfo>& grouped_by_step) {
+ InputPipelineAnalysisResult result;
+
+ // Computes the summary of step time in ms.
+ *result.mutable_step_time_summary() =
+ ComputeStepTimeSummaryInMs(grouped_by_step);
+
+ Stat<double> infeed_summary_stats_in_percent;
+ for (const auto& coreid_stepinfo_map : grouped_by_step) {
+ // Iterates over each step.
+ const auto* ptr =
+ gtl::FindOrNull(coreid_stepinfo_map.step_info_per_core(), 0);
+ if (ptr == nullptr) {
+ // For generic hardware, all step-info is put under core-0. If ptr
+ // is nullptr, it means there is no step at all.
+ continue;
+ }
+ const StepInfoResult& step_info = *ptr;
+ // Adds the details for a new step.
+ PerGenericStepDetails details;
+ details.set_step_number(step_info.step_num());
+ details.set_step_time_ms(PicosToMillis(step_info.duration_ps()));
+ GenericStepBreakdown generic;
+ bool success = step_info.step_breakdown().UnpackTo(&generic);
+ if (!success) {
+ LOG(ERROR) << "Unable to unpack step_breakdown. Expected: generic"
+ << std::endl;
+ return {};
+ }
+ const auto& type_ps = generic.type_ps();
+ details.set_unknown_time_ms(GetTimeInMs(type_ps, UNKNOWN_TIME));
+ // To be consistent with TPU case, the infeed time includes the time that
+ // the host is reading files, preprocessing, and the time to transfer the
+ // data to the device.
+ details.set_infeed_ms(GetTimeInMs(type_ps, HOST_WAIT_INPUT) +
+ GetTimeInMs(type_ps, HOST_TO_DEVICE) +
+ GetTimeInMs(type_ps, DEVICE_WAIT_HOST));
+ details.set_outfeed_ms(GetTimeInMs(type_ps, DEVICE_TO_HOST));
+ details.set_device_compute_ms(GetTimeInMs(type_ps, DEVICE_COMPUTE));
+ details.set_device_to_device_ms(GetTimeInMs(type_ps, DEVICE_TO_DEVICE) +
+ GetTimeInMs(type_ps, DEVICE_WAIT_DEVICE));
+ details.set_host_compute_ms(GetTimeInMs(type_ps, HOST_COMPUTE));
+ details.set_host_prepare_ms(GetTimeInMs(type_ps, HOST_PREPARE));
+ details.set_host_compile_ms(GetTimeInMs(type_ps, HOST_COMPILE));
+
+ result.add_step_details()->PackFrom(details);
+
+ const double infeed_pct_of_step_time =
+ 100.0 * SafeDivide(details.infeed_ms(), details.step_time_ms());
+ infeed_summary_stats_in_percent.UpdateStat(infeed_pct_of_step_time);
+ }
+
+ // Computes the summary of infeed time as percentage of step time.
+ *result.mutable_infeed_percent_summary() =
+ GetStepSummaryForSampleStats(infeed_summary_stats_in_percent);
+
+ // Computes the breakdown of step time.
+ GenericStepTimeBreakdown generic_step_time_breakdown =
+ ComputeGenericStepTimeBreakdownInMs(result);
+ result.mutable_step_time_breakdown()->PackFrom(generic_step_time_breakdown);
+
+ return result;
+}
+
+// Classification of input processing on the host.
+enum class InputOpCategory {
+ kEnqueue, // enqueue data to be transferred to device.
+ kDemandedFileRead, // demanded read from file.
+ kAdvancedFileRead, // advanced read from file (including cached,
+ // prefetch, parallel-map, interleave).
+ kPreprocessing // data preprocessing.
+};
+
+string InputOpCategoryString(InputOpCategory category) {
+ switch (category) {
+ case InputOpCategory::kEnqueue:
+ return "Enqueue";
+ case InputOpCategory::kDemandedFileRead:
+ return "Demanded file read";
+ case InputOpCategory::kAdvancedFileRead:
+ return "Advanced file read";
+ case InputOpCategory::kPreprocessing:
+ return "Preprocessing";
+ }
+}
+
+inline bool IsInputOp(absl::string_view category) {
+ return IsInfeedEnqueueOp(category) || IsDatasetOp(category);
+}
+
+InputOpCategory CategorizeInputOp(absl::string_view name,
+ absl::string_view category) {
+ if (IsInfeedEnqueueOp(category)) {
+ return InputOpCategory::kEnqueue;
+ }
+ DCHECK(IsDatasetOp(category));
+ if (absl::EndsWith(name, "::TFRecord") ||
+ absl::EndsWith(name, "::TextLine") ||
+ absl::EndsWith(name, "::FixedLengthRecord") ||
+ absl::EndsWith(name, "::SSTable") || absl::EndsWith(name, "::RecordIO")) {
+ if (absl::StrContains(name, "::MemoryReader") ||
+ absl::StrContains(name, "::MemoryWriter") ||
+ absl::StrContains(name, "::Interleave") ||
+ absl::StrContains(name, "::Prefetch") ||
+ absl::StrContains(name, "::ParallelMap")) {
+ return InputOpCategory::kAdvancedFileRead;
+ } else {
+ return InputOpCategory::kDemandedFileRead;
+ }
+ } else {
+ return InputOpCategory::kPreprocessing;
+ }
+}
+
+struct InputOpMetrics {
+ std::vector<const OpMetrics*> input_op_metrics;
+ uint64 input_op_time_ps = 0;
+};
+
+InputOpMetrics SelectInputOpMetrics(const OpMetricsDb& all_op_metrics) {
+ InputOpMetrics input_op_metrics;
+ for (const OpMetrics* op_metrics : SortedOpMetricsDb(all_op_metrics)) {
+ if (IsInputOp(op_metrics->category())) {
+ input_op_metrics.input_op_metrics.push_back(op_metrics);
+ input_op_metrics.input_op_time_ps += op_metrics->self_time_ps();
+ }
+ }
+ return input_op_metrics;
+}
+
+InputOpDetails ConvertOpMetricsToInputOpDetails(const OpMetrics& op_metrics,
+ uint64 input_op_time_ps,
+ InputOpCategory category) {
+ InputOpDetails details;
+ details.set_op_name(op_metrics.name());
+ details.set_count(op_metrics.occurrences());
+ details.set_time_in_ms(PicosToMillis(op_metrics.time_ps()));
+ details.set_self_time_in_ms(PicosToMillis(op_metrics.self_time_ps()));
+ details.set_time_in_percent(
+ 100.0 * SafeDivide(op_metrics.time_ps(), input_op_time_ps));
+ details.set_self_time_in_percent(
+ 100.0 * SafeDivide(op_metrics.self_time_ps(), input_op_time_ps));
+ details.set_category(InputOpCategoryString(category));
+ return details;
+}
+
+void GenerateHostResult(const OpMetricsDb& host_tf_metrics_db,
+ InputPipelineAnalysisResult* result) {
+ InputOpMetrics input_op_metrics = SelectInputOpMetrics(host_tf_metrics_db);
+ // Return if the program is not using an input pipeline with xprof
+ // instrumentation and no input ops are found.
+ if (input_op_metrics.input_op_metrics.empty()) return;
+
+ absl::flat_hash_map<InputOpCategory, double> aggregated_input_op_times_us;
+ for (const OpMetrics* op_metrics : input_op_metrics.input_op_metrics) {
+ InputOpCategory category =
+ CategorizeInputOp(op_metrics->name(), op_metrics->category());
+ *result->add_input_op_details() = ConvertOpMetricsToInputOpDetails(
+ *op_metrics, input_op_metrics.input_op_time_ps, category);
+ aggregated_input_op_times_us[category] +=
+ PicosToMicros(op_metrics->self_time_ps());
+ }
+
+ double enqueue_time_us =
+ aggregated_input_op_times_us[InputOpCategory::kEnqueue];
+ double total_input_op_time_us =
+ aggregated_input_op_times_us[InputOpCategory::kDemandedFileRead] +
+ aggregated_input_op_times_us[InputOpCategory::kAdvancedFileRead] +
+ aggregated_input_op_times_us[InputOpCategory::kPreprocessing];
+
+ // We use total_host_infeed_enq_start_timestamp_ps_diff_ to approximate the
+ // total host step time.
+ double ratio = SafeDivide(
+ host_tf_metrics_db.total_host_infeed_enq_duration_ps(),
+ host_tf_metrics_db.total_host_infeed_enq_start_timestamp_ps_diff());
+ DCHECK_LE(ratio, 1.0);
+ DCHECK_GE(ratio, 0.0);
+ double non_enqueue_time_us = (ratio != 0.0)
+ ? (enqueue_time_us * (1.0 - ratio) / ratio)
+ : total_input_op_time_us;
+
+ // Scales the various input-time components wrt to non_enqueue_time_us.
+ double scaled_demanded_fileread_time_us = SafeDivide(
+ non_enqueue_time_us *
+ aggregated_input_op_times_us[InputOpCategory::kDemandedFileRead],
+ total_input_op_time_us);
+ double scaled_advanced_fileread_time_us = SafeDivide(
+ non_enqueue_time_us *
+ aggregated_input_op_times_us[InputOpCategory::kAdvancedFileRead],
+ total_input_op_time_us);
+ double scaled_preprocessing_time_us = SafeDivide(
+ non_enqueue_time_us *
+ aggregated_input_op_times_us[InputOpCategory::kPreprocessing],
+ total_input_op_time_us);
+ double unclassified_non_enqueue_time_us = std::max(
+ 0.0, non_enqueue_time_us - scaled_demanded_fileread_time_us -
+ scaled_advanced_fileread_time_us - scaled_preprocessing_time_us);
+
+ InputTimeBreakdown* input_time_breakdown =
+ result->mutable_input_time_breakdown();
+ input_time_breakdown->set_enqueue_us(enqueue_time_us);
+ input_time_breakdown->set_demanded_file_read_us(
+ scaled_demanded_fileread_time_us);
+ input_time_breakdown->set_advanced_file_read_us(
+ scaled_advanced_fileread_time_us);
+ input_time_breakdown->set_preprocessing_us(scaled_preprocessing_time_us);
+ input_time_breakdown->set_unclassified_non_enqueue_us(
+ unclassified_non_enqueue_time_us);
+}
+
+string AnchorElement(absl::string_view url, absl::string_view text) {
+ return absl::StrCat("<a href=\"", url, "\" target=\"_blank\">", text, "</a>");
+}
+
+InputPipelineAnalysisRecommendation GenerateRecommendation() {
+ const absl::string_view kDatasetIntro =
+ "https://www.tensorflow.org/programmers_guide/datasets";
+
+ const absl::string_view kDatasetTopic =
+ "https://www.tensorflow.org/api_docs/python/tf/data/Dataset#";
+
+ const absl::string_view kTfRecordDataset =
+ "https://www.tensorflow.org/api_docs/python/tf/data/"
+ "TFRecordDataset#class_tfrecorddataset";
+
+ InputPipelineAnalysisRecommendation recommendation;
+ *recommendation.add_details() =
+ "Enqueuing data: you may want to combine small input data chunks "
+ "into fewer "
+ "but larger chunks.";
+ *recommendation.add_details() = absl::StrCat(
+ "Data preprocessing: you may increase num_parallel_calls in ",
+ AnchorElement(absl::StrCat(kDatasetTopic, "map"), "Dataset map()"),
+ " or preprocess the data OFFLINE.");
+ *recommendation.add_details() = absl::StrCat(
+ "Reading data from files in advance: you may tune parameters in the "
+ "following Dataset API (",
+ AnchorElement(absl::StrCat(kDatasetTopic, "prefetch"), "prefetch size"),
+ ", ",
+ AnchorElement(absl::StrCat(kDatasetTopic, "interleave"),
+ "interleave cycle_length"),
+ ", ", AnchorElement(kTfRecordDataset, "reader buffer_size"), ")");
+ *recommendation.add_details() = absl::StrCat(
+ "Reading data from files on demand: you should read data IN ADVANCE "
+ "using the following Dataset API (",
+ AnchorElement(absl::StrCat(kDatasetTopic, "prefetch"), "prefetch"), ", ",
+ AnchorElement(absl::StrCat(kDatasetTopic, "interleave"), "interleave"),
+ ", ", AnchorElement(kTfRecordDataset, "reader buffer"), ")");
+ *recommendation.add_details() = absl::StrCat(
+ "Other data reading or processing: you may consider using the ",
+ AnchorElement(kDatasetIntro, "Dataset API"),
+ " (if you are not using it now)");
+ return recommendation;
+}
+
+} // namespace
+
+StepSummary ComputeStepTimeSummaryInMs(
+ const protobuf::RepeatedPtrField<PerCoreStepInfo>& grouped_by_step) {
+ Stat<double> total_step_stats_in_ms;
+ // iterates over each step.
+ for (const auto& coreid_stepinfo_map : grouped_by_step) {
+ double max_per_step_stats_in_ms = 0.0;
+ // iterates over each core.
+ for (const auto& coreid_and_stepinfo :
+ coreid_stepinfo_map.step_info_per_core()) {
+ const auto& step_info = coreid_and_stepinfo.second;
+ max_per_step_stats_in_ms = std::max(step_info.duration_ps() / kNumPsPerMs,
+ max_per_step_stats_in_ms);
+ }
+ // Step time of each step is determined by the slowest core.
+ total_step_stats_in_ms.UpdateStat(max_per_step_stats_in_ms);
+ }
+
+ return GetStepSummaryForSampleStats(total_step_stats_in_ms);
+}
+
+InputPipelineAnalysisResult ConvertOpStatsToInputPipelineAnalysis(
+ const OpStats& op_stats, const HardwareType& hardware_type) {
+ InputPipelineAnalysisResult result =
+ ComputeGenericInputPipelineAnalysisResult(
+ op_stats.step_db().step_sequence());
+ result.set_hardware_type(hardware_type);
+ GenerateHostResult(op_stats.host_op_metrics_db(), &result);
+ *result.mutable_recommendation() = GenerateRecommendation();
+ return result;
+}
+
+} // namespace profiler
+} // namespace tensorflow
diff --git a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h
new file mode 100644
index 0000000..2bbe16e
--- /dev/null
+++ b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h
@@ -0,0 +1,39 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_INPUT_PIPELINE_ANALYSIS_H_
+#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_INPUT_PIPELINE_ANALYSIS_H_
+
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/protobuf/input_pipeline.pb.h"
+#include "tensorflow/core/profiler/protobuf/op_stats.pb.h"
+#include "tensorflow/core/profiler/protobuf/steps_db.pb.h"
+
+namespace tensorflow {
+namespace profiler {
+
+InputPipelineAnalysisResult ConvertOpStatsToInputPipelineAnalysis(
+ const OpStats& op_stats, const HardwareType& hardware_type);
+
+// Computes the summary of step time in milliseconds.
+StepSummary ComputeStepTimeSummaryInMs(
+ const ::tensorflow::protobuf::RepeatedPtrField<PerCoreStepInfo>&
+ grouped_by_step);
+
+} // namespace profiler
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_INPUT_PIPELINE_ANALYSIS_H_
diff --git a/tensorflow/core/profiler/convert/run_metadata_to_trace_events.cc b/tensorflow/core/profiler/convert/run_metadata_to_trace_events.cc
index 6d2705c..caad306 100644
--- a/tensorflow/core/profiler/convert/run_metadata_to_trace_events.cc
+++ b/tensorflow/core/profiler/convert/run_metadata_to_trace_events.cc
@@ -21,7 +21,7 @@
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
-#include "tensorflow/core/common_runtime/step_stats_collector.h"
+#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/platform/env_time.h"
namespace tensorflow {
diff --git a/tensorflow/core/profiler/internal/tfprof_show.cc b/tensorflow/core/profiler/internal/tfprof_show.cc
index e7a5b03..5d57c1b 100644
--- a/tensorflow/core/profiler/internal/tfprof_show.cc
+++ b/tensorflow/core/profiler/internal/tfprof_show.cc
@@ -170,7 +170,6 @@
}
info.push_back(fops);
}
- std::vector<string> attrs;
if (opts.select.find(kShown[0]) != opts.select.end()) {
info.push_back(FormatNodeMemory(node, node->proto().requested_bytes(),
node->proto().total_requested_bytes()));
diff --git a/tensorflow/core/profiler/lib/BUILD b/tensorflow/core/profiler/lib/BUILD
index 215eb15..2cda295 100644
--- a/tensorflow/core/profiler/lib/BUILD
+++ b/tensorflow/core/profiler/lib/BUILD
@@ -15,18 +15,19 @@
visibility = ["//tensorflow:internal"],
deps = [
":profiler_utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/platform",
"//tensorflow/core/profiler/internal:profiler_interface",
"//tensorflow/core/profiler/internal:profiler_factory",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
+ "//tensorflow/core/util:ptr_util",
] + select({
"//tensorflow:android": [],
"//conditions:default": [
"//tensorflow/core/profiler/convert:run_metadata_to_trace_events",
- "//tensorflow/core/platform",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- "//tensorflow/core:framework",
- "//tensorflow/core:protos_all_cc",
],
}),
)
diff --git a/tensorflow/core/profiler/lib/profiler_session.cc b/tensorflow/core/profiler/lib/profiler_session.cc
index 3882a63..ff2e5be 100644
--- a/tensorflow/core/profiler/lib/profiler_session.cc
+++ b/tensorflow/core/profiler/lib/profiler_session.cc
@@ -20,17 +20,18 @@
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/platform.h"
#include "tensorflow/core/platform/types.h"
-#if !defined(IS_MOBILE_PLATFORM)
-#include "tensorflow/core/profiler/convert/run_metadata_to_trace_events.h"
-#include "tensorflow/core/profiler/internal/profiler_factory.h"
-#include "tensorflow/core/profiler/lib/profiler_utils.h"
-#endif
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/protobuf/trace_events.pb.h"
#include "tensorflow/core/util/env_var.h"
#include "tensorflow/core/util/ptr_util.h"
+#if !defined(IS_MOBILE_PLATFORM)
+#include "tensorflow/core/profiler/convert/run_metadata_to_trace_events.h"
+#include "tensorflow/core/profiler/internal/profiler_factory.h"
+#include "tensorflow/core/profiler/lib/profiler_utils.h"
+#endif
+
namespace tensorflow {
/*static*/ std::unique_ptr<ProfilerSession> ProfilerSession::Create(
diff --git a/tensorflow/core/profiler/protobuf/BUILD b/tensorflow/core/profiler/protobuf/BUILD
index c9275d9..a42c70b 100644
--- a/tensorflow/core/profiler/protobuf/BUILD
+++ b/tensorflow/core/profiler/protobuf/BUILD
@@ -27,6 +27,16 @@
)
tf_proto_library(
+ name = "input_pipeline_proto",
+ srcs = ["input_pipeline.proto"],
+ cc_api_version = 2,
+ protodeps = [":hardware_types_proto"],
+ visibility = [
+ ":friends",
+ ],
+)
+
+tf_proto_library(
name = "op_metrics_proto",
srcs = ["op_metrics.proto"],
cc_api_version = 2,
diff --git a/tensorflow/core/profiler/protobuf/input_pipeline.proto b/tensorflow/core/profiler/protobuf/input_pipeline.proto
new file mode 100644
index 0000000..7b14e4a
--- /dev/null
+++ b/tensorflow/core/profiler/protobuf/input_pipeline.proto
@@ -0,0 +1,118 @@
+syntax = "proto3";
+
+package tensorflow.profiler;
+
+import "google/protobuf/any.proto";
+import "tensorflow/core/profiler/protobuf/hardware_types.proto";
+
+// Used for both step duration and Op duration.
+message StepSummary {
+ double average = 1;
+ double standard_deviation = 2;
+ double minimum = 3;
+ double maximum = 4;
+}
+
+// Per-step details on generic hardware.
+message PerGenericStepDetails {
+ // The step number of a step.
+ int32 step_number = 1;
+ // The step time (in ms).
+ double step_time_ms = 2;
+ // Breakdown of the step time in different event categories.
+ // The unknown time (in ms).
+ double unknown_time_ms = 3;
+ // The infeed time (in ms).
+ double infeed_ms = 4;
+ // The outfeed time (in ms).
+ double outfeed_ms = 5;
+ // The device-compute time (in ms).
+ double device_compute_ms = 6;
+ // The device-to-device communication time (in ms).
+ double device_to_device_ms = 7;
+ // The host-compute time (in ms).
+ double host_compute_ms = 8;
+ // The host-prepare time (in ms).
+ double host_prepare_ms = 9;
+ // The time spent on compiling (in ms).
+ double host_compile_ms = 10;
+}
+
+message InputTimeBreakdown {
+ // Time spent on demanded file read in microseconds.
+ double demanded_file_read_us = 1;
+ // Time spent on advanced file read in microseconds.
+ double advanced_file_read_us = 2;
+ // Time spent on data preprocessing in microseconds.
+ double preprocessing_us = 3;
+ // The infeed enqueue time in microseconds.
+ double enqueue_us = 4;
+ // This entry is for the situtation where we can't further
+ // break down the non-enqueue input time (because the input pipeline
+ // is not instrumented).
+ double unclassified_non_enqueue_us = 5;
+}
+
+message InputOpDetails {
+ // The Op's name.
+ string op_name = 1;
+ // The number of occurrences.
+ uint64 count = 2;
+ // Time (accumulated over all occurrences) in milliseconds.
+ double time_in_ms = 3;
+ // Time (accumulated over all occurrences) in
+ // percentage of the total input processing time.
+ double time_in_percent = 4;
+ // Self time (accumulated over all occurrences) in milliseconds.
+ double self_time_in_ms = 5;
+ // Self time (accumulated over all occurrences) in
+ // percentage of the total input processing time.
+ double self_time_in_percent = 6;
+ // Possible categories: "Enqueue", "Advanced file read",
+ // "Demanded file read", "Preprocessing", "Unknown".
+ string category = 7;
+}
+
+message InputPipelineAnalysisRecommendation {
+ // A list of detailed recommendations.
+ repeated string details = 1;
+}
+
+message GenericStepTimeBreakdown {
+ // Summary of all unknown time as a part of step in ms.
+ StepSummary unknown_time_ms_summary = 1;
+ // Summary of all infeed time as a part of step in ms.
+ StepSummary infeed_ms_summary = 2;
+ // Summary of all outfeed time as a part of step in ms.
+ StepSummary outfeed_ms_summary = 3;
+ // Summary of all device-compute time as a part of step in ms.
+ StepSummary device_compute_ms_summary = 4;
+ // Summary of all device-to-device time as a part of step in ms.
+ StepSummary device_to_device_ms_summary = 5;
+ // Summary of all host-compute time as a part of step in ms.
+ StepSummary host_compute_ms_summary = 6;
+ // Summary of all host-prepare time as a part of step in ms.
+ StepSummary host_prepare_ms_summary = 7;
+ // Summary of all compilation time as a part of step in ms.
+ StepSummary host_compile_ms_summary = 8;
+}
+
+message InputPipelineAnalysisResult {
+ // Hardware type.
+ HardwareType hardware_type = 1;
+ // Summary of all step duration across all cores.
+ StepSummary step_time_summary = 2;
+ // Summary of all infeed dequeue op duration as percentage of step duration.
+ StepSummary infeed_percent_summary = 3;
+ // Details of each step. Can be unpacked into a PerGenericStepDetails.
+ repeated google.protobuf.Any step_details = 4;
+ // The breakdown of the input processing time.
+ InputTimeBreakdown input_time_breakdown = 5;
+ // Details of each input Op executed.
+ repeated InputOpDetails input_op_details = 6;
+ // Recommendation for next steps to users.
+ InputPipelineAnalysisRecommendation recommendation = 7;
+ // Breakdown of the step time. Can be unpacked into a
+ // GenericStepTimeBreakdown.
+ google.protobuf.Any step_time_breakdown = 8;
+}
diff --git a/tensorflow/core/profiler/utils/event_span.cc b/tensorflow/core/profiler/utils/event_span.cc
index 8c31c55..e6e8fd2 100644
--- a/tensorflow/core/profiler/utils/event_span.cc
+++ b/tensorflow/core/profiler/utils/event_span.cc
@@ -116,17 +116,17 @@
}
EventType ClassifyCpuEvent(absl::string_view event_name, int64 correlation_id) {
- if (absl::StartsWithIgnoreCase(event_name, "MEMCPYHtoD"))
+ if (absl::StartsWithIgnoreCase(event_name, "MEMCPYHtoD") ||
+ absl::StrContains(event_name, "Infeed"))
return HOST_TO_DEVICE;
if (absl::StartsWithIgnoreCase(event_name, "MEMCPYHtoH")) return HOST_TO_HOST;
if (correlation_id >= 0 ||
absl::StartsWithIgnoreCase(event_name, "ExecutorState::Process")) {
return HOST_PREPARE;
- } else {
- if (absl::StartsWithIgnoreCase(event_name, "IteratorGetNext"))
- return HOST_WAIT_INPUT;
- return HOST_COMPUTE;
}
+ if (absl::StartsWithIgnoreCase(event_name, "IteratorGetNext"))
+ return HOST_WAIT_INPUT;
+ return HOST_COMPUTE;
}
std::string PrintEventType(EventType event_type) {
diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto
index bce52c6..4f9b0aa 100644
--- a/tensorflow/core/protobuf/config.proto
+++ b/tensorflow/core/protobuf/config.proto
@@ -367,6 +367,17 @@
// The execution of an individual op (for some op types) can be
// parallelized on a pool of intra_op_parallelism_threads.
// 0 means the system picks an appropriate number.
+ //
+ // If you create an ordinary session, e.g., from Python or C++,
+ // then there is exactly one intra op thread pool per process.
+ // The first session created determines the number of threads in this pool.
+ // All subsequent sessions reuse/share this one global pool.
+ //
+ // There are notable exceptions to the default behavior describe above:
+ // 1. There is an environment variable for overriding this thread pool,
+ // named TF_OVERRIDE_GLOBAL_THREADPOOL.
+ // 2. When connecting to a server, such as a remote `tf.train.Server`
+ // instance, then this option will be ignored altogether.
int32 intra_op_parallelism_threads = 2;
// Nodes that perform blocking operations are enqueued on a pool of
diff --git a/tensorflow/core/protobuf/debug_event.proto b/tensorflow/core/protobuf/debug_event.proto
index 8f9680f..badd518 100644
--- a/tensorflow/core/protobuf/debug_event.proto
+++ b/tensorflow/core/protobuf/debug_event.proto
@@ -100,6 +100,9 @@
// The ID of the graph (i.e., FuncGraph) executed here: applicable only
// to the execution of a FuncGraph.
string graph_id = 11;
+
+ // A device on which debugger-instrumented ops and/or tensors reside.
+ DebuggedDevice debugged_device = 12;
}
}
@@ -162,6 +165,7 @@
string graph_name = 3;
// Unique ID of the graph (generated by debugger).
+ // This is the ID of the immediately-enclosing graph.
string graph_id = 4;
// Name of the device that the op is assigned to (if available).
@@ -204,6 +208,18 @@
string outer_context_id = 6;
}
+// A device on which ops and/or tensors are instrumented by the debugger.
+message DebuggedDevice {
+ // Name of the device.
+ string device_name = 1;
+
+ // A debugger-generated ID for the device. Guaranteed to be unique within
+ // the scope of the debugged TensorFlow program, including single-host and
+ // multi-host settings.
+ // TODO(cais): Test the uniqueness guarantee in multi-host settings.
+ int32 device_id = 2;
+}
+
// Data relating to the eager execution of an op or a Graph.
// For a op that generates N output tensors (N >= 0), only one
// Execution proto will be used to describe the execution event.
@@ -236,6 +252,11 @@
// Stack trace of the eager execution.
CodeLocation code_location = 8;
+ // Debugged-generated IDs of the devices on which the output tensors reside.
+ // To look up details about the device (e.g., name), cross-reference this
+ // field with the DebuggedDevice messages.
+ repeated int32 output_tensor_device_ids = 9;
+
// TODO(cais): When backporting to V1 Session.run() support, add more fields
// such as fetches and feeds.
}
diff --git a/tensorflow/core/util/BUILD b/tensorflow/core/util/BUILD
index b79d266..8568507 100644
--- a/tensorflow/core/util/BUILD
+++ b/tensorflow/core/util/BUILD
@@ -490,6 +490,7 @@
"//tensorflow/core/platform:regexp",
"//third_party/eigen3",
"@com_google_absl//absl/base",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
diff --git a/tensorflow/core/util/debug_events_writer.cc b/tensorflow/core/util/debug_events_writer.cc
index 58994e7..595f92d 100644
--- a/tensorflow/core/util/debug_events_writer.cc
+++ b/tensorflow/core/util/debug_events_writer.cc
@@ -322,6 +322,23 @@
}
}
+int DebugEventsWriter::RegisterDeviceAndGetId(const string& device_name) {
+ mutex_lock l(device_mu_);
+ int& device_id = device_name_to_id_[device_name];
+ if (device_id == 0) {
+ device_id = device_name_to_id_.size();
+ DebugEvent debug_event;
+ MaybeSetDebugEventTimestamp(&debug_event, env_);
+ DebuggedDevice* debugged_device = debug_event.mutable_debugged_device();
+ debugged_device->set_device_name(device_name);
+ debugged_device->set_device_id(device_id);
+ string serialized;
+ debug_event.SerializeToString(&serialized);
+ graphs_writer_->WriteSerializedDebugEvent(serialized);
+ }
+ return device_id;
+}
+
Status DebugEventsWriter::FlushNonExecutionFiles() {
TF_RETURN_IF_ERROR(Init());
if (source_files_writer_ != nullptr) {
@@ -448,7 +465,9 @@
execution_buffer_(),
execution_buffer_mu_(),
graph_execution_trace_buffer_(),
- graph_execution_trace_buffer_mu_() {}
+ graph_execution_trace_buffer_mu_(),
+ device_name_to_id_(),
+ device_mu_() {}
Status DebugEventsWriter::InitNonMetadataFile(DebugEventFileType type) {
std::unique_ptr<SingleDebugEventFileWriter>* writer = nullptr;
diff --git a/tensorflow/core/util/debug_events_writer.h b/tensorflow/core/util/debug_events_writer.h
index 951dcba..78c23e3 100644
--- a/tensorflow/core/util/debug_events_writer.h
+++ b/tensorflow/core/util/debug_events_writer.h
@@ -18,6 +18,7 @@
#include <deque>
+#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/io/record_writer.h"
@@ -177,6 +178,11 @@
void WriteSerializedExecutionDebugEvent(const string& debug_event_str,
DebugEventFileType type);
+ // Given name of the device, retrieve a unique integer ID. As a side effect,
+ // if this is the first time this object encounters the device name,
+ // writes a DebuggedDevice proto to the .graphs file in the file set.
+ int RegisterDeviceAndGetId(const string& device_name);
+
// EventWriter automatically flushes and closes on destruction, but
// this method is provided for users who want to write to disk sooner
// and/or check for success.
@@ -233,6 +239,9 @@
GUARDED_BY(graph_execution_trace_buffer_mu_);
mutex graph_execution_trace_buffer_mu_;
+ absl::flat_hash_map<string, int> device_name_to_id_ GUARDED_BY(device_mu_);
+ mutex device_mu_;
+
std::unique_ptr<SingleDebugEventFileWriter> metadata_writer_;
std::unique_ptr<SingleDebugEventFileWriter> source_files_writer_;
std::unique_ptr<SingleDebugEventFileWriter> stack_frames_writer_;
diff --git a/tensorflow/core/util/debug_events_writer_test.cc b/tensorflow/core/util/debug_events_writer_test.cc
index 6ce7a06..e442a41 100644
--- a/tensorflow/core/util/debug_events_writer_test.cc
+++ b/tensorflow/core/util/debug_events_writer_test.cc
@@ -17,6 +17,7 @@
#include <vector>
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/io/path.h"
@@ -756,6 +757,50 @@
EXPECT_EQ(actuals.size(), 0);
}
+TEST_F(DebugEventsWriterTest, RegisterDeviceAndGetIdTrace) {
+ DebugEventsWriter* writer =
+ DebugEventsWriter::GetDebugEventsWriter(dump_root_);
+ TF_ASSERT_OK(writer->Init());
+
+ // Register and get some device IDs in a concurrent fashion.
+ thread::ThreadPool* thread_pool =
+ new thread::ThreadPool(Env::Default(), "test_pool", 8);
+ int device_ids[8];
+ for (int i = 0; i < 8; ++i) {
+ thread_pool->Schedule([i, &writer, &device_ids]() {
+ const string device_name = strings::Printf(
+ "/job:localhost/replica:0/task:0/device:GPU:%d", i % 4);
+ device_ids[i] = writer->RegisterDeviceAndGetId(device_name);
+ });
+ }
+ delete thread_pool;
+ TF_ASSERT_OK(writer->FlushNonExecutionFiles());
+ TF_ASSERT_OK(writer->Close());
+
+ // There should be only 4 unique device IDs, because there are only 4 unique
+ // device names.
+ EXPECT_EQ(device_ids[0], device_ids[4]);
+ EXPECT_EQ(device_ids[1], device_ids[5]);
+ EXPECT_EQ(device_ids[2], device_ids[6]);
+ EXPECT_EQ(device_ids[3], device_ids[7]);
+ // Assert that the four device IDs are all unique.
+ EXPECT_EQ(absl::flat_hash_set<int>(device_ids, device_ids + 8).size(), 4);
+
+ std::vector<DebugEvent> actuals;
+ ReadDebugEventProtos(writer, DebugEventFileType::GRAPHS, &actuals);
+ // Due to the `% 4`, there are only 4 unique device names, even though there
+ // are 8 threads each calling `RegisterDeviceAndGetId`.
+ EXPECT_EQ(actuals.size(), 4);
+ for (const DebugEvent& actual : actuals) {
+ const string& device_name = actual.debugged_device().device_name();
+ int device_index = -1;
+ CHECK(absl::SimpleAtoi(device_name.substr(strlen(
+ "/job:localhost/replica:0/task:0/device:GPU:")),
+ &device_index));
+ EXPECT_EQ(actual.debugged_device().device_id(), device_ids[device_index]);
+ }
+}
+
TEST_F(DebugEventsWriterTest, DisableCyclicBufferBeahavior) {
const size_t kCyclicBufferSize = 0; // A value <= 0 disables cyclic behavior.
DebugEventsWriter* writer =
diff --git a/tensorflow/core/util/device_name_utils_test.cc b/tensorflow/core/util/device_name_utils_test.cc
index 49bce7a..24657ae 100644
--- a/tensorflow/core/util/device_name_utils_test.cc
+++ b/tensorflow/core/util/device_name_utils_test.cc
@@ -426,8 +426,6 @@
}
TEST(DeviceNameUtilsTest, MergeDevNames) {
- DeviceNameUtils::ParsedName target;
-
// Idempotence tests.
MergeDevNamesHelper("", "", "");
MergeDevNamesHelper("/job:foo/replica:1/task:2/cpu:1",
diff --git a/tensorflow/core/util/stats_calculator.h b/tensorflow/core/util/stats_calculator.h
index 5005ee0..20cbe57 100644
--- a/tensorflow/core/util/stats_calculator.h
+++ b/tensorflow/core/util/stats_calculator.h
@@ -71,8 +71,21 @@
: static_cast<HighPrecisionValueType>(sum_) / count_;
}
+ // Returns sample variance.
+ ValueType sample_variance() const {
+ return all_same()
+ ? 0
+ : (squared_sum_ - std::pow(sum_, 2.0) / count_) / (count_ - 1);
+ }
+
+ // Returns population variance.
+ ValueType variance() const {
+ return all_same() ? 0 : (squared_sum_ / count_) - (avg() * avg());
+ }
+
+ // Returns population stddev.
ValueType std_deviation() const {
- return all_same() ? 0 : sqrt(squared_sum_ / count_ - avg() * avg());
+ return all_same() ? 0 : std::sqrt(variance());
}
void OutputToStream(std::ostream* stream) const {
diff --git a/tensorflow/core/util/stats_calculator_test.cc b/tensorflow/core/util/stats_calculator_test.cc
index 00d7bfc..d7efae3 100644
--- a/tensorflow/core/util/stats_calculator_test.cc
+++ b/tensorflow/core/util/stats_calculator_test.cc
@@ -14,6 +14,9 @@
==============================================================================*/
#include "tensorflow/core/util/stats_calculator.h"
+
+#include <cfloat>
+
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@@ -72,5 +75,34 @@
EXPECT_EQ(run1_mem_used + run2_mem_used, detail.mem_used.sum());
}
+TEST(StatsCalculatorTest, UpdateStat) {
+ Stat<double> stat;
+ EXPECT_TRUE(stat.empty());
+ EXPECT_TRUE(stat.all_same());
+ stat.UpdateStat(1);
+ EXPECT_TRUE(stat.all_same());
+ stat.UpdateStat(-1.0);
+ EXPECT_FALSE(stat.all_same());
+ stat.UpdateStat(100);
+ stat.UpdateStat(0);
+ EXPECT_EQ(4, stat.count());
+ EXPECT_EQ(-1, stat.min());
+ EXPECT_EQ(100, stat.max());
+ EXPECT_EQ(25, stat.avg());
+ EXPECT_EQ(1, stat.first());
+ EXPECT_EQ(0, stat.newest());
+ EXPECT_EQ(10002, stat.squared_sum());
+ EXPECT_EQ(625, stat.avg() * stat.avg());
+ // Sample variance
+ EXPECT_EQ(7502.0 / 3, stat.sample_variance());
+ // Sample standard deviation, from WolframAlpha
+ EXPECT_NEAR(50.00666622228147160678152, std::sqrt(stat.sample_variance()),
+ FLT_EPSILON);
+ // Population variance
+ EXPECT_NEAR(7502.0 / 4, stat.variance(), FLT_EPSILON);
+ // Population standard deviation, from WolframAlpha
+ EXPECT_NEAR(43.30704330706496060826769, stat.std_deviation(), FLT_EPSILON);
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 451be22..38759ee 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -3646,6 +3646,54 @@
return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4), op.Output(5), op.Output(6)
}
+// Calculates gains for each feature and returns the best possible split information for each node. However, if no split is found, then no split information is returned for that node.
+//
+// The split information is the best threshold (bucket id), gains and left/right node contributions per node for each feature.
+//
+// It is possible that not all nodes can be split on each feature. Hence, the list of possible nodes can differ between the features. Therefore, we return `node_ids_list` for each feature, containing the list of nodes that this feature can be used to split.
+//
+// In this manner, the output is the best split per features and per node, so that it needs to be combined later to produce the best split for each node (among all possible features).
+//
+// The output shapes are compatible in a way that the first dimension of all tensors are the same and equal to the number of possible split nodes for each feature.
+//
+// Arguments:
+// node_id_range: A Rank 1 tensor (shape=[2]) to specify the range [first, last) of node ids to process within `stats_summary_list`. The nodes are iterated between the two nodes specified by the tensor, as like `for node_id in range(node_id_range[0], node_id_range[1])` (Note that the last index node_id_range[1] is exclusive).
+// stats_summaries_list: A list of Rank 4 tensor (#shape=[max_splits, feature_dims, bucket, stats_dims]) for accumulated stats summary (gradient/hessian) per node, per dimension, per buckets for each feature.
+// The first dimension of the tensor is the maximum number of splits, and thus not all elements of it will be used, but only the indexes specified by node_ids will be used.
+// split_types: A Rank 1 tensor indicating if this Op should perform inequality split or equality split per feature.
+// candidate_feature_ids: Rank 1 tensor with ids for each feature. This is the real id of the feature.
+// l1: l1 regularization factor on leaf weights, per instance based.
+// l2: l2 regularization factor on leaf weights, per instance based.
+// tree_complexity: adjustment to the gain, per leaf based.
+// min_node_weight: mininum avg of hessians in a node before required for the node to be considered for splitting.
+// logits_dimension: The dimension of logit, i.e., number of classes.
+//
+// Returns:
+// node_ids: A Rank 1 tensors indicating possible split node ids for each feature. The length of the list is num_features, but each tensor has different size as each feature provides different possible nodes. See above for details like shapes and sizes.
+// gains: A Rank 1 tensor indicating the best gains for each feature to split for certain nodes. See above for details like shapes and sizes.
+// feature_ids: A Rank 1 tensors indicating the best feature id for each node. See above for details like shapes and sizes.
+// feature_dimensions: A Rank 1 tensors indicating the best feature dimension for each feature to split for certain nodes if the feature is multi-dimension. See above for details like shapes and sizes.
+// thresholds: A Rank 1 tensors indicating the bucket id to compare with (as a threshold) for split in each node. See above for details like shapes and sizes.
+// left_node_contribs: A Rank 2 tensors indicating the contribution of the left nodes when branching from parent nodes (given by the tensor element in the output node_ids_list) to the left direction by the given threshold for each feature. This value will be used to make the left node value by adding to the parent node value. Second dimension size is 1 for 1-dimensional logits, but would be larger for multi-class problems. See above for details like shapes and sizes.
+// right_node_contribs: A Rank 2 tensors, with the same shape/conditions as left_node_contribs_list, but just that the value is for the right node.
+// split_with_default_directions: A Rank 1 tensors indicating the which direction to go if data is missing. See above for details like shapes and sizes.
+// Inequality with default left returns 0, inequality with default right returns 1, equality with default right returns 2.
+func BoostedTreesCalculateBestFeatureSplitV2(scope *Scope, node_id_range tf.Output, stats_summaries_list []tf.Output, split_types tf.Output, candidate_feature_ids tf.Output, l1 tf.Output, l2 tf.Output, tree_complexity tf.Output, min_node_weight tf.Output, logits_dimension int64) (node_ids tf.Output, gains tf.Output, feature_ids tf.Output, feature_dimensions tf.Output, thresholds tf.Output, left_node_contribs tf.Output, right_node_contribs tf.Output, split_with_default_directions tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"logits_dimension": logits_dimension}
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesCalculateBestFeatureSplitV2",
+ Input: []tf.Input{
+ node_id_range, tf.OutputList(stats_summaries_list), split_types, candidate_feature_ids, l1, l2, tree_complexity, min_node_weight,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4), op.Output(5), op.Output(6), op.Output(7)
+}
+
// Calculates gains for each feature and returns the best possible split information for the feature.
//
// The split information is the best threshold (bucket id), gains and left/right node contributions per node for each feature.
@@ -11649,7 +11697,7 @@
// element on that dimension. The dimension order is determined by the value of
// `data_format`, see above for details. Dilations in the batch and depth
// dimensions must be 1.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr {
return func(m optionalAttr) {
m["dilations"] = value
@@ -11906,7 +11954,7 @@
//
// value: The cropped area of the image must have an aspect ratio =
// width / height within this range.
-// If not specified, defaults to {f:0.75 f:1.33}
+// If not specified, defaults to {f:0.75 f:1.33}
func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr {
return func(m optionalAttr) {
m["aspect_ratio_range"] = value
@@ -11917,7 +11965,7 @@
//
// value: The cropped area of the image must contain a fraction of the
// supplied image within this range.
-// If not specified, defaults to {f:0.05 f:1}
+// If not specified, defaults to {f:0.05 f:1}
func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr {
return func(m optionalAttr) {
m["area_range"] = value
@@ -12123,7 +12171,7 @@
//
// value: The cropped area of the image must have an aspect ratio =
// width / height within this range.
-// If not specified, defaults to {f:0.75 f:1.33}
+// If not specified, defaults to {f:0.75 f:1.33}
func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr {
return func(m optionalAttr) {
m["aspect_ratio_range"] = value
@@ -12134,7 +12182,7 @@
//
// value: The cropped area of the image must contain a fraction of the
// supplied image within this range.
-// If not specified, defaults to {f:0.05 f:1}
+// If not specified, defaults to {f:0.05 f:1}
func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr {
return func(m optionalAttr) {
m["area_range"] = value
@@ -18940,7 +18988,7 @@
// ImageSummaryBadColor sets the optional bad_color attribute to value.
//
// value: Color to use for pixels with non-finite values.
-// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255}
+// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255}
func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr {
return func(m optionalAttr) {
m["bad_color"] = value
@@ -19935,7 +19983,7 @@
// filter element on that dimension. The dimension order is determined by the
// value of `data_format`, see above for details. Dilations in the batch and
// depth dimensions must be 1.
-// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr {
return func(m optionalAttr) {
m["dilations"] = value
@@ -21232,7 +21280,7 @@
// element on that dimension. The dimension order is determined by the value of
// `data_format`, see above for details. Dilations in the batch and depth
// dimensions must be 1.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr {
return func(m optionalAttr) {
m["dilations"] = value
@@ -21940,7 +21988,7 @@
// filter element on that dimension. The dimension order is determined by the
// value of `data_format`, see above for details. Dilations in the batch and
// depth dimensions must be 1.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
func Conv2DDilations(value []int64) Conv2DAttr {
return func(m optionalAttr) {
m["dilations"] = value
@@ -22136,7 +22184,7 @@
// QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value.
//
// value: List of dilation values.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr {
return func(m optionalAttr) {
m["dilations"] = value
@@ -22205,7 +22253,7 @@
// QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value.
//
// value: List of dilation values.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr {
return func(m optionalAttr) {
m["dilations"] = value
@@ -22320,7 +22368,7 @@
// QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value.
//
// value: List of dilation values.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr {
return func(m optionalAttr) {
m["dilations"] = value
@@ -22379,7 +22427,7 @@
// QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value.
//
// value: List of dilation values.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr {
return func(m optionalAttr) {
m["dilations"] = value
@@ -22553,7 +22601,7 @@
// QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value.
//
// value: list of dilation values.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr {
return func(m optionalAttr) {
m["dilations"] = value
@@ -22744,7 +22792,7 @@
// filter element on that dimension. The dimension order is determined by the
// value of `data_format`, see above for details. Dilations in the batch and
// depth dimensions must be 1.
-// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr {
return func(m optionalAttr) {
m["dilations"] = value
@@ -25318,7 +25366,7 @@
// element on that dimension. The dimension order is determined by the value of
// `data_format`, see above for details. Dilations in the batch and depth
// dimensions must be 1.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr {
return func(m optionalAttr) {
m["dilations"] = value
@@ -25375,7 +25423,7 @@
type Conv3DBackpropInputAttr func(optionalAttr)
// Conv3DBackpropInputDilations sets the optional dilations attribute to value.
-// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr {
return func(m optionalAttr) {
m["dilations"] = value
@@ -25707,7 +25755,7 @@
// element on that dimension. The dimension order is determined by the value of
// `data_format`, see above for details. Dilations in the batch and depth
// dimensions must be 1.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr {
return func(m optionalAttr) {
m["dilations"] = value
@@ -26330,7 +26378,7 @@
// filter element on that dimension. The dimension order is determined by the
// value of `data_format`, see above for details. Dilations in the batch and
// depth dimensions must be 1.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr {
return func(m optionalAttr) {
m["dilations"] = value
@@ -27351,7 +27399,7 @@
// filter element on that dimension. The dimension order is determined by the
// value of `data_format`, see above for details. Dilations in the batch and
// depth dimensions must be 1.
-// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
func Conv3DDilations(value []int64) Conv3DAttr {
return func(m optionalAttr) {
m["dilations"] = value
@@ -33729,7 +33777,7 @@
type Conv3DBackpropFilterAttr func(optionalAttr)
// Conv3DBackpropFilterDilations sets the optional dilations attribute to value.
-// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr {
return func(m optionalAttr) {
m["dilations"] = value
@@ -45156,7 +45204,7 @@
// element on that dimension. The dimension order is determined by the value of
// `data_format`, see above for details. Dilations in the batch and depth
// dimensions must be 1.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr {
return func(m optionalAttr) {
m["dilations"] = value
diff --git a/tensorflow/lite/build_def.bzl b/tensorflow/lite/build_def.bzl
index 08686da..b736af5 100644
--- a/tensorflow/lite/build_def.bzl
+++ b/tensorflow/lite/build_def.bzl
@@ -659,7 +659,7 @@
else:
return []
-def gen_model_coverage_test(src, model_name, data, failure_type, tags):
+def gen_model_coverage_test(src, model_name, data, failure_type, tags, size = "medium"):
"""Generates Python test targets for testing TFLite models.
Args:
@@ -682,7 +682,7 @@
name = "model_coverage_test_%s_%s" % (model_name, target_op_sets.lower().replace(",", "_")),
srcs = [src],
main = src,
- size = "large",
+ size = size,
args = [
"--model_name=%s" % model_name,
"--target_ops=%s" % target_op_sets,
@@ -691,6 +691,7 @@
srcs_version = "PY2AND3",
python_version = "PY3",
tags = [
+ "no_gpu", # Executing with TF GPU configurations is redundant.
"no_oss",
"no_windows",
] + tags,
diff --git a/tensorflow/lite/c/BUILD b/tensorflow/lite/c/BUILD
index 0fe9d97..8388653 100644
--- a/tensorflow/lite/c/BUILD
+++ b/tensorflow/lite/c/BUILD
@@ -50,26 +50,20 @@
cc_library(
name = "c_api_internal",
- srcs = [
- "c_api.h",
- "common.h",
- ],
hdrs = ["c_api_internal.h"],
copts = tflite_copts(),
visibility = ["//visibility:private"],
deps = [
":common",
"//tensorflow/lite:framework",
+ "//tensorflow/lite/core/api",
],
)
cc_library(
name = "c_api",
srcs = ["c_api.cc"],
- hdrs = [
- "c_api.h",
- "common.h",
- ],
+ hdrs = ["c_api.h"],
copts = tflite_copts(),
visibility = [
":experimental",
@@ -79,6 +73,7 @@
":common",
"//tensorflow/lite:framework",
"//tensorflow/lite:version",
+ "//tensorflow/lite/core/api",
"//tensorflow/lite/kernels:builtin_ops",
],
alwayslink = 1,
@@ -92,6 +87,8 @@
deps = [
":c_api",
":c_api_internal",
+ ":common",
+ "//tensorflow/lite:framework",
"//tensorflow/lite:kernel_api",
],
alwayslink = 1,
@@ -107,7 +104,7 @@
],
deps = [
":c_api",
- "//tensorflow/lite/c:c_api_internal",
+ ":common",
"//tensorflow/lite/testing:util",
"@com_google_googletest//:gtest",
],
@@ -121,6 +118,7 @@
deps = [
":c_api",
":c_api_experimental",
+ ":common",
"//tensorflow/lite:kernel_api",
"//tensorflow/lite/testing:util",
"@com_google_googletest//:gtest",
@@ -136,6 +134,7 @@
],
build_for_embedded = True,
visibility = [
+ "//speech/micro/nn:__pkg__",
"//tensorflow/lite:__subpackages__",
],
alwayslink = 1,
diff --git a/tensorflow/lite/c/builtin_op_data_test.cc b/tensorflow/lite/c/builtin_op_data_test.cc
index af4f474..8d01528 100644
--- a/tensorflow/lite/c/builtin_op_data_test.cc
+++ b/tensorflow/lite/c/builtin_op_data_test.cc
@@ -75,6 +75,7 @@
TfLiteRankParams rank_params;
TfLiteFakeQuantParams fake_quant_params;
TfLitePackParams pack_params;
+ TfLiteUnpackParams unpack_params;
TfLiteOneHotParams one_hot_params;
TfLiteBidirectionalSequenceRNNParams bidi_sequence_rnn_params;
TfLiteBidirectionalSequenceLSTMParams bidi_sequence_lstm_params;
diff --git a/tensorflow/lite/c/c_api_experimental.cc b/tensorflow/lite/c/c_api_experimental.cc
index 4b81217..dbf4cd7 100644
--- a/tensorflow/lite/c/c_api_experimental.cc
+++ b/tensorflow/lite/c/c_api_experimental.cc
@@ -15,7 +15,15 @@
#include "tensorflow/lite/c/c_api_experimental.h"
+#include <stdint.h>
+
+#include <memory>
+
+#include "tensorflow/lite/builtin_ops.h"
+#include "tensorflow/lite/c/c_api.h"
#include "tensorflow/lite/c/c_api_internal.h"
+#include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/mutable_op_resolver.h"
#ifdef __cplusplus
extern "C" {
diff --git a/tensorflow/lite/c/c_api_experimental.h b/tensorflow/lite/c/c_api_experimental.h
index 554dabe..bf21e2e 100644
--- a/tensorflow/lite/c/c_api_experimental.h
+++ b/tensorflow/lite/c/c_api_experimental.h
@@ -17,6 +17,7 @@
#include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/c/c_api.h"
+#include "tensorflow/lite/c/common.h"
#ifdef __cplusplus
extern "C" {
diff --git a/tensorflow/lite/c/c_api_experimental_test.cc b/tensorflow/lite/c/c_api_experimental_test.cc
index ce72954..6de8236 100644
--- a/tensorflow/lite/c/c_api_experimental_test.cc
+++ b/tensorflow/lite/c/c_api_experimental_test.cc
@@ -18,6 +18,7 @@
#include <gtest/gtest.h>
#include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/c/c_api.h"
+#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/testing/util.h"
namespace {
diff --git a/tensorflow/lite/c/c_api_internal.h b/tensorflow/lite/c/c_api_internal.h
index 3ce7388..973d822 100644
--- a/tensorflow/lite/c/c_api_internal.h
+++ b/tensorflow/lite/c/c_api_internal.h
@@ -15,10 +15,15 @@
#ifndef TENSORFLOW_LITE_C_C_API_INTERNAL_H_
#define TENSORFLOW_LITE_C_C_API_INTERNAL_H_
-#include "tensorflow/lite/c/common.h"
+#include <stdarg.h>
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/model.h"
-#include "tensorflow/lite/op_resolver.h"
+#include "tensorflow/lite/mutable_op_resolver.h"
// Internal structures used by the C API. These are likely to change and should
// not be depended on directly by any C API clients.
diff --git a/tensorflow/lite/c/c_api_test.cc b/tensorflow/lite/c/c_api_test.cc
index eb2a70f..03d22a8 100644
--- a/tensorflow/lite/c/c_api_test.cc
+++ b/tensorflow/lite/c/c_api_test.cc
@@ -15,11 +15,15 @@
#include "tensorflow/lite/c/c_api.h"
+#include <stdarg.h>
+#include <stdint.h>
+
#include <array>
#include <fstream>
#include <vector>
#include <gtest/gtest.h>
+#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/testing/util.h"
namespace {
diff --git a/tensorflow/lite/core/api/profiler.h b/tensorflow/lite/core/api/profiler.h
index 7bc2965..dcbdf94 100644
--- a/tensorflow/lite/core/api/profiler.h
+++ b/tensorflow/lite/core/api/profiler.h
@@ -25,9 +25,15 @@
enum class EventType {
// Default event type, the metadata field has no special significance.
DEFAULT = 0,
+
// The event is an operator invocation and the event_metadata field is the
// index of operator node.
- OPERATOR_INVOKE_EVENT = 1
+ OPERATOR_INVOKE_EVENT = 1,
+
+ // The event is an invocation for an internal operator of a TFLite delegate.
+ // The event_metadata field is the index of operator node that's specific to
+ // the delegate.
+ DELEGATE_OPERATOR_INVOKE_EVENT = 2
};
virtual ~Profiler() {}
@@ -81,6 +87,15 @@
static_cast<uint32_t>(node_index)) {}
};
+class ScopedDelegateOperatorProfile : public ScopedProfile {
+ public:
+ ScopedDelegateOperatorProfile(Profiler* profiler, const char* tag,
+ int node_index)
+ : ScopedProfile(profiler, tag,
+ Profiler::EventType::DELEGATE_OPERATOR_INVOKE_EVENT,
+ static_cast<uint32_t>(node_index)) {}
+};
+
} // namespace tflite
#define TFLITE_VARNAME_UNIQ(name, ctr) name##ctr
@@ -93,8 +108,8 @@
tflite::ScopedOperatorProfile TFLITE_VARNAME_UNIQ(_profile_, __COUNTER__)( \
(profiler), (tag), (node_index))
-#define TFLITE_SCOPED_DELEGATE_OPERATOR_PROFILE(profiler, node_index) \
- TFLITE_SCOPED_TAGGED_OPERATOR_PROFILE((profiler), "DelegateOpInvoke", \
- (node_index))
+#define TFLITE_SCOPED_DELEGATE_OPERATOR_PROFILE(profiler, tag, node_index) \
+ tflite::ScopedDelegateOperatorProfile TFLITE_VARNAME_UNIQ( \
+ _profile_, __COUNTER__)((profiler), (tag), (node_index))
#endif // TENSORFLOW_LITE_CORE_API_PROFILER_H_
diff --git a/tensorflow/lite/delegates/flex/kernel.cc b/tensorflow/lite/delegates/flex/kernel.cc
index f733364..09a1a73 100644
--- a/tensorflow/lite/delegates/flex/kernel.cc
+++ b/tensorflow/lite/delegates/flex/kernel.cc
@@ -529,7 +529,8 @@
// Execute the TensorFlow Ops sequentially.
for (auto& node_data : op_data->nodes) {
TFLITE_SCOPED_DELEGATE_OPERATOR_PROFILE(
- reinterpret_cast<Profiler*>(context->profiler), node_data->index());
+ reinterpret_cast<Profiler*>(context->profiler),
+ node_data->name().c_str(), node_data->index());
auto status = ExecuteFlexOp(context, buffer_map, node_data.get());
TF_LITE_ENSURE_OK(context, ConvertStatus(context, status));
diff --git a/tensorflow/lite/delegates/gpu/BUILD b/tensorflow/lite/delegates/gpu/BUILD
index 8da62f0..9b787e7 100644
--- a/tensorflow/lite/delegates/gpu/BUILD
+++ b/tensorflow/lite/delegates/gpu/BUILD
@@ -1,5 +1,6 @@
load("//tensorflow/lite:special_rules.bzl", "tflite_extra_gles_deps")
load("@build_bazel_rules_apple//apple:ios.bzl", "ios_static_framework")
+load("@build_bazel_rules_apple//apple:macos.bzl", "macos_dylib")
package(
default_visibility = ["//visibility:public"],
@@ -101,6 +102,7 @@
"//tensorflow/lite/delegates/gpu/metal:inference_context",
"@com_google_absl//absl/types:span",
],
+ alwayslink = 1,
)
objc_library(
@@ -110,6 +112,7 @@
deps = [
"//tensorflow/lite/delegates/gpu:metal_delegate",
],
+ alwayslink = 1,
)
# build -c opt --config android_arm64 --copt -Os --copt -DTFLITE_GPU_BINARY_RELEASE --copt --linkopt -s --strip always :libtensorflowlite_gpu_gl.so
@@ -173,6 +176,22 @@
deps = [":metal_delegate"],
)
+# Note: Support for MacOS is best-effort at the moment.
+# bazel build -c opt --copt -Os --copt -DTFLITE_GPU_BINARY_RELEASE --copt -fvisibility=hidden --linkopt -s --strip always --cxxopt=-std=c++14 :tensorflow_lite_gpu_dylib --apple_platform_type=macos
+macos_dylib(
+ name = "tensorflow_lite_gpu_dylib",
+ minimum_os_version = "10.13",
+ tags = [
+ "manual",
+ "nobuilder",
+ "notap",
+ ],
+ deps = [
+ ":metal_delegate",
+ ":metal_delegate_internal",
+ ],
+)
+
cc_library(
name = "api",
srcs = ["api.cc"],
diff --git a/tensorflow/lite/delegates/gpu/cl/cl_device.cc b/tensorflow/lite/delegates/gpu/cl/cl_device.cc
index 6c29d7b..108d4ab 100644
--- a/tensorflow/lite/delegates/gpu/cl/cl_device.cc
+++ b/tensorflow/lite/delegates/gpu/cl/cl_device.cc
@@ -135,7 +135,7 @@
// check that gpu_version belong to range min_version-max_version
// min_version is included and max_version is excluded.
-bool isGPUVersionInRange(int gpu_version, int min_version, int max_version) {
+bool IsGPUVersionInRange(int gpu_version, int min_version, int max_version) {
return gpu_version >= min_version && gpu_version < max_version;
}
} // namespace
@@ -262,10 +262,14 @@
extensions =
absl::StrSplit(GetDeviceInfo<std::string>(id, CL_DEVICE_EXTENSIONS), ' ');
supports_fp16 = false;
+ supports_image3d_writes = false;
for (const auto& ext : extensions) {
if (ext == "cl_khr_fp16") {
supports_fp16 = true;
}
+ if (ext == "cl_khr_3d_image_writes") {
+ supports_image3d_writes = true;
+ }
}
if (vendor == Vendor::POWERVR && !supports_fp16) {
// PowerVR doesn't have full support of fp16 and so doesn't list this
@@ -273,9 +277,17 @@
// so we will use it.
supports_fp16 = true;
}
+
+ if (vendor == Vendor::QUALCOMM &&
+ IsGPUVersionInRange(adreno_info.gpu_version, 400, 500)) {
+ // in local tests Adreno 430 can write in image 3d, at least on small sizes,
+ // but it doesn't have cl_khr_3d_image_writes in list of available
+ // extensions
+ supports_image3d_writes = true;
+ }
compute_units_count = GetDeviceInfo<cl_uint>(id, CL_DEVICE_MAX_COMPUTE_UNITS);
- image2d_max_width = GetDeviceInfo<size_t>(id, CL_DEVICE_IMAGE2D_MAX_HEIGHT);
- image2d_max_height = GetDeviceInfo<size_t>(id, CL_DEVICE_IMAGE2D_MAX_WIDTH);
+ image2d_max_width = GetDeviceInfo<size_t>(id, CL_DEVICE_IMAGE2D_MAX_WIDTH);
+ image2d_max_height = GetDeviceInfo<size_t>(id, CL_DEVICE_IMAGE2D_MAX_HEIGHT);
buffer_max_size = GetDeviceInfo<cl_ulong>(id, CL_DEVICE_MAX_MEM_ALLOC_SIZE);
if (cl_version >= OpenCLVersion::CL_1_2) {
image_buffer_max_size =
@@ -283,6 +295,9 @@
image_array_max_layers =
GetDeviceInfo<size_t>(id, CL_DEVICE_IMAGE_MAX_ARRAY_SIZE);
}
+ image3d_max_width = GetDeviceInfo<size_t>(id, CL_DEVICE_IMAGE3D_MAX_WIDTH);
+ image3d_max_height = GetDeviceInfo<size_t>(id, CL_DEVICE_IMAGE2D_MAX_HEIGHT);
+ image3d_max_depth = GetDeviceInfo<size_t>(id, CL_DEVICE_IMAGE3D_MAX_DEPTH);
GetDeviceWorkDimsSizes(id, &max_work_group_sizes);
}
@@ -294,6 +309,8 @@
return cl_version >= OpenCLVersion::CL_1_2;
}
+bool DeviceInfo::SupportsImage3D() const { return supports_image3d_writes; }
+
CLDevice::CLDevice(cl_device_id id, cl_platform_id platform_id)
: id_(id), platform_id_(platform_id), info_(id) {}
@@ -347,6 +364,8 @@
return info_.SupportsImageBuffer();
}
+bool CLDevice::SupportsImage3D() const { return info_.SupportsImage3D(); }
+
std::string CLDevice::GetPlatformVersion() const {
return GetPlatformInfo(platform_id_, CL_PLATFORM_VERSION);
}
@@ -355,22 +374,22 @@
bool CLDevice::IsAdreno3xx() const {
return IsAdreno() &&
- isGPUVersionInRange(info_.adreno_info.gpu_version, 300, 400);
+ IsGPUVersionInRange(info_.adreno_info.gpu_version, 300, 400);
}
bool CLDevice::IsAdreno4xx() const {
return IsAdreno() &&
- isGPUVersionInRange(info_.adreno_info.gpu_version, 400, 500);
+ IsGPUVersionInRange(info_.adreno_info.gpu_version, 400, 500);
}
bool CLDevice::IsAdreno5xx() const {
return IsAdreno() &&
- isGPUVersionInRange(info_.adreno_info.gpu_version, 500, 600);
+ IsGPUVersionInRange(info_.adreno_info.gpu_version, 500, 600);
}
bool CLDevice::IsAdreno6xx() const {
return IsAdreno() &&
- isGPUVersionInRange(info_.adreno_info.gpu_version, 600, 700);
+ IsGPUVersionInRange(info_.adreno_info.gpu_version, 600, 700);
}
bool CLDevice::IsAdreno6xxOrHigher() const {
diff --git a/tensorflow/lite/delegates/gpu/cl/cl_device.h b/tensorflow/lite/delegates/gpu/cl/cl_device.h
index b051546..c19415c 100644
--- a/tensorflow/lite/delegates/gpu/cl/cl_device.h
+++ b/tensorflow/lite/delegates/gpu/cl/cl_device.h
@@ -66,9 +66,11 @@
bool SupportsTextureArray() const;
bool SupportsImageBuffer() const;
+ bool SupportsImage3D() const;
std::vector<std::string> extensions;
bool supports_fp16;
+ bool supports_image3d_writes;
Vendor vendor;
OpenCLVersion cl_version;
int compute_units_count;
@@ -77,6 +79,9 @@
uint64_t image2d_max_height;
uint64_t image_buffer_max_size;
uint64_t image_array_max_layers;
+ uint64_t image3d_max_width;
+ uint64_t image3d_max_height;
+ uint64_t image3d_max_depth;
int3 max_work_group_sizes;
AdrenoInfo adreno_info;
@@ -107,6 +112,7 @@
bool SupportsFP16() const;
bool SupportsTextureArray() const;
bool SupportsImageBuffer() const;
+ bool SupportsImage3D() const;
bool SupportsExtension(const std::string& extension) const;
bool IsAdreno() const;
bool IsAdreno3xx() const;
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/BUILD b/tensorflow/lite/delegates/gpu/cl/kernels/BUILD
index 5b12784..7830cb3 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/BUILD
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/BUILD
@@ -379,6 +379,7 @@
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common:tensor",
"//tensorflow/lite/delegates/gpu/common:types",
+ "@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.cc b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.cc
index 080af24..aeed3f4 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.cc
@@ -18,9 +18,11 @@
#include <string>
#include <utility>
+#include "absl/strings/substitute.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
+#include "tensorflow/lite/delegates/gpu/common/status.h"
namespace tflite {
namespace gpu {
@@ -29,7 +31,7 @@
std::string GenerateConvolutionTransposedCode(
const OperationDef& op_def, const LinearStorage& biases,
- const CLDevice& device,
+ const CLDevice& device, bool weights_are_buffer, const int3& block_size,
const std::vector<ElementwiseOperation*>& linked_operations) {
const TensorCodeGenerator::SizeVariablesNames src_size(
"src_size.x", "src_size.y", "src_size.z", "src_size.w");
@@ -37,39 +39,43 @@
"dst_size.x", "dst_size.y", "dst_size.z", "dst_size.w");
TensorCodeGenerator src_tensor("src_data", src_size, op_def.src_tensors[0]);
TensorCodeGenerator dst_tensor("dst_data", dst_size, op_def.dst_tensors[0]);
+
const auto src_tensor_type = op_def.src_tensors[0].storage_type;
+ bool image_buffer = src_tensor_type == TensorStorageType::IMAGE_BUFFER;
+ bool manual_clamp =
+ image_buffer || src_tensor_type == TensorStorageType::BUFFER;
const std::string batch_id = op_def.batch_support ? "B" : "";
std::string c = GetCommonDefines(op_def.precision);
- switch (op_def.precision) {
- case CalculationsPrecision::F32:
- case CalculationsPrecision::F16:
- if (src_tensor_type == TensorStorageType::BUFFER) {
- c += "#define CONV(R, S) \\\n";
- c += "R += S.x * f0.s0123; \\\n";
- c += "R += S.y * f0.s4567; \\\n";
- c += "R += S.z * f0.s89ab; \\\n";
- c += "R += S.w * f0.scdef; \n";
- } else {
- c += "#define CONV(R, S) \\\n";
- c += "R += S.x * f[0]; \\\n";
- c += "R += S.y * f[1]; \\\n";
- c += "R += S.z * f[2]; \\\n";
- c += "R += S.w * f[3]; \n";
- }
- break;
- case CalculationsPrecision::F32_F16:
- if (src_tensor_type == TensorStorageType::BUFFER) {
- c += "#define CONV(R, S) \\\n";
- c += "R += convert_float4(S.x * f0.s0123 + S.y * f0.s4567 + S.z * "
- "f0.s89ab + S.w * f0.scdef);\n";
- } else {
- c += "#define CONV(R, S) \\\n";
- c += "R += convert_float4(S.x * f[0] + S.y * f[1]";
- c += "+ S.z * f[2] + S.w * f[3]);\n";
- }
- break;
+ for (int z = 0; z < block_size.z; ++z) {
+ const std::string f0 =
+ weights_are_buffer ? "weights_cache[" + std::to_string(z) + "].s0123"
+ : "f" + std::to_string(z * 4 + 0);
+ const std::string f1 =
+ weights_are_buffer ? "weights_cache[" + std::to_string(z) + "].s4567"
+ : "f" + std::to_string(z * 4 + 1);
+ const std::string f2 =
+ weights_are_buffer ? "weights_cache[" + std::to_string(z) + "].s89ab"
+ : "f" + std::to_string(z * 4 + 2);
+ const std::string f3 =
+ weights_are_buffer ? "weights_cache[" + std::to_string(z) + "].scdef"
+ : "f" + std::to_string(z * 4 + 3);
+ switch (op_def.precision) {
+ case CalculationsPrecision::F32:
+ case CalculationsPrecision::F16:
+ c += "#define CONV" + std::to_string(z) + "(R, S) \\\n";
+ c += "R += S.x * " + f0 + "; \\\n";
+ c += "R += S.y * " + f1 + "; \\\n";
+ c += "R += S.z * " + f2 + "; \\\n";
+ c += "R += S.w * " + f3 + "; \n";
+ break;
+ case CalculationsPrecision::F32_F16:
+ c += "#define CONV" + std::to_string(z) + "(R, S) \\\n";
+ c += "R += convert_float4(S.x * " + f0 + " + S.y * " + f1 +
+ " + S.z * " + f2 + " + S.w * " + f3 + ");\n";
+ break;
+ }
}
switch (op_def.precision) {
@@ -84,179 +90,298 @@
c += "__kernel void main_function(\n";
c += src_tensor.GetDeclaration(AccessType::READ) + ",\n";
- if (src_tensor_type == TensorStorageType::BUFFER) {
+ if (weights_are_buffer) {
c += " __global FLT16* filters, \n";
- c += " __global FLT4* biases";
} else {
- c += " __read_only image2d_t filters, \n";
- c += " __read_only image2d_t biases";
+ c += " __read_only image2d_t filters0, \n";
+ c += " __read_only image2d_t filters1, \n";
+ c += " __read_only image2d_t filters2, \n";
+ c += " __read_only image2d_t filters3, \n";
}
+ c += biases.GetDeclaration();
c += GetArgsDeclaration(linked_operations);
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
c += " int2 kernel_size, \n";
c += " int2 stride, \n";
c += " int2 padding, \n";
- c += " int2 k_offset, \n";
- c += " int2 inner_size, \n";
c += " int4 src_size, \n";
c += " int4 dst_size \n";
c += ") {\n";
if (op_def.batch_support) {
c += " int linear_id = get_global_id(0);\n";
- c += " int X = linear_id / dst_size.w;\n";
+ c += " int dst_x = (linear_id / dst_size.w);\n";
c += " int B = linear_id % dst_size.w;\n";
} else {
- c += " int X = get_global_id(0);\n";
+ c += " int dst_x = get_global_id(0);\n";
}
- c += " int Y = get_global_id(1);\n";
- c += " int Z = get_global_id(2);\n";
- c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) return;\n";
- if (src_tensor_type == TensorStorageType::BUFFER) {
- c += " int f_base = Z * src_size.z * kernel_size.x * kernel_size.y;\n";
- }
- c += " int2 offset = (int2)(X, Y) + padding - k_offset;\n";
- c += " offset.x = offset.x % stride.x;\n";
- c += " offset.y = offset.y % stride.y;\n";
- c += " offset += stride;\n";
- c += " offset.x = offset.x % stride.x;\n";
- c += " offset.y = offset.y % stride.y;\n";
- c += " int2 f_offset;\n";
- c += " f_offset.x = offset.x == 0 ? 0 : stride.x - offset.x;\n";
- c += " f_offset.y = offset.y == 0 ? 0 : stride.y - offset.y;\n";
- c += " ACCUM_FLT4 r0 = (ACCUM_FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n";
- c += " for (int ky = 0; ky < inner_size.y; ++ky) {\n";
- c += " int index_y = ky * stride.y + f_offset.y;\n";
- c += " bool inside_y = index_y < kernel_size.y;\n";
- c += " int s_y = (Y + index_y + padding.y - k_offset.y) / stride.y;\n";
- c += " index_y = kernel_size.y - 1 - index_y;\n";
- c += " bool out_y = s_y < 0 || s_y >= src_size.y;\n";
- c += " for (int kx = 0; kx < inner_size.x; ++kx) {\n";
- c += " int index_x = kx * stride.x + f_offset.x;\n";
- c += " bool inside_kernel = index_x < kernel_size.x && inside_y;\n";
- c += " int s_x = (X + index_x + padding.x - k_offset.x) / stride.x;\n";
- c += " index_x = kernel_size.x - 1 - index_x;\n";
- c += " bool out_x = s_x < 0 || s_x >= src_size.x;\n";
- c += " int kernel_index = index_y * kernel_size.x + index_x;\n";
- c += " if (inside_kernel && !(out_x || out_y)) {\n";
- if (src_tensor_type == TensorStorageType::BUFFER) {
- c += " int f_offset = f_base + kernel_index * src_size.z;\n";
- } else {
- c += " int x_c = kernel_index * src_size.z * 4;\n";
- }
- c += " for (int l = 0; l < src_size.z; ++l) {\n";
- c += " FLT4 src =" + src_tensor.Read4D("s_x", "s_y", "l", batch_id) +
+ c += " int rem_x = dst_x % stride.x;\n";
+ c += " int ceil_x = dst_x / stride.x;\n";
+ c += " dst_x = ceil_x * stride.x * " + std::to_string(block_size.x) +
+ " + rem_x;\n";
+ c += " int dst_y = get_global_id(1);\n";
+ c += " int rem_y = dst_y % stride.y;\n";
+ c += " int ceil_y = dst_y / stride.y;\n";
+ c += " dst_y = ceil_y * stride.y * " + std::to_string(block_size.y) +
+ " + rem_y;\n";
+ c += " int dst_z = get_global_id(2) * " + std::to_string(block_size.z) +
";\n";
- if (src_tensor_type == TensorStorageType::BUFFER) {
- c += " FLT16 f0 = filters[f_offset]; f_offset++;\n";
- } else {
- c += " FLT4 f[4];\n";
- c += " f[0] = READ_IMAGE(filters, smp_none, (int2)(x_c, Z)); "
- "x_c++;\n";
- c += " f[1] = READ_IMAGE(filters, smp_none, (int2)(x_c, Z)); "
- "x_c++;\n";
- c += " f[2] = READ_IMAGE(filters, smp_none, (int2)(x_c, Z)); "
- "x_c++;\n";
- c += " f[3] = READ_IMAGE(filters, smp_none, (int2)(x_c, Z)); "
- "x_c++;\n";
+ c += " if (dst_x >= dst_size.x || dst_y >= dst_size.y || dst_z >= "
+ "dst_size.z) return;\n";
+ if (weights_are_buffer) {
+ c += " int f_base = dst_z * src_size.z * kernel_size.x * kernel_size.y;\n";
}
- c += " CONV(r0, src);\n";
- c += " }\n";
+ for (int i = 0; i < block_size.x * block_size.y * block_size.z; ++i) {
+ c += " ACCUM_FLT4 r" + std::to_string(i) +
+ " = (ACCUM_FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n";
+ }
+ c += " int kernel_first_dst_x = dst_x + padding.x;\n";
+ c += " int kernel_first_dst_y = dst_y + padding.y;\n";
+ c += " int kernel_last_dst_x = kernel_first_dst_x - kernel_size.x;\n";
+ c += " int kernel_last_dst_y = kernel_first_dst_y - kernel_size.y;\n";
+ c += " int offset_x = abs(padding.x);\n";
+ c += " int offset_x_strided = offset_x * stride.x;\n";
+ c += " int src_x = (kernel_first_dst_x + offset_x_strided) / stride.x - "
+ "offset_x;\n";
+ c += " int offset_y = abs(padding.y);\n";
+ c += " int offset_y_strided = offset_y * stride.y;\n";
+ c += " int src_y = (kernel_first_dst_y + offset_y_strided) / stride.y - "
+ "offset_y;\n";
+ c += " int src_as_dst_y = src_y * stride.y;\n";
+ c += " for (;src_as_dst_y > kernel_last_dst_y; src_y -= 1, src_as_dst_y -= "
+ "stride.y) {\n";
+ for (int y = 0; y < block_size.y; ++y) {
+ const std::string yindex = std::to_string(y);
+ c += " int sy" + yindex + " = src_y + " + yindex + ";\n";
+ if (manual_clamp) {
+ c += " bool in_y" + yindex + " = sy" + yindex + " >= 0 && sy" +
+ yindex + " < src_size.y;\n";
+ if (!image_buffer) {
+ c += " sy" + yindex + " = clamp(sy" + yindex +
+ ", 0, src_size.y - 1);\n";
+ }
+ }
+ }
+ c += " int kernel_y = kernel_first_dst_y - src_as_dst_y;\n";
+ c += " int src_as_dst_x = src_x * stride.x;\n";
+ c += " int src_x_copy = src_x;\n";
+ c += " for (;src_as_dst_x > kernel_last_dst_x; src_x_copy -= 1, "
+ "src_as_dst_x "
+ "-= stride.x) {\n";
+ for (int x = 0; x < block_size.x; ++x) {
+ const std::string xindex = std::to_string(x);
+ c += " int sx" + xindex + " = src_x_copy + " + xindex + ";\n";
+ if (manual_clamp) {
+ c += " bool in_x" + xindex + " = sx" + xindex + " >= 0 && sx" +
+ xindex + " < src_size.x;\n";
+ if (!image_buffer) {
+ c += " sx" + xindex + " = clamp(sx" + xindex +
+ ", 0, src_size.x - 1);\n";
+ }
+ }
+ }
+ const std::string layer_offset =
+ std::string("src_size.x * src_size.y") +
+ (op_def.batch_support ? " * src_size.w" : "");
+ for (int y = 0; y < block_size.y; ++y) {
+ const std::string yindex = std::to_string(y);
+ for (int x = 0; x < block_size.x; ++x) {
+ const std::string xindex = std::to_string(x);
+ const std::string id = std::to_string(y * block_size.x + x);
+ if (image_buffer) {
+ c += " " + src_tensor.GetAddress("addr_" + id, "sx" + xindex,
+ "sy" + yindex, "0", batch_id);
+ c += " addr_" + id + " = select(-1, addr_" + id + ", (in_x" +
+ xindex + " && in_y" + yindex + "));\n";
+ c += absl::Substitute(
+ " int dz_$0 = select(0, $3, (in_x$1 && "
+ "in_y$2));\n",
+ y * block_size.x + x, x, y, layer_offset);
+ } else {
+ c += " " + src_tensor.GetAddress("addr_" + id, "sx" + xindex,
+ "sy" + yindex, "0", batch_id);
+ }
+ }
+ }
+ if (src_tensor_type == TensorStorageType::BUFFER) {
+ c += " int dz = " + layer_offset + ";\n";
+ }
+ if (block_size.x == 1 && block_size.y == 1 && manual_clamp) {
+ c += " if (!in_x0 || !in_y0) continue;\n";
+ }
+ c += " int kernel_x = kernel_first_dst_x - src_as_dst_x;\n";
+ c += " int kernel_index = kernel_y * kernel_size.x + kernel_x;\n";
+ if (weights_are_buffer) {
+ c += " int f_offset = f_base + kernel_index * src_size.z * " +
+ std::to_string(block_size.z) + ";\n";
+ } else {
+ c += " int x_c = kernel_index * src_size.z;\n";
+ }
+ c += " for (int s = 0; s < src_size.z; ++s) {\n";
+ const auto mode = GetFastestZeroMode(device);
+ for (int y = 0; y < block_size.y; ++y) {
+ const std::string yindex = std::to_string(y);
+ for (int x = 0; x < block_size.x; ++x) {
+ const std::string xindex = std::to_string(x);
+ const std::string id = std::to_string(y * block_size.x + x);
+ if (image_buffer) {
+ c += " FLT4 src" + id + " = " + src_tensor.Read("addr_" + id) +
+ "; addr_" + id + " += dz_" + id + ";\n";
+ } else if (manual_clamp) {
+ c += " FLT4 src" + id + " = " + src_tensor.Read("addr_" + id) +
+ " * (FLT)(in_x" + xindex + " && in_y" + yindex + "); addr_" + id +
+ " += dz;\n";
+ } else {
+ c += " FLT4 src" + id + " = " +
+ src_tensor.Read4D("sx" + xindex, "sy" + yindex, "s", batch_id,
+ mode) +
+ ";\n";
+ }
+ }
+ }
+ if (weights_are_buffer) {
+ c += " __global FLT16* weights_cache = filters + f_offset;\n";
+ c += " f_offset += " + std::to_string(block_size.z) + ";\n";
+ } else {
+ for (int z = 0; z < block_size.z; ++z) {
+ const std::string fc = "(int2)(dst_z + " + std::to_string(z) + ", x_c)";
+ c += absl::Substitute(
+ R"( FLT4 f$1 = READ_IMAGE(filters0, smp_none, $0);
+ FLT4 f$2 = READ_IMAGE(filters1, smp_none, $0);
+ FLT4 f$3 = READ_IMAGE(filters2, smp_none, $0);
+ FLT4 f$4 = READ_IMAGE(filters3, smp_none, $0);
+)",
+ fc, z * 4 + 0, z * 4 + 1, z * 4 + 2, z * 4 + 3);
+ }
+ c += " x_c++;\n";
+ }
+ for (int z = 0; z < block_size.z; ++z) {
+ for (int i = 0; i < block_size.x * block_size.y; ++i) {
+ c += " CONV" + std::to_string(z) + "(r" +
+ std::to_string(i + z * block_size.x * block_size.y) + ", src" +
+ std::to_string(i) + ");\n";
+ }
+ }
c += " }\n";
c += " }\n";
c += " }\n";
- c += " FLT4 bias_val = " + biases.ReadLinearFLT4("Z") + ";\n";
- c += " FLT4 res0 = TO_FLT4(r0) + bias_val;\n";
- std::string x_3dcoord = op_def.batch_support ? "X * dst_size.w + B" : "X";
- const LinkingContext context{"res0", x_3dcoord, "Y", "Z"};
- c += PostProcess(linked_operations, context);
- c += " " + dst_tensor.Write4D("res0", "X", "Y", "Z", batch_id) + "\n";
+ for (int z = 0; z < block_size.z; ++z) {
+ c += " if (dst_z < dst_size.z) {\n";
+ c += " FLT4 bias_val = " + biases.ReadLinearFLT4("dst_z") + ";\n";
+ for (int y = 0; y < block_size.y; ++y) {
+ for (int x = 0; x < block_size.x; ++x) {
+ const std::string id =
+ std::to_string((z * block_size.y + y) * block_size.x + x);
+ c += " {\n";
+ c += " int xc = dst_x + stride.x * " + std::to_string(x) + ";\n";
+ c += " int yc = dst_y + stride.y * " + std::to_string(y) + ";\n";
+ c += " if (xc < dst_size.x && yc < dst_size.y) {\n";
+ c += " FLT4 res = TO_FLT4(r" + id + ") + bias_val;\n";
+ std::string x_3dcoord =
+ op_def.batch_support ? "xc * dst_size.w + B" : "xc";
+ const LinkingContext context{"res", x_3dcoord, "yc", "dst_z"};
+ c += PostProcess(linked_operations, context);
+ c += " " +
+ dst_tensor.Write4D("res", "xc", "yc", "dst_z", batch_id) + "\n";
+ c += " }\n";
+ c += " }\n";
+ }
+ }
+ c += " }\n";
+ c += " dst_z++;\n";
+ }
c += "}\n";
-
return c;
}
} // namespace
ConvolutionTransposed::ConvolutionTransposed(
- const OperationDef& definition, const ConvolutionTransposedAttributes& attr)
+ const OperationDef& definition, const ConvolutionTransposedAttributes& attr,
+ const CLDevice& device)
: GPUOperation(definition),
+ weights_are_buffer_(device.IsMali()),
kernel_size_(attr.weights.shape.w, attr.weights.shape.h),
stride_(attr.stride.w, attr.stride.h),
padding_(attr.padding.prepended.w, attr.padding.prepended.h),
- src_channels_(attr.weights.shape.i),
- dst_channels_(attr.weights.shape.o) {
- const int inner_size_x = (kernel_size_.x - 1) / stride_.x + 1;
- const int inner_size_y = (kernel_size_.y - 1) / stride_.y + 1;
- inner_size_ = int2(inner_size_x, inner_size_y);
- kernel_offset_ = int2(kernel_size_.x - 1, kernel_size_.y - 1);
-}
+ block_size_(2, 2, 2) {}
-ConvolutionTransposed::ConvolutionTransposed(ConvolutionTransposed&& kernel)
- : GPUOperation(std::move(kernel)),
- biases_(std::move(kernel.biases_)),
- weights_tex2d_(std::move(kernel.weights_tex2d_)),
- weights_buf_(std::move(kernel.weights_buf_)),
- weights_(kernel.weights_),
- kernel_size_(kernel.kernel_size_),
- stride_(kernel.stride_),
- padding_(kernel.padding_),
- kernel_offset_(kernel.kernel_offset_),
- inner_size_(kernel.inner_size_),
- src_channels_(kernel.src_channels_),
- dst_channels_(kernel.dst_channels_),
- kernel_(std::move(kernel.kernel_)),
- work_group_size_(kernel.work_group_size_) {}
+ConvolutionTransposed::ConvolutionTransposed(ConvolutionTransposed&& operation)
+ : GPUOperation(std::move(operation)),
+ biases_(std::move(operation.biases_)),
+ weights_0_(std::move(operation.weights_0_)),
+ weights_1_(std::move(operation.weights_1_)),
+ weights_2_(std::move(operation.weights_2_)),
+ weights_3_(std::move(operation.weights_3_)),
+ weights_buf_(std::move(operation.weights_buf_)),
+ weights_are_buffer_(operation.weights_are_buffer_),
+ kernel_size_(operation.kernel_size_),
+ stride_(operation.stride_),
+ padding_(operation.padding_),
+ block_size_(operation.block_size_),
+ kernel_(std::move(operation.kernel_)),
+ work_group_size_(operation.work_group_size_) {}
ConvolutionTransposed& ConvolutionTransposed::operator=(
- ConvolutionTransposed&& kernel) {
- if (this != &kernel) {
- biases_ = std::move(kernel.biases_);
- weights_tex2d_ = std::move(kernel.weights_tex2d_);
- weights_buf_ = std::move(kernel.weights_buf_);
- std::swap(weights_, kernel.weights_);
- std::swap(kernel_size_, kernel.kernel_size_);
- std::swap(stride_, kernel.stride_);
- std::swap(padding_, kernel.padding_);
- std::swap(kernel_offset_, kernel.kernel_offset_);
- std::swap(inner_size_, kernel.inner_size_);
- std::swap(src_channels_, kernel.src_channels_);
- std::swap(dst_channels_, kernel.dst_channels_);
- kernel_ = std::move(kernel.kernel_);
- std::swap(work_group_size_, kernel.work_group_size_);
- GPUOperation::operator=(std::move(kernel));
+ ConvolutionTransposed&& operation) {
+ if (this != &operation) {
+ biases_ = std::move(operation.biases_);
+ weights_0_ = std::move(operation.weights_0_);
+ weights_1_ = std::move(operation.weights_1_);
+ weights_2_ = std::move(operation.weights_2_);
+ weights_3_ = std::move(operation.weights_3_);
+ weights_buf_ = std::move(operation.weights_buf_);
+ std::swap(weights_are_buffer_, operation.weights_are_buffer_);
+ std::swap(kernel_size_, operation.kernel_size_);
+ std::swap(stride_, operation.stride_);
+ std::swap(padding_, operation.padding_);
+ std::swap(block_size_, operation.block_size_);
+ kernel_ = std::move(operation.kernel_);
+ std::swap(work_group_size_, operation.work_group_size_);
+ GPUOperation::operator=(std::move(operation));
}
return *this;
}
Status ConvolutionTransposed::Compile(const CreationContext& creation_context) {
const auto code = GenerateConvolutionTransposedCode(
- definition_, biases_, *creation_context.device, linked_operations_);
+ definition_, biases_, *creation_context.device, weights_are_buffer_,
+ block_size_, linked_operations_);
+ std::vector<CompilerOptions> options;
+ // options.push_back(CompilerOptions::POWERVR_FP16);
return creation_context.cache->GetOrCreateCLKernel(
- code, "main_function", *creation_context.context,
+ code, "main_function", options, *creation_context.context,
*creation_context.device, &kernel_);
}
Status ConvolutionTransposed::BindArguments() {
kernel_.ResetBindingCounter();
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
- RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_));
+ if (weights_are_buffer_) {
+ RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_buf_.GetMemoryPtr()));
+ } else {
+ RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_0_.GetMemoryPtr()));
+ RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_1_.GetMemoryPtr()));
+ RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_2_.GetMemoryPtr()));
+ RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_3_.GetMemoryPtr()));
+ }
RETURN_IF_ERROR(kernel_.SetMemoryAuto(biases_.GetMemoryPtr()));
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(kernel_size_));
RETURN_IF_ERROR(kernel_.SetBytesAuto(stride_));
RETURN_IF_ERROR(kernel_.SetBytesAuto(padding_));
- RETURN_IF_ERROR(kernel_.SetBytesAuto(kernel_offset_));
- RETURN_IF_ERROR(kernel_.SetBytesAuto(inner_size_));
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHDB()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHDB()));
return OkStatus();
}
int3 ConvolutionTransposed::GetGridSize() const {
- const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
- const int grid_y = dst_[0]->Height();
- const int grid_z = dst_[0]->Depth();
+ const int aligned_w = AlignByN(dst_[0]->Width(), stride_.x * block_size_.x);
+ const int aligned_h = AlignByN(dst_[0]->Height(), stride_.y * block_size_.y);
+ const int grid_x =
+ IntegralDivideRoundUp(aligned_w, block_size_.x) * dst_[0]->Batch();
+ const int grid_y = IntegralDivideRoundUp(aligned_h, block_size_.y);
+ const int grid_z = IntegralDivideRoundUp(dst_[0]->Depth(), block_size_.z);
return int3(grid_x, grid_y, grid_z);
}
@@ -275,7 +400,7 @@
const OperationDef& definition,
const ConvolutionTransposedAttributes& attr,
ConvolutionTransposed* result) {
- *result = ConvolutionTransposed(definition, attr);
+ *result = ConvolutionTransposed(definition, attr, *creation_context.device);
RETURN_IF_ERROR(
result->UploadWeights(attr.weights, creation_context.context));
LinearStorageCreateInfo create_info;
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.h b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.h
index 52d4b89..73fce02 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.h
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.h
@@ -44,8 +44,8 @@
Status Compile(const CreationContext& creation_context) override;
// Move only
- ConvolutionTransposed(ConvolutionTransposed&& kernel);
- ConvolutionTransposed& operator=(ConvolutionTransposed&& kernel);
+ ConvolutionTransposed(ConvolutionTransposed&& operation);
+ ConvolutionTransposed& operator=(ConvolutionTransposed&& operation);
ConvolutionTransposed(const ConvolutionTransposed&) = delete;
ConvolutionTransposed& operator=(const ConvolutionTransposed&) = delete;
@@ -55,7 +55,8 @@
const ConvolutionTransposedAttributes& attr,
ConvolutionTransposed* result);
explicit ConvolutionTransposed(const OperationDef& definition,
- const ConvolutionTransposedAttributes& attr);
+ const ConvolutionTransposedAttributes& attr,
+ const CLDevice& device);
template <DataType T>
Status UploadWeights(const ::tflite::gpu::Tensor<OHWI, T>& weights,
CLContext* context);
@@ -69,17 +70,18 @@
LinearStorage biases_;
- Texture2D weights_tex2d_;
+ Texture2D weights_0_;
+ Texture2D weights_1_;
+ Texture2D weights_2_;
+ Texture2D weights_3_;
Buffer weights_buf_;
- cl_mem weights_;
+ bool weights_are_buffer_;
int2 kernel_size_;
int2 stride_;
int2 padding_;
- int2 kernel_offset_;
- int2 inner_size_;
- int src_channels_;
- int dst_channels_;
+
+ int3 block_size_ = int3(1, 1, 1);
CLKernel kernel_;
int3 work_group_size_ = int3(8, 4, 1);
@@ -88,90 +90,118 @@
template <DataType T>
Status ConvolutionTransposed::UploadWeights(
const ::tflite::gpu::Tensor<OHWI, T>& weights, CLContext* context) {
- const int dst_depth = IntegralDivideRoundUp(dst_channels_, 4);
- const int src_depth = IntegralDivideRoundUp(src_channels_, 4);
+ const int dst_depth =
+ AlignByN(IntegralDivideRoundUp(weights.shape.o, 4), block_size_.z);
+ const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4);
const int kernel_x = kernel_size_.x;
const int kernel_y = kernel_size_.y;
+ int texture_width = dst_depth;
+ int texture_height = src_depth * kernel_x * kernel_y;
const int elements_count = kernel_x * kernel_y * src_depth * dst_depth * 4;
- bool is_buffer_storage =
- definition_.GetPrimaryStorageType() == TensorStorageType::BUFFER;
+ const bool f32_weights = definition_.precision == CalculationsPrecision::F32;
- const int float4_size =
- definition_.precision == CalculationsPrecision::F32 ? 16 : 8;
+ const int float4_size = f32_weights ? 16 : 8;
- if (definition_.GetDataType() == DataType::FLOAT32) {
+ if (f32_weights) {
std::vector<float4> gpu_data(elements_count);
RearrangeWeightsData(weights, absl::MakeSpan(gpu_data));
- if (is_buffer_storage) {
+ if (weights_are_buffer_) {
RETURN_IF_ERROR(CreateReadOnlyBuffer(float4_size * elements_count,
gpu_data.data(), context,
&weights_buf_));
} else {
RETURN_IF_ERROR(CreateTexture2DRGBA(
- definition_.GetDataType(), src_depth * kernel_x * kernel_y * 4,
- dst_depth, gpu_data.data(), context, &weights_tex2d_));
+ definition_.GetDataType(), dst_depth, src_depth * kernel_x * kernel_y,
+ gpu_data.data(), context, &weights_0_));
+ RETURN_IF_ERROR(CreateTexture2DRGBA(
+ definition_.GetDataType(), dst_depth, src_depth * kernel_x * kernel_y,
+ gpu_data.data() + texture_width * texture_height, context,
+ &weights_1_));
+ RETURN_IF_ERROR(CreateTexture2DRGBA(
+ definition_.GetDataType(), dst_depth, src_depth * kernel_x * kernel_y,
+ gpu_data.data() + texture_width * texture_height * 2, context,
+ &weights_2_));
+ RETURN_IF_ERROR(CreateTexture2DRGBA(
+ definition_.GetDataType(), dst_depth, src_depth * kernel_x * kernel_y,
+ gpu_data.data() + texture_width * texture_height * 3, context,
+ &weights_3_));
}
} else {
std::vector<half4> gpu_data(elements_count);
RearrangeWeightsData(weights, absl::MakeSpan(gpu_data));
- if (is_buffer_storage) {
+ if (weights_are_buffer_) {
RETURN_IF_ERROR(CreateReadOnlyBuffer(float4_size * elements_count,
gpu_data.data(), context,
&weights_buf_));
} else {
RETURN_IF_ERROR(CreateTexture2DRGBA(
- definition_.GetDataType(), src_depth * kernel_x * kernel_y * 4,
- dst_depth, gpu_data.data(), context, &weights_tex2d_));
+ definition_.GetDataType(), dst_depth, src_depth * kernel_x * kernel_y,
+ gpu_data.data(), context, &weights_0_));
+ RETURN_IF_ERROR(CreateTexture2DRGBA(
+ definition_.GetDataType(), dst_depth, src_depth * kernel_x * kernel_y,
+ gpu_data.data() + texture_width * texture_height, context,
+ &weights_1_));
+ RETURN_IF_ERROR(CreateTexture2DRGBA(
+ definition_.GetDataType(), dst_depth, src_depth * kernel_x * kernel_y,
+ gpu_data.data() + texture_width * texture_height * 2, context,
+ &weights_2_));
+ RETURN_IF_ERROR(CreateTexture2DRGBA(
+ definition_.GetDataType(), dst_depth, src_depth * kernel_x * kernel_y,
+ gpu_data.data() + texture_width * texture_height * 3, context,
+ &weights_3_));
}
}
- if (is_buffer_storage) {
- weights_ = weights_buf_.GetMemoryPtr();
- } else {
- weights_ = weights_tex2d_.GetMemoryPtr();
- }
-
return OkStatus();
}
template <DataType S, typename T>
void ConvolutionTransposed::RearrangeWeightsData(
const ::tflite::gpu::Tensor<OHWI, S>& weights, absl::Span<T> dst) {
- const int dst_depth = IntegralDivideRoundUp(dst_channels_, 4);
- const int src_depth = IntegralDivideRoundUp(src_channels_, 4);
+ const int dst_depth =
+ AlignByN(IntegralDivideRoundUp(weights.shape.o, 4), block_size_.z);
+ const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4);
const int kernel_x = kernel_size_.x;
const int kernel_y = kernel_size_.y;
+ int texture_width = dst_depth;
+ int texture_height = src_depth * kernel_x * kernel_y;
int counter = 0;
- for (int d = 0; d < dst_depth; ++d) {
+ for (int d = 0; d < dst_depth / block_size_.z; ++d) {
for (int y = 0; y < kernel_y; ++y) {
for (int x = 0; x < kernel_x; ++x) {
for (int s = 0; s < src_depth; ++s) {
- T filters[4];
- for (int j = 0; j < 4; ++j) {
+ for (int sub_d = 0; sub_d < block_size_.z; ++sub_d) {
+ T filters[4];
for (int i = 0; i < 4; ++i) {
- const int s_ch = s * 4 + j;
- const int d_ch = d * 4 + i;
- if (s_ch < src_channels_ && d_ch < dst_channels_) {
- const int f_index =
- weights.shape.LinearIndex({d_ch, y, x, s_ch});
- filters[i][j] = weights.data[f_index];
- } else {
- filters[i][j] = 0.0f;
+ for (int j = 0; j < 4; ++j) {
+ const int s_ch = s * 4 + j;
+ const int d_ch = (d * block_size_.z + sub_d) * 4 + i;
+ if (s_ch < weights.shape.i && d_ch < weights.shape.o) {
+ const int f_index =
+ weights.shape.LinearIndex({d_ch, y, x, s_ch});
+ filters[j][i] = weights.data[f_index];
+ } else {
+ filters[j][i] = 0.0f;
+ }
}
}
- }
- T filters_new[4];
- for (int i = 0; i < 4; ++i) {
- for (int j = 0; j < 4; ++j) {
- filters_new[i][j] = filters[j][i];
+ if (weights_are_buffer_) {
+ dst[counter++] = filters[0];
+ dst[counter++] = filters[1];
+ dst[counter++] = filters[2];
+ dst[counter++] = filters[3];
+ } else {
+ int x_coord = d * block_size_.z + sub_d;
+ int y_coord = (y * kernel_x + x) * src_depth + s;
+ int offset = y_coord * dst_depth + x_coord;
+ dst[offset + texture_width * texture_height * 0] = filters[0];
+ dst[offset + texture_width * texture_height * 1] = filters[1];
+ dst[offset + texture_width * texture_height * 2] = filters[2];
+ dst[offset + texture_width * texture_height * 3] = filters[3];
}
}
- dst[counter++] = filters_new[0];
- dst[counter++] = filters_new[1];
- dst[counter++] = filters_new[2];
- dst[counter++] = filters_new[3];
}
}
}
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.cc b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.cc
index 82f71cc..038b1ec 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_thin.cc
@@ -87,24 +87,17 @@
for (int x = 0; x < kernel_size.x; ++x) {
std::string r_s =
" r[" + std::to_string(y) + "][" + std::to_string(x) + "]";
- const std::string to_accum =
- op_def.precision == CalculationsPrecision::F32_F16 ? "convert_float"
- : "";
for (int d = 0; d < dst_channels; ++d) {
- c += r_s + postfix[d] + " = " + to_accum + "(dot(src, filters[" +
- std::to_string(index) + "]));\n";
+ c += r_s + postfix[d] + " = dot(src, filters[" + std::to_string(index) +
+ "]);\n";
index++;
}
}
}
c += " }\n";
for (int i = 1; i < src_depth; ++i) {
- if (op_def.precision != CalculationsPrecision::F32_F16) {
- c += " if (X > " + std::to_string(-i) +
- ") { // always true, to reduce registers usage\n";
- } else {
- c += " {\n";
- }
+ c += " if (X > " + std::to_string(-i) +
+ ") { // always true, to reduce registers usage\n";
c += " FLT4 src = " +
src_tensor.Read4D("X", "Y", std::to_string(i), batch_id) + ";\n";
for (int y = 0; y < kernel_size.y; ++y) {
@@ -112,8 +105,8 @@
std::string r_s =
" r[" + std::to_string(y) + "][" + std::to_string(x) + "]";
for (int d = 0; d < dst_channels; ++d) {
- c += r_s + postfix[d] + " += TO_ACCUM_FLT(dot(src, filters[" +
- std::to_string(index) + "]));\n";
+ c += r_s + postfix[d] + " += dot(src, filters[" +
+ std::to_string(index) + "]);\n";
index++;
}
}
diff --git a/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.cc b/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.cc
index 1bfa04b..627c781 100644
--- a/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.cc
+++ b/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.cc
@@ -297,6 +297,22 @@
image_desc->image_row_pitch, host_ptr, errcode_ret);
}
}
+
+cl_mem CreateImage3DLegacy(cl_context context, cl_mem_flags flags,
+ const cl_image_format* image_format,
+ const cl_image_desc* image_desc, void* host_ptr,
+ cl_int* errcode_ret) {
+ if (clCreateImage) { // clCreateImage available since OpenCL 1.2
+ return clCreateImage(context, flags, image_format, image_desc, host_ptr,
+ errcode_ret);
+ } else {
+ return clCreateImage3D(context, flags, image_format,
+ image_desc->image_width, image_desc->image_height,
+ image_desc->image_depth, image_desc->image_row_pitch,
+ image_desc->image_slice_pitch, host_ptr,
+ errcode_ret);
+ }
+}
} // namespace cl
} // namespace gpu
} // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h b/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h
index a84cf8b..acfee78 100644
--- a/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h
+++ b/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h
@@ -627,6 +627,13 @@
const cl_image_desc *image_desc, void *host_ptr,
cl_int *errcode_ret);
+// It uses clCreateImage if it available (clCreateImage available since cl 1.2)
+// otherwise it will use legacy clCreateImage3D
+cl_mem CreateImage3DLegacy(cl_context context, cl_mem_flags flags,
+ const cl_image_format *image_format,
+ const cl_image_desc *image_desc, void *host_ptr,
+ cl_int *errcode_ret);
+
} // namespace cl
} // namespace gpu
} // namespace tflite
diff --git a/tensorflow/lite/delegates/gpu/cl/run_tests.sh b/tensorflow/lite/delegates/gpu/cl/run_tests.sh
index c21b61a..16d2feb 100755
--- a/tensorflow/lite/delegates/gpu/cl/run_tests.sh
+++ b/tensorflow/lite/delegates/gpu/cl/run_tests.sh
@@ -61,12 +61,22 @@
ADB shell mkdir -p $OPENCL_DIR
trap "cleanup_device" EXIT
+declare -a BUILD_CONFIG
+abi_version=$(ADB shell getprop ro.product.cpu.abi | tr -d '\r')
+if [[ "$abi_version" == "armeabi-v7a" ]]; then
+#"32 bit"
+BUILD_CONFIG=( --config=android_arm -c opt --copt=-fPIE --linkopt=-pie )
+else
+#"64 bit"
+BUILD_CONFIG=( --config=android_arm64 -c opt )
+fi
+
targets=($(bazel query 'tests('$test_target')'))
num_targets=${#targets[@]}
if ((num_targets == 1)); then
target=${targets[0]}
executable=${target##*:} #finds last token after ':'
- bazel build --config=android_arm64 -c opt $target
+ bazel build "${BUILD_CONFIG[@]}" $target
test_path=$(echo $target | tr : /)
exec_path=bazel-bin/$(echo $test_path | cut -c 3-)
ADB push "$exec_path" $OPENCL_DIR
@@ -77,7 +87,7 @@
for ((i = 0; i < num_targets; i++)); do
target=${targets[i]}
executable=${target##*:} #finds last token after ':'
- bazel build --config=android_arm64 -c opt $target > /dev/null 2>&1
+ bazel build "${BUILD_CONFIG[@]}" $target > /dev/null 2>&1
test_path=$(echo $target | tr : /)
exec_path=bazel-bin/$(echo $test_path | cut -c 3-)
ADB push "$exec_path" $OPENCL_DIR > /dev/null 2>&1
diff --git a/tensorflow/lite/delegates/gpu/cl/testing/run_performance_profiling.sh b/tensorflow/lite/delegates/gpu/cl/testing/run_performance_profiling.sh
index e02ce4f..0fd2d33 100755
--- a/tensorflow/lite/delegates/gpu/cl/testing/run_performance_profiling.sh
+++ b/tensorflow/lite/delegates/gpu/cl/testing/run_performance_profiling.sh
@@ -30,6 +30,7 @@
model_path=""
alias ADB='adb'
+host=""
while [[ "$1" != "" ]]; do
case $1 in
@@ -39,6 +40,10 @@
;;
-d | --device)
shift
+ if [[ "$1" == "HOST" ]]
+ then
+ host="HOST"
+ fi
alias ADB='adb -s '$1''
;;
-h | --help)
@@ -57,19 +62,36 @@
fi
SHELL_DIR=$(dirname "$0")
+BINARY_NAME=performance_profiling
+
+if [[ "$host" == "HOST" ]]
+then
+bazel build -c opt //"$SHELL_DIR":"$BINARY_NAME"
+chmod +x bazel-bin/"$SHELL_DIR"/"$BINARY_NAME"
+./bazel-bin/"$SHELL_DIR"/"$BINARY_NAME" "$model_path"
+exit
+fi
model_name=${model_path##*/} # finds last token after '/'
-declare OPENCL_DIR=/data/local/tmp/profiling_inference/
-declare BINARY_NAME=performance_profiling
+OPENCL_DIR=/data/local/tmp/profiling_inference/
ADB shell mkdir -p $OPENCL_DIR
ADB push "$model_path" "$OPENCL_DIR"
-# push executables and data files to device
-# bazel build --config=android_arm -c opt --copt=-fPIE --linkopt=-pie //$SHELL_DIR:$BINARY_NAME # for 32bit version
-bazel build --config=android_arm64 -c opt //$SHELL_DIR:$BINARY_NAME
+declare -a BUILD_CONFIG
+abi_version=$(ADB shell getprop ro.product.cpu.abi | tr -d '\r')
+if [[ "$abi_version" == "armeabi-v7a" ]]; then
+#"32 bit"
+BUILD_CONFIG=( --config=android_arm -c opt --copt=-fPIE --linkopt=-pie )
+else
+#"64 bit"
+BUILD_CONFIG=( --config=android_arm64 -c opt )
+fi
+
+bazel build "${BUILD_CONFIG[@]}" //$SHELL_DIR:$BINARY_NAME
+
ADB push bazel-bin/$SHELL_DIR/$BINARY_NAME $OPENCL_DIR
ADB shell chmod +x $OPENCL_DIR/$BINARY_NAME
diff --git a/tensorflow/lite/delegates/gpu/metal/BUILD b/tensorflow/lite/delegates/gpu/metal/BUILD
index 4bf4431..e291eba 100644
--- a/tensorflow/lite/delegates/gpu/metal/BUILD
+++ b/tensorflow/lite/delegates/gpu/metal/BUILD
@@ -50,7 +50,6 @@
copts = DEFAULT_COPTS,
sdk_frameworks = [
"Metal",
- "UIKit",
],
deps = [
"//tensorflow/lite/delegates/gpu/common:status",
@@ -155,7 +154,6 @@
copts = DEFAULT_COPTS,
sdk_frameworks = [
"Metal",
- "UIKit",
],
deps = [
":common",
diff --git a/tensorflow/lite/delegates/gpu/metal/environment.mm b/tensorflow/lite/delegates/gpu/metal/environment.mm
index 3bc3b54..27c5110 100644
--- a/tensorflow/lite/delegates/gpu/metal/environment.mm
+++ b/tensorflow/lite/delegates/gpu/metal/environment.mm
@@ -16,7 +16,6 @@
#include "tensorflow/lite/delegates/gpu/metal/environment.h"
#import <Metal/Metal.h>
-#import <UIKit/UIKit.h>
#include <unordered_map>
#include <utility>
@@ -64,6 +63,17 @@
max_feature_set = std::max(max_feature_set, type.second);
}
}
+#elif defined(__MAC_10_5) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_10_5
+ std::vector<std::pair<MTLFeatureSet, int>> features;
+ if (@available(macOS 10.15, *)) {
+ features.emplace_back(MTLFeatureSet_macOS_GPUFamily2_v1, 12);
+ }
+ id<MTLDevice> device = GetBestSupportedMetalDevice();
+ for (auto &type : features) {
+ if ([device supportsFeatureSet:type.first]) {
+ max_feature_set = std::max(max_feature_set, type.second);
+ }
+ }
#endif
switch (max_feature_set) {
case 7:
diff --git a/tensorflow/lite/delegates/gpu/metal_delegate.h b/tensorflow/lite/delegates/gpu/metal_delegate.h
index 6f8767d..032c92c 100644
--- a/tensorflow/lite/delegates/gpu/metal_delegate.h
+++ b/tensorflow/lite/delegates/gpu/metal_delegate.h
@@ -16,6 +16,20 @@
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_DELEGATE_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_DELEGATE_H_
+#ifdef SWIG
+#define TFL_CAPI_EXPORT
+#else
+#if defined(_WIN32)
+#ifdef TFL_COMPILE_LIBRARY
+#define TFL_CAPI_EXPORT __declspec(dllexport)
+#else
+#define TFL_CAPI_EXPORT __declspec(dllimport)
+#endif // TFL_COMPILE_LIBRARY
+#else
+#define TFL_CAPI_EXPORT __attribute__((visibility("default")))
+#endif // _WIN32
+#endif // SWIG
+
#ifdef __cplusplus
extern "C" {
#else
@@ -51,10 +65,11 @@
// When `options` is set to `nullptr`, the following default values are used:
// .precision_loss_allowed = false,
// .wait_type = kPassive,
-TfLiteDelegate* TFLGpuDelegateCreate(const TFLGpuDelegateOptions* options);
+TFL_CAPI_EXPORT extern TfLiteDelegate* TFLGpuDelegateCreate(
+ const TFLGpuDelegateOptions* options);
// Destroys a delegate created with `TFLGpuDelegateCreate` call.
-void TFLGpuDelegateDelete(TfLiteDelegate* delegate);
+TFL_CAPI_EXPORT extern void TFLGpuDelegateDelete(TfLiteDelegate* delegate);
#ifdef __cplusplus
} // extern "C"
diff --git a/tensorflow/lite/delegates/nnapi/java/src/main/java/org/tensorflow/lite/nnapi/NnApiDelegate.java b/tensorflow/lite/delegates/nnapi/java/src/main/java/org/tensorflow/lite/nnapi/NnApiDelegate.java
index 5e1e896..8f815a2 100644
--- a/tensorflow/lite/delegates/nnapi/java/src/main/java/org/tensorflow/lite/nnapi/NnApiDelegate.java
+++ b/tensorflow/lite/delegates/nnapi/java/src/main/java/org/tensorflow/lite/nnapi/NnApiDelegate.java
@@ -25,8 +25,77 @@
private long delegateHandle;
+ /** Delegate options. */
+ public static final class Options {
+ public Options() {}
+
+ /**
+ * undefined, specifies default behavior. so far, the default setting of NNAPI is
+ * EXECUTION_PREFERENCE_FAST_SINGLE_ANSWER
+ */
+ public static final int EXECUTION_PREFERENCE_UNDEFINED = -1;
+
+ /**
+ * Prefer executing in a way that minimizes battery drain. This is desirable for compilations
+ * that will be executed often.
+ */
+ public static final int EXECUTION_PREFERENCE_LOW_POWER = 0;
+
+ /**
+ * Prefer returning a single answer as fast as possible, even if this causes more power
+ * consumption.
+ */
+ public static final int EXECUTION_PREFERENCE_FAST_SINGLE_ANSWER = 1;
+
+ /**
+ * Prefer maximizing the throughput of successive frames, for example when processing successive
+ * frames coming from the camera.
+ */
+ public static final int EXECUTION_PREFERENCE_SUSTAINED_SPEED = 2;
+
+ /**
+ * Sets the inference preference for precision/compilation/runtime tradeoffs.
+ *
+ * @param preference One of EXECUTION_PREFERENCE_LOW_POWER,
+ * EXECUTION_PREFERENCE_FAST_SINGLE_ANSWER, and EXECUTION_PREFERENCE_SUSTAINED_SPEED.
+ */
+ public Options setExecutionPreference(int preference) {
+ this.executionPreference = preference;
+ return this;
+ }
+
+ public Options setAcceleratorName(String name) {
+ this.accelerator_name = name;
+ return this;
+ }
+
+ public Options setCacheDir(String name) {
+ this.cache_dir = name;
+ return this;
+ }
+
+ public Options setModelToken(String name) {
+ this.model_token = name;
+ return this;
+ }
+
+ int executionPreference = EXECUTION_PREFERENCE_UNDEFINED;
+ String accelerator_name = null;
+ String cache_dir = null;
+ String model_token = null;
+ }
+
+ public NnApiDelegate(Options options) {
+ delegateHandle =
+ createDelegate(
+ options.executionPreference,
+ options.accelerator_name,
+ options.cache_dir,
+ options.model_token);
+ }
+
public NnApiDelegate() {
- delegateHandle = createDelegate();
+ this(new Options());
}
@Override
@@ -35,16 +104,22 @@
}
/**
- * The NNAPI delegate is singleton. Nothing to delete for now, so mark the handle invalid only.
+ * Frees TFLite resources in C runtime.
+ *
+ * <p>User is expected to call this method explicitly.
*/
@Override
public void close() {
if (delegateHandle != INVALID_DELEGATE_HANDLE) {
+ deleteDelegate(delegateHandle);
delegateHandle = INVALID_DELEGATE_HANDLE;
}
}
- private static native long createDelegate();
+ private static native long createDelegate(
+ int preference, String device_name, String cache_dir, String model_token);
+
+ private static native void deleteDelegate(long delegateHandle);
static {
// Ensure the native TensorFlow Lite libraries are available.
diff --git a/tensorflow/lite/delegates/nnapi/java/src/main/native/nnapi_delegate_jni.cc b/tensorflow/lite/delegates/nnapi/java/src/main/native/nnapi_delegate_jni.cc
index d68ff5e..65d39b0 100644
--- a/tensorflow/lite/delegates/nnapi/java/src/main/native/nnapi_delegate_jni.cc
+++ b/tensorflow/lite/delegates/nnapi/java/src/main/native/nnapi_delegate_jni.cc
@@ -21,10 +21,47 @@
extern "C" {
#endif // __cplusplus
+using namespace tflite;
+
JNIEXPORT jlong JNICALL
-Java_org_tensorflow_lite_nnapi_NnApiDelegate_createDelegate(JNIEnv* env,
- jclass clazz) {
- return reinterpret_cast<jlong>(tflite::NnApiDelegate());
+Java_org_tensorflow_lite_nnapi_NnApiDelegate_createDelegate(
+ JNIEnv* env, jclass clazz, jint preference, jstring accelerator_name,
+ jstring cache_dir, jstring model_token) {
+ StatefulNnApiDelegate::Options options = StatefulNnApiDelegate::Options();
+ options.execution_preference =
+ (StatefulNnApiDelegate::Options::ExecutionPreference)preference;
+ if (accelerator_name) {
+ options.accelerator_name = env->GetStringUTFChars(accelerator_name, NULL);
+ }
+ if (cache_dir) {
+ options.cache_dir = env->GetStringUTFChars(cache_dir, NULL);
+ }
+ if (model_token) {
+ options.model_token = env->GetStringUTFChars(model_token, NULL);
+ }
+
+ auto delegate = new StatefulNnApiDelegate(options);
+
+ if (options.accelerator_name) {
+ env->ReleaseStringUTFChars(accelerator_name, options.accelerator_name);
+ }
+
+ if (options.cache_dir) {
+ env->ReleaseStringUTFChars(cache_dir, options.accelerator_name);
+ }
+
+ if (options.model_token) {
+ env->ReleaseStringUTFChars(model_token, options.accelerator_name);
+ }
+
+ return reinterpret_cast<jlong>(delegate);
+}
+
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_nnapi_NnApiDelegate_deleteDelegate(JNIEnv* env,
+ jclass clazz,
+ jlong delegate) {
+ delete reinterpret_cast<TfLiteDelegate*>(delegate);
}
#ifdef __cplusplus
diff --git a/tensorflow/lite/examples/experimental_new_converter/stack_trace_example.py b/tensorflow/lite/examples/experimental_new_converter/stack_trace_example.py
index b5ac33a..f0940db 100644
--- a/tensorflow/lite/examples/experimental_new_converter/stack_trace_example.py
+++ b/tensorflow/lite/examples/experimental_new_converter/stack_trace_example.py
@@ -19,6 +19,7 @@
from __future__ import print_function
import sys
+
from absl import app
import tensorflow as tf # TF2
diff --git a/tensorflow/lite/examples/python/label_image.py b/tensorflow/lite/examples/python/label_image.py
index 6c75338..2ef1aa1 100644
--- a/tensorflow/lite/examples/python/label_image.py
+++ b/tensorflow/lite/examples/python/label_image.py
@@ -19,6 +19,7 @@
from __future__ import print_function
import argparse
+
import numpy as np
from PIL import Image
diff --git a/tensorflow/lite/experimental/delegates/testdata/BUILD b/tensorflow/lite/experimental/delegates/testdata/BUILD
new file mode 100644
index 0000000..1935dfc
--- /dev/null
+++ b/tensorflow/lite/experimental/delegates/testdata/BUILD
@@ -0,0 +1,3 @@
+licenses(["notice"])
+
+exports_files(glob(["*.tflite"]))
diff --git a/tensorflow/lite/experimental/delegates/testdata/README.txt b/tensorflow/lite/experimental/delegates/testdata/README.txt
new file mode 100644
index 0000000..b10966d
--- /dev/null
+++ b/tensorflow/lite/experimental/delegates/testdata/README.txt
@@ -0,0 +1 @@
+posenet_mobilenet_v1_100_257x257_multi_kpt_stripped.tflite: downloaded from https://storage.googleapis.com/download.tensorflow.org/models/tflite/posenet_mobilenet_v1_100_257x257_multi_kpt_stripped.tflite
diff --git a/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_lstm_test.py b/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_lstm_test.py
index d4b5e2b..8c48ef6 100644
--- a/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_lstm_test.py
+++ b/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_lstm_test.py
@@ -17,6 +17,7 @@
from __future__ import division
from __future__ import print_function
import tempfile
+
import numpy as np
from six.moves import range
import tensorflow as tf
@@ -236,8 +237,8 @@
"""
converter = tf.lite.TFLiteConverter.from_session(sess, [input_tensor],
[output_tensor])
- tflite = converter.convert()
converter.experimental_new_converter = use_mlir_converter
+ tflite = converter.convert()
interpreter = tf.lite.Interpreter(model_content=tflite)
diff --git a/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_rnn_test.py b/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_rnn_test.py
index b90d4d5..49b0b8c 100644
--- a/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_rnn_test.py
+++ b/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_rnn_test.py
@@ -17,6 +17,7 @@
from __future__ import division
from __future__ import print_function
import tempfile
+
import numpy as np
from six.moves import range
import tensorflow as tf
diff --git a/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py b/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py
index ba936a4..f27086a 100644
--- a/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py
+++ b/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py
@@ -17,6 +17,7 @@
from __future__ import division
from __future__ import print_function
import tempfile
+
import numpy as np
from six.moves import range
import tensorflow as tf
diff --git a/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_rnn_test.py b/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_rnn_test.py
index 49c3d5e..bb16191 100644
--- a/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_rnn_test.py
+++ b/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_rnn_test.py
@@ -17,6 +17,7 @@
from __future__ import division
from __future__ import print_function
import tempfile
+
import numpy as np
from six.moves import range
import tensorflow as tf
diff --git a/tensorflow/lite/experimental/microfrontend/lib/window_util.c b/tensorflow/lite/experimental/microfrontend/lib/window_util.c
index 3e544f5..eee6e7b 100644
--- a/tensorflow/lite/experimental/microfrontend/lib/window_util.c
+++ b/tensorflow/lite/experimental/microfrontend/lib/window_util.c
@@ -14,8 +14,6 @@
==============================================================================*/
#include "tensorflow/lite/experimental/microfrontend/lib/window_util.h"
-// This macro is required to make MSVC defines math constants in math.h
-#define _USE_MATH_DEFINES
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
diff --git a/tensorflow/lite/experimental/ruy/kernel_avx2.cc b/tensorflow/lite/experimental/ruy/kernel_avx2.cc
index fdf94b6..dfc0b1f 100644
--- a/tensorflow/lite/experimental/ruy/kernel_avx2.cc
+++ b/tensorflow/lite/experimental/ruy/kernel_avx2.cc
@@ -35,11 +35,21 @@
RUY_DCHECK(false);
}
+void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params) {
// CPU-ID-based checks should disable the path that would reach this point.
RUY_DCHECK(false);
}
+void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
#else // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM)
static constexpr int kAvx8bitBlockSize = 8;
@@ -346,6 +356,7 @@
void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) {
gemmlowp::ScopedProfilingLabel label("Kernel kAvx2 8-bit");
+
const std::int8_t splitter_idx_data[32] = {
0, 1, 4, 5, 8, 9, 12, 13, //
2, 3, 6, 7, 10, 11, 14, 15, //
@@ -1137,6 +1148,272 @@
} // End col-block loop.
} // NOLINT(readability/fn_size)
+void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) {
+ gemmlowp::ScopedProfilingLabel label("Kernel kAvx2 8-bit GEMV");
+
+ RUY_DCHECK_EQ(params.dst_cols, 1);
+ RUY_DCHECK_EQ(params.last_col, 0);
+ RUY_DCHECK_EQ(params.start_col, 0);
+
+ const std::int8_t splitter_idx_data[32] = {
+ 0, 1, 4, 5, 8, 9, 12, 13, //
+ 2, 3, 6, 7, 10, 11, 14, 15, //
+ 0, 1, 4, 5, 8, 9, 12, 13, //
+ 2, 3, 6, 7, 10, 11, 14, 15 //
+ };
+
+ int bias_ptr_block_increment =
+ params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0;
+
+ const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ void* dst_col_ptr = params.dst_base_ptr;
+ const std::int32_t* bias_col_ptr = params.bias;
+ if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
+ bias_col_ptr += params.start_row;
+ }
+
+ const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
+ void* dst_ptr = dst_col_ptr;
+ const std::int32_t* bias_ptr = bias_col_ptr;
+
+ const std::int32_t lhs_zero_point = params.lhs_zero_point;
+ const bool has_rhs_sums_offsets =
+ (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
+ std::int32_t rhs_sums_offsets[8];
+ if (has_rhs_sums_offsets) {
+ const __m256i rhs_sums_offset_v = _mm256_mullo_epi32(
+ _mm256_set1_epi32(lhs_zero_point),
+ _mm256_loadu_si256(
+ reinterpret_cast<__m256i const*>(¶ms.rhs_sums[0])));
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets),
+ rhs_sums_offset_v);
+ }
+
+ for (int row = params.start_row; row <= params.last_row;
+ row += kAvx8bitBlockSize) {
+ const int residual_rows =
+ std::min(params.dst_rows - row, kAvx8bitBlockSize);
+
+ const __m256i splitter_idx =
+ _mm256_loadu_si256(reinterpret_cast<__m256i const*>(splitter_idx_data));
+
+ __m256i accum_data_v0;
+
+ // Initialize with bias.
+ __m256i initial_accum_data =
+ intrin_utils::mm256_n_loadu_epi32(residual_rows, bias_ptr);
+ bias_ptr += bias_ptr_block_increment;
+
+ // Adjustments common across columns.
+ const std::int32_t rhs_zero_point = params.rhs_zero_point;
+ if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
+ const __m256i lhs_sums_offset = _mm256_mullo_epi32(
+ _mm256_set1_epi32(rhs_zero_point),
+ _mm256_loadu_si256(
+ reinterpret_cast<__m256i const*>(¶ms.lhs_sums[row])));
+ initial_accum_data =
+ _mm256_sub_epi32(initial_accum_data, lhs_sums_offset);
+ }
+ const std::int32_t prod_zp_depth = params.prod_zp_depth;
+ if (prod_zp_depth) {
+ initial_accum_data = _mm256_add_epi32(initial_accum_data,
+ _mm256_set1_epi32(prod_zp_depth));
+ }
+
+ // Adjustments differing across columns.
+ if (has_rhs_sums_offsets) {
+ accum_data_v0 = _mm256_sub_epi32(initial_accum_data,
+ _mm256_set1_epi32(rhs_sums_offsets[0]));
+ } else {
+ accum_data_v0 = initial_accum_data;
+ }
+
+ const std::int8_t* lhs_ptr = lhs_col_ptr;
+ const std::int8_t* rhs_ptr = rhs_col_ptr;
+ for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) {
+ const __m256i lhs_data =
+ _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr));
+ const __m128i rhs_data_8bit = _mm_loadu_si32(rhs_ptr);
+
+ // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
+ // For simplicity we load 4x the data that we need and process twice the
+ // data that we need and store only the data we need.
+ std::int32_t rhs_data[2];
+ const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit);
+ // Now that we have cast the RHS data, we store it so that each value
+ // can be separately loaded in the accumulation loop.
+ _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup);
+
+ const __m256i lhs_data_split =
+ _mm256_shuffle_epi8(lhs_data, splitter_idx);
+ const __m256i lhs_data_split_expand_bottom =
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 0));
+ const __m256i lhs_data_split_expand_top =
+ _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 1));
+
+ // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit.
+ const __m256i lhs_16_bit_low = _mm256_permute2x128_si256(
+ lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20);
+ // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit.
+ const __m256i lhs_16_bit_high = _mm256_permute2x128_si256(
+ lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31);
+ // Accumulate for column 0.
+ const std::int32_t low_rhs_value = rhs_data[0];
+ const std::int32_t high_rhs_value = rhs_data[1];
+
+ const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
+ const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
+
+ accum_data_v0 = _mm256_add_epi32(
+ accum_data_v0, _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
+ accum_data_v0 = _mm256_add_epi32(
+ accum_data_v0,
+ _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
+
+ lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
+ rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
+ }
+
+ if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
+ __m256i m_vector;
+ __m256i e_vector;
+ // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
+ if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) {
+ m_vector = intrin_utils::mm256_n_loadu_epi32(
+ residual_rows, ¶ms.multiplier_fixedpoint[row]);
+ e_vector = intrin_utils::mm256_n_loadu_epi32(
+ residual_rows, ¶ms.multiplier_exponent[row]);
+ } else {
+ // These arrays have size LhsCols, and are pre-filled.
+ m_vector = _mm256_set1_epi32(params.multiplier_fixedpoint[0]);
+ e_vector = _mm256_set1_epi32(params.multiplier_exponent[0]);
+ }
+
+ const __m256i m_64bit_low =
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 0));
+ const __m256i m_64bit_high =
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 1));
+
+ const __m256i zero_vector = _mm256_setzero_si256();
+ const __m256i left_shift = _mm256_max_epi32(e_vector, zero_vector);
+ const __m256i neg_e_vector = _mm256_sub_epi32(zero_vector, e_vector);
+ const __m256i right_shift = _mm256_max_epi32(neg_e_vector, zero_vector);
+ const __m256i final_right_shift =
+ _mm256_add_epi32(right_shift, _mm256_set1_epi32(31));
+ const __m256i final_right_shift_low =
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(final_right_shift, 0));
+ const __m256i final_right_shift_high =
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(final_right_shift, 1));
+ // Really we want 0x100000000, but use half to avoid overflowing.
+ const __m256i convert_to_signed_halved =
+ _mm256_srlv_epi32(_mm256_set1_epi32(0x80000000), right_shift);
+ const __m256i convert_to_unsigned_64 =
+ _mm256_set1_epi64x(0x8000000000000000);
+
+ __m256i post_scaling_offset =
+ _mm256_add_epi32(convert_to_signed_halved, convert_to_signed_halved);
+
+ const __m256i offset_vector =
+ _mm256_slli_epi64(_mm256_set1_epi64x(1), 30);
+ // Really these should be shifted by neg_e_vector, but tests pass when
+ // using right_shift.
+ const __m256i offset_vector_low = _mm256_add_epi64(
+ _mm256_sllv_epi64(
+ offset_vector,
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(right_shift, 0))),
+ convert_to_unsigned_64);
+ const __m256i offset_vector_high = _mm256_add_epi64(
+ _mm256_sllv_epi64(
+ offset_vector,
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(right_shift, 1))),
+ convert_to_unsigned_64);
+
+ if (params.dst_zero_point) {
+ const __m256i dst_zero_point = _mm256_set1_epi32(params.dst_zero_point);
+ // The post-scaling offset is subtracted later, so this has the effect
+ // of adding the zero point.
+ post_scaling_offset =
+ _mm256_sub_epi32(post_scaling_offset, dst_zero_point);
+ }
+
+#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING)
+ RUY_DCHECK(false);
+#endif
+ const __m256i repack_perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
+
+ // See GEMM version for details of this process.
+ {
+ __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v0, left_shift);
+ // Apply the fixed-point part of the multiplier.
+ __m256i scaled_v_low = _mm256_mul_epi32(
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
+ m_64bit_low);
+ __m256i scaled_v_high = _mm256_mul_epi32(
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
+ m_64bit_high);
+
+ scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low);
+ scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high);
+
+ scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
+ scaled_v_high =
+ _mm256_srlv_epi64(scaled_v_high, final_right_shift_high);
+
+ scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
+ __m256i results = _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
+ results = _mm256_permutevar8x32_epi32(results, repack_perm);
+
+ accum_data_v0 = _mm256_sub_epi32(results, post_scaling_offset);
+ }
+ }
+ const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max);
+ const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min);
+
+ if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
+ std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
+ __m256 result = accum_data_v0;
+ result = _mm256_min_epi32(result, clamp_max_v);
+ result = _mm256_max_epi32(result, clamp_min_v);
+ intrin_utils::mm256_n_storeu_cvtepi32_epi8(tmp_ptr, residual_rows,
+ result);
+ dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
+ std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
+ __m256 result = accum_data_v0;
+ result = _mm256_min_epi32(result, clamp_max_v);
+ result = _mm256_max_epi32(result, clamp_min_v);
+ intrin_utils::mm256_n_storeu_cvtepi32_epi8(tmp_ptr, residual_rows,
+ result);
+ dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
+ std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
+ __m256 result = accum_data_v0;
+ result = _mm256_min_epi32(result, clamp_max_v);
+ result = _mm256_max_epi32(result, clamp_min_v);
+ intrin_utils::mm256_n_storeu_cvtepi32_epi16(tmp_ptr, residual_rows,
+ result);
+ dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
+ std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr);
+ intrin_utils::mm256_n_storeu_epi32(dst_block_ptr, residual_rows,
+ accum_data_v0);
+ dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else {
+ RUY_DCHECK(false);
+ }
+
+ lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride;
+ } // End row-block loop.
+
+ dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
+ kAvx8bitBlockSize * params.dst_stride);
+ rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride;
+} // NOLINT(readability/fn_size)
+
void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params) {
gemmlowp::ScopedProfilingLabel label("Kernel kAvx2 float");
@@ -1274,6 +1551,111 @@
} // End col-block terminal conditional.
}
+void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params) {
+ gemmlowp::ScopedProfilingLabel label("Kernel kAvx2 float GEMV");
+
+ RUY_DCHECK_EQ(params.dst_cols, 1);
+ RUY_DCHECK_EQ(params.last_col, 0);
+ RUY_DCHECK_EQ(params.start_col, 0);
+
+ // As parameters are defined, we need to scale by sizeof(float).
+ const std::int64_t lhs_stride = params.lhs_stride >> 2;
+ //
+ int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0;
+ // AVX2 float block size = 8.
+ const int end_row = std::min(params.dst_rows, params.last_row + 8);
+
+ float* adj_dst_col_ptr = params.dst_base_ptr - params.start_row;
+ const float* adj_lhs_col_ptr =
+ params.lhs_base_ptr - params.start_row * lhs_stride;
+ const float* bias_col_ptr = params.bias;
+
+ const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max);
+ const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min);
+
+ __m256 accum_data_v;
+
+ const float* rhs_col_ptr = params.rhs_base_ptr;
+ float* dst_col_ptr = adj_dst_col_ptr;
+
+ int row = params.start_row;
+ for (; row <= end_row - 8; row += 8) {
+ const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
+ float* dst_ptr = dst_col_ptr + row;
+ const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
+
+ // Initialize with bias.
+ accum_data_v = _mm256_loadu_ps(bias_ptr);
+
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr;
+ int d = 0;
+ for (; d <= params.depth - 4; d += 4) {
+ const __m256 lhs_data_0 = _mm256_loadu_ps(lhs_ptr);
+ const __m256 dup_rhs_element_0 = _mm256_set1_ps(rhs_ptr[0]);
+ accum_data_v =
+ _mm256_fmadd_ps(lhs_data_0, dup_rhs_element_0, accum_data_v);
+ const __m256 dup_rhs_element_1 = _mm256_set1_ps(rhs_ptr[8]);
+ const __m256 lhs_data_1 = _mm256_loadu_ps(lhs_ptr + 8);
+ accum_data_v =
+ _mm256_fmadd_ps(lhs_data_1, dup_rhs_element_1, accum_data_v);
+
+ const __m256 lhs_data_2 = _mm256_loadu_ps(lhs_ptr + 16);
+ const __m256 dup_rhs_element_2 = _mm256_set1_ps(rhs_ptr[16]);
+ accum_data_v =
+ _mm256_fmadd_ps(lhs_data_2, dup_rhs_element_2, accum_data_v);
+ const __m256 dup_rhs_element_3 = _mm256_set1_ps(rhs_ptr[24]);
+ const __m256 lhs_data_3 = _mm256_loadu_ps(lhs_ptr + 24);
+ accum_data_v =
+ _mm256_fmadd_ps(lhs_data_3, dup_rhs_element_3, accum_data_v);
+ lhs_ptr += 32; // Loaded 8 * 4 floats.
+ rhs_ptr += 32;
+ }
+ for (; d < params.depth; ++d) {
+ const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
+ const float* rhs_data = rhs_ptr;
+
+ const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]);
+ accum_data_v = _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v);
+ lhs_ptr += 8;
+ rhs_ptr += 8;
+ }
+
+ accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v);
+ accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v);
+ _mm256_storeu_ps(dst_ptr, accum_data_v);
+ } // End row-block loop.
+
+ if (row < end_row) {
+ const int residual_rows = end_row - row;
+ RUY_CHECK_GE(residual_rows, 1);
+ RUY_CHECK_LT(residual_rows, 8);
+
+ const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
+ float* dst_ptr = dst_col_ptr + row;
+ const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
+
+ // Initialize with bias.
+ accum_data_v = intrin_utils::mm256_n_loadu_ps(residual_rows, bias_ptr);
+
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr;
+ for (int d = 0; d < params.depth; ++d) {
+ const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
+ const float* rhs_data = rhs_ptr;
+
+ const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]);
+ accum_data_v = _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v);
+ lhs_ptr += 8;
+ rhs_ptr += 8;
+ }
+
+ accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v);
+ accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v);
+ intrin_utils::mm256_n_storeu_ps(dst_ptr, residual_rows, accum_data_v);
+ } // End handling of residual rows.
+}
+
#endif // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM)
} // namespace ruy
diff --git a/tensorflow/lite/experimental/ruy/kernel_avx512.cc b/tensorflow/lite/experimental/ruy/kernel_avx512.cc
index a0b3afc..f74f338 100644
--- a/tensorflow/lite/experimental/ruy/kernel_avx512.cc
+++ b/tensorflow/lite/experimental/ruy/kernel_avx512.cc
@@ -35,11 +35,21 @@
RUY_DCHECK(false);
}
+void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) {
// CPU-ID-based checks should disable the path that would reach this point.
RUY_DCHECK(false);
}
+void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& params) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
#else // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)
void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) {
@@ -1039,6 +1049,232 @@
} // End col-block loop.
} // NOLINT(readability/fn_size)
+void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) {
+ gemmlowp::ScopedProfilingLabel label("Kernel kAvx512 8-bit GEMV");
+
+ RUY_DCHECK_EQ(params.dst_cols, 1);
+ RUY_DCHECK_EQ(params.last_col, 0);
+ RUY_DCHECK_EQ(params.start_col, 0);
+
+ std::int32_t dst_stride;
+ if ((params.dst_type_id == DstTypeId<std::int8_t>::kValue) ||
+ (params.dst_type_id == DstTypeId<std::uint8_t>::kValue)) {
+ dst_stride = params.dst_stride;
+ } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
+ dst_stride = params.dst_stride / sizeof(std::int16_t);
+ } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
+ dst_stride = params.dst_stride / sizeof(std::int32_t);
+ } else {
+ RUY_DCHECK(false);
+ }
+
+ int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 16 : 0;
+
+ const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ void* dst_col_ptr = params.dst_base_ptr;
+ const std::int32_t* bias_col_ptr = params.bias;
+ if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
+ bias_col_ptr += params.start_row;
+ }
+
+ const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
+ void* dst_ptr = dst_col_ptr;
+ const std::int32_t* bias_ptr = bias_col_ptr;
+
+ const std::int32_t lhs_zero_point = params.lhs_zero_point;
+ const bool has_rhs_sums_offsets =
+ (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
+ std::int32_t rhs_sums_offsets[16];
+ if (has_rhs_sums_offsets) {
+ const __m512i rhs_sums_offset_v =
+ _mm512_mullo_epi32(_mm512_set1_epi32(lhs_zero_point),
+ _mm512_loadu_epi32(¶ms.rhs_sums[0]));
+ _mm512_storeu_si512(reinterpret_cast<__m512i*>(rhs_sums_offsets),
+ rhs_sums_offset_v);
+ }
+
+ for (int row = params.start_row; row <= params.last_row; row += 16) {
+ const int residual_rows = std::min(params.dst_rows - row, 16);
+
+ __m512i accum_data_v0;
+
+ // Initialize with bias.
+ const __mmask16 row_mask =
+ (static_cast<std::uint32_t>(1) << residual_rows) - 1;
+ __m512i initial_accum_data = _mm512_maskz_loadu_epi32(row_mask, bias_ptr);
+ bias_ptr += bias_ptr_block_increment;
+
+ const std::int32_t rhs_zero_point = params.rhs_zero_point;
+ if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
+ const __m512i lhs_sums_offset =
+ _mm512_mullo_epi32(_mm512_set1_epi32(rhs_zero_point),
+ _mm512_loadu_epi32(¶ms.lhs_sums[row]));
+ initial_accum_data =
+ _mm512_sub_epi32(initial_accum_data, lhs_sums_offset);
+ }
+
+ const std::int32_t prod_zp_depth = params.prod_zp_depth;
+ if (prod_zp_depth != 0) {
+ initial_accum_data = _mm512_add_epi32(initial_accum_data,
+ _mm512_set1_epi32(prod_zp_depth));
+ }
+
+ // Adjustments differing across columns.
+ if (has_rhs_sums_offsets) {
+ accum_data_v0 = _mm512_sub_epi32(initial_accum_data,
+ _mm512_set1_epi32(rhs_sums_offsets[0]));
+ } else {
+ accum_data_v0 = initial_accum_data;
+ }
+
+ const std::int8_t* lhs_ptr = lhs_col_ptr;
+ const std::int8_t* rhs_ptr = rhs_col_ptr;
+ for (int d = 0; d < params.depth; d += 4) {
+ const __m512i lhs_data = _mm512_loadu_epi8(lhs_ptr);
+ const __m128i rhs_data_8bit = _mm_loadu_epi8(rhs_ptr);
+
+ // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
+ // For simplicity we load 4x the data that we need and process twice the
+ // data that we need and store only the data we need.
+ std::int32_t rhs_data[2];
+ const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit);
+ // Now that we have cast the RHS data, we store it so that each value
+ // can be separately loaded in the accumulation loop.
+ _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup);
+
+ // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit.
+ const __m512i lhs_16_bit_low =
+ _mm512_cvtepi8_epi16(_mm512_cvtepi32_epi16(lhs_data));
+ // Take bytes 2, 3, 6, 7, 10, 11, ... and expand to 16-bit.
+ const __m512i lhs_16_bit_high = _mm512_cvtepi8_epi16(
+ _mm512_cvtepi32_epi16(_mm512_srli_epi32(lhs_data, 16)));
+
+ // Process column 0.
+ __m512i accum_v = accum_data_v0;
+ constexpr int index = 0;
+
+ const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]);
+ const __m512i rhs_16_bit_dup_high =
+ _mm512_set1_epi32(rhs_data[index + 1]);
+
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
+ accum_v = _mm512_add_epi32(
+ accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
+ accum_data_v0 = accum_v;
+
+ lhs_ptr += 16 * 4;
+ rhs_ptr += 16 * 4;
+ }
+
+ if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
+ __m512i m_vector;
+ __m512i e_vector;
+ // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
+ if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) {
+ m_vector = _mm512_maskz_loadu_epi32(row_mask,
+ ¶ms.multiplier_fixedpoint[row]);
+ e_vector = _mm512_maskz_loadu_epi32(row_mask,
+ ¶ms.multiplier_exponent[row]);
+ } else {
+ // These arrays have size LhsCols, and are pre-filled.
+ m_vector = _mm512_set1_epi32(params.multiplier_fixedpoint[0]);
+ e_vector = _mm512_set1_epi32(params.multiplier_exponent[0]);
+ }
+
+ const __m512i m_64bit_low =
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 0));
+ const __m512i m_64bit_high =
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 1));
+
+ const __m512i zero_vector = _mm512_setzero_epi32();
+ const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector);
+ const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector);
+ const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector);
+ const __m512i final_right_shift =
+ _mm512_add_epi32(right_shift, _mm512_set1_epi32(31));
+ const __m512i final_right_shift_low = _mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(final_right_shift, 0));
+ const __m512i final_right_shift_high = _mm512_cvtepi32_epi64(
+ _mm512_extracti32x8_epi32(final_right_shift, 1));
+
+ const __m512i offset_vector = _mm512_slli_epi64(_mm512_set1_epi64(1), 30);
+ // Really these should be shifted by neg_e_vector, but tests pass when
+ // using right_shift.
+ const __m512i offset_vector_low = _mm512_sllv_epi64(
+ offset_vector,
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0)));
+ const __m512i offset_vector_high = _mm512_sllv_epi64(
+ offset_vector,
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1)));
+
+ // Shift and round column 0.
+ accum_data_v0 = _mm512_sllv_epi32(accum_data_v0, left_shift);
+ // Apply the fixed-point part of the multiplier.
+ __m512i scaled_v_low = _mm512_mul_epi32(
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum_data_v0, 0)),
+ m_64bit_low);
+ __m512i scaled_v_high = _mm512_mul_epi32(
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum_data_v0, 1)),
+ m_64bit_high);
+
+ scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low);
+ scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high);
+
+ scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
+ scaled_v_high = _mm512_srav_epi64(scaled_v_high, final_right_shift_high);
+
+ accum_data_v0 =
+ _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
+ accum_data_v0 = _mm512_inserti32x8(
+ accum_data_v0, _mm512_cvtepi64_epi32(scaled_v_high), 1);
+#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING)
+ RUY_DCHECK(false);
+#endif
+
+ if (params.dst_zero_point != 0) {
+ __m512i dst_zero_point = _mm512_set1_epi32(params.dst_zero_point);
+ accum_data_v0 = _mm512_add_epi32(accum_data_v0, dst_zero_point);
+ }
+ }
+
+ const __m512i clamp_max_v = _mm512_set1_epi32(params.clamp_max);
+ const __m512i clamp_min_v = _mm512_set1_epi32(params.clamp_min);
+
+ if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
+ std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
+ __m512i result = accum_data_v0;
+ result = _mm512_min_epi32(result, clamp_max_v);
+ result = _mm512_max_epi32(result, clamp_min_v);
+ _mm_mask_storeu_epi8(tmp_ptr, row_mask, _mm512_cvtepi32_epi8(result));
+ dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) + 16);
+ } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
+ std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
+ __m512i result = accum_data_v0;
+ result = _mm512_min_epi32(result, clamp_max_v);
+ result = _mm512_max_epi32(result, clamp_min_v);
+ _mm_mask_storeu_epi8(tmp_ptr, row_mask, _mm512_cvtepi32_epi8(result));
+ dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) + 16);
+ } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
+ std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
+ __m512i result = accum_data_v0;
+ result = _mm512_min_epi32(result, clamp_max_v);
+ result = _mm512_max_epi32(result, clamp_min_v);
+ _mm256_mask_storeu_epi16(tmp_ptr, row_mask,
+ _mm512_cvtepi32_epi16(result));
+ dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) + 16);
+ } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
+ std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
+ _mm512_mask_storeu_epi32(tmp_ptr, row_mask, accum_data_v0);
+ dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) + 16);
+ } else {
+ RUY_DCHECK(false);
+ }
+
+ lhs_col_ptr += 16 * params.lhs_stride;
+ } // End row-block loop.
+} // NOLINT(readability/fn_size)
+
void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) {
gemmlowp::ScopedProfilingLabel label("Kernel kAvx512 float");
@@ -1495,6 +1731,90 @@
} // Residual cols.
}
+void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& params) {
+ gemmlowp::ScopedProfilingLabel label("Kernel kAvx512 float GEMV");
+
+ RUY_DCHECK_EQ(params.dst_cols, 1);
+ RUY_DCHECK_EQ(params.last_col, 0);
+ RUY_DCHECK_EQ(params.start_col, 0);
+
+ // As parameters are defined, we need to scale by sizeof(float).
+ const std::int64_t lhs_stride = params.lhs_stride >> 2;
+
+ int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0;
+ const int end_row = std::min(params.dst_rows, params.last_row + 16);
+
+ float* adj_dst_col_ptr = params.dst_base_ptr - params.start_row;
+ const float* adj_lhs_col_ptr =
+ params.lhs_base_ptr - params.start_row * lhs_stride;
+ const float* bias_col_ptr = params.bias;
+
+ const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max);
+ const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min);
+
+ __m512 accum_data_v;
+
+ const float* rhs_col_ptr = params.rhs_base_ptr;
+ float* dst_col_ptr = adj_dst_col_ptr;
+
+ int row = params.start_row;
+ for (; row <= end_row - 16; row += 16) {
+ const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
+ float* dst_ptr = dst_col_ptr + row;
+ const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
+
+ // Initialize with bias.
+ accum_data_v = _mm512_loadu_ps(bias_ptr);
+
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr;
+ for (int d = 0; d < params.depth; ++d) {
+ const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
+ const float rhs_data = *rhs_ptr;
+
+ const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data);
+ accum_data_v = _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v);
+ lhs_ptr += 16;
+ rhs_ptr += 16;
+ }
+
+ accum_data_v = _mm512_min_ps(accum_data_v, clamp_max_v);
+ accum_data_v = _mm512_max_ps(accum_data_v, clamp_min_v);
+ _mm512_storeu_ps(dst_ptr, accum_data_v);
+ } // End row-block loop.
+
+ if (row < end_row) {
+ const int residual_rows = end_row - row;
+ RUY_CHECK_GE(residual_rows, 1);
+ RUY_CHECK_LT(residual_rows, 16);
+
+ const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
+ float* dst_ptr = dst_col_ptr + row;
+ const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
+
+ // Initialize with bias.
+ const __mmask16 row_mask =
+ (static_cast<std::uint32_t>(1) << residual_rows) - 1;
+ accum_data_v = _mm512_maskz_loadu_ps(row_mask, bias_ptr);
+
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr;
+ for (int d = 0; d < params.depth; ++d) {
+ const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
+ const float rhs_data = *rhs_ptr;
+
+ const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data);
+ accum_data_v = _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v);
+ lhs_ptr += 16;
+ rhs_ptr += 16;
+ }
+
+ accum_data_v = _mm512_min_ps(accum_data_v, clamp_max_v);
+ accum_data_v = _mm512_max_ps(accum_data_v, clamp_min_v);
+ _mm512_mask_storeu_ps(dst_ptr, row_mask, accum_data_v);
+ } // End handling of residual rows.
+}
+
#endif // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)
} // namespace ruy
diff --git a/tensorflow/lite/experimental/ruy/kernel_x86.h b/tensorflow/lite/experimental/ruy/kernel_x86.h
index 78dcffb..6564875 100644
--- a/tensorflow/lite/experimental/ruy/kernel_x86.h
+++ b/tensorflow/lite/experimental/ruy/kernel_x86.h
@@ -32,6 +32,7 @@
#if RUY_PLATFORM(X86)
void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params);
+void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params);
template <typename DstScalar>
struct Kernel<Path::kAvx512, std::int8_t, std::int8_t, DstScalar,
@@ -48,11 +49,16 @@
KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col,
dst, ¶ms);
- Kernel8bitAvx512(params);
+ if (dst->layout.cols == 1) {
+ Kernel8bitAvx512SingleCol(params);
+ } else {
+ Kernel8bitAvx512(params);
+ }
}
};
void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params);
+void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& param);
template <>
struct Kernel<Path::kAvx512, float, float, float, BasicSpec<float, float>> {
@@ -66,11 +72,16 @@
KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row,
end_col, dst, ¶ms);
- KernelFloatAvx512(params);
+ if (dst->layout.cols == 1) {
+ KernelFloatAvx512SingleCol(params);
+ } else {
+ KernelFloatAvx512(params);
+ }
}
};
void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params);
+void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params);
template <typename DstScalar>
struct Kernel<Path::kAvx2, std::int8_t, std::int8_t, DstScalar,
@@ -87,11 +98,16 @@
KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col,
dst, ¶ms);
- Kernel8bitAvx2(params);
+ if (dst->layout.cols == 1) {
+ Kernel8bitAvx2SingleCol(params);
+ } else {
+ Kernel8bitAvx2(params);
+ }
}
};
void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params);
+void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params);
template <>
struct Kernel<Path::kAvx2, float, float, float, BasicSpec<float, float>> {
@@ -105,7 +121,11 @@
KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row,
end_col, dst, ¶ms);
- KernelFloatAvx2(params);
+ if (dst->layout.cols == 1) {
+ KernelFloatAvx2SingleCol(params);
+ } else {
+ KernelFloatAvx2(params);
+ }
}
};
#endif // RUY_PLATFORM(X86)
diff --git a/tensorflow/lite/java/BUILD b/tensorflow/lite/java/BUILD
index 110a1f8..704c74a 100644
--- a/tensorflow/lite/java/BUILD
+++ b/tensorflow/lite/java/BUILD
@@ -41,14 +41,6 @@
android_library = ":tensorflowlite_flex",
)
-# DEPRECATED: AAR target that supports TensorFlow op execution with TFLite.
-# Please use `tensorflowlite-select-tf-ops` instead (along with the standard
-# `tensorflowlite` AAR).
-aar_with_jni(
- name = "tensorflow-lite-with-select-tf-ops",
- android_library = ":tensorflowlite_flex_deprecated",
-)
-
# EXPERIMENTAL: AAR target for GPU acceleration. Note that this .aar contains
# *only* the GPU delegate; clients must also include the core `tensorflow-lite`
# runtime.
@@ -86,22 +78,6 @@
],
)
-# DEPRECATED: Android target that supports TensorFlow op execution with TFLite.
-# Please use `tensorflowlite_flex`.
-android_library(
- name = "tensorflowlite_flex_deprecated",
- srcs = JAVA_SRCS + [
- "//tensorflow/lite/delegates/flex/java/src/main/java/org/tensorflow/lite/flex:flex_delegate",
- ],
- manifest = "AndroidManifest.xml",
- proguard_specs = ["proguard.flags"],
- deps = [
- ":tensorflowlite",
- ":tensorflowlite_native_flex",
- "@org_checkerframework_qual",
- ],
-)
-
# EXPERIMENTAL: Android target target for GPU acceleration. Note that this
# library contains *only* the GPU delegate and its Java wrapper; clients must
# also include the core `tensorflowlite` runtime.
diff --git a/tensorflow/lite/java/aar_with_jni.bzl b/tensorflow/lite/java/aar_with_jni.bzl
index e33479e..71da735 100644
--- a/tensorflow/lite/java/aar_with_jni.bzl
+++ b/tensorflow/lite/java/aar_with_jni.bzl
@@ -72,12 +72,12 @@
for src in headers:
if flatten_headers:
cmd += """
- cp -rL $$origdir/$(location {0}) headers/$$(basename $(location {0}))
+ cp -RL $$origdir/$(location {0}) headers/$$(basename $(location {0}))
""".format(src)
else:
cmd += """
mkdir -p headers/$$(dirname $(location {0}))
- cp -rL $$origdir/$(location {0}) headers/$(location {0})
+ cp -RL $$origdir/$(location {0}) headers/$(location {0})
""".format(src)
cmd += "zip -r $$origdir/$(location :{0}.aar) headers".format(name)
diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD
index ca794fd..7d86af5 100644
--- a/tensorflow/lite/kernels/BUILD
+++ b/tensorflow/lite/kernels/BUILD
@@ -518,6 +518,7 @@
":eigen_support",
":kernel_util",
":lstm_eval",
+ ":lstm_shared",
":op_macros",
":padding",
"//tensorflow/lite:framework",
@@ -613,6 +614,12 @@
)
cc_library(
+ name = "lstm_shared",
+ hdrs = ["lstm_shared.h"],
+ copts = tflite_copts(),
+)
+
+cc_library(
name = "builtin_ops",
srcs = ["register.cc"],
hdrs = [
diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_params.h b/tensorflow/lite/kernels/cpu_backend_gemm_params.h
index 763e931..66700ea 100644
--- a/tensorflow/lite/kernels/cpu_backend_gemm_params.h
+++ b/tensorflow/lite/kernels/cpu_backend_gemm_params.h
@@ -47,6 +47,10 @@
// The zero_point, i.e. which Scalar value is to be interpreted as zero.
// When Scalar is floating-point, this must be 0.
Scalar zero_point = 0;
+ // Indicate whether the underlying data will remain unchanged for
+ // some period of time. Defaults to false, but should be set to true
+ // for unchanging data (e.g. weights buffers in many cases)
+ bool cacheable = false;
};
// Enumeration of broad categories of Gemm.
diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_ruy.h b/tensorflow/lite/kernels/cpu_backend_gemm_ruy.h
index f3b2430..4e1158b 100644
--- a/tensorflow/lite/kernels/cpu_backend_gemm_ruy.h
+++ b/tensorflow/lite/kernels/cpu_backend_gemm_ruy.h
@@ -41,6 +41,7 @@
// It does care whether we assign to it a Scalar* or a const Scalar*.
dst->data = data_ptr;
dst->zero_point = params.zero_point;
+ dst->cacheable = params.cacheable;
}
template <typename GemmParamsType, typename RuySpecType>
diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD
index 13224fa..7919df2 100644
--- a/tensorflow/lite/kernels/internal/BUILD
+++ b/tensorflow/lite/kernels/internal/BUILD
@@ -191,6 +191,13 @@
)
config_setting(
+ name = "windows",
+ values = {
+ "cpu": "x64_windows",
+ },
+)
+
+config_setting(
name = "raspberry_pi_with_neon",
define_values = {
"raspberry_pi_with_neon": "true",
@@ -263,6 +270,7 @@
":darwin": tflite_deps_intel,
":darwin_x86_64": tflite_deps_intel,
":freebsd": tflite_deps_intel,
+ ":windows": tflite_deps_intel,
"//conditions:default": [],
}),
)
@@ -309,6 +317,7 @@
":darwin": tflite_deps_intel,
":darwin_x86_64": tflite_deps_intel,
":freebsd": tflite_deps_intel,
+ ":windows": tflite_deps_intel,
"//conditions:default": [],
}),
)
@@ -480,6 +489,7 @@
":darwin": tflite_deps_intel,
":darwin_x86_64": tflite_deps_intel,
":freebsd": tflite_deps_intel,
+ ":windows": tflite_deps_intel,
"//conditions:default": [],
}),
)
@@ -541,6 +551,7 @@
":darwin": tflite_deps_intel,
":darwin_x86_64": tflite_deps_intel,
":freebsd": tflite_deps_intel,
+ ":windows": tflite_deps_intel,
"//conditions:default": [],
}),
)
@@ -738,6 +749,7 @@
":freebsd": [
":sse_tensor_utils",
],
+ ":windows": [":sse_tensor_utils"],
"//conditions:default": [
":portable_tensor_utils",
],
@@ -974,6 +986,7 @@
":darwin": tflite_deps_intel,
":darwin_x86_64": tflite_deps_intel,
":freebsd": tflite_deps_intel,
+ ":windows": tflite_deps_intel,
"//conditions:default": [],
}),
)
diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc
index 9622f30..7371a9f 100644
--- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc
+++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc
@@ -13,7 +13,6 @@
limitations under the License.
==============================================================================*/
#include <sys/types.h>
-#include <unistd.h>
#include <algorithm>
#include <cmath>
@@ -962,6 +961,7 @@
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
lhs_params.rows = n_output;
lhs_params.cols = n_input;
+ lhs_params.cacheable = true;
MatrixParams<int8_t> rhs_params;
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
@@ -1817,54 +1817,6 @@
free(aligned_vec_free);
}
-void NeonVectorVectorCwiseProduct(const float* vector1, const float* vector2,
- int v_size, float* result) {
- // If v_size is not divisible by the vector size, then we need to process the
- // final few elements sequentially. postamble_start shows the start index
- // where this should happen.
- const int postamble_start =
- RoundDownVectors<kFloatValuesPerNeonVector>(v_size);
- int v = 0;
- for (; v < postamble_start; v += kFloatValuesPerNeonVector) {
- // Load 4 float values from vector1 and vector2.
- const float32x4_t v1_f32x4 = vld1q_f32(vector1 + v);
- const float32x4_t v2_f32x4 = vld1q_f32(vector2 + v);
- // Vector multiply 4 float
- const float32x4_t mul_32x4 = vmulq_f32(v1_f32x4, v2_f32x4);
- // Save to result array.
- vst1q_f32(result + v, mul_32x4);
- }
-#pragma clang loop vectorize(disable) unroll(disable)
- for (; v < v_size; v++) {
- result[v] = vector1[v] * vector2[v];
- }
-}
-
-void NeonVectorVectorCwiseProductAccumulate(const float* vector1,
- const float* vector2, int v_size,
- float* result) {
- // If v_size is not divisible by the vector size, then we need to process the
- // final few elements sequentially. postamble_start shows the start index
- // where this should happen.
- const int postamble_start =
- RoundDownVectors<kFloatValuesPerNeonVector>(v_size);
- int v = 0;
- for (; v < postamble_start; v += kFloatValuesPerNeonVector) {
- // Load 4 float values from vector1 and vector2 and accumulator.
- const float32x4_t v1_f32x4 = vld1q_f32(vector1 + v);
- const float32x4_t v2_f32x4 = vld1q_f32(vector2 + v);
- float32x4_t acc_32x4 = vld1q_f32(result + v);
- // Vector multiply-accumulate 4 float
- acc_32x4 = vmlaq_f32(acc_32x4, v1_f32x4, v2_f32x4);
- // Save to result array.
- vst1q_f32(result + v, acc_32x4);
- }
-#pragma clang loop vectorize(disable) unroll(disable)
- for (; v < v_size; v++) {
- result[v] += vector1[v] * vector2[v];
- }
-}
-
void NeonSub1Vector(const float* vector, int v_size, float* result) {
// If v_size is not divisible by the vector size, then we need to process the
// final few elements sequentially. postamble_start shows the start index
diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h
index cbb2cab..571d3ff 100644
--- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h
+++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h
@@ -142,11 +142,6 @@
NEON_OR_PORTABLE(CwiseClipping, input, clipping_value, n_batch, n_input);
}
-void VectorVectorCwiseProduct(const float* vector1, const float* vector2,
- int v_size, float* result) {
- NEON_OR_PORTABLE(VectorVectorCwiseProduct, vector1, vector2, v_size, result);
-}
-
void BatchVectorBatchVectorDotProduct(const int16_t* vector1,
const int16_t* vector2, int v_size,
int n_batch, int32_t* result,
@@ -155,13 +150,6 @@
vector1, vector2, v_size, n_batch, result, result_stride);
}
-void VectorVectorCwiseProductAccumulate(const float* vector1,
- const float* vector2, int v_size,
- float* result) {
- NEON_OR_PORTABLE(VectorVectorCwiseProductAccumulate, vector1, vector2, v_size,
- result);
-}
-
void VectorBatchVectorCwiseProductAccumulate(const int16_t* vector, int v_size,
const int16_t* batch_vector,
int n_batch, int32_t multiplier,
diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h
index ec98185..8e604d9 100644
--- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h
+++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h
@@ -107,16 +107,6 @@
const float* scaling_factors, int n_batch, float* __restrict__ result,
int result_stride);
-// Cwise product of two vectors.
-void NeonVectorVectorCwiseProduct(const float* vector1, const float* vector2,
- int v_size, float* result);
-
-// Cwise product and accumulate of two vectors. Since it's a MAC operation, the
-// assumption here is that result array is initialized to valid values.
-void NeonVectorVectorCwiseProductAccumulate(const float* vector1,
- const float* vector2, int v_size,
- float* result);
-
// Dot product of two vectors.
float NeonVectorVectorDotProduct(const float* vector1, const float* vector2,
int v_size);
diff --git a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h
index 0127645..9ceaa27 100644
--- a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h
+++ b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h
@@ -152,11 +152,6 @@
PortableCwiseClipping(input, clipping_value, n_batch, n_input);
}
-void VectorVectorCwiseProduct(const float* vector1, const float* vector2,
- int v_size, float* result) {
- NEON_OR_PORTABLE(VectorVectorCwiseProduct, vector1, vector2, v_size, result);
-}
-
void BatchVectorBatchVectorDotProduct(const int16_t* vector1,
const int16_t* vector2, int v_size,
int n_batch, int32_t* result,
@@ -165,13 +160,6 @@
vector1, vector2, v_size, n_batch, result, result_stride);
}
-void VectorVectorCwiseProductAccumulate(const float* vector1,
- const float* vector2, int v_size,
- float* result) {
- NEON_OR_PORTABLE(VectorVectorCwiseProductAccumulate, vector1, vector2, v_size,
- result);
-}
-
void VectorBatchVectorCwiseProductAccumulate(const int16_t* vector, int v_size,
const int16_t* batch_vector,
int n_batch, int32_t multiplier,
diff --git a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc
index 1b36144..8648096 100644
--- a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc
+++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc
@@ -504,14 +504,6 @@
}
}
-void PortableVectorVectorCwiseProduct(const float* vector1,
- const float* vector2, int v_size,
- float* result) {
- for (int v = 0; v < v_size; v++) {
- result[v] = vector1[v] * vector2[v];
- }
-}
-
float PortableVectorVectorDotProduct(const float* vector1, const float* vector2,
int v_size) {
float result = 0.0;
@@ -545,14 +537,6 @@
}
}
-void PortableVectorVectorCwiseProductAccumulate(const float* vector1,
- const float* vector2,
- int v_size, float* result) {
- for (int v = 0; v < v_size; v++) {
- result[v] += vector1[v] * vector2[v];
- }
-}
-
void PortableVectorBatchVectorCwiseProductAccumulate(
const int16_t* vector, int v_size, const int16_t* batch_vector, int n_batch,
int32_t multiplier, int shift, int16_t* result) {
diff --git a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h
index f3f41f7..b3f7c08 100644
--- a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h
+++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h
@@ -176,17 +176,6 @@
PortableCwiseClipping(input, clipping_value, n_batch, n_input);
}
-void VectorVectorCwiseProduct(const float* vector1, const float* vector2,
- int v_size, float* result) {
- PortableVectorVectorCwiseProduct(vector1, vector2, v_size, result);
-}
-
-void VectorVectorCwiseProductAccumulate(const float* vector1,
- const float* vector2, int v_size,
- float* result) {
- PortableVectorVectorCwiseProductAccumulate(vector1, vector2, v_size, result);
-}
-
void VectorBatchVectorCwiseProductAccumulate(const int16_t* vector, int v_size,
const int16_t* batch_vector,
int n_batch, int32_t multiplier,
diff --git a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h
index 0398edf..96d46ee 100644
--- a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h
+++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h
@@ -78,17 +78,6 @@
const float* scaling_factors, int n_batch, float* __restrict__ result,
int result_stride);
-// Cwise product of two vectors.
-void PortableVectorVectorCwiseProduct(const float* vector1,
- const float* vector2, int v_size,
- float* result);
-
-// Cwise product and accumulate of two vectors. Since it's a MAC opertation, the
-// assumption here is that result array is initialized to valid values.
-void PortableVectorVectorCwiseProductAccumulate(const float* vector1,
- const float* vector2,
- int v_size, float* result);
-
// Dot product of two vectors.
float PortableVectorVectorDotProduct(const float* vector1, const float* vector2,
int v_size);
diff --git a/tensorflow/lite/kernels/internal/tensor_utils.h b/tensorflow/lite/kernels/internal/tensor_utils.h
index 76162e3..62fe08b 100644
--- a/tensorflow/lite/kernels/internal/tensor_utils.h
+++ b/tensorflow/lite/kernels/internal/tensor_utils.h
@@ -314,14 +314,26 @@
int32_t n_input);
// Cwise product of two vectors.
-void VectorVectorCwiseProduct(const float* vector1, const float* vector2,
- int v_size, float* result);
+template <typename T>
+inline void VectorVectorCwiseProduct(const T* __restrict__ vector1,
+ const T* __restrict__ vector2, int v_size,
+ T* __restrict__ result) {
+ for (int v = 0; v < v_size; v++) {
+ *result++ = *vector1++ * *vector2++;
+ }
+}
// Cwise product and accumulate of two vectors. Since it's a MAC opertation, the
// assumption here is that result array is initialized to valid values.
-void VectorVectorCwiseProductAccumulate(const float* vector1,
- const float* vector2, int v_size,
- float* result);
+template <typename T>
+inline void VectorVectorCwiseProductAccumulate(const T* __restrict__ vector1,
+ const T* __restrict__ vector2,
+ int v_size,
+ T* __restrict__ result) {
+ for (int v = 0; v < v_size; v++) {
+ *result++ += *vector1++ * *vector2++;
+ }
+}
// Dot product of two vectors.
float VectorVectorDotProduct(const float* vector1, const float* vector2,
diff --git a/tensorflow/lite/kernels/lstm.cc b/tensorflow/lite/kernels/lstm.cc
index bbb9e17..4ef01dc 100644
--- a/tensorflow/lite/kernels/lstm.cc
+++ b/tensorflow/lite/kernels/lstm.cc
@@ -35,6 +35,7 @@
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/lstm_eval.h"
+#include "tensorflow/lite/kernels/lstm_shared.h"
namespace tflite {
namespace ops {
@@ -56,57 +57,7 @@
lstm_eval::QuantizedLstmParameter quantized_lstm_param;
};
-// For full inputs kernel (24-inputs).
-// Please note the 20-input full kernel is deprecated and only kept
-// here for backward compatibility.
namespace full {
-
-// Input Tensors of size {n_batch, n_input}
-constexpr int kInputTensor = 0;
-
-// Input weight tensors of size: {n_cell, n_input}
-constexpr int kInputToInputWeightsTensor = 1; // Optional
-constexpr int kInputToForgetWeightsTensor = 2;
-constexpr int kInputToCellWeightsTensor = 3;
-constexpr int kInputToOutputWeightsTensor = 4;
-
-// Recurrent weight tensors of size {n_cell, n_output}
-constexpr int kRecurrentToInputWeightsTensor = 5; // Optional
-constexpr int kRecurrentToForgetWeightsTensor = 6;
-constexpr int kRecurrentToCellWeightsTensor = 7;
-constexpr int kRecurrentToOutputWeightsTensor = 8;
-
-// Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
-constexpr int kCellToInputWeightsTensor = 9; // Optional
-constexpr int kCellToForgetWeightsTensor = 10; // Optional
-constexpr int kCellToOutputWeightsTensor = 11; // Optional
-
-// Gates bias tensors of size {n_cell}
-constexpr int kInputGateBiasTensor = 12; // Optional
-constexpr int kForgetGateBiasTensor = 13;
-constexpr int kCellGateBiasTensor = 14;
-constexpr int kOutputGateBiasTensor = 15;
-
-// Projection weight tensor of size {n_output, n_cell}
-constexpr int kProjectionWeightsTensor = 16; // Optional
-// Projection bias tensor of size {n_output}
-constexpr int kProjectionBiasTensor = 17; // Optional
-
-// These state tensors are defined as variable tensors, and will be modified by
-// this op.
-constexpr int kInputActivationStateTensor = 18;
-constexpr int kInputCellStateTensor = 19;
-
-// Layer norm coefficient tensors of size {n_cell}, representing a diagonal
-// matrix.
-constexpr int kInputLayerNormCoefficientsTensor = 20; // Optional
-constexpr int kForgetLayerNormCoefficientsTensor = 21; // Optional
-constexpr int kCellLayerNormCoefficientsTensor = 22; // Optional
-constexpr int kOutputLayerNormCoefficientsTensor = 23; // Optional
-
-// Output tensors.
-constexpr int kOutputTensor = 0;
-
namespace {
TfLiteStatus PopulateQuantizedLstmParams(
TfLiteContext* context, TfLiteNode* node,
diff --git a/tensorflow/lite/kernels/lstm_shared.h b/tensorflow/lite/kernels/lstm_shared.h
new file mode 100644
index 0000000..9e29650
--- /dev/null
+++ b/tensorflow/lite/kernels/lstm_shared.h
@@ -0,0 +1,78 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_KERNELS_LSTM_SHARED_H_
+#define TENSORFLOW_LITE_KERNELS_LSTM_SHARED_H_
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace lstm {
+// For full inputs kernel (24-inputs).
+// Please note the 20-input full kernel is deprecated and only kept
+// here for backward compatibility.
+namespace full {
+
+// Input Tensors of size {n_batch, n_input}
+constexpr int kInputTensor = 0;
+
+// Input weight tensors of size: {n_cell, n_input}
+constexpr int kInputToInputWeightsTensor = 1; // Optional
+constexpr int kInputToForgetWeightsTensor = 2;
+constexpr int kInputToCellWeightsTensor = 3;
+constexpr int kInputToOutputWeightsTensor = 4;
+
+// Recurrent weight tensors of size {n_cell, n_output}
+constexpr int kRecurrentToInputWeightsTensor = 5; // Optional
+constexpr int kRecurrentToForgetWeightsTensor = 6;
+constexpr int kRecurrentToCellWeightsTensor = 7;
+constexpr int kRecurrentToOutputWeightsTensor = 8;
+
+// Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
+constexpr int kCellToInputWeightsTensor = 9; // Optional
+constexpr int kCellToForgetWeightsTensor = 10; // Optional
+constexpr int kCellToOutputWeightsTensor = 11; // Optional
+
+// Gates bias tensors of size {n_cell}
+constexpr int kInputGateBiasTensor = 12; // Optional
+constexpr int kForgetGateBiasTensor = 13;
+constexpr int kCellGateBiasTensor = 14;
+constexpr int kOutputGateBiasTensor = 15;
+
+// Projection weight tensor of size {n_output, n_cell}
+constexpr int kProjectionWeightsTensor = 16; // Optional
+// Projection bias tensor of size {n_output}
+constexpr int kProjectionBiasTensor = 17; // Optional
+
+// These state tensors are defined as variable tensors, and will be modified by
+// this op.
+constexpr int kInputActivationStateTensor = 18;
+constexpr int kInputCellStateTensor = 19;
+
+// Layer norm coefficient tensors of size {n_cell}, representing a diagonal
+// matrix.
+constexpr int kInputLayerNormCoefficientsTensor = 20; // Optional
+constexpr int kForgetLayerNormCoefficientsTensor = 21; // Optional
+constexpr int kCellLayerNormCoefficientsTensor = 22; // Optional
+constexpr int kOutputLayerNormCoefficientsTensor = 23; // Optional
+
+// Output tensors.
+constexpr int kOutputTensor = 0;
+} // namespace full
+
+} // namespace lstm
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
+#endif // TENSORFLOW_LITE_KERNELS_LSTM_SHARED_H_
diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc
index 1a545b1..620f6ee 100644
--- a/tensorflow/lite/kernels/register.cc
+++ b/tensorflow/lite/kernels/register.cc
@@ -247,7 +247,7 @@
AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT());
AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK(),
/* min_version */ 1,
- /* max_version */ 2);
+ /* max_version */ 3);
AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV(),
/* min_version */ 1,
/* max_version */ 2);
diff --git a/tensorflow/lite/kernels/unpack.cc b/tensorflow/lite/kernels/unpack.cc
index 7de891c..8e66432 100644
--- a/tensorflow/lite/kernels/unpack.cc
+++ b/tensorflow/lite/kernels/unpack.cc
@@ -43,7 +43,8 @@
}
TF_LITE_ENSURE(context, 0 <= axis && axis < NumDimensions(input));
if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32 &&
- input->type != kTfLiteUInt8 && input->type != kTfLiteInt8) {
+ input->type != kTfLiteUInt8 && input->type != kTfLiteInt8 &&
+ input->type != kTfLiteBool) {
context->ReportError(context, "Type '%s' is not supported by unpack.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
@@ -112,6 +113,10 @@
UnpackImpl<int8_t>(context, node, input, data->num, data->axis);
break;
}
+ case kTfLiteBool: {
+ UnpackImpl<bool>(context, node, input, data->num, data->axis);
+ break;
+ }
default: {
context->ReportError(context, "Type '%s' is not supported by unpack.",
TfLiteTypeGetName(input->type));
diff --git a/tensorflow/lite/kernels/unpack_test.cc b/tensorflow/lite/kernels/unpack_test.cc
index 28d21cc..88eb706 100644
--- a/tensorflow/lite/kernels/unpack_test.cc
+++ b/tensorflow/lite/kernels/unpack_test.cc
@@ -87,43 +87,43 @@
TEST(UnpackOpTest, FloatThreeOutputs) {
Check<float>(/*axis=*/0, /*input_shape=*/{3, 2},
/*input_data=*/{1, 2, 3, 4, 5, 6},
- /*expected_output_shape=*/{{2}, {2}, {2}},
- /*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}});
+ /*exp_output_shape=*/{{2}, {2}, {2}},
+ /*exp_output_data=*/{{1, 2}, {3, 4}, {5, 6}});
}
TEST(UnpackOpTest, FloatThreeOutputsAxisOne) {
Check<float>(/*axis=*/1, /*input_shape=*/{3, 2},
/*input_data=*/{1, 2, 3, 4, 5, 6},
- /*expected_output_shape=*/{{3}, {3}},
- /*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}});
+ /*exp_output_shape=*/{{3}, {3}},
+ /*exp_output_data=*/{{1, 3, 5}, {2, 4, 6}});
}
TEST(UnpackOpTest, FloatThreeOutputsNegativeAxisOne) {
Check<float>(/*axis=*/-1, /*input_shape=*/{3, 2},
/*input_data=*/{1, 2, 3, 4, 5, 6},
- /*expected_output_shape=*/{{3}, {3}},
- /*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}});
+ /*exp_output_shape=*/{{3}, {3}},
+ /*exp_output_data=*/{{1, 3, 5}, {2, 4, 6}});
}
TEST(UnpackOpTest, FloatThreeOutputsNegativeAxisTwo) {
Check<float>(/*axis=*/-2, /*input_shape=*/{3, 2},
/*input_data=*/{1, 2, 3, 4, 5, 6},
- /*expected_output_shape=*/{{2}, {2}, {2}},
- /*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}});
+ /*exp_output_shape=*/{{2}, {2}, {2}},
+ /*exp_output_data=*/{{1, 2}, {3, 4}, {5, 6}});
}
TEST(UnpackOpTest, FloatOneOutput) {
Check<float>(/*axis=*/0, /*input_shape=*/{1, 6},
/*input_data=*/{1, 2, 3, 4, 5, 6},
- /*expected_output_shape=*/{{6}},
- /*expected_output_data=*/{{1, 2, 3, 4, 5, 6}});
+ /*exp_output_shape=*/{{6}},
+ /*exp_output_data=*/{{1, 2, 3, 4, 5, 6}});
}
TEST(UnpackOpTest, FloatThreeDimensionsOutputs) {
Check<float>(/*axis=*/2, /*input_shape=*/{2, 2, 2},
/*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8},
- /*expected_output_shape=*/{{2, 2}, {2, 2}},
- /*expected_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}});
+ /*exp_output_shape=*/{{2, 2}, {2, 2}},
+ /*exp_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}});
}
TEST(UnpackOpTest, FloatVectorToScalar) {
@@ -137,32 +137,32 @@
TEST(UnpackOpTest, IntThreeOutputs) {
Check<int32_t>(/*axis=*/0, /*input_shape=*/{3, 2},
/*input_data=*/{1, 2, 3, 4, 5, 6},
- /*expected_output_shape=*/{{2}, {2}, {2}},
- /*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
+ /*exp_output_shape=*/{{2}, {2}, {2}},
+ /*exp_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
/*type=*/TensorType_INT32);
}
TEST(UnpackOpTest, IntThreeOutputsAxisOne) {
Check<int32_t>(/*axis=*/1, /*input_shape=*/{3, 2},
/*input_data=*/{1, 2, 3, 4, 5, 6},
- /*expected_output_shape=*/{{3}, {3}},
- /*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}},
+ /*exp_output_shape=*/{{3}, {3}},
+ /*exp_output_data=*/{{1, 3, 5}, {2, 4, 6}},
/*type=*/TensorType_INT32);
}
TEST(UnpackOpTest, IntOneOutput) {
Check<int32_t>(/*axis=*/0, /*input_shape=*/{1, 6},
/*input_data=*/{1, 2, 3, 4, 5, 6},
- /*expected_output_shape=*/{{6}},
- /*expected_output_data=*/{{1, 2, 3, 4, 5, 6}},
+ /*exp_output_shape=*/{{6}},
+ /*exp_output_data=*/{{1, 2, 3, 4, 5, 6}},
/*type=*/TensorType_INT32);
}
TEST(UnpackOpTest, IntThreeDimensionsOutputs) {
Check<int32_t>(/*axis=*/2, /*input_shape=*/{2, 2, 2},
/*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8},
- /*expected_output_shape=*/{{2, 2}, {2, 2}},
- /*expected_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}},
+ /*exp_output_shape=*/{{2, 2}, {2, 2}},
+ /*exp_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}},
/*type=*/TensorType_INT32);
}
@@ -178,48 +178,48 @@
TEST(UnpackOpTest, Uint8ThreeOutputs) {
Check<uint8_t>(/*axis=*/0, /*input_shape=*/{3, 2},
/*input_data=*/{1, 2, 3, 4, 5, 6},
- /*expected_output_shape=*/{{2}, {2}, {2}},
- /*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
+ /*exp_output_shape=*/{{2}, {2}, {2}},
+ /*exp_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
/*type=*/TensorType_UINT8);
}
TEST(UnpackOpTest, Uint8ThreeOutputsAxisOne) {
Check<uint8_t>(/*axis=*/1, /*input_shape=*/{3, 2},
/*input_data=*/{1, 2, 3, 4, 5, 6},
- /*expected_output_shape=*/{{3}, {3}},
- /*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}},
+ /*exp_output_shape=*/{{3}, {3}},
+ /*exp_output_data=*/{{1, 3, 5}, {2, 4, 6}},
/*type=*/TensorType_UINT8);
}
TEST(UnpackOpTest, Uint8ThreeOutputsNegativeAxisOne) {
Check<uint8_t>(/*axis=*/-1, /*input_shape=*/{3, 2},
/*input_data=*/{1, 2, 3, 4, 5, 6},
- /*expected_output_shape=*/{{3}, {3}},
- /*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}},
+ /*exp_output_shape=*/{{3}, {3}},
+ /*exp_output_data=*/{{1, 3, 5}, {2, 4, 6}},
/*type=*/TensorType_UINT8);
}
TEST(UnpackOpTest, Uint8ThreeOutputsNegativeAxisTwo) {
Check<uint8_t>(/*axis=*/-2, /*input_shape=*/{3, 2},
/*input_data=*/{1, 2, 3, 4, 5, 6},
- /*expected_output_shape=*/{{2}, {2}, {2}},
- /*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
+ /*exp_output_shape=*/{{2}, {2}, {2}},
+ /*exp_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
/*type=*/TensorType_UINT8);
}
TEST(UnpackOpTest, Uint8OneOutput) {
Check<uint8_t>(/*axis=*/0, /*input_shape=*/{1, 6},
/*input_data=*/{1, 2, 3, 4, 5, 6},
- /*expected_output_shape=*/{{6}},
- /*expected_output_data=*/{{1, 2, 3, 4, 5, 6}},
+ /*exp_output_shape=*/{{6}},
+ /*exp_output_data=*/{{1, 2, 3, 4, 5, 6}},
/*type=*/TensorType_UINT8);
}
TEST(UnpackOpTest, Uint8ThreeDimensionsOutputs) {
Check<uint8_t>(/*axis=*/2, /*input_shape=*/{2, 2, 2},
/*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8},
- /*expected_output_shape=*/{{2, 2}, {2, 2}},
- /*expected_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}},
+ /*exp_output_shape=*/{{2, 2}, {2, 2}},
+ /*exp_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}},
/*type=*/TensorType_UINT8);
}
@@ -235,48 +235,48 @@
TEST(UnpackOpTest, Int8ThreeOutputs) {
Check<int8_t>(/*axis=*/0, /*input_shape=*/{3, 2},
/*input_data=*/{1, 2, 3, 4, 5, 6},
- /*expected_output_shape=*/{{2}, {2}, {2}},
- /*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
+ /*exp_output_shape=*/{{2}, {2}, {2}},
+ /*exp_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
/*type=*/TensorType_INT8);
}
TEST(UnpackOpTest, Int8ThreeOutputsAxisOne) {
Check<int8_t>(/*axis=*/1, /*input_shape=*/{3, 2},
/*input_data=*/{1, 2, 3, 4, 5, 6},
- /*expected_output_shape=*/{{3}, {3}},
- /*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}},
+ /*exp_output_shape=*/{{3}, {3}},
+ /*exp_output_data=*/{{1, 3, 5}, {2, 4, 6}},
/*type=*/TensorType_INT8);
}
TEST(UnpackOpTest, Int8ThreeOutputsNegativeAxisOne) {
Check<int8_t>(/*axis=*/-1, /*input_shape=*/{3, 2},
/*input_data=*/{1, 2, 3, 4, 5, 6},
- /*expected_output_shape=*/{{3}, {3}},
- /*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}},
+ /*exp_output_shape=*/{{3}, {3}},
+ /*exp_output_data=*/{{1, 3, 5}, {2, 4, 6}},
/*type=*/TensorType_INT8);
}
TEST(UnpackOpTest, Int8ThreeOutputsNegativeAxisTwo) {
Check<int8_t>(/*axis=*/-2, /*input_shape=*/{3, 2},
/*input_data=*/{1, 2, 3, 4, 5, 6},
- /*expected_output_shape=*/{{2}, {2}, {2}},
- /*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
+ /*exp_output_shape=*/{{2}, {2}, {2}},
+ /*exp_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
/*type=*/TensorType_INT8);
}
TEST(UnpackOpTest, Int8OneOutput) {
Check<int8_t>(/*axis=*/0, /*input_shape=*/{1, 6},
/*input_data=*/{1, 2, 3, 4, 5, 6},
- /*expected_output_shape=*/{{6}},
- /*expected_output_data=*/{{1, 2, 3, 4, 5, 6}},
+ /*exp_output_shape=*/{{6}},
+ /*exp_output_data=*/{{1, 2, 3, 4, 5, 6}},
/*type=*/TensorType_INT8);
}
TEST(UnpackOpTest, Int8ThreeDimensionsOutputs) {
Check<int8_t>(/*axis=*/2, /*input_shape=*/{2, 2, 2},
/*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8},
- /*expected_output_shape=*/{{2, 2}, {2, 2}},
- /*expected_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}},
+ /*exp_output_shape=*/{{2, 2}, {2, 2}},
+ /*exp_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}},
/*type=*/TensorType_INT8);
}
@@ -288,5 +288,69 @@
/*type=*/TensorType_INT8);
}
+// bool tests.
+TEST(UnpackOpTest, BoolThreeOutputs) {
+ Check<bool>(
+ /*axis=*/0, /*input_shape=*/{3, 2},
+ /*input_data=*/{true, false, true, false, true, false},
+ /*exp_output_shape=*/{{2}, {2}, {2}},
+ /*exp_output_data=*/{{true, false}, {true, false}, {true, false}},
+ /*type=*/TensorType_BOOL);
+}
+
+TEST(UnpackOpTest, BoolThreeOutputsAxisOne) {
+ Check<bool>(
+ /*axis=*/1, /*input_shape=*/{3, 2},
+ /*input_data=*/{true, false, true, false, true, false},
+ /*exp_output_shape=*/{{3}, {3}},
+ /*exp_output_data=*/{{true, true, true}, {false, false, false}},
+ /*type=*/TensorType_BOOL);
+}
+
+TEST(UnpackOpTest, BoolThreeOutputsNegativeAxisOne) {
+ Check<bool>(
+ /*axis=*/-1, /*input_shape=*/{3, 2},
+ /*input_data=*/{true, false, true, false, true, false},
+ /*exp_output_shape=*/{{3}, {3}},
+ /*exp_output_data=*/{{true, true, true}, {false, false, false}},
+ /*type=*/TensorType_BOOL);
+}
+
+TEST(UnpackOpTest, BoolThreeOutputsNegativeAxisTwo) {
+ Check<bool>(
+ /*axis=*/-2, /*input_shape=*/{3, 2},
+ /*input_data=*/{true, false, true, false, true, false},
+ /*exp_output_shape=*/{{2}, {2}, {2}},
+ /*exp_output_data=*/{{true, false}, {true, false}, {true, false}},
+ /*type=*/TensorType_BOOL);
+}
+
+TEST(UnpackOpTest, BoolOneOutput) {
+ Check<bool>(
+ /*axis=*/0, /*input_shape=*/{1, 6},
+ /*input_data=*/{true, false, true, false, true, false},
+ /*exp_output_shape=*/{{6}},
+ /*exp_output_data=*/{{true, false, true, false, true, false}},
+ /*type=*/TensorType_BOOL);
+}
+
+TEST(UnpackOpTest, BoolThreeDimensionsOutputs) {
+ Check<bool>(
+ /*axis=*/2, /*input_shape=*/{2, 2, 2},
+ /*input_data=*/{true, false, true, false, true, false, true, false},
+ /*exp_output_shape=*/{{2, 2}, {2, 2}},
+ /*exp_output_data=*/
+ {{true, true, true, true}, {false, false, false, false}},
+ /*type=*/TensorType_BOOL);
+}
+
+TEST(UnpackOpTest, BoolVectorToScalar) {
+ Check<bool>(/*axis=*/0, /*input_shape=*/{5},
+ /*input_data=*/{true, false, true, false, true},
+ /*exp_output_shape=*/{{}, {}, {}, {}, {}},
+ /*exp_output_data=*/{{true}, {false}, {true}, {false}, {true}},
+ /*type=*/TensorType_BOOL);
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/lite/micro/examples/magic_wand/README.md b/tensorflow/lite/micro/examples/magic_wand/README.md
index 91e238a..7241ce7 100644
--- a/tensorflow/lite/micro/examples/magic_wand/README.md
+++ b/tensorflow/lite/micro/examples/magic_wand/README.md
@@ -27,14 +27,9 @@
### Install the Arduino_TensorFlowLite library
-Download the current nightly build of the library:
-[magic_wand.zip](https://storage.googleapis.com/tensorflow-nightly/github/tensorflow/tensorflow/lite/micro/tools/make/gen/arduino_x86_64/prj/magic_wand/magic_wand.zip)
-
-Next, import this zip file into the Arduino Desktop IDE by going to `Sketch
-->Include Library -> Add .ZIP Library...`. This example application is included
-as part of the official TensorFlow Lite Arduino library. To install it, open the
-Arduino library manager in `Tools -> Manage Libraries...` and search for
-`Arduino_TensorFlowLite`.
+This example application is included as part of the official TensorFlow Lite
+Arduino library. To install it, open the Arduino library manager in
+`Tools -> Manage Libraries...` and search for `Arduino_TensorFlowLite`.
### Install and patch the accelerometer driver
diff --git a/tensorflow/lite/micro/examples/magic_wand/train/data_augmentation.py b/tensorflow/lite/micro/examples/magic_wand/train/data_augmentation.py
index 45700b9..8d30fa1 100644
--- a/tensorflow/lite/micro/examples/magic_wand/train/data_augmentation.py
+++ b/tensorflow/lite/micro/examples/magic_wand/train/data_augmentation.py
@@ -22,6 +22,7 @@
from __future__ import print_function
import random
+
import numpy as np
diff --git a/tensorflow/lite/micro/examples/magic_wand/train/data_load.py b/tensorflow/lite/micro/examples/magic_wand/train/data_load.py
index 321b9c7..ceb24a7 100644
--- a/tensorflow/lite/micro/examples/magic_wand/train/data_load.py
+++ b/tensorflow/lite/micro/examples/magic_wand/train/data_load.py
@@ -22,6 +22,7 @@
from __future__ import print_function
import json
+
import numpy as np
import tensorflow as tf
diff --git a/tensorflow/lite/micro/examples/magic_wand/train/train.py b/tensorflow/lite/micro/examples/magic_wand/train/train.py
index 0f17f33..6ccaa8c 100644
--- a/tensorflow/lite/micro/examples/magic_wand/train/train.py
+++ b/tensorflow/lite/micro/examples/magic_wand/train/train.py
@@ -26,6 +26,7 @@
import datetime
import os
from data_load import DataLoader
+
import numpy as np
import tensorflow as tf
diff --git a/tensorflow/lite/micro/examples/magic_wand/train/train_test.py b/tensorflow/lite/micro/examples/magic_wand/train/train_test.py
index 18467ab..4790eb2 100644
--- a/tensorflow/lite/micro/examples/magic_wand/train/train_test.py
+++ b/tensorflow/lite/micro/examples/magic_wand/train/train_test.py
@@ -21,6 +21,7 @@
from __future__ import print_function
import unittest
+
import numpy as np
import tensorflow as tf
from train import build_cnn
diff --git a/tensorflow/lite/micro/examples/micro_speech/CMSIS/create_constants.py b/tensorflow/lite/micro/examples/micro_speech/CMSIS/create_constants.py
index 6d0b4e2..7d14dc6 100755
--- a/tensorflow/lite/micro/examples/micro_speech/CMSIS/create_constants.py
+++ b/tensorflow/lite/micro/examples/micro_speech/CMSIS/create_constants.py
@@ -19,6 +19,7 @@
from __future__ import print_function
# import soundfile as sf
+
import numpy as np
diff --git a/tensorflow/lite/micro/examples/micro_speech/apollo3/captured_data_to_wav.py b/tensorflow/lite/micro/examples/micro_speech/apollo3/captured_data_to_wav.py
index 52604f5..c9ba8fd 100644
--- a/tensorflow/lite/micro/examples/micro_speech/apollo3/captured_data_to_wav.py
+++ b/tensorflow/lite/micro/examples/micro_speech/apollo3/captured_data_to_wav.py
@@ -20,6 +20,7 @@
import struct
# import matplotlib.pyplot as plt
+
import numpy as np
import soundfile as sf
diff --git a/tensorflow/lite/micro/examples/micro_speech/apollo3/compare_1k.py b/tensorflow/lite/micro/examples/micro_speech/apollo3/compare_1k.py
index fab178b..b0a0cd5 100644
--- a/tensorflow/lite/micro/examples/micro_speech/apollo3/compare_1k.py
+++ b/tensorflow/lite/micro/examples/micro_speech/apollo3/compare_1k.py
@@ -20,6 +20,7 @@
import struct
import matplotlib.pyplot as plt
+
import numpy as np
# import soundfile as sf
diff --git a/tensorflow/lite/micro/examples/person_detection/utils/raw_to_bitmap.py b/tensorflow/lite/micro/examples/person_detection/utils/raw_to_bitmap.py
index 6658c60..4ebb849 100644
--- a/tensorflow/lite/micro/examples/person_detection/utils/raw_to_bitmap.py
+++ b/tensorflow/lite/micro/examples/person_detection/utils/raw_to_bitmap.py
@@ -36,6 +36,7 @@
import os
import os.path
import re
+
import numpy as np
_DICT_RESOLUTIONS = {
diff --git a/tensorflow/lite/micro/micro_allocator.cc b/tensorflow/lite/micro/micro_allocator.cc
index 150c2c9..3f28b22 100644
--- a/tensorflow/lite/micro/micro_allocator.cc
+++ b/tensorflow/lite/micro/micro_allocator.cc
@@ -418,16 +418,14 @@
size_t type_size;
TF_LITE_ENSURE_STATUS(BytesRequiredForTensor(
flatbuffer_tensor, &result->bytes, &type_size, error_reporter));
- // Copy the shape of the tensor from the serialized data into the runtime
- // form. We have to allocate memory for this.
- result->dims =
- reinterpret_cast<TfLiteIntArray*>(memory_allocator_.AllocateFromTail(
- TfLiteIntArrayGetSizeInBytes(flatbuffer_tensor.shape()->Length()),
- alignof(TfLiteIntArray)));
- result->dims->size = flatbuffer_tensor.shape()->Length();
- for (size_t n = 0; n < flatbuffer_tensor.shape()->Length(); ++n) {
- result->dims->data[n] = flatbuffer_tensor.shape()->Get(n);
- }
+
+ // TFLM doesn't allow reshaping the tensor which requires dynamic memory
+ // allocation so it is safe to drop the const qualifier. In the future, if we
+ // really want to update the tensor shape, we can always pass in a new
+ // TfLiteIntArray - especially we have to do so if the dimension is changed.
+ result->dims = const_cast<TfLiteIntArray*>(
+ reinterpret_cast<const TfLiteIntArray*>(flatbuffer_tensor.shape()));
+
// Copy the quantization information from the serialized data.
const auto* src_quantization = flatbuffer_tensor.quantization();
if (src_quantization && src_quantization->scale() &&
diff --git a/tensorflow/lite/micro/tools/make/download_and_extract.sh b/tensorflow/lite/micro/tools/make/download_and_extract.sh
index 8a82cc0..5e96899 100755
--- a/tensorflow/lite/micro/tools/make/download_and_extract.sh
+++ b/tensorflow/lite/micro/tools/make/download_and_extract.sh
@@ -80,6 +80,10 @@
local tempfile=${tempdir}/temp_file
local curl_retries=3
+ command -v curl >/dev/null 2>&1 || {
+ echo >&2 "The required 'curl' tool isn't installed. Try 'apt-get install curl'."; exit 1;
+ }
+
echo "downloading ${url}" >&2
mkdir -p "${dir}"
# We've been seeing occasional 56 errors from valid URLs, so set up a retry
diff --git a/tensorflow/lite/micro/tools/make/fix_arduino_subfolders.py b/tensorflow/lite/micro/tools/make/fix_arduino_subfolders.py
index 2465049..a68267c 100755
--- a/tensorflow/lite/micro/tools/make/fix_arduino_subfolders.py
+++ b/tensorflow/lite/micro/tools/make/fix_arduino_subfolders.py
@@ -22,6 +22,7 @@
import argparse
import glob
import os
+
import six
diff --git a/tensorflow/lite/micro/tools/make/generate_keil_project.py b/tensorflow/lite/micro/tools/make/generate_keil_project.py
index 5af4b4e..5a9950c 100644
--- a/tensorflow/lite/micro/tools/make/generate_keil_project.py
+++ b/tensorflow/lite/micro/tools/make/generate_keil_project.py
@@ -22,6 +22,7 @@
import argparse
import os.path
import re
+
import six
diff --git a/tensorflow/lite/micro/tools/make/transform_arduino_source.py b/tensorflow/lite/micro/tools/make/transform_arduino_source.py
index e6b0265..c5c74b7 100644
--- a/tensorflow/lite/micro/tools/make/transform_arduino_source.py
+++ b/tensorflow/lite/micro/tools/make/transform_arduino_source.py
@@ -22,6 +22,7 @@
import argparse
import re
import sys
+
import six
diff --git a/tensorflow/lite/micro/tools/make/transform_source.py b/tensorflow/lite/micro/tools/make/transform_source.py
index f7eaaa0..7957476 100644
--- a/tensorflow/lite/micro/tools/make/transform_source.py
+++ b/tensorflow/lite/micro/tools/make/transform_source.py
@@ -26,6 +26,7 @@
import os
import re
import sys
+
import six
diff --git a/tensorflow/lite/model.cc b/tensorflow/lite/model.cc
index d7523db..0556f47 100644
--- a/tensorflow/lite/model.cc
+++ b/tensorflow/lite/model.cc
@@ -129,23 +129,23 @@
}
std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromBuffer(
- const char* buffer, size_t buffer_size, TfLiteVerifier* extra_verifier,
- ErrorReporter* error_reporter) {
+ const char* caller_owned_buffer, size_t buffer_size,
+ TfLiteVerifier* extra_verifier, ErrorReporter* error_reporter) {
error_reporter = ValidateErrorReporter(error_reporter);
- flatbuffers::Verifier base_verifier(reinterpret_cast<const uint8_t*>(buffer),
- buffer_size);
+ flatbuffers::Verifier base_verifier(
+ reinterpret_cast<const uint8_t*>(caller_owned_buffer), buffer_size);
if (!VerifyModelBuffer(base_verifier)) {
error_reporter->Report("The model is not a valid Flatbuffer buffer");
return nullptr;
}
- if (extra_verifier &&
- !extra_verifier->Verify(buffer, buffer_size, error_reporter)) {
+ if (extra_verifier && !extra_verifier->Verify(caller_owned_buffer,
+ buffer_size, error_reporter)) {
return nullptr;
}
- return BuildFromBuffer(buffer, buffer_size, error_reporter);
+ return BuildFromBuffer(caller_owned_buffer, buffer_size, error_reporter);
}
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromModel(
diff --git a/tensorflow/lite/model.h b/tensorflow/lite/model.h
index b8b4b44..159f800 100644
--- a/tensorflow/lite/model.h
+++ b/tensorflow/lite/model.h
@@ -110,7 +110,7 @@
/// and must ensure its lifetime is longer than the FlatBufferModel instance.
/// Returns a nullptr in case of failure.
static std::unique_ptr<FlatBufferModel> VerifyAndBuildFromBuffer(
- const char* buffer, size_t buffer_size,
+ const char* caller_owned_buffer, size_t buffer_size,
TfLiteVerifier* extra_verifier = nullptr,
ErrorReporter* error_reporter = DefaultErrorReporter());
diff --git a/tensorflow/lite/profiling/profile_summarizer.cc b/tensorflow/lite/profiling/profile_summarizer.cc
index 0b51b65..4b394f1 100644
--- a/tensorflow/lite/profiling/profile_summarizer.cc
+++ b/tensorflow/lite/profiling/profile_summarizer.cc
@@ -96,7 +96,9 @@
} // namespace
-ProfileSummarizer::ProfileSummarizer() {
+ProfileSummarizer::ProfileSummarizer()
+ : delegate_stats_calculator_(
+ new tensorflow::StatsCalculator(GetProfileSummarizerOptions())) {
// Create stats calculator for the primary graph.
stats_calculator_map_[0] = std::unique_ptr<tensorflow::StatsCalculator>(
new tensorflow::StatsCalculator(GetProfileSummarizerOptions()));
@@ -126,6 +128,7 @@
// Total time will be accumulated per subgraph.
std::map<uint32_t, int64_t> total_us_per_subgraph_map;
+ int64_t delegate_internal_total_us = 0;
for (auto event : events) {
const auto subgraph_index = event->event_subgraph_index;
@@ -156,6 +159,17 @@
stats_calculator->AddNodeStats(node_name_in_stats, type_in_stats,
node_num, start_us, node_exec_time,
0 /*memory */);
+ } else if (event->event_type ==
+ Profiler::EventType::DELEGATE_OPERATOR_INVOKE_EVENT) {
+ const std::string node_name(event->tag);
+ // Append event_metadata to node name because 'stats_calculator' can not
+ // distinguish two nodes w/ the same 'node_name'.
+ const auto node_name_in_stats =
+ "Delegate/" + node_name + ":" + std::to_string(event->event_metadata);
+
+ delegate_stats_calculator_->AddNodeStats(
+ node_name_in_stats, "DelegateOpInvoke", node_num, start_us,
+ node_exec_time, 0 /*memory */);
} else {
// TODO(b/139812778) consider use a different stats_calculator to record
// non-op-invoke events so that these could be separated from
@@ -171,8 +185,11 @@
// Add total time except actual delegate ops since the elapsed time of the
// delegate ops inside are already combined at a fused DELEGATE op.
- if (strcmp(event->tag, "DelegateOpInvoke") != 0) {
+ if (event->event_type !=
+ Profiler::EventType::DELEGATE_OPERATOR_INVOKE_EVENT) {
total_us_per_subgraph_map[subgraph_index] += node_exec_time;
+ } else {
+ delegate_internal_total_us += node_exec_time;
}
++node_num;
}
@@ -182,6 +199,9 @@
GetStatsCalculator(total_us_per_subgraph_pair.first);
stats_calculator->UpdateRunTotalUs(total_us_per_subgraph_pair.second);
}
+ if (delegate_internal_total_us > 0) {
+ delegate_stats_calculator_->UpdateRunTotalUs(delegate_internal_total_us);
+ }
}
tensorflow::StatsCalculator* ProfileSummarizer::GetStatsCalculator(
@@ -217,6 +237,15 @@
}
stream << subgraph_stats->GetShortSummary() << std::endl;
}
+
+ if (delegate_stats_calculator_->num_runs() > 0) {
+ stream << "Delegate internal: " << std::endl;
+ if (include_output_string) {
+ stream << delegate_stats_calculator_->GetOutputString();
+ }
+ stream << delegate_stats_calculator_->GetShortSummary() << std::endl;
+ }
+
return stream.str();
}
diff --git a/tensorflow/lite/profiling/profile_summarizer.h b/tensorflow/lite/profiling/profile_summarizer.h
index fa12876..d097231 100644
--- a/tensorflow/lite/profiling/profile_summarizer.h
+++ b/tensorflow/lite/profiling/profile_summarizer.h
@@ -61,6 +61,8 @@
std::map<uint32_t, std::unique_ptr<tensorflow::StatsCalculator>>
stats_calculator_map_;
+ std::unique_ptr<tensorflow::StatsCalculator> delegate_stats_calculator_;
+
// GenerateReport returns the report of subgraphs in a string format.
std::string GenerateReport(std::string tag, bool include_output_string);
};
diff --git a/tensorflow/lite/profiling/profile_summarizer_test.cc b/tensorflow/lite/profiling/profile_summarizer_test.cc
index 0c4b9fc..87e689e 100644
--- a/tensorflow/lite/profiling/profile_summarizer_test.cc
+++ b/tensorflow/lite/profiling/profile_summarizer_test.cc
@@ -178,7 +178,6 @@
TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
subgraph_test_util::CheckIntTensor(output, {1, 2}, {6, 9});
- ProfileSummarizer summarizer;
auto events = profiler.GetProfileEvents();
EXPECT_EQ(2, events.size());
int event_count_of_subgraph_zero = std::count_if(
@@ -206,7 +205,6 @@
TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
subgraph_test_util::CheckIntTensor(output, {1, 2}, {5, 14});
- ProfileSummarizer summarizer;
auto events = profiler.GetProfileEvents();
EXPECT_EQ(2, events.size());
int event_count_of_subgraph_zero = std::count_if(
diff --git a/tensorflow/lite/python/interpreter.py b/tensorflow/lite/python/interpreter.py
index 30f224e..153b6f1 100644
--- a/tensorflow/lite/python/interpreter.py
+++ b/tensorflow/lite/python/interpreter.py
@@ -120,7 +120,7 @@
raise ValueError(capture.message)
def __del__(self):
- # __del__ can be called multiple times, so if the delegate is destroyed.
+ # __del__ can not be called multiple times, so if the delegate is destroyed.
# don't try to destroy it twice.
if self._library is not None:
self._library.tflite_plugin_destroy_delegate.argtypes = [ctypes.c_void_p]
diff --git a/tensorflow/lite/python/interpreter_test.py b/tensorflow/lite/python/interpreter_test.py
index bfe3459..9c8dbba 100644
--- a/tensorflow/lite/python/interpreter_test.py
+++ b/tensorflow/lite/python/interpreter_test.py
@@ -21,6 +21,7 @@
import ctypes
import io
import sys
+
import numpy as np
import six
@@ -416,7 +417,10 @@
def testFail(self):
with self.assertRaisesRegexp(
- ValueError, 'Failed to load delegate from .*\nFail argument sent.'):
+ # Due to exception chaining in PY3, we can't be more specific here and check that
+ # the phrase 'Fail argument sent' is present.
+ ValueError,
+ r'Failed to load delegate from'):
interpreter_wrapper.load_delegate(
self._delegate_file, options={'fail': 'fail'})
diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py
index f4a6a4e..1f0156d 100644
--- a/tensorflow/lite/python/lite_v2_test.py
+++ b/tensorflow/lite/python/lite_v2_test.py
@@ -20,6 +20,7 @@
from __future__ import print_function
import os
+
from absl.testing import parameterized
import numpy as np
from six.moves import range
diff --git a/tensorflow/lite/python/optimize/calibration_wrapper.cc b/tensorflow/lite/python/optimize/calibration_wrapper.cc
index d4c7a3d..89ffb34 100644
--- a/tensorflow/lite/python/optimize/calibration_wrapper.cc
+++ b/tensorflow/lite/python/optimize/calibration_wrapper.cc
@@ -56,6 +56,12 @@
return copied_model;
}
+bool NoOpModel(const tflite::FlatBufferModel& model) {
+ return model->subgraphs()->size() == 1 &&
+ (!model->subgraphs()->begin()->operators() ||
+ model->subgraphs()->begin()->operators()->size() == 0);
+}
+
inline TensorType TfLiteTypeToSchemaType(TfLiteType type) {
switch (type) {
case kTfLiteNoType:
@@ -92,12 +98,14 @@
std::unique_ptr<tflite::interpreter_wrapper::PythonErrorReporter>
error_reporter,
std::unique_ptr<tflite::FlatBufferModel> model,
- std::unique_ptr<tflite::optimize::calibration::CalibrationReader> reader)
+ std::unique_ptr<tflite::optimize::calibration::CalibrationReader> reader,
+ std::unique_ptr<std::string> model_str)
: interpreter_(std::move(interpreter)),
error_reporter_(std::move(error_reporter)),
resolver_(std::move(resolver)),
model_(std::move(model)),
- reader_(std::move(reader)) {}
+ reader_(std::move(reader)),
+ model_str_(std::move(model_str)) {}
CalibrationWrapper::~CalibrationWrapper() {}
@@ -197,6 +205,11 @@
int output_py_type,
bool allow_float,
bool enable_mlir_quantizer) {
+ if (NoOpModel(*model_)) {
+ return python_utils::ConvertToPyString(model_str_->data(),
+ model_str_->size());
+ }
+
TfLiteType input_type = python_utils::TfLiteTypeFromPyType(input_py_type);
TfLiteType output_type = python_utils::TfLiteTypeFromPyType(output_py_type);
if (input_type == kTfLiteNoType || output_type == kTfLiteNoType) {
@@ -288,9 +301,16 @@
return nullptr;
}
+ auto model_str = std::make_unique<std::string>(buf, length);
+ // If we are not going to use this string during quantization, reset the
+ // pointer and release the memory.
+ if (!NoOpModel(*model)) {
+ model_str.reset();
+ }
+
auto wrapper = new CalibrationWrapper(
std::move(interpreter), std::move(resolver), std::move(error_reporter),
- std::move(model), std::move(reader));
+ std::move(model), std::move(reader), std::move(model_str));
return wrapper;
}
diff --git a/tensorflow/lite/python/optimize/calibration_wrapper.h b/tensorflow/lite/python/optimize/calibration_wrapper.h
index 2484858..0fefc29 100644
--- a/tensorflow/lite/python/optimize/calibration_wrapper.h
+++ b/tensorflow/lite/python/optimize/calibration_wrapper.h
@@ -77,7 +77,8 @@
std::unique_ptr<tflite::interpreter_wrapper::PythonErrorReporter>
error_reporter,
std::unique_ptr<tflite::FlatBufferModel> model,
- std::unique_ptr<tflite::optimize::calibration::CalibrationReader> reader);
+ std::unique_ptr<tflite::optimize::calibration::CalibrationReader> reader,
+ std::unique_ptr<std::string> model_str_);
CalibrationWrapper(const CalibrationWrapper& rhs);
@@ -89,6 +90,7 @@
std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver_;
std::unique_ptr<tflite::FlatBufferModel> model_;
std::unique_ptr<tflite::optimize::calibration::CalibrationReader> reader_;
+ std::unique_ptr<std::string> model_str_;
};
} // namespace calibration_wrapper
diff --git a/tensorflow/lite/python/optimize/calibrator_test.py b/tensorflow/lite/python/optimize/calibrator_test.py
index 934e441..28e8723 100644
--- a/tensorflow/lite/python/optimize/calibrator_test.py
+++ b/tensorflow/lite/python/optimize/calibrator_test.py
@@ -18,6 +18,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+
from absl.testing import parameterized
import numpy as np
from six.moves import range
diff --git a/tensorflow/lite/python/testdata/test_delegate.cc b/tensorflow/lite/python/testdata/test_delegate.cc
index a7c48e0..98854ca 100644
--- a/tensorflow/lite/python/testdata/test_delegate.cc
+++ b/tensorflow/lite/python/testdata/test_delegate.cc
@@ -66,7 +66,12 @@
void tflite_plugin_destroy_delegate(TfLiteDelegate* delegate) {
num_delegates_destroyed++;
delete delegate;
- if (destruction_callback) destruction_callback("test_delegate");
+ if (destruction_callback) {
+ destruction_callback("test_delegate");
+ // destruction_callback is a global variable,
+ // so it should be set to nullptr here to avoid crashes
+ destruction_callback = nullptr;
+ }
}
void initialize_counters() {
diff --git a/tensorflow/lite/python/tflite_convert_test.py b/tensorflow/lite/python/tflite_convert_test.py
index 1b50452..1e80907 100644
--- a/tensorflow/lite/python/tflite_convert_test.py
+++ b/tensorflow/lite/python/tflite_convert_test.py
@@ -19,6 +19,7 @@
from __future__ import print_function
import os
+
import numpy as np
from tensorflow.lite.python import tflite_convert
diff --git a/tensorflow/lite/testing/generate_examples.py b/tensorflow/lite/testing/generate_examples.py
index 39a0ca4..fd21d42 100644
--- a/tensorflow/lite/testing/generate_examples.py
+++ b/tensorflow/lite/testing/generate_examples.py
@@ -29,7 +29,6 @@
from __future__ import division
from __future__ import print_function
-
import tensorflow as tf
import argparse
import os
diff --git a/tensorflow/lite/testing/generate_examples_lib.py b/tensorflow/lite/testing/generate_examples_lib.py
index 1d257e1..c974070 100644
--- a/tensorflow/lite/testing/generate_examples_lib.py
+++ b/tensorflow/lite/testing/generate_examples_lib.py
@@ -34,6 +34,7 @@
import os
import re
import zipfile
+
import tensorflow as tf
# TODO(aselle): Disable GPU for now
diff --git a/tensorflow/lite/testing/model_coverage/model_coverage_lib.py b/tensorflow/lite/testing/model_coverage/model_coverage_lib.py
index 6d050eb..30d102c 100644
--- a/tensorflow/lite/testing/model_coverage/model_coverage_lib.py
+++ b/tensorflow/lite/testing/model_coverage/model_coverage_lib.py
@@ -19,6 +19,7 @@
from __future__ import print_function
import os
+
import numpy as np
from six import PY3
diff --git a/tensorflow/lite/testing/model_coverage/model_coverage_lib_test.py b/tensorflow/lite/testing/model_coverage/model_coverage_lib_test.py
index 3414903..3f445aa 100644
--- a/tensorflow/lite/testing/model_coverage/model_coverage_lib_test.py
+++ b/tensorflow/lite/testing/model_coverage/model_coverage_lib_test.py
@@ -20,6 +20,7 @@
import os
import tempfile
+
import numpy as np
from tensorflow.lite.python import lite
diff --git a/tensorflow/lite/testing/op_tests/unpack.py b/tensorflow/lite/testing/op_tests/unpack.py
index c408748..0b59444 100644
--- a/tensorflow/lite/testing/op_tests/unpack.py
+++ b/tensorflow/lite/testing/op_tests/unpack.py
@@ -17,7 +17,6 @@
from __future__ import division
from __future__ import print_function
-import numpy as np
import tensorflow as tf
from tensorflow.lite.testing.zip_test_utils import create_tensor_data
from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests
@@ -31,6 +30,7 @@
test_parameters = [{
"base_shape": [[3, 4, 3], [3, 4], [5, 6, 7, 8]],
"axis": [0, 1, 2, 3],
+ "dtype": [tf.int32, tf.bool, tf.float32],
}]
def get_valid_axis(parameters):
@@ -43,12 +43,15 @@
def build_graph(parameters):
input_tensor = tf.compat.v1.placeholder(
- dtype=tf.float32, name=("input"), shape=parameters["base_shape"])
+ dtype=parameters["dtype"],
+ name=("input"),
+ shape=parameters["base_shape"])
outs = tf.unstack(input_tensor, axis=get_valid_axis(parameters))
return [input_tensor], [outs[0]]
def build_inputs(parameters, sess, inputs, outputs):
- input_value = create_tensor_data(np.float32, shape=parameters["base_shape"])
+ input_value = create_tensor_data(
+ parameters["dtype"], shape=parameters["base_shape"])
return [input_value], sess.run(
outputs, feed_dict=dict(zip(inputs, [input_value])))
diff --git a/tensorflow/lite/testing/zip_test_utils.py b/tensorflow/lite/testing/zip_test_utils.py
index 459b72b..3d380ff 100644
--- a/tensorflow/lite/testing/zip_test_utils.py
+++ b/tensorflow/lite/testing/zip_test_utils.py
@@ -25,6 +25,7 @@
import string
import traceback
import zipfile
+
import numpy as np
from six import StringIO
diff --git a/tensorflow/lite/toco/import_tensorflow.cc b/tensorflow/lite/toco/import_tensorflow.cc
index dd7a9e3..26ce2af 100644
--- a/tensorflow/lite/toco/import_tensorflow.cc
+++ b/tensorflow/lite/toco/import_tensorflow.cc
@@ -1194,7 +1194,11 @@
softmax->outputs.push_back(node.name());
// TensorFlow's Softmax doesn't seem to admit a 'beta' parameter.
CHECK(!node.attr().count("beta")); // Stab in the dark, just in case.
- softmax->beta = 1.f;
+ if (node.attr().count("_softmax_beta")) {
+ softmax->beta = GetFloatAttr(node, "_softmax_beta");
+ } else {
+ softmax->beta = 1.f;
+ }
model->operators.emplace_back(softmax);
return tensorflow::Status::OK();
}
@@ -2235,7 +2239,6 @@
tensorflow::FunctionLibraryDefinition fld(tensorflow::OpRegistry::Global(),
graphdef_copy.library());
tensorflow::StaticDeviceMgr device_mgr(std::move(devices));
- tensorflow::OptimizerOptions o_opts;
tensorflow::ProcessFunctionLibraryRuntime pflr(
&device_mgr, tensorflow::Env::Default(), &options.config,
TF_GRAPH_DEF_VERSION, &fld,
diff --git a/tensorflow/lite/toco/import_tensorflow_test.cc b/tensorflow/lite/toco/import_tensorflow_test.cc
index 3e0c530..eb6ed3f 100644
--- a/tensorflow/lite/toco/import_tensorflow_test.cc
+++ b/tensorflow/lite/toco/import_tensorflow_test.cc
@@ -186,6 +186,43 @@
EXPECT_FALSE(model.HasArray("BadType"));
}
+TEST(FlexImportTest, SoftmaxWithBeta) {
+ NodeDef node;
+ node.set_op("Softmax");
+ node.set_name("softmax");
+ node.add_input();
+ node.set_input(0, "logits");
+
+ AttrValue dtype_attr;
+ SetAttrValue(0.5, &dtype_attr);
+ (*node.mutable_attr())["_softmax_beta"] = dtype_attr;
+ Model model;
+ EXPECT_TRUE(ImportNode(node, &model).ok());
+
+ ASSERT_THAT(model.operators.size(), ::testing::Ge(1));
+ ASSERT_EQ(model.operators[0]->type, OperatorType::kSoftmax);
+ const SoftmaxOperator* op =
+ static_cast<const SoftmaxOperator*>(model.operators[0].get());
+ EXPECT_EQ(op->beta, 0.5);
+}
+
+TEST(FlexImportTest, SoftmaxWithoutBeta) {
+ NodeDef node;
+ node.set_op("Softmax");
+ node.set_name("softmax");
+ node.add_input();
+ node.set_input(0, "logits");
+
+ Model model;
+ EXPECT_TRUE(ImportNode(node, &model).ok());
+
+ ASSERT_THAT(model.operators.size(), ::testing::Ge(1));
+ ASSERT_EQ(model.operators[0]->type, OperatorType::kSoftmax);
+ const SoftmaxOperator* op =
+ static_cast<const SoftmaxOperator*>(model.operators[0].get());
+ EXPECT_EQ(op->beta, 1.0);
+}
+
class ShapeImportTest : public ::testing::TestWithParam<tensorflow::DataType> {
};
diff --git a/tensorflow/lite/toco/tflite/op_version.cc b/tensorflow/lite/toco/tflite/op_version.cc
index 241048f..456d877 100644
--- a/tensorflow/lite/toco/tflite/op_version.cc
+++ b/tensorflow/lite/toco/tflite/op_version.cc
@@ -168,6 +168,8 @@
{{OperatorType::kOneHot, 1}, "1.11.0"},
{{OperatorType::kCTCBeamSearchDecoder, 1}, "1.11.0"},
{{OperatorType::kUnpack, 1}, "1.11.0"},
+ {{OperatorType::kUnpack, 2}, "1.14.0"},
+ {{OperatorType::kUnpack, 3}, kPendingReleaseOpVersion},
{{OperatorType::kLeakyRelu, 1}, "1.13.1"},
{{OperatorType::kLogistic, 1}, "1.14.0"},
{{OperatorType::kLogistic, 2}, "1.14.0"},
diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc
index f98a621..f106e4c 100644
--- a/tensorflow/lite/toco/tflite/operator.cc
+++ b/tensorflow/lite/toco/tflite/operator.cc
@@ -1349,6 +1349,21 @@
op->num = options.num();
op->axis = options.axis();
}
+
+ int GetVersion(const OperatorSignature& op_signature) const override {
+ const string& input_name = op_signature.op->inputs[0];
+ const Array& input_array = op_signature.model->GetArray(input_name);
+ // If the op take int8/uint8 input, it is version 2.
+ if (input_array.data_type == ArrayDataType::kInt8 ||
+ input_array.data_type == ArrayDataType::kUint8) {
+ return 2;
+ }
+ // If the op take bool input, it is version 3.
+ if (input_array.data_type == ArrayDataType::kBool) {
+ return 3;
+ }
+ return 1;
+ }
};
class LeakyRelu
diff --git a/tensorflow/lite/tools/accuracy/ilsvrc/BUILD b/tensorflow/lite/tools/accuracy/ilsvrc/BUILD
index d0e5810..9af47e2 100644
--- a/tensorflow/lite/tools/accuracy/ilsvrc/BUILD
+++ b/tensorflow/lite/tools/accuracy/ilsvrc/BUILD
@@ -33,9 +33,10 @@
],
)
-cc_binary(
- name = "imagenet_accuracy_eval",
+cc_library(
+ name = "imagenet_accuracy_eval_lib",
srcs = ["imagenet_accuracy_eval.cc"],
+ hdrs = ["imagenet_accuracy_eval.h"],
copts = tflite_copts(),
linkopts = common_linkopts,
deps = [
@@ -43,9 +44,19 @@
"//tensorflow/core:tflite_portable_logging",
"//tensorflow/lite/c:common",
"//tensorflow/lite/profiling:time",
- "//tensorflow/lite/tools:command_line_flags",
"//tensorflow/lite/tools/accuracy:csv_writer",
"//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",
"@com_google_absl//absl/memory",
],
)
+
+cc_binary(
+ name = "imagenet_accuracy_eval",
+ srcs = ["imagenet_accuracy_eval_main.cc"],
+ copts = tflite_copts(),
+ linkopts = common_linkopts,
+ deps = [
+ ":imagenet_accuracy_eval_lib",
+ "//tensorflow/lite/tools:command_line_flags",
+ ],
+)
diff --git a/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc
index ea4805a..9139cfc 100644
--- a/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc
+++ b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc
@@ -13,12 +13,10 @@
limitations under the License.
==============================================================================*/
+#include "tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.h"
+
#include <cstdlib>
#include <iomanip>
-#include <memory>
-#include <mutex> // NOLINT(build/c++11)
-#include <ostream>
-#include <string>
#include "absl/memory/memory.h"
#include "tensorflow/core/platform/logging.h"
@@ -26,46 +24,55 @@
#include "tensorflow/lite/profiling/time.h"
#include "tensorflow/lite/tools/accuracy/csv_writer.h"
#include "tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h"
-#include "tensorflow/lite/tools/command_line_flags.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
namespace tensorflow {
namespace metrics {
-namespace {
-
using ::tflite::evaluation::TopkAccuracyEvalMetrics;
-constexpr char kNumThreadsFlag[] = "num_threads";
-constexpr char kOutputFilePathFlag[] = "output_file_path";
-constexpr char kProtoOutputFilePathFlag[] = "proto_output_file_path";
+ResultsWriter::ResultsWriter(int top_k, const std::string& output_file_path)
+ : top_k_(top_k) {
+ if (output_file_path.empty()) {
+ LOG(ERROR) << "Empty output file path.";
+ return;
+ }
-// TODO(b/130823599): Move to tools/evaluation/stages/topk_accuracy_eval_stage.
-// Computes total number of images processed & aggregates Top-K accuracies
-// into 'accuracies'.
-void AggregateAccuraciesAndNumImages(
- int k,
- const std::unordered_map<uint64_t, TopkAccuracyEvalMetrics>&
- shard_id_accuracy_metrics_map,
- const std::unordered_map<uint64_t, int>& shard_id_done_image_count_map,
+ output_stream_.reset(new std::ofstream(output_file_path, std::ios::out));
+ if (!output_stream_) {
+ LOG(ERROR) << "Unable to open output file path: '" << output_file_path
+ << "'";
+ }
+
+ (*output_stream_) << std::setprecision(3) << std::fixed;
+ std::vector<string> columns;
+ columns.reserve(top_k);
+ for (int i = 0; i < top_k; i++) {
+ std::string column_name = "Top ";
+ column_name = column_name + std::to_string(i + 1);
+ columns.push_back(column_name);
+ }
+
+ writer_.reset(new CSVWriter(columns, output_stream_.get()));
+}
+
+void ResultsWriter::AggregateAccuraciesAndNumImages(
std::vector<double>* accuracies, int* num_done_images) {
// Total images done.
*num_done_images = 0;
- for (auto iter = shard_id_done_image_count_map.begin();
- iter != shard_id_done_image_count_map.end(); ++iter) {
- *num_done_images += iter->second;
+ for (const auto entry : shard_id_done_image_count_map_) {
+ *num_done_images += entry.second;
}
// Aggregated accuracies.
- for (int i = 0; i < k; ++i) {
+ for (int i = 0; i < top_k_; ++i) {
double correct_inferences = 0;
double total_inferences = 0;
- for (auto iter = shard_id_done_image_count_map.begin();
- iter != shard_id_done_image_count_map.end(); ++iter) {
- const uint64_t shard_id = iter->first;
+ for (const auto entry : shard_id_done_image_count_map_) {
+ const uint64_t shard_id = entry.first;
const TopkAccuracyEvalMetrics& accuracy_metrics =
- shard_id_accuracy_metrics_map.at(shard_id);
- const int num_images = iter->second;
+ shard_id_accuracy_metrics_map_.at(shard_id);
+ const int num_images = entry.second;
correct_inferences += num_images * accuracy_metrics.topk_accuracies(i);
total_inferences += num_images;
}
@@ -74,40 +81,6 @@
}
}
-} // namespace
-
-// Writes results to a CSV file & logs progress to standard output with
-// `kLogDelayUs` microseconds.
-class ResultsWriter : public ImagenetModelEvaluator::Observer {
- public:
- explicit ResultsWriter(int k, std::unique_ptr<CSVWriter> writer)
- : k_(k), writer_(std::move(writer)) {}
-
- void OnEvaluationStart(const std::unordered_map<uint64_t, int>&
- shard_id_image_count_map) override;
-
- void OnSingleImageEvaluationComplete(uint64_t shard_id,
- const TopkAccuracyEvalMetrics& metrics,
- const string& image) override;
-
- TopkAccuracyEvalMetrics AggregatedMetrics();
-
- private:
- // For writing to CSV.
- int k_;
- std::unordered_map<uint64_t, TopkAccuracyEvalMetrics>
- shard_id_accuracy_metrics_map_;
- std::unordered_map<uint64_t, int> shard_id_done_image_count_map_;
- std::unique_ptr<CSVWriter> writer_;
-
- // For logging to stdout.
- uint64_t last_logged_time_us_ = 0;
- int total_num_images_;
- static constexpr int kLogDelayUs = 500 * 1000;
-
- std::mutex mu_;
-};
-
void ResultsWriter::OnEvaluationStart(
const std::unordered_map<uint64_t, int>& shard_id_image_count_map) {
int total_num_images = 0;
@@ -129,9 +102,7 @@
int num_evaluated;
std::vector<double> total_accuracies;
- AggregateAccuraciesAndNumImages(k_, shard_id_accuracy_metrics_map_,
- shard_id_done_image_count_map_,
- &total_accuracies, &num_evaluated);
+ AggregateAccuraciesAndNumImages(&total_accuracies, &num_evaluated);
if (writer_->WriteRow(total_accuracies) != kTfLiteOk) {
LOG(ERROR) << "Could not write to file";
return;
@@ -152,9 +123,7 @@
std::lock_guard<std::mutex> lock(mu_);
int num_evaluated;
std::vector<double> total_accuracies;
- AggregateAccuraciesAndNumImages(k_, shard_id_accuracy_metrics_map_,
- shard_id_done_image_count_map_,
- &total_accuracies, &num_evaluated);
+ AggregateAccuraciesAndNumImages(&total_accuracies, &num_evaluated);
TopkAccuracyEvalMetrics aggregated_metrics;
for (auto accuracy : total_accuracies) {
aggregated_metrics.add_topk_accuracies(accuracy);
@@ -162,74 +131,39 @@
return aggregated_metrics;
}
-int Main(int argc, char* argv[]) {
- std::string output_file_path, proto_output_file_path;
- int num_threads = 4;
- std::vector<tflite::Flag> flag_list = {
- tflite::Flag::CreateFlag(kNumThreadsFlag, &num_threads,
- "Number of threads."),
- tflite::Flag::CreateFlag(kOutputFilePathFlag, &output_file_path,
- "Path to output file."),
- tflite::Flag::CreateFlag(kProtoOutputFilePathFlag,
- &proto_output_file_path,
- "Path to proto output file."),
- };
- tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
-
- std::unique_ptr<ImagenetModelEvaluator> evaluator;
- if (output_file_path.empty()) {
- LOG(ERROR) << "Invalid output file path.";
- return EXIT_FAILURE;
- }
-
- if (num_threads <= 0) {
- LOG(ERROR) << "Invalid number of threads.";
- return EXIT_FAILURE;
- }
-
- if (ImagenetModelEvaluator::Create(argc, argv, num_threads, &evaluator) !=
- kTfLiteOk)
- return EXIT_FAILURE;
-
- std::ofstream output_stream(output_file_path, std::ios::out);
- if (!output_stream) {
- LOG(ERROR) << "Unable to open output file path: '" << output_file_path
- << "'";
- }
-
- output_stream << std::setprecision(3) << std::fixed;
- std::vector<string> columns;
- columns.reserve(evaluator->params().num_ranks);
- for (int i = 0; i < evaluator->params().num_ranks; i++) {
- std::string column_name = "Top ";
- column_name = column_name + std::to_string(i + 1);
- columns.push_back(column_name);
- }
-
- ResultsWriter results_writer(
- evaluator->params().num_ranks,
- absl::make_unique<CSVWriter>(columns, &output_stream));
- evaluator->AddObserver(&results_writer);
- LOG(ERROR) << "Starting evaluation with: " << num_threads << " threads.";
- if (evaluator->EvaluateModel() != kTfLiteOk) {
- LOG(ERROR) << "Failed to evaluate the model!";
- return EXIT_FAILURE;
- }
-
- if (!proto_output_file_path.empty()) {
- std::ofstream proto_out_file(proto_output_file_path,
+void ResultsWriter::OutputEvalMetriccProto(
+ const std::string& proto_output_file) {
+ if (!proto_output_file.empty()) {
+ std::ofstream proto_out_file(proto_output_file,
std::ios::out | std::ios::binary);
- TopkAccuracyEvalMetrics metrics = results_writer.AggregatedMetrics();
+ TopkAccuracyEvalMetrics metrics = AggregatedMetrics();
proto_out_file << metrics.SerializeAsString();
proto_out_file.close();
+ LOG(INFO) << "The result metrics proto is written to " << proto_output_file;
+ } else {
+ LOG(INFO) << "Metrics proto output file path is not specified!";
+ }
+}
+
+std::unique_ptr<ImagenetModelEvaluator> CreateImagenetModelEvaluator(
+ int* argc, char* argv[], int num_threads) {
+ std::unique_ptr<ImagenetModelEvaluator> evaluator;
+ if (ImagenetModelEvaluator::Create(argc, argv, num_threads, &evaluator) !=
+ kTfLiteOk) {
+ evaluator.reset(nullptr);
}
- return EXIT_SUCCESS;
+ return evaluator;
+}
+
+std::unique_ptr<ResultsWriter> CreateImagenetEvalResultsWriter(
+ int top_k, const std::string& output_file_path) {
+ std::unique_ptr<ResultsWriter> writer(
+ new ResultsWriter(top_k, output_file_path));
+ if (!writer->IsValid()) return nullptr;
+
+ return writer;
}
} // namespace metrics
} // namespace tensorflow
-
-int main(int argc, char* argv[]) {
- return tensorflow::metrics::Main(argc, argv);
-}
diff --git a/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.h b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.h
new file mode 100644
index 0000000..6e3d614
--- /dev/null
+++ b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.h
@@ -0,0 +1,86 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_TOOLS_ACCURACY_ILSVRC_IMAGENET_ACCURACY_EVAL_H_
+#define TENSORFLOW_LITE_TOOLS_ACCURACY_ILSVRC_IMAGENET_ACCURACY_EVAL_H_
+
+#include <memory>
+#include <mutex> // NOLINT(build/c++11)
+#include <ostream>
+#include <string>
+
+#include "tensorflow/lite/tools/accuracy/csv_writer.h"
+#include "tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h"
+
+namespace tensorflow {
+namespace metrics {
+
+// Writes topK accuracy results to a CSV file & logs progress to standard output
+// with `kLogDelayUs` microseconds.
+class ResultsWriter : public ImagenetModelEvaluator::Observer {
+ public:
+ ResultsWriter(int top_k, const std::string& output_file_path);
+
+ bool IsValid() const { return writer_ != nullptr; }
+
+ void OnEvaluationStart(const std::unordered_map<uint64_t, int>&
+ shard_id_image_count_map) override;
+
+ void OnSingleImageEvaluationComplete(
+ uint64_t shard_id,
+ const tflite::evaluation::TopkAccuracyEvalMetrics& metrics,
+ const std::string& image) override;
+
+ tflite::evaluation::TopkAccuracyEvalMetrics AggregatedMetrics();
+
+ void OutputEvalMetriccProto(const std::string& proto_output_file);
+
+ private:
+ void AggregateAccuraciesAndNumImages(std::vector<double>* accuracies,
+ int* num_done_images);
+
+ int top_k_ = 0;
+ std::unordered_map<uint64_t, tflite::evaluation::TopkAccuracyEvalMetrics>
+ shard_id_accuracy_metrics_map_;
+ std::unordered_map<uint64_t, int> shard_id_done_image_count_map_;
+
+ // TODO(b/146988222): Refactor CSVWriter to take the memory ownership of
+ // 'output_stream_'.
+ std::unique_ptr<std::ofstream> output_stream_;
+ std::unique_ptr<CSVWriter> writer_;
+
+ // For logging to stdout.
+ uint64_t last_logged_time_us_ = 0;
+ int total_num_images_ = 0;
+ static constexpr int kLogDelayUs = 500 * 1000;
+
+ std::mutex mu_;
+};
+
+// Create an evaluator by parsing command line arguments.
+// Note argc and argv will be updated accordingly as matching arguments will
+// be removed argv.
+std::unique_ptr<ImagenetModelEvaluator> CreateImagenetModelEvaluator(
+ int* argc, char* argv[],
+ int num_threads = 1 // the number of threads used for evaluation.
+);
+
+std::unique_ptr<ResultsWriter> CreateImagenetEvalResultsWriter(
+ int top_k, const std::string& output_file_path);
+
+} // namespace metrics
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LITE_TOOLS_ACCURACY_ILSVRC_IMAGENET_ACCURACY_EVAL_H_
diff --git a/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval_main.cc b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval_main.cc
new file mode 100644
index 0000000..af9bdf3
--- /dev/null
+++ b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval_main.cc
@@ -0,0 +1,70 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.h"
+#include "tensorflow/lite/tools/command_line_flags.h"
+
+namespace {
+constexpr char kNumThreadsFlag[] = "num_threads";
+constexpr char kOutputFilePathFlag[] = "output_file_path";
+constexpr char kProtoOutputFilePathFlag[] = "proto_output_file_path";
+} // namespace
+
+int main(int argc, char* argv[]) {
+ std::string output_file_path, proto_output_file_path;
+ int num_threads = 4;
+ std::vector<tflite::Flag> flag_list = {
+ tflite::Flag::CreateFlag(kNumThreadsFlag, &num_threads,
+ "Number of threads."),
+ tflite::Flag::CreateFlag(kOutputFilePathFlag, &output_file_path,
+ "Path to output file."),
+ tflite::Flag::CreateFlag(kProtoOutputFilePathFlag,
+ &proto_output_file_path,
+ "Path to proto output file."),
+ };
+ tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
+
+ if (num_threads <= 0) {
+ LOG(ERROR) << "Invalid number of threads.";
+ return EXIT_FAILURE;
+ }
+
+ std::unique_ptr<tensorflow::metrics::ImagenetModelEvaluator> evaluator =
+ tensorflow::metrics::CreateImagenetModelEvaluator(&argc, argv,
+ num_threads);
+
+ if (!evaluator) {
+ LOG(ERROR) << "Fail to create the ImagenetModelEvaluator.";
+ return EXIT_FAILURE;
+ }
+
+ std::unique_ptr<tensorflow::metrics::ResultsWriter> writer =
+ tensorflow::metrics::CreateImagenetEvalResultsWriter(
+ evaluator->params().num_ranks, output_file_path);
+ if (!writer) {
+ LOG(ERROR) << "Fail to create the ResultsWriter.";
+ return EXIT_FAILURE;
+ }
+
+ evaluator->AddObserver(writer.get());
+ LOG(ERROR) << "Starting evaluation with: " << num_threads << " threads.";
+ if (evaluator->EvaluateModel() != kTfLiteOk) {
+ LOG(ERROR) << "Failed to evaluate the model!";
+ return EXIT_FAILURE;
+ }
+
+ writer->OutputEvalMetriccProto(proto_output_file_path);
+ return EXIT_SUCCESS;
+}
diff --git a/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc
index d44d560..0e0c778 100644
--- a/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc
+++ b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc
@@ -106,7 +106,7 @@
};
/*static*/ TfLiteStatus ImagenetModelEvaluator::Create(
- int argc, char* argv[], int num_threads,
+ int* argc, char* argv[], int num_threads,
std::unique_ptr<ImagenetModelEvaluator>* model_evaluator) {
Params params;
params.number_of_images = 100;
@@ -147,7 +147,7 @@
"Generates the top-1 to top-k accuracy values"
"where k = num_ranks. Default: 10"),
};
- tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
+ tflite::Flags::Parse(argc, const_cast<const char**>(argv), flag_list);
if (params.number_of_images < 0) {
LOG(ERROR) << "Invalid: num_examples";
diff --git a/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h
index c4c2d66..8776a20 100644
--- a/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h
+++ b/tensorflow/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h
@@ -108,8 +108,10 @@
: params_(params), num_threads_(num_threads) {}
// Factory method to create the evaluator by parsing command line arguments.
+ // Note argc and argv will be updated accordingly as matching arguments will
+ // be removed in argv.
static TfLiteStatus Create(
- int argc, char* argv[], int num_threads,
+ int* argc, char* argv[], int num_threads,
std::unique_ptr<ImagenetModelEvaluator>* evaluator);
// Adds an observer that can observe evaluation events..
diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
index 2edbbd0..f013be8 100644
--- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
+++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
@@ -223,7 +223,6 @@
// Populate input value range if it's specified.
std::vector<std::string> value_ranges = Split(value_ranges_string, ':');
- std::vector<int> tmp_range;
for (const auto val : value_ranges) {
std::vector<std::string> name_range = Split(val, ',');
if (name_range.size() != 3) {
diff --git a/tensorflow/lite/tools/optimize/BUILD b/tensorflow/lite/tools/optimize/BUILD
index bdc1baf..c7c5043 100644
--- a/tensorflow/lite/tools/optimize/BUILD
+++ b/tensorflow/lite/tools/optimize/BUILD
@@ -247,6 +247,8 @@
"//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_minus_127_max_plus_127.bin",
"//tensorflow/lite/tools/optimize:testdata/single_softmax_min_minus_5_max_plus_5.bin",
"//tensorflow/lite/tools/optimize:testdata/split.bin",
+ "//tensorflow/lite/tools/optimize:testdata/svdf_calibrated.bin",
+ "//tensorflow/lite/tools/optimize:testdata/svdf_quantized.bin",
"//tensorflow/lite/tools/optimize:testdata/unpack.bin",
],
tags = [
diff --git a/tensorflow/lite/tools/optimize/calibration/BUILD b/tensorflow/lite/tools/optimize/calibration/BUILD
index f7f3d87..99175ac 100644
--- a/tensorflow/lite/tools/optimize/calibration/BUILD
+++ b/tensorflow/lite/tools/optimize/calibration/BUILD
@@ -19,6 +19,7 @@
"//tensorflow/lite:framework",
"//tensorflow/lite/c:common",
"//tensorflow/lite/kernels:kernel_util",
+ "//tensorflow/lite/kernels:lstm_shared",
"//tensorflow/lite/kernels:op_macros",
"//tensorflow/lite/kernels/internal:kernel_utils",
"//tensorflow/lite/kernels/internal:optimized_base",
diff --git a/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc b/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc
index cc35b14..11f9b64 100644
--- a/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc
+++ b/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc
@@ -25,6 +25,7 @@
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/kernels/lstm_shared.h"
#include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/tools/optimize/calibration/calibration_logger.h"
@@ -485,52 +486,6 @@
int scratch_tensor_index;
};
-// Input Tensors of size {n_batch, n_input}
-constexpr int kInputTensor = 0;
-
-// Input weight tensors of size: {n_cell, n_input}
-constexpr int kInputToInputWeightsTensor = 1; // Optional
-constexpr int kInputToForgetWeightsTensor = 2;
-constexpr int kInputToCellWeightsTensor = 3;
-constexpr int kInputToOutputWeightsTensor = 4;
-
-// Recurrent weight tensors of size {n_cell, n_output}
-constexpr int kRecurrentToInputWeightsTensor = 5; // Optional
-constexpr int kRecurrentToForgetWeightsTensor = 6;
-constexpr int kRecurrentToCellWeightsTensor = 7;
-constexpr int kRecurrentToOutputWeightsTensor = 8;
-
-// Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
-constexpr int kCellToInputWeightsTensor = 9; // Optional
-constexpr int kCellToForgetWeightsTensor = 10; // Optional
-constexpr int kCellToOutputWeightsTensor = 11; // Optional
-
-// Gates bias tensors of size {n_cell}
-constexpr int kInputGateBiasTensor = 12; // Optional
-constexpr int kForgetGateBiasTensor = 13;
-constexpr int kCellGateBiasTensor = 14;
-constexpr int kOutputGateBiasTensor = 15;
-
-// Projection weight tensor of size {n_output, n_cell}
-constexpr int kProjectionWeightsTensor = 16; // Optional
-// Projection bias tensor of size {n_output}
-constexpr int kProjectionBiasTensor = 17; // Optional
-
-// These state tensors are defined as variable tensors, and will be modified by
-// this op.
-constexpr int kInputActivationStateTensor = 18;
-constexpr int kInputCellStateTensor = 19;
-
-// Layer norm coefficient tensors of size {n_cell}, representing a diagonal
-// matrix.
-constexpr int kInputLayerNormCoefficientsTensor = 20; // Optional
-constexpr int kForgetLayerNormCoefficientsTensor = 21; // Optional
-constexpr int kCellLayerNormCoefficientsTensor = 22; // Optional
-constexpr int kOutputLayerNormCoefficientsTensor = 23; // Optional
-
-// Output tensors.
-constexpr int kOutputTensor = 0;
-
// Resize the output, state tensors based on the sizes of the input tensors.
// Allocate a temporary scratch tensor. Also check that the sizes of the input
// tensors match each other.
@@ -538,66 +493,73 @@
Logger* logger) {
const auto* params = static_cast<TfLiteLSTMParams*>(node->builtin_data);
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* input =
+ GetInput(context, node, ops::builtin::lstm::full::kInputTensor);
- const TfLiteTensor* input_to_input_weights =
- GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
- const TfLiteTensor* input_to_forget_weights =
- GetInput(context, node, kInputToForgetWeightsTensor);
- const TfLiteTensor* input_to_cell_weights =
- GetInput(context, node, kInputToCellWeightsTensor);
- const TfLiteTensor* input_to_output_weights =
- GetInput(context, node, kInputToOutputWeightsTensor);
+ const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
+ context, node, ops::builtin::lstm::full::kInputToInputWeightsTensor);
+ const TfLiteTensor* input_to_forget_weights = GetInput(
+ context, node, ops::builtin::lstm::full::kInputToForgetWeightsTensor);
+ const TfLiteTensor* input_to_cell_weights = GetInput(
+ context, node, ops::builtin::lstm::full::kInputToCellWeightsTensor);
+ const TfLiteTensor* input_to_output_weights = GetInput(
+ context, node, ops::builtin::lstm::full::kInputToOutputWeightsTensor);
- const TfLiteTensor* recurrent_to_input_weights =
- GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
- const TfLiteTensor* recurrent_to_forget_weights =
- GetInput(context, node, kRecurrentToForgetWeightsTensor);
- const TfLiteTensor* recurrent_to_cell_weights =
- GetInput(context, node, kRecurrentToCellWeightsTensor);
- const TfLiteTensor* recurrent_to_output_weights =
- GetInput(context, node, kRecurrentToOutputWeightsTensor);
+ const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
+ context, node, ops::builtin::lstm::full::kRecurrentToInputWeightsTensor);
+ const TfLiteTensor* recurrent_to_forget_weights = GetInput(
+ context, node, ops::builtin::lstm::full::kRecurrentToForgetWeightsTensor);
+ const TfLiteTensor* recurrent_to_cell_weights = GetInput(
+ context, node, ops::builtin::lstm::full::kRecurrentToCellWeightsTensor);
+ const TfLiteTensor* recurrent_to_output_weights = GetInput(
+ context, node, ops::builtin::lstm::full::kRecurrentToOutputWeightsTensor);
- const TfLiteTensor* cell_to_input_weights =
- GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
- const TfLiteTensor* cell_to_forget_weights =
- GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
- const TfLiteTensor* cell_to_output_weights =
- GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
+ const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(
+ context, node, ops::builtin::lstm::full::kCellToInputWeightsTensor);
+ const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor(
+ context, node, ops::builtin::lstm::full::kCellToForgetWeightsTensor);
+ const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(
+ context, node, ops::builtin::lstm::full::kCellToOutputWeightsTensor);
- const TfLiteTensor* input_layer_norm_coefficients =
- GetOptionalInputTensor(context, node, kInputLayerNormCoefficientsTensor);
- const TfLiteTensor* forget_layer_norm_coefficients =
- GetOptionalInputTensor(context, node, kForgetLayerNormCoefficientsTensor);
- const TfLiteTensor* cell_layer_norm_coefficients =
- GetOptionalInputTensor(context, node, kCellLayerNormCoefficientsTensor);
- const TfLiteTensor* output_layer_norm_coefficients =
- GetOptionalInputTensor(context, node, kOutputLayerNormCoefficientsTensor);
+ const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor(
+ context, node,
+ ops::builtin::lstm::full::kInputLayerNormCoefficientsTensor);
+ const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor(
+ context, node,
+ ops::builtin::lstm::full::kForgetLayerNormCoefficientsTensor);
+ const TfLiteTensor* cell_layer_norm_coefficients = GetOptionalInputTensor(
+ context, node,
+ ops::builtin::lstm::full::kCellLayerNormCoefficientsTensor);
+ const TfLiteTensor* output_layer_norm_coefficients = GetOptionalInputTensor(
+ context, node,
+ ops::builtin::lstm::full::kOutputLayerNormCoefficientsTensor);
- const TfLiteTensor* input_gate_bias =
- GetOptionalInputTensor(context, node, kInputGateBiasTensor);
+ const TfLiteTensor* input_gate_bias = GetOptionalInputTensor(
+ context, node, ops::builtin::lstm::full::kInputGateBiasTensor);
const TfLiteTensor* forget_gate_bias =
- GetInput(context, node, kForgetGateBiasTensor);
- const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
+ GetInput(context, node, ops::builtin::lstm::full::kForgetGateBiasTensor);
+ const TfLiteTensor* cell_bias =
+ GetInput(context, node, ops::builtin::lstm::full::kCellGateBiasTensor);
const TfLiteTensor* output_gate_bias =
- GetInput(context, node, kOutputGateBiasTensor);
+ GetInput(context, node, ops::builtin::lstm::full::kOutputGateBiasTensor);
- const TfLiteTensor* projection_weights =
- GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
- const TfLiteTensor* projection_bias =
- GetOptionalInputTensor(context, node, kProjectionBiasTensor);
+ const TfLiteTensor* projection_weights = GetOptionalInputTensor(
+ context, node, ops::builtin::lstm::full::kProjectionWeightsTensor);
+ const TfLiteTensor* projection_bias = GetOptionalInputTensor(
+ context, node, ops::builtin::lstm::full::kProjectionBiasTensor);
// Index the scratch buffers pointers to the global scratch buffer.
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
- TfLiteTensor* activation_state =
- GetVariableInput(context, node, kInputActivationStateTensor);
+ TfLiteTensor* activation_state = GetVariableInput(
+ context, node, ops::builtin::lstm::full::kInputActivationStateTensor);
TF_LITE_ENSURE(context, activation_state != nullptr);
- TfLiteTensor* cell_state =
- GetVariableInput(context, node, kInputCellStateTensor);
+ TfLiteTensor* cell_state = GetVariableInput(
+ context, node, ops::builtin::lstm::full::kInputCellStateTensor);
TF_LITE_ENSURE(context, cell_state != nullptr);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ TfLiteTensor* output =
+ GetOutput(context, node, ops::builtin::lstm::full::kOutputTensor);
std::vector<int> intemediate_tensor_indexes(node->intermediates->size);
for (int i = 0; i < node->intermediates->size; ++i) {
diff --git a/tensorflow/lite/tools/optimize/operator_property.cc b/tensorflow/lite/tools/optimize/operator_property.cc
index 9b94bd3..b2044c2 100644
--- a/tensorflow/lite/tools/optimize/operator_property.cc
+++ b/tensorflow/lite/tools/optimize/operator_property.cc
@@ -871,6 +871,29 @@
property.version = 2;
break;
}
+ case BuiltinOperator_SVDF: {
+ TensorProperty tensor_property_time;
+ // Only 10bits are needed because 6bits are reserved for the reduce
+ // operation after elemement-wise multiplication between state and time
+ // weights.
+ tensor_property_time.number_of_bits = 10;
+ TensorProperty tensor_property_bias;
+ tensor_property_bias.use_derived_scale = true;
+ tensor_property_bias.number_of_bits = 32;
+ tensor_property_bias.derived_scale = {{2, 4}, {}, {}};
+ TensorProperty tensor_property_state;
+ tensor_property_state.number_of_bits = 16;
+ tensor_property_state.state_tensor = true;
+
+ property.inputs = {{0, {}},
+ {1, {}},
+ {2, tensor_property_time},
+ {4, tensor_property_state},
+ {3, tensor_property_bias}};
+ property.outputs = {{0, {}}};
+ property.version = 2;
+ break;
+ }
case BuiltinOperator_TRANSPOSE:
property.inputs = {{0, {}}};
property.outputs = {{0, {}}};
diff --git a/tensorflow/lite/tools/optimize/quantize_model.cc b/tensorflow/lite/tools/optimize/quantize_model.cc
index 26d5959..6fc19ff 100644
--- a/tensorflow/lite/tools/optimize/quantize_model.cc
+++ b/tensorflow/lite/tools/optimize/quantize_model.cc
@@ -479,8 +479,26 @@
return utils::SymmetricPerLayerBiasQuantize(model, tensor, scale,
error_reporter);
+ } else if (tensor_property.number_of_bits == 10) {
+ // When the number of bits is 10 (instead of 16), quantize the tensor to
+ // [-512, 512], instead of [-32767, 32767].
+ TensorT* tensor = subgraph->tensors[tensor_idx].get();
+ int total_size = 1;
+ for (int i = 0; i < tensor->shape.size(); ++i) {
+ total_size *= tensor->shape[i];
+ }
+ BufferT* buffer = model->buffers[tensor->buffer].get();
+ float* buffer_data = reinterpret_cast<float*>(buffer->data.data());
+ auto minmax =
+ std::minmax_element(buffer_data, buffer_data + total_size);
+ const float range =
+ std::max(std::abs(*minmax.first), std::abs(*minmax.second));
+ const float quantized_range = 512.0;
+ const float scale = range / quantized_range;
+ return utils::SymmetricQuantizeFloatsToInt16(model, tensor, scale,
+ error_reporter);
} else {
- // Only 8, 16, 32 are supported.
+ // Only 8, 16, 32, 10 are supported.
// TODO(jianlijianli): extend this to support arbitrary bits.
error_reporter->Report(
"Unable to quantize buffer or min/max value for input %d "
@@ -499,14 +517,15 @@
utils::QuantizeActivation(tensor);
} else if (tensor_property.number_of_bits == 16) {
TensorT* tensor = subgraph->tensors[tensor_idx].get();
+ float quantized_range = 32767.0;
float range = std::max(std::abs(tensor->quantization->min[0]),
std::abs(tensor->quantization->max[0]));
if (tensor_property.extend_to_power_of_two) {
const int power_of_two_scale = utils::GetPowerOfTwoScale(
tensor->quantization->min[0], tensor->quantization->max[0]);
range = std::pow(2, power_of_two_scale);
+ quantized_range = 32768.0;
}
- const float quantized_range = 32768.0;
const float scale = range / quantized_range;
utils::QuantizeActivationToInt16(tensor, scale);
}
diff --git a/tensorflow/lite/tools/optimize/quantize_model_test.cc b/tensorflow/lite/tools/optimize/quantize_model_test.cc
index 89038ad..344a605 100644
--- a/tensorflow/lite/tools/optimize/quantize_model_test.cc
+++ b/tensorflow/lite/tools/optimize/quantize_model_test.cc
@@ -1115,6 +1115,65 @@
}
}
+class QuantizeSVDFTest : public QuantizeModelTest {
+ protected:
+ QuantizeSVDFTest() {
+ input_model_ = ReadModel(internal::kSvdfCalibrated);
+ readonly_model_ = input_model_->GetModel();
+ readonly_model_->UnPackTo(&model_);
+ }
+};
+
+TEST_F(QuantizeSVDFTest, VerifySVDF) {
+ // Quantize model.
+ auto status = QuantizeModel(&builder_, &model_, TensorType_INT8,
+ TensorType_INT8, &error_reporter_);
+ ASSERT_EQ(kTfLiteOk, status);
+
+ // Read expected model.
+ auto expected_fb_model = ReadModel(internal::kSvdfQuantized);
+ auto expected_read_only_model = expected_fb_model->GetModel();
+ ModelT expected_model;
+ expected_read_only_model->UnPackTo(&expected_model);
+
+ // Comparison.
+ ASSERT_EQ(model_.subgraphs.size(), expected_model.subgraphs.size());
+ for (size_t subgraph_idx = 0; subgraph_idx < model_.subgraphs.size();
+ subgraph_idx++) {
+ const auto graph = model_.subgraphs[subgraph_idx].get();
+ const auto expected_graph = expected_model.subgraphs[subgraph_idx].get();
+ ASSERT_EQ(graph->tensors.size(), expected_graph->tensors.size());
+ for (size_t i = 0; i < graph->tensors.size(); i++) {
+ const auto tensor = graph->tensors[i].get();
+ const auto expected_tensor = expected_graph->tensors[i].get();
+ EXPECT_EQ(tensor->buffer, expected_tensor->buffer);
+ EXPECT_EQ(tensor->is_variable, expected_tensor->is_variable);
+ EXPECT_EQ(tensor->shape, expected_tensor->shape);
+ EXPECT_EQ(tensor->name, expected_tensor->name);
+ EXPECT_EQ(tensor->type, expected_tensor->type);
+ const auto quantization_params = tensor->quantization.get();
+ const auto expected_quantization_params =
+ expected_tensor->quantization.get();
+ if (quantization_params != nullptr ||
+ expected_quantization_params != nullptr) {
+ EXPECT_NE(quantization_params, nullptr);
+ EXPECT_NE(expected_quantization_params, nullptr);
+ EXPECT_EQ(quantization_params->scale,
+ expected_quantization_params->scale);
+ EXPECT_EQ(quantization_params->zero_point,
+ expected_quantization_params->zero_point);
+ }
+ }
+ }
+ ASSERT_EQ(model_.buffers.size(), expected_model.buffers.size());
+ for (size_t buffer_idx = 0; buffer_idx < model_.buffers.size();
+ ++buffer_idx) {
+ const auto buffer = model_.buffers[buffer_idx].get()->data;
+ const auto expected_buffer = expected_model.buffers[buffer_idx].get()->data;
+ EXPECT_EQ(buffer, expected_buffer);
+ }
+}
+
class QuantizeFCTest : public QuantizeModelTest {
protected:
QuantizeFCTest() {
diff --git a/tensorflow/lite/tools/optimize/test_util.cc b/tensorflow/lite/tools/optimize/test_util.cc
index be99f9e..0d7cfd6 100644
--- a/tensorflow/lite/tools/optimize/test_util.cc
+++ b/tensorflow/lite/tools/optimize/test_util.cc
@@ -59,6 +59,9 @@
const char* kLstmCalibrated2 = "lstm_calibrated2.bin";
const char* kLstmQuantized2 = "lstm_quantized2.bin";
+const char* kSvdfCalibrated = "svdf_calibrated.bin";
+const char* kSvdfQuantized = "svdf_quantized.bin";
+
const char* kModelWithUnpack = "unpack.bin";
int FailOnErrorReporter::Report(const char* format, va_list args) {
diff --git a/tensorflow/lite/tools/optimize/test_util.h b/tensorflow/lite/tools/optimize/test_util.h
index 0d394b0..525fbd0 100644
--- a/tensorflow/lite/tools/optimize/test_util.h
+++ b/tensorflow/lite/tools/optimize/test_util.h
@@ -95,6 +95,10 @@
extern const char* kLstmCalibrated2;
extern const char* kLstmQuantized2;
+// Test model with SVDF op.
+extern const char* kSvdfCalibrated;
+extern const char* kSvdfQuantized;
+
// Test model with an unpack op.
extern const char* kModelWithUnpack;
diff --git a/tensorflow/lite/tools/optimize/testdata/svdf_calibrated.bin b/tensorflow/lite/tools/optimize/testdata/svdf_calibrated.bin
new file mode 100644
index 0000000..e363b4a
--- /dev/null
+++ b/tensorflow/lite/tools/optimize/testdata/svdf_calibrated.bin
Binary files differ
diff --git a/tensorflow/lite/tools/optimize/testdata/svdf_quantized.bin b/tensorflow/lite/tools/optimize/testdata/svdf_quantized.bin
new file mode 100644
index 0000000..fd30ba7
--- /dev/null
+++ b/tensorflow/lite/tools/optimize/testdata/svdf_quantized.bin
Binary files differ
diff --git a/tensorflow/lite/tools/verifier.cc b/tensorflow/lite/tools/verifier.cc
index 4337fa1..c16030b 100644
--- a/tensorflow/lite/tools/verifier.cc
+++ b/tensorflow/lite/tools/verifier.cc
@@ -179,7 +179,8 @@
const int total_dims = sparsity->traversal_order()->size();
- if (sparsity->dim_metadata()->size() != total_dims) {
+ if (total_dims < tensor.shape()->size() ||
+ sparsity->dim_metadata()->size() != total_dims) {
return absl::nullopt;
}
diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc
index e638840..213e7ff 100644
--- a/tensorflow/lite/tools/versioning/op_version.cc
+++ b/tensorflow/lite/tools/versioning/op_version.cc
@@ -219,6 +219,10 @@
op_sig.input_types.at(0) == TensorType_UINT8) {
return 2;
}
+ // If the op take bool input, it is version 3.
+ if (op_sig.input_types.at(0) == TensorType_BOOL) {
+ return 3;
+ }
return 1;
case BuiltinOperator_DEQUANTIZE:
diff --git a/tensorflow/lite/tutorials/mnist_tflite.py b/tensorflow/lite/tutorials/mnist_tflite.py
index 60df266..62c5e27 100644
--- a/tensorflow/lite/tutorials/mnist_tflite.py
+++ b/tensorflow/lite/tutorials/mnist_tflite.py
@@ -17,6 +17,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+
import numpy as np
import tensorflow as tf # pylint: disable=g-bad-import-order
from tensorflow.lite.tutorials import dataset
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index db2222c..9255844 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -21,6 +21,7 @@
visibility = [
"//engedu/ml/tf_from_scratch:__pkg__",
"//third_party/cloud_tpu/convergence_tools:__subpackages__",
+ "//third_party/mlperf:__subpackages__",
"//tensorflow:internal",
"//tensorflow/lite/toco/python:__pkg__",
"//tensorflow_models:__subpackages__",
@@ -394,6 +395,8 @@
srcs = ["lib/core/numpy.cc"],
hdrs = ["lib/core/numpy.h"],
deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
"//third_party/py/numpy:headers",
"//third_party/python_runtime:headers",
],
@@ -405,12 +408,24 @@
hdrs = ["lib/core/bfloat16.h"],
deps = [
":numpy_lib",
+ ":safe_ptr",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//third_party/python_runtime:headers",
],
)
+tf_python_pybind_extension(
+ name = "_pywrap_bfloat16",
+ srcs = ["lib/core/bfloat16_wrapper.cc"],
+ hdrs = ["lib/core/bfloat16.h"],
+ module_name = "_pywrap_bfloat16",
+ deps = [
+ "//third_party/python_runtime:headers",
+ "@pybind11",
+ ],
+)
+
cc_library(
name = "ndarray_tensor_bridge",
srcs = ["lib/core/ndarray_tensor_bridge.cc"],
@@ -421,7 +436,7 @@
],
),
deps = [
- ":ndarray_tensor_types",
+ ":bfloat16_lib",
":numpy_lib",
"//tensorflow/c:c_api",
"//tensorflow/core:lib",
@@ -783,31 +798,6 @@
)
cc_library(
- name = "ndarray_tensor_types",
- srcs = ["lib/core/ndarray_tensor_types.cc"],
- hdrs = ["lib/core/ndarray_tensor_types.h"],
- deps = [
- ":bfloat16_lib",
- ":numpy_lib",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:protos_all_cc",
- "//third_party/python_runtime:headers",
- "@com_google_absl//absl/container:flat_hash_set",
- ],
-)
-
-cc_library(
- name = "ndarray_tensor_types_headers_lib",
- hdrs = ["lib/core/ndarray_tensor_types.h"],
- deps = [
- ":numpy_lib",
- "//tensorflow/core:lib",
- "//tensorflow/core:protos_all_cc",
- ],
-)
-
-cc_library(
name = "ndarray_tensor",
srcs = ["lib/core/ndarray_tensor.cc"],
hdrs = ["lib/core/ndarray_tensor.h"],
@@ -815,8 +805,8 @@
"//learning/deepmind/courier:__subpackages__",
]),
deps = [
+ ":bfloat16_lib",
":ndarray_tensor_bridge",
- ":ndarray_tensor_types",
":numpy_lib",
":safe_ptr",
"//tensorflow/c:c_api",
@@ -1176,7 +1166,6 @@
srcs = ["framework/dtypes.cc"],
module_name = "_dtypes",
deps = [
- ":ndarray_tensor_types_headers_lib",
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:protos_all_cc",
"//third_party/eigen3",
@@ -1190,6 +1179,7 @@
srcs_version = "PY2AND3",
deps = [
":_dtypes",
+ ":_pywrap_bfloat16",
":pywrap_tensorflow",
"//tensorflow/core:protos_all_py",
],
@@ -1751,6 +1741,12 @@
)
py_library(
+ name = "gpu_util",
+ srcs = ["framework/gpu_util.py"],
+ deps = [],
+)
+
+py_library(
name = "framework_test_lib",
srcs = ["framework/test_util.py"],
srcs_version = "PY2AND3",
@@ -1763,6 +1759,7 @@
":client",
":errors",
":framework_for_generated_wrappers",
+ ":gpu_util",
":platform",
":platform_test",
":pywrap_tensorflow",
@@ -1850,6 +1847,17 @@
)
tf_py_test(
+ name = "framework_constant_op_test",
+ size = "small",
+ srcs = ["framework/constant_op_test.py"],
+ main = "framework/constant_op_test.py",
+ python_version = "PY3",
+ deps = [
+ ":constant_op",
+ ],
+)
+
+tf_py_test(
name = "framework_registry_test",
size = "small",
srcs = ["framework/registry_test.py"],
@@ -3651,9 +3659,6 @@
size = "small",
srcs = ["training/experimental/mixed_precision_test.py"],
python_version = "PY3",
- tags = [
- "no_rocm",
- ],
deps = [
":mixed_precision",
"//tensorflow/python:client_testlib",
@@ -4764,7 +4769,6 @@
srcs = ["ops/nn_fused_batchnorm_test.py"],
python_version = "PY3",
shard_count = 16,
- tags = ["no_rocm"],
deps = [
":array_ops",
":client_testlib",
@@ -5507,7 +5511,6 @@
"grappler/item.i",
"grappler/tf_optimizer.i",
"lib/core/strings.i",
- "lib/io/file_io.i",
"lib/io/py_record_reader.i",
"platform/base.i",
"//tensorflow/compiler/mlir/python:mlir.i",
@@ -5518,6 +5521,7 @@
"//conditions:default": None,
}),
deps = [
+ ":bfloat16_lib",
":cost_analyzer_lib",
":model_analyzer_lib",
":cpp_python_util",
@@ -5587,8 +5591,7 @@
":numpy_lib", # checkpoint_reader
":safe_ptr", # checkpoint_reader
":python_op_gen", # python_op_gen
- ":bfloat16_lib", # _dtypes
- ":ndarray_tensor_types", # _dtypes
+ ":bfloat16_lib", # bfloat16
"//tensorflow/python/eager:pywrap_tfe_lib", # pywrap_tfe_lib
"//tensorflow/core/util/tensor_bundle", # checkpoint_reader
"//tensorflow/core/common_runtime/eager:eager_executor", # tfe
@@ -5694,6 +5697,19 @@
# ** Targets for Windows build (end) **
+tf_python_pybind_extension(
+ name = "_pywrap_file_io",
+ srcs = ["lib/io/file_io_wrapper.cc"],
+ module_name = "_pywrap_file_io",
+ deps = [
+ ":pybind11_absl",
+ ":pybind11_status",
+ "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:protos_all_cc",
+ "@pybind11",
+ ],
+)
+
py_library(
name = "lib",
srcs = [
@@ -5703,6 +5719,7 @@
],
srcs_version = "PY2AND3",
deps = [
+ ":_pywrap_file_io",
":_pywrap_record_io",
":errors",
":pywrap_tensorflow",
@@ -6215,6 +6232,7 @@
":client_testlib",
":constant_op",
":dtypes",
+ ":framework_for_generated_wrappers",
":framework_ops",
":training",
":variable_scope",
diff --git a/tensorflow/python/autograph/converters/BUILD b/tensorflow/python/autograph/converters/BUILD
index 7fe43cf..01dbdff 100644
--- a/tensorflow/python/autograph/converters/BUILD
+++ b/tensorflow/python/autograph/converters/BUILD
@@ -134,6 +134,8 @@
":converters",
"//tensorflow/python:client_testlib",
"//tensorflow/python/autograph/core:test_lib",
+ # TOODO(b/145618471): Remove this transitive dependency.
+ "//tensorflow/python/distribute:input_lib",
],
)
diff --git a/tensorflow/python/autograph/core/BUILD b/tensorflow/python/autograph/core/BUILD
index 1b44121..d81723c 100644
--- a/tensorflow/python/autograph/core/BUILD
+++ b/tensorflow/python/autograph/core/BUILD
@@ -47,6 +47,7 @@
visibility = ["//tensorflow:__subpackages__"],
deps = [
":core",
+ "//tensorflow/python/autograph/lang",
"//tensorflow/python/autograph/operators",
"//tensorflow/python/autograph/pyct",
"//tensorflow/python/autograph/pyct/static_analysis",
diff --git a/tensorflow/python/autograph/core/config.py b/tensorflow/python/autograph/core/config.py
index b336ea7..dae441b 100644
--- a/tensorflow/python/autograph/core/config.py
+++ b/tensorflow/python/autograph/core/config.py
@@ -48,6 +48,7 @@
# Known libraries
DoNotConvert('numpy'),
+ DoNotConvert('pandas'),
DoNotConvert('tensorflow'),
DoNotConvert('PIL'),
diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py
index dbcdf43..9e976b3 100644
--- a/tensorflow/python/autograph/impl/api.py
+++ b/tensorflow/python/autograph/impl/api.py
@@ -30,6 +30,7 @@
import traceback
# pylint:disable=g-bad-import-order
+
import six
# pylint:enable=g-bad-import-order
diff --git a/tensorflow/python/autograph/operators/BUILD b/tensorflow/python/autograph/operators/BUILD
index fd92a32..9dbfc82 100644
--- a/tensorflow/python/autograph/operators/BUILD
+++ b/tensorflow/python/autograph/operators/BUILD
@@ -73,6 +73,8 @@
deps = [
":operators",
"//tensorflow/python:client_testlib",
+ # TODO(b/145618471): Remove this transitive dependency.
+ "//tensorflow/python/distribute:input_lib",
],
)
@@ -108,6 +110,8 @@
":operators",
"//tensorflow/python:client_testlib",
"//tensorflow/python/autograph/core",
+ # TODO(b/145618471): Remove this transitive dependency.
+ "//tensorflow/python/distribute:input_lib",
],
)
diff --git a/tensorflow/python/autograph/operators/control_flow.py b/tensorflow/python/autograph/operators/control_flow.py
index 63f9c02..a716ffd 100644
--- a/tensorflow/python/autograph/operators/control_flow.py
+++ b/tensorflow/python/autograph/operators/control_flow.py
@@ -60,6 +60,7 @@
from __future__ import print_function
import functools
+
import numpy as np
from tensorflow.python.autograph.operators import py_builtins
@@ -78,13 +79,19 @@
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops.ragged import ragged_tensor
+from tensorflow.python.util import lazy_loader
from tensorflow.python.util import nest
+# TODO(b/145618471): Remove this dependency.
+# Lazy import to work around circular dependencies
+input_lib = lazy_loader.LazyLoader(
+ 'input_lib', globals(),
+ 'tensorflow.python.distribute.input_lib')
+
LIMIT_PYTHON_ITERATIONS = True
PYTHON_MAX_ITERATIONS = 100000000 # Fails in about one minute for empty loops.
WARN_INEFFICIENT_UNROLL = True
@@ -342,6 +349,11 @@
init_vars, basic_symbol_names,
composite_symbol_names, opts)
+ if isinstance(iter_, input_lib.DistributedIterator):
+ raise NotImplementedError(
+ 'distributed iterators not supported yet, use the distributed dataset'
+ ' directly')
+
# Note: This experimental interface is subject to change.
custom_handler = getattr(iter_, '_autograph_for_loop', None)
if custom_handler is not None:
@@ -409,9 +421,7 @@
lambda: False)
return iterate_index < n
- # TODO(b/134181679): Let the op itself handle optimizations.
- if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
- opts['maximum_iterations'] = n
+ opts['maximum_iterations'] = n
results = _tf_while_stmt(
while_cond,
@@ -525,26 +535,9 @@
def while_cond(iterate, *loop_vars):
"""Cond function for `tf.while_loop`."""
-
- def build_main_test():
- """Main iteration condition."""
- # TODO(b/138857806): The optimizer should handle this.
- # LogicalAnd is slow on GPU so we avoid adding it if `delta` is a
- # compile time constant.
- delta_const = tensor_util.constant_value(delta)
- if delta_const is not None:
- # Support single element arrays.
- delta_const = np.asscalar(delta_const)
- if delta_const >= 0:
- return iterate < limit
- else:
- return iterate > limit
- else:
- return math_ops.logical_or(
- math_ops.logical_and(delta >= 0, iterate < limit),
- math_ops.logical_and(delta < 0, iterate > limit))
-
- main_test = build_main_test()
+ main_test = math_ops.logical_or(
+ math_ops.logical_and(delta >= 0, iterate < limit),
+ math_ops.logical_and(delta < 0, iterate > limit))
if extra_test is not None:
return control_flow_ops.cond(
main_test,
@@ -553,11 +546,8 @@
)
return main_test
- # TODO(b/134181679): The op should handle this optimizations.
- if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
- # This specific dtype is required by while_loop.
- opts['maximum_iterations'] = math_ops.cast(
- misc.get_range_len(start, limit, delta), dtypes.int32)
+ opts['maximum_iterations'] = math_ops.cast(
+ misc.get_range_len(start, limit, delta), dtypes.int32)
results = _tf_while_stmt(
while_cond,
diff --git a/tensorflow/python/autograph/operators/py_builtins.py b/tensorflow/python/autograph/operators/py_builtins.py
index 2d00daf..7df4781 100644
--- a/tensorflow/python/autograph/operators/py_builtins.py
+++ b/tensorflow/python/autograph/operators/py_builtins.py
@@ -38,9 +38,17 @@
from tensorflow.python.ops import gen_string_ops
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.util import lazy_loader
from tensorflow.python.util import nest
+# TODO(b/145618471): Remove this dependency.
+# Lazy import to work around circular dependencies
+input_lib = lazy_loader.LazyLoader(
+ 'input_lib', globals(),
+ 'tensorflow.python.distribute.input_lib')
+
+
UNSPECIFIED = object()
@@ -341,6 +349,10 @@
def enumerate_(s, start=0):
if isinstance(s, dataset_ops.DatasetV2):
return _tf_dataset_enumerate(s, start)
+ if isinstance(
+ s, (input_lib.DistributedIterator, input_lib.DistributedDataset)):
+ raise NotImplementedError(
+ 'use a for loop over the dataset and keep a separate counter')
return _py_enumerate(s, start)
diff --git a/tensorflow/python/autograph/pyct/BUILD b/tensorflow/python/autograph/pyct/BUILD
index b993123..6ea3d8d 100644
--- a/tensorflow/python/autograph/pyct/BUILD
+++ b/tensorflow/python/autograph/pyct/BUILD
@@ -27,6 +27,7 @@
"cfg.py",
"error_utils.py",
"errors.py",
+ "gast_util.py",
"inspect_utils.py",
"loader.py",
"origin_info.py",
diff --git a/tensorflow/python/autograph/pyct/anno.py b/tensorflow/python/autograph/pyct/anno.py
index a8ae864..2a81530 100644
--- a/tensorflow/python/autograph/pyct/anno.py
+++ b/tensorflow/python/autograph/pyct/anno.py
@@ -24,6 +24,7 @@
import enum
# pylint:disable=g-bad-import-order
+
import gast
# pylint:enable=g-bad-import-order
diff --git a/tensorflow/python/autograph/pyct/cfg.py b/tensorflow/python/autograph/pyct/cfg.py
index c2da09e..02618a8 100644
--- a/tensorflow/python/autograph/pyct/cfg.py
+++ b/tensorflow/python/autograph/pyct/cfg.py
@@ -35,6 +35,7 @@
from enum import Enum
# pylint:disable=g-bad-import-order
+
import gast
# pylint:enable=g-bad-import-order
diff --git a/tensorflow/python/autograph/pyct/gast_util.py b/tensorflow/python/autograph/pyct/gast_util.py
new file mode 100644
index 0000000..49eb314
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/gast_util.py
@@ -0,0 +1,78 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Gast compatibility library. Supports 0.2.2 and 0.3.2."""
+# TODO(mdan): Remove this file once it's safe to break compatibility.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+import gast
+
+
+GAST2 = hasattr(gast, 'Str')
+GAST3 = not GAST2
+
+
+def _is_constant_gast_2(node):
+ return isinstance(node, (gast.Num, gast.Str, gast.Bytes, gast.Ellipsis,
+ gast.NameConstant))
+
+
+def _is_constant_gast_3(node):
+ return isinstance(node, gast.Constant)
+
+
+def is_literal(node):
+ """Tests whether node represents a Python literal."""
+ # Normal literals, True/False/None/Etc. in Python3
+ if is_constant(node):
+ return True
+
+ # True/False/None/Etc. in Python2
+ if isinstance(node, gast.Name) and node.id in ['True', 'False', 'None']:
+ return True
+
+ return False
+
+
+def _is_ellipsis_gast_2(node):
+ return isinstance(node, gast.Ellipsis)
+
+
+def _is_ellipsis_gast_3(node):
+ return isinstance(node, gast.Constant) and node.value == Ellipsis
+
+
+if GAST2:
+ is_constant = _is_constant_gast_2
+ is_ellipsis = _is_ellipsis_gast_2
+
+ Module = gast.Module
+ Name = gast.Name
+ Str = gast.Str
+
+elif GAST3:
+ is_constant = _is_constant_gast_3
+ is_ellipsis = _is_ellipsis_gast_3
+
+ Module = functools.partial(gast.Module, type_ignores=None) # pylint:disable=invalid-name
+ Name = functools.partial(gast.Name, type_comment=None) # pylint:disable=invalid-name
+ Str = functools.partial(gast.Constant, kind=None) # pylint:disable=invalid-name
+
+else:
+ assert False
diff --git a/tensorflow/python/autograph/pyct/pretty_printer.py b/tensorflow/python/autograph/pyct/pretty_printer.py
index d6e8f86..c4d74d0 100644
--- a/tensorflow/python/autograph/pyct/pretty_printer.py
+++ b/tensorflow/python/autograph/pyct/pretty_printer.py
@@ -18,7 +18,6 @@
from __future__ import division
from __future__ import print_function
-
import gast
import six
import termcolor
diff --git a/tensorflow/python/client/debug_events_writer_wrapper.cc b/tensorflow/python/client/debug_events_writer_wrapper.cc
index 3c0cd31..75abf70 100644
--- a/tensorflow/python/client/debug_events_writer_wrapper.cc
+++ b/tensorflow/python/client/debug_events_writer_wrapper.cc
@@ -29,7 +29,7 @@
using namespace tensorflow::tfdbg; // NOLINT(build/namespaces)
m.def("Init",
- [](const std::string dump_root, const int64 circular_buffer_size) {
+ [](const std::string& dump_root, const int64 circular_buffer_size) {
DebugEventsWriter* writer = DebugEventsWriter::GetDebugEventsWriter(
dump_root, circular_buffer_size);
if (!writer->Init().ok()) {
@@ -39,7 +39,7 @@
}
});
m.def("WriteSourceFile",
- [](const std::string dump_root, const py::object obj) {
+ [](const std::string& dump_root, const py::object obj) {
CheckProtoType(obj, "tensorflow.DebugEvent");
DebugEventsWriter* writer =
DebugEventsWriter::GetDebugEventsWriter(dump_root);
@@ -48,7 +48,7 @@
tfdbg::DebugEventFileType::SOURCE_FILES);
});
m.def("WriteStackFrameWithId",
- [](const std::string dump_root, const py::object obj) {
+ [](const std::string& dump_root, const py::object& obj) {
CheckProtoType(obj, "tensorflow.DebugEvent");
DebugEventsWriter* writer =
DebugEventsWriter::GetDebugEventsWriter(dump_root);
@@ -57,7 +57,7 @@
tfdbg::DebugEventFileType::STACK_FRAMES);
});
m.def("WriteGraphOpCreation",
- [](const std::string dump_root, const py::object obj) {
+ [](const std::string& dump_root, const py::object& obj) {
CheckProtoType(obj, "tensorflow.DebugEvent");
DebugEventsWriter* writer =
DebugEventsWriter::GetDebugEventsWriter(dump_root);
@@ -66,7 +66,7 @@
tfdbg::DebugEventFileType::GRAPHS);
});
m.def("WriteDebuggedGraph",
- [](const std::string dump_root, const py::object obj) {
+ [](const std::string& dump_root, const py::object& obj) {
CheckProtoType(obj, "tensorflow.DebugEvent");
DebugEventsWriter* writer =
DebugEventsWriter::GetDebugEventsWriter(dump_root);
@@ -75,7 +75,7 @@
tfdbg::DebugEventFileType::GRAPHS);
});
m.def("WriteExecution",
- [](const std::string dump_root, const py::object obj) {
+ [](const std::string& dump_root, const py::object& obj) {
CheckProtoType(obj, "tensorflow.DebugEvent");
DebugEventsWriter* writer =
DebugEventsWriter::GetDebugEventsWriter(dump_root);
@@ -84,7 +84,7 @@
tfdbg::DebugEventFileType::EXECUTION);
});
m.def("WriteGraphExecutionTrace",
- [](const std::string dump_root, const py::object obj) {
+ [](const std::string& dump_root, const py::object& obj) {
CheckProtoType(obj, "tensorflow.DebugEvent");
DebugEventsWriter* writer =
DebugEventsWriter::GetDebugEventsWriter(dump_root);
@@ -92,17 +92,23 @@
obj.attr("SerializeToString")().cast<std::string>(),
tfdbg::DebugEventFileType::GRAPH_EXECUTION_TRACES);
});
- m.def("FlushNonExecutionFiles", [](const std::string dump_root) {
+ m.def("RegisterDeviceAndGetId",
+ [](const std::string& dump_root, const std::string& device_name) {
+ DebugEventsWriter* writer =
+ DebugEventsWriter::GetDebugEventsWriter(dump_root);
+ return writer->RegisterDeviceAndGetId(device_name);
+ });
+ m.def("FlushNonExecutionFiles", [](const std::string& dump_root) {
DebugEventsWriter* writer =
DebugEventsWriter::GetDebugEventsWriter(dump_root);
writer->FlushNonExecutionFiles();
});
- m.def("FlushExecutionFiles", [](const std::string dump_root) {
+ m.def("FlushExecutionFiles", [](const std::string& dump_root) {
DebugEventsWriter* writer =
DebugEventsWriter::GetDebugEventsWriter(dump_root);
writer->FlushExecutionFiles();
});
- m.def("Close", [](const std::string dump_root) {
+ m.def("Close", [](const std::string& dump_root) {
DebugEventsWriter* writer =
DebugEventsWriter::GetDebugEventsWriter(dump_root);
writer->Close();
diff --git a/tensorflow/python/client/events_writer_wrapper.cc b/tensorflow/python/client/events_writer_wrapper.cc
index b37f970..22b3811 100644
--- a/tensorflow/python/client/events_writer_wrapper.cc
+++ b/tensorflow/python/client/events_writer_wrapper.cc
@@ -33,8 +33,7 @@
.def("FileName",
[](tensorflow::EventsWriter& self) { return self.FileName(); })
.def("_WriteSerializedEvent",
- [](tensorflow::EventsWriter& self,
- const absl::string_view event_str) {
+ [](tensorflow::EventsWriter& self, const std::string& event_str) {
self.WriteSerializedEvent(event_str);
})
.def("Flush", [](tensorflow::EventsWriter& self) { return self.Flush(); })
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index b200340..90d1c36 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -31,7 +31,7 @@
# This value changes every day with an automatic CL. It can be modified in code
# via `forward_compatibility_horizon()` or with the environment variable
# TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2019, 12, 19)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2019, 12, 30)
_FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
_FORWARD_COMPATIBILITY_DATE_NUMBER = None
diff --git a/tensorflow/python/compiler/xla/BUILD b/tensorflow/python/compiler/xla/BUILD
index 2061f0c..a8c4ce2 100644
--- a/tensorflow/python/compiler/xla/BUILD
+++ b/tensorflow/python/compiler/xla/BUILD
@@ -70,7 +70,6 @@
srcs = ["xla_test.py"],
tags = [
"no_mac",
- "no_rocm", # XLA support is not enabled on the ROCm platform
"no_windows",
],
xla_enabled = True,
@@ -91,3 +90,20 @@
"@absl_py//absl/testing:parameterized",
],
)
+
+cuda_py_test(
+ name = "experimental_compile_test",
+ srcs = ["experimental_compile_test.py"],
+ additional_deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:resource_variable_ops",
+ ],
+ python_version = "PY3",
+ tags = [
+ "no_mac",
+ "no_windows",
+ ],
+ xla_enabled = True,
+)
diff --git a/tensorflow/python/compiler/xla/experimental_compile_test.py b/tensorflow/python/compiler/xla/experimental_compile_test.py
new file mode 100644
index 0000000..c0a1c4b
--- /dev/null
+++ b/tensorflow/python/compiler/xla/experimental_compile_test.py
@@ -0,0 +1,113 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.client import session
+from tensorflow.python.eager import backprop
+from tensorflow.python.eager import def_function
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class ExperimentalCompileTest(test.TestCase):
+
+ def testBasic(self):
+ with ops.Graph().as_default() as g:
+
+ def fn(x, a):
+ return x + a
+
+ xla_func = def_function.function(fn, experimental_compile=True)
+ inputs = array_ops.placeholder(dtypes.float32, [5])
+ # XLA support is not yet enabled for TF ROCm
+ if not test.is_built_with_rocm():
+ x = xla_func(inputs, 1)
+ with session.Session(graph=g) as sess:
+ y = sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]})
+ self.assertTrue(x.graph.as_graph_def().library.function[0]
+ .attr["_XlaMustCompile"].b)
+ self.assertAllClose([2, 3, 3, 4, 4], y)
+
+ def testDerivative(self):
+ # XLA support is not yet enabled for TF ROCm
+ if test.is_built_with_rocm():
+ return
+
+ def fn(x, a):
+ return 2 * x + a
+
+ with ops.Graph().as_default() as g:
+ xla_func = def_function.function(fn, experimental_compile=True)
+ with backprop.GradientTape() as tape:
+ inputs = array_ops.placeholder(dtypes.float32, [5])
+ tape.watch(inputs)
+ outputs = xla_func(inputs, 1)
+ grads = tape.gradient(outputs, inputs)
+
+ with session.Session(graph=g) as sess:
+ grads_tensor = sess.run(grads, feed_dict={inputs: [1, 2, 2, 3, 3]})
+ self.assertAllClose([2, 2, 2, 2, 2], grads_tensor)
+ (forward, backward) = xla_func.get_concrete_function(
+ inputs, 1)._delayed_rewrite_functions.forward_backward()
+
+ # Check that the must-compile attribute gets correctly propagated to the
+ # created derivatives.
+ self.assertTrue(forward.definition.attr["_XlaMustCompile"])
+ self.assertTrue(backward.function_def.attr["_XlaMustCompile"])
+
+ def testBasicInt32(self):
+ with ops.Graph().as_default() as g:
+
+ def fn(x, a):
+ return x + a
+
+ xla_func = def_function.function(fn, experimental_compile=True)
+ inputs = array_ops.placeholder(dtypes.int32, [5])
+ # XLA support is not yet enabled for TF ROCm
+ if not test.is_built_with_rocm():
+ x = xla_func(inputs, 1)
+ with session.Session(graph=g) as sess:
+ y = sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]})
+ self.assertTrue(x.graph.as_graph_def().library.function[0]
+ .attr["_XlaMustCompile"].b)
+ self.assertAllClose([2, 3, 3, 4, 4], y)
+
+ # Checking that we crash on an unsupported operation lets us test that the XLA
+ # compiler was actually invoked.
+ def testUnsupportedOps(self):
+ with ops.Graph().as_default() as g:
+
+ def fn(x):
+ return array_ops.unique(x).y # Unique is not supported by XLA
+
+ xla_func = def_function.function(fn, experimental_compile=True)
+ inputs = array_ops.placeholder(dtypes.float32, [5])
+ x = xla_func(inputs)
+ # XLA support is not yet enabled for TF ROCm
+ if not test.is_built_with_rocm():
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ "not compilable"):
+ with session.Session(graph=g) as sess:
+ sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]})
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/benchmarks/csv_dataset_benchmark.py b/tensorflow/python/data/experimental/benchmarks/csv_dataset_benchmark.py
index a065e7d..9348ae8 100644
--- a/tensorflow/python/data/experimental/benchmarks/csv_dataset_benchmark.py
+++ b/tensorflow/python/data/experimental/benchmarks/csv_dataset_benchmark.py
@@ -51,7 +51,7 @@
self._filenames = []
for n in self._num_cols:
fn = os.path.join(self._temp_dir, 'file%d.csv' % n)
- with open(fn, 'wb') as f:
+ with open(fn, 'w') as f:
# Just write 100 rows and use `repeat`... Assumes the cost
# of creating an iterator is not significant
row = ','.join(str_val for _ in range(n))
diff --git a/tensorflow/python/data/experimental/benchmarks/map_and_batch_benchmark.py b/tensorflow/python/data/experimental/benchmarks/map_and_batch_benchmark.py
index d6950a0..ac3646a 100644
--- a/tensorflow/python/data/experimental/benchmarks/map_and_batch_benchmark.py
+++ b/tensorflow/python/data/experimental/benchmarks/map_and_batch_benchmark.py
@@ -116,7 +116,7 @@
def name(method, label, num_calls, inter_op, element_size, batch_size):
return ("%s_id_%s_num_calls_%d_inter_op_%d_elem_size_%d_batch_size_%d" % (
method,
- hashlib.sha1(label).hexdigest()[:8],
+ hashlib.sha1((label).encode("utf-8")).hexdigest()[:8],
num_calls,
inter_op,
element_size,
diff --git a/tensorflow/python/data/experimental/kernel_tests/matching_files_test.py b/tensorflow/python/data/experimental/kernel_tests/matching_files_test.py
index 1240b70..d52f348 100644
--- a/tensorflow/python/data/experimental/kernel_tests/matching_files_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/matching_files_test.py
@@ -20,6 +20,7 @@
import os
import shutil
import tempfile
+
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import matching_files
diff --git a/tensorflow/python/data/experimental/kernel_tests/rejection_resample_test.py b/tensorflow/python/data/experimental/kernel_tests/rejection_resample_test.py
index fb1d4ea..e9cefb2 100644
--- a/tensorflow/python/data/experimental/kernel_tests/rejection_resample_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/rejection_resample_test.py
@@ -17,7 +17,6 @@
from __future__ import division
from __future__ import print_function
-
from absl.testing import parameterized
import numpy as np
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/snapshot_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/snapshot_dataset_serialization_test.py
index c1d55da..6bdb6e0 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/snapshot_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/snapshot_dataset_serialization_test.py
@@ -18,6 +18,7 @@
from __future__ import print_function
import os
+
from absl.testing import parameterized
from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
diff --git a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py b/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py
index d922d82..d4868e8 100644
--- a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py
@@ -20,6 +20,7 @@
import os
import shutil
import time
+
from absl.testing import parameterized
from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
diff --git a/tensorflow/python/data/experimental/kernel_tests/stats_dataset_test_base.py b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_test_base.py
index 23647ec..bf70b69 100644
--- a/tensorflow/python/data/experimental/kernel_tests/stats_dataset_test_base.py
+++ b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_test_base.py
@@ -19,6 +19,7 @@
import os
import re
+
import numpy as np
from tensorflow.core.framework import summary_pb2
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 3f6906e..ecaaecd 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -83,9 +83,6 @@
name = "dataset_test",
size = "small",
srcs = ["dataset_test.py"],
- tags = [
- "no_rocm",
- ],
deps = [
":test_base",
"//tensorflow/core:protos_all_py",
diff --git a/tensorflow/python/data/kernel_tests/padded_batch_test.py b/tensorflow/python/data/kernel_tests/padded_batch_test.py
index a3b8f39..6e151af 100644
--- a/tensorflow/python/data/kernel_tests/padded_batch_test.py
+++ b/tensorflow/python/data/kernel_tests/padded_batch_test.py
@@ -99,18 +99,26 @@
batch_size=4, padded_shapes=[-1]))
self.assertDatasetProduces(dataset, expected_output=[[[], [], [], []]])
- @combinations.generate(test_base.default_test_combinations())
- def testPaddedBatchDatasetNonDefaultPadding(self):
+ @combinations.generate(
+ combinations.times(
+ test_base.default_test_combinations(),
+ combinations.combine(
+ padding_values=[(-1, '<end>', {'structure': ''}),
+ (-1, '<end>', None)])))
+ def testPaddedBatchDatasetNonDefaultPadding(self, padding_values):
def fill_tuple(x):
filled = array_ops.fill([x], x)
- return (filled, string_ops.as_string(filled))
+ return (filled, string_ops.as_string(filled), {
+ 'structure': string_ops.as_string(filled)
+ })
random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32)
dataset = (
dataset_ops.Dataset.from_tensor_slices(random_seq_lens).map(fill_tuple)
.padded_batch(
- 4, padded_shapes=([-1], [-1]), padding_values=(-1, '<end>')))
+ 4, padded_shapes=([-1], [-1], {'structure': [-1]}),
+ padding_values=padding_values))
get_next = self.getNext(dataset)
for i in range(8):
@@ -118,6 +126,7 @@
padded_len = np.max(result[0])
self.assertEqual((4, padded_len), result[0].shape)
self.assertEqual((4, padded_len), result[1].shape)
+ self.assertEqual((4, padded_len), result[2]['structure'].shape)
for j in range(4):
seq_len = random_seq_lens[(i * 4) + j]
self.assertAllEqual(result[0][j, :seq_len], [seq_len] * seq_len)
@@ -127,6 +136,10 @@
[compat.as_bytes(str(seq_len))] * seq_len)
self.assertAllEqual(result[1][j, seq_len:],
[b'<end>'] * (padded_len - seq_len))
+ self.assertAllEqual(result[2]['structure'][j, :seq_len],
+ [compat.as_bytes(str(seq_len))] * seq_len)
+ self.assertAllEqual(result[2]['structure'][j, seq_len:],
+ [b''] * (padded_len - seq_len))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
diff --git a/tensorflow/python/data/kernel_tests/shuffle_test.py b/tensorflow/python/data/kernel_tests/shuffle_test.py
index a42abb5..7a1273c 100644
--- a/tensorflow/python/data/kernel_tests/shuffle_test.py
+++ b/tensorflow/python/data/kernel_tests/shuffle_test.py
@@ -243,16 +243,14 @@
combinations.combine(tf_api_version=2, mode="eager"),
combinations.combine(reshuffle=[True, False], seed=[None, 42])))
def testReshuffleIterationEpochs(self, reshuffle, seed):
+ # TensorFlow unit tests set the global graph seed. We unset it here so that
+ # we can control determinism via the `seed` parameter.
+ random_seed.set_random_seed(None)
dataset = dataset_ops.Dataset.range(10).shuffle(
10, seed=seed, reshuffle_each_iteration=reshuffle)
- first_epoch = []
- for elem in dataset:
- first_epoch.append(elem.numpy())
-
- second_epoch = []
- for elem in dataset:
- second_epoch.append(elem.numpy())
+ first_epoch = self.getDatasetOutput(dataset)
+ second_epoch = self.getDatasetOutput(dataset)
self.assertEqual(first_epoch == second_epoch, not reshuffle)
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 6eda3c4..f7fec93 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -1457,8 +1457,9 @@
maximum size of that dimension in each batch.
padding_values: (Optional.) A nested structure of scalar-shaped
`tf.Tensor`, representing the padding values to use for the respective
- components. Defaults are `0` for numeric types and the empty string for
- string types.
+ components. None represents that the nested structure should be padded
+ with default values. Defaults are `0` for numeric types and the empty
+ string for string types.
drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
whether the last batch should be dropped in the case it has fewer than
`batch_size` elements; the default behavior is not to drop the smaller
@@ -3769,8 +3770,8 @@
return value
-def _default_padding(input_dataset):
- """Returns default padding tensors in a structure matching `input_dataset`."""
+def _padding_values_or_default(padding_values, input_dataset):
+ """Returns padding values with None elements replaced with default values."""
def make_zero(t):
if t.base_dtype == dtypes.string:
return ""
@@ -3782,9 +3783,13 @@
raise TypeError(error_msg)
else:
return np.zeros_like(t.as_numpy_dtype())
+ def value_or_default(value, default):
+ return default if value is None else value
- return nest.map_structure(
- make_zero, get_legacy_output_types(input_dataset))
+ default_padding = nest.map_structure(make_zero,
+ get_legacy_output_types(input_dataset))
+ return nest.map_structure_up_to(padding_values, value_or_default,
+ padding_values, default_padding)
class PaddedBatchDataset(UnaryDataset):
@@ -3801,9 +3806,7 @@
self._input_dataset = input_dataset
self._batch_size = ops.convert_to_tensor(
batch_size, dtype=dtypes.int64, name="batch_size")
- padding_values = (
- padding_values
- if padding_values is not None else _default_padding(input_dataset))
+ padding_values = _padding_values_or_default(padding_values, input_dataset)
input_shapes = get_legacy_output_shapes(input_dataset)
flat_padded_shapes = nest.flatten_up_to(input_shapes, padded_shapes)
diff --git a/tensorflow/python/data/util/BUILD b/tensorflow/python/data/util/BUILD
index 4613d76..b5dc355 100644
--- a/tensorflow/python/data/util/BUILD
+++ b/tensorflow/python/data/util/BUILD
@@ -156,6 +156,7 @@
deps = [
":convert",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:util",
],
)
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index c3b49e7..58bebfa 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -120,6 +120,7 @@
deps = [
"//tensorflow/core:protos_all_py",
"//tensorflow/python:framework",
+ "@six_archive//:six",
],
)
@@ -715,7 +716,6 @@
tags = [
"guitar",
"multi_and_single_gpu",
- "no_rocm",
"no_windows", # TODO(b/142475891): Enable this test on Windows.
"no_windows_gpu", # TODO(b/130551176)
],
diff --git a/tensorflow/python/debug/lib/debug_events_reader.py b/tensorflow/python/debug/lib/debug_events_reader.py
index c6142c6..9033e48 100644
--- a/tensorflow/python/debug/lib/debug_events_reader.py
+++ b/tensorflow/python/debug/lib/debug_events_reader.py
@@ -18,17 +18,24 @@
from __future__ import division
from __future__ import print_function
+import collections
import glob
import os
import threading
-from six.moves import map
+import six
from tensorflow.core.protobuf import debug_event_pb2
-from tensorflow.python.lib.io import tf_record
+from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import tensor_util
from tensorflow.python.util import compat
+DebugEventWithOffset = collections.namedtuple(
+ "DebugEventWithOffset", "debug_event offset")
+
+
class DebugEventsReader(object):
"""Reader class for a tfdbg v2 DebugEvents directory."""
@@ -56,6 +63,8 @@
self._readers = dict() # A map from file path to reader.
self._readers_lock = threading.Lock()
+ self._offsets = dict()
+
def __enter__(self):
return self
@@ -64,15 +73,48 @@
self.close()
def _generic_iterator(self, file_path):
- """A helper method that makes an iterator given a debug-events file path."""
+ """A helper method that makes an iterator given a debug-events file path.
+
+ Repeated calls to this method create iterators that remember the last
+ successful reading position (offset) for each given `file_path`. So the
+ iterators are meant for incremental reading of the file.
+
+ Args:
+ file_path: Path to the file to create the iterator for.
+
+ Yields:
+ A tuple of (offset, debug_event_proto) on each `next()` call.
+ """
# The following code uses the double-checked locking pattern to optimize
# the common case (where the reader is already initialized).
if file_path not in self._readers: # 1st check, without lock.
with self._readers_lock:
if file_path not in self._readers: # 2nd check, with lock.
- self._readers[file_path] = tf_record.tf_record_iterator(file_path)
+ with errors.raise_exception_on_not_ok_status() as status:
+ # TODO(b/136474806): Use tf_record.tf_record_iterator() once it
+ # supports offset.
+ self._readers[file_path] = pywrap_tensorflow.PyRecordReader_New(
+ compat.as_bytes(file_path), 0, b"", status)
+ reader = self._readers[file_path]
+ while True:
+ offset = reader.offset()
+ try:
+ reader.GetNext()
+ except (errors.DataLossError, errors.OutOfRangeError):
+ # We ignore partial read exceptions, because a record may be truncated.
+ # PyRecordReader holds the offset prior to the failed read, so retrying
+ # will succeed.
+ break
+ yield DebugEventWithOffset(
+ debug_event=debug_event_pb2.DebugEvent.FromString(reader.record()),
+ offset=offset)
- return map(debug_event_pb2.DebugEvent.FromString, self._readers[file_path])
+ def _create_offset_reader(self, file_path, offset):
+ with errors.raise_exception_on_not_ok_status() as status:
+ # TODO(b/136474806): Use tf_record.tf_record_iterator() once it
+ # supports ofset.
+ return pywrap_tensorflow.PyRecordReader_New(
+ file_path, offset, b"", status)
def metadata_iterator(self):
return self._generic_iterator(self._metadata_path)
@@ -86,12 +128,890 @@
def graphs_iterator(self):
return self._generic_iterator(self._graphs_path)
+ def read_graphs_event(self, offset):
+ """Read a DebugEvent proto at a given offset from the .graphs file.
+
+ Args:
+ offset: Offset to read the DebugEvent proto from.
+
+ Returns:
+ A DebugEventProto.
+
+ Raises:
+ `errors.DataLossError` if offset is at a wrong location.
+ `errors.OutOfRangeError` if offset is out of range of the file.
+ """
+ # TODO(cais): After switching to new Python wrapper of tfrecord reader,
+ # use seeking instead of repeated file opening. Same below.
+ reader = self._create_offset_reader(self._graphs_path, offset)
+ reader.GetNext()
+ debug_event = debug_event_pb2.DebugEvent.FromString(reader.record())
+ reader.Close()
+ return debug_event
+
def execution_iterator(self):
return self._generic_iterator(self._execution_path)
+ def read_execution_debug_event(self, offset):
+ """Read a DebugEvent proto at a given offset from the .execution file.
+
+ Args:
+ offset: Offset to read the DebugEvent proto from.
+
+ Returns:
+ A DebugEventProto.
+
+ Raises:
+ `errors.DataLossError` if offset is at a wrong location.
+ `errors.OutOfRangeError` if offset is out of range of the file.
+ """
+ reader = self._create_offset_reader(self._execution_path, offset)
+ reader.GetNext()
+ debug_event = debug_event_pb2.DebugEvent.FromString(reader.record())
+ reader.Close()
+ return debug_event
+
def graph_execution_traces_iterator(self):
return self._generic_iterator(self._graph_execution_traces_path)
+ def read_graph_execution_traces_event(self, offset):
+ """Read DebugEvent at given offset from .graph_execution_traces file.
+
+ Args:
+ offset: Offset to read the DebugEvent proto from.
+
+ Returns:
+ A DebugEventProto.
+
+ Raises:
+ `errors.DataLossError` if offset is at a wrong location.
+ `errors.OutOfRangeError` if offset is out of range of the file.
+ """
+ reader = self._create_offset_reader(
+ self._graph_execution_traces_path, offset)
+ reader.GetNext()
+ debug_event = debug_event_pb2.DebugEvent.FromString(reader.record())
+ reader.Close()
+ return debug_event
+
def close(self):
- with self._readers_lock:
- self._readers.clear()
+ for reader in self._readers.values():
+ reader.Close()
+
+
+class BaseDigest(object):
+ """Base class for digest.
+
+ Properties:
+ wall_time: A timestamp for the digest (unit: s).
+ offset: A offset number in the corresponding file that can be used for
+ fast random read access.
+ """
+
+ def __init__(self, wall_time, offset):
+ self._wall_time = wall_time
+ self._offset = offset
+
+ @property
+ def wall_time(self):
+ return self._wall_time
+
+ @property
+ def offset(self):
+ return self._offset
+
+
+class ExecutionDigest(BaseDigest):
+ """Light-weight digest summarizing top-level execution event.
+
+ Use `DebugDataReader.read_execution(execution_digest)` to load the more
+ detailed data object concerning the execution event (`Execution`).
+
+ Properties:
+ op_type: Type name of the executed op. In the case of the eager execution of
+ an individual op, it is the name of the op (e.g., "MatMul").
+ In the case of the execution of a tf.function (FuncGraph), this is the
+ internally-generated name of the function (e.g.,
+ "__inference_my_func_123").
+ output_tensor_device_ids: IDs of the devices on which the output tensors of
+ the execution reside. For no-output execution, this is `None`.
+ """
+
+ def __init__(self,
+ wall_time,
+ offset,
+ op_type,
+ output_tensor_device_ids=None):
+ super(ExecutionDigest, self).__init__(wall_time, offset)
+ self._op_type = op_type
+ self._output_tensor_device_ids = output_tensor_device_ids
+
+ @property
+ def op_type(self):
+ return self._op_type
+
+ @property
+ def output_tensor_device_ids(self):
+ return self._output_tensor_device_ids
+
+ # TODO(cais): Implement to_json().
+
+
+class Execution(ExecutionDigest):
+ """Detailed data relating to a top-level execution event.
+
+ The execution is of an individual op or a tf.function, which may have any
+ number of output tensors.
+
+ Properties (beyond the base class `ExecutionDigest`):
+ stack_frame_ids: Reference IDs for stack frames, ordered from bottommost to
+ topmost. Use `DebugDataReader.read_execution_stack_trace()` to load the
+ detailed stack frames (filepath, lineno and function name).
+ tensor_debug_mode: TensorDebugMode enum value, as an `int`.
+ graph_id: ID of the executed FuncGraph (applicable only the execution of a
+ tf.function). `None` for the eager execution of an individual op.
+ input_tensor_ids: IDs of the input (eager) tensor(s) for this execution, if
+ any.
+ output_tensor_ids: IDs of the output (eager) tensor(s) from this execution,
+ if any.
+ debug_tensor_values: Values of the debug tensor(s), applicable only to
+ non-FULL_TENSOR tensor debug mode. A tuple of list of numbers. Each
+ element of the tuple corresponds to an output tensor of the execution.
+ See documentation of the various TensorDebugModes for the semantics of the
+ numbers.
+ """
+
+ def __init__(self,
+ execution_digest,
+ stack_frame_ids,
+ tensor_debug_mode,
+ graph_id=None,
+ input_tensor_ids=None,
+ output_tensor_ids=None,
+ debug_tensor_values=None):
+ super(Execution, self).__init__(
+ execution_digest.wall_time,
+ execution_digest.offset,
+ execution_digest.op_type,
+ output_tensor_device_ids=execution_digest.output_tensor_device_ids)
+ self._stack_frame_ids = stack_frame_ids
+ self._tensor_debug_mode = tensor_debug_mode
+ self._graph_id = graph_id
+ self._input_tensor_ids = input_tensor_ids
+ self._output_tensor_ids = output_tensor_ids
+ self._debug_tensor_values = debug_tensor_values
+
+ @property
+ def stack_frame_ids(self):
+ return self._stack_frame_ids
+
+ @property
+ def tensor_debug_mode(self):
+ return self._tensor_debug_mode
+
+ @property
+ def graph_id(self):
+ return self._graph_id
+
+ @property
+ def input_tensor_ids(self):
+ return self._input_tensor_ids
+
+ @property
+ def num_outputs(self):
+ return len(self._output_tensor_ids)
+
+ @property
+ def output_tensor_ids(self):
+ return self._output_tensor_ids
+
+ @property
+ def debug_tensor_values(self):
+ return self._debug_tensor_values
+
+ # TODO(cais): Implement to_json().
+
+
+class DebuggedGraph(object):
+ """Data object representing debugging information about a tf.Graph.
+
+ Includes `FuncGraph`s.
+
+ Properties:
+ name: Name of the graph (if any). May be `None` for non-function graphs.
+ graph_id: Debugger-generated ID for the graph.
+ inner_graph_ids: A list of the debugger-generated IDs for the graphs
+ enclosed by this graph.
+ outer_graph_id: If this graph is nested within an outer graph, ID of the
+ outer graph. If this is an outermost graph, `None`.
+ """
+
+ def __init__(self,
+ name,
+ graph_id,
+ outer_graph_id=None):
+ self._name = name
+ self._graph_id = graph_id
+ self._outer_graph_id = outer_graph_id
+ self._inner_graph_ids = []
+ # A dictionary from op name to GraphOpCreationDigest.
+ self._op_by_name = dict()
+
+ def add_inner_graph_id(self, inner_graph_id):
+ """Add the debugger-generated ID of a graph nested within this graph.
+
+ Args:
+ inner_graph_id: The debugger-generated ID of the nested inner graph.
+ """
+ assert isinstance(inner_graph_id, six.string_types)
+ self._inner_graph_ids.append(inner_graph_id)
+
+ def add_op(self, graph_op_creation_digest):
+ """Add an op creation data object.
+
+ Args:
+ graph_op_creation_digest: A GraphOpCreationDigest data object describing
+ the creation of an op inside this graph.
+ """
+ assert graph_op_creation_digest.op_name not in self._op_by_name
+ self._op_by_name[
+ graph_op_creation_digest.op_name] = graph_op_creation_digest
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def graph_id(self):
+ return self._graph_id
+
+ @property
+ def outer_graph_id(self):
+ return self._outer_graph_id
+
+ @property
+ def inner_graph_ids(self):
+ return self._inner_graph_ids
+
+ def get_op_type(self, op_name):
+ return self._op_by_name[op_name].op_type
+
+ def get_tensor_id(self, op_name, output_slot):
+ """Get the ID of a symbolic tensor in this graph."""
+ return self._op_by_name[op_name].output_tensor_ids[output_slot]
+
+ # TODO(cais): Implement to_json().
+
+
+class DebuggedDevice(object):
+ """Debugger data regarding a device involved in the debugged program.
+
+ Properties:
+ device_name: Name of the device, as a str.
+ device_id: An integer ID for the device, unique for each device within
+ the scope of the debugged TensorFlow program.
+ """
+
+ def __init__(self,
+ device_name,
+ device_id):
+ self._device_name = device_name
+ self._device_id = device_id
+
+ @property
+ def device_name(self):
+ return self._device_name
+
+ @property
+ def device_id(self):
+ return self._device_id
+
+ # TODO(cais): Implement to_json().
+
+
+class GraphOpCreationDigest(BaseDigest):
+ """Data object describing the creation of an op inside a graph.
+
+ For size efficiency, this digest object does not contain any stack frames or
+ any references to them. To obtain the stack frames, use
+ `DataReader.read_graph_op_creation_stack_trace()`.
+
+ Properties (beyond the base class):
+ graph_id: Debugger-generated ID of the immediately-enclosing graph.
+ op_type: Type name of the op (e.g., "MatMul").
+ op_name: Name of the op (e.g., "dense_1/MatMul").
+ output_tensor_ids: Debugger-generated IDs for the output(s) of the op.
+ input_names: Names of the input tensors to the op.
+ device_name: The name of the device that the op is placed on (if available).
+ """
+
+ def __init__(self,
+ wall_time,
+ offset,
+ graph_id,
+ op_type,
+ op_name,
+ output_tensor_ids,
+ input_names=None,
+ device_name=None):
+ super(GraphOpCreationDigest, self).__init__(wall_time, offset)
+ self._graph_id = graph_id
+ self._op_type = op_type
+ self._op_name = op_name
+ self._output_tensor_ids = output_tensor_ids
+ self._input_names = input_names
+ self._device_name = device_name
+
+ @property
+ def graph_id(self):
+ return self._graph_id
+
+ @property
+ def op_type(self):
+ return self._op_type
+
+ @property
+ def op_name(self):
+ return self._op_name
+
+ @property
+ def output_tensor_ids(self):
+ return self._output_tensor_ids
+
+ @property
+ def num_outputs(self):
+ return len(self._output_tensor_ids)
+
+ @property
+ def input_names(self):
+ return self._input_names
+
+ @property
+ def device_name(self):
+ return self._device_name
+
+ # TODO(cais): Implement to_json().
+
+
+class GraphExecutionTraceDigest(BaseDigest):
+ """Light-weight summary of a intra-graph tensor execution event.
+
+ Use `DebugDataReader.read_graph_execution_trace()` on this object to read more
+ detailed data (`GraphExecutionTrace`).
+
+ Properties (beyond the base class):
+ op_type: Type name of the executed op (e.g., "Conv2D").
+ op_name: Name of the op (e.g., "conv_2d_3/Conv2D").
+ output_slot: Output slot index of the tensor.
+ """
+
+ def __init__(self,
+ wall_time,
+ offset,
+ op_type,
+ op_name,
+ output_slot):
+ super(GraphExecutionTraceDigest, self).__init__(wall_time, offset)
+ self._op_type = op_type
+ self._op_name = op_name
+ self._output_slot = output_slot
+
+ @property
+ def op_type(self):
+ return self._op_type
+
+ @property
+ def op_name(self):
+ return self._op_name
+
+ @property
+ def output_slot(self):
+ return self._output_slot
+
+ # TODO(cais): Implement to_json().
+
+
+class GraphExecutionTrace(GraphExecutionTraceDigest):
+ """Detailed data object describing an intra-graph tensor execution.
+
+ Attributes (in addition to GraphExecutionTraceDigest):
+ graph_ids: The debugger-generated IDs of the graphs that enclose the
+ executed op (tensor), ordered from the outermost to the innermost.
+ graph_id: The debugger-generated ID of the innermost (immediately-enclosing)
+ graph.
+ tensor_debug_mode: TensorDebugMode enum value.
+ debug_tensor_value: Debug tensor values (only for non-FULL_TENSOR
+ tensor_debug_mode). A list of numbers. See the documentation of the
+ TensorDebugModes for the semantics of the numbers.
+ device_name: Device on which the tensor resides (if available)
+ """
+
+ def __init__(self,
+ graph_execution_trace_digest,
+ graph_ids,
+ tensor_debug_mode,
+ debug_tensor_value=None,
+ device_name=None):
+ super(GraphExecutionTrace, self).__init__(
+ graph_execution_trace_digest.wall_time,
+ graph_execution_trace_digest.offset,
+ graph_execution_trace_digest.op_type,
+ graph_execution_trace_digest.op_name,
+ graph_execution_trace_digest.output_slot)
+ self._graph_ids = graph_ids
+ self._tensor_debug_mode = tensor_debug_mode
+ self._debug_tensor_value = debug_tensor_value
+ self._device_name = device_name
+
+ @property
+ def graph_ids(self):
+ return self._graph_ids
+
+ @property
+ def graph_id(self):
+ return self._graph_ids[-1]
+
+ @property
+ def tensor_debug_mode(self):
+ return self._tensor_debug_mode
+
+ @property
+ def debug_tensor_value(self):
+ return self._debug_tensor_value
+
+ @property
+ def device_name(self):
+ return self._device_name
+
+ # TODO(cais): Implement to_json().
+
+
+def _parse_tensor_value(tensor_proto, return_list=False):
+ """Helper method for reading a tensor value from a tensor proto.
+
+ The rationale for the distinction between `True` and `False value of
+ `return_list` is as follows:
+ - `return_list=True` is used for TensorDebugMode values other than
+ FULL_TENSOR, e.g., CONCISE_HEALTH, SHAPE and FULL_HEATLH. Under
+ those modes, the value is guaranteed (by contract) to be a 1D float64
+ tensor.
+ - `return_list=False` is used for the FULL_HEALTH TensorDebugMode
+ specifically. Instead, we use `numpy.ndarray` to maximally preserve
+ the shape, dtype and value information regarding the underlying tensor
+ value. Under that mode, we don't use a python list to represent the
+ tensor value because that can lead to loss of information (e.g., both
+ float16 and float32 dtypes get mapped to Python floats).
+
+ Args:
+ tensor_proto: The TensorProto instance from which the tensor value will be
+ loaded.
+ return_list: Whether the return value will be a nested Python list that
+ comes out from `numpy.ndarray.tolist()`.
+
+ Returns:
+ If parsing is successful, the tensor value as a `numpy.ndarray` or the
+ nested Python list converted from it.
+ If parsing fails, `None`.
+ """
+ try:
+ ndarray = tensor_util.MakeNdarray(tensor_proto)
+ return ndarray.tolist() if return_list else ndarray
+ except TypeError:
+ # Depending on tensor_debug_mode, certain dtype of tensors don't
+ # have logged debug tensor values.
+ return None
+
+
+class DebugDataReader(object):
+ """A reader that reads structured debugging data in the tfdbg v2 format.
+
+ The set of data read by an object of this class concerns the execution history
+ of a tfdbg2-instrumented TensorFlow program.
+
+ Note:
+ - An object of this class incrementally reads data from files that belong to
+ the tfdbg v2 DebugEvent file set. Calling `update()` triggers the reading
+ from the last-successful reading positions in the files.
+ - This object can be used as a context manager. Its `__exit__()` call
+ closes the file readers cleanly.
+ """
+
+ def __init__(self, dump_root):
+ self._reader = DebugEventsReader(dump_root)
+ # TODO(cais): Implement pagination for memory constraints.
+ self._execution_digests = []
+
+ # A list of (host_name, file_path) tuples.
+ self._host_name_file_paths = []
+ # A dict mapping id to (host_name, file_path, lineno, func) tuple.
+ self._stack_frame_by_id = dict()
+ # Stores unprocessed stack frame IDs. This is necessary to handle the
+ # case in which reading of the .stack_frames file gets ahead of the reading
+ # of the .source_files file.
+ self._unprocessed_stack_frames = dict()
+ # A dict mapping id to DebuggedDevice objects.
+ self._device_by_id = dict()
+ # A dict mapping id to DebuggedGraph objects.
+ self._graph_by_id = dict()
+ self._graph_op_digests = []
+ # TODO(cais): Implement pagination for memory constraints.
+ self._graph_execution_trace_digests = []
+
+ # The following timestamps keep track where we've reached in each
+ # file of the DebugEvent source file, so that we don't run into race
+ # conditions with the writer.
+ self._source_files_timestamp = 0
+ # Temporary object used to hold DebugEvent protos with stack_frames
+ # field that has been read beyond max_wall_time.
+ # self._last_successful_stack_frames_offset = -1 # TODO(cais): Fix.
+
+ # TODO(cais): Read metadata.
+ def _load_source_files(self):
+ """Incrementally read the .source_files DebugEvent file."""
+ source_files_iter = self._reader.source_files_iterator()
+ for debug_event, _ in source_files_iter:
+ source_file = debug_event.source_file
+ self._host_name_file_paths.append(
+ (source_file.host_name, source_file.file_path))
+ self._source_file_timestamp = debug_event.wall_time
+
+ def _load_stack_frames(self):
+ """Incrementally read the .stack_frames file.
+
+ This must be called after _load_source_files().
+ It assumes that the following contract is honored by the writer of the tfdbg
+ v2 data file set:
+ - Before a stack frame is written to the .stack_frames file, the
+ corresponding source file information must have been written to the
+ .source_files file first.
+ """
+ stack_frames_iter = self._reader.stack_frames_iterator()
+ for debug_event, _ in stack_frames_iter:
+ stack_frame_with_id = debug_event.stack_frame_with_id
+ file_line_col = stack_frame_with_id.file_line_col
+ self._unprocessed_stack_frames[stack_frame_with_id.id] = file_line_col
+ # We do the processing in a separate stage, because the reading in the
+ # .source_files file may sometimes get ahead of the .source_files file.
+ unprocessed_stack_frame_ids = tuple(self._unprocessed_stack_frames.keys())
+ for stack_frame_id in unprocessed_stack_frame_ids:
+ file_line_col = self._unprocessed_stack_frames[stack_frame_id]
+ if len(self._host_name_file_paths) > file_line_col.file_index:
+ self._stack_frame_by_id[stack_frame_id] = (
+ self._host_name_file_paths[file_line_col.file_index][0],
+ self._host_name_file_paths[file_line_col.file_index][1],
+ file_line_col.line,
+ file_line_col.func)
+ del self._unprocessed_stack_frames[stack_frame_id]
+
+ def _load_graphs(self):
+ """Incrementally read the .graphs file.
+
+ Compiles the DebuggedGraph and GraphOpCreation data.
+ """
+ graphs_iter = self._reader.graphs_iterator()
+ for debug_event, offset in graphs_iter:
+ if debug_event.graph_op_creation.ByteSize():
+ op_creation_proto = debug_event.graph_op_creation
+ op_digest = GraphOpCreationDigest(
+ debug_event.wall_time,
+ offset,
+ op_creation_proto.graph_id,
+ op_creation_proto.op_type,
+ op_creation_proto.op_name,
+ tuple(op_creation_proto.output_tensor_ids),
+ input_names=tuple(op_creation_proto.input_names))
+ self._graph_op_digests.append(op_digest)
+ self._graph_by_id[op_creation_proto.graph_id].add_op(op_digest)
+ elif debug_event.debugged_graph.ByteSize():
+ graph_proto = debug_event.debugged_graph
+ graph = DebuggedGraph(
+ graph_proto.graph_name or None,
+ graph_proto.graph_id,
+ outer_graph_id=graph_proto.outer_context_id or None)
+ self._graph_by_id[graph_proto.graph_id] = graph
+ if graph_proto.outer_context_id:
+ self._graph_by_id[
+ graph_proto.outer_context_id].add_inner_graph_id(graph.graph_id)
+ elif debug_event.debugged_device.ByteSize():
+ device_proto = debug_event.debugged_device
+ self._device_by_id[device_proto.device_id] = DebuggedDevice(
+ device_proto.device_name, device_proto.device_id)
+
+ def _load_graph_execution_traces(self):
+ """Incrementally load the .graph_execution_traces file."""
+ traces_iter = self._reader.graph_execution_traces_iterator()
+ for debug_event, offset in traces_iter:
+ trace_proto = debug_event.graph_execution_trace
+ op_name = trace_proto.op_name
+ op_type = self._lookup_op_type(trace_proto.tfdbg_context_id, op_name)
+ digest = GraphExecutionTraceDigest(
+ debug_event.wall_time,
+ offset,
+ op_type,
+ op_name,
+ trace_proto.output_slot)
+ self._graph_execution_trace_digests.append(digest)
+
+ def _lookup_op_type(self, graph_id, op_name):
+ """Lookup the type of an op by name and the immediately enclosing graph.
+
+ Args:
+ graph_id: Debugger-generated ID of the immediately-enclosing graph.
+ op_name: Name of the op.
+
+ Returns:
+ Op type as a str.
+ """
+ return self._graph_by_id[graph_id].get_op_type(op_name)
+
+ def _load_execution(self):
+ """Incrementally read the .execution file."""
+ execution_iter = self._reader.execution_iterator()
+ for debug_event, offset in execution_iter:
+ self._execution_digests.append(ExecutionDigest(
+ debug_event.wall_time,
+ offset,
+ debug_event.execution.op_type,
+ output_tensor_device_ids=(
+ debug_event.execution.output_tensor_device_ids or None)))
+
+ def update(self):
+ """Perform incremental read of the file set."""
+ self._load_source_files()
+ self._load_stack_frames()
+ self._load_graphs()
+ self._load_graph_execution_traces()
+ self._load_execution()
+
+ def outermost_graphs(self):
+ """Get the number of outer most graphs read so far."""
+ return [graph for graph in self._graph_by_id.values()
+ if not graph.outer_graph_id]
+
+ def graph_by_id(self, graph_id):
+ """Get a DebuggedGraph object by its ID."""
+ return self._graph_by_id[graph_id]
+
+ def device_name_by_id(self, device_id):
+ """Get the name of a device by the debugger-generated ID of the device."""
+ return self._device_by_id[device_id].device_name
+
+ def device_names(self):
+ """Get a set of all device names known to the debugger."""
+ return set(device.device_name for device in self._device_by_id.values())
+
+ def graph_op_digests(self, op_type=None):
+ """Get the list of the digests for graph-op creation so far.
+
+ Args:
+ op_type: Optional op type to filter the creation events with.
+
+ Returns:
+ A list of `GraphOpCreationDigest` objects.
+ """
+ if op_type is not None:
+ return [digest for digest in self._graph_op_digests
+ if digest.op_type == op_type]
+ else:
+ return self._graph_op_digests
+
+ def graph_execution_traces(self, digest=False):
+ """Get all the intra-graph execution tensor traces read so far.
+
+ TODO(cais): Support begin and end to enable partial loading.
+
+ Args:
+ digest: Whether the results will be returned in the more light-weight
+ digest form.
+
+ Returns:
+ If `digest`: a `list` of `GraphExecutionTraceDigest` objects.
+ Else: a `list` of `GraphExecutionTrace` objects.
+ """
+ if digest:
+ return self._graph_execution_trace_digests
+ else:
+ return [self.read_graph_execution_trace(digest)
+ for digest in self._graph_execution_trace_digests]
+
+ def num_graph_execution_traces(self):
+ """Get the number of graph execution traces read so far."""
+ return len(self._graph_execution_trace_digests)
+
+ def executions(self, digest=False):
+ """Get `Execution`s or `ExecutionDigest`s this reader has read so far.
+
+ # TODO(cais): Support begin index and end index to support partial loading.
+
+ Args:
+ digest: Whether the results are returned in a digest form, i.e.,
+ `ExecutionDigest` format, instead of the more detailed `Execution`
+ format.
+
+ Returns:
+ If `digest`: a `list` of `ExecutionDigest` objects.
+ Else: a `list` of `Execution` objects.
+ """
+ if digest:
+ return self._execution_digests
+ else:
+ # TODO(cais): Optimizer performance removing repeated file open/close.
+ return [self.read_execution(digest) for digest in self._execution_digests]
+
+ def num_executions(self):
+ """Get the number of execution events read so far."""
+ return len(self._execution_digests)
+
+ def read_execution(self, execution_digest):
+ """Read a detailed Execution object."""
+ debug_event = self._reader.read_execution_debug_event(
+ execution_digest.offset)
+ execution_proto = debug_event.execution
+
+ debug_tensor_values = None
+ if (execution_proto.tensor_debug_mode ==
+ debug_event_pb2.TensorDebugMode.FULL_TENSOR):
+ pass # TODO(cais): Build tensor store.
+ elif (execution_proto.tensor_debug_mode !=
+ debug_event_pb2.TensorDebugMode.NO_TENSOR):
+ debug_tensor_values = []
+ for tensor_proto in execution_proto.tensor_protos:
+ # TODO(cais): Refactor into a helper method.
+ debug_tensor_values.append(
+ _parse_tensor_value(tensor_proto, return_list=True))
+ return Execution(
+ execution_digest,
+ tuple(execution_proto.code_location.stack_frame_ids),
+ execution_proto.tensor_debug_mode,
+ graph_id=execution_proto.graph_id,
+ input_tensor_ids=tuple(execution_proto.input_tensor_ids),
+ output_tensor_ids=tuple(execution_proto.output_tensor_ids),
+ debug_tensor_values=tuple(
+ debug_tensor_values) if debug_tensor_values else None)
+
+ def read_graph_execution_trace(self, graph_execution_trace_digest):
+ """Read the detailed graph execution trace.
+
+ Args:
+ graph_execution_trace_digest: A `GraphExecutionTraceDigest` object.
+
+ Returns:
+ The corresponding `GraphExecutionTrace` object.
+ """
+ debug_event = self._reader.read_graph_execution_traces_event(
+ graph_execution_trace_digest.offset)
+ trace_proto = debug_event.graph_execution_trace
+
+ graph_ids = [trace_proto.tfdbg_context_id]
+ # Exhaust the outer contexts (graphs).
+ while True:
+ graph = self.graph_by_id(graph_ids[0])
+ if graph.outer_graph_id:
+ graph_ids.insert(0, graph.outer_graph_id)
+ else:
+ break
+
+ debug_tensor_value = None
+ if (trace_proto.tensor_debug_mode ==
+ debug_event_pb2.TensorDebugMode.FULL_TENSOR):
+ pass # TODO(cais): Build tensor store.
+ else:
+ debug_tensor_value = _parse_tensor_value(
+ trace_proto.tensor_proto, return_list=True)
+ return GraphExecutionTrace(
+ graph_execution_trace_digest,
+ graph_ids=graph_ids,
+ tensor_debug_mode=trace_proto.tensor_debug_mode,
+ debug_tensor_value=debug_tensor_value,
+ device_name=trace_proto.device_name or None)
+
+ def read_execution_stack_trace(self, execution):
+ """Read the stack trace of a given Execution object.
+
+ Args:
+ execution: The Execution object of interest.
+
+ Returns:
+ A tuple consisting of:
+ 1. The host name.
+ 2. The stack trace, as a list of (file_path, lineno, func) tuples.
+ """
+ host_name = self._stack_frame_by_id[execution.stack_frame_ids[0]][0]
+ return (host_name, [
+ self._stack_frame_by_id[frame_id][1:]
+ for frame_id in execution.stack_frame_ids])
+
+ def read_graph_op_creation_stack_trace(self, graph_op_creation_digest):
+ """Read the stack trace of a given graph op creation object.
+
+ Args:
+ graph_op_creation_digest: The GraphOpCreationDigest object of interest.
+
+ Returns:
+ A tuple consisting of:
+ 1. The host name.
+ 2. The stack trace, as a list of (file_path, lineno, func) tuples.
+ """
+ debug_event = self._reader.read_graphs_event(
+ graph_op_creation_digest.offset)
+ graph_op_creation = debug_event.graph_op_creation
+ host_name = graph_op_creation.code_location.host_name
+ return host_name, [
+ self._stack_frame_by_id[frame_id][1:]
+ for frame_id in graph_op_creation.code_location.stack_frame_ids]
+
+ # TODO(cais): Add graph_execution_digests() with an ExecutionDigest
+ # as a kwarg, to establish the association between top-level and intra-graph
+ # execution events.
+
+ def execution_to_tensor_values(self, execution):
+ """Read the full tensor values from an Execution or ExecutionDigest.
+
+ Args:
+ execution: An `ExecutionDigest` or `ExeuctionDigest` object.
+
+ Returns:
+ A list of numpy arrays representing the output tensor values of the
+ execution event.
+ """
+ debug_event = self._reader.read_execution_debug_event(execution.offset)
+ return [_parse_tensor_value(tensor_proto)
+ for tensor_proto in debug_event.execution.tensor_protos]
+
+ def graph_execution_trace_to_tensor_value(self, trace):
+ """Read full tensor values from an Execution or ExecutionDigest.
+
+ Args:
+ trace: An `GraphExecutionTraceDigest` or `GraphExecutionTrace` object.
+
+ Returns:
+ A numpy array representing the output tensor value of the intra-graph
+ tensor execution event.
+ """
+ debug_event = self._reader.read_graph_execution_traces_event(trace.offset)
+ return _parse_tensor_value(debug_event.graph_execution_trace.tensor_proto)
+
+ def symbolic_tensor_id(self, graph_id, op_name, output_slot):
+ """Get the ID of a symbolic tensor.
+
+ Args:
+ graph_id: The ID of the immediately-enclosing graph.
+ op_name: Name of the op.
+ output_slot: Output slot as an int.
+
+ Returns:
+ The ID of the symbolic tensor as an int.
+ """
+ return self._graph_by_id[graph_id].get_tensor_id(op_name, output_slot)
+
+ def graph_execution_trace_to_tensor_id(self, trace):
+ """Get symbolic tensor ID from a GraphExecutoinTraceDigest object."""
+ return self.symbolic_tensor_id(
+ trace.graph_id, trace.op_name, trace.output_slot)
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exception_type, exception_value, traceback):
+ del exception_type, exception_value, traceback # Unused
+ self._reader.close()
diff --git a/tensorflow/python/debug/lib/debug_events_writer.py b/tensorflow/python/debug/lib/debug_events_writer.py
index 7f7ae38..3de0ab7 100644
--- a/tensorflow/python/debug/lib/debug_events_writer.py
+++ b/tensorflow/python/debug/lib/debug_events_writer.py
@@ -128,6 +128,10 @@
_pywrap_debug_events_writer.WriteGraphExecutionTrace(
self._dump_root, debug_event)
+ def RegisterDeviceAndGetId(self, device_name):
+ return _pywrap_debug_events_writer.RegisterDeviceAndGetId(
+ self._dump_root, device_name)
+
def FlushNonExecutionFiles(self):
"""Flush the non-execution debug event files."""
_pywrap_debug_events_writer.FlushNonExecutionFiles(self._dump_root)
diff --git a/tensorflow/python/debug/lib/debug_events_writer_test.py b/tensorflow/python/debug/lib/debug_events_writer_test.py
index f6e973b..b62fc9b 100644
--- a/tensorflow/python/debug/lib/debug_events_writer_test.py
+++ b/tensorflow/python/debug/lib/debug_events_writer_test.py
@@ -76,20 +76,20 @@
writer.FlushNonExecutionFiles()
with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
- actuals = list(reader.source_files_iterator())
+ actuals = list(item.debug_event.source_file
+ for item in reader.source_files_iterator())
self.assertLen(actuals, num_protos)
for i in range(num_protos):
- self.assertEqual(actuals[i].source_file.file_path,
- "/home/tf2user/main.py")
- self.assertEqual(actuals[i].source_file.host_name, "machine.cluster")
- self.assertEqual(actuals[i].source_file.lines, ["print(%d)" % i])
+ self.assertEqual(actuals[i].file_path, "/home/tf2user/main.py")
+ self.assertEqual(actuals[i].host_name, "machine.cluster")
+ self.assertEqual(actuals[i].lines, ["print(%d)" % i])
- actuals = list(reader.stack_frames_iterator())
+ actuals = list(item.debug_event.stack_frame_with_id
+ for item in reader.stack_frames_iterator())
self.assertLen(actuals, num_protos)
for i in range(num_protos):
- self.assertEqual(actuals[i].stack_frame_with_id.id, "stack_%d" % i)
- self.assertEqual(
- actuals[i].stack_frame_with_id.file_line_col.file_index, i * 10)
+ self.assertEqual(actuals[i].id, "stack_%d" % i)
+ self.assertEqual(actuals[i].file_line_col.file_index, i * 10)
def testWriteGraphOpCreationAndDebuggedGraphs(self):
writer = debug_events_writer.DebugEventsWriter(self.dump_root)
@@ -106,7 +106,7 @@
writer.FlushNonExecutionFiles()
reader = debug_events_reader.DebugEventsReader(self.dump_root)
- actuals = list(reader.graphs_iterator())
+ actuals = list(item.debug_event for item in reader.graphs_iterator())
self.assertLen(actuals, num_op_creations + 1)
for i in range(num_op_creations):
self.assertEqual(actuals[i].graph_op_creation.op_type, "Conv2D")
@@ -172,24 +172,24 @@
# Verify the content of the .source_files file.
with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
source_files_iter = reader.source_files_iterator()
- actuals = list(source_files_iter)
- file_paths = sorted([actual.source_file.file_path for actual in actuals])
+ actuals = list(item.debug_event.source_file for item in source_files_iter)
+ file_paths = sorted([actual.file_path for actual in actuals])
self.assertEqual(file_paths, [
"/home/tf2user/file_0.py", "/home/tf2user/file_1.py",
"/home/tf2user/file_2.py"
])
# Verify the content of the .stack_frames file.
- actuals = list(reader.stack_frames_iterator())
- stack_frame_ids = sorted(
- [actual.stack_frame_with_id.id for actual in actuals])
+ actuals = list(item.debug_event.stack_frame_with_id
+ for item in reader.stack_frames_iterator())
+ stack_frame_ids = sorted([actual.id for actual in actuals])
self.assertEqual(stack_frame_ids,
["stack_frame_0", "stack_frame_1", "stack_frame_2"])
# Verify the content of the .graphs file.
- actuals = list(reader.graphs_iterator())
- graph_op_names = sorted(
- [actual.graph_op_creation.op_name for actual in actuals])
+ actuals = list(item.debug_event.graph_op_creation
+ for item in reader.graphs_iterator())
+ graph_op_names = sorted([actual.op_name for actual in actuals])
self.assertEqual(graph_op_names, ["Op0", "Op1", "Op2"])
def testWriteExecutionEventsWithCircularBuffer(self):
@@ -242,11 +242,12 @@
self.assertEqual(len(actuals), 0)
writer.FlushExecutionFiles()
- actuals = list(reader.graph_execution_traces_iterator())
+ actuals = list(item.debug_event.graph_execution_trace
+ for item in reader.graph_execution_traces_iterator())
self.assertLen(actuals, debug_events_writer.DEFAULT_CIRCULAR_BUFFER_SIZE)
for i in range(debug_events_writer.DEFAULT_CIRCULAR_BUFFER_SIZE):
self.assertEqual(
- actuals[i].graph_execution_trace.op_name,
+ actuals[i].op_name,
"Op%d" % (i + debug_events_writer.DEFAULT_CIRCULAR_BUFFER_SIZE))
def testWriteGraphExecutionTraceEventsWithoutCircularBufferBehavior(self):
@@ -260,10 +261,11 @@
writer.FlushExecutionFiles()
with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
- actuals = list(reader.graph_execution_traces_iterator())
+ actuals = list(item.debug_event.graph_execution_trace
+ for item in reader.graph_execution_traces_iterator())
self.assertLen(actuals, num_execution_events)
for i in range(num_execution_events):
- self.assertEqual(actuals[i].graph_execution_trace.op_name, "Op%d" % i)
+ self.assertEqual(actuals[i].op_name, "Op%d" % i)
def testConcurrentWritesToExecutionFiles(self):
circular_buffer_size = 5
@@ -308,9 +310,9 @@
# Verify the content of the .execution file.
with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
- actuals = list(reader.graph_execution_traces_iterator())
- op_names = sorted(
- [actual.graph_execution_trace.op_name for actual in actuals])
+ actuals = list(item.debug_event.graph_execution_trace
+ for item in reader.graph_execution_traces_iterator())
+ op_names = sorted([actual.op_name for actual in actuals])
self.assertLen(op_names, circular_buffer_size)
self.assertLen(op_names, len(set(op_names)))
diff --git a/tensorflow/python/debug/lib/debug_v2_ops_test.py b/tensorflow/python/debug/lib/debug_v2_ops_test.py
index c665da7..d6f0d43 100644
--- a/tensorflow/python/debug/lib/debug_v2_ops_test.py
+++ b/tensorflow/python/debug/lib/debug_v2_ops_test.py
@@ -88,7 +88,7 @@
metadata_iter = reader.metadata_iterator()
# Check that the .metadata DebugEvents data file has been created, even
# before FlushExecutionFiles() is called.
- debug_event = next(metadata_iter)
+ debug_event = next(metadata_iter).debug_event
self.assertGreater(debug_event.wall_time, 0)
self.assertTrue(debug_event.debug_metadata.tensorflow_version)
self.assertTrue(
@@ -107,7 +107,7 @@
# The circular buffer has a size of 4. So only the data from the
# last two iterations should have been written to self.dump_root.
for _ in range(2):
- debug_event = next(graph_trace_iter)
+ debug_event = next(graph_trace_iter).debug_event
self.assertGreater(debug_event.wall_time, 0)
trace = debug_event.graph_execution_trace
self.assertEqual(trace.tfdbg_context_id, "deadbeaf")
@@ -118,7 +118,7 @@
tensor_value = tensor_util.MakeNdarray(trace.tensor_proto)
self.assertAllClose(tensor_value, [9.0, 16.0])
- debug_event = next(graph_trace_iter)
+ debug_event = next(graph_trace_iter).debug_event
self.assertGreater(debug_event.wall_time, 0)
trace = debug_event.graph_execution_trace
self.assertEqual(trace.tfdbg_context_id, "beafdead")
@@ -165,7 +165,7 @@
x_values = []
timestamp = 0
while True:
- debug_event = next(graph_trace_iter)
+ debug_event = next(graph_trace_iter).debug_event
self.assertGreater(debug_event.wall_time, timestamp)
timestamp = debug_event.wall_time
trace = debug_event.graph_execution_trace
@@ -210,7 +210,7 @@
with debug_events_reader.DebugEventsReader(debug_root) as reader:
graph_trace_iter = reader.graph_execution_traces_iterator()
- debug_event = next(graph_trace_iter)
+ debug_event = next(graph_trace_iter).debug_event
trace = debug_event.graph_execution_trace
self.assertEqual(trace.tfdbg_context_id, "deadbeaf")
self.assertEqual(trace.op_name, "")
diff --git a/tensorflow/python/debug/lib/dumping_callback.py b/tensorflow/python/debug/lib/dumping_callback.py
index 98e7292..e51eedf 100644
--- a/tensorflow/python/debug/lib/dumping_callback.py
+++ b/tensorflow/python/debug/lib/dumping_callback.py
@@ -386,6 +386,7 @@
tensors,
op_type,
input_tensor_ids,
+ output_tensor_device_ids,
graph_id=None):
"""Dump the value of eager tensors.
@@ -400,6 +401,9 @@
value transform.
op_type: Type of the op that generates the tensors, as a string.
input_tensor_ids: IDs of the input EagerTensors to the op.
+ output_tensor_device_ids: Debugged-generated IDs for the devices on which
+ the output tensors are allocated, as a `list` of `int`s. Must match
+ `tensors` in length.
graph_id: ID of the executed graph, applicable only to eager execution of
a FuncGraph.
@@ -409,6 +413,7 @@
tensor_debug_mode = self._tensor_debug_mode
output_tensor_ids = [
t._id for t in tensors] # pylint:disable=protected-access
+ assert len(tensors) == len(output_tensor_device_ids)
if tensor_debug_mode == debug_event_pb2.TensorDebugMode.NO_TENSOR:
return debug_event_pb2.Execution(
op_type=op_type,
@@ -416,6 +421,7 @@
num_outputs=len(tensors),
input_tensor_ids=input_tensor_ids,
output_tensor_ids=output_tensor_ids,
+ output_tensor_device_ids=output_tensor_device_ids,
tensor_debug_mode=tensor_debug_mode,
code_location=self._process_stack_frames())
elif tensor_debug_mode in (debug_event_pb2.TensorDebugMode.CURT_HEALTH,
@@ -428,6 +434,7 @@
graph_id=graph_id,
input_tensor_ids=input_tensor_ids,
output_tensor_ids=output_tensor_ids,
+ output_tensor_device_ids=output_tensor_device_ids,
tensor_debug_mode=tensor_debug_mode,
code_location=self._process_stack_frames())
for tensor in tensors:
@@ -505,8 +512,11 @@
return None
context_id = self._func_graph_id_from_func_name(op_type)
input_ids = [t._id for t in inputs] # pylint:disable=protected-access
+ output_tensor_device_ids = [writer.RegisterDeviceAndGetId(output.device)
+ for output in outputs] if outputs else []
writer.WriteExecution(self._dump_eager_tensors(
- outputs, op_type, input_ids, graph_id=context_id))
+ outputs, op_type, input_ids, output_tensor_device_ids,
+ graph_id=context_id))
def _func_graph_id_from_func_name(self, op_type):
"""Attempt to get the ID of a FuncGraph based on an op type name.
diff --git a/tensorflow/python/debug/lib/dumping_callback_test.py b/tensorflow/python/debug/lib/dumping_callback_test.py
index b7e90f3..115315a 100644
--- a/tensorflow/python/debug/lib/dumping_callback_test.py
+++ b/tensorflow/python/debug/lib/dumping_callback_test.py
@@ -21,6 +21,7 @@
import collections
import os
import shutil
+import socket
import tempfile
import threading
@@ -36,7 +37,6 @@
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
from tensorflow.python.keras import models
from tensorflow.python.keras.applications import mobilenet_v2
@@ -61,6 +61,10 @@
return model
+_host_name = socket.gethostname()
+_current_file_full_path = os.path.abspath(__file__)
+
+
class TracingCallbackTest(
dumping_callback_test_lib.DumpingCallbackTestBase, parameterized.TestCase):
@@ -74,6 +78,26 @@
dumping_callback.disable_dump_debug_info()
super(TracingCallbackTest, self).tearDown()
+ def _verifyStackFrames(self, stack_frames):
+ """Verify the correctness of the stack frames.
+
+ Currently, it simply asserts that the current file is found in the stack
+ frames.
+ TODO(cais): Perhaps implement a stricter check later.
+
+ Args:
+ stack_frames: The stack frames to verify.
+ """
+ self.assertTrue([
+ frame for frame in stack_frames if frame[0] == _current_file_full_path])
+
+ def _expectedDefaultDeviceName(self):
+ gpu_name = test_util.gpu_device_name()
+ if gpu_name:
+ return "/job:localhost/replica:0/task:0" + gpu_name
+ else:
+ return "/job:localhost/replica:0/task:0/device:CPU:0"
+
def testInvalidTensorDebugModeCausesError(self):
with self.assertRaisesRegexp(
ValueError,
@@ -111,73 +135,74 @@
writer.FlushNonExecutionFiles()
self._readAndCheckMetadataFile()
- stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
- # Before FlushExecutionFiles() is called, the .execution file should be
- # empty.
- with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
- execution_iter = reader.execution_iterator()
- with self.assertRaises(StopIteration):
- next(execution_iter)
+ with debug_events_reader.DebugDataReader(self.dump_root) as reader:
+ reader.update()
+ # Before FlushExecutionFiles() is called, the .execution file should be
+ # empty.
+ self.assertFalse(reader.executions())
# After the flushing, the .execution file should hold the appropriate
# contents.
writer.FlushExecutionFiles()
- execution_iter = reader.execution_iterator()
+ reader.update()
+ executions = reader.executions()
prev_wall_time = 1
executed_op_types = []
tensor_values = collections.defaultdict(lambda: [])
- for debug_event in execution_iter:
- self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
- prev_wall_time = debug_event.wall_time
- execution = debug_event.execution
+ for execution in executions:
+ self.assertGreaterEqual(execution.wall_time, prev_wall_time)
+ prev_wall_time = execution.wall_time
executed_op_types.append(execution.op_type)
+ # Check the device name.
+ if execution.op_type in ("AddV2", "Mul", "RealDiv"):
+ self.assertLen(execution.output_tensor_device_ids, 1)
+ self.assertEqual(
+ reader.device_name_by_id(execution.output_tensor_device_ids[0]),
+ self._expectedDefaultDeviceName(),
+ "Unexpected device name from eager op %s" % execution.op_type)
+
# No graph IDs should have been logged for eager op executions.
self.assertFalse(execution.graph_id)
self.assertTrue(execution.input_tensor_ids)
self.assertTrue(execution.output_tensor_ids)
+ self.assertEqual(
+ debug_event_pb2.TensorDebugMode.keys()[execution.tensor_debug_mode],
+ tensor_debug_mode)
if tensor_debug_mode == "NO_TENSOR":
# Due to the NO_TENSOR tensor debug mode, tensor_protos ought to
# be empty.
- self.assertFalse(execution.tensor_protos)
+ self.assertFalse(execution.debug_tensor_values)
elif tensor_debug_mode == "CURT_HEALTH":
- self.assertLen(execution.tensor_protos, 1)
+ self.assertLen(execution.debug_tensor_values, 1)
if execution.op_type in ("AddV2", "Mul", "RealDiv"):
# 1st element: -1 is the unset tensor_id for eager op execution.
# 2nd element: 0 means there is no inf or nan.
- self.assertAllClose(
- tensor_util.MakeNdarray(execution.tensor_protos[0]),
- [-1.0, 0.0])
+ self.assertAllClose(execution.debug_tensor_values, [[-1.0, 0.0]])
elif tensor_debug_mode == "CONCISE_HEALTH":
- self.assertLen(execution.tensor_protos, 1)
if execution.op_type in ("AddV2", "Mul", "RealDiv"):
# 1st element: -1 is the unset tensor_id for eager op execution.
# 2nd element: each scalar tensor has 1 element.
# Remaining elements: no -inf, inf or nan in these
self.assertAllClose(
- tensor_util.MakeNdarray(execution.tensor_protos[0]),
- [-1, 1, 0, 0, 0])
+ execution.debug_tensor_values, [[-1, 1, 0, 0, 0]])
elif tensor_debug_mode == "SHAPE":
- self.assertLen(execution.tensor_protos, 1)
if execution.op_type in ("AddV2", "Mul", "RealDiv"):
# 1st element: -1 is the unset tensor_id for eager op execution.
# 2nd element: dtype enum value (float32).
# 3rd element: rank (scalar).
# 4th element: element count (4).
# Remaining elements: shape at fixed length (6).
- self.assertAllClose(
- tensor_util.MakeNdarray(execution.tensor_protos[0]),
- [-1, 1, 0, 1, 0, 0, 0, 0, 0, 0])
+ self.assertAllClose(execution.debug_tensor_values,
+ [[-1, 1, 0, 1, 0, 0, 0, 0, 0, 0]])
elif tensor_debug_mode == "FULL_TENSOR":
- # Under the FULL_TENSOR mode, the value of the tensor should be
- # available through `tensor_protos`.
- tensor_value = float(
- tensor_util.MakeNdarray(execution.tensor_protos[0]))
- tensor_values[execution.op_type].append(tensor_value)
- # Verify the code_location field.
- self.assertTrue(execution.code_location.stack_frame_ids)
- for stack_frame_id in execution.code_location.stack_frame_ids:
- self.assertIn(stack_frame_id, stack_frame_by_id)
+ tensor_values[execution.op_type].append(
+ reader.execution_to_tensor_values(execution)[0])
+
+ host_name, stack_frames = reader.read_execution_stack_trace(execution)
+ self.assertEqual(host_name, _host_name)
+ self._verifyStackFrames(stack_frames)
+
if tensor_debug_mode == "FULL_TENSOR":
self.assertAllClose(tensor_values["Greater"], [1, 1, 1, 1, 1, 1, 0])
self.assertAllClose(tensor_values["RealDiv"], [5, 8, 4, 2, 1])
@@ -217,12 +242,8 @@
# Due to the pure eager op execution, the .graph file and the
# .graph_execution_traces file ought to be empty.
- graphs_iterator = reader.graphs_iterator()
- with self.assertRaises(StopIteration):
- next(graphs_iterator)
- graph_trace_iter = reader.graph_execution_traces_iterator()
- with self.assertRaises(StopIteration):
- next(graph_trace_iter)
+ self.assertFalse(reader.outermost_graphs())
+ self.assertEqual(reader.num_graph_execution_traces(), 0)
@parameterized.named_parameters(
("CurtHealth", "CURT_HEALTH"),
@@ -242,60 +263,48 @@
y = np.array([2, -1, 0, 0, 1, 1, 1, 3], dtype=np.float16)
# (x + y) / (x - y) = [0.2, -inf, nan, nan, inf, inf, inf, -5].
self.evaluate(func(x, y))
-
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
- stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
- (context_ids,
- _, op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id)
-
- (op_names, _, _,
- tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
- executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
- self.assertCountEqual(executed_op_types, ["AddV2", "Sub", "RealDiv"])
-
- if tensor_debug_mode == "CURT_HEALTH":
- for op_type, tensor_value in zip(executed_op_types, tensor_values):
- self.assertLen(tensor_value, 2)
- # 1st element: tensor_id, should be >= 0.
- # TODO(cais): Assert on detailed value once Function-graph association
- # is in place.
- self.assertGreaterEqual(tensor_value[0], 0)
- # 2nd element: 0 means there is no inf or nan.
- if op_type == "RealDiv":
- self.assertEqual(tensor_value[1], 1)
- else:
- self.assertEqual(tensor_value[1], 0)
- elif tensor_debug_mode == "CONCISE_HEALTH":
- for op_type, tensor_value in zip(executed_op_types, tensor_values):
- self.assertLen(tensor_value, 5)
- # 1st element: tensor_id, should be >= 0.
- # TODO(cais): Assert on detailed value once Function-graph association
- # is in place.
- self.assertGreaterEqual(tensor_value[0], 0)
- # 2nd element: element count.
- self.assertEqual(tensor_value[1], 8)
- # Remaining 3 elements: The counts of -inf, inf and nan.
- if op_type == "RealDiv":
- self.assertAllClose(tensor_value[2:], [1, 3, 2])
- else:
- self.assertAllClose(tensor_value[2:], [0, 0, 0])
- else: # SHAPE.
- for op_type, tensor_value in zip(executed_op_types, tensor_values):
- self.assertLen(tensor_value, 10)
- # 1st element: tensor_id, should be >= 0.
- # TODO(cais): Assert on detailed value once Function-graph association
- # is in place.
- self.assertGreaterEqual(tensor_value[0], 0)
- # 2nd element: dtype enum value (float16).
- self.assertEqual(tensor_value[1], 19)
- # 3rd element: rank (1)
- self.assertEqual(tensor_value[2], 1)
- # 4th element: element count.
- self.assertEqual(tensor_value[3], 8)
- # Remaining elements: shape at fixed length.
- self.assertAllClose(tensor_value[4:], [8, 0, 0, 0, 0, 0])
+ with debug_events_reader.DebugDataReader(self.dump_root) as reader:
+ reader.update()
+ graph_exec_traces = reader.graph_execution_traces()
+ executed_op_types = [trace.op_type for trace in graph_exec_traces]
+ self.assertCountEqual(executed_op_types, ["AddV2", "Sub", "RealDiv"])
+ if tensor_debug_mode == "CURT_HEALTH":
+ for trace in graph_exec_traces:
+ # 1st element: tensor_id, should be >= 0.
+ # 2nd element: indicates if there is any inf or nan.
+ tensor_id = reader.graph_execution_trace_to_tensor_id(trace)
+ self.assertGreaterEqual(tensor_id, 0)
+ if trace.op_type == "RealDiv":
+ self.assertAllClose(trace.debug_tensor_value, [tensor_id, 1])
+ else:
+ self.assertAllClose(trace.debug_tensor_value, [tensor_id, 0])
+ elif tensor_debug_mode == "CONCISE_HEALTH":
+ for trace in graph_exec_traces:
+ # 1st element: tensor_id, should be >= 0.
+ # 2nd element: element count (8).
+ # Remaining 3 elements: The counts of -inf, inf and nan.
+ tensor_id = reader.graph_execution_trace_to_tensor_id(trace)
+ self.assertGreaterEqual(tensor_id, 0)
+ if trace.op_type == "RealDiv":
+ self.assertAllClose(trace.debug_tensor_value,
+ [tensor_id, 8, 1, 3, 2])
+ else:
+ self.assertAllClose(trace.debug_tensor_value,
+ [tensor_id, 8, 0, 0, 0])
+ else: # SHAPE.
+ for trace in graph_exec_traces:
+ # 1st element: tensor_id, should be >= 0.
+ # 2nd element: dtype enum value (float16 = 19).
+ # 3rd element: rank (1)
+ # 4th element: element count (8).
+ # Remaining elements: shape at fixed length (6).
+ tensor_id = reader.graph_execution_trace_to_tensor_id(trace)
+ self.assertGreaterEqual(tensor_id, 0)
+ self.assertAllClose(trace.debug_tensor_value,
+ [tensor_id, 19, 1, 8, 8, 0, 0, 0, 0, 0])
@parameterized.named_parameters(
("Shape", "SHAPE"),
@@ -317,28 +326,21 @@
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
- stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
- (context_ids,
- _, op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id)
-
- (op_names, _, _,
- tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
- executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
- self.assertEqual(executed_op_types, ["LogicalAnd", "LogicalNot"])
-
- for tensor_value in tensor_values:
- # 1st element: tensor_id, should be >= 0.
- # TODO(cais): Assert on detailed value once Function-graph association
- # is in place.
- self.assertGreaterEqual(tensor_value[0], 0)
- # 2nd element: dtype enum value (bool).
- self.assertEqual(tensor_value[1], 10)
- # 3rd element: rank (2)
- self.assertEqual(tensor_value[2], 2)
- # 4th element: element count.
- self.assertEqual(tensor_value[3], 4)
- # Remaining elements: shape at fixed length.
- self.assertAllClose(tensor_value[4:], [2, 2, 0, 0, 0, 0])
+ with debug_events_reader.DebugDataReader(self.dump_root) as reader:
+ reader.update()
+ graph_exec_traces = reader.graph_execution_traces()
+ executed_op_types = [trace.op_type for trace in graph_exec_traces]
+ self.assertEqual(executed_op_types, ["LogicalAnd", "LogicalNot"])
+ for trace in graph_exec_traces:
+ tensor_id = reader.graph_execution_trace_to_tensor_id(trace)
+ self.assertGreaterEqual(tensor_id, 0)
+ # 1st element: tensor_id, should be >= 0.
+ # 2nd element: dtype enum value (bool).
+ # 3rd element: rank (2).
+ # 4th element: element count (4).
+ # Remaining elements: shape at fixed length.
+ self.assertAllClose(
+ trace.debug_tensor_value, [tensor_id, 10, 2, 4, 2, 2, 0, 0, 0, 0])
@parameterized.named_parameters(
("NoTensor", "NO_TENSOR"),
@@ -366,86 +368,157 @@
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
- if context.executing_eagerly():
- # NOTE(b/142486213): Execution of the TF function happens with
- # Session.run() in v1 graph mode, so doesn't get logged to the
- # .execution file.
- (executed_op_types, executed_graph_ids,
- _, _, _, _) = self._readAndCheckExecutionFile()
- executed_op_types = [op_type for op_type in executed_op_types
- if "sin1p_log_sum" in op_type]
- self.assertLen(executed_op_types, 1)
+ with debug_events_reader.DebugDataReader(self.dump_root) as reader:
+ reader.update()
+ outermost_graphs = reader.outermost_graphs()
+ self.assertLen(outermost_graphs, 1)
- stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
- (context_ids, op_types, op_name_to_op_type,
- op_name_to_context_id) = self._readAndCheckGraphsFile(stack_frame_by_id)
+ if context.executing_eagerly():
+ # NOTE(b/142486213): Execution of the TF function happens with
+ # Session.run() in v1 graph mode, so doesn't get logged to the
+ # .execution file.
+ executions = reader.executions()
+ self.assertLen(executions, 1)
+ self.assertIn("sin1p_log_sum", executions[0].op_type)
+ # Get the executed graph and verify its identity and inner graph.
+ graph = reader.graph_by_id(executions[0].graph_id)
+ self.assertEqual(graph.name, "sin1p_log_sum")
+ self.assertLen(graph.inner_graph_ids, 1)
+ inner_graph = reader.graph_by_id(graph.inner_graph_ids[0])
+ self.assertEqual(inner_graph.name, "log_sum")
+ # Check device names.
+ self.assertLen(executions[0].output_tensor_device_ids, 1)
+ self.assertEqual(
+ reader.device_name_by_id(executions[0].output_tensor_device_ids[0]),
+ self._expectedDefaultDeviceName())
+ self.assertIn(self._expectedDefaultDeviceName(), reader.device_names())
- self.assertIn("AddV2", op_types)
- self.assertIn("Log", op_types)
- self.assertIn("Sin", op_types)
- if context.executing_eagerly():
- # Check the correctness of the ID of the executed graph ID.
- sin_op_name = [op_name for op_name in op_name_to_op_type
- if op_name_to_op_type[op_name] == "Sin"]
- self.assertLen(sin_op_name, 1)
- sin_context_id = op_name_to_context_id[sin_op_name[0]]
- # The executed "op" is a FuncGraph, and its graph ID should have been
- # recorded properly and be the ID of the graph that the Sin op belongs to.
- executed_graph_ids = [
- executed_graph_ids[i] for i, op_type
- in enumerate(executed_op_types) if "sin1p_log_sum" in op_type]
- self.assertEqual(executed_graph_ids[0], sin_context_id)
+ # Verify the recorded graph-building history.
+ add_op_digests = reader.graph_op_digests(op_type="AddV2")
+ self.assertLen(add_op_digests, 2)
+ self.assertEqual(
+ reader.graph_by_id(add_op_digests[0].graph_id).name, "log_sum")
+ self.assertEqual(
+ reader.graph_by_id(add_op_digests[1].graph_id).name, "sin1p_log_sum")
+ log_op_digests = reader.graph_op_digests(op_type="Log")
+ self.assertLen(log_op_digests, 1)
+ self.assertEqual(
+ reader.graph_by_id(log_op_digests[0].graph_id).name, "log_sum")
+ sin_op_digests = reader.graph_op_digests(op_type="Sin")
+ self.assertLen(sin_op_digests, 1)
+ self.assertEqual(
+ reader.graph_by_id(sin_op_digests[0].graph_id).name, "sin1p_log_sum")
- (op_names, _, _,
- tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
- executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
- self.assertEqual(executed_op_types, ["AddV2", "Log", "AddV2", "Sin"])
+ # Verify the output tensor IDs and the stack traces.
+ for op_digest in add_op_digests + log_op_digests + sin_op_digests:
+ # These are all single-output ops.
+ self.assertLen(op_digest.output_tensor_ids, 1)
+ self.assertGreaterEqual(op_digest.output_tensor_ids[0], 0)
+ _, stack_frames = reader.read_graph_op_creation_stack_trace(op_digest)
+ self._verifyStackFrames(stack_frames)
- if tensor_debug_mode == "NO_TENSOR":
- # Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought to
- # be an empty float32 tensor.
- for tensor_value in tensor_values:
- self.assertEqual(tensor_value.dtype, np.float32)
- self.assertEqual(tensor_value.shape, (0,))
- elif tensor_debug_mode == "CURT_HEALTH":
- for tensor_value in tensor_values:
- self.assertLen(tensor_value, 2)
+ graph_exec_traces = reader.graph_execution_traces()
+ executed_op_types = [digest.op_type for digest in graph_exec_traces]
+ self.assertEqual(executed_op_types, ["AddV2", "Log", "AddV2", "Sin"])
+
+ # Verify the graph ID stack of each op.
+ # 1st AddV2 op.
+ self.assertEqual(
+ reader.graph_by_id(graph_exec_traces[0].graph_ids[-1]).name,
+ "log_sum")
+ self.assertEqual(
+ reader.graph_by_id(graph_exec_traces[0].graph_ids[-2]).name,
+ "sin1p_log_sum")
+ # Log op.
+ self.assertEqual(
+ reader.graph_by_id(graph_exec_traces[1].graph_ids[-1]).name,
+ "log_sum")
+ self.assertEqual(
+ reader.graph_by_id(graph_exec_traces[1].graph_ids[-2]).name,
+ "sin1p_log_sum")
+ # 2nd AddV2 op.
+ self.assertEqual(
+ reader.graph_by_id(graph_exec_traces[2].graph_ids[-1]).name,
+ "sin1p_log_sum")
+ # Sin op.
+ self.assertEqual(
+ reader.graph_by_id(graph_exec_traces[3].graph_ids[-1]).name,
+ "sin1p_log_sum")
+
+ if tensor_debug_mode == "NO_TENSOR":
+ # Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought
+ # to be an empty float32 tensor.
+ for trace in graph_exec_traces:
+ self.assertEqual(trace.debug_tensor_value, [])
+ elif tensor_debug_mode == "CURT_HEALTH":
+ # Test the association between graph exec and prior graph building.
+ # In each case, the 1st element of debug_tensor_value is the ID of the
+ # symbolic tenosr and the 2nd element is a zero indicating there is no
+ # inf or nan.
+ self.assertAllClose(
+ graph_exec_traces[0].debug_tensor_value,
+ [add_op_digests[0].output_tensor_ids[0], 0.0]) # 1st AddV2 op.
+ self.assertAllClose(
+ graph_exec_traces[1].debug_tensor_value,
+ [log_op_digests[0].output_tensor_ids[0], 0.0]) # Log op.
+ self.assertAllClose(
+ graph_exec_traces[2].debug_tensor_value,
+ [add_op_digests[1].output_tensor_ids[0], 0.0]) # 2nd AddV2 op.
+ self.assertAllClose(
+ graph_exec_traces[3].debug_tensor_value,
+ [sin_op_digests[0].output_tensor_ids[0], 0.0]) # Sin op.
+ elif tensor_debug_mode == "CONCISE_HEALTH":
# 1st element: tensor_id, should be >= 0.
- # TODO(cais): Assert on detailed value once Function-graph association
- # is in place.
- self.assertGreaterEqual(tensor_value[0], 0)
- # 2nd element: 0 means there is no inf or nan.
- self.assertEqual(tensor_value[1], 0)
- elif tensor_debug_mode == "CONCISE_HEALTH":
- for tensor_value in tensor_values:
- self.assertLen(tensor_value, 5)
- # 1st element: tensor_id, should be >= 0.
- # TODO(cais): Assert on detailed value once Function-graph association
- # is in place.
- self.assertGreaterEqual(tensor_value[0], 0)
# 2nd element: element count. Remaining elements: all zero because there
# is no -inf, inf or nan.
- self.assertAllClose(tensor_value[1:], [1, 0, 0, 0])
- elif tensor_debug_mode == "SHAPE":
- for tensor_value in tensor_values:
- # 1st element: tensor_id, should be >= 0.
- # TODO(cais): Assert on detailed value once Function-graph association
- # is in place.
- self.assertGreaterEqual(tensor_value[0], 0)
+ # 1st AddV2 op.
+ self.assertAllClose(
+ graph_exec_traces[0].debug_tensor_value,
+ [add_op_digests[0].output_tensor_ids[0], 1.0, 0.0, 0.0, 0.0])
+ # Log op.
+ self.assertAllClose(
+ graph_exec_traces[1].debug_tensor_value,
+ [log_op_digests[0].output_tensor_ids[0], 1.0, 0.0, 0.0, 0.0])
+ # 2nd AddV2 op.
+ self.assertAllClose(
+ graph_exec_traces[2].debug_tensor_value,
+ [add_op_digests[1].output_tensor_ids[0], 1.0, 0.0, 0.0, 0.0])
+ # Sin op.
+ self.assertAllClose(
+ graph_exec_traces[3].debug_tensor_value,
+ [sin_op_digests[0].output_tensor_ids[0], 1.0, 0.0, 0.0, 0.0])
+ elif tensor_debug_mode == "SHAPE":
+ # 1st element: tensor_id.
# 2nd element: dtype (float32).
- self.assertGreaterEqual(tensor_value[1], 1)
# 3rd element: rank (scalar).
- self.assertGreaterEqual(tensor_value[2], 0)
- # 4th element: element count.
- self.assertGreaterEqual(tensor_value[3], 1)
- # Remaining elements: shape padded to fixed length.
- self.assertAllClose(tensor_value[4:], [0, 0, 0, 0, 0, 0])
- elif tensor_debug_mode == "FULL_TENSOR":
- self.assertAllClose(tensor_values[0], 5.0) # 1st AddV2 op.
- self.assertAllClose(tensor_values[1], np.log(5.0)) # Log op.
- self.assertAllClose(tensor_values[2], np.log(5.0) + 1.0) # 2nd AddV2 op.
- self.assertAllClose(tensor_values[3],
- np.sin(np.log(5.0) + 1.0)) # Sin op.
+ # 4th element: element count (1).
+ # Remaining elements: shape padded to fixed length (6).
+ # 1st AddV2 op.
+ self.assertAllClose(
+ graph_exec_traces[0].debug_tensor_value,
+ [add_op_digests[0].output_tensor_ids[0], 1, 0, 1, 0, 0, 0, 0, 0, 0])
+ # Log op.
+ self.assertAllClose(
+ graph_exec_traces[1].debug_tensor_value,
+ [log_op_digests[0].output_tensor_ids[0], 1, 0, 1, 0, 0, 0, 0, 0, 0])
+ # 2nd AddV2 op.
+ self.assertAllClose(
+ graph_exec_traces[2].debug_tensor_value,
+ [add_op_digests[1].output_tensor_ids[0], 1, 0, 1, 0, 0, 0, 0, 0, 0])
+ # Sin op.
+ self.assertAllClose(
+ graph_exec_traces[3].debug_tensor_value,
+ [sin_op_digests[0].output_tensor_ids[0], 1, 0, 1, 0, 0, 0, 0, 0, 0])
+ else: # FULL_TENSOR.
+ full_tensor_values = [
+ reader.graph_execution_trace_to_tensor_value(trace)
+ for trace in graph_exec_traces]
+ self.assertAllClose(full_tensor_values[0], 5.0) # 1st AddV2 op.
+ self.assertAllClose(full_tensor_values[1], np.log(5.0)) # Log op.
+ self.assertAllClose(
+ full_tensor_values[2], np.log(5.0) + 1.0) # 2nd AddV2 op.
+ self.assertAllClose(
+ full_tensor_values[3], np.sin(np.log(5.0) + 1.0)) # Sin op.
def testCapturingExecutedGraphIdsOfTwoCompilationsOfSameFunction(self):
"""Test correct executed IDs of two FuncGraphs from the same Py function."""
@@ -467,15 +540,21 @@
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
- (executed_op_types, executed_graph_ids,
- _, _, _, _) = self._readAndCheckExecutionFile()
- self.assertLen(executed_op_types, 4)
- for executed_op_type in executed_op_types:
- self.assertStartsWith(executed_op_type, "__inference_ceil_times_two_")
- self.assertLen(executed_graph_ids, 4)
- self.assertEqual(executed_graph_ids[0], executed_graph_ids[2])
- self.assertEqual(executed_graph_ids[1], executed_graph_ids[3])
- self.assertLen(set(executed_graph_ids), 2)
+ with debug_events_reader.DebugDataReader(self.dump_root) as reader:
+ reader.update()
+
+ executions = reader.executions()
+ self.assertLen(executions, 4)
+ for execution in executions:
+ self.assertStartsWith(execution.op_type, "__inference_ceil_times_two_")
+ executed_graph_ids = [execution.graph_id for execution in executions]
+ self.assertEqual(executed_graph_ids[0], executed_graph_ids[2])
+ self.assertEqual(executed_graph_ids[1], executed_graph_ids[3])
+ self.assertNotEqual(executed_graph_ids[0], executed_graph_ids[1])
+ self.assertNotEqual(executed_graph_ids[2], executed_graph_ids[3])
+ for executed_graph_id in executed_graph_ids:
+ self.assertEqual(
+ reader.graph_by_id(executed_graph_id).name, "ceil_times_two")
def testCapturingExecutedGraphIdsOfDuplicateFunctionNames(self):
"""Two FuncGraphs compiled from Python functions with identical names."""
@@ -503,15 +582,20 @@
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
- (executed_op_types, executed_graph_ids,
- _, _, _, _) = self._readAndCheckExecutionFile()
- self.assertLen(executed_op_types, 4)
- for executed_op_type in executed_op_types:
- self.assertStartsWith(executed_op_type, "__inference_ceil_times_two_")
- self.assertLen(executed_graph_ids, 4)
- self.assertEqual(executed_graph_ids[0], executed_graph_ids[2])
- self.assertEqual(executed_graph_ids[1], executed_graph_ids[3])
- self.assertLen(set(executed_graph_ids), 2)
+ with debug_events_reader.DebugDataReader(self.dump_root) as reader:
+ reader.update()
+ executions = reader.executions()
+ self.assertLen(executions, 4)
+ for execution in executions:
+ self.assertStartsWith(execution.op_type, "__inference_ceil_times_two_")
+ executed_graph_ids = [execution.graph_id for execution in executions]
+ self.assertEqual(executed_graph_ids[0], executed_graph_ids[2])
+ self.assertEqual(executed_graph_ids[1], executed_graph_ids[3])
+ self.assertNotEqual(executed_graph_ids[0], executed_graph_ids[1])
+ self.assertNotEqual(executed_graph_ids[2], executed_graph_ids[3])
+ for executed_graph_id in executed_graph_ids:
+ self.assertEqual(
+ reader.graph_by_id(executed_graph_id).name, "ceil_times_two")
@parameterized.named_parameters(
("AddV2", "AddV2"),
@@ -539,32 +623,35 @@
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
- stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
- (context_ids, op_types,
- op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id)
- self.assertIn("AddV2", op_types)
- self.assertIn("Log", op_types)
- self.assertIn("Sin", op_types)
+ with debug_events_reader.DebugDataReader(self.dump_root) as reader:
+ reader.update()
+ graph_op_digests = reader.graph_op_digests()
+ op_types = [digest.op_type for digest in graph_op_digests]
+ self.assertIn("AddV2", op_types)
+ self.assertIn("Log", op_types)
+ self.assertIn("Sin", op_types)
- (op_names, _, _,
- tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
- executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
-
- if op_regex == "AddV2":
- self.assertEqual(executed_op_types, ["AddV2", "AddV2"])
- self.assertLen(tensor_values, 2)
- self.assertAllClose(tensor_values[0], 5.0) # 1st AddV2 op.
- self.assertAllClose(tensor_values[1], np.log(5.0) + 1.0) # 2nd AddV2 op.
- elif op_regex == "Log":
- self.assertEqual(executed_op_types, ["Log"])
- self.assertLen(tensor_values, 1)
- self.assertAllClose(tensor_values[0], np.log(5.0)) # Log op.
- else: # "(AddV2|Log)"
- self.assertEqual(executed_op_types, ["AddV2", "Log", "AddV2"])
- self.assertLen(tensor_values, 3)
- self.assertAllClose(tensor_values[0], 5.0) # 1st AddV2 op.
- self.assertAllClose(tensor_values[1], np.log(5.0)) # Log op.
- self.assertAllClose(tensor_values[2], np.log(5.0) + 1.0) # 2nd AddV2 op.
+ graph_exec_digests = reader.graph_execution_traces(digest=True)
+ executed_op_types = [digest.op_type for digest in graph_exec_digests]
+ tensor_values = [reader.graph_execution_trace_to_tensor_value(digest)
+ for digest in graph_exec_digests]
+ if op_regex == "AddV2":
+ self.assertEqual(executed_op_types, ["AddV2", "AddV2"])
+ self.assertLen(tensor_values, 2)
+ self.assertAllClose(tensor_values[0], 5.0) # 1st AddV2 op.
+ self.assertAllClose(
+ tensor_values[1], np.log(5.0) + 1.0) # 2nd AddV2 op.
+ elif op_regex == "Log":
+ self.assertEqual(executed_op_types, ["Log"])
+ self.assertLen(tensor_values, 1)
+ self.assertAllClose(tensor_values[0], np.log(5.0)) # Log op.
+ else: # "(AddV2|Log)"
+ self.assertEqual(executed_op_types, ["AddV2", "Log", "AddV2"])
+ self.assertLen(tensor_values, 3)
+ self.assertAllClose(tensor_values[0], 5.0) # 1st AddV2 op.
+ self.assertAllClose(tensor_values[1], np.log(5.0)) # Log op.
+ self.assertAllClose(
+ tensor_values[2], np.log(5.0) + 1.0) # 2nd AddV2 op.
def testIncorrectTensorDTypeArgFormatLeadsToError(self):
with self.assertRaisesRegexp(
@@ -617,48 +704,54 @@
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
- stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
- (context_ids, _,
- op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id)
- (op_names, _, _,
- tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
- executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
- if tensor_dtypes == [dtypes.float32] and not op_regex:
- self.assertEqual(executed_op_types, ["Unique", "Sum"])
- self.assertLen(tensor_values, 2)
- self.assertAllClose(tensor_values[0], [2., 6., 8., 1.]) # Unique values.
- self.assertAllClose(tensor_values[1], 17.) # Sum.
- elif tensor_dtypes == ["float32"] and op_regex == "Sum":
- self.assertEqual(executed_op_types, ["Sum"])
- self.assertLen(tensor_values, 1)
- self.assertAllClose(tensor_values[0], 17.) # Sum.
- elif tensor_dtypes == (dtypes.float32,) and op_regex == "(?!Sum)":
- self.assertEqual(executed_op_types, ["Unique"])
- self.assertLen(tensor_values, 1)
- self.assertAllClose(tensor_values[0], [2., 6., 8., 1.]) # Unique values.
- elif tensor_dtypes == [dtypes.int32] and not op_regex:
- self.assertEqual(executed_op_types, ["Unique"])
- self.assertLen(tensor_values, 1)
- self.assertAllEqual(tensor_values[0], [0, 1, 2, 3, 0]) # Unique indices.
- elif callable(tensor_dtypes) and not op_regex:
- self.assertEqual(executed_op_types, ["Unique"])
- self.assertLen(tensor_values, 1)
- self.assertAllEqual(tensor_values[0], [0, 1, 2, 3, 0]) # Unique indices.
- elif not tensor_dtypes and op_regex == "(?!Sum)":
- self.assertEqual(executed_op_types, ["Unique", "Unique"])
- self.assertLen(tensor_values, 2)
- self.assertAllClose(tensor_values[0], [2., 6., 8., 1.]) # Unique values.
- self.assertAllEqual(tensor_values[1], [0, 1, 2, 3, 0]) # Unique indices.
- else: # "All".
- self.assertEqual(executed_op_types, ["Unique", "Unique", "Sum"])
- self.assertLen(tensor_values, 3)
- self.assertAllClose(tensor_values[0], [2., 6., 8., 1.]) # Unique values.
- self.assertAllEqual(tensor_values[1], [0, 1, 2, 3, 0]) # Unique indices.
- self.assertAllClose(tensor_values[2], 17.) # Sum.
+ with debug_events_reader.DebugDataReader(self.dump_root) as reader:
+ reader.update()
+ graph_exec_digests = reader.graph_execution_traces(digest=True)
+ executed_op_types = [digest.op_type for digest in graph_exec_digests]
+ tensor_values = [reader.graph_execution_trace_to_tensor_value(digest)
+ for digest in graph_exec_digests]
+
+ if tensor_dtypes == [dtypes.float32] and not op_regex:
+ self.assertEqual(executed_op_types, ["Unique", "Sum"])
+ self.assertLen(tensor_values, 2)
+ self.assertAllClose(tensor_values[0], [2, 6, 8, 1]) # Unique values.
+ self.assertAllClose(tensor_values[1], 17.) # Sum.
+ elif tensor_dtypes == ["float32"] and op_regex == "Sum":
+ self.assertEqual(executed_op_types, ["Sum"])
+ self.assertLen(tensor_values, 1)
+ self.assertAllClose(tensor_values[0], 17.) # Sum.
+ elif tensor_dtypes == (dtypes.float32,) and op_regex == "(?!Sum)":
+ self.assertEqual(executed_op_types, ["Unique"])
+ self.assertLen(tensor_values, 1)
+ self.assertAllClose(tensor_values[0], [2, 6, 8, 1]) # Unique values.
+ elif tensor_dtypes == [dtypes.int32] and not op_regex:
+ self.assertEqual(executed_op_types, ["Unique"])
+ self.assertLen(tensor_values, 1)
+ self.assertAllEqual(
+ tensor_values[0], [0, 1, 2, 3, 0]) # Unique indices.
+ elif callable(tensor_dtypes) and not op_regex:
+ self.assertEqual(executed_op_types, ["Unique"])
+ self.assertLen(tensor_values, 1)
+ self.assertAllEqual(
+ tensor_values[0], [0, 1, 2, 3, 0]) # Unique indices.
+ elif not tensor_dtypes and op_regex == "(?!Sum)":
+ self.assertEqual(executed_op_types, ["Unique", "Unique"])
+ self.assertLen(tensor_values, 2)
+ self.assertAllClose(tensor_values[0], [2, 6, 8, 1]) # Unique values.
+ self.assertAllEqual(
+ tensor_values[1], [0, 1, 2, 3, 0]) # Unique indices.
+ else: # "All".
+ self.assertEqual(executed_op_types, ["Unique", "Unique", "Sum"])
+ self.assertLen(tensor_values, 3)
+ self.assertAllClose(tensor_values[0], [2, 6, 8, 1]) # Unique values.
+ self.assertAllEqual(
+ tensor_values[1], [0, 1, 2, 3, 0]) # Unique indices.
+ self.assertAllClose(tensor_values[2], 17) # Sum.
@parameterized.named_parameters(
("NoTensor", "NO_TENSOR"),
+ ("CurtHealth", "CURT_HEALTH"),
("FullTensor", "FULL_TENSOR"),
)
@test_util.run_in_graph_and_eager_modes
@@ -679,86 +772,78 @@
self.assertAllClose(self.evaluate(iterative_doubling(x, times)), 8.0)
writer.FlushNonExecutionFiles()
- stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
+ with debug_events_reader.DebugDataReader(self.dump_root) as reader:
+ reader.update()
+ graph_op_digests = reader.graph_op_digests()
+ op_types = [digest.op_type for digest in graph_op_digests]
+ self.assertIn("Less", op_types)
+ self.assertIn("Mul", op_types)
+ self.assertIn("AddV2", op_types)
- # Verify the content of the .graphs file.
- context_ids, op_types, op_name_to_op_type, _ = (
- self._readAndCheckGraphsFile(stack_frame_by_id))
- self.assertIn("Less", op_types)
- self.assertIn("Mul", op_types)
- self.assertIn("AddV2", op_types)
-
- # Before FlushExecutionFiles() is called, the .execution and
- # .graph_execution_traces files should be both empty.
- with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
- execution_iter = reader.execution_iterator()
- graph_execution_traces_iter = reader.graph_execution_traces_iterator()
- with self.assertRaises(StopIteration):
- next(execution_iter)
- with self.assertRaises(StopIteration):
- next(graph_execution_traces_iter)
+ # Before FlushExecutionFiles() is called, the .execution and
+ # .graph_execution_traces files should be both empty.
+ self.assertEqual(reader.num_executions(), 0)
+ self.assertEqual(reader.num_graph_execution_traces(), 0)
# TODO(cais): Backport execution instrumentation to tf.Session.
writer.FlushExecutionFiles()
# After the flushing, the .execution file should hold the appropriate
# contents.
+ reader.update()
if context.executing_eagerly():
- (executed_op_types, _, input_tensor_ids, output_tensor_ids,
- tensor_debug_modes, tensor_values) = self._readAndCheckExecutionFile()
# NOTE(b/142486213): Execution of the TF function happens with
# Session.run() in v1 graph mode, hence it doesn't get logged to the
- # .execution file.
- self.assertLen(executed_op_types, 1)
- self.assertIn("iterative_doubling", executed_op_types[0])
- self.assertLen(input_tensor_ids[0], 2)
- self.assertLen(output_tensor_ids[0], 1)
+ executions = reader.executions()
+ self.assertLen(executions, 1)
+ executed_op_types = [execution.op_type for execution in executions]
+ self.assertIn("iterative_doubling", executions[0].op_type)
+ execution = executions[0]
+ self.assertLen(execution.input_tensor_ids, 2)
+ self.assertLen(execution.output_tensor_ids, 1)
self.assertEqual(
- tensor_debug_modes[0],
- debug_event_pb2.TensorDebugMode.Value(tensor_debug_mode))
+ debug_event_pb2.TensorDebugMode.keys()[execution.tensor_debug_mode],
+ tensor_debug_mode)
if tensor_debug_mode == "FULL_TENSOR":
- self.assertAllClose(tensor_values, [[8.0]])
+ tensor_values = reader.execution_to_tensor_values(execution)
+ self.assertAllClose(tensor_values, [8.0])
- (op_names, _, output_slots,
- tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
- executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
- # The Less op should have been executed 5 times.
- self.assertEqual(executed_op_types.count("Less"), 5)
- # The last executed op should be Less.
- self.assertEqual(executed_op_types[-1], "Less")
+ graph_exec_traces = reader.graph_execution_traces()
+ executed_op_types = [trace.op_type for trace in graph_exec_traces]
+ if tensor_debug_mode != "CURT_HEALTH":
+ # Less outputs a boolean tensor, which is not tracked under CURT_HEALTH.
+ # The Less op should have been executed 5 times.
+ self.assertEqual(executed_op_types.count("Less"), 5)
+ # The last executed op should be Less.
+ self.assertEqual(executed_op_types[-1], "Less")
+ # AddV2 produces an int tensor, which is not tracked under CURT_HEALTH.
+ # The AddV2 op should have been run, but we refrain from asserting on
+ # how many times it's executed.
+ self.assertIn("AddV2", executed_op_types)
+ for trace in graph_exec_traces:
+ self.assertEqual(trace.output_slot, 0)
# The Mul op should have been executed 4 times.
self.assertEqual(executed_op_types.count("Mul"), 4)
- # The AddV2 op should have been run, but we refrain from asserting on how
- # many times it's executed.
- self.assertIn("AddV2", executed_op_types)
- for output_slot in output_slots:
- self.assertEqual(output_slot, 0)
+
+ tensor_values = [reader.graph_execution_trace_to_tensor_value(trace)
+ for trace in graph_exec_traces]
if tensor_debug_mode == "NO_TENSOR":
# Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought
# to be an empty float32 tensor.
for tensor_value in tensor_values:
- self.assertEqual(tensor_value.dtype, np.float32)
- self.assertEqual(tensor_value.shape, (0,))
- elif tensor_debug_mode == "CURT_TENSOR":
- for tensor_value in tensor_values:
- self.assertLen(tensor_value, 2)
- # 1st element: tensor_id, should be >= 0.
- # TODO(cais): Assert on detailed value once Function-graph association
- # is in place.
- self.assertGreaterEqual(tensor_value[0], 0)
- # 2nd element: 0 means there is no inf or nan.
- self.assertEqual(tensor_value[1], 0)
+ self.assertAllEqual(tensor_value, [])
+ elif tensor_debug_mode == "CURT_HEALTH":
+ for trace in graph_exec_traces:
+ tensor_id = reader.graph_execution_trace_to_tensor_id(trace)
+ # 1st element: tensor_id; 2nd element: 0 indicating no inf or nan.
+ self.assertAllClose(trace.debug_tensor_value, [tensor_id, 0.0])
elif tensor_debug_mode == "FULL_TENSOR":
less_values = [
- tensor_values[i]
- for i, op_type in enumerate(executed_op_types)
- if op_type == "Less"
- ]
- self.assertAllClose(less_values, [True, True, True, True, False])
+ reader.graph_execution_trace_to_tensor_value(trace)
+ for trace in graph_exec_traces if trace.op_type == "Less"]
+ self.assertAllEqual(less_values, [True, True, True, True, False])
mul_values = [
- tensor_values[i]
- for i, op_type in enumerate(executed_op_types)
- if op_type == "Mul"
- ]
+ reader.graph_execution_trace_to_tensor_value(trace)
+ for trace in graph_exec_traces if trace.op_type == "Mul"]
self.assertAllClose(mul_values, [1.0, 2.0, 4.0, 8.0])
def testCallingEnableTracingTwiceWithTheSameDumpRootIsIdempotent(self):
@@ -772,17 +857,16 @@
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
- with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
- execution_iter = reader.execution_iterator()
- for _ in range(2):
- debug_event = next(execution_iter)
- self.assertGreater(debug_event.wall_time, 0)
- execution = debug_event.execution
+ with debug_events_reader.DebugDataReader(self.dump_root) as reader:
+ reader.update()
+ executions = reader.executions()
+ self.assertLen(executions, 2)
+ for execution in executions:
+ self.assertGreater(execution.wall_time, 0)
self.assertEqual(execution.op_type, "Unique")
self.assertEqual(execution.num_outputs, 2)
- self.assertTrue(execution.code_location)
- with self.assertRaises(StopIteration):
- next(execution_iter)
+ _, stack_frames = reader.read_execution_stack_trace(execution)
+ self._verifyStackFrames(stack_frames)
def testCallingEnableTracingTwiceWithDifferentDumpRootsOverwrites(self):
dumping_callback.enable_dump_debug_info(self.dump_root)
@@ -796,27 +880,26 @@
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
- with debug_events_reader.DebugEventsReader(new_dump_root) as reader:
- execution_iter = reader.execution_iterator()
- for _ in range(2):
- debug_event = next(execution_iter)
- self.assertGreater(debug_event.wall_time, 0)
- execution = debug_event.execution
+ with debug_events_reader.DebugDataReader(new_dump_root) as reader:
+ reader.update()
+ executions = reader.executions()
+ self.assertLen(executions, 2)
+ for execution in executions:
+ self.assertGreater(execution.wall_time, 0)
self.assertEqual(execution.op_type, "Unique")
self.assertEqual(execution.num_outputs, 2)
- self.assertTrue(execution.code_location)
- with self.assertRaises(StopIteration):
- next(execution_iter)
+ _, stack_frames = reader.read_execution_stack_trace(execution)
+ self._verifyStackFrames(stack_frames)
- with debug_events_reader.DebugEventsReader(
- self.dump_root) as old_dump_root_reader:
- execution_iter = old_dump_root_reader.execution_iterator()
- # The old dump root shouldn't have been written to.
- with self.assertRaises(StopIteration):
- next(execution_iter)
+ with debug_events_reader.DebugDataReader(
+ self.dump_root) as old_dump_root_reader:
+ old_dump_root_reader.update()
+ # The old dump root shouldn't have been written to.
+ self.assertEqual(old_dump_root_reader.num_executions(), 0)
+ self.assertFalse(old_dump_root_reader.outermost_graphs())
def testCallingEnableRepeatedlyWithDifferentTensorDebugMode(self):
- """Assert that calling enable_dump_debug_info() with different tensor-debug modes.
+ """Assert calling enable_dump_debug_info() with two tensor-debug modes.
It should lead to overwriting of the previously-configured mode.
"""
@@ -830,16 +913,16 @@
self.assertAllClose(add_1_divide_by_2(constant_op.constant(4.0)), 2.5)
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
- stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
- context_ids, _, _, _ = self._readAndCheckGraphsFile(stack_frame_by_id)
- _, _, _, _, _, tensor_values = self._readAndCheckExecutionFile()
- self.assertEqual(tensor_values, [[]])
- (_, _, _,
- tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
- self.assertLen(tensor_values, 2)
- for tensor_value in tensor_values:
- self.assertEqual(tensor_value.dtype, np.float32)
- self.assertEqual(tensor_value.shape, (0,))
+
+ with debug_events_reader.DebugDataReader(self.dump_root) as reader:
+ reader.update()
+ graph_exec_digests = reader.graph_execution_traces(digest=True)
+ tensor_values = [reader.graph_execution_trace_to_tensor_value(digest)
+ for digest in graph_exec_digests]
+ for tensor_value in tensor_values:
+ # Under NO_TENSOR mode, each tensor is summarized as an empty float32
+ # array.
+ self.assertAllEqual(tensor_value, [])
with self.assertRaisesRegexp(
ValueError, r"already.*NO_TENSOR.*FULL_TENSOR.*not be honored"):
@@ -862,17 +945,11 @@
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
- with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
- source_files_iter = reader.source_files_iterator()
- stack_frames_iter = reader.stack_frames_iterator()
- execution_iter = reader.execution_iterator()
- # No source-file, stack-frame or execution data should have been dumped.
- with self.assertRaises(StopIteration):
- next(source_files_iter)
- with self.assertRaises(StopIteration):
- next(stack_frames_iter)
- with self.assertRaises(StopIteration):
- next(execution_iter)
+ with debug_events_reader.DebugDataReader(self.dump_root) as reader:
+ reader.update()
+ self.assertEqual(reader.num_executions(), 0)
+ self.assertEqual(reader.num_graph_execution_traces(), 0)
+ self.assertFalse(reader.outermost_graphs())
@parameterized.named_parameters(
("NoTensor", "NO_TENSOR"),
@@ -908,73 +985,54 @@
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
- stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
- with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
- execution_iter = reader.execution_iterator()
+ with debug_events_reader.DebugDataReader(self.dump_root) as reader:
+ reader.update()
+ exec_digests = reader.executions(digest=True)
prev_wall_time = 1
- for debug_event in execution_iter:
- self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
- prev_wall_time = debug_event.wall_time
+ for exec_digest in exec_digests:
+ self.assertGreaterEqual(exec_digest.wall_time, prev_wall_time)
+ prev_wall_time = exec_digest.wall_time
- (context_ids, _,
- op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id)
+ graph_exec_traces = reader.graph_execution_traces()
+ executed_op_types = [trace.op_type for trace in graph_exec_traces]
+ self.assertEqual(executed_op_types.count("Mul"), 1 + num_threads)
+ self.assertEqual(
+ executed_op_types.count("ReadVariableOp"), 2 * (1 + num_threads))
+ for trace in graph_exec_traces:
+ # These are all single-output tensors.
+ self.assertEqual(trace.output_slot, 0)
- (op_names, _, output_slots,
- tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
- executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
- self.assertEqual(executed_op_types.count("Mul"), 1 + num_threads)
- self.assertEqual(
- executed_op_types.count("ReadVariableOp"), 2 * (1 + num_threads))
- for output_slot in output_slots:
- self.assertEqual(output_slot, 0)
+ tensor_values = [reader.graph_execution_trace_to_tensor_value(trace)
+ for trace in graph_exec_traces]
if tensor_debug_mode == "NO_TENSOR":
for tensor_value in tensor_values:
- self.assertEqual(tensor_value.dtype, np.float32)
- self.assertEqual(tensor_value.shape, (0,))
+ self.assertAllEqual(tensor_value, [])
elif tensor_debug_mode == "CURT_HEALTH":
- for tensor_value in tensor_values:
- self.assertLen(tensor_value, 2)
- # 1st element: tensor_id, should be >= 0.
- # TODO(cais): Assert on detailed value once Function-graph association
- # is in place.
- self.assertGreaterEqual(tensor_value[0], 0)
- # 2nd element: 0 means there is no inf or nan.
- self.assertEqual(tensor_value[1], 0)
+ for trace in graph_exec_traces:
+ tensor_id = reader.graph_execution_trace_to_tensor_id(trace)
+ # 1st element: tensor ID; 2nd element: 0 indicating no inf or nan.
+ self.assertAllClose(trace.debug_tensor_value, [tensor_id, 0])
elif tensor_debug_mode == "CONCISE_HEALTH":
for tensor_value in tensor_values:
- self.assertLen(tensor_value, 5)
- # 1st element: tensor_id, should be >= 0.
- # TODO(cais): Assert on detailed value once Function-graph association
- # is in place.
- self.assertGreaterEqual(tensor_value[0], 0)
+ tensor_id = reader.graph_execution_trace_to_tensor_id(trace)
+ # 1st element: tensor ID.
# 2nd element: element count. Remaining elements: all zero because there
# is no -inf, inf or nan.
- self.assertAllClose(tensor_value[1:], [1, 0, 0, 0])
+ self.assertAllClose(trace.debug_tensor_value, [tensor_id, 1, 0, 0, 0])
elif tensor_debug_mode == "SHAPE":
- mul_values = [
- tensor_values[i]
- for i, op_type in enumerate(executed_op_types)
- if op_type == "Mul"
- ]
- for mul_value in mul_values:
- # 1st element: tensor_id, should be >= 0.
- # TODO(cais): Assert on detailed value once Function-graph association
- # is in place.
- self.assertGreaterEqual(mul_value[0], 0)
- # 2nd element: dtype enum value (float32).
- self.assertEqual(mul_value[1], 1)
- # 3rd element: rank.
- self.assertEqual(mul_value[2], 0)
- # 3rd element: element count.
- self.assertEqual(mul_value[3], 1)
- # Remaining elements: shape padded to a fixed length.
- self.assertAllClose(mul_value[4:], [0, 0, 0, 0, 0, 0])
+ for trace in graph_exec_traces:
+ if trace.op_type == "Mul":
+ tensor_id = reader.graph_execution_trace_to_tensor_id(trace)
+ mul_value = reader.graph_execution_trace_to_tensor_value(trace)
+ # 1st element: tensor_id, should be >= 0.
+ # 2nd element: dtype enum value (float32).
+ # 3rd element: rank.
+ # 4th element: element count.
+ self.assertAllClose(mul_value, [tensor_id, 1, 0, 1, 0, 0, 0, 0, 0, 0])
elif tensor_debug_mode == "FULL_TENSOR":
mul_values = [
- tensor_values[i]
- for i, op_type in enumerate(executed_op_types)
- if op_type == "Mul"
- ]
+ reader.graph_execution_trace_to_tensor_value(trace)
+ for trace in graph_exec_traces if trace.op_type == "Mul"]
self.assertAllClose(mul_values, [6.0, 6.0, 6.0, 6.0])
def testMultiThreadedDumpingWithDifferentSettings(self):
@@ -1017,23 +1075,28 @@
self.assertAllClose(v1.read_value(), -67084290.0)
self.assertAllClose(v2.read_value(), -6.0)
- (executed_op_types, _, _, _, _,
- tensor_values) = self._readAndCheckExecutionFile(dump_root=dump_root_1)
- v1_squared_values = [
- tensor_values[i] for i, op_type in enumerate(executed_op_types)
- if op_type == "Pow"]
- negative_v1_squared_values = [
- tensor_values[i] for i, op_type in enumerate(executed_op_types)
- if op_type == "Neg"]
- self.assertAllClose(v1_squared_values, [[100.0], [8100.0], [67076100.0]])
- self.assertAllClose(
- negative_v1_squared_values, [[-100.0], [-8100.0], [-67076100.0]])
+ with debug_events_reader.DebugDataReader(dump_root_1) as reader:
+ reader.update()
+ exec_digests = reader.executions(digest=True)
+ v1_squared_values = [
+ reader.execution_to_tensor_values(digest)
+ for digest in exec_digests if digest.op_type == "Pow"]
+ negative_v1_squared_values = [
+ reader.execution_to_tensor_values(digest)
+ for digest in exec_digests if digest.op_type == "Neg"]
+ self.assertAllClose(v1_squared_values, [[100.0], [8100.0], [67076100.0]])
+ self.assertAllClose(
+ negative_v1_squared_values, [[-100.0], [-8100.0], [-67076100.0]])
- (executed_op_types, _, _, _, _,
- tensor_values) = self._readAndCheckExecutionFile(dump_root=dump_root_2)
- self.assertNotIn("Neg", executed_op_types)
- v2_squared_values = tensor_values[executed_op_types.index("Pow")]
- self.assertAllClose(v2_squared_values, [9.0])
+ with debug_events_reader.DebugDataReader(dump_root_2) as reader:
+ reader.update()
+ exec_digests = reader.executions(digest=True)
+ executed_op_types = [digest.op_type for digest in exec_digests]
+ self.assertNotIn("Neg", executed_op_types)
+ v2_squared_values = [
+ reader.execution_to_tensor_values(digest)
+ for digest in exec_digests if digest.op_type == "Pow"]
+ self.assertAllClose(v2_squared_values, [[9.0]])
@test_util.run_in_graph_and_eager_modes
def testNestedContextIsCapturedByGraphOpCreationHistory(self):
@@ -1055,36 +1118,18 @@
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
-
- stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
- (_, _, op_name_to_op_type,
- op_name_to_context_id) = self._readAndCheckGraphsFile(stack_frame_by_id)
-
- less_op_names = [op_name for op_name in op_name_to_op_type
- if op_name_to_op_type[op_name] == "Less"]
- less_context_ids = [op_name_to_context_id[op_name]
- for op_name in less_op_names]
- mul_op_names = [op_name for op_name in op_name_to_op_type
- if op_name_to_op_type[op_name] == "Mul"]
- mul_context_ids = [op_name_to_context_id[op_name]
- for op_name in mul_op_names]
- sub_op_names = [op_name for op_name in op_name_to_op_type
- if op_name_to_op_type[op_name] == "Sub"]
- sub_context_ids = [op_name_to_context_id[op_name]
- for op_name in sub_op_names]
- self.assertLen(less_context_ids, 1)
- self.assertLen(mul_context_ids, 1)
- self.assertLen(sub_context_ids, 1)
- self.assertTrue(less_context_ids[0])
- self.assertTrue(mul_context_ids[0])
- self.assertTrue(sub_context_ids[0])
- # The Less op is from the while-loop cond context and hence should have
- # a different innermost context ID from the mul and sub ops, which are both
- # from the while-loop body context.
- self.assertNotEqual(less_context_ids[0], mul_context_ids[0])
- self.assertNotEqual(less_context_ids[0], sub_context_ids[0])
- # The Mul and Sub ops are from the same innermost context.
- self.assertEqual(mul_context_ids[0], sub_context_ids[0])
+ with debug_events_reader.DebugDataReader(self.dump_root) as reader:
+ reader.update()
+ less_op_digest = reader.graph_op_digests(op_type="Less")[-1]
+ mul_op_digest = reader.graph_op_digests(op_type="Mul")[-1]
+ sub_op_digest = reader.graph_op_digests(op_type="Sub")[-1]
+ # The Less op is from the while-loop cond context and hence should have
+ # a different innermost context ID from the mul and sub ops, which are
+ # both from the while-loop body context.
+ self.assertNotEqual(less_op_digest.graph_id, mul_op_digest.graph_id)
+ self.assertNotEqual(less_op_digest.graph_id, sub_op_digest.graph_id)
+ # The Mul and Sub ops are from the same innermost context.
+ self.assertEqual(mul_op_digest.graph_id, sub_op_digest.graph_id)
@parameterized.named_parameters(
("NoTensor", "NO_TENSOR"),
@@ -1102,53 +1147,38 @@
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
- stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
- (context_ids, op_types,
- op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id)
- # Simply assert that graph are recorded and refrain from asserting on the
- # internal details of the Keras model.
- self.assertTrue(context_ids)
- self.assertTrue(op_types)
- self.assertTrue(op_name_to_op_type)
+ with debug_events_reader.DebugDataReader(self.dump_root) as reader:
+ reader.update()
+ if context.executing_eagerly():
+ # NOTE(b/142486213): Execution of the TF function happens with
+ # Session.run() in v1 graph mode, hence it doesn't get logged to the
+ # .execution file.
+ self.assertTrue(reader.executions(digest=True))
- if context.executing_eagerly():
- # NOTE(b/142486213): Execution of the TF function happens with
- # Session.run() in v1 graph mode, hence it doesn't get logged to the
- # .execution file.
- (executed_op_types, _, _, _, _,
- tensor_values) = self._readAndCheckExecutionFile()
- self.assertTrue(executed_op_types)
+ graph_exec_digests = reader.graph_execution_traces(digest=True)
+ executed_op_types = [digest.op_type for digest in graph_exec_digests]
+ # These are the ops that we can safely assume to have been executed during
+ # the model prediction.
+ self.assertIn("MatMul", executed_op_types)
+ self.assertIn("BiasAdd", executed_op_types)
+ # On the GPU, CudnnRNN is used in lieu of the default op-by-op
+ # implementation.
+ self.assertTrue(
+ ("Sigmoid" in executed_op_types and "Tanh" in executed_op_types or
+ "CudnnRNN" in executed_op_types))
- for value_list in tensor_values:
- if tensor_debug_mode == "NO_TENSOR":
- self.assertFalse(value_list)
-
- (op_names, _, _,
- tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
- executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
- # These are the ops that we can safely assume to have been executed during
- # the model prediction.
- self.assertIn("MatMul", executed_op_types)
- self.assertIn("BiasAdd", executed_op_types)
- # On the GPU, CudnnRNN is used in lieu of the default op-by-op
- # implementation.
- self.assertTrue(
- ("Sigmoid" in executed_op_types and "Tanh" in executed_op_types or
- "CudnnRNN" in executed_op_types))
- # Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought to
- # be an empty float32 tensor.
- if tensor_debug_mode == "NO_TENSOR":
- for tensor_value in tensor_values:
- self.assertEqual(tensor_value.dtype, np.float32)
- self.assertEqual(tensor_value.shape, (0,))
- else:
- # Refrain from asserting the internal implementation details of the LSTM
- # layer.
- concrete_tensor_values = [
- value for value in tensor_values
- if value is not None and value.size > 0
- ]
- self.assertTrue(concrete_tensor_values)
+ # Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought to
+ # be an empty float32 tensor.
+ tensor_values = [reader.graph_execution_trace_to_tensor_value(digest)
+ for digest in graph_exec_digests]
+ if tensor_debug_mode == "NO_TENSOR":
+ for tensor_value in tensor_values:
+ self.assertAllEqual(tensor_value, [])
+ else:
+ # Refrain from asserting the internal implementation details of the LSTM
+ # layer.
+ self.assertTrue(any(
+ bool(tensor_value.size) for tensor_value in tensor_values))
@parameterized.named_parameters(
("NoTensor", "NO_TENSOR"),
@@ -1169,48 +1199,38 @@
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
- stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
- (context_ids, op_types,
- op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id)
- # Simply assert that graph are recorded and refrain from asserting on the
- # internal details of the Keras model.
- self.assertTrue(context_ids)
- self.assertTrue(op_types)
- self.assertTrue(op_name_to_op_type)
+ with debug_events_reader.DebugDataReader(self.dump_root) as reader:
+ reader.update()
+ if context.executing_eagerly():
+ exec_digests = reader.executions(digest=True)
+ self.assertTrue(exec_digests)
+ if tensor_debug_mode == "NO_TENSOR":
+ for digest in exec_digests:
+ tensor_values = reader.execution_to_tensor_values(digest)
+ for tensor_value in tensor_values:
+ self.assertEqual(tensor_value, [])
- if context.executing_eagerly():
- # NOTE(b/142486213): Execution of the TF function happens with
- # Session.run() in v1 graph mode, hence it doesn't get logged to the
- # .execution file.
- (executed_op_types, _, _, _, _,
- tensor_values) = self._readAndCheckExecutionFile()
- self.assertTrue(executed_op_types)
+ graph_exec_digests = reader.graph_execution_traces(digest=True)
+ executed_op_types = [digest.op_type for digest in graph_exec_digests]
+ # These are the ops that we can safely assume to have been executed during
+ # the recurrent model's fit() call.
+ self.assertIn("MatMul", executed_op_types)
+ self.assertIn("BiasAdd", executed_op_types)
+
+ # On the GPU, CudnnRNN is used in lieu of the default op-by-op
+ # implementation.
+ self.assertTrue(
+ ("Sigmoid" in executed_op_types and "Tanh" in executed_op_types or
+ "CudnnRNN" in executed_op_types))
+ self.assertTrue(
+ ("SigmoidGrad" in executed_op_types and
+ "TanhGrad" in executed_op_types or
+ "CudnnRNNBackprop" in executed_op_types))
if tensor_debug_mode == "NO_TENSOR":
- for value_list in tensor_values:
- self.assertFalse(value_list)
-
- (op_names, _, _,
- tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
- executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
- # These are the ops that we can safely assume to have been executed during
- # the recurrent model's fit() call.
- self.assertIn("MatMul", executed_op_types)
- self.assertIn("BiasAdd", executed_op_types)
- # On the GPU, CudnnRNN is used in lieu of the default op-by-op
- # implementation.
- self.assertTrue(
- ("Sigmoid" in executed_op_types and "Tanh" in executed_op_types or
- "CudnnRNN" in executed_op_types))
- self.assertTrue(
- ("SigmoidGrad" in executed_op_types and
- "TanhGrad" in executed_op_types or
- "CudnnRNNBackprop" in executed_op_types))
- if tensor_debug_mode == "NO_TENSOR":
- # Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought
- # to be an empty float32 tensor.
- for tensor_value in tensor_values:
- self.assertEqual(tensor_value.dtype, np.float32)
- self.assertEqual(tensor_value.shape, (0,))
+ for digest in graph_exec_digests:
+ tensor_values = reader.graph_execution_trace_to_tensor_value(digest)
+ for tensor_value in tensor_values:
+ self.assertEqual(tensor_value, [])
@parameterized.named_parameters(
("NoTensor", "NO_TENSOR"),
@@ -1242,72 +1262,60 @@
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
- stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
- (context_ids, op_types,
- op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id)
- # Simply assert that graph are recorded and refrain from asserting on the
- # internal details of the Keras model.
- self.assertTrue(context_ids)
- self.assertTrue(op_types)
- self.assertTrue(op_name_to_op_type)
+ with debug_events_reader.DebugDataReader(self.dump_root) as reader:
+ reader.update()
+ if context.executing_eagerly():
+ # NOTE(b/142486213): Execution of the TF function happens with
+ # Session.run() in v1 graph mode, hence it doesn't get logged to the
+ # .execution file.
+ exec_digests = reader.executions(digest=True)
+ self.assertTrue(exec_digests)
- if context.executing_eagerly():
- # NOTE(b/142486213): Execution of the TF function happens with
- # Session.run() in v1 graph mode, hence it doesn't get logged to the
- # .execution file.
- executed_op_types, _, _, _, _, _ = self._readAndCheckExecutionFile()
- self.assertTrue(executed_op_types)
+ graph_exec_digests = reader.graph_execution_traces()
+ executed_op_types = [digest.op_type for digest in graph_exec_digests]
+ # These are the ops that we can safely assume to have been executed during
+ # the model's fit() call.
+ self.assertIn("Conv2D", executed_op_types)
+ self.assertIn("Relu6", executed_op_types)
+ self.assertIn("Conv2DBackpropFilter", executed_op_types)
+ self.assertIn("Relu6Grad", executed_op_types)
- (op_names, _, _,
- tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
- executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
- # These are the ops that we can safely assume to have been executed during
- # the model's fit() call.
- self.assertIn("Conv2D", executed_op_types)
- self.assertIn("Relu6", executed_op_types)
- self.assertIn("Conv2DBackpropFilter", executed_op_types)
- self.assertIn("Relu6Grad", executed_op_types)
- if tensor_debug_mode == "NO_TENSOR":
- # Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought to
- # be an empty float32 tensor.
- for tensor_value in tensor_values:
- self.assertEqual(tensor_value.dtype, np.float32)
- self.assertEqual(tensor_value.shape, (0,))
- elif tensor_debug_mode == "FULL_TENSOR":
- conv2d_values = [
- tensor_values[i]
- for i, op_type in enumerate(executed_op_types)
- if op_type == "Conv2D"
- ]
- self.assertTrue(conv2d_values)
- for conv2d_value in conv2d_values:
- self.assertGreater(len(conv2d_value.shape), 1)
- self.assertEqual(conv2d_value.shape[0], batch_size)
- relu6_values = [
- tensor_values[i]
- for i, op_type in enumerate(executed_op_types)
- if op_type == "Relu6"
- ]
- self.assertTrue(relu6_values)
- for relu6_value in relu6_values:
- self.assertGreater(len(relu6_value.shape), 1)
- self.assertEqual(relu6_value.shape[0], batch_size)
- conv2d_bp_filter_values = [
- tensor_values[i]
- for i, op_type in enumerate(executed_op_types)
- if op_type == "Conv2DBackpropFilter"
- ]
- self.assertTrue(conv2d_bp_filter_values)
- for conv2d_bp_filter_value in conv2d_bp_filter_values:
- self.assertGreater(len(conv2d_bp_filter_value.shape), 1)
- relu6_grad_values = [
- tensor_values[i]
- for i, op_type in enumerate(executed_op_types)
- if op_type == "Relu6Grad"
- ]
- self.assertTrue(relu6_grad_values)
- for relu6_grad_value in relu6_grad_values:
- self.assertGreater(len(relu6_grad_value.shape), 1)
+ if tensor_debug_mode == "NO_TENSOR":
+ # Under the default NO_TENSOR tensor-debug mode, the tensor_proto ought
+ # to be an empty float32 tensor.
+ tensor_values = [
+ reader.graph_execution_trace_to_tensor_value(digest)
+ for digest in graph_exec_digests]
+ for tensor_value in tensor_values:
+ self.assertAllEqual(tensor_value, [])
+ elif tensor_debug_mode == "FULL_TENSOR":
+ conv2d_values = [
+ reader.graph_execution_trace_to_tensor_value(digest)
+ for digest in graph_exec_digests if digest.op_type == "Conv2D"]
+ self.assertTrue(conv2d_values)
+ for conv2d_value in conv2d_values:
+ self.assertGreater(len(conv2d_value.shape), 1)
+ self.assertEqual(conv2d_value.shape[0], batch_size)
+ relu6_values = [
+ reader.graph_execution_trace_to_tensor_value(digest)
+ for digest in graph_exec_digests if digest.op_type == "Relu6"]
+ self.assertTrue(relu6_values)
+ for relu6_value in relu6_values:
+ self.assertGreater(len(relu6_value.shape), 1)
+ self.assertEqual(relu6_value.shape[0], batch_size)
+ conv2d_bp_filter_values = [
+ reader.graph_execution_trace_to_tensor_value(digest)
+ for digest in graph_exec_digests
+ if digest.op_type == "Conv2DBackpropFilter"]
+ self.assertTrue(conv2d_bp_filter_values)
+ for conv2d_bp_filter_value in conv2d_bp_filter_values:
+ self.assertGreater(len(conv2d_bp_filter_value.shape), 1)
+ relu6_grad_values = [
+ reader.graph_execution_trace_to_tensor_value(digest)
+ for digest in graph_exec_digests if digest.op_type == "Relu6Grad"]
+ self.assertTrue(relu6_grad_values)
+ for relu6_grad_value in relu6_grad_values:
+ self.assertGreater(len(relu6_grad_value.shape), 1)
if __name__ == "__main__":
diff --git a/tensorflow/python/debug/lib/dumping_callback_test_lib.py b/tensorflow/python/debug/lib/dumping_callback_test_lib.py
index 6144f2b..1d449f6 100644
--- a/tensorflow/python/debug/lib/dumping_callback_test_lib.py
+++ b/tensorflow/python/debug/lib/dumping_callback_test_lib.py
@@ -52,7 +52,7 @@
"""Read and check the .metadata debug-events file."""
with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
metadata_iter = reader.metadata_iterator()
- metadata = next(metadata_iter).debug_metadata
+ metadata = next(metadata_iter).debug_event.debug_metadata
self.assertEqual(metadata.tensorflow_version, versions.__version__)
self.assertTrue(metadata.file_version.startswith("debug.Event"))
@@ -67,7 +67,7 @@
source_files_iter = reader.source_files_iterator()
source_file_paths = []
prev_wall_time = 1
- for debug_event in source_files_iter:
+ for debug_event, _ in source_files_iter:
self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
prev_wall_time = debug_event.wall_time
source_file = debug_event.source_file
@@ -84,7 +84,7 @@
stack_frame_by_id = collections.OrderedDict()
stack_frames_iter = reader.stack_frames_iterator()
prev_wall_time = 0
- for debug_event in stack_frames_iter:
+ for debug_event, _ in stack_frames_iter:
self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
prev_wall_time = debug_event.wall_time
stack_frame_with_id = debug_event.stack_frame_with_id
@@ -133,7 +133,7 @@
# outermost contexts).
context_id_to_outer_id = dict()
- for debug_event in graphs_iter:
+ for debug_event, _ in graphs_iter:
self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
prev_wall_time = debug_event.wall_time
# A DebugEvent in the .graphs file contains either of the two fields:
@@ -219,7 +219,7 @@
output_tensor_ids = []
tensor_debug_modes = []
tensor_values = []
- for debug_event in execution_iter:
+ for debug_event, _ in execution_iter:
self.assertGreaterEqual(debug_event.wall_time, prev_wall_time)
prev_wall_time = debug_event.wall_time
execution = debug_event.execution
@@ -260,7 +260,7 @@
device_names = []
output_slots = []
tensor_values = []
- for debug_event in graph_execution_traces_iter:
+ for debug_event, _ in graph_execution_traces_iter:
self.assertGreaterEqual(debug_event.wall_time, 0)
graph_execution_trace = debug_event.graph_execution_trace
op_names.append(graph_execution_trace.op_name)
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index ff60fe6..de88124 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -172,9 +172,6 @@
srcs = ["distribute_lib_test.py"],
python_version = "PY3",
srcs_version = "PY2AND3",
- tags = [
- "no_rocm",
- ],
deps = [
":combinations",
":distribute_lib",
@@ -941,7 +938,6 @@
tags = [
"multi_and_single_gpu",
"no_oss", # TODO(b/139815303): enable after this is fixed.
- "no_rocm",
"notap", # TODO(b/139815303): enable after this is fixed.
],
deps = [
@@ -995,7 +991,6 @@
main = "step_fn_test.py",
tags = [
"multi_and_single_gpu",
- "no_rocm",
],
deps = [
":single_loss_example",
@@ -1056,10 +1051,10 @@
srcs = ["mirrored_strategy_test.py"],
shard_count = 5,
tags = [
- "guitar",
"multi_and_single_gpu",
"no_rocm",
"no_windows_gpu", # TODO(b/130551176)
+ "noguitar",
],
deps = [
":combinations",
@@ -1258,7 +1253,7 @@
tags = [
"multi_and_single_gpu",
],
- xla_enable_strict_auto_jit = False,
+ xla_enable_strict_auto_jit = True,
deps = [
":collective_all_reduce_strategy",
":combinations",
diff --git a/tensorflow/python/distribute/custom_training_loop_test.py b/tensorflow/python/distribute/custom_training_loop_test.py
index 53af0c7..55cb458 100644
--- a/tensorflow/python/distribute/custom_training_loop_test.py
+++ b/tensorflow/python/distribute/custom_training_loop_test.py
@@ -32,6 +32,7 @@
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
+from tensorflow.python.util import nest
class InputIterationTest(test.TestCase, parameterized.TestCase):
@@ -99,6 +100,37 @@
@combinations.generate(
combinations.combine(
+ distribution=[
+ strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+ strategy_combinations.tpu_strategy
+ ],
+ mode=["eager"]))
+ def testNestedOutput(self, distribution):
+ dataset = self._get_dataset()
+ input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
+
+ @def_function.function
+ def run(iterator):
+
+ def computation(x):
+ return [{
+ "a": x - 1,
+ "b": x + 1
+ }]
+
+ inputs = next(iterator)
+ outputs = distribution.experimental_run_v2(computation, args=(inputs,))
+ return nest.map_structure(distribution.experimental_local_results,
+ outputs)
+
+ results = run(input_iterator)
+ for replica in range(distribution.num_replicas_in_sync):
+ # The input dataset is range(10), so the replica id is same as input.
+ self.assertAllEqual(results[0]["a"][replica], [replica - 1])
+ self.assertAllEqual(results[0]["b"][replica], [replica + 1])
+
+ @combinations.generate(
+ combinations.combine(
distribution=strategy_combinations.all_strategies,
mode=["eager"]
))
diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py
index 216ec8b..552b739 100644
--- a/tensorflow/python/distribute/distribute_lib.py
+++ b/tensorflow/python/distribute/distribute_lib.py
@@ -102,7 +102,7 @@
import six
-from tensorflow.python.autograph.core import ag_ctx
+from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
from tensorflow.python.autograph.impl import api as autograph
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import device_util
@@ -629,6 +629,7 @@
# Distribute that dataset
dist_dataset = strategy.experimental_distribute_dataset(dataset)
+
# Iterate over the distributed dataset
for x in dist_dataset:
# process dataset elements
@@ -665,6 +666,32 @@
please use `experimental_distribute_datasets_from_function` instead, which
does not do any automatic splitting or sharding.
+ You can also use the `element_spec` property of the distributed dataset
+ returned by this API to query the `tf.TypeSpec` of the elements returned
+ by the iterator. This can be used to set the `input_signature` property
+ of a `tf.function`.
+
+ ```python
+ strategy = tf.distribute.MirroredStrategy()
+
+ # Create a dataset
+ dataset = dataset_ops.Dataset.TFRecordDataset([
+ "/a/1.tfr", "/a/2.tfr", "/a/3.tfr", "/a/4.tfr"])
+
+ # Distribute that dataset
+ dist_dataset = strategy.experimental_distribute_dataset(dataset)
+
+ @tf.function(input_signature=[dist_dataset.element_spec])
+ def train_step(inputs):
+ # train model with inputs
+ return
+
+ # Iterate over the distributed dataset
+ for x in dist_dataset:
+ # process dataset elements
+ strategy.experimental_run_v2(train_step, args=(x,))
+ ```
+
Args:
dataset: `tf.data.Dataset` that will be sharded across all replicas using
the rules stated above.
@@ -714,6 +741,26 @@
the global batch size. This may be computed using
`input_context.get_per_replica_batch_size`.
+ To query the `tf.TypeSpec` of the elements in the distributed dataset
+ returned by this API, you need to use the `element_spec` property of the
+ distributed iterator. This `tf.TypeSpec` can be used to set the
+ `input_signature` property of a `tf.function`.
+
+ ```python
+ # If you want to specify `input_signature` for a `tf.function` you must
+ # first create the iterator.
+ iterator = iter(inputs)
+
+ @tf.function(input_signature=[iterator.element_spec])
+ def replica_fn_with_signature(inputs):
+ # train the model with inputs
+ return
+
+ for _ in range(steps):
+ strategy.experimental_run_v2(replica_fn_with_signature,
+ args=(next(iterator),))
+ ```
+
Args:
dataset_fn: A function taking a `tf.distribute.InputContext` instance and
returning a `tf.data.Dataset`.
@@ -754,11 +801,15 @@
structure can either be "per-replica" `Tensor` objects or `Tensor`s
(for example, if running on a single replica).
"""
+ if not isinstance(args, (list, tuple)):
+ raise ValueError(
+ "positional args must be a list or tuple, got {}".format(type(args)))
+
with self.scope():
# tf.distribute supports Eager functions, so AutoGraph should not be
# applied when when the caller is also in Eager mode.
- fn = autograph.tf_convert(fn, ag_ctx.control_status_ctx(),
- convert_by_default=False)
+ fn = autograph.tf_convert(
+ fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
def reduce(self, reduce_op, value, axis):
@@ -1539,7 +1590,7 @@
if kwargs is None:
kwargs = {}
fn = autograph.tf_convert(
- fn, ag_ctx.control_status_ctx(), convert_by_default=False)
+ fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
with self._container_strategy().scope():
return self._update(var, fn, args, kwargs, group)
@@ -1565,7 +1616,7 @@
if kwargs is None:
kwargs = {}
fn = autograph.tf_convert(
- fn, ag_ctx.control_status_ctx(), convert_by_default=False)
+ fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
with self._container_strategy().scope():
return self._update_non_slot(colocate_with, fn, args, kwargs, group)
@@ -1949,8 +2000,8 @@
require_replica_context(self)
if kwargs is None:
kwargs = {}
- merge_fn = autograph.tf_convert(merge_fn, ag_ctx.control_status_ctx(),
- convert_by_default=False)
+ merge_fn = autograph.tf_convert(
+ merge_fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
return self._merge_call(merge_fn, args, kwargs)
def _merge_call(self, merge_fn, args, kwargs):
diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py
index e35365a..0aa3786 100644
--- a/tensorflow/python/distribute/input_lib.py
+++ b/tensorflow/python/distribute/input_lib.py
@@ -339,6 +339,11 @@
init_ops.extend(it.initialize())
return control_flow_ops.group(init_ops)
+ @property
+ def element_spec(self):
+ """The type specification of an element of this iterator."""
+ return self._element_spec
+
class DistributedIteratorV1(DistributedIterator):
"""Input Iterator for tf.data.DatasetV1."""
@@ -524,10 +529,9 @@
self._cloned_datasets.append(cloned_dataset)
self._input_workers = input_workers
- # TODO(anjalisridhar): Identify if we need to set this property on the
- # iterator.
- self.element_spec = dataset.element_spec
self._strategy = strategy
+ self._element_spec = _create_distributed_tensor_spec(self._strategy,
+ dataset.element_spec) # pylint: disable=protected-access
def __iter__(self):
if not (context.executing_eagerly() or
@@ -539,9 +543,14 @@
self._input_workers)
iterator = DistributedIterator(self._input_workers, worker_iterators,
self._strategy)
- iterator.element_spec = self.element_spec # pylint: disable=protected-access
+ iterator._element_spec = self.element_spec # pylint: disable=protected-access
return iterator
+ @property
+ def element_spec(self):
+ """The type specification of an element of this dataset."""
+ return self._element_spec
+
class DistributedDatasetV1(DistributedDataset):
"""Wrapped tf.data.DatasetV1 that supports prefetching to multiple devices."""
@@ -607,7 +616,7 @@
self._input_workers)
iterator = DistributedIteratorV1(self._input_workers, worker_iterators,
self._strategy)
- iterator.element_spec = self.element_spec # pylint: disable=protected-access
+ iterator._element_spec = self.element_spec # pylint: disable=protected-access
return iterator
@@ -640,6 +649,7 @@
self._input_workers = input_workers
self._input_contexts = input_contexts
self._strategy = strategy
+ self._element_spec = None
def __iter__(self):
if not (context.executing_eagerly() or
@@ -647,9 +657,25 @@
raise RuntimeError("__iter__() is only supported inside of tf.function "
"or when eager execution is enabled.")
- iterators = _create_iterators_per_worker_with_input_context(
+ iterators, element_spec = _create_iterators_per_worker_with_input_context(
self._input_contexts, self._input_workers, self._dataset_fn)
- return DistributedIterator(self._input_workers, iterators, self._strategy)
+ iterator = DistributedIterator(self._input_workers, iterators,
+ self._strategy)
+ self._element_spec = _create_distributed_tensor_spec(self._strategy,
+ element_spec)
+ iterator._element_spec = self._element_spec # pylint: disable=protected-access
+ return iterator
+
+ @property
+ def element_spec(self):
+ """The type specification of an element of this dataset."""
+ if self._element_spec is None:
+ raise ValueError("You must create an iterator before calling "
+ "`element_spec` on the distributed dataset or iterator. "
+ "This is because the dataset function is not called "
+ "before an iterator is created.")
+
+ return self._element_spec
class DistributedDatasetsFromFunctionV1(DistributedDatasetsFromFunction):
@@ -676,9 +702,14 @@
return self._get_iterator()
def _get_iterator(self):
- iterators = _create_iterators_per_worker_with_input_context(
+ iterators, element_spec = _create_iterators_per_worker_with_input_context(
self._input_contexts, self._input_workers, self._dataset_fn)
- return DistributedIteratorV1(self._input_workers, iterators, self._strategy)
+ iterator = DistributedIteratorV1(self._input_workers, iterators,
+ self._strategy)
+ self._element_spec = _create_distributed_tensor_spec(self._strategy,
+ element_spec)
+ iterator._element_spec = self._element_spec # pylint: disable=protected-access
+ return iterator
# TODO(anjalisridhar): This class will be soon be removed in favor of newer
@@ -769,7 +800,7 @@
input_workers,
worker_iterators, # pylint: disable=protected-access
strategy)
- self.element_spec = dist_dataset.element_spec # pylint: disable=protected-access
+ self._element_spec = dist_dataset.element_spec
def _dummy_tensor_fn(value_structure):
@@ -1003,7 +1034,7 @@
devices = input_workers.compute_devices_for_worker(i)
iterator = _SingleWorkerDatasetIterator(dataset, worker, devices)
iterators.append(iterator)
- return iterators
+ return iterators, dataset.element_spec
# TODO(sourabhbajaj): Remove this in lieu of distributed datasets
@@ -1169,3 +1200,32 @@
distribution.experimental_local_results(value))
distribution_strategy_context.get_replica_context().merge_call(
merge_fn, args=(output,))
+
+
+def _create_distributed_tensor_spec(strategy, tensor_spec):
+ """Create a `tf.TypeSpec` for a given strategy and input `tensor_spec`.
+
+ Args:
+ strategy: The given `tf.distribute` strategy.
+ tensor_spec: `tf.TensorSpec` of a given value. The batch dimension of the
+ shape should be None if you have partial batches.
+
+ Returns:
+ A `tf.TypeSpec` that matches the values produced by a given strategy. This
+ can be a `tf.TensorSpec` or a `PerRelicaSpec`.
+ """
+ num_replicas = len(strategy.extended.worker_devices)
+
+ # If the number of devices used in the strategy is just 1 then we return
+ # the tensor_spec as is.
+ if num_replicas == 1:
+ return tensor_spec
+
+ # If the number of devices is greater than 1 then we assume the input to
+ # tf.function is a per replica type.
+ def _get_value_per_replica(tensor_spec_per_input):
+ value_specs = [tensor_spec_per_input for _ in range(num_replicas)]
+ return values.PerReplicaSpec(*value_specs)
+
+ return nest.map_structure(_get_value_per_replica, tensor_spec)
+
diff --git a/tensorflow/python/distribute/input_lib_test.py b/tensorflow/python/distribute/input_lib_test.py
index ea02ba8..5df3a09 100644
--- a/tensorflow/python/distribute/input_lib_test.py
+++ b/tensorflow/python/distribute/input_lib_test.py
@@ -956,5 +956,68 @@
sess=sess)
+class InputTypeSpecTest(test.TestCase, parameterized.TestCase):
+
+ @combinations.generate(
+ combinations.combine(
+ mode=["eager"],
+ distribution=[
+ strategy_combinations.one_device_strategy,
+ strategy_combinations.mirrored_strategy_with_one_cpu,
+ strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+ strategy_combinations.tpu_strategy,
+ strategy_combinations.central_storage_strategy_with_two_gpus,
+ ],
+ input_type=["dataset", "dataset_fn"],
+ ))
+ def testInputSignatureForPerReplicaValues(self, distribution, input_type):
+ def dataset_fn(ctx):
+ del ctx # unused
+ return dataset_ops.DatasetV2.from_tensor_slices(
+ np.ones([10, 12]).astype(np.float32)).batch(4)
+
+ if input_type == "dataset":
+ ds = distribution.experimental_distribute_dataset(
+ dataset_fn(distribute_lib.InputContext()))
+ type_spec = ds.element_spec
+ else:
+ ds = distribution.experimental_distribute_datasets_from_function(
+ dataset_fn)
+ iterator = iter(ds)
+ type_spec = iterator.element_spec
+
+ @def_function.function(input_signature=[type_spec])
+ def process_inputs(inputs):
+ distribution.experimental_run_v2(lambda inputs: inputs, args=(inputs,))
+
+ for x in ds:
+ process_inputs(x)
+
+ @combinations.generate(
+ combinations.combine(
+ mode=["eager"],
+ distribution=[
+ strategy_combinations.one_device_strategy,
+ strategy_combinations.mirrored_strategy_with_one_cpu,
+ strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+ strategy_combinations.tpu_strategy,
+ strategy_combinations.central_storage_strategy_with_two_gpus,
+ ],
+ ))
+ def testInputSignatureForNestedPerReplicaValues(self, distribution):
+ a = np.ones((10, 2)) * 5
+ b = np.ones((10, 3)) * 6
+ dataset = dataset_ops.DatasetV2.from_tensor_slices((a, b)).batch(2)
+
+ dist_dataset = distribution.experimental_distribute_dataset(dataset)
+
+ @def_function.function(input_signature=[dist_dataset.element_spec])
+ def process_inputs(inputs):
+ distribution.experimental_run_v2(lambda inputs: inputs, args=(inputs,))
+
+ for x in dist_dataset:
+ process_inputs(x)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/distribute/minimize_loss_test.py b/tensorflow/python/distribute/minimize_loss_test.py
index 92e5f6d..d59d6d7 100644
--- a/tensorflow/python/distribute/minimize_loss_test.py
+++ b/tensorflow/python/distribute/minimize_loss_test.py
@@ -44,7 +44,11 @@
VAR_MAP_V1 = {
"GradientDescent": ("dense/kernel", "dense/bias"),
"Adagrad": ("dense/kernel/Adagrad", "dense/kernel", "dense/bias/Adagrad",
- "dense/bias")
+ "dense/bias"),
+ "Ftrl": ("dense/kernel/Ftrl", "dense/kernel", "dense/bias/Ftrl",
+ "dense/bias", "dense/kernel/Ftrl_1", "dense/bias/Ftrl_1"),
+ "RMSProp": ("dense/kernel", "dense/bias/RMSProp", "dense/bias/RMSProp_1",
+ "dense/bias", "dense/kernel/RMSProp_1", "dense/kernel/RMSProp")
}
VAR_MAP_V2 = {
diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py
index 729bb34..d04bde8 100644
--- a/tensorflow/python/distribute/mirrored_strategy.py
+++ b/tensorflow/python/distribute/mirrored_strategy.py
@@ -25,7 +25,7 @@
import weakref
from tensorflow.python import pywrap_tfe
-from tensorflow.python.autograph.core import ag_ctx
+from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
from tensorflow.python.autograph.impl import api as autograph
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import device_util
@@ -756,7 +756,7 @@
# _call_for_each_replica itself (TF library functions are whitelisted).
# This makes suresure that the Python function that originally passed to
# the tf.function is still converted.
- fn = autograph.tf_convert(fn, ag_ctx.control_status_ctx())
+ fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
return _call_for_each_replica(self._container_strategy(), self._devices,
fn, args, kwargs)
diff --git a/tensorflow/python/distribute/strategy_combinations.py b/tensorflow/python/distribute/strategy_combinations.py
index ae5c4a0..95fc7b9 100644
--- a/tensorflow/python/distribute/strategy_combinations.py
+++ b/tensorflow/python/distribute/strategy_combinations.py
@@ -40,6 +40,7 @@
from tensorflow.python.tpu import tpu_strategy_util
from tensorflow.python.training import adagrad
from tensorflow.python.training import adam
+from tensorflow.python.training import ftrl
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import rmsprop
@@ -130,11 +131,16 @@
"AdagradV1", lambda: adagrad.AdagradOptimizer(0.001))
adam_optimizer_v1_fn = combinations.NamedObject(
"AdamV1", lambda: adam.AdamOptimizer(0.001, epsilon=1))
+ftrl_optimizer_v1_fn = combinations.NamedObject(
+ "FtrlV1", lambda: ftrl.FtrlOptimizer(0.001))
rmsprop_optimizer_v1_fn = combinations.NamedObject(
"RmsPropV1", lambda: rmsprop.RMSPropOptimizer(0.001))
# TODO(shiningsun): consider adding the other v1 optimizers
-optimizers_v1 = [gradient_descent_optimizer_v1_fn, adagrad_optimizer_v1_fn]
+optimizers_v1 = [
+ gradient_descent_optimizer_v1_fn, adagrad_optimizer_v1_fn,
+ ftrl_optimizer_v1_fn, rmsprop_optimizer_v1_fn
+]
adadelta_optimizer_keras_v2_fn = combinations.NamedObject(
"AdadeltaKerasV2", lambda: adadelta_keras_v2.Adadelta(0.001))
diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py
index 85ff254..6f89ac6 100644
--- a/tensorflow/python/distribute/tpu_strategy.py
+++ b/tensorflow/python/distribute/tpu_strategy.py
@@ -25,7 +25,7 @@
import numpy as np
-from tensorflow.python.autograph.core import ag_ctx
+from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
from tensorflow.python.autograph.impl import api as autograph
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import device_util
@@ -163,7 +163,7 @@
# Note: the target function is converted to graph even when in Eager mode,
# so autograph is on by default here.
- fn = autograph.tf_convert(fn, ag_ctx.control_status_ctx())
+ fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
return self.extended.tpu_run(fn, args, kwargs)
@@ -209,7 +209,7 @@
"""See base class."""
validate_experimental_run_function(fn)
- fn = autograph.tf_convert(fn, ag_ctx.control_status_ctx())
+ fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
return self.extended.tpu_run(fn, args, kwargs)
@@ -817,7 +817,8 @@
# Remove all no ops that may have been added during 'tpu.replicate()'
if isinstance(result[0], list):
result[0] = [
- output for output in result[0] if tensor_util.is_tensor(output)
+ output for output in result[0] if not isinstance(
+ output, ops.Operation)
]
# Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py
index 5c6c5e3..58b29c4 100644
--- a/tensorflow/python/distribute/values_test.py
+++ b/tensorflow/python/distribute/values_test.py
@@ -680,6 +680,35 @@
foo()
+ @combinations.generate(
+ combinations.combine(
+ distribution=[
+ strategy_combinations.mirrored_strategy_with_one_cpu,
+ strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+ strategy_combinations.tpu_strategy,
+ strategy_combinations.central_storage_strategy_with_two_gpus,
+ ],
+ mode=["graph", "eager"]))
+ def testAggregationOnlyFirstReplica(self, distribution):
+ with distribution.scope():
+ v = variable_scope.variable(
+ 15.,
+ synchronization=variables_lib.VariableSynchronization.ON_WRITE,
+ aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA)
+ self.evaluate(variables_lib.global_variables_initializer())
+
+ @def_function.function
+ def assign():
+ ctx = distribution_strategy_context.get_replica_context()
+ replica_id = ctx.replica_id_in_sync_group
+ return v.assign(math_ops.cast(replica_id, dtypes.float32))
+ per_replica_results = self.evaluate(distribution.experimental_local_results(
+ distribution.experimental_run_v2(assign)))
+ # The per-replica values should always match the first replicas value.
+ self.assertAllEqual(
+ array_ops.zeros(distribution.num_replicas_in_sync, dtypes.float32),
+ per_replica_results)
+
_TPU_STRATEGIES = (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)
@@ -757,6 +786,8 @@
mode=["graph", "eager"])
+# TODO(b/144432582): Add variable aggregation type to combinations to simplify
+# tests.
def strategy_and_run_tf_function_combinations():
# Test the combination of different strategies and whether a tf.function
# is passed into strategy.experimental_run_v2."""
@@ -1123,6 +1154,68 @@
expected = 0
self.assertEqual(expected, result, aggregation)
+ # TODO(b/145574622): Re-enable this test once ReduceOp argument is
+ # respected on GPUs.
+ @combinations.generate(strategy_and_run_tf_function_combinations())
+ def disable_testAllReduce(self, distribution,
+ experimental_run_tf_function):
+ with distribution.scope():
+ v = variable_scope.variable(
+ 2.,
+ synchronization=variables_lib.VariableSynchronization.ON_WRITE,
+ aggregation=variables_lib.VariableAggregation.MEAN)
+ self.evaluate(variables_lib.global_variables_initializer())
+
+ def all_reduce():
+ ctx = distribution_strategy_context.get_replica_context()
+ replica_id = ctx.replica_id_in_sync_group
+ return ctx.all_reduce("SUM", v) + math_ops.cast(replica_id,
+ dtypes.float32)
+
+ if experimental_run_tf_function:
+ all_reduce = def_function.function(all_reduce)
+
+ per_replica_results = self.evaluate(
+ distribution.experimental_local_results(
+ distribution.experimental_run_v2(all_reduce)))
+ expected_result = []
+ for i in range(distribution.num_replicas_in_sync):
+ expected_result.append(2.0 * distribution.num_replicas_in_sync +
+ 1.0 * i)
+ self.assertEqual(per_replica_results, tuple(expected_result))
+
+ @combinations.generate(strategy_and_run_tf_function_combinations())
+ def testAssignPerReplicaBeforeRead(self, distribution,
+ experimental_run_tf_function):
+ aggregations = [
+ variables_lib.VariableAggregation.SUM,
+ variables_lib.VariableAggregation.MEAN,
+ variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
+ ]
+ for aggregation in aggregations:
+ with distribution.scope():
+ v = variable_scope.variable(
+ 0.,
+ synchronization=variables_lib.VariableSynchronization.ON_READ,
+ aggregation=aggregation)
+ self.evaluate(variables_lib.global_variables_initializer())
+
+ def assign(var=v):
+ ctx = distribution_strategy_context.get_replica_context()
+ replica_id = ctx.replica_id_in_sync_group
+ return var.assign(math_ops.cast(replica_id, dtypes.float32))
+
+ if experimental_run_tf_function:
+ assign = def_function.function(assign)
+
+ per_replica_results = self.evaluate(
+ distribution.experimental_local_results(
+ distribution.experimental_run_v2(assign)))
+ expected_result = []
+ for i in range(distribution.num_replicas_in_sync):
+ expected_result.append(1.0 * i)
+ self.assertEqual(per_replica_results, tuple(expected_result))
+
@combinations.generate(mirrored_and_tpu_strategy_combinations())
def testReadValueWithAggregationNoneInCrossReplicaContext(self, distribution):
with distribution.scope():
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index df1a409..809b4a8 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -44,7 +44,6 @@
"//tensorflow/python:cpp_python_util",
"//tensorflow/python:ndarray_tensor",
"//tensorflow/python:ndarray_tensor_bridge",
- "//tensorflow/python:ndarray_tensor_types",
"//tensorflow/python:numpy_lib",
"//tensorflow/python:py_seq_tensor",
"//tensorflow/python:safe_ptr",
@@ -266,7 +265,6 @@
srcs = ["backprop_test.py"],
python_version = "PY3",
tags = [
- "no_rocm",
"no_windows", #TODO(b/139745667)
],
deps = [
@@ -777,6 +775,8 @@
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:control_flow_util",
+ # TODO(b/145618471): Remove this transitive dependency.
+ "//tensorflow/python/distribute:input_lib",
"//tensorflow/python:framework_ops",
],
)
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index 7ffaefe..3fdbec2 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -127,20 +127,30 @@
@parameterized.named_parameters(
[('Function', def_function.function),
('NoFunction', lambda f: f)])
- def testIdentityBehaviorConsistent(self, decorator):
+ def testNoOpBehaviorConsistent(self, decorator):
@decorator
def f(x):
+ # Test all different types of no-ops
x1 = array_ops.identity(x)
+ x2 = math_ops.add_v2(x, 0)
+ x3 = math_ops.subtract(x, 0)
+ x4 = math_ops.multiply(x, 1)
with backprop.GradientTape() as t:
t.watch(x)
t.watch(x1)
+ t.watch(x2)
+ t.watch(x3)
+ t.watch(x4)
y1 = x * 2.
y2 = x1 * 3.
- loss = y1 + y2
- return t.gradient(loss, [x, x1])
+ y3 = x2 * 3.
+ y4 = x3 * 3.
+ y5 = x4 * 3.
+ loss = y1 + y2 + y3 + y4 + y5
+ return t.gradient(loss, [x, x1, x2, x3, x4])
- self.assertAllClose([2., 3.], f(constant_op.constant(10.)))
+ self.assertAllClose([2., 3., 3., 3., 3.], f(constant_op.constant(10.)))
def testGradientInsideLoop(self):
with ops.Graph().as_default():
diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py
index 50b8130..e7b90a1 100644
--- a/tensorflow/python/eager/benchmarks_test.py
+++ b/tensorflow/python/eager/benchmarks_test.py
@@ -65,6 +65,7 @@
CPU = "/device:CPU:0"
GPU = "/device:GPU:0"
+GLOBAL_TEST_VALUE = None
def c_tfe_py_fastpath_execute(a,
@@ -200,10 +201,20 @@
self._run(func, 30000)
- def _benchmark_create_constant(self, value, dtype):
- def func():
+ def _benchmark_create_constant(self, value, dtype, cached=True):
+ global GLOBAL_TEST_VALUE
+ GLOBAL_TEST_VALUE = value
+
+ def cached_func():
constant_op.constant(value, dtype=dtype)
+ def uncached_func():
+ global GLOBAL_TEST_VALUE
+ GLOBAL_TEST_VALUE += 1
+ constant_op.constant(GLOBAL_TEST_VALUE, dtype=dtype)
+
+ func = cached_func if cached else uncached_func
+
with ops.device("GPU:0" if context.num_gpus() else "CPU:0"):
for _ in range(1000):
func() # Warmup.
@@ -212,13 +223,22 @@
def benchmark_create_float_constant(self):
self._benchmark_create_constant(42.0, dtype=None)
+ def benchmark_create_float_constant_uncached(self):
+ self._benchmark_create_constant(42.0, dtype=None, cached=False)
+
def benchmark_create_int32_constant(self):
if context.num_gpus():
return # int32 constants are always allocated on CPU.
self._benchmark_create_constant(42, dtype=dtypes.int32)
- def _benchmark_add_scalars(self, a, b):
+ def benchmark_create_int32_constant_uncached(self):
+ if context.num_gpus():
+ return # int32 constants are always allocated on CPU.
+
+ self._benchmark_create_constant(42, dtype=dtypes.int32, cached=False)
+
+ def _benchmark_add(self, a, b):
def func():
return memoryview(math_ops.add(a, b))
@@ -228,10 +248,30 @@
self._run(func, 30000)
def benchmark_add_float_scalars(self):
- self._benchmark_add_scalars(42.0, 24.0)
+ self._benchmark_add(42.0, 24.0)
def benchmark_add_int32_scalars(self):
- self._benchmark_add_scalars(42, 24)
+ self._benchmark_add(42, 24)
+
+ def benchmark_add_float_scalar_tensor(self):
+ tensor_a = constant_op.constant(42.0)
+ tensor_b = constant_op.constant(24.0)
+ self._benchmark_add(tensor_a, tensor_b)
+
+ def benchmark_add_int32_scalar_tensor(self):
+ tensor_a = constant_op.constant(42)
+ tensor_b = constant_op.constant(24)
+ self._benchmark_add(tensor_a, tensor_b)
+
+ def benchmark_add_float_dense_tensor(self):
+ tensor_a = constant_op.constant([[42.0, 42.0], [42.0, 42.0]])
+ tensor_b = constant_op.constant([[24.0, 24.0], [24.0, 24.0]])
+ self._benchmark_add(tensor_a, tensor_b)
+
+ def benchmark_add_int32_dense_tensor(self):
+ tensor_a = constant_op.constant([[42, 42], [42, 42]])
+ tensor_b = constant_op.constant([[24, 24], [24, 24]])
+ self._benchmark_add(tensor_a, tensor_b)
def benchmark_create_float_tensor_from_list_CPU(self):
self._benchmark_create_tensor([[3.0]], dtypes.float32.as_datatype_enum, CPU)
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index 2a3fedc..d7b5c50 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -23,6 +23,7 @@
import copy
import random
import threading
+
from absl import logging
import numpy as np
import six
@@ -1904,16 +1905,19 @@
@tf_contextlib.contextmanager
def execution_mode(mode):
"""Context manager for setting execution mode for current thread."""
- ctx = context()
- executor_new = executor.new_executor(mode == ASYNC)
- executor_old = ctx.executor
- try:
- executor_old.wait()
- ctx.executor = executor_new
+ if mode is None:
yield
- finally:
- ctx.executor = executor_old
- executor_new.wait()
+ else:
+ ctx = context()
+ executor_new = executor.new_executor(mode == ASYNC)
+ executor_old = ctx.executor
+ try:
+ executor_old.wait()
+ ctx.executor = executor_new
+ yield
+ finally:
+ ctx.executor = executor_old
+ executor_new.wait()
@tf_contextlib.contextmanager
diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py
index 1c67105..09a0722 100644
--- a/tensorflow/python/eager/def_function.py
+++ b/tensorflow/python/eager/def_function.py
@@ -449,7 +449,9 @@
if self._implements is not None:
attributes[function_lib.IMPLEMENTS_ATTRIBUTE_NAME] = self._implements
if self._experimental_compile is not None:
- attributes.update(_XlaCompile=bool(self._experimental_compile))
+ attributes.update(_XlaMustCompile=bool(self._experimental_compile))
+ if self._experimental_compile:
+ attributes.update(_noinline=True)
if not attributes:
attributes = None
return function_lib.defun_with_attributes(
diff --git a/tensorflow/python/eager/def_function_xla_jit_test.py b/tensorflow/python/eager/def_function_xla_jit_test.py
index 5338725..c69b5fe 100644
--- a/tensorflow/python/eager/def_function_xla_jit_test.py
+++ b/tensorflow/python/eager/def_function_xla_jit_test.py
@@ -20,6 +20,7 @@
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -45,6 +46,79 @@
# XLA support is not yet enabled for TF ROCm
self.assertAllClose([2, 3, 3, 4, 4], xla_func(inputs, 1))
+ def testBasicInt32(self):
+
+ def fn(x, a):
+ return x + a
+
+ xla_func = def_function.function(fn, experimental_compile=True)
+
+ inputs = constant_op.constant([1, 2, 2, 3, 3], dtype=dtypes.int32)
+ if not test.is_built_with_rocm():
+ # XLA support is not yet enabled for TF ROCm
+ self.assertAllClose([2, 3, 3, 4, 4], xla_func(inputs, 1))
+
+ def testDerivative(self):
+ if test.is_built_with_rocm():
+ return
+
+ def fn(x, a):
+ return 2 * x + a
+
+ xla_func = def_function.function(fn, experimental_compile=True)
+
+ with backprop.GradientTape() as tape:
+ inputs = constant_op.constant([1., 2., 2., 3., 3.])
+ tape.watch(inputs)
+ outputs = xla_func(inputs, 1)
+
+ self.assertAllClose([2, 2, 2, 2, 2], tape.gradient(outputs, inputs))
+
+ # pylint: disable=protected-access
+ (forward, backward) = xla_func.get_concrete_function(
+ inputs, 1)._delayed_rewrite_functions.forward_backward()
+
+ # Check that the must-compile attribute gets correctly propagated to the
+ # created derivatives.
+ self.assertTrue(backward.function_def.attr['_XlaMustCompile'])
+ self.assertTrue(forward.definition.attr['_XlaMustCompile'])
+
+ # Calling function with experimental_compile=True from
+ # experimental_compile=False should compile the inner func.
+ def testNestedCall(self):
+
+ def fn(x, a):
+ return x + a
+
+ xla_func = def_function.function(fn, experimental_compile=True)
+
+ def fn2(x, a):
+ return xla_func(x, a)
+
+ func = def_function.function(fn2, experimental_compile=False)
+
+ inputs = constant_op.constant([1, 2, 2, 3, 3])
+ if not test.is_built_with_rocm():
+ # XLA support is not yet enabled for TF ROCm
+ self.assertAllClose([2, 3, 3, 4, 4], func(inputs, 1))
+
+ def testNestedCallUnsupportedOps(self):
+
+ def fn(x):
+ return array_ops.unique(x).y
+
+ xla_func = def_function.function(fn, experimental_compile=True)
+
+ def fn2(x):
+ return xla_func(x)
+
+ func = def_function.function(fn2, experimental_compile=False)
+ inputs = constant_op.constant([1, 2, 2, 3, 3])
+ if not test.is_built_with_rocm():
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ 'not compilable'):
+ func(inputs)
+
def testUnsupportedOps(self):
def fn(x):
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index de35852..29b463a 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -2738,11 +2738,7 @@
def get_list():
return [constant_op.constant(0.), constant_op.constant(1.)]
- expected_msg = (
- 'Function to be traced should not modify structure of input '
- 'arguments. Check if your function has list and dictionary '
- 'operations that alter input arguments, '
- 'such as `list.pop`, `list.append`')
+ expected_msg = '.*() should not modify'
with self.assertRaisesRegexp(ValueError, expected_msg):
@@ -2818,11 +2814,7 @@
def get_dict():
return {'t1': constant_op.constant(0.), 't2': constant_op.constant(1.)}
- expected_msg = (
- 'Function to be traced should not modify structure of input '
- 'arguments. Check if your function has list and dictionary '
- 'operations that alter input arguments, '
- 'such as `list.pop`, `list.append`')
+ expected_msg = '.* should not modify'
with self.assertRaisesRegexp(ValueError, expected_msg):
@@ -2865,14 +2857,8 @@
setdefault(get_dict())
def testFunctionModifiesInputNest(self):
- # Test on functions that modify structure of nested input arguments
- expected_msg = (
- 'Function to be traced should not modify structure of input '
- 'arguments. Check if your function has list and dictionary '
- 'operations that alter input arguments, '
- 'such as `list.pop`, `list.append`')
-
- with self.assertRaisesRegexp(ValueError, expected_msg):
+ with self.assertRaisesRegexp(
+ ValueError, 'modify.* should not modify'):
@def_function.function
def modify(n):
@@ -2886,7 +2872,8 @@
modify(nested_input)
- with self.assertRaisesRegexp(ValueError, expected_msg):
+ with self.assertRaisesRegexp(
+ ValueError, 'modify_same_flat.* should not modify'):
# The flat list doesn't change whereas the true structure changes
@def_function.function
diff --git a/tensorflow/python/eager/memory_tests/memory_test_util.py b/tensorflow/python/eager/memory_tests/memory_test_util.py
index 8e2fa00..0bb3089 100644
--- a/tensorflow/python/eager/memory_tests/memory_test_util.py
+++ b/tensorflow/python/eager/memory_tests/memory_test_util.py
@@ -21,6 +21,7 @@
import collections
import gc
import time
+
import six
from tensorflow.python.eager import context
diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc
index 66c2b85..519026f 100644
--- a/tensorflow/python/eager/pywrap_tensor.cc
+++ b/tensorflow/python/eager/pywrap_tensor.cc
@@ -31,7 +31,6 @@
#include "tensorflow/python/eager/pywrap_tfe.h"
#include "tensorflow/python/lib/core/ndarray_tensor.h"
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
-#include "tensorflow/python/lib/core/ndarray_tensor_types.h"
#include "tensorflow/python/lib/core/numpy.h"
#include "tensorflow/python/lib/core/py_seq_tensor.h"
#include "tensorflow/python/lib/core/safe_ptr.h"
@@ -289,15 +288,15 @@
if (PyArray_Check(value)) {
int desired_np_dtype = -1;
if (dtype != tensorflow::DT_INVALID) {
- PyArray_Descr* descr = nullptr;
- if (!tensorflow::DataTypeToPyArray_Descr(dtype, &descr).ok()) {
+ if (!tensorflow::TF_DataType_to_PyArray_TYPE(
+ static_cast<TF_DataType>(dtype), &desired_np_dtype)
+ .ok()) {
PyErr_SetString(
PyExc_TypeError,
tensorflow::strings::StrCat("Invalid dtype argument value ", dtype)
.c_str());
return nullptr;
}
- desired_np_dtype = descr->type_num;
}
PyArrayObject* array = reinterpret_cast<PyArrayObject*>(value);
int current_np_dtype = PyArray_TYPE(array);
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index f5508d7..8fe4b6a 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -72,11 +72,15 @@
TFE_Op* maybe_op = ReleaseThreadLocalOp();
if (maybe_op) {
TFE_OpReset(ctx, op_or_function_name, raw_device_name, status, maybe_op);
- return maybe_op;
- } else {
- return NewOrResetOp(ctx, op_or_function_name, raw_device_name, status,
- nullptr);
+ if (status->status.ok()) {
+ return maybe_op;
+ }
+ // Delete op and create a fresh one
+ delete maybe_op;
}
+
+ return NewOrResetOp(ctx, op_or_function_name, raw_device_name, status,
+ nullptr);
}
void ReturnOp(TFE_Op* object) {
diff --git a/tensorflow/python/eager/remote.py b/tensorflow/python/eager/remote.py
index 64309f9..276f2de 100644
--- a/tensorflow/python/eager/remote.py
+++ b/tensorflow/python/eager/remote.py
@@ -19,6 +19,7 @@
from __future__ import print_function
import copy
+
from absl import logging
from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef
diff --git a/tensorflow/python/feature_column/dense_features.py b/tensorflow/python/feature_column/dense_features.py
index e9b6393..e6dc842 100644
--- a/tensorflow/python/feature_column/dense_features.py
+++ b/tensorflow/python/feature_column/dense_features.py
@@ -36,11 +36,12 @@
This layer can be called multiple times with different features.
- This is the V1 version of this layer that uses variable_scope's to create
- variables which works well with PartitionedVariables. Variable scopes are
- deprecated in V2, so the V2 version uses name_scopes instead. But currently
- that lacks support for partitioned variables. Use this if you need
- partitioned variables.
+ This is the V1 version of this layer that uses variable_scope's or partitioner
+ to create variables which works well with PartitionedVariables. Variable
+ scopes are deprecated in V2, so the V2 version uses name_scopes instead. But
+ currently that lacks support for partitioned variables. Use this if you need
+ partitioned variables. Use the partitioner argument if you have a Keras model
+ and uses `tf.compat.v1.keras.estimator.model_to_estimator` for training.
Example:
@@ -50,7 +51,9 @@
tf.feature_column.categorical_column_with_hash_bucket("keywords", 10K),
dimensions=16)
columns = [price, keywords_embedded, ...]
- feature_layer = tf.compat.v1.keras.layers.DenseFeatures(columns)
+ partitioner = tf.compat.v1.fixed_size_partitioner(num_shards=4)
+ feature_layer = tf.compat.v1.keras.layers.DenseFeatures(
+ feature_columns=columns, partitioner=partitioner)
features = tf.io.parse_example(
..., features=tf.feature_column.make_parse_example_spec(columns))
@@ -62,7 +65,12 @@
```
"""
- def __init__(self, feature_columns, trainable=True, name=None, **kwargs):
+ def __init__(self,
+ feature_columns,
+ trainable=True,
+ name=None,
+ partitioner=None,
+ **kwargs):
"""Constructs a DenseFeatures layer.
Args:
@@ -75,6 +83,7 @@
trainable: Boolean, whether the layer's variables will be updated via
gradient descent during training.
name: Name to give to the DenseFeatures.
+ partitioner: Partitioner for input layer. Defaults to None.
**kwargs: Keyword arguments to construct a layer.
Raises:
@@ -84,6 +93,7 @@
feature_columns=feature_columns,
trainable=trainable,
name=name,
+ partitioner=partitioner,
expected_column_type=fc.DenseColumn,
**kwargs)
diff --git a/tensorflow/python/feature_column/dense_features_test.py b/tensorflow/python/feature_column/dense_features_test.py
index c1a970e..7cd523d 100644
--- a/tensorflow/python/feature_column/dense_features_test.py
+++ b/tensorflow/python/feature_column/dense_features_test.py
@@ -33,6 +33,7 @@
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import lookup_ops
+from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
@@ -98,6 +99,58 @@
self.assertEqual(1, len(variables))
self.assertIs(variables[0], dense_features.variables[0])
+ def test_dense_feature_with_partitioner(self):
+ with context.eager_mode():
+ sparse_input = sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (2, 0), (3, 0)),
+ values=(0, 1, 3, 2),
+ dense_shape=(4, 4))
+
+ # Create feature columns (categorical and embedding).
+ categorical_column = fc.categorical_column_with_identity(
+ key='a', num_buckets=4)
+ embedding_dimension = 2
+
+ def _embedding_column_initializer(shape, dtype, partition_info=None):
+ offset = partition_info._var_offset[0]
+ del shape # unused
+ del dtype # unused
+ if offset == 0:
+ embedding_values = (
+ (1, 0), # id 0
+ (0, 1)) # id 1
+ else:
+ embedding_values = (
+ (1, 1), # id 2
+ (2, 2)) # id 3
+ return embedding_values
+
+ embedding_column = fc.embedding_column(
+ categorical_column,
+ dimension=embedding_dimension,
+ initializer=_embedding_column_initializer)
+
+ dense_features = df.DenseFeatures(
+ [embedding_column],
+ partitioner=partitioned_variables.fixed_size_partitioner(2))
+ features = {'a': sparse_input}
+
+ inputs = dense_features(features)
+ variables = dense_features.variables
+
+ # Sanity check: test that the inputs are correct.
+ self.assertAllEqual([[1, 0], [0, 1], [2, 2], [1, 1]], inputs)
+
+ # Check that only one variable was created.
+ self.assertEqual(2, len(variables))
+
+ # Check that invoking dense_features on the same features does not create
+ # additional variables
+ _ = dense_features(features)
+ self.assertEqual(2, len(variables))
+ self.assertIs(variables[0], dense_features.variables[0])
+ self.assertIs(variables[1], dense_features.variables[1])
+
def test_feature_column_dense_features_gradient(self):
with context.eager_mode():
sparse_input = sparse_tensor.SparseTensor(
diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py
index b3a7523..0e8b076 100644
--- a/tensorflow/python/feature_column/feature_column_v2.py
+++ b/tensorflow/python/feature_column/feature_column_v2.py
@@ -306,8 +306,14 @@
# specifying a default partitioner for an entire layer. In that case,
# the default getter for Layers should work.
getter=variable_scope.get_variable)
- if isinstance(var, trackable.Trackable):
- self._layer._track_trackable(var, feature_column.name + '/' + name) # pylint: disable=protected-access
+ if isinstance(var, variables.PartitionedVariable):
+ for v in var:
+ part_name = name + '/' + str(v._get_save_slice_info().var_offset[0]) # pylint: disable=protected-access
+ self._layer._track_trackable(v, feature_column.name + '/' + part_name) # pylint: disable=protected-access
+ else:
+ if isinstance(var, trackable.Trackable):
+ self._layer._track_trackable(var, feature_column.name + '/' + name) # pylint: disable=protected-access
+
self._cols_to_vars_map[feature_column][name] = var
return var
@@ -375,12 +381,19 @@
ValueError: if an item in `feature_columns` doesn't match
`expected_column_type`.
"""
- def __init__(self, feature_columns, expected_column_type, trainable, name,
+
+ def __init__(self,
+ feature_columns,
+ expected_column_type,
+ trainable,
+ name,
+ partitioner=None,
**kwargs):
super(_BaseFeaturesLayer, self).__init__(
name=name, trainable=trainable, **kwargs)
self._feature_columns = _normalize_feature_columns(feature_columns)
self._state_manager = _StateManagerImpl(self, self.trainable)
+ self._partitioner = partitioner
for column in self._feature_columns:
if not isinstance(column, expected_column_type):
raise ValueError(
@@ -391,7 +404,9 @@
def build(self, _):
for column in self._feature_columns:
- with variable_scope._pure_variable_scope(self.name): # pylint: disable=protected-access
+ with variable_scope._pure_variable_scope( # pylint: disable=protected-access
+ self.name,
+ partitioner=self._partitioner):
with variable_scope._pure_variable_scope(column.name): # pylint: disable=protected-access
column.create_state(self._state_manager)
super(_BaseFeaturesLayer, self).build(None)
@@ -438,6 +453,8 @@
column_configs = serialization.serialize_feature_columns(
self._feature_columns)
config = {'feature_columns': column_configs}
+ config['partitioner'] = generic_utils.serialize_keras_object(
+ self._partitioner)
base_config = super( # pylint: disable=bad-super-call
_BaseFeaturesLayer, self).get_config()
@@ -450,6 +467,8 @@
config_cp = config.copy()
config_cp['feature_columns'] = serialization.deserialize_feature_columns(
config['feature_columns'], custom_objects=custom_objects)
+ config_cp['partitioner'] = generic_utils.deserialize_keras_object(
+ config['partitioner'], custom_objects)
return cls(**config_cp)
diff --git a/tensorflow/python/framework/constant_op_test.py b/tensorflow/python/framework/constant_op_test.py
new file mode 100644
index 0000000..da0fb64
--- /dev/null
+++ b/tensorflow/python/framework/constant_op_test.py
@@ -0,0 +1,61 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.python.framework.constant_op."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import test
+
+
+class ConstantOpTest(test.TestCase, parameterized.TestCase):
+
+ @parameterized.parameters(
+ dtypes.bfloat16,
+ dtypes.complex128,
+ dtypes.complex64,
+ dtypes.double,
+ dtypes.float16,
+ dtypes.float32,
+ dtypes.float64,
+ dtypes.half,
+ dtypes.int16,
+ dtypes.int32,
+ dtypes.int64,
+ dtypes.int8,
+ dtypes.qint16,
+ dtypes.qint32,
+ dtypes.qint8,
+ dtypes.quint16,
+ dtypes.quint8,
+ dtypes.uint16,
+ dtypes.uint32,
+ dtypes.uint64,
+ dtypes.uint8,
+ )
+ def test_convert_string_to_number(self, dtype):
+ with self.assertRaises(TypeError):
+ constant_op.constant("hello", dtype)
+
+
+if __name__ == "__main__":
+ ops.enable_eager_execution()
+ test.main()
diff --git a/tensorflow/python/framework/dtypes.cc b/tensorflow/python/framework/dtypes.cc
index c5efd68..7c8521b 100644
--- a/tensorflow/python/framework/dtypes.cc
+++ b/tensorflow/python/framework/dtypes.cc
@@ -17,7 +17,6 @@
#include "include/pybind11/pybind11.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
-#include "tensorflow/python/lib/core/ndarray_tensor_types.h"
namespace {
@@ -61,18 +60,6 @@
namespace py = pybind11;
PYBIND11_MODULE(_dtypes, m) {
- tensorflow::MaybeRegisterCustomNumPyTypes();
-
- m.attr("np_bfloat16") =
- reinterpret_cast<PyObject*>(tensorflow::BFLOAT16_DESCR);
- m.attr("np_qint8") = reinterpret_cast<PyObject*>(tensorflow::QINT8_DESCR);
- m.attr("np_qint16") = reinterpret_cast<PyObject*>(tensorflow::QINT16_DESCR);
- m.attr("np_qint32") = reinterpret_cast<PyObject*>(tensorflow::QINT32_DESCR);
- m.attr("np_quint8") = reinterpret_cast<PyObject*>(tensorflow::QUINT8_DESCR);
- m.attr("np_quint16") = reinterpret_cast<PyObject*>(tensorflow::QUINT16_DESCR);
- m.attr("np_resource") =
- reinterpret_cast<PyObject*>(tensorflow::RESOURCE_DESCR);
-
py::class_<tensorflow::DataType>(m, "DType")
.def(py::init([](py::object obj) {
auto id = static_cast<int>(py::int_(obj));
diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py
index 405184b..44d98a9 100644
--- a/tensorflow/python/framework/dtypes.py
+++ b/tensorflow/python/framework/dtypes.py
@@ -20,16 +20,17 @@
import numpy as np
from six.moves import builtins
-# TODO(b/143110113): This import has to come first. This is a temporary
-# workaround which fixes repeated proto registration on macOS.
-# pylint: disable=g-bad-import-order, unused-import
-from tensorflow.python import pywrap_tensorflow
-# pylint: enable=g-bad-import-order, unused-import
-
from tensorflow.core.framework import types_pb2
+# We need to import pywrap_tensorflow prior to the bfloat wrapper to avoid
+# protobuf errors where a file is defined twice on MacOS.
+# pylint: disable=invalid-import-order,g-bad-import-order
+from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
+from tensorflow.python import _pywrap_bfloat16
from tensorflow.python import _dtypes
from tensorflow.python.util.tf_export import tf_export
+_np_bfloat16 = _pywrap_bfloat16.TF_bfloat16_type()
+
# pylint: disable=slots-on-old-class
@tf_export("dtypes.DType", "DType")
@@ -424,18 +425,20 @@
# Numpy representation for quantized dtypes.
#
-_np_qint8 = _dtypes.np_qint8
-_np_qint16 = _dtypes.np_qint16
-_np_qint32 = _dtypes.np_qint32
-_np_quint8 = _dtypes.np_quint8
-_np_quint16 = _dtypes.np_quint16
+# These are magic strings that are used in the swig wrapper to identify
+# quantized types.
+# TODO(mrry,keveman): Investigate Numpy type registration to replace this
+# hard-coding of names.
+_np_qint8 = np.dtype([("qint8", np.int8)])
+_np_quint8 = np.dtype([("quint8", np.uint8)])
+_np_qint16 = np.dtype([("qint16", np.int16)])
+_np_quint16 = np.dtype([("quint16", np.uint16)])
+_np_qint32 = np.dtype([("qint32", np.int32)])
-# Technically, _np_bfloat does not have to be a Python class, but existing
-# code expects it to.
-_np_bfloat16 = _dtypes.np_bfloat16.type
+# _np_bfloat16 is defined by a module import.
# Custom struct dtype for directly-fed ResourceHandles of supported type(s).
-np_resource = _dtypes.np_resource
+np_resource = np.dtype([("resource", np.ubyte)])
# Standard mappings between types_pb2.DataType values and numpy.dtypes.
_NP_TO_TF = {
diff --git a/tensorflow/python/framework/func_graph.py b/tensorflow/python/framework/func_graph.py
index 18dddea..e4b086d 100644
--- a/tensorflow/python/framework/func_graph.py
+++ b/tensorflow/python/framework/func_graph.py
@@ -975,6 +975,9 @@
python_func = tf_decorator.rewrap(python_func, original_func,
converted_func)
+ else:
+ _, original_func = tf_decorator.unwrap(python_func)
+
func_outputs = python_func(*func_args, **func_kwargs)
# invariant: `func_outputs` contains only Tensors, CompositeTensors,
@@ -982,8 +985,8 @@
func_outputs = nest.map_structure(convert, func_outputs,
expand_composites=True)
- check_mutation(func_args_before, func_args)
- check_mutation(func_kwargs_before, func_kwargs)
+ check_mutation(func_args_before, func_args, original_func)
+ check_mutation(func_kwargs_before, func_kwargs, original_func)
finally:
current_scope.set_use_resource(default_use_recource)
@@ -1048,13 +1051,15 @@
for spec in device_stack.peek_objs())
-def check_mutation(n1, n2):
+def check_mutation(n1, n2, func):
"""Check if two list of arguments are exactly the same."""
- errmsg = ("Function to be traced should not modify structure of input "
- "arguments. Check if your function has list and dictionary "
- "operations that alter input arguments, "
- "such as `list.pop`, `list.append`")
+ func_name = getattr(func, "__name__", func)
+
+ errmsg = ("{}() should not modify its Python input arguments."
+ " Check if it modifies any lists or dicts passed as"
+ " arguments. Modifying a copy is allowed.".format(func_name))
try:
+ # TODO(mdan): Compare more robustly so that argument names can be reported.
nest.assert_same_structure(n1, n2, expand_composites=True)
except ValueError:
raise ValueError(errmsg)
diff --git a/tensorflow/python/framework/gpu_util.py b/tensorflow/python/framework/gpu_util.py
new file mode 100644
index 0000000..37ddc22
--- /dev/null
+++ b/tensorflow/python/framework/gpu_util.py
@@ -0,0 +1,57 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains GPU utility functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import re
+
+
+# Matches the DeviceAttributes.physical_device_desc field.
+_PHYSICAL_DEVICE_DESCRIPTION_REGEX = re.compile(
+ r'name: ([^,]*), (?:.*compute capability: (\d+)\.(\d+))?')
+
+
+# compute_capability is a (major version, minor version) pair, or None if this
+# is not an Nvidia GPU.
+GpuInfo = collections.namedtuple('gpu_info', ['name', 'compute_capability'])
+
+
+def compute_capability_from_device_desc(device_attrs):
+ """Returns the GpuInfo given a DeviceAttributes proto.
+
+ Args:
+ device_attrs: A DeviceAttributes proto.
+
+ Returns
+ A gpu_info tuple. Both fields are None if `device_attrs` does not have a
+ valid physical_device_desc field.
+ """
+ # TODO(jingyue): The device description generator has to be in sync with
+ # this file. Another option is to put compute capability in
+ # DeviceAttributes, but I avoided that to keep DeviceAttributes
+ # target-independent. Reconsider this option when we have more things like
+ # this to keep in sync.
+ # LINT.IfChange
+ match = _PHYSICAL_DEVICE_DESCRIPTION_REGEX.search(
+ device_attrs.physical_device_desc)
+ # LINT.ThenChange(//tensorflow/core/common_runtime/gpu/gpu_device.cc)
+ if not match:
+ return GpuInfo(None, None)
+ cc = int(match.group(2)), int(match.group(3)) if match.group(2) else None
+ return GpuInfo(match.group(1), cc)
diff --git a/tensorflow/python/framework/random_seed.py b/tensorflow/python/framework/random_seed.py
index eff0434..6f41e8d 100644
--- a/tensorflow/python/framework/random_seed.py
+++ b/tensorflow/python/framework/random_seed.py
@@ -252,6 +252,29 @@
The reason we get 'A2' instead 'A1' on the second call of `tf.random.uniform`
above is because the secand call uses a different operation seed.
+ Note that `tf.function` acts like a re-run of a program in this case. When
+ the global seed is set but operation seeds are not set, the sequence of random
+ numbers are the same for each `tf.function`. For example:
+
+ ```python
+ tf.random.set_seed(1234)
+
+ @tf.function
+ def f():
+ a = tf.random.uniform([1])
+ b = tf.random.uniform([1])
+ return a, b
+
+ @tf.function
+ def g():
+ a = tf.random.uniform([1])
+ b = tf.random.uniform([1])
+ return a, b
+
+ print(f()) # prints '(A1, A2)'
+ print(g()) # prints '(A1, A2)'
+ ```
+
If the operation seed is set, we get different results for every call to the
random op, but the same sequence for every re-run of the program:
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 6f4794a..2eff46f 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -56,6 +56,7 @@
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import errors_impl
+from tensorflow.python.framework import gpu_util
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
@@ -1487,28 +1488,12 @@
Returns:
True if a GPU device of the requested kind is available.
"""
-
- def compute_capability_from_device_desc(device_desc):
- # TODO(jingyue): The device description generator has to be in sync with
- # this file. Another option is to put compute capability in
- # DeviceAttributes, but I avoided that to keep DeviceAttributes
- # target-independent. Reconsider this option when we have more things like
- # this to keep in sync.
- # LINT.IfChange
- match = re.search(r"compute capability: (\d+)\.(\d+)", device_desc)
- # LINT.ThenChange(//tensorflow/core/\
- # common_runtime/gpu/gpu_device.cc)
- if not match:
- return 0, 0
- return int(match.group(1)), int(match.group(2))
-
try:
for local_device in device_lib.list_local_devices():
if local_device.device_type == "GPU":
- if (min_cuda_compute_capability is None or
- compute_capability_from_device_desc(
- local_device.physical_device_desc) >=
- min_cuda_compute_capability):
+ gpu_info = gpu_util.compute_capability_from_device_desc(local_device)
+ cc = gpu_info.compute_capability or (0, 0)
+ if not min_cuda_compute_capability or cc >= min_cuda_compute_capability:
return True
if local_device.device_type == "SYCL" and not cuda_only:
return True
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index 1c14fb1..0134d6a 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -458,6 +458,7 @@
"layers/normalization.py",
"layers/normalization_v2.py",
"layers/pooling.py",
+ "layers/preprocessing/categorical.py",
"layers/preprocessing/image_preprocessing.py",
"layers/preprocessing/normalization.py",
"layers/preprocessing/normalization_v1.py",
@@ -657,6 +658,20 @@
)
tf_py_test(
+ name = "add_loss_correctness_test",
+ size = "medium",
+ srcs = ["add_loss_correctness_test.py"],
+ additional_deps = [
+ ":keras",
+ "@absl_py//absl/testing:parameterized",
+ "//third_party/py/numpy",
+ "//tensorflow/python:client_testlib",
+ ],
+ python_version = "PY3",
+ shard_count = 4,
+)
+
+tf_py_test(
name = "metrics_functional_test",
size = "small",
srcs = ["metrics_functional_test.py"],
@@ -728,6 +743,20 @@
size = "medium",
srcs = ["saving/metrics_serialization_test.py"],
python_version = "PY3",
+ shard_count = 8,
+ deps = [
+ ":keras",
+ "//tensorflow/python:client_testlib",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+tf_py_test(
+ name = "losses_serialization_test",
+ size = "medium",
+ srcs = ["saving/losses_serialization_test.py"],
+ python_version = "PY3",
shard_count = 4,
deps = [
":keras",
@@ -772,7 +801,6 @@
srcs = ["layers/convolutional_recurrent_test.py"],
python_version = "PY3",
shard_count = 4,
- tags = ["no_rocm"],
deps = [
":keras",
"//tensorflow/python:client_testlib",
@@ -795,6 +823,31 @@
],
)
+filegroup(
+ name = "vocabulary_testdata",
+ srcs = [
+ "layers/preprocessing/testdata/wire_vocabulary.txt",
+ ],
+)
+
+cuda_py_test(
+ name = "categorical_test",
+ size = "medium",
+ srcs = ["layers/preprocessing/categorical_test.py"],
+ data = [":vocabulary_testdata"],
+ python_version = "PY3",
+ shard_count = 4,
+ tags = [
+ "no_oss",
+ ],
+ deps = [
+ ":keras",
+ "//tensorflow/python:client_testlib",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
cuda_py_test(
name = "image_preprocessing_test",
size = "medium",
@@ -829,7 +882,6 @@
python_version = "PY3",
shard_count = 4,
tags = [
- "no_rocm",
"no_windows_gpu",
],
deps = [
@@ -972,7 +1024,6 @@
python_version = "PY3",
shard_count = 4,
tags = [
- "no_rocm",
"notsan",
],
deps = [
@@ -1580,7 +1631,6 @@
srcs = ["engine/training_arrays_test.py"],
python_version = "PY3",
tags = [
- "no_rocm",
"nomac", # TODO(mihaimaruseac): b/127695564
],
deps = [
@@ -1655,7 +1705,6 @@
srcs = ["engine/training_eager_test.py"],
python_version = "PY3",
tags = [
- "no_rocm",
"nomac", # TODO(mihaimaruseac): b/127695564
"notsan",
],
@@ -1805,7 +1854,6 @@
srcs = ["engine/base_layer_utils_test.py"],
python_version = "PY3",
tags = [
- "no_rocm",
"nomac", # TODO(mihaimaruseac): b/127695564
],
deps = [
diff --git a/tensorflow/python/keras/activations.py b/tensorflow/python/keras/activations.py
index f26c5a1..16f60a7 100644
--- a/tensorflow/python/keras/activations.py
+++ b/tensorflow/python/keras/activations.py
@@ -182,6 +182,19 @@
return nn.softsign(x)
+@keras_export('keras.activations.swish')
+def swish(x):
+ """Swish activation function.
+
+ Arguments:
+ x: Input tensor.
+
+ Returns:
+ The swish activation applied to `x`.
+ """
+ return nn.swish(x)
+
+
@keras_export('keras.activations.relu')
def relu(x, alpha=0., max_value=None, threshold=0):
"""Applies the rectified linear unit activation function.
diff --git a/tensorflow/python/keras/add_loss_correctness_test.py b/tensorflow/python/keras/add_loss_correctness_test.py
new file mode 100644
index 0000000..2f02799
--- /dev/null
+++ b/tensorflow/python/keras/add_loss_correctness_test.py
@@ -0,0 +1,464 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests add_loss API correctness."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
+from tensorflow.python.eager import def_function
+from tensorflow.python.keras import Input
+from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras import layers
+from tensorflow.python.keras import losses
+from tensorflow.python.keras import Model
+from tensorflow.python.keras import optimizer_v2
+from tensorflow.python.keras import Sequential
+from tensorflow.python.keras import testing_utils
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+from tensorflow.python.training.rmsprop import RMSPropOptimizer
+
+MAE = losses.MeanAbsoluteError
+mae = losses.mean_absolute_error
+
+
+def get_ctl_train_step(model):
+ optimizer = optimizer_v2.gradient_descent.SGD(0.05)
+
+ def train_step(x, y, w=None):
+ with backprop.GradientTape() as tape:
+ if w is not None:
+ model([x, y, w])
+ else:
+ model([x, y])
+ loss = math_ops.reduce_sum(model.losses)
+ gradients = tape.gradient(loss, model.trainable_weights)
+ optimizer.apply_gradients(zip(gradients, model.trainable_weights))
+ return loss
+
+ return train_step
+
+
+# TODO(psv): Add tests cases where a model is used in loss function but is
+# not part of the training model.
+
+
+class TestAddLossCorrectness(keras_parameterized.TestCase):
+
+ def setUp(self):
+ super(TestAddLossCorrectness, self).setUp()
+ self.x = np.array([[0.], [1.], [2.]], dtype='float32')
+ self.y = np.array([[0.5], [2.], [3.5]], dtype='float32')
+ self.w = np.array([[1.25], [0.5], [1.25]], dtype='float32')
+
+ @keras_parameterized.run_all_keras_modes
+ def test_loss_on_model_fit(self):
+ inputs = Input(shape=(1,))
+ targets = Input(shape=(1,))
+ outputs = testing_utils.Bias()(inputs)
+ model = Model([inputs, targets], outputs)
+ model.add_loss(MAE()(targets, outputs))
+ model.add_loss(math_ops.reduce_mean(mae(targets, outputs)))
+ model.compile(
+ optimizer_v2.gradient_descent.SGD(0.05),
+ run_eagerly=testing_utils.should_run_eagerly(),
+ experimental_run_tf_function=testing_utils.should_run_tf_function())
+
+ history = model.fit([self.x, self.y], batch_size=3, epochs=5)
+ self.assertAllClose(history.history['loss'], [2., 1.8, 1.6, 1.4, 1.2], 1e-3)
+
+ @keras_parameterized.run_with_all_model_types(exclude_models=['sequential'])
+ @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
+ def test_loss_callable_on_model_fit(self):
+ model = testing_utils.get_model_from_layers([testing_utils.Bias()],
+ input_shape=(1,))
+
+ def callable_loss():
+ return math_ops.reduce_sum(model.weights)
+
+ model.add_loss(callable_loss)
+ model.compile(
+ optimizer_v2.gradient_descent.SGD(0.1),
+ run_eagerly=testing_utils.should_run_eagerly(),
+ experimental_run_tf_function=testing_utils.should_run_tf_function())
+
+ history = model.fit(self.x, batch_size=3, epochs=5)
+ self.assertAllClose(history.history['loss'], [0., -.1, -.2, -.3, -.4], 1e-3)
+
+ def test_loss_on_model_ctl(self):
+ with context.eager_mode():
+
+ def get_model_and_train_step():
+ inputs = Input(shape=(1,))
+ targets = Input(shape=(1,))
+ outputs = testing_utils.Bias()(inputs)
+ model = Model([inputs, targets], outputs)
+ model.add_loss(MAE()(targets, outputs))
+ model.add_loss(math_ops.reduce_mean(mae(targets, outputs)))
+ return get_ctl_train_step(model)
+
+ train_step = get_model_and_train_step()
+ loss = [train_step(self.x, self.y) for _ in range(5)]
+ self.assertAllClose(loss, [2., 1.8, 1.6, 1.4, 1.2], 1e-3)
+
+ train_step = def_function.function(get_model_and_train_step())
+ loss = [train_step(self.x, self.y) for _ in range(5)]
+ self.assertAllClose(loss, [2., 1.8, 1.6, 1.4, 1.2], 1e-3)
+
+ def test_loss_callable_on_model_ctl(self):
+ with context.eager_mode():
+
+ def get_model_and_train_step():
+ inputs = Input(shape=(1,))
+ targets = Input(shape=(1,))
+ outputs = testing_utils.Bias()(inputs)
+ model = Model([inputs, targets], outputs)
+
+ def callable_loss():
+ return math_ops.reduce_sum(model.weights)
+
+ model.add_loss(callable_loss)
+ return get_ctl_train_step(model)
+
+ train_step = get_model_and_train_step()
+ loss = [train_step(self.x, self.y) for _ in range(5)]
+ self.assertAllClose(loss, [0., -0.05, -0.1, -0.15, -0.2], 1e-3)
+
+ train_step = def_function.function(get_model_and_train_step())
+ loss = [train_step(self.x, self.y) for _ in range(5)]
+ self.assertAllClose(loss, [0., -0.05, -0.1, -0.15, -0.2], 1e-3)
+
+ @keras_parameterized.run_all_keras_modes
+ def test_loss_with_sample_weight_on_model_fit(self):
+ inputs = Input(shape=(1,))
+ targets = Input(shape=(1,))
+ sw = Input(shape=(1,))
+ outputs = testing_utils.Bias()(inputs)
+ model = Model([inputs, targets, sw], outputs)
+ model.add_loss(MAE()(targets, outputs, sw))
+ model.add_loss(3 * math_ops.reduce_mean(sw * mae(targets, outputs)))
+ model.compile(
+ optimizer_v2.gradient_descent.SGD(0.025),
+ run_eagerly=testing_utils.should_run_eagerly(),
+ experimental_run_tf_function=testing_utils.should_run_tf_function())
+
+ history = model.fit([self.x, self.y, self.w], batch_size=3, epochs=5)
+ self.assertAllClose(history.history['loss'], [4., 3.6, 3.2, 2.8, 2.4], 1e-3)
+
+ def test_loss_with_sample_weight_on_model_ctl(self):
+ with context.eager_mode():
+
+ def get_model_and_train_step():
+ inputs = Input(shape=(1,))
+ targets = Input(shape=(1,))
+ sw = Input(shape=(1,))
+ outputs = testing_utils.Bias()(inputs)
+ model = Model([inputs, targets, sw], outputs)
+ model.add_loss(MAE()(targets, outputs, sw))
+ model.add_loss(math_ops.reduce_mean(sw * mae(targets, outputs)))
+ return get_ctl_train_step(model)
+
+ train_step = get_model_and_train_step()
+ loss = [train_step(self.x, self.y, self.w) for _ in range(5)]
+ self.assertAllClose(loss, [2., 1.8, 1.6, 1.4, 1.2], 1e-3)
+
+ train_step = def_function.function(get_model_and_train_step())
+ loss = [train_step(self.x, self.y, self.w) for _ in range(5)]
+ self.assertAllClose(loss, [2., 1.8, 1.6, 1.4, 1.2], 1e-3)
+
+ @keras_parameterized.run_all_keras_modes
+ def test_loss_with_sample_weight_in_model_call(self):
+
+ class MyModel(Model):
+
+ def __init__(self):
+ super(MyModel, self).__init__()
+ self.bias = testing_utils.Bias()
+
+ def call(self, inputs):
+ outputs = self.bias(inputs[0])
+ self.add_loss(MAE()(inputs[1], outputs, inputs[2]))
+ self.add_loss(math_ops.reduce_mean(inputs[2] * mae(inputs[1], outputs)))
+ return outputs
+
+ model = MyModel()
+ model.predict([self.x, self.y, self.w])
+ model.compile(
+ optimizer_v2.gradient_descent.SGD(0.05),
+ run_eagerly=testing_utils.should_run_eagerly(),
+ experimental_run_tf_function=testing_utils.should_run_tf_function())
+
+ history = model.fit([self.x, self.y, self.w], batch_size=3, epochs=5)
+ self.assertEqual(len(model.losses), 2)
+ self.assertAllClose(history.history['loss'], [2., 1.8, 1.6, 1.4, 1.2], 1e-3)
+
+ eval_out = model.evaluate([self.x, self.y, self.w])
+ self.assertAlmostEqual(eval_out, 1.0, 3)
+
+ @keras_parameterized.run_all_keras_modes
+ def test_loss_with_sample_weight_in_layer_call(self):
+
+ class MyLayer(layers.Layer):
+
+ def __init__(self):
+ super(MyLayer, self).__init__()
+ self.bias = testing_utils.Bias()
+
+ def call(self, inputs):
+ out = self.bias(inputs[0])
+ self.add_loss(MAE()(inputs[1], out, inputs[2]))
+ self.add_loss(math_ops.reduce_mean(inputs[2] * mae(inputs[1], out)))
+ return out
+
+ inputs = Input(shape=(1,))
+ targets = Input(shape=(1,))
+ sw = Input(shape=(1,))
+
+ outputs = MyLayer()([inputs, targets, sw])
+ model = Model([inputs, targets, sw], outputs)
+ model.predict([self.x, self.y, self.w])
+ model.compile(
+ optimizer_v2.gradient_descent.SGD(0.05),
+ run_eagerly=testing_utils.should_run_eagerly(),
+ experimental_run_tf_function=testing_utils.should_run_tf_function())
+
+ history = model.fit([self.x, self.y, self.w], batch_size=3, epochs=5)
+ self.assertAllClose(history.history['loss'], [2., 1.8, 1.6, 1.4, 1.2], 1e-3)
+
+ output = model.evaluate([self.x, self.y, self.w])
+ self.assertAlmostEqual(output, 1.0, 3)
+
+ output = model.test_on_batch([self.x, self.y, self.w])
+ self.assertAlmostEqual(output, 1.0, 3)
+
+ @keras_parameterized.run_all_keras_modes
+ def test_loss_on_layer(self):
+
+ class MyLayer(layers.Layer):
+
+ def call(self, inputs):
+ self.add_loss(math_ops.reduce_sum(inputs))
+ return inputs
+
+ inputs = Input((3,))
+ layer = MyLayer()
+ outputs = layer(inputs)
+ model = Model(inputs, outputs)
+ self.assertEqual(len(model.losses), 1)
+ model.compile(
+ 'sgd',
+ 'mse',
+ run_eagerly=testing_utils.should_run_eagerly(),
+ experimental_run_tf_function=testing_utils.should_run_tf_function())
+ loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
+ self.assertEqual(loss, 2 * 3)
+
+ @keras_parameterized.run_all_keras_modes
+ @keras_parameterized.run_with_all_model_types
+ def test_activity_regularizer(self):
+ loss = {}
+ for reg in [None, 'l2']:
+ model_layers = [
+ layers.Dense(
+ 10,
+ activation='relu',
+ activity_regularizer=reg,
+ kernel_initializer='ones',
+ use_bias=False),
+ layers.Dense(
+ 1,
+ activation='sigmoid',
+ kernel_initializer='ones',
+ use_bias=False),
+ ]
+
+ model = testing_utils.get_model_from_layers(
+ model_layers, input_shape=(10,))
+
+ x = np.ones((10, 10), 'float32')
+ y = np.ones((10, 1), 'float32')
+
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ model.compile(
+ optimizer,
+ 'binary_crossentropy',
+ run_eagerly=testing_utils.should_run_eagerly(),
+ experimental_run_tf_function=testing_utils.should_run_tf_function())
+ model.fit(x, y, batch_size=2, epochs=5)
+ loss[reg] = model.evaluate(x, y)
+ self.assertLess(loss[None], loss['l2'])
+
+ @keras_parameterized.run_all_keras_modes
+ @keras_parameterized.run_with_all_model_types
+ def test_activity_regularizer_loss_value(self):
+ layer = layers.Dense(
+ 1,
+ kernel_initializer='zeros',
+ bias_initializer='ones',
+ activity_regularizer='l2')
+
+ model = testing_utils.get_model_from_layers([layer], input_shape=(10,))
+
+ x = np.ones((10, 10), 'float32')
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ model.compile(
+ optimizer,
+ run_eagerly=testing_utils.should_run_eagerly(),
+ experimental_run_tf_function=testing_utils.should_run_tf_function())
+ loss = model.test_on_batch(x)
+ self.assertAlmostEqual(0.01, loss, places=4)
+
+ @keras_parameterized.run_all_keras_modes
+ def test_activity_regularizer_batch_independent(self):
+ inputs = layers.Input(shape=(10,))
+ x = layers.Dense(10, activation='relu', activity_regularizer='l2')(inputs)
+ outputs = layers.Dense(1, activation='sigmoid')(x)
+ model = Model(inputs, outputs)
+
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ model.compile(
+ optimizer,
+ run_eagerly=testing_utils.should_run_eagerly(),
+ experimental_run_tf_function=testing_utils.should_run_tf_function())
+
+ loss_small_batch = model.test_on_batch(np.ones((10, 10), 'float32'))
+ loss_big_batch = model.test_on_batch(np.ones((20, 10), 'float32'))
+ self.assertAlmostEqual(loss_small_batch, loss_big_batch, places=4)
+
+ @keras_parameterized.run_all_keras_modes
+ def test_with_shared_layer(self):
+
+ class LayerWithLoss(layers.Layer):
+
+ def call(self, inputs):
+ self.add_loss(math_ops.reduce_sum(inputs), inputs)
+ return inputs * 2
+
+ shared_layer = LayerWithLoss()
+
+ m = Sequential([shared_layer])
+ m2 = Sequential([shared_layer, m])
+ m2(array_ops.constant([1, 2, 3]))
+ self.assertEqual(len(m2.losses), 2)
+ self.assertAllClose(m2.losses, [6, 12])
+
+ @keras_parameterized.run_all_keras_modes
+ def test_with_shared_nested_layer(self):
+
+ class LayerWithLoss(layers.Layer):
+
+ def call(self, inputs):
+ self.add_loss(math_ops.reduce_sum(inputs), inputs)
+ return inputs * 2
+
+ class LayerWithNestedLayerWithLoss(layers.Layer):
+
+ def __init__(self):
+ super(LayerWithNestedLayerWithLoss, self).__init__()
+ self.loss_layer = LayerWithLoss()
+
+ def call(self, inputs):
+ return self.loss_layer(inputs)
+
+ shared_layer = LayerWithNestedLayerWithLoss()
+
+ m = Sequential([shared_layer])
+ m2 = Sequential([shared_layer, m])
+ m2(array_ops.constant([1, 2, 3]))
+ self.assertEqual(len(m2.losses), 2)
+ self.assertAllClose(m2.losses, [6, 12])
+
+ @keras_parameterized.run_all_keras_modes
+ def test_clear_losses(self):
+
+ class LayerWithSharedNestedLossLayer(layers.Layer):
+
+ def __init__(self):
+ super(LayerWithSharedNestedLossLayer, self).__init__()
+ self.loss_layer = layers.ActivityRegularization(l2=0.001)
+ self.add_weight(shape=(1,), regularizer='l2')
+
+ def call(self, x):
+ x = self.loss_layer(x)
+ return self.loss_layer(x)
+
+ inputs = Input(shape=(1,))
+ l = LayerWithSharedNestedLossLayer() # Weight loss + 2 activity losses.
+
+ x1 = array_ops.ones((1, 1))
+ _ = l(x1)
+ if not context.executing_eagerly():
+ self.assertEqual(len(l.get_losses_for(x1)), 2)
+ self.assertEqual(len(l.get_losses_for(None)), 1)
+
+ x2 = array_ops.ones((1, 1))
+ _ = l(x2)
+ if not context.executing_eagerly():
+ self.assertEqual(len(l.get_losses_for(x1)), 2)
+ self.assertEqual(len(l.get_losses_for(x2)), 2)
+ self.assertEqual(len(l.get_losses_for(None)), 1)
+
+ outputs = l(inputs)
+ model = Model(inputs, outputs)
+ if not context.executing_eagerly():
+ self.assertEqual(len(model.losses), 7)
+ self.assertEqual(len(l.get_losses_for(x1)), 2)
+ self.assertEqual(len(l.get_losses_for(x2)), 2)
+ self.assertEqual(len(l.get_losses_for(None)), 1)
+
+ x3 = array_ops.ones((1, 1))
+ model(x3)
+ x4 = array_ops.ones((1, 1))
+ model(x4)
+ if context.executing_eagerly():
+ # Eager losses are cleared every `__call__`.
+ self.assertEqual(len(model.losses), 3)
+ else:
+ self.assertEqual(len(model.losses), 11)
+ self.assertEqual(len(model.get_losses_for(x3)), 2)
+ self.assertEqual(len(model.get_losses_for(x4)), 2)
+ self.assertEqual(len(model.get_losses_for(None)), 1)
+
+ @keras_parameterized.run_all_keras_modes
+ def test_invalid_constant_input(self):
+ with context.eager_mode():
+ inputs = Input(shape=(1,))
+ outputs = testing_utils.Bias()(inputs)
+ model = Model(inputs, outputs)
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'Expected a symbolic Tensors or a callable for the loss value'):
+ model.add_loss(1.)
+
+ @keras_parameterized.run_all_keras_modes
+ def test_invalid_variable_input(self):
+ with context.eager_mode():
+ inputs = Input(shape=(1,))
+ outputs = testing_utils.Bias()(inputs)
+ model = Model(inputs, outputs)
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'Expected a symbolic Tensors or a callable for the loss value'):
+ model.add_loss(model.weights[0])
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/keras/applications/BUILD b/tensorflow/python/keras/applications/BUILD
index f5faae0..17998dff 100644
--- a/tensorflow/python/keras/applications/BUILD
+++ b/tensorflow/python/keras/applications/BUILD
@@ -15,6 +15,7 @@
srcs = [
"__init__.py",
"densenet.py",
+ "efficientnet.py",
"imagenet_utils.py",
"inception_resnet_v2.py",
"inception_v3.py",
diff --git a/tensorflow/python/keras/applications/applications_test.py b/tensorflow/python/keras/applications/applications_test.py
index b790eb8..198bebd 100644
--- a/tensorflow/python/keras/applications/applications_test.py
+++ b/tensorflow/python/keras/applications/applications_test.py
@@ -22,6 +22,7 @@
from tensorflow.python.keras import backend
from tensorflow.python.keras.applications import densenet
+from tensorflow.python.keras.applications import efficientnet
from tensorflow.python.keras.applications import inception_resnet_v2
from tensorflow.python.keras.applications import inception_v3
from tensorflow.python.keras.applications import mobilenet
@@ -52,6 +53,14 @@
(densenet.DenseNet121, 1024),
(densenet.DenseNet169, 1664),
(densenet.DenseNet201, 1920),
+ (efficientnet.EfficientNetB0, 1280),
+ (efficientnet.EfficientNetB1, 1280),
+ (efficientnet.EfficientNetB2, 1408),
+ (efficientnet.EfficientNetB3, 1536),
+ (efficientnet.EfficientNetB4, 1792),
+ (efficientnet.EfficientNetB5, 2048),
+ (efficientnet.EfficientNetB6, 2304),
+ (efficientnet.EfficientNetB7, 2560),
]
NASNET_LIST = [
@@ -73,6 +82,16 @@
raise AssertionError('Shapes differ: %s vs %s' % (shape1, shape2))
@parameterized.parameters(*MODEL_LIST)
+ def test_application_base(self, app, _):
+ # Can be instantiated with default arguments
+ model = app(weights=None)
+ # Can be serialized and deserialized
+ config = model.get_config()
+ reconstructed_model = model.__class__.from_config(config)
+ self.assertEqual(len(model.weights), len(reconstructed_model.weights))
+ backend.clear_session()
+
+ @parameterized.parameters(*MODEL_LIST)
def test_application_notop(self, app, last_dim):
if 'NASNet' in app.__name__:
only_check_last_dim = True
diff --git a/tensorflow/python/keras/applications/efficientnet.py b/tensorflow/python/keras/applications/efficientnet.py
new file mode 100644
index 0000000..f3d0f1e
--- /dev/null
+++ b/tensorflow/python/keras/applications/efficientnet.py
@@ -0,0 +1,654 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# pylint: disable=invalid-name
+"""EfficientNet models for Keras.
+
+Reference paper:
+ - [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks]
+ (https://arxiv.org/abs/1905.11946) (ICML 2019)
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+import math
+import os
+
+from tensorflow.python.keras import backend
+from tensorflow.python.keras import layers
+from tensorflow.python.keras.applications import imagenet_utils
+from tensorflow.python.keras.engine import training
+from tensorflow.python.keras.utils import data_utils
+from tensorflow.python.keras.utils import layer_utils
+from tensorflow.python.util.tf_export import keras_export
+
+
+BASE_WEIGHTS_PATH = 'https://storage.googleapis.com/keras-applications/'
+
+WEIGHTS_HASHES = {
+ 'b0': ('902e53a9f72be733fc0bcb005b3ebbac',
+ '50bc09e76180e00e4465e1a485ddc09d'),
+ 'b1': ('1d254153d4ab51201f1646940f018540',
+ '74c4e6b3e1f6a1eea24c589628592432'),
+ 'b2': ('b15cce36ff4dcbd00b6dd88e7857a6ad',
+ '111f8e2ac8aa800a7a99e3239f7bfb39'),
+ 'b3': ('ffd1fdc53d0ce67064dc6a9c7960ede0',
+ 'af6d107764bb5b1abb91932881670226'),
+ 'b4': ('18c95ad55216b8f92d7e70b3a046e2fc',
+ 'ebc24e6d6c33eaebbd558eafbeedf1ba'),
+ 'b5': ('ace28f2a6363774853a83a0b21b9421a',
+ '38879255a25d3c92d5e44e04ae6cec6f'),
+ 'b6': ('165f6e37dce68623721b423839de8be5',
+ '9ecce42647a20130c1f39a5d4cb75743'),
+ 'b7': ('8c03f828fec3ef71311cd463b6759d99',
+ 'cbcfe4450ddf6f3ad90b1b398090fe4a'),
+}
+
+DEFAULT_BLOCKS_ARGS = [{
+ 'kernel_size': 3,
+ 'repeats': 1,
+ 'filters_in': 32,
+ 'filters_out': 16,
+ 'expand_ratio': 1,
+ 'id_skip': True,
+ 'strides': 1,
+ 'se_ratio': 0.25
+}, {
+ 'kernel_size': 3,
+ 'repeats': 2,
+ 'filters_in': 16,
+ 'filters_out': 24,
+ 'expand_ratio': 6,
+ 'id_skip': True,
+ 'strides': 2,
+ 'se_ratio': 0.25
+}, {
+ 'kernel_size': 5,
+ 'repeats': 2,
+ 'filters_in': 24,
+ 'filters_out': 40,
+ 'expand_ratio': 6,
+ 'id_skip': True,
+ 'strides': 2,
+ 'se_ratio': 0.25
+}, {
+ 'kernel_size': 3,
+ 'repeats': 3,
+ 'filters_in': 40,
+ 'filters_out': 80,
+ 'expand_ratio': 6,
+ 'id_skip': True,
+ 'strides': 2,
+ 'se_ratio': 0.25
+}, {
+ 'kernel_size': 5,
+ 'repeats': 3,
+ 'filters_in': 80,
+ 'filters_out': 112,
+ 'expand_ratio': 6,
+ 'id_skip': True,
+ 'strides': 1,
+ 'se_ratio': 0.25
+}, {
+ 'kernel_size': 5,
+ 'repeats': 4,
+ 'filters_in': 112,
+ 'filters_out': 192,
+ 'expand_ratio': 6,
+ 'id_skip': True,
+ 'strides': 2,
+ 'se_ratio': 0.25
+}, {
+ 'kernel_size': 3,
+ 'repeats': 1,
+ 'filters_in': 192,
+ 'filters_out': 320,
+ 'expand_ratio': 6,
+ 'id_skip': True,
+ 'strides': 1,
+ 'se_ratio': 0.25
+}]
+
+CONV_KERNEL_INITIALIZER = {
+ 'class_name': 'VarianceScaling',
+ 'config': {
+ 'scale': 2.0,
+ 'mode': 'fan_out',
+ 'distribution': 'truncated_normal'
+ }
+}
+
+DENSE_KERNEL_INITIALIZER = {
+ 'class_name': 'VarianceScaling',
+ 'config': {
+ 'scale': 1. / 3.,
+ 'mode': 'fan_out',
+ 'distribution': 'uniform'
+ }
+}
+
+
+def EfficientNet(width_coefficient,
+ depth_coefficient,
+ default_size,
+ dropout_rate=0.2,
+ drop_connect_rate=0.2,
+ depth_divisor=8,
+ activation='swish',
+ blocks_args='default',
+ model_name='efficientnet',
+ include_top=True,
+ weights='imagenet',
+ input_tensor=None,
+ input_shape=None,
+ pooling=None,
+ classes=1000):
+ """Instantiates the EfficientNet architecture using given scaling coefficients.
+
+ Optionally loads weights pre-trained on ImageNet.
+ Note that the data format convention used by the model is
+ the one specified in your Keras config at `~/.keras/keras.json`.
+
+ Arguments:
+ width_coefficient: float, scaling coefficient for network width.
+ depth_coefficient: float, scaling coefficient for network depth.
+ default_size: integer, default input image size.
+ dropout_rate: float, dropout rate before final classifier layer.
+ drop_connect_rate: float, dropout rate at skip connections.
+ depth_divisor: integer, a unit of network width.
+ activation: activation function.
+ blocks_args: list of dicts, parameters to construct block modules.
+ model_name: string, model name.
+ include_top: whether to include the fully-connected
+ layer at the top of the network.
+ weights: one of `None` (random initialization),
+ 'imagenet' (pre-training on ImageNet),
+ or the path to the weights file to be loaded.
+ input_tensor: optional Keras tensor
+ (i.e. output of `layers.Input()`)
+ to use as image input for the model.
+ input_shape: optional shape tuple, only to be specified
+ if `include_top` is False.
+ It should have exactly 3 inputs channels.
+ pooling: optional pooling mode for feature extraction
+ when `include_top` is `False`.
+ - `None` means that the output of the model will be
+ the 4D tensor output of the
+ last convolutional layer.
+ - `avg` means that global average pooling
+ will be applied to the output of the
+ last convolutional layer, and thus
+ the output of the model will be a 2D tensor.
+ - `max` means that global max pooling will
+ be applied.
+ classes: optional number of classes to classify images
+ into, only to be specified if `include_top` is True, and
+ if no `weights` argument is specified.
+
+ Returns:
+ A Keras model instance.
+
+ Raises:
+ ValueError: in case of invalid argument for `weights`,
+ or invalid input shape.
+ """
+ if blocks_args == 'default':
+ blocks_args = DEFAULT_BLOCKS_ARGS
+
+ if not (weights in {'imagenet', None} or os.path.exists(weights)):
+ raise ValueError('The `weights` argument should be either '
+ '`None` (random initialization), `imagenet` '
+ '(pre-training on ImageNet), '
+ 'or the path to the weights file to be loaded.')
+
+ if weights == 'imagenet' and include_top and classes != 1000:
+ raise ValueError('If using `weights` as `"imagenet"` with `include_top`'
+ ' as true, `classes` should be 1000')
+
+ # Determine proper input shape
+ input_shape = imagenet_utils.obtain_input_shape(
+ input_shape,
+ default_size=default_size,
+ min_size=32,
+ data_format=backend.image_data_format(),
+ require_flatten=include_top,
+ weights=weights)
+
+ if input_tensor is None:
+ img_input = layers.Input(shape=input_shape)
+ else:
+ if not backend.is_keras_tensor(input_tensor):
+ img_input = layers.Input(tensor=input_tensor, shape=input_shape)
+ else:
+ img_input = input_tensor
+
+ bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1
+
+ def round_filters(filters, divisor=depth_divisor):
+ """Round number of filters based on depth multiplier."""
+ filters *= width_coefficient
+ new_filters = max(divisor, int(filters + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than 10%.
+ if new_filters < 0.9 * filters:
+ new_filters += divisor
+ return int(new_filters)
+
+ def round_repeats(repeats):
+ """Round number of repeats based on depth multiplier."""
+ return int(math.ceil(depth_coefficient * repeats))
+
+ # Build stem
+ x = img_input
+ x = layers.Rescaling(1. / 255.)(x)
+ x = layers.Normalization(axis=bn_axis)(x)
+
+ x = layers.ZeroPadding2D(
+ padding=imagenet_utils.correct_pad(x, 3),
+ name='stem_conv_pad')(x)
+ x = layers.Conv2D(
+ round_filters(32),
+ 3,
+ strides=2,
+ padding='valid',
+ use_bias=False,
+ kernel_initializer=CONV_KERNEL_INITIALIZER,
+ name='stem_conv')(x)
+ x = layers.BatchNormalization(axis=bn_axis, name='stem_bn')(x)
+ x = layers.Activation(activation, name='stem_activation')(x)
+
+ # Build blocks
+ blocks_args = copy.deepcopy(blocks_args)
+
+ b = 0
+ blocks = float(sum(args['repeats'] for args in blocks_args))
+ for (i, args) in enumerate(blocks_args):
+ assert args['repeats'] > 0
+ # Update block input and output filters based on depth multiplier.
+ args['filters_in'] = round_filters(args['filters_in'])
+ args['filters_out'] = round_filters(args['filters_out'])
+
+ for j in range(round_repeats(args.pop('repeats'))):
+ # The first block needs to take care of stride and filter size increase.
+ if j > 0:
+ args['strides'] = 1
+ args['filters_in'] = args['filters_out']
+ x = block(
+ x,
+ activation,
+ drop_connect_rate * b / blocks,
+ name='block{}{}_'.format(i + 1, chr(j + 97)),
+ **args)
+ b += 1
+
+ # Build top
+ x = layers.Conv2D(
+ round_filters(1280),
+ 1,
+ padding='same',
+ use_bias=False,
+ kernel_initializer=CONV_KERNEL_INITIALIZER,
+ name='top_conv')(x)
+ x = layers.BatchNormalization(axis=bn_axis, name='top_bn')(x)
+ x = layers.Activation(activation, name='top_activation')(x)
+ if include_top:
+ x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
+ if dropout_rate > 0:
+ x = layers.Dropout(dropout_rate, name='top_dropout')(x)
+ x = layers.Dense(
+ classes,
+ activation='softmax',
+ kernel_initializer=DENSE_KERNEL_INITIALIZER,
+ name='probs')(x)
+ else:
+ if pooling == 'avg':
+ x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
+ elif pooling == 'max':
+ x = layers.GlobalMaxPooling2D(name='max_pool')(x)
+
+ # Ensure that the model takes into account
+ # any potential predecessors of `input_tensor`.
+ if input_tensor is not None:
+ inputs = layer_utils.get_source_inputs(input_tensor)
+ else:
+ inputs = img_input
+
+ # Create model.
+ model = training.Model(inputs, x, name=model_name)
+
+ # Load weights.
+ if weights == 'imagenet':
+ if include_top:
+ file_suffix = '.h5'
+ file_hash = WEIGHTS_HASHES[model_name[-2:]][0]
+ else:
+ file_suffix = '_notop.h5'
+ file_hash = WEIGHTS_HASHES[model_name[-2:]][1]
+ file_name = model_name + file_suffix
+ weights_path = data_utils.get_file(
+ file_name,
+ BASE_WEIGHTS_PATH + file_name,
+ cache_subdir='models',
+ file_hash=file_hash)
+ model.load_weights(weights_path)
+ elif weights is not None:
+ model.load_weights(weights)
+ return model
+
+
+def block(inputs,
+ activation='swish',
+ drop_rate=0.,
+ name='',
+ filters_in=32,
+ filters_out=16,
+ kernel_size=3,
+ strides=1,
+ expand_ratio=1,
+ se_ratio=0.,
+ id_skip=True):
+ """An inverted residual block.
+
+ Arguments:
+ inputs: input tensor.
+ activation: activation function.
+ drop_rate: float between 0 and 1, fraction of the input units to drop.
+ name: string, block label.
+ filters_in: integer, the number of input filters.
+ filters_out: integer, the number of output filters.
+ kernel_size: integer, the dimension of the convolution window.
+ strides: integer, the stride of the convolution.
+ expand_ratio: integer, scaling coefficient for the input filters.
+ se_ratio: float between 0 and 1, fraction to squeeze the input filters.
+ id_skip: boolean.
+
+ Returns:
+ output tensor for the block.
+ """
+ bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1
+
+ # Expansion phase
+ filters = filters_in * expand_ratio
+ if expand_ratio != 1:
+ x = layers.Conv2D(
+ filters,
+ 1,
+ padding='same',
+ use_bias=False,
+ kernel_initializer=CONV_KERNEL_INITIALIZER,
+ name=name + 'expand_conv')(
+ inputs)
+ x = layers.BatchNormalization(axis=bn_axis, name=name + 'expand_bn')(x)
+ x = layers.Activation(activation, name=name + 'expand_activation')(x)
+ else:
+ x = inputs
+
+ # Depthwise Convolution
+ if strides == 2:
+ x = layers.ZeroPadding2D(
+ padding=imagenet_utils.correct_pad(x, kernel_size),
+ name=name + 'dwconv_pad')(x)
+ conv_pad = 'valid'
+ else:
+ conv_pad = 'same'
+ x = layers.DepthwiseConv2D(
+ kernel_size,
+ strides=strides,
+ padding=conv_pad,
+ use_bias=False,
+ depthwise_initializer=CONV_KERNEL_INITIALIZER,
+ name=name + 'dwconv')(x)
+ x = layers.BatchNormalization(axis=bn_axis, name=name + 'bn')(x)
+ x = layers.Activation(activation, name=name + 'activation')(x)
+
+ # Squeeze and Excitation phase
+ if 0 < se_ratio <= 1:
+ filters_se = max(1, int(filters_in * se_ratio))
+ se = layers.GlobalAveragePooling2D(name=name + 'se_squeeze')(x)
+ se = layers.Reshape((1, 1, filters), name=name + 'se_reshape')(se)
+ se = layers.Conv2D(
+ filters_se,
+ 1,
+ padding='same',
+ activation=activation,
+ kernel_initializer=CONV_KERNEL_INITIALIZER,
+ name=name + 'se_reduce')(
+ se)
+ se = layers.Conv2D(
+ filters,
+ 1,
+ padding='same',
+ activation='sigmoid',
+ kernel_initializer=CONV_KERNEL_INITIALIZER,
+ name=name + 'se_expand')(se)
+ x = layers.multiply([x, se], name=name + 'se_excite')
+
+ # Output phase
+ x = layers.Conv2D(
+ filters_out,
+ 1,
+ padding='same',
+ use_bias=False,
+ kernel_initializer=CONV_KERNEL_INITIALIZER,
+ name=name + 'project_conv')(x)
+ x = layers.BatchNormalization(axis=bn_axis, name=name + 'project_bn')(x)
+ if id_skip and strides == 1 and filters_in == filters_out:
+ if drop_rate > 0:
+ x = layers.Dropout(
+ drop_rate, noise_shape=(None, 1, 1, 1), name=name + 'drop')(x)
+ x = layers.add([x, inputs], name=name + 'add')
+ return x
+
+
+@keras_export('keras.applications.efficientnet.EfficientNetB0',
+ 'keras.applications.EfficientNetB0')
+def EfficientNetB0(include_top=True,
+ weights='imagenet',
+ input_tensor=None,
+ input_shape=None,
+ pooling=None,
+ classes=1000,
+ **kwargs):
+ return EfficientNet(
+ 1.0,
+ 1.0,
+ 224,
+ 0.2,
+ model_name='efficientnetb0',
+ include_top=include_top,
+ weights=weights,
+ input_tensor=input_tensor,
+ input_shape=input_shape,
+ pooling=pooling,
+ classes=classes,
+ **kwargs)
+
+
+@keras_export('keras.applications.efficientnet.EfficientNetB1',
+ 'keras.applications.EfficientNetB1')
+def EfficientNetB1(include_top=True,
+ weights='imagenet',
+ input_tensor=None,
+ input_shape=None,
+ pooling=None,
+ classes=1000,
+ **kwargs):
+ return EfficientNet(
+ 1.0,
+ 1.1,
+ 240,
+ 0.2,
+ model_name='efficientnetb1',
+ include_top=include_top,
+ weights=weights,
+ input_tensor=input_tensor,
+ input_shape=input_shape,
+ pooling=pooling,
+ classes=classes,
+ **kwargs)
+
+
+@keras_export('keras.applications.efficientnet.EfficientNetB2',
+ 'keras.applications.EfficientNetB2')
+def EfficientNetB2(include_top=True,
+ weights='imagenet',
+ input_tensor=None,
+ input_shape=None,
+ pooling=None,
+ classes=1000,
+ **kwargs):
+ return EfficientNet(
+ 1.1,
+ 1.2,
+ 260,
+ 0.3,
+ model_name='efficientnetb2',
+ include_top=include_top,
+ weights=weights,
+ input_tensor=input_tensor,
+ input_shape=input_shape,
+ pooling=pooling,
+ classes=classes,
+ **kwargs)
+
+
+@keras_export('keras.applications.efficientnet.EfficientNetB3',
+ 'keras.applications.EfficientNetB3')
+def EfficientNetB3(include_top=True,
+ weights='imagenet',
+ input_tensor=None,
+ input_shape=None,
+ pooling=None,
+ classes=1000,
+ **kwargs):
+ return EfficientNet(
+ 1.2,
+ 1.4,
+ 300,
+ 0.3,
+ model_name='efficientnetb3',
+ include_top=include_top,
+ weights=weights,
+ input_tensor=input_tensor,
+ input_shape=input_shape,
+ pooling=pooling,
+ classes=classes,
+ **kwargs)
+
+
+@keras_export('keras.applications.efficientnet.EfficientNetB4',
+ 'keras.applications.EfficientNetB4')
+def EfficientNetB4(include_top=True,
+ weights='imagenet',
+ input_tensor=None,
+ input_shape=None,
+ pooling=None,
+ classes=1000,
+ **kwargs):
+ return EfficientNet(
+ 1.4,
+ 1.8,
+ 380,
+ 0.4,
+ model_name='efficientnetb4',
+ include_top=include_top,
+ weights=weights,
+ input_tensor=input_tensor,
+ input_shape=input_shape,
+ pooling=pooling,
+ classes=classes,
+ **kwargs)
+
+
+@keras_export('keras.applications.efficientnet.EfficientNetB5',
+ 'keras.applications.EfficientNetB5')
+def EfficientNetB5(include_top=True,
+ weights='imagenet',
+ input_tensor=None,
+ input_shape=None,
+ pooling=None,
+ classes=1000,
+ **kwargs):
+ return EfficientNet(
+ 1.6,
+ 2.2,
+ 456,
+ 0.4,
+ model_name='efficientnetb5',
+ include_top=include_top,
+ weights=weights,
+ input_tensor=input_tensor,
+ input_shape=input_shape,
+ pooling=pooling,
+ classes=classes,
+ **kwargs)
+
+
+@keras_export('keras.applications.efficientnet.EfficientNetB6',
+ 'keras.applications.EfficientNetB6')
+def EfficientNetB6(include_top=True,
+ weights='imagenet',
+ input_tensor=None,
+ input_shape=None,
+ pooling=None,
+ classes=1000,
+ **kwargs):
+ return EfficientNet(
+ 1.8,
+ 2.6,
+ 528,
+ 0.5,
+ model_name='efficientnetb6',
+ include_top=include_top,
+ weights=weights,
+ input_tensor=input_tensor,
+ input_shape=input_shape,
+ pooling=pooling,
+ classes=classes,
+ **kwargs)
+
+
+@keras_export('keras.applications.efficientnet.EfficientNetB7',
+ 'keras.applications.EfficientNetB7')
+def EfficientNetB7(include_top=True,
+ weights='imagenet',
+ input_tensor=None,
+ input_shape=None,
+ pooling=None,
+ classes=1000,
+ **kwargs):
+ return EfficientNet(
+ 2.0,
+ 3.1,
+ 600,
+ 0.5,
+ model_name='efficientnetb7',
+ include_top=include_top,
+ weights=weights,
+ input_tensor=input_tensor,
+ input_shape=input_shape,
+ pooling=pooling,
+ classes=classes,
+ **kwargs)
+
+
+@keras_export('keras.applications.efficientnet.preprocess_input')
+def preprocess_input(x, data_format=None): # pylint: disable=unused-argument
+ return x
+
+
+@keras_export('keras.applications.efficientnet.decode_predictions')
+def decode_predictions(preds, top=5):
+ return imagenet_utils.decode_predictions(preds, top=top)
diff --git a/tensorflow/python/keras/applications/resnet_v2.py b/tensorflow/python/keras/applications/resnet_v2.py
index 4c78204..ce56fbb 100644
--- a/tensorflow/python/keras/applications/resnet_v2.py
+++ b/tensorflow/python/keras/applications/resnet_v2.py
@@ -80,7 +80,7 @@
@keras_export('keras.applications.resnet_v2.preprocess_input')
def preprocess_input(x, data_format=None):
return imagenet_utils.preprocess_input(
- x, data_format=data_format, mode='caffe')
+ x, data_format=data_format, mode='tf')
@keras_export('keras.applications.resnet_v2.decode_predictions')
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index f1e199e..f63b6e6 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -1300,7 +1300,6 @@
v = array_ops.zeros(shape=shape, dtype=tf_dtype, name=name)
if py_all(v.shape.as_list()):
return variable(v, dtype=dtype, name=name)
- track_variable(v)
return v
@@ -1335,7 +1334,6 @@
v = array_ops.ones(shape=shape, dtype=tf_dtype, name=name)
if py_all(v.shape.as_list()):
return variable(v, dtype=dtype, name=name)
- track_variable(v)
return v
@@ -1450,12 +1448,10 @@
Example:
- # TensorFlow example
>>> kvar = tf.keras.backend.random_uniform_variable((2,3), 0, 1)
>>> kvar
<tf.Variable 'Variable:0' shape=(2, 3) dtype=float32, numpy=...,
dtype=float32)>
-
"""
if dtype is None:
dtype = floatx()
@@ -1486,12 +1482,10 @@
Example:
- # TensorFlow example
>>> kvar = tf.keras.backend.random_normal_variable((2,3), 0, 1)
>>> kvar
<tf.Variable 'Variable:0' shape=(2, 3) dtype=float32, numpy=...,
dtype=float32)>
-
"""
if dtype is None:
dtype = floatx()
@@ -1628,27 +1622,23 @@
Examples:
- # dot product between tensors
>>> x = tf.keras.backend.placeholder(shape=(2, 3))
>>> y = tf.keras.backend.placeholder(shape=(3, 4))
>>> xy = tf.keras.backend.dot(x, y)
>>> xy
<tf.Tensor ... shape=(2, 4) dtype=float32>
- # dot product between tensors
>>> x = tf.keras.backend.placeholder(shape=(32, 28, 3))
>>> y = tf.keras.backend.placeholder(shape=(3, 4))
>>> xy = tf.keras.backend.dot(x, y)
>>> xy
<tf.Tensor ... shape=(32, 28, 4) dtype=float32>
- # Theano-like behavior example
>>> x = tf.keras.backend.random_uniform_variable(shape=(2, 3), low=0, high=1)
>>> y = tf.keras.backend.ones((4, 3, 5))
>>> xy = tf.keras.backend.dot(x, y)
>>> tf.keras.backend.int_shape(xy)
(2, 4, 5)
-
"""
if ndim(x) is not None and (ndim(x) > 2 or ndim(y) > 2):
x_shape = []
@@ -2400,7 +2390,6 @@
Examples:
- # maximum of two tensors
>>> x = tf.Variable([[1, 2], [3, 4]])
>>> y = tf.Variable([[2, 1], [0, -1]])
>>> m = tf.keras.backend.maximum(x, y)
@@ -2408,7 +2397,6 @@
<tf.Tensor: shape=(2, 2), dtype=int32, numpy=
array([[2, 2],
[3, 4]], dtype=int32)>
-
"""
return math_ops.maximum(x, y)
diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD
index de47f90..126fbf5 100644
--- a/tensorflow/python/keras/distribute/BUILD
+++ b/tensorflow/python/keras/distribute/BUILD
@@ -199,7 +199,6 @@
shard_count = 4,
tags = [
"multi_and_single_gpu",
- "no_rocm", # times out on ROCm
"no_windows_gpu",
"notsan",
],
diff --git a/tensorflow/python/keras/distribute/keras_premade_models_test.py b/tensorflow/python/keras/distribute/keras_premade_models_test.py
index fa77ca2..d57f50a 100644
--- a/tensorflow/python/keras/distribute/keras_premade_models_test.py
+++ b/tensorflow/python/keras/distribute/keras_premade_models_test.py
@@ -78,8 +78,8 @@
linear_model = linear.LinearModel(units=1)
dnn_model = sequential.Sequential([core.Dense(units=1)])
wide_deep_model = wide_deep.WideDeepModel(linear_model, dnn_model)
- linear_opt = gradient_descent.SGD(learning_rate=0.1)
- dnn_opt = adagrad.Adagrad(learning_rate=0.2)
+ linear_opt = gradient_descent.SGD(learning_rate=0.05)
+ dnn_opt = adagrad.Adagrad(learning_rate=0.1)
wide_deep_model.compile(
optimizer=[linear_opt, dnn_opt],
loss='mse',
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index 0d1bdc4..0b9c658 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -444,8 +444,6 @@
synchronization=synchronization,
aggregation=aggregation,
caching_device=caching_device)
- backend.track_variable(variable)
-
if regularizer is not None:
# TODO(fchollet): in the future, this should be handled at the
# level of variable creation, and weight regularization losses
@@ -454,10 +452,19 @@
self._handle_weight_regularization(name_in_scope,
variable,
regularizer)
- if trainable:
- self._trainable_weights.append(variable)
+ if isinstance(variable, tf_variables.PartitionedVariable):
+ for v in variable:
+ backend.track_variable(v)
+ if trainable:
+ self._trainable_weights.append(v)
+ else:
+ self._non_trainable_weights.append(v)
else:
- self._non_trainable_weights.append(variable)
+ backend.track_variable(variable)
+ if trainable:
+ self._trainable_weights.append(variable)
+ else:
+ self._non_trainable_weights.append(variable)
return variable
@base_layer_utils.default
@@ -888,22 +895,24 @@
@property
def trainable_weights(self):
- if self.trainable:
- nested = self._gather_children_attribute('trainable_weights')
- return self._dedup_weights(self._trainable_weights + nested)
- else:
- return []
+ collected_weights = []
+ all_layers = self._gather_unique_layers()
+ for layer in all_layers:
+ if layer.trainable:
+ collected_weights.extend(layer._trainable_weights)
+ return self._dedup_weights(collected_weights)
@property
def non_trainable_weights(self):
- if self.trainable:
- nested = self._gather_children_attribute('non_trainable_weights')
- non_trainable_weights = self._non_trainable_weights + nested
- else:
- nested = self._gather_children_attribute('weights')
- non_trainable_weights = (
- self._trainable_weights + self._non_trainable_weights + nested)
- return self._dedup_weights(non_trainable_weights)
+ collected_weights = []
+ all_layers = self._gather_unique_layers()
+ for layer in all_layers:
+ if layer.trainable:
+ collected_weights.extend(layer._non_trainable_weights)
+ else:
+ collected_weights.extend(layer._trainable_weights +
+ layer._non_trainable_weights)
+ return self._dedup_weights(collected_weights)
@property
def weights(self):
@@ -916,21 +925,23 @@
@property
def updates(self):
- if not self.trainable and not self.stateful:
- return []
+ collected_updates = []
+ all_layers = self._gather_unique_layers()
with backend.get_graph().as_default():
- updates = []
- for u in self._updates:
- if callable(u):
- try:
- u = u()
- except errors.InaccessibleTensorError:
- base_layer_utils.check_graph_consistency(
- method='add_update', force_raise=True)
- raise # check_graph_consistency may not always raise.
- base_layer_utils.check_graph_consistency(u, method='add_update')
- updates.append(u)
- return updates + self._gather_children_attribute('updates')
+ for layer in all_layers:
+ if not layer.trainable and not layer.stateful:
+ continue
+ for u in layer._updates:
+ if callable(u):
+ try:
+ u = u()
+ except errors.InaccessibleTensorError:
+ base_layer_utils.check_graph_consistency(
+ method='add_update', force_raise=True)
+ raise # check_graph_consistency may not always raise.
+ base_layer_utils.check_graph_consistency(u, method='add_update')
+ collected_updates.append(u)
+ return collected_updates
@property
def losses(self):
@@ -944,20 +955,20 @@
A list of tensors.
"""
collected_losses = []
-
- # If any eager losses are present, we assume the model to be part of an
- # eager training loop (either a custom one or the one used when
- # `run_eagerly=True`), and so we always return just the eager losses in that
- # case.
- if self._eager_losses:
- collected_losses.extend(self._eager_losses)
- else:
- collected_losses.extend(self._losses)
- for regularizer in self._callable_losses:
- loss_tensor = regularizer()
- if loss_tensor is not None:
- collected_losses.append(loss_tensor)
- return collected_losses + self._gather_children_attribute('losses')
+ all_layers = self._gather_unique_layers()
+ for layer in all_layers:
+ # If any eager losses are present, we assume the model to be part of an
+ # eager training loop (either a custom one or the one used when
+ # `run_eagerly=True`) and so we always return just the eager losses.
+ if layer._eager_losses:
+ collected_losses.extend(layer._eager_losses)
+ else:
+ collected_losses.extend(layer._losses)
+ for regularizer in layer._callable_losses:
+ loss_tensor = regularizer()
+ if loss_tensor is not None:
+ collected_losses.append(loss_tensor)
+ return collected_losses
@doc_controls.for_subclass_implementers
def add_loss(self, losses, inputs=None):
@@ -1094,7 +1105,11 @@
@property
def metrics(self):
- return self._metrics + self._gather_children_attribute('metrics')
+ collected_metrics = []
+ all_layers = self._gather_unique_layers()
+ for layer in all_layers:
+ collected_metrics.extend(layer._metrics)
+ return collected_metrics
@doc_controls.for_subclass_implementers
def add_metric(self, value, aggregation=None, name=None):
@@ -2370,18 +2385,29 @@
# at __delattr__.
super(tracking.AutoTrackable, self).__setattr__(name, value)
- def _gather_children_attribute(self, attribute):
- assert attribute in {
- 'weights', 'trainable_weights', 'non_trainable_weights', 'updates',
- 'losses', 'metrics'
- }
+ def _gather_unique_layers(self):
+ """Returns the current layer and all its children depth first deduped.
+
+ We are deduping after getting the layers to maintain the order.
+ """
+ all_layers = self._gather_layers()
+ unique_layers, seen_layers = [], object_identity.ObjectIdentitySet()
+ for layer in all_layers:
+ if layer not in seen_layers:
+ unique_layers.append(layer)
+ # Track the Variable's identity to avoid __eq__ issues.
+ seen_layers.add(layer)
+ return unique_layers
+
+ def _gather_layers(self):
+ """Returns the current layer and all its children depth first."""
+ all_layers = [self]
if hasattr(self, '_layers'):
- nested_layers = trackable_layer_utils.filter_empty_layer_containers(
+ child_layers = trackable_layer_utils.filter_empty_layer_containers(
self._layers)
- return list(
- itertools.chain.from_iterable(
- getattr(layer, attribute) for layer in nested_layers))
- return []
+ for child_layer in child_layers:
+ all_layers.extend(child_layer._gather_layers())
+ return all_layers
@property
@tracking.cached_per_instance
diff --git a/tensorflow/python/keras/engine/base_layer_test.py b/tensorflow/python/keras/engine/base_layer_test.py
index 201c269..fa77088 100644
--- a/tensorflow/python/keras/engine/base_layer_test.py
+++ b/tensorflow/python/keras/engine/base_layer_test.py
@@ -226,28 +226,6 @@
self.assertEqual(new_layer.bias_regularizer, bias_reg)
self.assertEqual(layer.get_config(), new_layer.get_config())
- @keras_parameterized.run_all_keras_modes
- def test_add_loss_correctness(self):
-
- class MyLayer(keras.layers.Layer):
-
- def call(self, inputs, training=None):
- self.add_loss(math_ops.reduce_sum(inputs))
- return inputs
-
- inputs = keras.Input((3,))
- layer = MyLayer()
- outputs = layer(inputs)
- model = keras.Model(inputs, outputs)
- self.assertEqual(len(model.losses), 1)
- model.compile(
- 'sgd',
- 'mse',
- run_eagerly=testing_utils.should_run_eagerly(),
- experimental_run_tf_function=testing_utils.should_run_tf_function())
- loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
- self.assertEqual(loss, 2 * 3)
-
@test_util.run_in_graph_and_eager_modes
def test_invalid_forward_pass(self):
inputs = keras.Input((3,))
diff --git a/tensorflow/python/keras/engine/data_adapter.py b/tensorflow/python/keras/engine/data_adapter.py
index 13fb866..ebcf0db 100644
--- a/tensorflow/python/keras/engine/data_adapter.py
+++ b/tensorflow/python/keras/engine/data_adapter.py
@@ -21,6 +21,7 @@
import abc
import collections
import contextlib
+import functools
import itertools
import math
import random
@@ -1148,7 +1149,6 @@
# TODO(omalleyt): Handle `validation_split` with separate utility.
# TODO(omalleyt): Handle `validation_data` batch size when `x` is a gen.
- # TODO(omalleyt): Handle `class_weight` in `DataAdapter`s.
def __init__(self,
x,
y=None,
@@ -1158,6 +1158,7 @@
initial_epoch=0,
epochs=1,
shuffle=False,
+ class_weight=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False):
@@ -1182,6 +1183,8 @@
strategy = ds_context.get_strategy()
dataset = self._train_adapter.get_dataset()
+ if class_weight:
+ dataset = dataset.map(_make_class_weight_map_fn(class_weight))
self._train_dataset = strategy.experimental_distribute_dataset(dataset)
self._steps_per_epoch = self._infer_steps(steps_per_epoch)
@@ -1252,3 +1255,116 @@
if size >= 0:
return size
return None
+
+
+def _make_class_weight_map_fn(class_weight):
+ """Applies class weighting to a `Dataset`.
+
+ The `Dataset` is assumed to be in format `(x, y)` or `(x, y, sw)`, where
+ `y` must be a single `Tensor`.
+
+ Arguments:
+ class_weight: A map where the keys are integer class ids and values are
+ the class weights, e.g. `{0: 0.2, 1: 0.6, 2: 0.3}`
+
+ Returns:
+ A function that can be used with `tf.data.Dataset.map` to apply class
+ weighting.
+ """
+ class_ids = list(sorted(class_weight.keys()))
+ expected_class_ids = list(range(len(class_ids)))
+ if class_ids != expected_class_ids:
+ error_msg = (
+ "Expected `class_weight` to be a dict with keys from 0 to one less "
+ "than the number of classes, found {}").format(class_weight)
+ raise ValueError(error_msg)
+
+ class_weight_tensor = ops.convert_to_tensor(
+ [class_weight[c] for c in class_ids])
+
+ def _class_weights_map_fn(*data):
+ """Convert `class_weight` to `sample_weight`."""
+ if len(data) == 2:
+ x, y = data
+ sw = None
+ else:
+ x, y, sw = data
+
+ if nest.is_sequence(y):
+ raise ValueError(
+ "`class_weight` is only supported for `Model`s with a single output.")
+
+ cw = array_ops.gather_v2(class_weight_tensor, y)
+ if sw is not None:
+ cw = math_ops.cast(cw, sw.dtype)
+ if len(cw.shape.as_list()) > len(sw.shape.as_list()):
+ cw = array_ops.squeeze(cw)
+ # `class_weight` and `sample_weight` are multiplicative.
+ sw = sw * cw
+ else:
+ sw = cw
+
+ return x, y, sw
+
+ return _class_weights_map_fn
+
+
+def train_validation_split(arrays, validation_split, shuffle=True):
+ """Split arrays into random train and validation subsets.
+
+ Arguments:
+ arrays: Tensors to split. Allowed inputs are arbitrarily nested structures
+ of Tensors and NumPy arrays.
+ validation_split: Float between 0 and 1. The proportion of the dataset to
+ include in the validation split. The rest of the dataset will be included
+ in the training split.
+ shuffle: Bool. Whether to shuffle the data before performing a split. If
+ `False`, the last `validation_split` fraction of that training data will
+ become the validation split.
+
+ Returns:
+ `(train_arrays, validation_arrays)`
+ """
+
+ def _can_split(t):
+ tensor_types = (ops.Tensor, np.ndarray)
+ if pd:
+ tensor_types = (ops.Tensor, np.ndarray, pd.Series, pd.DataFrame)
+ return isinstance(t, tensor_types) or t is None
+
+ flat_arrays = nest.flatten(arrays)
+ if not all(_can_split(t) for t in flat_arrays):
+ raise ValueError(
+ "`validation_split` is only supported for Tensors or NumPy "
+ "arrays, found: {}".format(arrays))
+
+ if all(t is None for t in flat_arrays):
+ return arrays, arrays
+
+ first_non_none = None
+ for t in flat_arrays:
+ if t is not None:
+ first_non_none = t
+ break
+
+ # Assumes all arrays have the same batch shape or are `None`.
+ batch_dim = int(first_non_none.shape[0])
+ indices = ops.convert_to_tensor(range(batch_dim))
+ if shuffle:
+ indices = random_ops.random_shuffle(indices)
+ split_at = int(math.floor(batch_dim * (1. - validation_split)))
+ train_indices = indices[:split_at]
+ val_indices = indices[split_at:]
+
+ def _split(t, indices):
+ if t is None:
+ return t
+ t = ops.convert_to_tensor(t)
+ return array_ops.gather_v2(t, indices)
+
+ train_arrays = nest.map_structure(
+ functools.partial(_split, indices=train_indices), arrays)
+ val_arrays = nest.map_structure(
+ functools.partial(_split, indices=val_indices), arrays)
+
+ return train_arrays, val_arrays
diff --git a/tensorflow/python/keras/engine/data_adapter_test.py b/tensorflow/python/keras/engine/data_adapter_test.py
index 5b0f119..b399c6b 100644
--- a/tensorflow/python/keras/engine/data_adapter_test.py
+++ b/tensorflow/python/keras/engine/data_adapter_test.py
@@ -963,6 +963,161 @@
self.assertEqual(returned_data, [[([0],), ([1],),
([2],)], [([0],), ([1],), ([2],)]])
+ def test_class_weight(self):
+ data_handler = data_adapter.DataHandler(
+ x=[[0], [1], [2]],
+ y=[[2], [1], [0]],
+ class_weight={
+ 0: 0.5,
+ 1: 1.,
+ 2: 1.5
+ },
+ epochs=2,
+ steps_per_epoch=3)
+ returned_data = []
+ for _, iterator in data_handler.enumerate_epochs():
+ epoch_data = []
+ for _ in data_handler.steps():
+ epoch_data.append(next(iterator))
+ returned_data.append(epoch_data)
+ returned_data = self.evaluate(returned_data)
+ self.assertEqual(returned_data, [[([0], [2], [1.5]), ([1], [1], [1.]),
+ ([2], [0], [0.5])],
+ [([0], [2], [1.5]), ([1], [1], [1.]),
+ ([2], [0], [0.5])]])
+
+ def test_class_weight_and_sample_weight(self):
+ data_handler = data_adapter.DataHandler(
+ x=[[0], [1], [2]],
+ y=[[2], [1], [0]],
+ sample_weight=[[1.], [2.], [4.]],
+ class_weight={
+ 0: 0.5,
+ 1: 1.,
+ 2: 1.5
+ },
+ epochs=2,
+ steps_per_epoch=3)
+ returned_data = []
+ for _, iterator in data_handler.enumerate_epochs():
+ epoch_data = []
+ for _ in data_handler.steps():
+ epoch_data.append(next(iterator))
+ returned_data.append(epoch_data)
+ returned_data = self.evaluate(returned_data)
+ self.assertEqual(returned_data, [[([0], [2], [1.5]), ([1], [1], [2.]),
+ ([2], [0], [2.])],
+ [([0], [2], [1.5]), ([1], [1], [2.]),
+ ([2], [0], [2.])]])
+
+ def test_class_weight_user_errors(self):
+ with self.assertRaisesRegexp(ValueError, 'to be a dict with keys'):
+ data_adapter.DataHandler(
+ x=[[0], [1], [2]],
+ y=[[2], [1], [0]],
+ batch_size=1,
+ sample_weight=[[1.], [2.], [4.]],
+ class_weight={
+ 0: 0.5,
+ 1: 1.,
+ 3: 1.5 # Skips class `2`.
+ })
+
+ with self.assertRaisesRegexp(ValueError, 'with a single output'):
+ data_adapter.DataHandler(
+ x=np.ones((10, 1)),
+ y=[np.ones((10, 1)), np.zeros((10, 1))],
+ batch_size=2,
+ class_weight={
+ 0: 0.5,
+ 1: 1.,
+ 2: 1.5
+ })
+
+
+class TestValidationSplit(keras_parameterized.TestCase):
+
+ @parameterized.named_parameters(('numpy_arrays', True), ('tensors', False))
+ def test_validation_split_shuffled(self, use_numpy):
+ if use_numpy:
+ x = np.array([0, 1, 2, 3, 4])
+ y = np.array([0, 2, 4, 6, 8])
+ sw = np.array([0, 4, 8, 12, 16])
+ else:
+ x = ops.convert_to_tensor([0, 1, 2, 3, 4])
+ y = ops.convert_to_tensor([0, 2, 4, 6, 8])
+ sw = ops.convert_to_tensor([0, 4, 8, 12, 16])
+
+ (train_x, train_y, train_sw), (val_x, val_y, val_sw) = (
+ data_adapter.train_validation_split((x, y, sw), validation_split=0.2))
+
+ self.assertEqual(int(train_x.shape[0]), 4)
+ self.assertEqual(int(train_y.shape[0]), 4)
+ self.assertEqual(int(train_sw.shape[0]), 4)
+ for i in range(4):
+ # Check that all arrays were shuffled in identical order.
+ self.assertEqual(2 * train_x[i].numpy(), train_y[i].numpy())
+ self.assertEqual(2 * train_y[i].numpy(), train_sw[i].numpy())
+
+ self.assertEqual(int(val_x.shape[0]), 1)
+ self.assertEqual(int(val_y.shape[0]), 1)
+ self.assertEqual(int(val_sw.shape[0]), 1)
+ for i in range(1):
+ # Check that all arrays were shuffled in identical order.
+ self.assertEqual(2 * train_x[i].numpy(), train_y[i].numpy())
+ self.assertEqual(2 * train_y[i].numpy(), train_sw[i].numpy())
+
+ # Check that arrays contain expected values.
+ self.assertEqual(
+ sorted(array_ops.concat([train_x, val_x], axis=0).numpy().tolist()),
+ sorted(ops.convert_to_tensor(x).numpy().tolist()))
+ self.assertEqual(
+ sorted(array_ops.concat([train_y, val_y], axis=0).numpy().tolist()),
+ sorted(ops.convert_to_tensor(y).numpy().tolist()))
+ self.assertEqual(
+ sorted(array_ops.concat([train_sw, val_sw], axis=0).numpy().tolist()),
+ sorted(ops.convert_to_tensor(sw).numpy().tolist()))
+
+ @parameterized.named_parameters(('numpy_arrays', True), ('tensors', False))
+ def test_validation_split_unshuffled(self, use_numpy):
+ if use_numpy:
+ x = np.array([0, 1, 2, 3, 4])
+ y = np.array([0, 2, 4, 6, 8])
+ sw = np.array([0, 4, 8, 12, 16])
+ else:
+ x = ops.convert_to_tensor([0, 1, 2, 3, 4])
+ y = ops.convert_to_tensor([0, 2, 4, 6, 8])
+ sw = ops.convert_to_tensor([0, 4, 8, 12, 16])
+
+ (train_x, train_y, train_sw), (val_x, val_y, val_sw) = (
+ data_adapter.train_validation_split((x, y, sw),
+ validation_split=0.2,
+ shuffle=False))
+
+ self.assertEqual(train_x.numpy().tolist(), [0, 1, 2, 3])
+ self.assertEqual(train_y.numpy().tolist(), [0, 2, 4, 6])
+ self.assertEqual(train_sw.numpy().tolist(), [0, 4, 8, 12])
+
+ self.assertEqual(val_x.numpy().tolist(), [4])
+ self.assertEqual(val_y.numpy().tolist(), [8])
+ self.assertEqual(val_sw.numpy().tolist(), [16])
+
+ def test_validation_split_user_error(self):
+ with self.assertRaisesRegexp(ValueError, 'is only supported for Tensors'):
+ data_adapter.train_validation_split(
+ lambda: np.ones((10, 1)), validation_split=0.2)
+
+ def test_validation_split_none(self):
+ train_sw, val_sw = data_adapter.train_validation_split(
+ None, validation_split=0.2)
+ self.assertIsNone(train_sw)
+ self.assertIsNone(val_sw)
+
+ (_, train_sw), (_, val_sw) = data_adapter.train_validation_split(
+ (np.ones((10, 1)), None), validation_split=0.2)
+ self.assertIsNone(train_sw)
+ self.assertIsNone(val_sw)
+
if __name__ == '__main__':
ops.enable_eager_execution()
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py
index f279c70..4313b37 100644
--- a/tensorflow/python/keras/engine/network.py
+++ b/tensorflow/python/keras/engine/network.py
@@ -330,8 +330,6 @@
self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
layer._attribute_sentinel.add_parent(self._attribute_sentinel)
- self._track_layers(layers)
-
# Create the node linking internal inputs to internal outputs.
node_module.Node(
outbound_layer=self,
@@ -397,18 +395,25 @@
return any(layer.dynamic for layer in self.layers)
return self._dynamic or any(layer.dynamic for layer in self.layers)
- def _track_layers(self, layers):
- """Add Trackable dependencies on a list of Layers."""
+ @property
+ def _layer_checkpoint_dependencies(self):
+ """Dictionary of layer dependencies to be included in the checkpoint."""
+ # Use getattr becuase this function can be called from __setattr__, at which
+ # point the _is_graph_network attribute has not been created.
+ if (not getattr(self, '_is_graph_network', False) and
+ base_layer_utils.is_subclassed(self)):
+ return {} # Only add layer dependencies for graph networks
+
weight_layer_index = 0
- for layer_index, layer in enumerate(layers):
+
+ dependencies = {}
+ for layer_index, layer in enumerate(self.layers):
try:
if layer.weights:
# Keep a separate index for layers which have weights. This allows
# users to insert Layers without weights anywhere in the network
# without breaking checkpoints.
- self._track_trackable(
- layer, name='layer_with_weights-%d' % weight_layer_index,
- overwrite=True)
+ dependencies['layer_with_weights-%d' % weight_layer_index] = layer
weight_layer_index += 1
except ValueError:
# The layer might have weights, but may not be built yet. We just treat
@@ -417,8 +422,31 @@
# Even if it doesn't have weights, we should still track everything in
# case it has/will have Trackable dependencies.
- self._track_trackable(
- layer, name='layer-%d' % layer_index, overwrite=True)
+ dependencies['layer-%d' % layer_index] = layer
+ return dependencies
+
+ @property
+ def _checkpoint_dependencies(self):
+ dependencies = [
+ trackable.TrackableReference(name=name, ref=layer)
+ for name, layer in self._layer_checkpoint_dependencies.items()]
+ dependencies.extend(super(Network, self)._checkpoint_dependencies)
+ return dependencies
+
+ def _lookup_dependency(self, name):
+ layer_dependencies = self._layer_checkpoint_dependencies
+ if name in layer_dependencies:
+ return layer_dependencies[name]
+ return super(Network, self)._lookup_dependency(name)
+
+ def _handle_deferred_layer_dependencies(self, layers):
+ """Handles layer checkpoint dependencies that are added after init."""
+ layer_checkpoint_dependencies = self._layer_checkpoint_dependencies
+ layer_to_name = {v: k for k, v in layer_checkpoint_dependencies.items()}
+ for layer in layers:
+ if layer in layer_to_name:
+ self._handle_deferred_dependencies(name=layer_to_name[layer],
+ trackable=layer)
def __setattr__(self, name, value):
if not getattr(self, '_self_setattr_tracking', True):
@@ -686,8 +714,7 @@
'Instead, in order to instantiate and build your '
'model, `call` your model on real tensor data (of '
'the correct dtype).')
- if self._layers:
- self._track_layers(self._layers)
+
self.built = True
def call(self, inputs, training=None, mask=None):
@@ -1437,15 +1464,18 @@
# Insert layers and update other layer attrs.
layer_set = set(self._layers)
+ deferred_layers = []
for layer in layers:
if layer not in layer_set:
self._layers.append(layer)
+ deferred_layers.append(layer)
self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
# This allows the added layer to broadcast mutations to the current
# layer, which is necessary to ensure cache correctness.
layer._attribute_sentinel.add_parent(self._attribute_sentinel)
layer_set.add(layer)
+ self._handle_deferred_layer_dependencies(deferred_layers)
def _assert_weights_created(self):
"""Asserts that all the weights for the network have been created.
diff --git a/tensorflow/python/keras/engine/network_test.py b/tensorflow/python/keras/engine/network_test.py
index efa151d..fd4f47a 100644
--- a/tensorflow/python/keras/engine/network_test.py
+++ b/tensorflow/python/keras/engine/network_test.py
@@ -121,61 +121,6 @@
self.assertEqual(len(layer.get_updates_for(x1)), 2)
self.assertEqual(len(layer.get_updates_for(None)), 0)
- @test_util.run_deprecated_v1
- def test_get_losses(self):
-
- class MyLayer(keras.layers.Layer):
-
- def build(self, input_shape):
- self.a = self.add_variable('a',
- (1, 1),
- 'float32',
- trainable=False)
- self.b = self.add_variable('b',
- (1, 1),
- 'float32',
- trainable=False)
- self.add_loss(math_ops.reduce_sum(self.a))
- self.built = True
-
- def call(self, inputs):
- self.add_loss(math_ops.reduce_sum(inputs),
- inputs=True)
- return inputs + 1
-
- x1 = input_layer_lib.Input(shape=(1,))
- layer = MyLayer()
- _ = layer(x1)
-
- self.assertEqual(len(layer.losses), 2)
- self.assertEqual(len(layer.get_losses_for(x1)), 1)
- self.assertEqual(len(layer.get_losses_for(None)), 1)
-
- x2 = input_layer_lib.Input(shape=(1,))
- y2 = layer(x2)
-
- self.assertEqual(len(layer.losses), 3)
- self.assertEqual(len(layer.get_losses_for(x1)), 1)
- self.assertEqual(len(layer.get_losses_for(x2)), 1)
- self.assertEqual(len(layer.get_losses_for(None)), 1)
-
- network = network_lib.Network(x2, y2)
- self.assertEqual(len(network.losses), 3)
- self.assertEqual(len(network.get_losses_for(x1)), 1)
- self.assertEqual(len(network.get_losses_for(x2)), 1)
- self.assertEqual(len(network.get_losses_for(None)), 1)
-
- x3 = input_layer_lib.Input(shape=(1,))
- _ = layer(x3)
- self.assertEqual(len(network.losses), 4)
-
- x4 = input_layer_lib.Input(shape=(1,))
- _ = network(x4)
- self.assertEqual(len(network.losses), 5)
- self.assertEqual(len(network.get_losses_for(x2)), 1)
- self.assertEqual(len(network.get_losses_for(x4)), 1)
- self.assertEqual(len(network.get_losses_for(None)), 1)
-
@test_util.run_in_graph_and_eager_modes()
def testTopologicalAttributes(self):
# test layer attributes / methods related to cross-layer connectivity.
diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py
index 369cd31..5557a00 100644
--- a/tensorflow/python/keras/engine/sequential.py
+++ b/tensorflow/python/keras/engine/sequential.py
@@ -217,8 +217,7 @@
self._init_graph_network(self.inputs, self.outputs, name=self.name)
else:
self._layers.append(layer)
- if self._layers:
- self._track_layers(self._layers)
+ self._handle_deferred_layer_dependencies([layer])
self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
# Different Model types add to `._layers` in different ways, so for safety
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 771c6e8..d6ef71b 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -256,7 +256,7 @@
will then be the *weighted sum* of all individual losses,
weighted by the `loss_weights` coefficients.
If a list, it is expected to have a 1:1 mapping
- to the model's outputs. If a tensor, it is expected to map
+ to the model's outputs. If a dict, it is expected to map
output names (strings) to scalar coefficients.
sample_weight_mode: If you need to do timestep-wise
sample weighting (2D weights), set this to `"temporal"`.
@@ -371,8 +371,9 @@
metrics = []
if self._is_compiled:
metrics += self._compile_metric_functions
- metrics.extend(self._metrics)
- metrics.extend(_get_metrics_from_layers(self._layers))
+ all_layers = self._gather_unique_layers()
+ for l in all_layers:
+ metrics.extend(l._metrics) # pylint: disable=protected-access
return metrics
@property
@@ -2941,27 +2942,3 @@
return sparse_tensor.SparseTensor(indices, data, shape)
else:
return value
-
-
-def _get_metrics_from_layers(layers):
- """Returns list of metrics from the given layers.
-
- This will not include the `compile` metrics of a model layer.
-
- Arguments:
- layers: List of layers.
-
- Returns:
- List of metrics.
- """
- metrics = []
- layers = trackable_layer_utils.filter_empty_layer_containers(layers)
- for layer in layers:
- if isinstance(layer, Model):
- # We cannot call 'metrics' on the model because we do not want to
- # include the metrics that were added in compile API of a nested model.
- metrics.extend(layer._metrics) # pylint: disable=protected-access
- metrics.extend(_get_metrics_from_layers(layer.layers))
- else:
- metrics.extend(layer.metrics)
- return metrics
diff --git a/tensorflow/python/keras/engine/training_eager_test.py b/tensorflow/python/keras/engine/training_eager_test.py
index 1fabe36..3fdc723 100644
--- a/tensorflow/python/keras/engine/training_eager_test.py
+++ b/tensorflow/python/keras/engine/training_eager_test.py
@@ -280,18 +280,6 @@
history = model.fit(dataset, epochs=1, steps_per_epoch=10)
self.assertAlmostEqual(history.history['loss'][-1], 0.5836, 4)
- def test_loss_in_call(self):
-
- class HasLoss(keras.layers.Layer):
-
- def call(self, x):
- self.add_loss(x)
- return x
-
- layer = HasLoss()
- layer(1.) # Plain-value inputs are only valid in eager mode.
- self.assertEqual(1, len(layer.losses))
-
@parameterized.named_parameters([
('_None', contextlib.contextmanager(lambda: iter([None])), 0., 4.),
('_0', lambda: keras.backend.learning_phase_scope(0), 4., 4.),
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index a83689c..3674c1e 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -34,7 +34,6 @@
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function
-from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.keras import keras_parameterized
@@ -48,6 +47,7 @@
from tensorflow.python.keras.utils import np_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables as variables_lib
@@ -857,96 +857,6 @@
model.predict(x_function(use_namedtuple=True), **predict_kwargs)
@keras_parameterized.run_all_keras_modes
- @keras_parameterized.run_with_all_model_types
- def test_activity_regularizer_fit(self):
- loss = {}
- for reg in [None, 'l2']:
- layers = [
- keras.layers.Dense(
- 10, activation='relu', activity_regularizer=reg,
- kernel_initializer='ones', use_bias=False),
- keras.layers.Dense(
- 1, activation='sigmoid', kernel_initializer='ones',
- use_bias=False),
- ]
-
- model = testing_utils.get_model_from_layers(
- layers, input_shape=(10,))
-
- x = np.ones((10, 10), 'float32')
- y = np.ones((10, 1), 'float32')
-
- optimizer = RMSPropOptimizer(learning_rate=0.001)
- model.compile(
- optimizer,
- 'binary_crossentropy',
- run_eagerly=testing_utils.should_run_eagerly(),
- experimental_run_tf_function=testing_utils.should_run_tf_function())
- model.fit(x, y, batch_size=2, epochs=5)
- loss[reg] = model.evaluate(x, y)
- self.assertLess(loss[None], loss['l2'])
-
- @keras_parameterized.run_all_keras_modes
- @keras_parameterized.run_with_all_model_types
- def test_activity_regularizer_loss_value(self):
- layer = keras.layers.Dense(
- 1, kernel_initializer=keras.initializers.zeros(),
- bias_initializer=keras.initializers.ones(), activity_regularizer='l2')
-
- model = testing_utils.get_model_from_layers([layer], input_shape=(10,))
-
- x = np.ones((10, 10), 'float32')
- y = np.ones((10, 1), 'float32')
- optimizer = RMSPropOptimizer(learning_rate=0.001)
- model.compile(
- optimizer,
- 'binary_crossentropy',
- run_eagerly=testing_utils.should_run_eagerly(),
- experimental_run_tf_function=testing_utils.should_run_tf_function())
- loss = model.test_on_batch(x, y)
- self.assertAlmostEqual(0.01, loss, places=4)
-
- @keras_parameterized.run_all_keras_modes
- def test_activity_regularizer_batch_independent(self):
- inputs = keras.layers.Input(shape=(10,))
- x = keras.layers.Dense(
- 10, activation='relu', activity_regularizer='l2')(
- inputs)
- outputs = keras.layers.Dense(1, activation='sigmoid')(x)
- model = keras.Model(inputs, outputs)
-
- optimizer = RMSPropOptimizer(learning_rate=0.001)
- model.compile(
- optimizer,
- 'binary_crossentropy',
- run_eagerly=testing_utils.should_run_eagerly(),
- experimental_run_tf_function=testing_utils.should_run_tf_function())
-
- x = np.ones((10, 10), 'float32')
- y = np.ones((10, 1), 'float32')
- loss_small_batch = model.test_on_batch(x, y)
-
- x2 = np.ones((20, 10), 'float32')
- y2 = np.ones((20, 1), 'float32')
- loss_big_batch = model.test_on_batch(x2, y2)
-
- self.assertAlmostEqual(loss_small_batch, loss_big_batch, places=4)
-
- @keras_parameterized.run_all_keras_modes
- def test_activity_regularizer_in_model_call(self):
-
- class MyModel(keras.Model):
-
- def call(self, inputs):
- self.add_loss(inputs)
- return inputs
-
- x = ops.convert_to_tensor(1.)
- model = MyModel()
- _ = model(x)
- self.assertEqual(1, len(model.losses))
-
- @keras_parameterized.run_all_keras_modes
def test_custom_mapping_in_config(self):
class MyModel(keras.Model):
@@ -1140,6 +1050,42 @@
# be ~0.15, compared to the correct answer of O(1e-7).
self.assertLess(history.history['loss'][-1], 1e-6)
+ @keras_parameterized.run_all_keras_modes
+ def test_weight_shared_across_layers(self):
+
+ class AddWeightLayer(keras.layers.Layer):
+
+ def __init__(self, trainable_var, non_trainable_var):
+ self.trainable_var = trainable_var
+ self.non_trainable_var = non_trainable_var
+ super(AddWeightLayer, self).__init__()
+
+ def call(self, inputs):
+ return inputs + self.trainable_var
+
+ class LayerWithWeightSharedLayers(keras.layers.Layer):
+
+ def __init__(self):
+ super(LayerWithWeightSharedLayers, self).__init__()
+ shared_trainable_var = resource_variable_ops.ResourceVariable(1.)
+ shared_non_trainable_var = resource_variable_ops.ResourceVariable(
+ 1., trainable=False)
+ self.layer1 = AddWeightLayer(shared_trainable_var,
+ shared_non_trainable_var)
+ self.layer2 = AddWeightLayer(shared_trainable_var,
+ shared_non_trainable_var)
+
+ def call(self, inputs):
+ return self.layer2(self.layer1(inputs))
+
+ l = LayerWithWeightSharedLayers()
+ self.assertEqual(l._layers, [l.layer1, l.layer2])
+ self.assertEqual(l.variables,
+ [l.layer1.trainable_var, l.layer1.non_trainable_var])
+ self.assertEqual(l.trainable_variables, [l.layer1.trainable_var])
+ self.assertEqual(l.non_trainable_variables, [l.layer1.non_trainable_var])
+ self.assertLen(l.get_weights(), 2)
+
def test_logs_passed_to_callbacks(self):
with self.cached_session():
input_dim = 5
@@ -1474,107 +1420,6 @@
'`validation_data` is None.'):
model.fit(x, y, epochs=4, validation_data=None, validation_steps=3)
- @keras_parameterized.run_all_keras_modes
- def test_add_loss_correctness(self):
- inputs = keras.Input(shape=(1,))
- targets = keras.Input(shape=(1,))
- outputs = testing_utils.Bias()(inputs)
- model = keras.Model([inputs, targets], outputs)
-
- model.add_loss(2 * math_ops.reduce_mean(
- keras.losses.mean_absolute_error(targets, outputs)))
-
- model.add_loss(keras.losses.MeanAbsoluteError()(targets, outputs))
-
- model.compile(
- keras.optimizer_v2.gradient_descent.SGD(0.025),
- loss=keras.losses.MeanAbsoluteError(),
- run_eagerly=testing_utils.should_run_eagerly(),
- experimental_run_tf_function=testing_utils.should_run_tf_function())
-
- x = np.array([[0.], [1.], [2.]])
- y = np.array([[0.5], [2.], [3.5]])
- history = model.fit([x, y], y, batch_size=3, epochs=5)
- self.assertAllClose(history.history['loss'], [4., 3.6, 3.2, 2.8, 2.4], 1e-3)
-
- @keras_parameterized.run_all_keras_modes
- def test_add_loss_with_sample_weight_correctness(self):
- inputs = keras.Input(shape=(1,))
- targets = keras.Input(shape=(1,))
- sw = keras.Input(shape=(1,))
- outputs = testing_utils.Bias()(inputs)
- model = keras.Model([inputs, targets, sw], outputs)
-
- model.add_loss(2 * math_ops.reduce_mean(
- sw * keras.losses.mean_absolute_error(targets, outputs)))
- model.add_loss(keras.losses.MeanAbsoluteError()(targets, outputs, sw))
- model.compile(
- keras.optimizer_v2.gradient_descent.SGD(0.025),
- loss=keras.losses.MeanAbsoluteError(),
- run_eagerly=testing_utils.should_run_eagerly(),
- experimental_run_tf_function=testing_utils.should_run_tf_function())
-
- x = np.array([[0.], [1.], [2.]])
- y = np.array([[0.5], [2.], [3.5]])
- w = np.array([1.25, 0.5, 1.25])
- history = model.fit([x, y, w], y, batch_size=3, epochs=5, sample_weight=w)
- self.assertAllClose(history.history['loss'], [4., 3.6, 3.2, 2.8, 2.4], 1e-3)
-
- @keras_parameterized.run_all_keras_modes
- def test_unconditional_add_loss_correctness(self):
-
- class MyLayer(keras.layers.Layer):
-
- def call(self, inputs, training=None):
- # Reachable from the inputs but marked as unconditional.
- self.add_loss(math_ops.reduce_sum(inputs))
- return inputs
-
- inputs = keras.Input((3,))
- layer = MyLayer()
- outputs = layer(inputs)
- model = keras.Model(inputs, outputs)
- self.assertEqual(len(model.losses), 1)
- model.compile(
- 'sgd',
- 'mse',
- run_eagerly=testing_utils.should_run_eagerly(),
- experimental_run_tf_function=testing_utils.should_run_tf_function())
- loss = model.train_on_batch(np.ones((2, 3)), np.ones((2, 3)))
- self.assertEqual(loss, 2 * 3)
-
- @keras_parameterized.run_all_keras_modes
- def test_clear_losses(self):
-
- class LayerWithSharedNestedLossLayer(keras.layers.Layer):
-
- def __init__(self):
- super(LayerWithSharedNestedLossLayer, self).__init__()
- self.loss_layer = keras.layers.ActivityRegularization(l2=0.001)
- self.add_weight(shape=(1,), regularizer='l2')
-
- def call(self, x):
- x = self.loss_layer(x)
- return self.loss_layer(x)
-
- inputs = keras.Input(shape=(1,))
- outputs = LayerWithSharedNestedLossLayer()(inputs)
- model = keras.Model(inputs, outputs)
- # Weight loss + 2 activity losses.
- self.assertEqual(len(model.losses), 3)
-
- x = array_ops.ones((1, 1))
- model(x)
- y = array_ops.ones((1, 1))
- model(y)
- if context.executing_eagerly():
- # Eager losses are cleared every `__call__`.
- self.assertEqual(len(model.losses), 3)
- else:
- self.assertEqual(len(model.get_losses_for(x)), 2)
- self.assertEqual(len(model.get_losses_for(y)), 2)
- self.assertEqual(len(model.get_losses_for(None)), 1)
-
@keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
def test_layer_with_variable_output(self):
diff --git a/tensorflow/python/keras/layers/__init__.py b/tensorflow/python/keras/layers/__init__.py
index 07cb1bd..3f648b4 100644
--- a/tensorflow/python/keras/layers/__init__.py
+++ b/tensorflow/python/keras/layers/__init__.py
@@ -44,6 +44,7 @@
from tensorflow.python.keras.layers.preprocessing.text_vectorization_v1 import TextVectorization
from tensorflow.python.keras.layers.preprocessing.text_vectorization import TextVectorization as TextVectorizationV2
TextVectorizationV1 = TextVectorization
+from tensorflow.python.keras.layers.preprocessing.image_preprocessing import Rescaling
# Advanced activations.
from tensorflow.python.keras.layers.advanced_activations import LeakyReLU
diff --git a/tensorflow/python/keras/layers/dense_attention.py b/tensorflow/python/keras/layers/dense_attention.py
index 210cc16..ba249d7 100644
--- a/tensorflow/python/keras/layers/dense_attention.py
+++ b/tensorflow/python/keras/layers/dense_attention.py
@@ -219,7 +219,7 @@
2. Use scores to calculate a distribution with shape
`[batch_size, Tq, Tv]`: `distribution = tf.nn.softmax(scores)`.
3. Use `distribution` to create a linear combination of `value` with
- shape `batch_size, Tq, dim]`:
+ shape `[batch_size, Tq, dim]`:
`return tf.matmul(distribution, value)`.
Args:
@@ -406,7 +406,7 @@
# Query embeddings of shape [batch_size, Tq, dimension].
query_embeddings = token_embedding(query_input)
# Value embeddings of shape [batch_size, Tv, dimension].
- value_embeddings = token_embedding(query_input)
+ value_embeddings = token_embedding(value_input)
# CNN layer.
cnn_layer = tf.keras.layers.Conv1D(
diff --git a/tensorflow/python/keras/layers/preprocessing/categorical.py b/tensorflow/python/keras/layers/preprocessing/categorical.py
new file mode 100644
index 0000000..8572a60
--- /dev/null
+++ b/tensorflow/python/keras/layers/preprocessing/categorical.py
@@ -0,0 +1,106 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras categorical preprocessing layers."""
+# pylint: disable=g-classes-have-attributes
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_spec
+from tensorflow.python.keras.engine.base_layer import Layer
+from tensorflow.python.ops import lookup_ops
+
+
+class CategoryLookup(Layer):
+ """Category lookup layer.
+
+ This layer looks up tokens (int or string) in a vocabulary table,
+ and return their indices (int). It converts a sequence of int or string to a
+ sequence of int.
+
+ Attributes:
+ vocabulary: The vocabulary to lookup the input. If it is a file, it
+ represents the source vocab file; If it is a list/tuple, it represents the
+ source vocab list. If it is None, the vocabulary can later be set.
+ max_tokens: The maximum size of the vocabulary for this layer. If None,
+ there is no cap on the size of the vocabulary. This is used when `adapt`
+ is called.
+ num_oov_tokens: Non-negative integer. The number of out-of-vocab tokens. All
+ out-of-vocab inputs will be assigned IDs in the range of [0,
+ num_oov_tokens) based on a hash.
+ name: Name to give to the layer.
+ **kwargs: Keyword arguments to construct a layer.
+ Input shape: A string or int tensor of shape `[batch_size, d1, ..., dm]`
+ Output shape: An int tensor of shape `[batch_size, d1, .., dm]`
+ Example: Consider a batch of a single input sample, `[["a", "c", "d", "a",
+ "x"]]`. Let's say the vocabulary is `["a", "b", "c", "d"]` and a single OOV
+ token is used (`num_oov_tokens=1`). Then the corresponding output is `[[1,
+ 3, 4, 1, 0]]`. 0 stands for an OOV token.
+ """
+
+ def __init__(self,
+ max_tokens=None,
+ num_oov_tokens=1,
+ vocabulary=None,
+ name=None,
+ **kwargs):
+ if max_tokens is not None:
+ raise ValueError('`max_tokens` and `adapt` is not supported yet.')
+ if vocabulary is None:
+ raise ValueError('for now, you must pass a `vocabulary` argument')
+ self.max_tokens = max_tokens
+ self.num_oov_tokens = num_oov_tokens
+ self.vocabulary = vocabulary
+ super(CategoryLookup, self).__init__(name, **kwargs)
+
+ def __call__(self, inputs, *args, **kwargs):
+ if isinstance(inputs, (np.ndarray, float, int)):
+ inputs = ops.convert_to_tensor(inputs)
+ self._input_dtype = inputs.dtype
+ return super(CategoryLookup, self).__call__(inputs, *args, **kwargs)
+
+ def build(self, input_shape):
+ # categorical with vocabulary list.
+ if isinstance(self.vocabulary, (tuple, list, np.ndarray)):
+ self.table = lookup_ops.index_table_from_tensor(
+ vocabulary_list=self.vocabulary,
+ num_oov_buckets=self.num_oov_tokens,
+ dtype=self._input_dtype)
+ # categorical with vocabulary file.
+ elif self.vocabulary:
+ self.table = lookup_ops.index_table_from_file(
+ vocabulary_file=self.vocabulary,
+ num_oov_buckets=self.num_oov_tokens,
+ key_dtype=self._input_dtype)
+
+ def call(self, inputs):
+ return self.table.lookup(inputs)
+
+ def compute_output_shape(self, input_shape):
+ return input_shape
+
+ def compute_output_signature(self, input_spec):
+ output_shape = self.compute_output_shape(input_spec.shape.as_list())
+ output_dtype = dtypes.int64
+ if isinstance(input_spec, sparse_tensor.SparseTensorSpec):
+ return sparse_tensor.SparseTensorSpec(
+ shape=output_shape, dtype=output_dtype)
+ else:
+ return tensor_spec.TensorSpec(shape=output_shape, dtype=output_dtype)
diff --git a/tensorflow/python/keras/layers/preprocessing/categorical_test.py b/tensorflow/python/keras/layers/preprocessing/categorical_test.py
new file mode 100644
index 0000000..78a08e9
--- /dev/null
+++ b/tensorflow/python/keras/layers/preprocessing/categorical_test.py
@@ -0,0 +1,145 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for image preprocessing layers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_spec
+from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras.layers.preprocessing import categorical
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
+class CategoryLookupVocabListTest(keras_parameterized.TestCase):
+
+ def test_vocab_list_basic(self):
+ vocabulary_list = ['A', 'B', 'C', 'D', 'E']
+ layer = categorical.CategoryLookup(
+ vocabulary=vocabulary_list, num_oov_tokens=0)
+ inp = np.asarray([['A', 'D'], ['E', 'C'], ['D', 'A']])
+ output = layer(inp)
+ self.assertAllClose(np.asarray([[0, 3], [4, 2], [3, 0]]), output)
+
+ def test_vocab_list_unknown_input(self):
+ vocabulary_list = ['A', 'B', 'C', 'D', 'E']
+ layer = categorical.CategoryLookup(vocabulary=vocabulary_list)
+ inp = np.asarray([['A', ''], ['E', 'C'], ['D', 'A']])
+ output = layer(inp)
+ self.assertAllClose(np.asarray([[0, 5], [4, 2], [3, 0]]), output)
+
+ def test_vocab_list_invalid_input(self):
+ vocabulary_list = ['A', 'B', 'C', 'D', 'E']
+ layer = categorical.CategoryLookup(
+ vocabulary=vocabulary_list, num_oov_tokens=0)
+ inp = np.asarray([['A', ''], ['E', 'C'], ['D', 'A']])
+ output = layer(inp)
+ self.assertAllClose(np.asarray([[0, -1], [4, 2], [3, 0]]), output)
+
+ def test_vocab_list_compute_output_signature(self):
+ input_shape = tensor_shape.TensorShape([2, 3])
+ input_spec = tensor_spec.TensorSpec(input_shape, dtypes.string)
+ vocabulary_list = ['A', 'B', 'C', 'D', 'E']
+ layer = categorical.CategoryLookup(
+ vocabulary=vocabulary_list, num_oov_tokens=0)
+ output_spec = layer.compute_output_signature(input_spec)
+ self.assertEqual(output_spec.shape.dims, input_shape.dims)
+ self.assertEqual(output_spec.dtype, dtypes.int64)
+
+ def test_vocab_list_sparse_input(self):
+ vocabulary_list = ['A', 'B', 'C', 'D', 'E']
+ layer = categorical.CategoryLookup(
+ vocabulary=vocabulary_list, num_oov_tokens=0)
+ inp = np.asarray([['A', ''], ['E', 'C'], ['D', 'A']])
+ indices = array_ops.where_v2(math_ops.not_equal(inp, ''))
+ sp_inp = sparse_tensor.SparseTensor(
+ indices,
+ array_ops.gather_nd_v2(inp, indices),
+ dense_shape=array_ops.shape_v2(inp, out_type=dtypes.int64))
+ output = layer(sp_inp)
+ self.assertIsInstance(output, sparse_tensor.SparseTensor)
+ self.assertAllClose(
+ np.asarray([[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]]), output.indices)
+ self.assertAllClose(np.asarray([0, 4, 2, 3, 0]), output.values)
+
+
+@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
+class CategoryLookupVocabFileTest(keras_parameterized.TestCase):
+
+ def setUp(self):
+ super(CategoryLookupVocabFileTest, self).setUp()
+
+ # Contains strings, character names from 'The Wire': omar, stringer, marlo
+ self._wire_vocabulary_file_name = test.test_src_dir_path(
+ 'python/keras/layers/preprocessing/testdata/wire_vocabulary.txt')
+ self._wire_vocabulary_size = 3
+
+ def test_vocab_file_basic(self):
+ layer = categorical.CategoryLookup(
+ vocabulary=self._wire_vocabulary_file_name, num_oov_tokens=0)
+ inp = np.asarray([['marlo', 'omar'], ['stringer', 'omar']])
+ output = layer(inp)
+ self.assertAllClose(np.asarray([[2, 0], [1, 0]]), output)
+
+ def test_vocab_file_unknown_input(self):
+ layer = categorical.CategoryLookup(
+ vocabulary=self._wire_vocabulary_file_name)
+ inp = np.asarray([['marlo', 'omar'], ['skywalker', 'omar']])
+ output = layer(inp)
+ self.assertAllClose(np.asarray([[2, 0], [3, 0]]), output)
+
+ def test_vocab_file_invalid_input(self):
+ layer = categorical.CategoryLookup(
+ vocabulary=self._wire_vocabulary_file_name, num_oov_tokens=0)
+ inp = np.asarray([['marlo', 'omar'], ['skywalker', 'omar']])
+ output = layer(inp)
+ self.assertAllClose(np.asarray([[2, 0], [-1, 0]]), output)
+
+ def test_vocab_file_compute_output_signature(self):
+ input_shape = tensor_shape.TensorShape([2, 3])
+ input_spec = tensor_spec.TensorSpec(input_shape, dtypes.string)
+ layer = categorical.CategoryLookup(
+ vocabulary=self._wire_vocabulary_file_name, num_oov_tokens=0)
+ output_spec = layer.compute_output_signature(input_spec)
+ self.assertEqual(output_spec.shape.dims, input_shape.dims)
+ self.assertEqual(output_spec.dtype, dtypes.int64)
+
+ def test_vocab_list_sparse_input(self):
+ layer = categorical.CategoryLookup(
+ vocabulary=self._wire_vocabulary_file_name, num_oov_tokens=0)
+ inp = np.asarray([['omar', ''], ['stringer', 'marlo'], ['marlo', 'omar']])
+ indices = array_ops.where_v2(math_ops.not_equal(inp, ''))
+ sp_inp = sparse_tensor.SparseTensor(
+ indices,
+ array_ops.gather_nd_v2(inp, indices),
+ dense_shape=array_ops.shape_v2(inp, out_type=dtypes.int64))
+ output = layer(sp_inp)
+ self.assertIsInstance(output, sparse_tensor.SparseTensor)
+ self.assertAllClose(
+ np.asarray([[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]]), output.indices)
+ self.assertAllClose(np.asarray([0, 1, 2, 2, 0]), output.values)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/keras/layers/preprocessing/testdata/wire_vocabulary.txt b/tensorflow/python/keras/layers/preprocessing/testdata/wire_vocabulary.txt
new file mode 100644
index 0000000..32c6b56
--- /dev/null
+++ b/tensorflow/python/keras/layers/preprocessing/testdata/wire_vocabulary.txt
@@ -0,0 +1,3 @@
+omar
+stringer
+marlo
diff --git a/tensorflow/python/keras/layers/serialization.py b/tensorflow/python/keras/layers/serialization.py
index a7c43c6..afefcc3 100644
--- a/tensorflow/python/keras/layers/serialization.py
+++ b/tensorflow/python/keras/layers/serialization.py
@@ -40,6 +40,7 @@
from tensorflow.python.keras.layers.normalization import *
from tensorflow.python.keras.layers.pooling import *
from tensorflow.python.keras.layers.preprocessing.image_preprocessing import *
+from tensorflow.python.keras.layers.preprocessing.normalization_v1 import *
from tensorflow.python.keras.layers.recurrent import *
from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import *
from tensorflow.python.keras.layers.wrappers import *
@@ -49,7 +50,8 @@
if tf2.enabled():
from tensorflow.python.keras.layers.normalization_v2 import * # pylint: disable=g-import-not-at-top
- from tensorflow.python.keras.layers.recurrent_v2 import * # pylint: disable=g-import-not-at-top
+ from tensorflow.python.keras.layers.recurrent_v2 import * # pylint: disable=g-import-not-at-top
+ from tensorflow.python.keras.layers.preprocessing.normalization import * # pylint: disable=g-import-not-at-top
# This deserialization table is added for backward compatibility, as in TF 1.13,
# BatchNormalizationV1 and BatchNormalizationV2 are used as class name for v1
diff --git a/tensorflow/python/keras/losses.py b/tensorflow/python/keras/losses.py
index 5e89cf4..e7008b5 100644
--- a/tensorflow/python/keras/losses.py
+++ b/tensorflow/python/keras/losses.py
@@ -151,8 +151,11 @@
"""Invokes the `Loss` instance.
Args:
- y_true: Ground truth values, with the same shape as 'y_pred'.
- y_pred: The predicted values.
+ y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`
+ y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`
+
+ Returns:
+ Loss values with the shape `[batch_size, d0, .. dN-1]`.
"""
NotImplementedError('Must be implemented in subclasses.')
diff --git a/tensorflow/python/keras/losses_test.py b/tensorflow/python/keras/losses_test.py
index 5776ebd..3a500bf 100644
--- a/tensorflow/python/keras/losses_test.py
+++ b/tensorflow/python/keras/losses_test.py
@@ -18,9 +18,6 @@
from __future__ import division
from __future__ import print_function
-import os
-import shutil
-
import numpy as np
from tensorflow.python import keras
@@ -29,15 +26,9 @@
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
-from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import losses_utils
from tensorflow.python.platform import test
-try:
- import h5py # pylint:disable=g-import-not-at-top
-except ImportError:
- h5py = None
-
ALL_LOSSES = [keras.losses.mean_squared_error,
keras.losses.mean_absolute_error,
keras.losses.mean_absolute_percentage_error,
@@ -53,20 +44,6 @@
keras.losses.categorical_hinge]
-class _MSEMAELoss(object):
- """Loss function with internal state, for testing serialization code."""
-
- def __init__(self, mse_fraction):
- self.mse_fraction = mse_fraction
-
- def __call__(self, y_true, y_pred, sample_weight=None):
- return (self.mse_fraction * keras.losses.mse(y_true, y_pred) +
- (1 - self.mse_fraction) * keras.losses.mae(y_true, y_pred))
-
- def get_config(self):
- return {'mse_fraction': self.mse_fraction}
-
-
class KerasLossesTest(test.TestCase):
def test_objective_shapes_3d(self):
@@ -200,39 +177,6 @@
loss = keras.backend.eval(keras.losses.categorical_hinge(y_true, y_pred))
self.assertAllClose(expected_loss, np.mean(loss))
- def test_serializing_loss_class(self):
- orig_loss_class = _MSEMAELoss(0.3)
- with generic_utils.custom_object_scope({'_MSEMAELoss': _MSEMAELoss}):
- serialized = keras.losses.serialize(orig_loss_class)
-
- with generic_utils.custom_object_scope({'_MSEMAELoss': _MSEMAELoss}):
- deserialized = keras.losses.deserialize(serialized)
- assert isinstance(deserialized, _MSEMAELoss)
- assert deserialized.mse_fraction == 0.3
-
- def test_serializing_model_with_loss_class(self):
- tmpdir = self.get_temp_dir()
- self.addCleanup(shutil.rmtree, tmpdir)
- model_filename = os.path.join(tmpdir, 'custom_loss.h5')
-
- with self.cached_session():
- with generic_utils.custom_object_scope({'_MSEMAELoss': _MSEMAELoss}):
- loss = _MSEMAELoss(0.3)
- inputs = keras.layers.Input((2,))
- outputs = keras.layers.Dense(1, name='model_output')(inputs)
- model = keras.models.Model(inputs, outputs)
- model.compile(optimizer='sgd', loss={'model_output': loss})
- model.fit(np.random.rand(256, 2), np.random.rand(256, 1))
-
- if h5py is None:
- return
-
- model.save(model_filename)
-
- with generic_utils.custom_object_scope({'_MSEMAELoss': _MSEMAELoss}):
- loaded_model = keras.models.load_model(model_filename)
- loaded_model.predict(np.random.rand(128, 2))
-
def test_loss_wrapper(self):
loss_fn = keras.losses.get('mse')
mse_obj = keras.losses.LossFunctionWrapper(loss_fn, name=loss_fn.__name__)
diff --git a/tensorflow/python/keras/mixed_precision/experimental/BUILD b/tensorflow/python/keras/mixed_precision/experimental/BUILD
index ff595cd..c39e70c 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/BUILD
+++ b/tensorflow/python/keras/mixed_precision/experimental/BUILD
@@ -43,6 +43,7 @@
],
srcs_version = "PY2AND3",
deps = [
+ ":device_compatibility_check",
"//tensorflow/python:framework",
"//tensorflow/python:mixed_precision_global_state",
],
@@ -67,6 +68,27 @@
)
py_library(
+ name = "device_compatibility_check",
+ srcs = ["device_compatibility_check.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:device_lib",
+ "//tensorflow/python:gpu_util",
+ "//tensorflow/python/eager:context",
+ ],
+)
+
+cuda_py_test(
+ name = "device_compatibility_check_test",
+ srcs = ["device_compatibility_check_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":device_compatibility_check",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_library(
name = "autocast_variable",
srcs = [
"autocast_variable.py",
diff --git a/tensorflow/python/keras/mixed_precision/experimental/device_compatibility_check.py b/tensorflow/python/keras/mixed_precision/experimental/device_compatibility_check.py
new file mode 100644
index 0000000..d92c16d
--- /dev/null
+++ b/tensorflow/python/keras/mixed_precision/experimental/device_compatibility_check.py
@@ -0,0 +1,153 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains function to log if devices are compatible with mixed precision."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import itertools
+
+from tensorflow.python.client import device_lib
+from tensorflow.python.eager import context
+from tensorflow.python.framework import gpu_util
+from tensorflow.python.platform import tf_logging
+
+
+_COMPAT_CHECK_PREFIX = 'Mixed precision compatibility check (mixed_float16): '
+_COMPAT_CHECK_OK_PREFIX = _COMPAT_CHECK_PREFIX + 'OK'
+_COMPAT_CHECK_WARNING_PREFIX = _COMPAT_CHECK_PREFIX + 'WARNING'
+_COMPAT_CHECK_WARNING_SUFFIX = (
+ 'If you will use compatible GPU(s) not attached to this host, e.g. by '
+ 'running a multi-worker model, you can ignore this warning. This message '
+ 'will only be logged once')
+
+
+def _dedup_strings(device_strs):
+ """Groups together consecutive identical strings.
+
+ For example, given:
+ ['GPU 1', 'GPU 2', 'GPU 2', 'GPU 3', 'GPU 3', 'GPU 3']
+ This function returns:
+ ['GPU 1', 'GPU 2 (x2)', 'GPU 3 (x3)']
+
+ Args:
+ device_strs: A list of strings, each representing a device.
+
+ Returns:
+ A copy of the input, but identical consecutive strings are merged into a
+ single string.
+ """
+ new_device_strs = []
+ for device_str, vals in itertools.groupby(device_strs):
+ num = len(list(vals))
+ if num == 1:
+ new_device_strs.append(device_str)
+ else:
+ new_device_strs.append('%s (x%d)' % (device_str, num))
+ return new_device_strs
+
+
+def _log_device_compatibility_check(policy_name, device_attr_list):
+ """Logs a compatibility check if the devices support the policy.
+
+ Currently only logs for the policy mixed_float16.
+
+ Args:
+ policy_name: The name of the dtype policy.
+ device_attr_list: A list of DeviceAttributes.
+ """
+ if policy_name != 'mixed_float16':
+ # TODO(b/145686977): Log if the policy is 'mixed_bfloat16'. This requires
+ # checking if a TPU is available.
+ return
+ supported_device_strs = []
+ unsupported_device_strs = []
+ for device in device_attr_list:
+ if device.device_type == 'GPU':
+ name, cc = gpu_util.compute_capability_from_device_desc(device)
+ name = name or 'Unknown GPU'
+ if cc:
+ device_str = '%s, compute capability %s.%s' % (name, cc[0], cc[1])
+ if cc >= (7, 0):
+ supported_device_strs.append(device_str)
+ else:
+ unsupported_device_strs.append(device_str)
+ else:
+ unsupported_device_strs.append(
+ name + ', no compute capability (probably not an Nvidia GPU)')
+
+ if unsupported_device_strs:
+ warning_str = _COMPAT_CHECK_WARNING_PREFIX + '\n'
+ if supported_device_strs:
+ warning_str += ('Some of your GPUs may run slowly with dtype policy '
+ 'mixed_float16 because they do not all have compute '
+ 'capability of at least 7.0. Your GPUs:\n')
+ elif len(unsupported_device_strs) == 1:
+ warning_str += ('Your GPU may run slowly with dtype policy mixed_float16 '
+ 'because it does not have compute capability of at least '
+ '7.0. Your GPU:\n')
+ else:
+ warning_str += ('Your GPUs may run slowly with dtype policy '
+ 'mixed_float16 because they do not have compute '
+ 'capability of at least 7.0. Your GPUs:\n')
+ for device_str in _dedup_strings(supported_device_strs +
+ unsupported_device_strs):
+ warning_str += ' ' + device_str + '\n'
+ warning_str += ('See https://developer.nvidia.com/cuda-gpus for a list of '
+ 'GPUs and their compute capabilities.\n')
+ warning_str += _COMPAT_CHECK_WARNING_SUFFIX
+ tf_logging.warn(warning_str)
+ elif not supported_device_strs:
+ tf_logging.warn('%s\n'
+ 'The dtype policy mixed_float16 may run slowly because '
+ 'this machine does not have a GPU. Only Nvidia GPUs with '
+ 'compute capability of at least 7.0 run quickly with '
+ 'mixed_float16.\n%s' % (_COMPAT_CHECK_WARNING_PREFIX,
+ _COMPAT_CHECK_WARNING_SUFFIX))
+ elif len(supported_device_strs) == 1:
+ tf_logging.info('%s\n'
+ 'Your GPU will likely run quickly with dtype policy '
+ 'mixed_float16 as it has compute capability of at least '
+ '7.0. Your GPU: %s' % (_COMPAT_CHECK_OK_PREFIX,
+ supported_device_strs[0]))
+ else:
+ tf_logging.info('%s\n'
+ 'Your GPUs will likely run quickly with dtype policy '
+ 'mixed_float16 as they all have compute capability of at '
+ 'least 7.0' % _COMPAT_CHECK_OK_PREFIX)
+
+
+_logged_compatibility_check = False
+
+
+def log_device_compatibility_check(policy_name):
+ """Logs a compatibility check if the devices support the policy.
+
+ Currently only logs for the policy mixed_float16. A log is shown only the
+ first time this function is called.
+
+ Args:
+ policy_name: The name of the dtype policy.
+ """
+ global _logged_compatibility_check
+ # In graph mode, calling list_local_devices may initialize some session state,
+ # so we only call it in eager mode.
+ if not context.executing_eagerly() or _logged_compatibility_check:
+ return
+ _logged_compatibility_check = True
+ device_attr_list = device_lib.list_local_devices()
+ _log_device_compatibility_check(policy_name, device_attr_list)
+
diff --git a/tensorflow/python/keras/mixed_precision/experimental/device_compatibility_check_test.py b/tensorflow/python/keras/mixed_precision/experimental/device_compatibility_check_test.py
new file mode 100644
index 0000000..c3315ca
--- /dev/null
+++ b/tensorflow/python/keras/mixed_precision/experimental/device_compatibility_check_test.py
@@ -0,0 +1,162 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests the device compatibility check."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import re
+
+from tensorflow.core.framework import device_attributes_pb2
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras.mixed_precision.experimental import device_compatibility_check
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
+
+
+def _get_device_attrs(device_type, device_name=None, cc_major=None,
+ cc_minor=None):
+ if device_type == 'CPU':
+ return device_attributes_pb2.DeviceAttributes(device_type='CPU')
+ assert device_type == 'GPU', 'Invalid device type: %s' % (device_type,)
+ if not device_name:
+ return device_attributes_pb2.DeviceAttributes(device_type='GPU')
+ physical_device_desc = (
+ 'device: 0, name: %s, pci bus id: 0:0:0.0' % device_name)
+ if cc_major:
+ physical_device_desc += ', compute capability: %d.%d' % (cc_major, cc_minor)
+ return device_attributes_pb2.DeviceAttributes(
+ device_type='GPU', physical_device_desc=physical_device_desc)
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class DeviceCompatibilityCheckTest(test.TestCase):
+
+ def _test_compat_check(self, device_attr_list, should_warn, expected_regex,
+ policy_name='mixed_float16'):
+ with test.mock.patch.object(tf_logging, 'warn') as mock_warn, \
+ test.mock.patch.object(tf_logging, 'info') as mock_info:
+ device_compatibility_check._log_device_compatibility_check(
+ policy_name, device_attr_list)
+ if should_warn:
+ self.assertRegexpMatches(mock_warn.call_args[0][0], expected_regex)
+ mock_info.assert_not_called()
+ else:
+ self.assertRegexpMatches(mock_info.call_args[0][0], expected_regex)
+ mock_warn.assert_not_called()
+
+ def test_supported(self):
+ device_attrs_list = [_get_device_attrs('GPU', 'GPU 1', 7, 1)]
+ regex = re.compile(
+ r'.*compatibility check \(mixed_float16\): OK\n'
+ r'Your GPU will likely run quickly with dtype policy mixed_float16 as '
+ r'it has compute capability of at least 7.0. Your GPU: GPU 1, compute '
+ r'capability 7.1', flags=re.MULTILINE)
+ self._test_compat_check(device_attrs_list, False, regex)
+ device_attrs_list.append(_get_device_attrs('CPU'))
+ self._test_compat_check(device_attrs_list, False, regex)
+
+ device_attrs_list = [
+ _get_device_attrs('CPU', 'CPU'),
+ _get_device_attrs('GPU', 'GPU 1', 7, 0),
+ _get_device_attrs('GPU', 'GPU 2', 7, 1),
+ _get_device_attrs('GPU', 'GPU 3', 8, 0),
+ ]
+ regex = re.compile(
+ r'.*compatibility check \(mixed_float16\): OK\n'
+ r'Your GPUs will likely run quickly with dtype policy mixed_float16 as '
+ r'they all have compute capability of at least 7.0', flags=re.MULTILINE)
+ self._test_compat_check(device_attrs_list, False, regex)
+
+ def test_unsupported(self):
+ device_attrs_list = [
+ _get_device_attrs('GPU', 'GPU 1', 6, 0)
+ ]
+ regex = re.compile(
+ r'.*compatibility check \(mixed_float16\): WARNING\n'
+ r'Your GPU may run slowly with dtype policy mixed_float16.*\n'
+ r' GPU 1, compute capability 6.0\n'
+ r'See.*', flags=re.MULTILINE)
+ self._test_compat_check(device_attrs_list, True, regex)
+ device_attrs_list.append(_get_device_attrs('CPU'))
+ self._test_compat_check(device_attrs_list, True, regex)
+
+ device_attrs_list = [
+ _get_device_attrs('GPU')
+ ]
+ regex = re.compile(
+ r'.*compatibility check \(mixed_float16\): WARNING\n'
+ r'Your GPU may run slowly with dtype policy mixed_float16.*\n'
+ r' Unknown GPU, no compute capability \(probably not an Nvidia GPU\)\n'
+ r'See.*', flags=re.MULTILINE)
+ self._test_compat_check(device_attrs_list, True, regex)
+ device_attrs_list.append(_get_device_attrs('CPU'))
+ self._test_compat_check(device_attrs_list, True, regex)
+
+ device_attrs_list = [
+ _get_device_attrs('CPU', 'CPU'),
+ _get_device_attrs('GPU', 'GPU 1', 6, 0),
+ _get_device_attrs('GPU', 'GPU 2', 3, 10),
+ ]
+ regex = re.compile(
+ r'.*compatibility check \(mixed_float16\): WARNING\n'
+ r'Your GPUs may run slowly with dtype policy mixed_float16.*\n'
+ r' GPU 1, compute capability 6.0\n'
+ r' GPU 2, compute capability 3.10\n'
+ r'See.*', flags=re.MULTILINE)
+ self._test_compat_check(device_attrs_list, True, regex)
+
+ device_attrs_list = [
+ _get_device_attrs('CPU', 'CPU'),
+ _get_device_attrs('GPU', 'GPU 1', 6, 0),
+ _get_device_attrs('GPU', 'GPU 1', 6, 0),
+ _get_device_attrs('GPU', 'GPU 1', 6, 0),
+ _get_device_attrs('GPU', 'GPU 2', 3, 10),
+ ]
+ regex = re.compile(
+ r'.*compatibility check \(mixed_float16\): WARNING\n'
+ r'Your GPUs may run slowly with dtype policy mixed_float16.*\n'
+ r' GPU 1, compute capability 6.0 \(x3\)\n'
+ r' GPU 2, compute capability 3.10\n'
+ r'See.*', flags=re.MULTILINE)
+ self._test_compat_check(device_attrs_list, True, regex)
+
+ device_attrs_list = [_get_device_attrs('CPU')]
+ regex = re.compile(
+ r'.*compatibility check \(mixed_float16\): WARNING\n'
+ r'The dtype policy mixed_float16 may run slowly because this machine '
+ r'does not have a GPU', flags=re.MULTILINE)
+ self._test_compat_check(device_attrs_list, True, regex)
+
+ def test_mix_of_supported_and_unsupported(self):
+ device_attrs_list = [
+ _get_device_attrs('GPU', 'GPU 1', 7, 0),
+ _get_device_attrs('GPU', 'GPU 1', 7, 0),
+ _get_device_attrs('GPU', 'GPU 2', 6, 0)
+ ]
+ regex = re.compile(
+ r'.*compatibility check \(mixed_float16\): WARNING\n'
+ r'Some of your GPUs may run slowly with dtype policy mixed_float16.*\n'
+ r' GPU 1, compute capability 7.0 \(x2\)\n'
+ r' GPU 2, compute capability 6.0\n'
+ r'See.*', flags=re.MULTILINE)
+ self._test_compat_check(device_attrs_list, True, regex)
+ device_attrs_list.append(_get_device_attrs('CPU'))
+ self._test_compat_check(device_attrs_list, True, regex)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/keras/mixed_precision/experimental/keras_test.py b/tensorflow/python/keras/mixed_precision/experimental/keras_test.py
index 6bb73cd..a8fc290 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/keras_test.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/keras_test.py
@@ -39,6 +39,7 @@
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import base_layer_utils
+from tensorflow.python.keras.engine import input_spec
from tensorflow.python.keras.layers import core
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer
from tensorflow.python.keras.mixed_precision.experimental import policy
@@ -456,6 +457,12 @@
'save_format': 'tf',
'use_regularizer': True,
}, {
+ 'testcase_name': 'saved_model_input_spec',
+ 'strategy_fn': default_strategy_fn,
+ 'save_format': 'tf',
+ 'use_regularizer': True,
+ 'use_input_spec': True,
+ }, {
'testcase_name': 'h5',
'strategy_fn': default_strategy_fn,
'save_format': 'h5',
@@ -466,6 +473,12 @@
'save_format': 'tf',
'use_regularizer': True,
}, {
+ 'testcase_name': 'saved_model_input_spec_distribute',
+ 'strategy_fn': create_mirrored_strategy,
+ 'save_format': 'tf',
+ 'use_regularizer': True,
+ 'use_input_spec': True,
+ }, {
'testcase_name': 'h5_distribute',
'strategy_fn': create_mirrored_strategy,
'save_format': 'h5',
@@ -482,6 +495,7 @@
policy_name='mixed_float16',
get_config=False,
save_format=None,
+ use_input_spec=False,
experimental_run_tf_function=True):
self._skip_if_strategy_unsupported(strategy_fn, check_model_type=True)
self._skip_if_save_format_unsupported(save_format)
@@ -496,6 +510,8 @@
use_operator=use_operator,
regularizer=regularizer,
input_shape=(1,))
+ if use_input_spec:
+ layer.input_spec = input_spec.InputSpec(shape=(2, 1))
cast_f32_layer = layers.Lambda(lambda x: math_ops.cast(x, 'float32'))
model = testing_utils.get_model_from_layers(
[layer, cast_f32_layer], input_shape=(1,),
diff --git a/tensorflow/python/keras/mixed_precision/experimental/policy.py b/tensorflow/python/keras/mixed_precision/experimental/policy.py
index 67dc947..6cdec23 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/policy.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/policy.py
@@ -24,6 +24,7 @@
from tensorflow.python.framework import dtypes
from tensorflow.python.keras import backend
from tensorflow.python.keras.engine import base_layer_utils
+from tensorflow.python.keras.mixed_precision.experimental import device_compatibility_check
from tensorflow.python.keras.mixed_precision.experimental import loss_scale as keras_loss_scale_module
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.platform import tf_logging
@@ -281,6 +282,9 @@
(loss_scale, name))
self._loss_scale = keras_loss_scale_module.get(loss_scale)
+ if name in ('mixed_float16', 'mixed_bloat16'):
+ device_compatibility_check.log_device_compatibility_check(name)
+
def _parse_name(self, name):
"""Parses a Policy name into a compute and variable dtype.
diff --git a/tensorflow/python/keras/mixed_precision/experimental/policy_test.py b/tensorflow/python/keras/mixed_precision/experimental/policy_test.py
index 7859251..330ed9e 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/policy_test.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/policy_test.py
@@ -18,11 +18,13 @@
from __future__ import division
from __future__ import print_function
+from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import base_layer_utils
+from tensorflow.python.keras.mixed_precision.experimental import device_compatibility_check
from tensorflow.python.keras.mixed_precision.experimental import policy as mp_policy
from tensorflow.python.keras.optimizer_v2 import gradient_descent
from tensorflow.python.platform import test
@@ -166,6 +168,30 @@
mock_warn.assert_not_called()
@testing_utils.enable_v2_dtype_behavior
+ def test_device_compatibility_warning(self):
+ with context.eager_mode():
+ device_compatibility_check._logged_compatibility_check = False
+ with test.mock.patch.object(tf_logging, 'warn') as mock_warn, \
+ test.mock.patch.object(tf_logging, 'info') as mock_info:
+ mp_policy.Policy('mixed_float16')
+ if mock_warn.called:
+ self.assertRegexpMatches(
+ mock_warn.call_args[0][0],
+ r'Mixed precision compatibility check \(mixed_float16\): WARNING.*')
+ mock_info.assert_not_called()
+ else:
+ self.assertRegexpMatches(
+ mock_info.call_args[0][0],
+ r'Mixed precision compatibility check \(mixed_float16\): OK.*')
+
+ # Assert message is only logged once
+ with test.mock.patch.object(tf_logging, 'warn') as mock_warn, \
+ test.mock.patch.object(tf_logging, 'info') as mock_info:
+ mp_policy.Policy('mixed_float16')
+ mock_warn.assert_not_called()
+ mock_info.assert_not_called()
+
+ @testing_utils.enable_v2_dtype_behavior
def test_policy_scope(self):
if base_layer_utils.v2_dtype_behavior_enabled():
default_policy = 'float32'
diff --git a/tensorflow/python/keras/mixed_precision/experimental/test_util.py b/tensorflow/python/keras/mixed_precision/experimental/test_util.py
index aefe3ae..fff2689 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/test_util.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/test_util.py
@@ -124,7 +124,7 @@
for inp in inputs_flattened:
assert inp.dtype.base_dtype == self._assert_type, (
'Input tensor has type %s which does not match assert type %s' %
- (inp.dtype.name, self._assert_type.name))
+ (inp.dtype.name, self._assert_type))
class AddLayer(AssertTypeLayer):
diff --git a/tensorflow/python/keras/model_subclassing_compiled_test.py b/tensorflow/python/keras/model_subclassing_compiled_test.py
index 54a91bd..18cb6e5 100644
--- a/tensorflow/python/keras/model_subclassing_compiled_test.py
+++ b/tensorflow/python/keras/model_subclassing_compiled_test.py
@@ -27,7 +27,6 @@
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import model_subclassing_test_util as model_util
from tensorflow.python.keras import testing_utils
-from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
try:
@@ -455,29 +454,6 @@
loss = model.train_on_batch(x, y)
self.assertGreater(loss, 0.1)
- def test_no_loss_in_compile(self):
-
- class InternalLossModel(keras.Model):
-
- def __init__(self):
- super(InternalLossModel, self).__init__()
- self.dense = keras.layers.Dense(1)
-
- def call(self, inputs):
- out = self.dense(inputs)
- self.add_loss(math_ops.reduce_sum(out))
- return out
-
- model = InternalLossModel()
- x = np.ones((10, 10))
- model.predict(x)
- model.compile(
- optimizer='rmsprop',
- run_eagerly=testing_utils.should_run_eagerly(),
- experimental_run_tf_function=testing_utils.should_run_tf_function())
- model.fit(x)
- model.evaluate(x)
-
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/models_test.py b/tensorflow/python/keras/models_test.py
index 81f419c..8a10180 100644
--- a/tensorflow/python/keras/models_test.py
+++ b/tensorflow/python/keras/models_test.py
@@ -28,6 +28,7 @@
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import metrics
@@ -312,10 +313,10 @@
return any('Placeholder' in s for s in ops_types)
-@keras_parameterized.run_with_all_model_types
-@keras_parameterized.run_all_keras_modes
class CheckpointingTests(keras_parameterized.TestCase):
+ @keras_parameterized.run_with_all_model_types
+ @keras_parameterized.run_all_keras_modes
def test_optimizer_dependency(self):
model = _get_model()
opt = adam.AdamOptimizer(.01)
@@ -337,6 +338,37 @@
model.load_weights(save_prefix)
self.assertEqual(12., self.evaluate(beta1_power))
+ @keras_parameterized.run_with_all_model_types(exclude_models=['subclass'])
+ def test_layer_tracking(self):
+ with self.cached_session():
+ model = _get_model(input_shape=(4,))
+
+ if testing_utils.get_model_type() == 'subclass':
+ # Subclassed model must be built separately.
+ model._set_inputs(tensor_spec.TensorSpec((None, 4)))
+
+ # Ensure that checkpoints are compatible with another model with the same
+ # layers, even if the model isn't built until after initialization.
+ layers = _get_layers(input_shape=None, add_input_layer=False)
+ model2 = models.Sequential(layers)
+ # Build model by calling it.
+ model2.predict_on_batch(np.random.random((10, 4)))
+
+ model_path = os.path.join(self.get_temp_dir(), 'model_ckpt')
+ model.save_weights(model_path)
+ model2_path = os.path.join(self.get_temp_dir(), 'model2_ckpt')
+ model2.save_weights(model2_path)
+
+ # Check that the checkpoints are compatible with both models.
+ model.load_weights(model2_path)
+ self.assertAllClose(self.evaluate(model.weights),
+ self.evaluate(model2.weights))
+
+ model.load_weights(model_path)
+ model2.load_weights(model_path)
+ self.assertAllClose(self.evaluate(model.weights),
+ self.evaluate(model2.weights))
+
@keras_parameterized.run_all_keras_modes
class TestModelBackend(keras_parameterized.TestCase):
diff --git a/tensorflow/python/keras/optimizer_v2/BUILD b/tensorflow/python/keras/optimizer_v2/BUILD
index bb20d4b..6e0153f 100644
--- a/tensorflow/python/keras/optimizer_v2/BUILD
+++ b/tensorflow/python/keras/optimizer_v2/BUILD
@@ -102,7 +102,6 @@
size = "medium",
srcs = ["adamax_test.py"],
shard_count = 4,
- tags = ["no_rocm"],
deps = [
":optimizer_v2",
"//tensorflow/python:client_testlib",
@@ -201,7 +200,6 @@
tags = [
"manual",
"no_oss",
- "no_rocm",
"no_windows",
"notap",
],
diff --git a/tensorflow/python/keras/premade/linear.py b/tensorflow/python/keras/premade/linear.py
index 451ac3d..dd3e1fd 100644
--- a/tensorflow/python/keras/premade/linear.py
+++ b/tensorflow/python/keras/premade/linear.py
@@ -64,7 +64,7 @@
units=1,
activation=None,
use_bias=True,
- kernel_initializer='glorot_uniform',
+ kernel_initializer='zeros',
bias_initializer='zeros',
kernel_regularizer=None,
bias_regularizer=None,
diff --git a/tensorflow/python/keras/premade/wide_deep.py b/tensorflow/python/keras/premade/wide_deep.py
index f9314ef..bf90314 100644
--- a/tensorflow/python/keras/premade/wide_deep.py
+++ b/tensorflow/python/keras/premade/wide_deep.py
@@ -101,8 +101,7 @@
dnn_output = self.dnn_model(dnn_inputs, training=training)
else:
dnn_output = self.dnn_model(dnn_inputs)
- output = nest.map_structure(lambda x, y: 0.5 * (x + y), linear_output,
- dnn_output)
+ output = nest.map_structure(lambda x, y: (x + y), linear_output, dnn_output)
if self.activation:
return nest.map_structure(self.activation, output)
return output
diff --git a/tensorflow/python/keras/premade/wide_deep_test.py b/tensorflow/python/keras/premade/wide_deep_test.py
index f7e10fc..e2f471e 100644
--- a/tensorflow/python/keras/premade/wide_deep_test.py
+++ b/tensorflow/python/keras/premade/wide_deep_test.py
@@ -78,9 +78,9 @@
self.evaluate(variables.global_variables_initializer())
wide_deep_model.fit(inputs, output, epochs=1)
self.assertAllClose(
- [[0.3]],
+ [[0.6]],
self.evaluate(wide_deep_model.linear_model.dense_layers[0].kernel))
- self.assertAllClose([[0.9]],
+ self.assertAllClose([[1.8]],
self.evaluate(
wide_deep_model.dnn_model.layers[0].kernel))
@@ -112,15 +112,15 @@
wide_deep_model = wide_deep.WideDeepModel(linear_model, dnn_model)
inp_np = np.asarray([[1.]])
out1, out2 = wide_deep_model(inp_np)
- # output should be 0.5 * (0.5 + 0.1), and 0.5 * (0.3 - 0.5)
- self.assertAllClose([[0.3]], out1)
- self.assertAllClose([[-0.1]], out2)
+ # output should be (0.5 + 0.1), and (0.3 - 0.5)
+ self.assertAllClose([[0.6]], out1)
+ self.assertAllClose([[-0.2]], out2)
wide_deep_model = wide_deep.WideDeepModel(
linear_model, dnn_model, activation='relu')
out1, out2 = wide_deep_model(inp_np)
- # output should be relu(0.5 * (0.5 + 0.1)), and relu(0.5 * (0.3 - 0.5))
- self.assertAllClose([[0.3]], out1)
+ # output should be relu((0.5 + 0.1)), and relu((0.3 - 0.5))
+ self.assertAllClose([[0.6]], out1)
self.assertAllClose([[0.]], out2)
def test_wide_deep_model_with_single_optimizer(self):
diff --git a/tensorflow/python/keras/regularizers_test.py b/tensorflow/python/keras/regularizers_test.py
index c8180ca..f700fb9 100644
--- a/tensorflow/python/keras/regularizers_test.py
+++ b/tensorflow/python/keras/regularizers_test.py
@@ -146,7 +146,7 @@
optimizer='sgd',
run_eagerly=testing_utils.should_run_eagerly(),
experimental_run_tf_function=testing_utils.should_run_tf_function())
- self.assertEqual(len(model.losses), 5)
+ self.assertLen(model.losses, 5)
@keras_parameterized.run_all_keras_modes
@parameterized.named_parameters([
@@ -169,7 +169,7 @@
optimizer='sgd',
run_eagerly=testing_utils.should_run_eagerly(),
experimental_run_tf_function=testing_utils.should_run_tf_function())
- self.assertEqual(len(model.losses), 6)
+ self.assertLen(model.losses, 6)
@keras_parameterized.run_all_keras_modes
@parameterized.named_parameters([
@@ -197,7 +197,13 @@
optimizer='sgd',
run_eagerly=testing_utils.should_run_eagerly(),
experimental_run_tf_function=testing_utils.should_run_tf_function())
- self.assertEqual(len(model.losses), 14)
+
+ # We expect to see 9 losses on the model:
+ # - 2 from the 2 add_loss calls on the outer model.
+ # - 3 from the weight regularizers on the shared_dense layer, unshared_dense
+ # in inner model 1, unshared_dense in inner model 2.
+ # - 4 from activity regularizers on the shared_dense layer.
+ self.assertLen(model.losses, 9)
if __name__ == '__main__':
diff --git a/tensorflow/python/keras/saving/hdf5_format_test.py b/tensorflow/python/keras/saving/hdf5_format_test.py
index 28101cf..532379d 100644
--- a/tensorflow/python/keras/saving/hdf5_format_test.py
+++ b/tensorflow/python/keras/saving/hdf5_format_test.py
@@ -21,6 +21,7 @@
import os
import shutil
import tempfile
+
from absl.testing import parameterized
import numpy as np
@@ -827,6 +828,20 @@
evaluation_results['sparse_categorical_crossentropy'] +
evaluation_results['custom_loss'], evaluation_results['loss'], 1e-6)
+ def test_save_uncompiled_model_with_optimizer(self):
+ saved_model_dir = self._save_model_dir()
+ save_format = testing_utils.get_save_format()
+ model = keras.models.Sequential([keras.layers.Dense(1, input_shape=(3,))])
+ # Set the model's optimizer but don't compile. This can happen if the model
+ # is trained with a custom training loop.
+ model.optimizer = keras.optimizer_v2.rmsprop.RMSprop(lr=0.0001)
+ model.save(saved_model_dir, save_format=save_format)
+
+ if save_format in ['tf', 'tensorflow']:
+ loaded = keras.models.load_model(saved_model_dir)
+ self.assertIsInstance(loaded.optimizer,
+ keras.optimizer_v2.optimizer_v2.OptimizerV2)
+
# Factory functions to create models that will be serialized inside a Network.
def _make_graph_network(input_size, output_size):
diff --git a/tensorflow/python/keras/saving/losses_serialization_test.py b/tensorflow/python/keras/saving/losses_serialization_test.py
new file mode 100644
index 0000000..60252b1
--- /dev/null
+++ b/tensorflow/python/keras/saving/losses_serialization_test.py
@@ -0,0 +1,198 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Keras losses serialization."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import shutil
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python import keras
+from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras import layers
+from tensorflow.python.keras import losses
+from tensorflow.python.keras import optimizer_v2
+from tensorflow.python.keras import testing_utils
+from tensorflow.python.keras.utils import generic_utils
+from tensorflow.python.keras.utils import losses_utils
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+try:
+ import h5py # pylint:disable=g-import-not-at-top
+except ImportError:
+ h5py = None
+
+
+# Custom loss class
+class MyMeanAbsoluteError(losses.LossFunctionWrapper):
+
+ def __init__(self,
+ reduction=losses_utils.ReductionV2.AUTO,
+ name='mean_absolute_error'):
+ super(MyMeanAbsoluteError, self).__init__(
+ _my_mae, name=name, reduction=reduction)
+
+
+# Custom loss function
+def _my_mae(y_true, y_pred):
+ return keras.backend.mean(math_ops.abs(y_pred - y_true), axis=-1)
+
+
+def _get_multi_io_model():
+ inp_1 = layers.Input(shape=(1,), name='input_1')
+ inp_2 = layers.Input(shape=(1,), name='input_2')
+ d = testing_utils.Bias(name='output')
+ out_1 = d(inp_1)
+ out_2 = d(inp_2)
+ return keras.Model([inp_1, inp_2], [out_1, out_2])
+
+
+@keras_parameterized.run_all_keras_modes
+@parameterized.named_parameters([
+ dict(testcase_name='string', value='mae'),
+ dict(testcase_name='built_in_fn', value=losses.mae),
+ dict(testcase_name='built_in_class', value=losses.MeanAbsoluteError()),
+ dict(testcase_name='custom_fn', value=_my_mae),
+ dict(testcase_name='custom_class', value=MyMeanAbsoluteError()),
+ dict(testcase_name='list_of_strings', value=['mae', 'mae']),
+ dict(testcase_name='list_of_built_in_fns', value=[losses.mae, losses.mae]),
+ dict(
+ testcase_name='list_of_built_in_classes',
+ value=[losses.MeanAbsoluteError(),
+ losses.MeanAbsoluteError()]),
+ dict(testcase_name='list_of_custom_fns', value=[_my_mae, _my_mae]),
+ dict(
+ testcase_name='list_of_custom_classes',
+ value=[MyMeanAbsoluteError(),
+ MyMeanAbsoluteError()]),
+ dict(
+ testcase_name='dict_of_string',
+ value={
+ 'output': 'mae',
+ 'output_1': 'mae',
+ }),
+ dict(
+ testcase_name='dict_of_built_in_fn',
+ value={
+ 'output': losses.mae,
+ 'output_1': losses.mae,
+ }),
+ dict(
+ testcase_name='dict_of_built_in_class',
+ value={
+ 'output': losses.MeanAbsoluteError(),
+ 'output_1': losses.MeanAbsoluteError(),
+ }),
+ dict(
+ testcase_name='dict_of_custom_fn',
+ value={
+ 'output': _my_mae,
+ 'output_1': _my_mae
+ }),
+ dict(
+ testcase_name='dict_of_custom_class',
+ value={
+ 'output': MyMeanAbsoluteError(),
+ 'output_1': MyMeanAbsoluteError(),
+ }),
+])
+class LossesSerialization(keras_parameterized.TestCase):
+
+ def setUp(self):
+ super(LossesSerialization, self).setUp()
+ tmpdir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, tmpdir)
+ self.model_filename = os.path.join(tmpdir, 'tmp_model_loss.h5')
+ self.x = np.array([[0.], [1.], [2.]], dtype='float32')
+ self.y = np.array([[0.5], [2.], [3.5]], dtype='float32')
+ self.w = np.array([1.25, 0.5, 1.25], dtype='float32')
+
+ def test_serializing_model_with_loss_with_custom_object_scope(self, value):
+ with generic_utils.custom_object_scope({
+ 'MyMeanAbsoluteError': MyMeanAbsoluteError,
+ '_my_mae': _my_mae,
+ 'Bias': testing_utils.Bias,
+ }):
+ model = _get_multi_io_model()
+ model.compile(
+ optimizer_v2.gradient_descent.SGD(0.1),
+ loss=value,
+ run_eagerly=testing_utils.should_run_eagerly(),
+ experimental_run_tf_function=testing_utils.should_run_tf_function())
+ history = model.fit([self.x, self.x], [self.y, self.y],
+ batch_size=3,
+ epochs=3,
+ sample_weight=[self.w, self.w])
+
+ # Assert training.
+ self.assertAllClose(history.history['loss'], [2., 1.6, 1.2], 1e-3)
+ eval_results = model.evaluate([self.x, self.x], [self.y, self.y],
+ sample_weight=[self.w, self.w])
+
+ if h5py is None:
+ return
+ model.save(self.model_filename)
+ loaded_model = keras.models.load_model(self.model_filename)
+ loaded_model.predict([self.x, self.x])
+ loaded_eval_results = loaded_model.evaluate(
+ [self.x, self.x], [self.y, self.y], sample_weight=[self.w, self.w])
+
+ # Assert all evaluation results are the same.
+ self.assertAllClose(eval_results, loaded_eval_results, 1e-9)
+
+ def test_serializing_model_with_loss_with_custom_objects(self, value):
+ model = _get_multi_io_model()
+ model.compile(
+ optimizer_v2.gradient_descent.SGD(0.1),
+ loss=value,
+ run_eagerly=testing_utils.should_run_eagerly(),
+ experimental_run_tf_function=testing_utils.should_run_tf_function())
+ history = model.fit([self.x, self.x], [self.y, self.y],
+ batch_size=3,
+ epochs=3,
+ sample_weight=[self.w, self.w])
+
+ # Assert training.
+ self.assertAllClose(history.history['loss'], [2., 1.6, 1.2], 1e-3)
+ eval_results = model.evaluate([self.x, self.x], [self.y, self.y],
+ sample_weight=[self.w, self.w])
+
+ if h5py is None:
+ return
+ model.save(self.model_filename)
+ loaded_model = keras.models.load_model(
+ self.model_filename,
+ custom_objects={
+ 'MyMeanAbsoluteError': MyMeanAbsoluteError,
+ '_my_mae': _my_mae,
+ 'Bias': testing_utils.Bias,
+ })
+ loaded_model.predict([self.x, self.x])
+ loaded_eval_results = loaded_model.evaluate([self.x, self.x],
+ [self.y, self.y],
+ sample_weight=[self.w, self.w])
+
+ # Assert all evaluation results are the same.
+ self.assertAllClose(eval_results, loaded_eval_results, 1e-9)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/keras/saving/metrics_serialization_test.py b/tensorflow/python/keras/saving/metrics_serialization_test.py
index e0a7fc9..10eee4d 100644
--- a/tensorflow/python/keras/saving/metrics_serialization_test.py
+++ b/tensorflow/python/keras/saving/metrics_serialization_test.py
@@ -20,6 +20,7 @@
import os
import shutil
+
from absl.testing import parameterized
import numpy as np
@@ -173,7 +174,7 @@
def get_instance(x):
if isinstance(x, str):
return x
- if issubclass(x, metrics.Metric):
+ if isinstance(x, type) and issubclass(x, metrics.Metric):
return x()
return x
@@ -219,7 +220,7 @@
def get_instance(x):
if isinstance(x, str):
return x
- if issubclass(x, metrics.Metric):
+ if isinstance(x, type) and issubclass(x, metrics.Metric):
return x()
return x
diff --git a/tensorflow/python/keras/saving/saved_model/save_impl.py b/tensorflow/python/keras/saving/saved_model/save_impl.py
index dca39ba..580e452 100644
--- a/tensorflow/python/keras/saving/saved_model/save_impl.py
+++ b/tensorflow/python/keras/saving/saved_model/save_impl.py
@@ -365,7 +365,7 @@
elif layer.input_spec is not None:
def to_tensor_spec_or_none(x):
- spec = input_spec.to_tensor_spec(x, layer.dtype)
+ spec = input_spec.to_tensor_spec(x, layer._compute_dtype) # pylint: disable=protected-access
# If the shape is too general (e.g. multiple dimensions are allowed),
# return None so that separate functions can be generated for each
# inferred input signature.
diff --git a/tensorflow/python/keras/saving/saved_model_experimental.py b/tensorflow/python/keras/saving/saved_model_experimental.py
index a4cce31..0c6714b 100644
--- a/tensorflow/python/keras/saving/saved_model_experimental.py
+++ b/tensorflow/python/keras/saving/saved_model_experimental.py
@@ -18,6 +18,7 @@
from __future__ import print_function
import os
+
import six
from tensorflow.python.client import session
diff --git a/tensorflow/python/keras/saving/saving_utils.py b/tensorflow/python/keras/saving/saving_utils.py
index 1bbae94..0949aa1 100644
--- a/tensorflow/python/keras/saving/saving_utils.py
+++ b/tensorflow/python/keras/saving/saving_utils.py
@@ -188,26 +188,31 @@
'Prefer using a Keras optimizer instead '
'(see keras.io/optimizers).')
else:
- metadata['training_config'] = {
- 'loss': model.loss,
- # pylint: disable=protected-access
- 'metrics': model._compile_metrics,
- 'weighted_metrics': model._compile_weighted_metrics,
- # pylint: enable=protected-access
- 'sample_weight_mode': model.sample_weight_mode,
- 'loss_weights': model.loss_weights,
- }
- if isinstance(model.optimizer, optimizer_v2.RestoredOptimizer):
- raise NotImplementedError(
- 'As of now, Optimizers loaded from SavedModel cannot be saved. '
- 'If you\'re calling `model.save` or `tf.keras.models.save_model`, '
- 'please set the `include_optimizer` option to `False`. For '
- '`tf.saved_model.save`, delete the optimizer from the model.')
- else:
- optimizer_config = {
- 'class_name': model.optimizer.__class__.__name__,
- 'config': model.optimizer.get_config()}
- metadata['training_config']['optimizer_config'] = optimizer_config
+ try:
+ metadata['training_config'] = {
+ 'loss': model.loss,
+ # pylint: disable=protected-access
+ 'metrics': model._compile_metrics,
+ 'weighted_metrics': model._compile_weighted_metrics,
+ # pylint: enable=protected-access
+ 'sample_weight_mode': model.sample_weight_mode,
+ 'loss_weights': model.loss_weights,
+ }
+ if isinstance(model.optimizer, optimizer_v2.RestoredOptimizer):
+ raise NotImplementedError(
+ 'As of now, Optimizers loaded from SavedModel cannot be saved. '
+ 'If you\'re calling `model.save` or `tf.keras.models.save_model`,'
+ ' please set the `include_optimizer` option to `False`. For '
+ '`tf.saved_model.save`, delete the optimizer from the model.')
+ else:
+ optimizer_config = {
+ 'class_name': model.optimizer.__class__.__name__,
+ 'config': model.optimizer.get_config()}
+ metadata['training_config']['optimizer_config'] = optimizer_config
+ except AttributeError:
+ pass # If the model has an optimizer, but not all of the attributes
+ # loss, _compile_metrics, etc., then it was not compiled using
+ # model.compile. In this case, do not save the training config.
return metadata
diff --git a/tensorflow/python/keras/utils/generic_utils.py b/tensorflow/python/keras/utils/generic_utils.py
index 4fbb6d6..ebab3d7 100644
--- a/tensorflow/python/keras/utils/generic_utils.py
+++ b/tensorflow/python/keras/utils/generic_utils.py
@@ -96,8 +96,8 @@
```
Arguments:
- *args: Variable length list of dictionaries of name,
- class pairs to add to custom objects.
+ *args: Variable length list of dictionaries of name, class pairs to add to
+ custom objects.
Returns:
Object of type `CustomObjectScope`.
@@ -180,13 +180,63 @@
return decorator
-def _get_name_or_custom_name(obj):
+@keras_export('keras.utils.get_registered_name')
+def get_registered_name(obj):
+ """Returns the name registered to an object within the Keras framework.
+
+ This function is part of the Keras serialization and deserialization
+ framework. It maps objects to the string names associated with those objects
+ for serialization/deserialization.
+
+ Args:
+ obj: The object to look up.
+
+ Returns:
+ The name associated with the object, or the default Python name if the
+ object is not registered.
+ """
if obj in _GLOBAL_CUSTOM_NAMES:
return _GLOBAL_CUSTOM_NAMES[obj]
else:
return obj.__name__
+@keras_export('keras.utils.get_registered_object')
+def get_registered_object(name, custom_objects=None, module_objects=None):
+ """Returns the class associated with `name` if it is registered with Keras.
+
+ This function is part of the Keras serialization and deserialization
+ framework. It maps strings to the objects associated with them for
+ serialization/deserialization.
+
+ Example:
+ ```
+ def from_config(cls, config, custom_objects=None):
+ if 'my_custom_object_name' in config:
+ config['hidden_cls'] = tf.keras.utils.get_registered_object(
+ config['my_custom_object_name'], custom_objects=custom_objects)
+ ```
+
+ Args:
+ name: The name to look up.
+ custom_objects: A dictionary of custom objects to look the name up in.
+ Generally, custom_objects is provided by the user.
+ module_objects: A dictionary of custom objects to look the name up in.
+ Generally, module_objects is provided by midlevel library implementers.
+
+ Returns:
+ An instantiable class associated with 'name', or None if no such class
+ exists.
+ """
+ if name in _GLOBAL_CUSTOM_OBJECTS:
+ return _GLOBAL_CUSTOM_OBJECTS[name]
+ elif custom_objects and name in custom_objects:
+ return custom_objects[name]
+ elif module_objects and name in module_objects:
+ return module_objects[name]
+ return None
+
+
@keras_export('keras.utils.serialize_keras_object')
def serialize_keras_object(instance):
"""Serialize Keras object into JSON."""
@@ -212,22 +262,13 @@
except ValueError:
serialization_config[key] = item
- name = _get_name_or_custom_name(instance.__class__)
+ name = get_registered_name(instance.__class__)
return serialize_keras_class_and_config(name, serialization_config)
if hasattr(instance, '__name__'):
- return _get_name_or_custom_name(instance)
+ return get_registered_name(instance)
raise ValueError('Cannot serialize', instance)
-def _get_custom_objects_by_name(item, custom_objects=None):
- """Returns the item if it is in either local or global custom objects."""
- if item in _GLOBAL_CUSTOM_OBJECTS:
- return _GLOBAL_CUSTOM_OBJECTS[item]
- elif custom_objects and item in custom_objects:
- return custom_objects[item]
- return None
-
-
def class_and_config_for_serialized_keras_object(
config,
module_objects=None,
@@ -239,15 +280,9 @@
raise ValueError('Improper config format: ' + str(config))
class_name = config['class_name']
- if custom_objects and class_name in custom_objects:
- cls = custom_objects[class_name]
- elif class_name in _GLOBAL_CUSTOM_OBJECTS:
- cls = _GLOBAL_CUSTOM_OBJECTS[class_name]
- else:
- module_objects = module_objects or {}
- cls = module_objects.get(class_name)
- if cls is None:
- raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
+ cls = get_registered_object(class_name, custom_objects, module_objects)
+ if cls is None:
+ raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
cls_config = config['config']
deserialized_objects = {}
@@ -258,9 +293,9 @@
module_objects=module_objects,
custom_objects=custom_objects,
printable_module_name='config_item')
+ # TODO(momernick): Should this also have 'module_objects'?
elif (isinstance(item, six.string_types) and
- tf_inspect.isfunction(
- _get_custom_objects_by_name(item, custom_objects))):
+ tf_inspect.isfunction(get_registered_object(item, custom_objects))):
# Handle custom functions here. When saving functions, we only save the
# function's name as a string. If we find a matching string in the custom
# objects during deserialization, we convert the string back to the
@@ -269,8 +304,7 @@
# conflict with a custom function name, but this should be a rare case.
# This issue does not occur if a string field has a naming conflict with
# a custom object, since the config of an object will always be a dict.
- deserialized_objects[key] = _get_custom_objects_by_name(
- item, custom_objects)
+ deserialized_objects[key] = get_registered_object(item, custom_objects)
for key, item in deserialized_objects.items():
cls_config[key] = deserialized_objects[key]
@@ -382,6 +416,7 @@
Returns:
A value wrapped as a cell object (see function "func_load")
"""
+
def dummy_fn():
# pylint: disable=pointless-statement
value # just access it so it gets captured in .__closure__
@@ -410,8 +445,8 @@
Arguments:
fn: Callable to inspect.
name: Check if `fn` can be called with `name` as a keyword argument.
- accept_all: What to return if there is no parameter called `name`
- but the function accepts a `**kwargs` argument.
+ accept_all: What to return if there is no parameter called `name` but the
+ function accepts a `**kwargs` argument.
Returns:
bool, whether `fn` accepts a `name` keyword argument.
@@ -430,16 +465,20 @@
target: Total number of steps expected, None if unknown.
width: Progress bar width on screen.
verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
- stateful_metrics: Iterable of string names of metrics that
- should *not* be averaged over time. Metrics in this list
- will be displayed as-is. All others will be averaged
- by the progbar before display.
+ stateful_metrics: Iterable of string names of metrics that should *not* be
+ averaged over time. Metrics in this list will be displayed as-is. All
+ others will be averaged by the progbar before display.
interval: Minimum visual progress update interval (in seconds).
unit_name: Display name for step counts (usually "step" or "sample").
"""
- def __init__(self, target, width=30, verbose=1, interval=0.05,
- stateful_metrics=None, unit_name='step'):
+ def __init__(self,
+ target,
+ width=30,
+ verbose=1,
+ interval=0.05,
+ stateful_metrics=None,
+ unit_name='step'):
self.target = target
self.width = width
self.verbose = verbose
@@ -469,11 +508,9 @@
Arguments:
current: Index of current step.
- values: List of tuples:
- `(name, value_for_last_step)`.
- If `name` is in `stateful_metrics`,
- `value_for_last_step` will be displayed as-is.
- Else, an average of the metric over time will be displayed.
+ values: List of tuples: `(name, value_for_last_step)`. If `name` is in
+ `stateful_metrics`, `value_for_last_step` will be displayed as-is.
+ Else, an average of the metric over time will be displayed.
"""
values = values or []
for k, v in values:
@@ -538,8 +575,7 @@
eta = time_per_unit * (self.target - current)
if eta > 3600:
eta_format = '%d:%02d:%02d' % (eta // 3600,
- (eta % 3600) // 60,
- eta % 60)
+ (eta % 3600) // 60, eta % 60)
elif eta > 60:
eta_format = '%d:%02d' % (eta // 60, eta % 60)
else:
@@ -625,10 +661,8 @@
Arguments:
arrays: Single array or list of arrays.
- start: can be an integer index (start index)
- or a list/array of indices
- stop: integer (stop index); should be None if
- `start` was a list.
+ start: can be an integer index (start index) or a list/array of indices
+ stop: integer (stop index); should be None if `start` was a list.
Returns:
A slice of the array(s).
@@ -711,7 +745,8 @@
expected_values))
-def validate_kwargs(kwargs, allowed_kwargs,
+def validate_kwargs(kwargs,
+ allowed_kwargs,
error_message='Keyword argument not understood:'):
"""Checks that all keyword arguments are in the set of allowed keys."""
for kwarg in kwargs:
diff --git a/tensorflow/python/keras/utils/generic_utils_test.py b/tensorflow/python/keras/utils/generic_utils_test.py
index 619d31e..3347588 100644
--- a/tensorflow/python/keras/utils/generic_utils_test.py
+++ b/tensorflow/python/keras/utils/generic_utils_test.py
@@ -129,6 +129,13 @@
inst = OtherTestClass(val=5)
class_name = keras.utils.generic_utils._GLOBAL_CUSTOM_NAMES[OtherTestClass]
self.assertEqual(serialized_name, class_name)
+ fn_class_name = keras.utils.generic_utils.get_registered_name(
+ OtherTestClass)
+ self.assertEqual(fn_class_name, class_name)
+
+ cls = keras.utils.generic_utils.get_registered_object(fn_class_name)
+ self.assertEqual(OtherTestClass, cls)
+
config = keras.utils.generic_utils.serialize_keras_object(inst)
self.assertEqual(class_name, config['class_name'])
new_inst = keras.utils.generic_utils.deserialize_keras_object(config)
@@ -145,11 +152,17 @@
serialized_name = 'Custom>my_fn'
class_name = keras.utils.generic_utils._GLOBAL_CUSTOM_NAMES[my_fn]
self.assertEqual(serialized_name, class_name)
+ fn_class_name = keras.utils.generic_utils.get_registered_name(my_fn)
+ self.assertEqual(fn_class_name, class_name)
+
config = keras.utils.generic_utils.serialize_keras_object(my_fn)
self.assertEqual(class_name, config)
fn = keras.utils.generic_utils.deserialize_keras_object(config)
self.assertEqual(42, fn())
+ fn_2 = keras.utils.generic_utils.get_registered_object(fn_class_name)
+ self.assertEqual(42, fn_2())
+
def test_serialize_custom_class_without_get_config_fails(self):
with self.assertRaisesRegex(
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index fb25e50..6ea17b4 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -120,7 +120,6 @@
size = "small",
srcs = ["list_ops_test.py"],
grpc_enabled = True,
- tags = ["no_rocm"],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@@ -1287,7 +1286,6 @@
shard_count = 10,
tags = [
"no_oss", # TODO(b/142818120): Re-enable.
- "no_rocm",
],
deps = [
"//tensorflow/python:client_testlib",
@@ -2717,7 +2715,6 @@
srcs = ["tensor_array_ops_test.py"],
flaky = 1, # create_local_cluster sometimes times out.
shard_count = 10,
- tags = ["no_rocm"],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@@ -3448,9 +3445,6 @@
size = "medium",
srcs = ["qr_op_test.py"],
shard_count = 20,
- tags = [
- "no_rocm", # TODO(rocm): feature not supported on ROCm platform
- ],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py
index 402c6f0..c5f58f1 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py
@@ -46,6 +46,24 @@
axis=2)
return stats_summary
+ def add_f_dim_and_append_zeros(self, stats_summaries):
+ """Transform a list of stats summaries, adding a feature dimension.
+
+ The input shape is a list of arrays of shape [max_splits, num_buckets,
+ logits+hess dim]. This transformation returns a list of arrays of shape
+ [max_splits, 1, num_buckets + 1, logits+hess dim].
+
+ Args:
+ stats_summaries: a list of numpy arrays.
+
+ Returns:
+ A list of numpy arrays.
+ """
+ return [
+ self._append_zeros_for_default_bucket(np.expand_dims(feature, axis=1))
+ for feature in stats_summaries
+ ]
+
def _get_stats_summary_for_split(self):
return [
[
@@ -160,7 +178,7 @@
self.assertAllClose([[0.833333], [0.8]], right_node_contribs)
self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types)
- def testCalculateBestGainsWithoutRegularization(self):
+ def testCalculateBestGainsWithoutRegularization_v1_op(self):
"""Testing Gain calculation without any regularization."""
with self.cached_session() as sess:
max_splits = 7
@@ -189,19 +207,40 @@
self.assertAllClose([[[-.592593], [-.75]], [[-.076923], [.568966]]],
self.evaluate(right_node_contribs_list))
- def testCalculateBestMultiDimFeatureSplitsWithoutRegularization(self):
+ def testCalculateBestFeaturesInvalidSplitType_v2_op(self):
"""Testing best split calculation without any regularization."""
+ candidate_feature_ids = [9, 12]
node_id_range = [1, 3] # node 1 through 2 will be processed.
- stats_summary = np.asarray(self._get_stats_summary_for_split())
- # reshape to [max_splits, feature_dim, num_buckets, 2]
- stats_summary = np.moveaxis(stats_summary, 0, 1)
- stats_summary = self._append_zeros_for_default_bucket(stats_summary)
+ stats_summaries = self._get_stats_summary_for_split()
+ stats_summaries = self.add_f_dim_and_append_zeros(stats_summaries)
- (node_ids, gains, feature_dimensions, thresholds, left_node_contribs,
- right_node_contribs, split_types) = self.evaluate(
- boosted_trees_ops.calculate_best_feature_split(
+ with self.assertRaisesRegexp(Exception, 'Incorrect split type'):
+ self.evaluate(
+ boosted_trees_ops.calculate_best_feature_split_v2(
+ node_id_range,
+ stats_summaries,
+ split_types=['INVALID'] * len(candidate_feature_ids),
+ candidate_feature_ids=candidate_feature_ids,
+ l1=0.0,
+ l2=0.0,
+ tree_complexity=0.0,
+ min_node_weight=0,
+ logits_dimension=1))
+
+ def testCalculateBestFeaturesWithoutRegularization_v2_op(self):
+ """Testing best split calculation without any regularization."""
+ candidate_feature_ids = [9, 12]
+ node_id_range = [1, 3] # node 1 through 2 will be processed.
+ stats_summaries = self._get_stats_summary_for_split()
+ stats_summaries = self.add_f_dim_and_append_zeros(stats_summaries)
+
+ (node_ids, gains, feature_ids, feature_dimensions, thresholds,
+ left_node_contribs, right_node_contribs, split_types) = self.evaluate(
+ boosted_trees_ops.calculate_best_feature_split_v2(
node_id_range,
- stats_summary,
+ stats_summaries,
+ split_types=['inequality'] * len(candidate_feature_ids),
+ candidate_feature_ids=candidate_feature_ids,
l1=0.0,
l2=0.0,
tree_complexity=0.0,
@@ -209,10 +248,47 @@
logits_dimension=1))
# Get same result as v1 op (CalculateBestGainsPerFeature), and find the
- # feature dimension that has the best gain.
+ # feature_id and dimension that has the best gain per node.
self.assertAllEqual([1, 2], node_ids)
self.assertAllClose([0.02823, 0.41184], gains)
self.assertAllEqual([1, 1], thresholds)
+ self.assertAllEqual([12, 9], feature_ids)
+ f_dim = 0 # Both features only have one dimension.
+ self.assertAllEqual([f_dim] * 2, feature_dimensions)
+ # The left node contrib will be later added to the previous node value to
+ # make the left node value, and the same for right node contrib.
+ self.assertAllClose([[-.6], [.568966]], left_node_contribs)
+ self.assertAllClose([[-.076923], [-.75]], right_node_contribs)
+ self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types)
+
+ def testCalculateBestMultiDimFeatureSplitsWithoutRegularization_v2_op(self):
+ """Testing best split without any regularization for a multi-dim feature."""
+ candidate_feature_ids = [4]
+ node_id_range = [1, 3] # node 1 through 2 will be processed.
+ stats_summaries = self._get_stats_summary_for_split()
+ # Convert from list of arrays to a single array and reshape to [max_splits,
+ # feature_dim, num_buckets, 2].
+ stats_summary = np.moveaxis(stats_summaries, 0, 1)
+ stats_summary = self._append_zeros_for_default_bucket(stats_summary)
+
+ (node_ids, gains, feature_ids, feature_dimensions, thresholds,
+ left_node_contribs, right_node_contribs, split_types) = self.evaluate(
+ boosted_trees_ops.calculate_best_feature_split_v2(
+ node_id_range, [stats_summary],
+ split_types=['inequality'],
+ candidate_feature_ids=candidate_feature_ids,
+ l1=0.0,
+ l2=0.0,
+ tree_complexity=0.0,
+ min_node_weight=0,
+ logits_dimension=1))
+
+ # Get same result as v1 op (CalculateBestGainsPerFeature), and find the
+ # feature_id and dimension that has the best gain per node.
+ self.assertAllEqual([1, 2], node_ids)
+ self.assertAllClose([0.02823, 0.41184], gains)
+ self.assertAllEqual([1, 1], thresholds)
+ self.assertAllEqual([4, 4], feature_ids)
self.assertAllEqual([1, 0], feature_dimensions)
# The left node contrib will be later added to the previous node value to
# make the left node value, and the same for right node contrib.
@@ -220,18 +296,22 @@
self.assertAllClose([[-.076923], [-.75]], right_node_contribs)
self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types)
- def testCalculateBestMultiDimFeatureSplitWMissingValuesWORegularization(self):
+ def testCalculateBestMultiDimFeatureSplitWMissingValuesWORegularization_v2_op(
+ self):
"""Testing best split calculation without any regularization."""
+ candidate_feature_ids = [4]
node_id_range = [1, 3] # node 1 through 2 will be processed.
- stats_summary = np.asarray(self._get_stats_summary_for_split())
- # reshape to [max_splits, feature_dim, num_buckets, 2]
- stats_summary = np.moveaxis(stats_summary, 0, 1)
+ stats_summaries = self._get_stats_summary_for_split()
+ # Convert from list of arrays to a single array and reshape to [max_splits,
+ # feature_dim, num_buckets, 2].
+ stats_summary = np.moveaxis(stats_summaries, 0, 1)
- (node_ids, gains, feature_dimensions, thresholds, left_node_contribs,
- right_node_contribs, split_types) = self.evaluate(
- boosted_trees_ops.calculate_best_feature_split(
- node_id_range,
- stats_summary,
+ (node_ids, gains, feature_ids, feature_dimensions, thresholds,
+ left_node_contribs, right_node_contribs, split_types) = self.evaluate(
+ boosted_trees_ops.calculate_best_feature_split_v2(
+ node_id_range, [stats_summary],
+ split_types=['inequality'],
+ candidate_feature_ids=candidate_feature_ids,
l1=0.0,
l2=0.0,
tree_complexity=0.0,
@@ -242,39 +322,44 @@
# feature dimension that has the best gain.
self.assertAllEqual([1, 2], node_ids)
self.assertAllClose([0.116495, 0.60429], gains)
- self.assertAllEqual([1, 1], thresholds)
+ self.assertAllEqual([4, 4], feature_ids)
self.assertAllEqual([1, 1], feature_dimensions)
+ self.assertAllEqual([1, 1], thresholds)
# The left node contrib will be later added to the previous node value to
# make the left node value, and the same for right node contrib.
self.assertAllClose([[-0.631579], [-0.770833]], left_node_contribs)
self.assertAllClose([[0.833333], [0.8]], right_node_contribs)
self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types)
- def testCalculateBestMultiDimFeatureEqualitySplitsWithoutRegularization(self):
+ def testCalculateBestMultiDimFeatureEqualitySplitsWithoutRegularization_v2_op(
+ self):
"""Testing best split calculation without any regularization."""
+ candidate_feature_ids = [4]
node_id_range = [1, 3] # node 1 through 2 will be processed.
- stats_summary = np.asarray(self._get_stats_summary_for_split())
- # reshape to [max_splits, feature_dim, num_buckets, 2]
- stats_summary = np.moveaxis(stats_summary, 0, 1)
+ stats_summaries = self._get_stats_summary_for_split()
+ # Convert from list of arrays to a single array and reshape to [max_splits,
+ # feature_dim, num_buckets, 2].
+ stats_summary = np.moveaxis(stats_summaries, 0, 1)
- (node_ids, gains, feature_dimensions, thresholds, left_node_contribs,
- right_node_contribs, split_types) = self.evaluate(
- boosted_trees_ops.calculate_best_feature_split(
- node_id_range,
- stats_summary,
+ (node_ids, gains, feature_ids, feature_dimensions, thresholds,
+ left_node_contribs, right_node_contribs, split_types) = self.evaluate(
+ boosted_trees_ops.calculate_best_feature_split_v2(
+ node_id_range, [stats_summary],
+ split_types=['equality'],
+ candidate_feature_ids=candidate_feature_ids,
l1=0.0,
l2=0.0,
tree_complexity=0.0,
min_node_weight=0,
- logits_dimension=1,
- split_type='equality'))
+ logits_dimension=1))
self.assertAllEqual([1, 2], node_ids)
# 0.116495 = (-0.05)^2/0.06 + 0.36^2/0.57 - 0.31^2/0.63
# 0.60429 = (-0.4)^2/0.5 + 0.37^2/0.48 - 0.03^2/0.98
self.assertAllClose([0.116495, 0.60429], gains)
- self.assertAllEqual([2, 2], thresholds)
+ self.assertAllEqual([4, 4], feature_ids)
self.assertAllEqual([1, 1], feature_dimensions)
+ self.assertAllEqual([2, 2], thresholds)
# The left node contrib will be later added to the previous node value to
# make the left node value, and the same for right node contrib.
# left contrib 0.83 = 0.05/0.06, 0.8 = 0.4/0.5
@@ -283,7 +368,48 @@
self.assertAllClose([[-0.631579], [-0.770833]], right_node_contribs)
self.assertAllEqual([_EQUALITY_DEFAULT_RIGHT] * 2, split_types)
- def testCalculateBestGainsWithL2(self):
+ def testCalculateBestMultiDimFeatureMixedSplitTypeWithoutRegularization_v2_op(
+ self):
+ """Testing best split calculation without any regularization."""
+ candidate_feature_ids = [9, 12]
+ node_id_range = [1, 3] # node 1 through 2 will be processed.
+ stats_summaries = self._get_stats_summary_for_split()
+ # Add in feature dimension.
+ stats_summaries = [
+ np.expand_dims(feature, axis=1) for feature in stats_summaries
+ ]
+
+ (node_ids, gains, feature_ids, feature_dimensions, thresholds,
+ left_node_contribs, right_node_contribs, split_types) = self.evaluate(
+ boosted_trees_ops.calculate_best_feature_split_v2(
+ node_id_range,
+ stats_summaries,
+ split_types=['inequality', 'equality'],
+ candidate_feature_ids=candidate_feature_ids,
+ l1=0.0,
+ l2=0.0,
+ tree_complexity=0.0,
+ min_node_weight=0,
+ logits_dimension=1))
+
+ self.assertAllEqual([1, 2], node_ids)
+ # 0.116495 = (-0.05)^2/0.06 + 0.36^2/0.57 - 0.31^2/0.63
+ # 0.60429 = (-0.4)^2/0.5 + 0.37^2/0.48 - 0.03^2/0.98
+ self.assertAllClose([0.116495, 0.60429], gains)
+ self.assertAllEqual([12, 12], feature_ids)
+ f_dim = 0 # Both features only have one dimension.
+ self.assertAllEqual([f_dim, f_dim], feature_dimensions)
+ self.assertAllEqual([2, 2], thresholds)
+ # Same result as equality only test, as feature_1 is chose for both nodes.
+ # left contrib 0.83 = 0.05/0.06, 0.8 = 0.4/0.5
+ self.assertAllClose([[0.833333], [.8]], left_node_contribs)
+ # right contrib -0.6315 = -0.36/0.57, -0.7708 = -0.37/0.48
+ self.assertAllClose([[-0.631579], [-0.770833]], right_node_contribs)
+ # Feature 1 is inequality.
+ self.assertAllEqual([_EQUALITY_DEFAULT_RIGHT, _EQUALITY_DEFAULT_RIGHT],
+ split_types)
+
+ def testCalculateBestGainsWithL2_v1_op(self):
"""Testing Gain calculation with L2."""
with self.cached_session() as sess:
max_splits = 7
@@ -312,19 +438,22 @@
self.assertAllClose([[[-.424658], [-.6]], [[-.043478], [.485294]]],
self.evaluate(right_node_contribs_list))
- def testCalculateMultiDimBestFeatureSplitsWithL2(self):
+ def testCalculateMultiDimBestFeatureSplitsWithL2_v2_op(self):
"""Testing best split calculation with L2."""
+ candidate_feature_ids = [4]
node_id_range = [1, 3] # node 1 through 2 will be processed.
- stats_summary = np.asarray(self._get_stats_summary_for_split())
- # reshape to [max_splits, feature_dim, num_buckets, 2]
- stats_summary = np.moveaxis(stats_summary, 0, 1)
+ stats_summaries = self._get_stats_summary_for_split()
+ # Convert from list of arrays to a single array and reshape to [max_splits,
+ # feature_dim, num_buckets, 2].
+ stats_summary = np.moveaxis(stats_summaries, 0, 1)
stats_summary = self._append_zeros_for_default_bucket(stats_summary)
- (node_ids, gains, feature_dimensions, thresholds, left_node_contribs,
- right_node_contribs, split_types) = self.evaluate(
- boosted_trees_ops.calculate_best_feature_split(
- node_id_range,
- stats_summary,
+ (node_ids, gains, feature_ids, feature_dimensions, thresholds,
+ left_node_contribs, right_node_contribs, split_types) = self.evaluate(
+ boosted_trees_ops.calculate_best_feature_split_v2(
+ node_id_range, [stats_summary],
+ split_types=['inequality'],
+ candidate_feature_ids=candidate_feature_ids,
l1=0.0,
l2=0.1,
tree_complexity=0.0,
@@ -334,27 +463,31 @@
# Get same result as v1 op (CalculateBestGainsPerFeature), and find the
# feature dimension that has the best gain.
self.assertAllEqual([1, 2], node_ids)
+ self.assertAllEqual([4, 4], feature_ids)
+ self.assertAllEqual([1, 0], feature_dimensions)
self.assertAllClose([0.01879096, 0.33931375], gains)
self.assertAllEqual([1, 1], thresholds)
- self.assertAllEqual([1, 0], feature_dimensions)
# # The left node contrib will be later added to the previous node value to
# # make the left node value, and the same for right node contrib.
self.assertAllClose([[-.5], [.485294]], left_node_contribs)
self.assertAllClose([[-.043478], [-.6]], right_node_contribs)
self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types)
- def testCalculateMultiDimBestFeatureSplitsWithMissingValuesL2(self):
+ def testCalculateMultiDimBestFeatureSplitsWithMissingValuesL2_v2_op(self):
"""Testing best split calculation with L2."""
+ candidate_feature_ids = [4]
node_id_range = [1, 3] # node 1 through 2 will be processed.
- stats_summary = np.asarray(self._get_stats_summary_for_split())
- # reshape to [max_splits, feature_dim, num_buckets, 2]
- stats_summary = np.moveaxis(stats_summary, 0, 1)
+ stats_summaries = self._get_stats_summary_for_split()
+ # Convert from list of arrays to a single array and reshape to [max_splits,
+ # feature_dim, num_buckets, 2].
+ stats_summary = np.moveaxis(stats_summaries, 0, 1)
- (node_ids, gains, feature_dimensions, thresholds, left_node_contribs,
- right_node_contribs, split_types) = self.evaluate(
- boosted_trees_ops.calculate_best_feature_split(
- node_id_range,
- stats_summary,
+ (node_ids, gains, feature_ids, feature_dimensions, thresholds,
+ left_node_contribs, right_node_contribs, split_types) = self.evaluate(
+ boosted_trees_ops.calculate_best_feature_split_v2(
+ node_id_range, [stats_summary],
+ split_types=['inequality'],
+ candidate_feature_ids=candidate_feature_ids,
l1=0.0,
l2=0.1,
tree_complexity=0.0,
@@ -364,40 +497,44 @@
# Get same result as v1 op (CalculateBestGainsPerFeature), and find the
# feature dimension that has the best gain.
self.assertAllEqual([1, 2], node_ids)
+ self.assertAllEqual([4, 4], feature_ids)
+ self.assertAllEqual([1, 1], feature_dimensions)
self.assertAllClose([0.077414, 0.501868], gains)
self.assertAllEqual([1, 1], thresholds)
- self.assertAllEqual([1, 1], feature_dimensions)
# The left node contrib will be later added to the previous node value to
# make the left node value, and the same for right node contrib.
self.assertAllClose([[-0.537313], [-0.637931]], left_node_contribs)
self.assertAllClose([[0.3125], [0.666667]], right_node_contribs)
self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types)
- def testCalculateMultiDimBestFeatureEqualitySplitsWithL2(self):
+ def testCalculateMultiDimBestFeatureEqualitySplitsWithL2_v2_op(self):
"""Testing best split calculation with L2."""
+ candidate_feature_ids = [4]
node_id_range = [1, 3] # node 1 through 2 will be processed.
- stats_summary = np.asarray(self._get_stats_summary_for_split())
- # reshape to [max_splits, feature_dim, num_buckets, 2]
- stats_summary = np.moveaxis(stats_summary, 0, 1)
+ stats_summaries = self._get_stats_summary_for_split()
+ # Convert from list of arrays to a single array and reshape to [max_splits,
+ # feature_dim, num_buckets, 2].
+ stats_summary = np.moveaxis(stats_summaries, 0, 1)
- (node_ids, gains, feature_dimensions, thresholds, left_node_contribs,
- right_node_contribs, split_types) = self.evaluate(
- boosted_trees_ops.calculate_best_feature_split(
- node_id_range,
- stats_summary,
+ (node_ids, gains, feature_ids, feature_dimensions, thresholds,
+ left_node_contribs, right_node_contribs, split_types) = self.evaluate(
+ boosted_trees_ops.calculate_best_feature_split_v2(
+ node_id_range, [stats_summary],
+ split_types=['equality'],
+ candidate_feature_ids=candidate_feature_ids,
l1=0.0,
l2=0.1,
tree_complexity=0.0,
min_node_weight=0,
- logits_dimension=1,
- split_type='equality'))
+ logits_dimension=1))
self.assertAllEqual([1, 2], node_ids)
+ self.assertAllEqual([4, 4], feature_ids)
+ self.assertAllEqual([1, 1], feature_dimensions)
# 0.077414 = 0.05^2/0.16 + 0.36^2/0.67 - 0.31^2/0.73
# 0.501868 = 0.4^2/0.6 + 0.37^2/0.58 - 0.03^2/1.08
self.assertAllClose([0.077414, 0.501868], gains)
self.assertAllEqual([2, 2], thresholds)
- self.assertAllEqual([1, 1], feature_dimensions)
# # The left node contrib will be later added to the previous node value to
# # make the left node value, and the same for right node contrib.
# left contrib 0.3125 = 0.05/0.16, 0.6667 = 0.4/0.6
@@ -434,7 +571,7 @@
self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT, _INEQUALITY_DEFAULT_LEFT],
split_types)
- def testCalculateBestGainsWithL1(self):
+ def testCalculateBestGainsWithL1_v1_op(self):
"""Testing Gain calculation with L1."""
with self.cached_session() as sess:
max_splits = 7
@@ -466,22 +603,24 @@
self.assertAllClose([[0.0, 0.191207], [0.01, 0.191207]],
self.evaluate(gains_list))
- def testCalculateBestMultiDimFeatureSplitsWithL1(self):
+ def testCalculateBestMultiDimFeatureSplitsWithL1_v2_op(self):
"""Testing best split calculation with L1."""
+ candidate_feature_ids = [4]
node_id_range = [1, 3] # node 1 through 2 will be processed.
- stats_summary = np.asarray(self._get_stats_summary_for_split())
- # reshape to [max_splits, feature_dim, num_buckets, 2]
- stats_summary = np.moveaxis(stats_summary, 0, 1)
+ stats_summaries = self._get_stats_summary_for_split()
+ # Convert from list of arrays to a single array and reshape to [max_splits,
+ # feature_dim, num_buckets, 2].
+ stats_summary = np.moveaxis(stats_summaries, 0, 1)
stats_summary = self._append_zeros_for_default_bucket(stats_summary)
- l1 = 0.1
- (node_ids, gains, feature_dimensions, thresholds, left_node_contribs,
- right_node_contribs, split_types) = self.evaluate(
- boosted_trees_ops.calculate_best_feature_split(
- node_id_range,
- stats_summary,
- l1=l1,
- l2=0.,
+ (node_ids, gains, feature_ids, feature_dimensions, thresholds,
+ left_node_contribs, right_node_contribs, split_types) = self.evaluate(
+ boosted_trees_ops.calculate_best_feature_split_v2(
+ node_id_range, [stats_summary],
+ split_types=['inequality'],
+ candidate_feature_ids=candidate_feature_ids,
+ l1=0.1,
+ l2=0.0,
tree_complexity=0.0,
min_node_weight=0,
logits_dimension=1))
@@ -489,29 +628,32 @@
# Get same result as v1 op (CalculateBestGainsPerFeature), and find the
# feature dimension that has the best gain.
self.assertAllEqual([1, 2], node_ids)
+ self.assertAllEqual([4, 4], feature_ids)
+ self.assertAllEqual([1, 1], feature_dimensions)
# Gain should also include an adjustment of the gradient by l1.
self.assertAllClose([0.01, 0.191207], gains)
self.assertAllEqual([1, 1], thresholds)
self.assertAllClose([[-0.4], [-0.5]], left_node_contribs)
self.assertAllClose([[0.], [0.396552]], right_node_contribs)
- self.assertAllEqual([1, 1], feature_dimensions)
self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types)
- def testCalculateBestMultiDimFeatureSplitsWithMissingValuesL1(self):
+ def testCalculateBestMultiDimFeatureSplitsWithMissingValuesL1_v2_op(self):
"""Testing best split calculation with L1."""
+ candidate_feature_ids = [4]
node_id_range = [1, 3] # node 1 through 2 will be processed.
- stats_summary = np.asarray(self._get_stats_summary_for_split())
- # reshape to [max_splits, feature_dim, num_buckets, 2]
- stats_summary = np.moveaxis(stats_summary, 0, 1)
+ stats_summaries = self._get_stats_summary_for_split()
+ # Convert from list of arrays to a single array and reshape to [max_splits,
+ # feature_dim, num_buckets, 2].
+ stats_summary = np.moveaxis(stats_summaries, 0, 1)
- l1 = 0.1
- (node_ids, gains, feature_dimensions, thresholds, left_node_contribs,
- right_node_contribs, split_types) = self.evaluate(
- boosted_trees_ops.calculate_best_feature_split(
- node_id_range,
- stats_summary,
- l1=l1,
- l2=0.,
+ (node_ids, gains, feature_ids, feature_dimensions, thresholds,
+ left_node_contribs, right_node_contribs, split_types) = self.evaluate(
+ boosted_trees_ops.calculate_best_feature_split_v2(
+ node_id_range, [stats_summary],
+ split_types=['inequality'],
+ candidate_feature_ids=candidate_feature_ids,
+ l1=0.1,
+ l2=0.0,
tree_complexity=0.0,
min_node_weight=0,
logits_dimension=1))
@@ -519,6 +661,8 @@
# Get same result as v1 op (CalculateBestGainsPerFeature), and find the
# feature dimension that has the best gain.
self.assertAllEqual([1, 2], node_ids)
+ self.assertAllEqual([4, 4], feature_ids)
+ self.assertAllEqual([1, 1], feature_dimensions)
# Gain should also include an adjustment of the gradient by l1.
# (0.36-0.1)^2/0.57 + 0 - (0.31-0.1)^2/0.63 = 0.048597
# (0.37-0.1)^2/0.48 + (-0.4+0.1)^2/0.5 = 0.331875
@@ -529,35 +673,37 @@
self.assertAllClose([[-0.45614], [-0.5625]], left_node_contribs)
# -(-0.4+0.1)/0.5 = 0.6
self.assertAllClose([[0.], [0.6]], right_node_contribs)
- self.assertAllEqual([1, 1], feature_dimensions)
self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types)
- def testCalculateBestMultiDimFeatureEqualitySplitsWithL1(self):
+ def testCalculateBestMultiDimFeatureEqualitySplitsWithL1_v2_op(self):
"""Testing best split calculation with L1."""
+ candidate_feature_ids = [4]
node_id_range = [1, 3] # node 1 through 2 will be processed.
- stats_summary = np.asarray(self._get_stats_summary_for_split())
- # reshape to [max_splits, feature_dim, num_buckets, 2]
- stats_summary = np.moveaxis(stats_summary, 0, 1)
+ stats_summaries = self._get_stats_summary_for_split()
+ # Convert from list of arrays to a single array and reshape to [max_splits,
+ # feature_dim, num_buckets, 2].
+ stats_summary = np.moveaxis(stats_summaries, 0, 1)
+ stats_summary = self._append_zeros_for_default_bucket(stats_summary)
- l1 = 0.1
- (node_ids, gains, feature_dimensions, thresholds, left_node_contribs,
- right_node_contribs, split_types) = self.evaluate(
- boosted_trees_ops.calculate_best_feature_split(
- node_id_range,
- stats_summary,
- l1=l1,
- l2=0.,
+ (node_ids, gains, feature_ids, feature_dimensions, thresholds,
+ left_node_contribs, right_node_contribs, split_types) = self.evaluate(
+ boosted_trees_ops.calculate_best_feature_split_v2(
+ node_id_range, [stats_summary],
+ split_types=['equality'],
+ candidate_feature_ids=candidate_feature_ids,
+ l1=0.1,
+ l2=0.0,
tree_complexity=0.0,
min_node_weight=0,
- logits_dimension=1,
- split_type='equality'))
+ logits_dimension=1))
self.assertAllEqual([1, 2], node_ids)
# 0.048597 = 0 + 0.26^2/0.57 - 0.21^2/0.63
# 0.501868 = 0.3^2/0.5 + 0.27^2/0.48 - 0
self.assertAllClose([0.048597, 0.331875], gains)
- self.assertAllEqual([2, 2], thresholds)
+ self.assertAllEqual([4, 4], feature_ids)
self.assertAllEqual([1, 1], feature_dimensions)
+ self.assertAllEqual([2, 2], thresholds)
# # The left node contrib will be later added to the previous node value to
# # make the left node value, and the same for right node contrib.
# left contrib 0 (-0.05>-0.1), 0.6 = 0.3/0.5
@@ -593,7 +739,7 @@
self.assertAllClose([[0.0], [0.6]], right_node_contribs)
self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types)
- def testCalculateBestGainsWithTreeComplexity(self):
+ def testCalculateBestGainsWithTreeComplexity_v1_op(self):
"""Testing best gain calculation with tree complexity."""
with self.cached_session() as sess:
max_splits = 7
@@ -626,24 +772,25 @@
self.assertAllClose([[[-.424658], [-.6]], [[-.043478], [.485294]]],
self.evaluate(right_node_contribs_list))
- def testCalculateBestMultiDimFeatureSplitsWithTreeComplexity(self):
+ def testCalculateBestMultiDimFeatureSplitsWithTreeComplexity_v2_op(self):
"""Testing best split calculation with tree complexity."""
+ candidate_feature_ids = [4]
node_id_range = [1, 3] # node 1 through 2 will be processed.
- stats_summary = np.asarray(self._get_stats_summary_for_split())
- # reshape to [max_splits, feature_dim, num_buckets, 2]
- stats_summary = np.moveaxis(stats_summary, 0, 1)
+ stats_summaries = self._get_stats_summary_for_split()
+ # Convert from list of arrays to a single array and reshape to [max_splits,
+ # feature_dim, num_buckets, 2].
+ stats_summary = np.moveaxis(stats_summaries, 0, 1)
stats_summary = self._append_zeros_for_default_bucket(stats_summary)
- l2 = 0.1
- tree_complexity = 3.
- (node_ids, gains, feature_dimensions, thresholds, left_node_contribs,
- right_node_contribs, split_types) = self.evaluate(
- boosted_trees_ops.calculate_best_feature_split(
- node_id_range,
- stats_summary,
- l1=0.,
- l2=l2,
- tree_complexity=tree_complexity,
+ (node_ids, gains, feature_ids, feature_dimensions, thresholds,
+ left_node_contribs, right_node_contribs, split_types) = self.evaluate(
+ boosted_trees_ops.calculate_best_feature_split_v2(
+ node_id_range, [stats_summary],
+ split_types=['inequality'],
+ candidate_feature_ids=candidate_feature_ids,
+ l1=0.0,
+ l2=0.1,
+ tree_complexity=3,
min_node_weight=0,
logits_dimension=1))
@@ -652,29 +799,32 @@
self.assertAllEqual([1, 2], node_ids)
# Gain should also include an adjustment of the gradient by l1.
self.assertAllClose([-2.98120904, -2.66068625], gains)
+ self.assertAllEqual([4, 4], feature_ids)
+ self.assertAllEqual([1, 0], feature_dimensions)
self.assertAllEqual([1, 1], thresholds)
self.assertAllClose([[-0.5], [0.485294]], left_node_contribs)
self.assertAllClose([[-0.043478], [-.6]], right_node_contribs)
- self.assertAllEqual([1, 0], feature_dimensions)
self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types)
- def testCalculateBestMultiDimFeatureSplitsWMissingValsTreeComplexity(self):
+ def testCalculateBestMultiDimFeatureSplitsWMissingValsTreeComplexity_v2_op(
+ self):
"""Testing best split calculation with tree complexity."""
+ candidate_feature_ids = [4]
node_id_range = [1, 3] # node 1 through 2 will be processed.
- stats_summary = np.asarray(self._get_stats_summary_for_split())
- # reshape to [max_splits, feature_dim, num_buckets, 2]
- stats_summary = np.moveaxis(stats_summary, 0, 1)
+ stats_summaries = self._get_stats_summary_for_split()
+ # Convert from list of arrays to a single array and reshape to [max_splits,
+ # feature_dim, num_buckets, 2].
+ stats_summary = np.moveaxis(stats_summaries, 0, 1)
- l2 = 0.1
- tree_complexity = 3.
- (node_ids, gains, feature_dimensions, thresholds, left_node_contribs,
- right_node_contribs, split_types) = self.evaluate(
- boosted_trees_ops.calculate_best_feature_split(
- node_id_range,
- stats_summary,
- l1=0.,
- l2=l2,
- tree_complexity=tree_complexity,
+ (node_ids, gains, feature_ids, feature_dimensions, thresholds,
+ left_node_contribs, right_node_contribs, split_types) = self.evaluate(
+ boosted_trees_ops.calculate_best_feature_split_v2(
+ node_id_range, [stats_summary],
+ split_types=['inequality'],
+ candidate_feature_ids=candidate_feature_ids,
+ l1=0.0,
+ l2=0.1,
+ tree_complexity=3,
min_node_weight=0,
logits_dimension=1))
@@ -683,38 +833,41 @@
self.assertAllEqual([1, 2], node_ids)
# Gain should also include an adjustment of the gradient by l1.
self.assertAllClose([-2.922586, -2.498132], gains)
+ self.assertAllEqual([4, 4], feature_ids)
+ self.assertAllEqual([1, 1], feature_dimensions)
self.assertAllEqual([1, 1], thresholds)
self.assertAllClose([[-0.537313], [-0.637931]], left_node_contribs)
self.assertAllClose([[0.3125], [0.666667]], right_node_contribs)
- self.assertAllEqual([1, 1], feature_dimensions)
self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types)
- def testCalculateBestMultiDimFeatureEqualitySplitsWithTreeComplexity(self):
+ def testCalculateBestMultiDimFeatureEqualitySplitsWithTreeComplexity_v2_op(
+ self):
"""Testing best split calculation with tree complexity."""
+ candidate_feature_ids = [4]
node_id_range = [1, 3] # node 1 through 2 will be processed.
- stats_summary = np.asarray(self._get_stats_summary_for_split())
- # reshape to [max_splits, feature_dim, num_buckets, 2]
- stats_summary = np.moveaxis(stats_summary, 0, 1)
+ stats_summaries = self._get_stats_summary_for_split()
+ # Convert from list of arrays to a single array and reshape to [max_splits,
+ # feature_dim, num_buckets, 2].
+ stats_summary = np.moveaxis(stats_summaries, 0, 1)
- l2 = 0.1
- tree_complexity = 3.
- (node_ids, gains, feature_dimensions, thresholds, left_node_contribs,
- right_node_contribs, split_types) = self.evaluate(
- boosted_trees_ops.calculate_best_feature_split(
- node_id_range,
- stats_summary,
- l1=0.,
- l2=l2,
- tree_complexity=tree_complexity,
+ (node_ids, gains, feature_ids, feature_dimensions, thresholds,
+ left_node_contribs, right_node_contribs, split_types) = self.evaluate(
+ boosted_trees_ops.calculate_best_feature_split_v2(
+ node_id_range, [stats_summary],
+ split_types=['equality'],
+ candidate_feature_ids=candidate_feature_ids,
+ l1=0.0,
+ l2=0.1,
+ tree_complexity=3,
min_node_weight=0,
- logits_dimension=1,
- split_type='equality'))
+ logits_dimension=1))
self.assertAllEqual([1, 2], node_ids)
# -2.922586 = 0.05^2/0.16 + 0.36^2/0.67 - 0.31^2/0.73 - 3
# -2.498132 = 0.4^2/0.6 + 0.37^2/0.58 - 0.03^2/1.08 - 3
self.assertAllClose([-2.922586, -2.498132], gains)
self.assertAllEqual([2, 2], thresholds)
+ self.assertAllEqual([4, 4], feature_ids)
self.assertAllEqual([1, 1], feature_dimensions)
# # The left node contrib will be later added to the previous node value to
# # make the left node value, and the same for right node contrib.
@@ -751,7 +904,7 @@
self.assertAllClose([[0.3125], [0.666667]], right_node_contribs)
self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types)
- def testCalculateBestGainsWithMinNodeWeight(self):
+ def testCalculateBestGainsWithMinNodeWeight_v1_op(self):
"""Testing Gain calculation with min node weight."""
with self.cached_session() as sess:
max_splits = 7
@@ -798,8 +951,9 @@
self.assertAllClose([[[-0.75]], [[-0.014925]]],
self.evaluate(right_node_contribs_list))
- def testCalculateMultiDimBestSplitsWithMinNodeWeight(self):
+ def testCalculateMultiDimBestSplitsWithMinNodeWeight_v2_op(self):
"""Testing best split calculation with min node weight."""
+ candidate_feature_ids = [4]
node_id_range = [1, 3] # node 1 through 2 will be processed.
stats_summary = np.asarray([
[
@@ -810,7 +964,7 @@
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored
- ], # feature 0
+ ], # f_dim 0
[
[[0., 0.], [0., 0.], [.08, .09], [0., 0.]], # node 0; ignored
[[0., 0.], [.3, .5], [-.05, .6], [.06, .07]], # node 1
@@ -819,34 +973,37 @@
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored
- ], # feature 1
+ ], # f_dim 1
]) # feature_dim * shape=[max_splits, num_buckets, 2]
- # reshape to [max_splits, feature_dim, num_buckets, 2]
+ # Reshape to [max_splits, feature_dim, num_buckets, 2].
stats_summary = np.moveaxis(stats_summary, 0, 1)
stats_summary = self._append_zeros_for_default_bucket(stats_summary)
- (node_ids, gains, feature_dimensions, thresholds, left_node_contribs,
- right_node_contribs, split_types) = self.evaluate(
- boosted_trees_ops.calculate_best_feature_split(
- node_id_range,
- stats_summary,
- l1=0.,
- l2=0.,
- tree_complexity=0.,
+ (node_ids, gains, feature_ids, feature_dimensions, thresholds,
+ left_node_contribs, right_node_contribs, split_types) = self.evaluate(
+ boosted_trees_ops.calculate_best_feature_split_v2(
+ node_id_range, [stats_summary],
+ split_types=['inequality'],
+ candidate_feature_ids=candidate_feature_ids,
+ l1=0.0,
+ l2=0.0,
+ tree_complexity=0.0,
min_node_weight=1,
logits_dimension=1))
self.assertAllEqual([1, 2], node_ids)
# Gain should also include an adjustment of the gradient by l1.
self.assertAllClose([0.098013, 0.931596], gains)
+ self.assertAllEqual([4, 4], feature_ids)
+ self.assertAllEqual([1, 1], feature_dimensions)
self.assertAllEqual([1, 1], thresholds)
self.assertAllClose([[-.6], [-0.315789]], left_node_contribs)
self.assertAllClose([[-0.014925], [2.53846]], right_node_contribs)
- self.assertAllEqual([1, 1], feature_dimensions)
self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types)
- def testCalculateMultiDimBestSplitsWithMissingValuesMinNodeWeight(self):
+ def testCalculateMultiDimBestSplitsWithMissingValuesMinNodeWeight_v2_op(self):
"""Testing best split calculation with min node weight."""
+ candidate_feature_ids = [4]
node_id_range = [1, 3] # node 1 through 2 will be processed.
stats_summary = np.asarray([
[
@@ -857,7 +1014,7 @@
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored
- ], # feature 0
+ ], # f_dim 0
[
[[0., 0.], [0., 0.], [.08, .09], [0., 0.]], # node 0; ignored
[[0., 0.], [.3, .5], [-.05, .6], [.06, .07]], # node 1
@@ -866,29 +1023,31 @@
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored
- ], # feature 1
+ ], # f_dim 1
]) # feature_dim * shape=[max_splits, num_buckets, 2]
- # reshape to [max_splits, feature_dim, num_buckets, 2]
+ # Reshape to [max_splits, feature_dim, num_buckets, 2].
stats_summary = np.moveaxis(stats_summary, 0, 1)
- (node_ids, gains, feature_dimensions, thresholds, left_node_contribs,
- right_node_contribs, split_types) = self.evaluate(
- boosted_trees_ops.calculate_best_feature_split(
- node_id_range,
- stats_summary,
- l1=0.,
- l2=0.,
- tree_complexity=0.,
+ (node_ids, gains, feature_ids, feature_dimensions, thresholds,
+ left_node_contribs, right_node_contribs, split_types) = self.evaluate(
+ boosted_trees_ops.calculate_best_feature_split_v2(
+ node_id_range, [stats_summary],
+ split_types=['inequality'],
+ candidate_feature_ids=candidate_feature_ids,
+ l1=0.0,
+ l2=0.0,
+ tree_complexity=0.0,
min_node_weight=1,
logits_dimension=1))
self.assertAllEqual([1, 2], node_ids)
# Gain should also include an adjustment of the gradient by l1.
self.assertAllClose([0.149398, 3.332075], gains)
+ self.assertAllEqual([4, 4], feature_ids)
+ self.assertAllEqual([1, 1], feature_dimensions)
self.assertAllEqual([1, 1], thresholds)
self.assertAllClose([[-0.631579], [-0.359223]], left_node_contribs)
self.assertAllClose([[0.083333], [7.999989]], right_node_contribs)
- self.assertAllEqual([1, 1], feature_dimensions)
self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types)
def testSparseCalculateBestSplitsWithMinNodeWeight(self):
@@ -942,7 +1101,8 @@
self.assertAllEqual([_INEQUALITY_DEFAULT_RIGHT, _INEQUALITY_DEFAULT_LEFT],
split_types)
- def testCalculateBestGainsWithMinNodeWeightNoSplitOnFeturePossible(self):
+ def testCalculateBestGainsWithMinNodeWeightNoSplitOnFeaturePossible_v1_op(
+ self):
"""Testing Gain calculation without any regularization."""
with self.cached_session() as sess:
max_splits = 7
@@ -995,8 +1155,10 @@
max_splits=max_splits)
self.assertAllEqual([[], []], self.evaluate(node_ids_list))
- def testCalculateBestMultiDimFeatureSplitsWithNoSplitOnFeaturePossible(self):
+ def testCalculateBestMultiDimFeatureSplitsWithNoSplitOnFeaturePossible_v2_op(
+ self):
"""Testing best split calculation with min node weight and no split."""
+ candidate_feature_ids = [4]
node_id_range = [1, 3] # node 1 through 2 will be processed.
stats_summary = np.asarray([
[
@@ -1007,7 +1169,7 @@
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored
- ], # feature 0
+ ], # f_dim 0
[
[[0., 0.], [0., 0.], [.08, .09], [0., 0.]], # node 0; ignored
[[0., 0.], [.3, .5], [-.05, .06], [.06, .7]], # node 1
@@ -1016,29 +1178,32 @@
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored
- ], # feature 1
+ ], # f_dim 1
]) # feature_dim * shape=[max_splits, num_buckets, 2]
- # reshape to [max_splits, feature_dim, num_buckets, 2]
+ # Reshape to [max_splits, feature_dim, num_buckets, 2].
stats_summary = np.moveaxis(stats_summary, 0, 1)
+ stats_summary = self._append_zeros_for_default_bucket(stats_summary)
- (node_ids, _, _, _, _, _,
- _) = boosted_trees_ops.calculate_best_feature_split(
- node_id_range,
- stats_summary,
+ (node_ids, _, _, _, _, _, _,
+ _) = boosted_trees_ops.calculate_best_feature_split_v2(
+ node_id_range, [stats_summary],
+ split_types=['inequality'],
+ candidate_feature_ids=candidate_feature_ids,
l1=0.0,
l2=0.0,
tree_complexity=0.0,
min_node_weight=1,
logits_dimension=1)
- # We can't split either of the nodes on the first feature
+ # We can't split either of the nodes on the first feature.
self.assertAllEqual([1], node_ids)
- # Now check when we can't split on any feature
- (node_ids, _, _, _, _, _,
- _) = boosted_trees_ops.calculate_best_feature_split(
- node_id_range,
- stats_summary,
+ # Now check when we can't split on any feature.
+ (node_ids, _, _, _, _, _, _,
+ _) = boosted_trees_ops.calculate_best_feature_split_v2(
+ node_id_range, [stats_summary],
+ split_types=['inequality'],
+ candidate_feature_ids=candidate_feature_ids,
l1=0.0,
l2=0.0,
tree_complexity=0.0,
@@ -1046,8 +1211,10 @@
logits_dimension=1)
self.assertAllEqual([], node_ids)
- def testCalculateBestMultiDimFeatureEqualitySplitsWithNoSplitPossible(self):
+ def testCalculateBestMultiDimFeatureEqualitySplitsWithNoSplitPossible_v2_op(
+ self):
"""Testing best split calculation with min node weight and no split."""
+ candidate_feature_ids = [4]
node_id_range = [1, 3] # node 1 through 2 will be processed.
stats_summary = np.asarray([
[
@@ -1058,7 +1225,7 @@
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored
- ], # feature 0
+ ], # f_dim 0
[
[[0., 0.], [0., 0.], [.08, .09], [0., 0.]], # node 0; ignored
[[0., 0.], [.3, .5], [-.05, .06], [.06, .7]], # node 1
@@ -1067,30 +1234,31 @@
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored
- ], # feature 1
+ ], # f_dim 1
]) # feature_dim * shape=[max_splits, num_buckets, 2]
- # reshape to [max_splits, feature_dim, num_buckets, 2]
+ # Reshape to [max_splits, feature_dim, num_buckets, 2].
stats_summary = np.moveaxis(stats_summary, 0, 1)
- (node_ids, _, _, _, _, _,
- _) = boosted_trees_ops.calculate_best_feature_split(
- node_id_range,
- stats_summary,
+ (node_ids, _, _, _, _, _, _,
+ _) = boosted_trees_ops.calculate_best_feature_split_v2(
+ node_id_range, [stats_summary],
+ split_types=['equality'],
+ candidate_feature_ids=candidate_feature_ids,
l1=0.0,
l2=0.0,
tree_complexity=0.0,
min_node_weight=1,
- logits_dimension=1,
- split_type='equality')
+ logits_dimension=1)
# We can't split either of the nodes on the first feature
self.assertAllEqual([1], node_ids)
# Now check when we can't split on any feature
- (node_ids, _, _, _, _, _,
- _) = boosted_trees_ops.calculate_best_feature_split(
- node_id_range,
- stats_summary,
+ (node_ids, _, _, _, _, _, _,
+ _) = boosted_trees_ops.calculate_best_feature_split_v2(
+ node_id_range, [stats_summary],
+ split_types=['equality'],
+ candidate_feature_ids=candidate_feature_ids,
l1=0.0,
l2=0.0,
tree_complexity=0.0,
@@ -1502,8 +1670,8 @@
self._verify_precision(length=50000000)
-class BestMultiDimFeatureSplitMultiClass(StatsOpsTest):
- """Tests multi-class/multi-regression for best splits."""
+class BestMultiDimFeatureSplitMultiClassV2Op(StatsOpsTest):
+ """Tests multi-class/multi-regression for best splits using V2 op."""
logits_dim = 2
@@ -1566,6 +1734,7 @@
def testCalculateBestFeatureSplitsSingleClassVsMultiClass(self):
"""Testing same results using same grads/hess with both single and multi."""
+ candidate_feature_ids = [14]
node_id_range = [1, 3] # node 1 through 2 will be processed.
# Build same stats summary in single class and multi-class form (using
@@ -1589,23 +1758,25 @@
# [max_splits, feature_dim, num_buckets, 4]
diag_stats_summary = self._add_feature_dim(diag_stats_summary)
- (node_ids, gains, feature_dimensions, thresholds, left_node_contribs,
- right_node_contribs, split_types) = self.evaluate(
- boosted_trees_ops.calculate_best_feature_split(
- node_id_range,
- stats_summary,
+ (node_ids, gains, feature_ids, feature_dimensions, thresholds,
+ left_node_contribs, right_node_contribs, split_types) = self.evaluate(
+ boosted_trees_ops.calculate_best_feature_split_v2(
+ node_id_range, [stats_summary],
+ split_types=['inequality'],
+ candidate_feature_ids=candidate_feature_ids,
l1=0.0,
l2=0.0,
tree_complexity=0.0,
min_node_weight=0,
logits_dimension=1))
- (diag_node_ids, diag_gains, diag_feature_dimensions, diag_thresholds,
- diag_left_node_contribs, diag_right_node_contribs,
+ (diag_node_ids, diag_gains, diag_feature_ids, diag_feature_dimensions,
+ diag_thresholds, diag_left_node_contribs, diag_right_node_contribs,
diag_split_types) = self.evaluate(
- boosted_trees_ops.calculate_best_feature_split(
- node_id_range,
- diag_stats_summary,
+ boosted_trees_ops.calculate_best_feature_split_v2(
+ node_id_range, [diag_stats_summary],
+ split_types=['inequality'],
+ candidate_feature_ids=candidate_feature_ids,
l1=0.0,
l2=0.0,
tree_complexity=0.0,
@@ -1614,8 +1785,9 @@
self.assertAllEqual(node_ids, diag_node_ids)
self.assertAllClose(gains, diag_gains)
- self.assertAllEqual(thresholds, diag_thresholds)
+ self.assertAllEqual(feature_ids, diag_feature_ids)
self.assertAllEqual(feature_dimensions, diag_feature_dimensions)
+ self.assertAllEqual(thresholds, diag_thresholds)
# The left node contrib will be later added to the previous node value to
# make the left node value, and the same for right node contrib.
zeros = np.zeros_like(left_node_contribs)
@@ -1629,6 +1801,7 @@
def testCalculateBestFeatureSplitsDiagonalVsFull(self):
"""Test results are same using diagonal hessian and full hessian."""
+ candidate_feature_ids = [14]
node_id_range = [1, 3] # node 1 through 2 will be processed.
# Build same stats summary in diagonal and full hessian form, respectively.
@@ -1651,24 +1824,26 @@
]
# [max_splits, feature_dim, num_buckets, logits_dim + logits_dim**2]
full_stats_summary = self._add_feature_dim(full_stats_summary)
- (diag_node_ids, diag_gains, diag_feature_dimensions, diag_thresholds,
- diag_left_node_contribs, diag_right_node_contribs,
+ (diag_node_ids, diag_gains, diag_feature_ids, diag_feature_dimensions,
+ diag_thresholds, diag_left_node_contribs, diag_right_node_contribs,
diag_split_types) = self.evaluate(
- boosted_trees_ops.calculate_best_feature_split(
- node_id_range,
- diag_stats_summary,
+ boosted_trees_ops.calculate_best_feature_split_v2(
+ node_id_range, [diag_stats_summary],
+ split_types=['inequality'],
+ candidate_feature_ids=candidate_feature_ids,
l1=0.0,
l2=0.0,
tree_complexity=0.0,
min_node_weight=0,
logits_dimension=self.logits_dim))
- (full_node_ids, full_gains, full_feature_dimensions, full_thresholds,
- full_left_node_contribs, full_right_node_contribs,
+ (full_node_ids, full_gains, full_feature_ids, full_feature_dimensions,
+ full_thresholds, full_left_node_contribs, full_right_node_contribs,
full_split_types) = self.evaluate(
- boosted_trees_ops.calculate_best_feature_split(
- node_id_range,
- full_stats_summary,
+ boosted_trees_ops.calculate_best_feature_split_v2(
+ node_id_range, [full_stats_summary],
+ split_types=['inequality'],
+ candidate_feature_ids=candidate_feature_ids,
l1=0.0,
l2=0.0,
tree_complexity=0.0,
@@ -1677,8 +1852,9 @@
self.assertAllEqual(diag_node_ids, full_node_ids)
self.assertAllClose(diag_gains, full_gains)
- self.assertAllEqual(diag_thresholds, full_thresholds)
+ self.assertAllEqual(diag_feature_ids, full_feature_ids)
self.assertAllEqual(diag_feature_dimensions, full_feature_dimensions)
+ self.assertAllEqual(diag_thresholds, full_thresholds)
# The left node contrib will be later added to the previous node value to
# make the left node value, and the same for right node contrib.
self.assertAllClose(diag_left_node_contribs, full_left_node_contribs)
@@ -1687,16 +1863,18 @@
def testCalculateBestFeatureSplitsWithoutRegularization(self):
"""Testing best split calculation without any regularization."""
+ candidate_feature_ids = [14]
node_id_range = [1, 3] # node 1 through 2 will be processed.
# [max_splits, feature_dim, num_buckets, 2*logits_dim]
stats_summary = self._get_stats_summary_for_split_diagonal_hessian()
stats_summary = self._append_zeros_for_default_bucket(stats_summary)
- (node_ids, gains, feature_dimensions, thresholds, left_node_contribs,
- right_node_contribs, split_types) = self.evaluate(
- boosted_trees_ops.calculate_best_feature_split(
- node_id_range,
- stats_summary,
+ (node_ids, gains, feature_ids, feature_dimensions, thresholds,
+ left_node_contribs, right_node_contribs, split_types) = self.evaluate(
+ boosted_trees_ops.calculate_best_feature_split_v2(
+ node_id_range, [stats_summary],
+ split_types=['inequality'],
+ candidate_feature_ids=candidate_feature_ids,
l1=0.0,
l2=0.0,
tree_complexity=0.0,
@@ -1706,6 +1884,7 @@
self.assertAllEqual([1, 2], node_ids)
self.assertAllClose([0.912981, 1.446218], gains)
self.assertAllEqual([2, 1], thresholds)
+ self.assertAllEqual([14, 14], feature_ids)
self.assertAllEqual([0, 1], feature_dimensions)
# The left node contrib will be later added to the previous node value to
# make the left node value, and the same for right node contrib.
@@ -1717,15 +1896,17 @@
def testCalculateBestFeatureSplitsWMissingValuesWoRegularization(self):
"""Testing best split calculation without any regularization."""
+ candidate_feature_ids = [14]
node_id_range = [1, 3] # node 1 through 2 will be processed.
# [max_splits, feature_dim, num_buckets, 2*logits_dim]
stats_summary = self._get_stats_summary_for_split_diagonal_hessian()
- (node_ids, gains, feature_dimensions, thresholds, left_node_contribs,
- right_node_contribs, split_types) = self.evaluate(
- boosted_trees_ops.calculate_best_feature_split(
- node_id_range,
- stats_summary,
+ (node_ids, gains, feature_ids, feature_dimensions, thresholds,
+ left_node_contribs, right_node_contribs, split_types) = self.evaluate(
+ boosted_trees_ops.calculate_best_feature_split_v2(
+ node_id_range, [stats_summary],
+ split_types=['inequality'],
+ candidate_feature_ids=candidate_feature_ids,
l1=0.0,
l2=0.0,
tree_complexity=0.0,
@@ -1735,6 +1916,7 @@
self.assertAllEqual([1, 2], node_ids)
self.assertAllClose([0.912981, 2.79444], gains)
self.assertAllEqual([0, 1], thresholds)
+ self.assertAllEqual([14, 14], feature_ids)
self.assertAllEqual([0, 1], feature_dimensions)
# The left node contrib will be later added to the previous node value to
# make the left node value, and the same for right node contrib.
@@ -1746,17 +1928,19 @@
def testCalculateBestFeatureSplitsWithL2(self):
"""Testing best split calculation inith L2 regularization."""
+ candidate_feature_ids = [14]
node_id_range = [1, 3] # node 1 through 2 will be processed.
# [max_splits, feature_dim, num_buckets, 2*logits_dim]
stats_summary = self._get_stats_summary_for_split_diagonal_hessian()
stats_summary = self._append_zeros_for_default_bucket(stats_summary)
l2 = 0.1
- (node_ids, gains, feature_dimensions, thresholds, left_node_contribs,
- right_node_contribs, split_types) = self.evaluate(
- boosted_trees_ops.calculate_best_feature_split(
- node_id_range,
- stats_summary,
+ (node_ids, gains, feature_ids, feature_dimensions, thresholds,
+ left_node_contribs, right_node_contribs, split_types) = self.evaluate(
+ boosted_trees_ops.calculate_best_feature_split_v2(
+ node_id_range, [stats_summary],
+ split_types=['inequality'],
+ candidate_feature_ids=candidate_feature_ids,
l1=0.0,
l2=l2,
tree_complexity=0.0,
@@ -1766,6 +1950,7 @@
self.assertAllEqual([1, 2], node_ids)
self.assertAllClose([0.475669, 1.009791], gains)
self.assertAllEqual([1, 1], thresholds)
+ self.assertAllEqual([14, 14], feature_ids)
self.assertAllEqual([0, 1], feature_dimensions)
# The left node contrib will be later added to the previous node value to
# make the left node value, and the same for right node contrib.
@@ -1777,16 +1962,18 @@
def testCalculateBestFeatureSplitsWithMissingValuesL2(self):
"""Testing best split calculation inith L2 regularization."""
+ candidate_feature_ids = [14]
node_id_range = [1, 3] # node 1 through 2 will be processed.
# [max_splits, feature_dim, num_buckets, 2*logits_dim]
stats_summary = self._get_stats_summary_for_split_diagonal_hessian()
l2 = 0.1
- (node_ids, gains, feature_dimensions, thresholds, left_node_contribs,
- right_node_contribs, split_types) = self.evaluate(
- boosted_trees_ops.calculate_best_feature_split(
- node_id_range,
- stats_summary,
+ (node_ids, gains, feature_ids, feature_dimensions, thresholds,
+ left_node_contribs, right_node_contribs, split_types) = self.evaluate(
+ boosted_trees_ops.calculate_best_feature_split_v2(
+ node_id_range, [stats_summary],
+ split_types=['inequality'],
+ candidate_feature_ids=candidate_feature_ids,
l1=0.0,
l2=l2,
tree_complexity=0.0,
@@ -1796,6 +1983,7 @@
self.assertAllEqual([1, 2], node_ids)
self.assertAllClose([0.475669, 3.467833], gains)
self.assertAllEqual([1, 0], thresholds)
+ self.assertAllEqual([14, 14], feature_ids)
self.assertAllEqual([0, 1], feature_dimensions)
# The left node contrib will be later added to the previous node value to
# make the left node value, and the same for right node contrib.
@@ -1808,15 +1996,17 @@
def testCalculateBestFeatureSplitsWithMinNodeWeight(self):
"""Testing best split calculation with min_node_weight."""
+ candidate_feature_ids = [14]
node_id_range = [1, 3] # node 1 through 2 will be processed.
# [max_splits, feature_dim, num_buckets, 2*logits_dim]
stats_summary = self._get_stats_summary_for_split_diagonal_hessian()
- (node_ids, gains, feature_dimensions, thresholds, left_node_contribs,
- right_node_contribs, split_types) = self.evaluate(
- boosted_trees_ops.calculate_best_feature_split(
- node_id_range,
- stats_summary,
+ (node_ids, gains, feature_ids, feature_dimensions, thresholds,
+ left_node_contribs, right_node_contribs, split_types) = self.evaluate(
+ boosted_trees_ops.calculate_best_feature_split_v2(
+ node_id_range, [stats_summary],
+ split_types=['inequality'],
+ candidate_feature_ids=candidate_feature_ids,
l1=0.0,
l2=0.0,
tree_complexity=0.0,
@@ -1827,6 +2017,7 @@
self.assertAllEqual([1, 2], node_ids)
self.assertAllClose([0.912981, 2.79444], gains)
self.assertAllEqual([0, 1], thresholds)
+ self.assertAllEqual([14, 14], feature_ids)
self.assertAllEqual([0, 1], feature_dimensions)
# The left node contrib will be later added to the previous node value to
# make the left node value, and the same for right node contrib.
@@ -1838,17 +2029,19 @@
def testCalculateBestFeatureSplitsWithTreeComplexity(self):
"""Testing best split calculation with tree complexity."""
+ candidate_feature_ids = [14]
node_id_range = [1, 3] # node 1 through 2 will be processed.
# [max_splits, feature_dim, num_buckets, 2*logits_dim]
stats_summary = self._get_stats_summary_for_split_diagonal_hessian()
l2 = 0.1
tree_complexity = 3.
- (node_ids, gains, feature_dimensions, thresholds, left_node_contribs,
- right_node_contribs, split_types) = self.evaluate(
- boosted_trees_ops.calculate_best_feature_split(
- node_id_range,
- stats_summary,
+ (node_ids, gains, feature_ids, feature_dimensions, thresholds,
+ left_node_contribs, right_node_contribs, split_types) = self.evaluate(
+ boosted_trees_ops.calculate_best_feature_split_v2(
+ node_id_range, [stats_summary],
+ split_types=['inequality'],
+ candidate_feature_ids=candidate_feature_ids,
l1=0.0,
l2=l2,
tree_complexity=tree_complexity,
@@ -1860,6 +2053,7 @@
# L2 test result, but subtracted by tree_complexity.
self.assertAllClose([-2.524331, 0.467833], gains)
self.assertAllEqual([1, 0], thresholds)
+ self.assertAllEqual([14, 14], feature_ids)
self.assertAllEqual([0, 1], feature_dimensions)
# The left node contrib will be later added to the previous node value to
# make the left node value, and the same for right node contrib.
@@ -1872,16 +2066,18 @@
def testCalculateBestFeatureSplitsWithMinNodeNoSplitOnFeaturePossible(self):
"""Test when parent node hessian doesn't meet min node weight."""
+ candidate_feature_ids = [14]
node_id_range = [1, 3] # node 1 through 2 will be processed.
# [max_splits, feature_dim, num_buckets, 2*logits_dim]
stats_summary = self._get_stats_summary_for_split_diagonal_hessian()
min_node_weight = 0.8
- (node_ids, gains, feature_dimensions, thresholds, left_node_contribs,
- right_node_contribs, split_types) = self.evaluate(
- boosted_trees_ops.calculate_best_feature_split(
- node_id_range,
- stats_summary,
+ (node_ids, gains, feature_ids, feature_dimensions, thresholds,
+ left_node_contribs, right_node_contribs, split_types) = self.evaluate(
+ boosted_trees_ops.calculate_best_feature_split_v2(
+ node_id_range, [stats_summary],
+ split_types=['inequality'],
+ candidate_feature_ids=candidate_feature_ids,
l1=0.0,
l2=0.0,
tree_complexity=0.0,
@@ -1892,6 +2088,7 @@
self.assertAllEqual([2], node_ids)
self.assertAllClose([2.79444], gains)
self.assertAllEqual([1], thresholds)
+ self.assertAllEqual([14], feature_ids)
self.assertAllEqual([1], feature_dimensions)
# The left node contrib will be later added to the previous node value to
# make the left node value, and the same for right node contrib.
diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py
index 36eb854..b4abcfa 100644
--- a/tensorflow/python/kernel_tests/conv_ops_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_test.py
@@ -363,7 +363,7 @@
tf_logging.debug("actual = %s", value)
tol_to_use = fp16_tol if value.dtype == np.float16 else tol
if np.issubdtype(value.dtype, np.integer):
- self.assertAllEqual(expected, np.ravel(value))
+ self.assertAllEqual(np.rint(expected), np.ravel(value))
else:
self.assertAllClose(expected, np.ravel(value), atol=tol_to_use,
rtol=tol_to_use)
@@ -2659,7 +2659,7 @@
value = self.evaluate(conv)
tf_logging.debug("value = %s", value)
- self.assertArrayNear(expected, np.ravel(value), 1e-3)
+ self.assertArrayNear(expected, np.ravel(value), 2e-3)
self.assertShapeEqual(value, conv)
def _testSeparableConv2D(self, data_format):
diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py
index c1e09e3..5383410 100644
--- a/tensorflow/python/kernel_tests/py_func_test.py
+++ b/tensorflow/python/kernel_tests/py_func_test.py
@@ -321,7 +321,7 @@
y, = script_ops.py_func(bad, [], [dtypes.float32])
with self.assertRaisesRegexp(errors.InternalError,
- "Unsupported NumPy struct data type"):
+ "Unsupported numpy data type"):
self.evaluate(y)
@test_util.run_v1_only("b/120545219")
diff --git a/tensorflow/python/kernel_tests/signal/BUILD b/tensorflow/python/kernel_tests/signal/BUILD
index 49076bd..230b35c 100644
--- a/tensorflow/python/kernel_tests/signal/BUILD
+++ b/tensorflow/python/kernel_tests/signal/BUILD
@@ -25,7 +25,6 @@
name = "dct_ops_test",
srcs = ["dct_ops_test.py"],
python_version = "PY3",
- tags = ["no_rocm"],
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
@@ -70,7 +69,6 @@
name = "mfcc_ops_test",
srcs = ["mfcc_ops_test.py"],
python_version = "PY3",
- tags = ["no_rocm"],
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
diff --git a/tensorflow/python/kernel_tests/signal/dct_ops_test.py b/tensorflow/python/kernel_tests/signal/dct_ops_test.py
index 2899f4d..d4f9e39 100644
--- a/tensorflow/python/kernel_tests/signal/dct_ops_test.py
+++ b/tensorflow/python/kernel_tests/signal/dct_ops_test.py
@@ -87,7 +87,7 @@
phi = np.cos(np.pi * (np.arange(dct_size) + 0.5) * k / dct_size)
dct[..., k] = np.sum(signals_mod * phi, axis=-1)
# SciPy's `dct` has a scaling factor of 2.0 which we follow.
- # https://github.com/scipy/scipy/blob/v0.15.1/scipy/fftpack/src/dct.c.src
+ # https://github.com/scipy/scipy/blob/v1.2.1/scipy/fftpack/src/dct.c.src
if norm == "ortho":
# The orthonormal scaling includes a factor of 0.5 which we combine with
# the overall scaling of 2.0 to cancel.
@@ -101,7 +101,7 @@
def _np_dct3(signals, n=None, norm=None):
"""Computes the DCT-III manually with NumPy."""
# SciPy's `dct` has a scaling factor of 2.0 which we follow.
- # https://github.com/scipy/scipy/blob/v0.15.1/scipy/fftpack/src/dct.c.src
+ # https://github.com/scipy/scipy/blob/v1.2.1/scipy/fftpack/src/dct.c.src
signals_mod = _modify_input_for_dct(signals, n=n)
dct_size = signals_mod.shape[-1]
signals_mod = np.array(signals_mod) # make a copy so we can modify
@@ -120,8 +120,30 @@
return dct
-NP_DCT = {1: _np_dct1, 2: _np_dct2, 3: _np_dct3}
-NP_IDCT = {1: _np_dct1, 2: _np_dct3, 3: _np_dct2}
+def _np_dct4(signals, n=None, norm=None):
+ """Computes the DCT-IV manually with NumPy."""
+ # SciPy's `dct` has a scaling factor of 2.0 which we follow.
+ # https://github.com/scipy/scipy/blob/v1.2.1/scipy/fftpack/src/dct.c.src
+ signals_mod = _modify_input_for_dct(signals, n=n)
+ dct_size = signals_mod.shape[-1]
+ signals_mod = np.array(signals_mod) # make a copy so we can modify
+ if norm == "ortho":
+ signals_mod *= np.sqrt(2.0 / dct_size)
+ else:
+ signals_mod *= 2.0
+ dct = np.zeros_like(signals_mod)
+ # X_k = sum_{n=0}^{N-1}
+ # x_n * cos(\frac{pi}{4N} * (2n + 1) * (2k + 1)) k=0,...,N-1
+ for k in range(dct_size):
+ phi = np.cos(np.pi *
+ (2 * np.arange(0, dct_size) + 1) * (2 * k + 1) /
+ (4.0 * dct_size))
+ dct[..., k] = np.sum(signals_mod * phi, axis=-1)
+ return dct
+
+
+NP_DCT = {1: _np_dct1, 2: _np_dct2, 3: _np_dct3, 4: _np_dct4}
+NP_IDCT = {1: _np_dct1, 2: _np_dct3, 3: _np_dct2, 4: _np_dct4}
@test_util.run_all_in_graph_and_eager_modes
@@ -137,7 +159,7 @@
tf_idct = dct_ops.idct(signals, type=dct_type, norm=norm)
self.assertEqual(tf_idct.dtype.as_numpy_dtype, signals.dtype)
self.assertAllClose(np_idct, tf_idct, atol=atol, rtol=rtol)
- if fftpack:
+ if fftpack and dct_type != 4:
scipy_dct = fftpack.dct(signals, n=n, type=dct_type, norm=norm)
self.assertAllClose(scipy_dct, tf_dct, atol=atol, rtol=rtol)
scipy_idct = fftpack.idct(signals, type=dct_type, norm=norm)
@@ -159,7 +181,7 @@
self.assertAllClose(signals, tf_dct_idct, atol=atol, rtol=rtol)
@parameterized.parameters(itertools.product(
- [1, 2, 3],
+ [1, 2, 3, 4],
[None, "ortho"],
[[2], [3], [10], [2, 20], [2, 3, 25]],
[np.float32, np.float64]))
diff --git a/tensorflow/python/lib/core/bfloat16.cc b/tensorflow/python/lib/core/bfloat16.cc
index 7fda88e..42b248a 100644
--- a/tensorflow/python/lib/core/bfloat16.cc
+++ b/tensorflow/python/lib/core/bfloat16.cc
@@ -21,19 +21,11 @@
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/python/lib/core/numpy.h"
+#include "tensorflow/python/lib/core/safe_ptr.h"
namespace tensorflow {
namespace {
-struct PyDecrefDeleter {
- void operator()(PyObject* p) const { Py_DECREF(p); }
-};
-
-using Safe_PyObjectPtr = std::unique_ptr<PyObject, PyDecrefDeleter>;
-Safe_PyObjectPtr make_safe(PyObject* object) {
- return Safe_PyObjectPtr(object);
-}
-
// Workarounds for Python 2 vs 3 API differences.
#if PY_MAJOR_VERSION < 3
diff --git a/tensorflow/python/lib/core/bfloat16_test.py b/tensorflow/python/lib/core/bfloat16_test.py
index ee17fea..32453ae 100644
--- a/tensorflow/python/lib/core/bfloat16_test.py
+++ b/tensorflow/python/lib/core/bfloat16_test.py
@@ -24,10 +24,12 @@
import numpy as np
# pylint: disable=unused-import,g-bad-import-order
+from tensorflow.python import _pywrap_bfloat16
from tensorflow.python.framework import dtypes
from tensorflow.python.platform import test
-bfloat16 = dtypes._np_bfloat16 # pylint: disable=protected-access
+
+bfloat16 = _pywrap_bfloat16.TF_bfloat16_type()
class Bfloat16Test(test.TestCase):
diff --git a/tensorflow/python/lib/core/bfloat16_wrapper.cc b/tensorflow/python/lib/core/bfloat16_wrapper.cc
new file mode 100644
index 0000000..4a8e180
--- /dev/null
+++ b/tensorflow/python/lib/core/bfloat16_wrapper.cc
@@ -0,0 +1,24 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "include/pybind11/pybind11.h"
+#include "tensorflow/python/lib/core/bfloat16.h"
+
+PYBIND11_MODULE(_pywrap_bfloat16, m) {
+ tensorflow::RegisterNumpyBfloat16();
+
+ m.def("TF_bfloat16_type",
+ [] { return pybind11::handle(tensorflow::Bfloat16PyType()); });
+}
diff --git a/tensorflow/python/lib/core/ndarray_tensor.cc b/tensorflow/python/lib/core/ndarray_tensor.cc
index fcf41c2..8c83629 100644
--- a/tensorflow/python/lib/core/ndarray_tensor.cc
+++ b/tensorflow/python/lib/core/ndarray_tensor.cc
@@ -21,13 +21,171 @@
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/python/lib/core/bfloat16.h"
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
-#include "tensorflow/python/lib/core/ndarray_tensor_types.h"
#include "tensorflow/python/lib/core/numpy.h"
namespace tensorflow {
namespace {
+char const* numpy_type_name(int numpy_type) {
+ switch (numpy_type) {
+#define TYPE_CASE(s) \
+ case s: \
+ return #s
+
+ TYPE_CASE(NPY_BOOL);
+ TYPE_CASE(NPY_BYTE);
+ TYPE_CASE(NPY_UBYTE);
+ TYPE_CASE(NPY_SHORT);
+ TYPE_CASE(NPY_USHORT);
+ TYPE_CASE(NPY_INT);
+ TYPE_CASE(NPY_UINT);
+ TYPE_CASE(NPY_LONG);
+ TYPE_CASE(NPY_ULONG);
+ TYPE_CASE(NPY_LONGLONG);
+ TYPE_CASE(NPY_ULONGLONG);
+ TYPE_CASE(NPY_FLOAT);
+ TYPE_CASE(NPY_DOUBLE);
+ TYPE_CASE(NPY_LONGDOUBLE);
+ TYPE_CASE(NPY_CFLOAT);
+ TYPE_CASE(NPY_CDOUBLE);
+ TYPE_CASE(NPY_CLONGDOUBLE);
+ TYPE_CASE(NPY_OBJECT);
+ TYPE_CASE(NPY_STRING);
+ TYPE_CASE(NPY_UNICODE);
+ TYPE_CASE(NPY_VOID);
+ TYPE_CASE(NPY_DATETIME);
+ TYPE_CASE(NPY_TIMEDELTA);
+ TYPE_CASE(NPY_HALF);
+ TYPE_CASE(NPY_NTYPES);
+ TYPE_CASE(NPY_NOTYPE);
+ TYPE_CASE(NPY_CHAR);
+ TYPE_CASE(NPY_USERDEF);
+ default:
+ return "not a numpy type";
+ }
+}
+
+Status PyArrayDescr_to_TF_DataType(PyArray_Descr* descr,
+ TF_DataType* out_tf_datatype) {
+ PyObject* key;
+ PyObject* value;
+ Py_ssize_t pos = 0;
+ if (PyDict_Next(descr->fields, &pos, &key, &value)) {
+ // In Python 3, the keys of numpy custom struct types are unicode, unlike
+ // Python 2, where the keys are bytes.
+ const char* key_string =
+ PyBytes_Check(key) ? PyBytes_AsString(key)
+ : PyBytes_AsString(PyUnicode_AsASCIIString(key));
+ if (!key_string) {
+ return errors::Internal("Corrupt numpy type descriptor");
+ }
+ tensorflow::string key = key_string;
+ // The typenames here should match the field names in the custom struct
+ // types constructed in test_util.py.
+ // TODO(mrry,keveman): Investigate Numpy type registration to replace this
+ // hard-coding of names.
+ if (key == "quint8") {
+ *out_tf_datatype = TF_QUINT8;
+ } else if (key == "qint8") {
+ *out_tf_datatype = TF_QINT8;
+ } else if (key == "qint16") {
+ *out_tf_datatype = TF_QINT16;
+ } else if (key == "quint16") {
+ *out_tf_datatype = TF_QUINT16;
+ } else if (key == "qint32") {
+ *out_tf_datatype = TF_QINT32;
+ } else if (key == "resource") {
+ *out_tf_datatype = TF_RESOURCE;
+ } else {
+ return errors::Internal("Unsupported numpy data type");
+ }
+ return Status::OK();
+ }
+ return errors::Internal("Unsupported numpy data type");
+}
+
+Status PyArray_TYPE_to_TF_DataType(PyArrayObject* array,
+ TF_DataType* out_tf_datatype) {
+ int pyarray_type = PyArray_TYPE(array);
+ PyArray_Descr* descr = PyArray_DESCR(array);
+ switch (pyarray_type) {
+ case NPY_FLOAT16:
+ *out_tf_datatype = TF_HALF;
+ break;
+ case NPY_FLOAT32:
+ *out_tf_datatype = TF_FLOAT;
+ break;
+ case NPY_FLOAT64:
+ *out_tf_datatype = TF_DOUBLE;
+ break;
+ case NPY_INT32:
+ *out_tf_datatype = TF_INT32;
+ break;
+ case NPY_UINT8:
+ *out_tf_datatype = TF_UINT8;
+ break;
+ case NPY_UINT16:
+ *out_tf_datatype = TF_UINT16;
+ break;
+ case NPY_UINT32:
+ *out_tf_datatype = TF_UINT32;
+ break;
+ case NPY_UINT64:
+ *out_tf_datatype = TF_UINT64;
+ break;
+ case NPY_INT8:
+ *out_tf_datatype = TF_INT8;
+ break;
+ case NPY_INT16:
+ *out_tf_datatype = TF_INT16;
+ break;
+ case NPY_INT64:
+ *out_tf_datatype = TF_INT64;
+ break;
+ case NPY_BOOL:
+ *out_tf_datatype = TF_BOOL;
+ break;
+ case NPY_COMPLEX64:
+ *out_tf_datatype = TF_COMPLEX64;
+ break;
+ case NPY_COMPLEX128:
+ *out_tf_datatype = TF_COMPLEX128;
+ break;
+ case NPY_OBJECT:
+ case NPY_STRING:
+ case NPY_UNICODE:
+ *out_tf_datatype = TF_STRING;
+ break;
+ case NPY_VOID:
+ // Quantized types are currently represented as custom struct types.
+ // PyArray_TYPE returns NPY_VOID for structs, and we should look into
+ // descr to derive the actual type.
+ // Direct feeds of certain types of ResourceHandles are represented as a
+ // custom struct type.
+ return PyArrayDescr_to_TF_DataType(descr, out_tf_datatype);
+ default:
+ if (pyarray_type == Bfloat16NumpyType()) {
+ *out_tf_datatype = TF_BFLOAT16;
+ break;
+ } else if (pyarray_type == NPY_ULONGLONG) {
+ // NPY_ULONGLONG is equivalent to NPY_UINT64, while their enum values
+ // might be different on certain platforms.
+ *out_tf_datatype = TF_UINT64;
+ break;
+ } else if (pyarray_type == NPY_LONGLONG) {
+ // NPY_LONGLONG is equivalent to NPY_INT64, while their enum values
+ // might be different on certain platforms.
+ *out_tf_datatype = TF_INT64;
+ break;
+ }
+ return errors::Internal("Unsupported numpy type: ",
+ numpy_type_name(pyarray_type));
+ }
+ return Status::OK();
+}
+
Status PyObjectToString(PyObject* obj, const char** ptr, Py_ssize_t* len,
PyObject** ptr_owner) {
*ptr_owner = nullptr;
@@ -186,6 +344,38 @@
return Status::OK();
}
+// Determine the type description (PyArray_Descr) of a numpy ndarray to be
+// created to represent an output Tensor.
+Status GetPyArrayDescrForTensor(const TF_Tensor* tensor,
+ PyArray_Descr** descr) {
+ if (TF_TensorType(tensor) == TF_RESOURCE) {
+ PyObject* field = PyTuple_New(3);
+#if PY_MAJOR_VERSION < 3
+ PyTuple_SetItem(field, 0, PyBytes_FromString("resource"));
+#else
+ PyTuple_SetItem(field, 0, PyUnicode_FromString("resource"));
+#endif
+ PyTuple_SetItem(field, 1, PyArray_TypeObjectFromType(NPY_UBYTE));
+ PyTuple_SetItem(field, 2, PyLong_FromLong(1));
+ PyObject* fields = PyList_New(1);
+ PyList_SetItem(fields, 0, field);
+ int convert_result = PyArray_DescrConverter(fields, descr);
+ Py_CLEAR(field);
+ Py_CLEAR(fields);
+ if (convert_result != 1) {
+ return errors::Internal("Failed to create numpy array description for ",
+ "TF_RESOURCE-type tensor");
+ }
+ } else {
+ int type_num = -1;
+ TF_RETURN_IF_ERROR(
+ TF_DataType_to_PyArray_TYPE(TF_TensorType(tensor), &type_num));
+ *descr = PyArray_DescrFromType(type_num);
+ }
+
+ return Status::OK();
+}
+
inline void FastMemcpy(void* dst, const void* src, size_t size) {
// clang-format off
switch (size) {
@@ -271,8 +461,7 @@
// Copy the TF_TensorData into a newly-created ndarray and return it.
PyArray_Descr* descr = nullptr;
- TF_RETURN_IF_ERROR(DataTypeToPyArray_Descr(
- static_cast<DataType>(TF_TensorType(tensor.get())), &descr));
+ TF_RETURN_IF_ERROR(GetPyArrayDescrForTensor(tensor.get(), &descr));
Safe_PyObjectPtr safe_out_array =
tensorflow::make_safe(PyArray_Empty(dims.size(), dims.data(), descr, 0));
if (!safe_out_array) {
@@ -310,11 +499,7 @@
// Convert numpy dtype to TensorFlow dtype.
TF_DataType dtype = TF_FLOAT;
- {
- DataType tmp;
- TF_RETURN_IF_ERROR(PyArray_DescrToDataType(PyArray_DESCR(array), &tmp));
- dtype = static_cast<TF_DataType>(tmp);
- }
+ TF_RETURN_IF_ERROR(PyArray_TYPE_to_TF_DataType(array, &dtype));
tensorflow::int64 nelems = 1;
gtl::InlinedVector<int64_t, 4> dims;
diff --git a/tensorflow/python/lib/core/ndarray_tensor_bridge.cc b/tensorflow/python/lib/core/ndarray_tensor_bridge.cc
index 485b34c..03ff771 100644
--- a/tensorflow/python/lib/core/ndarray_tensor_bridge.cc
+++ b/tensorflow/python/lib/core/ndarray_tensor_bridge.cc
@@ -13,19 +13,16 @@
limitations under the License.
==============================================================================*/
-#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
+// Must be included first.
+#include "tensorflow/python/lib/core/numpy.h"
#include <vector>
-// Must be included first.
-// clang-format: off
-#include "tensorflow/python/lib/core/numpy.h"
-// clang-format: on
-
#include "tensorflow/c/c_api.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/python/lib/core/ndarray_tensor_types.h"
+#include "tensorflow/python/lib/core/bfloat16.h"
+#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
namespace tensorflow {
@@ -110,6 +107,85 @@
nullptr, /* tp_richcompare */
};
+Status TF_DataType_to_PyArray_TYPE(TF_DataType tf_datatype,
+ int* out_pyarray_type) {
+ switch (tf_datatype) {
+ case TF_HALF:
+ *out_pyarray_type = NPY_FLOAT16;
+ break;
+ case TF_FLOAT:
+ *out_pyarray_type = NPY_FLOAT32;
+ break;
+ case TF_DOUBLE:
+ *out_pyarray_type = NPY_FLOAT64;
+ break;
+ case TF_INT32:
+ *out_pyarray_type = NPY_INT32;
+ break;
+ case TF_UINT32:
+ *out_pyarray_type = NPY_UINT32;
+ break;
+ case TF_UINT8:
+ *out_pyarray_type = NPY_UINT8;
+ break;
+ case TF_UINT16:
+ *out_pyarray_type = NPY_UINT16;
+ break;
+ case TF_INT8:
+ *out_pyarray_type = NPY_INT8;
+ break;
+ case TF_INT16:
+ *out_pyarray_type = NPY_INT16;
+ break;
+ case TF_INT64:
+ *out_pyarray_type = NPY_INT64;
+ break;
+ case TF_UINT64:
+ *out_pyarray_type = NPY_UINT64;
+ break;
+ case TF_BOOL:
+ *out_pyarray_type = NPY_BOOL;
+ break;
+ case TF_COMPLEX64:
+ *out_pyarray_type = NPY_COMPLEX64;
+ break;
+ case TF_COMPLEX128:
+ *out_pyarray_type = NPY_COMPLEX128;
+ break;
+ case TF_STRING:
+ *out_pyarray_type = NPY_OBJECT;
+ break;
+ case TF_RESOURCE:
+ *out_pyarray_type = NPY_VOID;
+ break;
+ // TODO(keveman): These should be changed to NPY_VOID, and the type used for
+ // the resulting numpy array should be the custom struct types that we
+ // expect for quantized types.
+ case TF_QINT8:
+ *out_pyarray_type = NPY_INT8;
+ break;
+ case TF_QUINT8:
+ *out_pyarray_type = NPY_UINT8;
+ break;
+ case TF_QINT16:
+ *out_pyarray_type = NPY_INT16;
+ break;
+ case TF_QUINT16:
+ *out_pyarray_type = NPY_UINT16;
+ break;
+ case TF_QINT32:
+ *out_pyarray_type = NPY_INT32;
+ break;
+ case TF_BFLOAT16:
+ *out_pyarray_type = Bfloat16NumpyType();
+ break;
+ default:
+ return errors::Internal("Tensorflow type ", tf_datatype,
+ " not convertible to numpy dtype.");
+ }
+ return Status::OK();
+}
+
Status ArrayFromMemory(int dim_size, npy_intp* dims, void* data, DataType dtype,
std::function<void()> destructor, PyObject** result) {
if (dtype == DT_STRING || dtype == DT_RESOURCE) {
@@ -117,11 +193,15 @@
"Cannot convert string or resource Tensors.");
}
- PyArray_Descr* descr = nullptr;
- TF_RETURN_IF_ERROR(DataTypeToPyArray_Descr(dtype, &descr));
+ int type_num = -1;
+ Status s =
+ TF_DataType_to_PyArray_TYPE(static_cast<TF_DataType>(dtype), &type_num);
+ if (!s.ok()) {
+ return s;
+ }
+
auto* np_array = reinterpret_cast<PyArrayObject*>(
- PyArray_SimpleNewFromData(dim_size, dims, descr->type_num, data));
- CHECK_NE(np_array, nullptr);
+ PyArray_SimpleNewFromData(dim_size, dims, type_num, data));
PyArray_CLEARFLAGS(np_array, NPY_ARRAY_OWNDATA);
if (PyType_Ready(&TensorReleaserType) == -1) {
return errors::Unknown("Python type initialization failed.");
diff --git a/tensorflow/python/lib/core/ndarray_tensor_bridge.h b/tensorflow/python/lib/core/ndarray_tensor_bridge.h
index d6943af..029c0d3 100644
--- a/tensorflow/python/lib/core/ndarray_tensor_bridge.h
+++ b/tensorflow/python/lib/core/ndarray_tensor_bridge.h
@@ -42,6 +42,10 @@
Status ArrayFromMemory(int dim_size, npy_intp* dims, void* data, DataType dtype,
std::function<void()> destructor, PyObject** result);
+// Converts TF_DataType to the corresponding numpy type.
+Status TF_DataType_to_PyArray_TYPE(TF_DataType tf_datatype,
+ int* out_pyarray_type);
+
} // namespace tensorflow
#endif // TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_BRIDGE_H_
diff --git a/tensorflow/python/lib/core/ndarray_tensor_types.cc b/tensorflow/python/lib/core/ndarray_tensor_types.cc
deleted file mode 100644
index c255db4..0000000
--- a/tensorflow/python/lib/core/ndarray_tensor_types.cc
+++ /dev/null
@@ -1,287 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/python/lib/core/ndarray_tensor_types.h"
-
-#include <Python.h>
-
-// Must be included first.
-// clang-format: off
-#include "tensorflow/python/lib/core/numpy.h"
-// clang-format: on
-
-#include "absl/container/flat_hash_set.h"
-#include "tensorflow/core/framework/types.pb.h"
-#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/python/lib/core/bfloat16.h"
-
-namespace tensorflow {
-
-PyArray_Descr* BFLOAT16_DESCR = nullptr;
-PyArray_Descr* QINT8_DESCR = nullptr;
-PyArray_Descr* QINT16_DESCR = nullptr;
-PyArray_Descr* QINT32_DESCR = nullptr;
-PyArray_Descr* QUINT8_DESCR = nullptr;
-PyArray_Descr* QUINT16_DESCR = nullptr;
-PyArray_Descr* RESOURCE_DESCR = nullptr;
-
-// Define a struct array data type `[(tag, type_num)]`.
-PyArray_Descr* DefineStructTypeAlias(const char* tag, int type_num) {
-#if PY_MAJOR_VERSION < 3
- auto* py_tag = PyBytes_FromString(tag);
-#else
- auto* py_tag = PyUnicode_FromString(tag);
-#endif
- auto* descr = PyArray_DescrFromType(type_num);
- auto* py_tag_and_descr = PyTuple_Pack(2, py_tag, descr);
- auto* obj = PyList_New(1);
- PyList_SetItem(obj, 0, py_tag_and_descr);
- PyArray_Descr* alias_descr = nullptr;
- // TODO(slebedev): Switch to PyArray_DescrNewFromType because struct
- // array dtypes could not be used for scalars. Note that this will
- // require registering type conversions and UFunc specializations.
- // See b/144230631.
- CHECK_EQ(PyArray_DescrConverter(obj, &alias_descr), NPY_SUCCEED);
- Py_DECREF(obj);
- Py_DECREF(py_tag_and_descr);
- Py_DECREF(py_tag);
- Py_DECREF(descr);
- CHECK_NE(alias_descr, nullptr);
- return alias_descr;
-}
-
-void MaybeRegisterCustomNumPyTypes() {
- static bool registered = false;
- if (registered) return;
- ImportNumpy(); // Ensure NumPy is loaded.
- // Make sure the tags are consistent with DataTypeToPyArray_Descr.
- QINT8_DESCR = DefineStructTypeAlias("qint8", NPY_INT8);
- QINT16_DESCR = DefineStructTypeAlias("qint16", NPY_INT16);
- QINT32_DESCR = DefineStructTypeAlias("qint32", NPY_INT32);
- QUINT8_DESCR = DefineStructTypeAlias("quint8", NPY_UINT8);
- QUINT16_DESCR = DefineStructTypeAlias("quint16", NPY_UINT16);
- RESOURCE_DESCR = DefineStructTypeAlias("resource", NPY_UBYTE);
- RegisterNumpyBfloat16();
- BFLOAT16_DESCR = PyArray_DescrFromType(Bfloat16NumpyType());
- registered = true;
-}
-
-const char* PyArray_DescrReprAsString(PyArray_Descr* descr) {
- auto* descr_repr = PyObject_Repr(reinterpret_cast<PyObject*>(descr));
- const char* result;
-#if PY_MAJOR_VERSION < 3
- result = PyBytes_AsString(descr_repr);
-#else
- auto* tmp = PyUnicode_AsASCIIString(descr_repr);
- result = PyBytes_AsString(tmp);
- Py_DECREF(tmp);
-#endif
-
- Py_DECREF(descr_repr);
- return result;
-}
-
-Status DataTypeToPyArray_Descr(DataType dt, PyArray_Descr** out_descr) {
- switch (dt) {
- case DT_HALF:
- *out_descr = PyArray_DescrFromType(NPY_FLOAT16);
- break;
- case DT_FLOAT:
- *out_descr = PyArray_DescrFromType(NPY_FLOAT32);
- break;
- case DT_DOUBLE:
- *out_descr = PyArray_DescrFromType(NPY_FLOAT64);
- break;
- case DT_INT32:
- *out_descr = PyArray_DescrFromType(NPY_INT32);
- break;
- case DT_UINT32:
- *out_descr = PyArray_DescrFromType(NPY_UINT32);
- break;
- case DT_UINT8:
- *out_descr = PyArray_DescrFromType(NPY_UINT8);
- break;
- case DT_UINT16:
- *out_descr = PyArray_DescrFromType(NPY_UINT16);
- break;
- case DT_INT8:
- *out_descr = PyArray_DescrFromType(NPY_INT8);
- break;
- case DT_INT16:
- *out_descr = PyArray_DescrFromType(NPY_INT16);
- break;
- case DT_INT64:
- *out_descr = PyArray_DescrFromType(NPY_INT64);
- break;
- case DT_UINT64:
- *out_descr = PyArray_DescrFromType(NPY_UINT64);
- break;
- case DT_BOOL:
- *out_descr = PyArray_DescrFromType(NPY_BOOL);
- break;
- case DT_COMPLEX64:
- *out_descr = PyArray_DescrFromType(NPY_COMPLEX64);
- break;
- case DT_COMPLEX128:
- *out_descr = PyArray_DescrFromType(NPY_COMPLEX128);
- break;
- case DT_STRING:
- *out_descr = PyArray_DescrFromType(NPY_OBJECT);
- break;
- case DT_QINT8:
- *out_descr = PyArray_DescrFromType(NPY_INT8);
- break;
- case DT_QINT16:
- *out_descr = PyArray_DescrFromType(NPY_INT16);
- break;
- case DT_QINT32:
- *out_descr = PyArray_DescrFromType(NPY_INT32);
- break;
- case DT_QUINT8:
- *out_descr = PyArray_DescrFromType(NPY_UINT8);
- break;
- case DT_QUINT16:
- *out_descr = PyArray_DescrFromType(NPY_UINT16);
- break;
- case DT_RESOURCE:
- *out_descr = PyArray_DescrFromType(NPY_UBYTE);
- break;
- case DT_BFLOAT16:
- Py_INCREF(BFLOAT16_DESCR);
- *out_descr = BFLOAT16_DESCR;
- break;
- default:
- return errors::Internal("TensorFlow data type ", DataType_Name(dt),
- " cannot be converted to a NumPy data type.");
- }
-
- return Status::OK();
-}
-
-// NumPy defines fixed-width aliases for platform integer types. However,
-// some types do not have a fixed-width alias. Specifically
-//
-// * on a LLP64 system NPY_INT32 == NPY_LONG therefore NPY_INT is not aliased;
-// * on a LP64 system NPY_INT64 == NPY_LONG and NPY_LONGLONG is not aliased.
-//
-int MaybeResolveNumPyPlatformType(int type_num) {
- switch (type_num) {
-#if NPY_BITS_OF_INT == 32 && NPY_BITS_OF_LONGLONG == 32
- case NPY_INT:
- return NPY_INT32;
- case NPY_UINT:
- return NPY_UINT32;
-#endif
-#if NPY_BITSOF_INT == 32 && NPY_BITSOF_LONGLONG == 64
- case NPY_LONGLONG:
- return NPY_INT64;
- case NPY_ULONGLONG:
- return NPY_UINT64;
-#endif
- default:
- return type_num;
- }
-}
-
-Status PyArray_DescrToDataType(PyArray_Descr* descr, DataType* out_dt) {
- const int type_num = MaybeResolveNumPyPlatformType(descr->type_num);
- switch (type_num) {
- case NPY_FLOAT16:
- *out_dt = DT_HALF;
- break;
- case NPY_FLOAT32:
- *out_dt = DT_FLOAT;
- break;
- case NPY_FLOAT64:
- *out_dt = DT_DOUBLE;
- break;
- case NPY_INT8:
- *out_dt = DT_INT8;
- break;
- case NPY_INT16:
- *out_dt = DT_INT16;
- break;
- case NPY_INT32:
- *out_dt = DT_INT32;
- break;
- case NPY_INT64:
- *out_dt = DT_INT64;
- break;
- case NPY_UINT8:
- *out_dt = DT_UINT8;
- break;
- case NPY_UINT16:
- *out_dt = DT_UINT16;
- break;
- case NPY_UINT32:
- *out_dt = DT_UINT32;
- break;
- case NPY_UINT64:
- *out_dt = DT_UINT64;
- break;
- case NPY_BOOL:
- *out_dt = DT_BOOL;
- break;
- case NPY_COMPLEX64:
- *out_dt = DT_COMPLEX64;
- break;
- case NPY_COMPLEX128:
- *out_dt = DT_COMPLEX128;
- break;
- case NPY_OBJECT:
- case NPY_STRING:
- case NPY_UNICODE:
- *out_dt = DT_STRING;
- break;
- case NPY_VOID: {
- if (descr == QINT8_DESCR) {
- *out_dt = DT_QINT8;
- break;
- } else if (descr == QINT16_DESCR) {
- *out_dt = DT_QINT16;
- break;
- } else if (descr == QINT32_DESCR) {
- *out_dt = DT_QINT32;
- break;
- } else if (descr == QUINT8_DESCR) {
- *out_dt = DT_QUINT8;
- break;
- } else if (descr == QUINT16_DESCR) {
- *out_dt = DT_QUINT16;
- break;
- } else if (descr == RESOURCE_DESCR) {
- *out_dt = DT_RESOURCE;
- break;
- }
-
- return errors::Internal("Unsupported NumPy struct data type: ",
- PyArray_DescrReprAsString(descr));
- }
- default:
- if (type_num == Bfloat16NumpyType()) {
- *out_dt = DT_BFLOAT16;
- break;
- }
-
- return errors::Internal("Unregistered NumPy data type: ",
- PyArray_DescrReprAsString(descr));
- }
- return Status::OK();
-}
-
-} // namespace tensorflow
diff --git a/tensorflow/python/lib/core/ndarray_tensor_types.h b/tensorflow/python/lib/core/ndarray_tensor_types.h
deleted file mode 100644
index 5a4a905..0000000
--- a/tensorflow/python/lib/core/ndarray_tensor_types.h
+++ /dev/null
@@ -1,65 +0,0 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_TYPES_H_
-#define TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_TYPES_H_
-
-// Must be included first.
-// clang-format: off
-#include "tensorflow/python/lib/core/numpy.h"
-// clang-format: on
-
-#include "tensorflow/core/framework/types.pb.h"
-#include "tensorflow/core/lib/core/status.h"
-
-namespace tensorflow {
-
-extern PyArray_Descr* QINT8_DESCR;
-extern PyArray_Descr* QINT16_DESCR;
-extern PyArray_Descr* QINT32_DESCR;
-extern PyArray_Descr* QUINT8_DESCR;
-extern PyArray_Descr* QUINT16_DESCR;
-extern PyArray_Descr* RESOURCE_DESCR;
-extern PyArray_Descr* BFLOAT16_DESCR;
-
-// Register custom NumPy types.
-//
-// This function must be called in order to be able to map TensorFlow
-// data types which do not have a corresponding standard NumPy data type
-// (e.g. bfloat16 or qint8).
-//
-// TODO(b/144230631): The name is slightly misleading, as the function only
-// registers bfloat16 and defines structured aliases for other data types
-// (e.g. qint8).
-void MaybeRegisterCustomNumPyTypes();
-
-// Returns a NumPy data type matching a given tensorflow::DataType. If the
-// function call succeeds, the caller is responsible for DECREF'ing the
-// resulting PyArray_Descr*.
-//
-// NumPy does not support quantized integer types, so TensorFlow defines
-// structured aliases for them, e.g. tf.qint8 is represented as
-// np.dtype([("qint8", np.int8)]). However, for historical reasons this
-// function does not use these aliases, and instead returns the *aliased*
-// types (np.int8 in the example).
-// TODO(b/144230631): Return an alias instead of the aliased type.
-Status DataTypeToPyArray_Descr(DataType dt, PyArray_Descr** out_descr);
-
-// Returns a tensorflow::DataType corresponding to a given NumPy data type.
-Status PyArray_DescrToDataType(PyArray_Descr* descr, DataType* out_dt);
-
-} // namespace tensorflow
-
-#endif // TENSORFLOW_PYTHON_LIB_CORE_NDARRAY_TENSOR_TYPES_H_
diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc
index 8770b36..5d4916f 100644
--- a/tensorflow/python/lib/core/py_seq_tensor.cc
+++ b/tensorflow/python/lib/core/py_seq_tensor.cc
@@ -21,6 +21,7 @@
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/python/lib/core/numpy.h"
#include "tensorflow/python/lib/core/py_util.h"
@@ -396,6 +397,21 @@
// Floating-point support
+// Returns `true` if `out` overflows when converted from `as_double`.
+template <class T>
+static inline bool CheckForOverflow(double as_double, T* out) {
+ return (sizeof(T) < sizeof(double) && std::isinf(*out) &&
+ std::isfinite(as_double));
+}
+
+// There is no `std::isinf` that takes `Eigen::half` as argument but Eigen
+// provides `Eigen::half_impl::isinf` instead.
+template <>
+inline bool CheckForOverflow<Eigen::half>(double as_double, Eigen::half* out) {
+ return (sizeof(Eigen::half) < sizeof(double) &&
+ Eigen::half_impl::isinf(*out) && std::isfinite(as_double));
+}
+
template <class T>
static const char* ConvertOneFloat(PyObject* v, T* out) {
if (PyErr_Occurred()) {
@@ -405,20 +421,19 @@
const double as_double = PyFloat_AS_DOUBLE(v);
*out = static_cast<T>(as_double);
// Check for overflow
- if (TF_PREDICT_FALSE(sizeof(T) < sizeof(double) && std::isinf(*out) &&
- std::isfinite(as_double))) {
+ if (TF_PREDICT_FALSE(CheckForOverflow<T>(as_double, out))) {
return ErrorOutOfRangeDouble;
}
return nullptr;
}
#if PY_MAJOR_VERSION < 3
if (PyInt_Check(v)) {
- *out = PyInt_AS_LONG(v);
+ *out = static_cast<T>(PyInt_AS_LONG(v));
return nullptr;
}
#endif
if (PyLong_Check(v)) {
- *out = PyLong_AsDouble(v);
+ *out = static_cast<T>(PyLong_AsDouble(v));
if (PyErr_Occurred()) return ErrorOutOfRangeDouble;
return nullptr;
}
@@ -467,13 +482,7 @@
static const tensorflow::DataType kTypeEnum = DT_HALF;
static const char* ConvertScalar(PyObject* v, Eigen::half* out) {
- // NOTE(nareshmodi): Is there a way to convert to C double without the
- // intermediate Python double? This will help with ConvertOneFloat as well.
- Safe_PyObjectPtr as_float = make_safe(PyNumber_Float(v));
- double v_double = PyFloat_AS_DOUBLE(as_float.get());
- *out = Eigen::half(v_double);
-
- return nullptr;
+ return ConvertOneFloat<Eigen::half>(v, out);
}
};
@@ -613,7 +622,9 @@
break;
case DT_HALF:
- RETURN_STRING_AS_STATUS(NumpyHalfConverter::Convert(obj, &state, ret));
+ if (NumpyHalfConverter::Convert(obj, &state, ret) == nullptr)
+ return Status::OK();
+ break;
case DT_INT64:
if (Int64Converter::Convert(obj, &state, ret) == nullptr)
diff --git a/tensorflow/python/lib/io/file_io.i b/tensorflow/python/lib/io/file_io.i
deleted file mode 100644
index cbd619b..0000000
--- a/tensorflow/python/lib/io/file_io.i
+++ /dev/null
@@ -1,302 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-%include "tensorflow/python/lib/core/strings.i"
-%include "tensorflow/python/platform/base.i"
-
-%{
-#include "tensorflow/c/tf_status_helper.h"
-#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/io/buffered_inputstream.h"
-#include "tensorflow/core/lib/io/inputstream_interface.h"
-#include "tensorflow/core/lib/io/random_inputstream.h"
-#include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/file_statistics.h"
-#include "tensorflow/core/platform/file_system.h"
-#include "tensorflow/core/protobuf/meta_graph.pb.h"
-%}
-
-%{
-inline void FileExists(const string& filename, TF_Status* status) {
- tensorflow::Status s = tensorflow::Env::Default()->FileExists(filename);
- if (!s.ok()) {
- Set_TF_Status_from_Status(status, s);
- }
-}
-
-inline void FileExists(const tensorflow::StringPiece& filename,
- TF_Status* status) {
- tensorflow::Status s =
- tensorflow::Env::Default()->FileExists(string(filename));
- if (!s.ok()) {
- Set_TF_Status_from_Status(status, s);
- }
-}
-
-inline void DeleteFile(const string& filename, TF_Status* status) {
- tensorflow::Status s = tensorflow::Env::Default()->DeleteFile(filename);
- if (!s.ok()) {
- Set_TF_Status_from_Status(status, s);
- }
-}
-
-string ReadFileToString(const string& filename, TF_Status* status) {
- string file_content;
- tensorflow::Status s = ReadFileToString(tensorflow::Env::Default(),
- filename, &file_content);
- if (!s.ok()) {
- Set_TF_Status_from_Status(status, s);
- }
- return file_content;
-}
-
-void WriteStringToFile(const string& filename, const string& file_content,
- TF_Status* status) {
- tensorflow::Status s = WriteStringToFile(tensorflow::Env::Default(),
- filename, file_content);
- if (!s.ok()) {
- Set_TF_Status_from_Status(status, s);
- }
-}
-
-std::vector<string> GetChildren(const string& dir, TF_Status* status) {
- std::vector<string> results;
- tensorflow::Status s = tensorflow::Env::Default()->GetChildren(
- dir, &results);
- if (!s.ok()) {
- Set_TF_Status_from_Status(status, s);
- }
- return results;
-}
-
-std::vector<string> GetMatchingFiles(const string& filename, TF_Status* status) {
- std::vector<string> results;
- tensorflow::Status s = tensorflow::Env::Default()->GetMatchingPaths(
- filename, &results);
- if (!s.ok()) {
- Set_TF_Status_from_Status(status, s);
- }
- return results;
-}
-
-void CreateDir(const string& dirname, TF_Status* status) {
- tensorflow::Status s = tensorflow::Env::Default()->CreateDir(dirname);
- if (!s.ok() && s.code() != tensorflow::error::ALREADY_EXISTS) {
- Set_TF_Status_from_Status(status, s);
- }
-}
-
-void RecursivelyCreateDir(const string& dirname, TF_Status* status) {
- tensorflow::Status s = tensorflow::Env::Default()->RecursivelyCreateDir(
- dirname);
- if (!s.ok()) {
- Set_TF_Status_from_Status(status, s);
- }
-}
-
-void CopyFile(const string& src, const string& target, bool overwrite,
- TF_Status* status) {
- // If overwrite is false and the target file exists then its an error.
- if (!overwrite && tensorflow::Env::Default()->FileExists(target).ok()) {
- TF_SetStatus(status, TF_ALREADY_EXISTS, "file already exists");
- return;
- }
- tensorflow::Status s = tensorflow::Env::Default()->CopyFile(src, target);
- if (!s.ok()) {
- Set_TF_Status_from_Status(status, s);
- }
-}
-
-void RenameFile(const string& src, const string& target, bool overwrite,
- TF_Status* status) {
- // If overwrite is false and the target file exists then its an error.
- if (!overwrite && tensorflow::Env::Default()->FileExists(target).ok()) {
- TF_SetStatus(status, TF_ALREADY_EXISTS, "file already exists");
- return;
- }
- tensorflow::Status s = tensorflow::Env::Default()->RenameFile(src, target);
- if (!s.ok()) {
- Set_TF_Status_from_Status(status, s);
- }
-}
-
-using tensorflow::int64;
-
-void DeleteRecursively(const string& dirname, TF_Status* status) {
- int64 undeleted_files, undeleted_dirs;
- tensorflow::Status s = tensorflow::Env::Default()->DeleteRecursively(
- dirname, &undeleted_files, &undeleted_dirs);
- if (!s.ok()) {
- Set_TF_Status_from_Status(status, s);
- return;
- }
- if (undeleted_files > 0 || undeleted_dirs > 0) {
- TF_SetStatus(status, TF_PERMISSION_DENIED, "could not fully delete dir");
- return;
- }
-}
-
-bool IsDirectory(const string& dirname, TF_Status* out_status) {
- tensorflow::Status status = tensorflow::Env::Default()->IsDirectory(dirname);
- if (status.ok()) {
- return true;
- }
- // FAILED_PRECONDITION Status response means path exists but isn't a dir.
- if (status.code() != tensorflow::error::FAILED_PRECONDITION) {
- Set_TF_Status_from_Status(out_status, status);
- }
- return false;
-}
-
-using tensorflow::FileStatistics;
-
-void Stat(const string& filename, FileStatistics* stats, TF_Status* status) {
- tensorflow::Status s = tensorflow::Env::Default()->Stat(filename,
- stats);
- if (!s.ok()) {
- Set_TF_Status_from_Status(status, s);
- }
-}
-
-tensorflow::io::BufferedInputStream* CreateBufferedInputStream(
- const string& filename, size_t buffer_size, TF_Status* status) {
- std::unique_ptr<tensorflow::RandomAccessFile> file;
- tensorflow::Status s =
- tensorflow::Env::Default()->NewRandomAccessFile(filename, &file);
- if (!s.ok()) {
- Set_TF_Status_from_Status(status, s);
- return nullptr;
- }
- std::unique_ptr<tensorflow::io::RandomAccessInputStream> input_stream(
- new tensorflow::io::RandomAccessInputStream(
- file.release(), true /* owns_file */));
- std::unique_ptr<tensorflow::io::BufferedInputStream> buffered_input_stream(
- new tensorflow::io::BufferedInputStream(
- input_stream.release(), buffer_size, true /* owns_input_stream */));
- return buffered_input_stream.release();
-}
-
-tensorflow::WritableFile* CreateWritableFile(
- const string& filename, const string& mode, TF_Status* status) {
- std::unique_ptr<tensorflow::WritableFile> file;
- tensorflow::Status s;
- if (mode.find("a") != std::string::npos) {
- s = tensorflow::Env::Default()->NewAppendableFile(filename, &file);
- } else {
- s = tensorflow::Env::Default()->NewWritableFile(filename, &file);
- }
- if (!s.ok()) {
- Set_TF_Status_from_Status(status, s);
- return nullptr;
- }
- return file.release();
-}
-
-void AppendToFile(const string& file_content, tensorflow::WritableFile* file,
- TF_Status* status) {
- tensorflow::Status s = file->Append(file_content);
- if (!s.ok()) {
- Set_TF_Status_from_Status(status, s);
- }
-}
-
-int64 TellFile(tensorflow::WritableFile* file, TF_Status* status) {
- int64 position = -1;
- tensorflow::Status s = file->Tell(&position);
- if (!s.ok()) {
- Set_TF_Status_from_Status(status, s);
- }
- return position;
-}
-
-
-string ReadFromStream(tensorflow::io::BufferedInputStream* stream,
- size_t bytes,
- TF_Status* status) {
- tensorflow::tstring result;
- tensorflow::Status s = stream->ReadNBytes(bytes, &result);
- if (!s.ok() && s.code() != tensorflow::error::OUT_OF_RANGE) {
- Set_TF_Status_from_Status(status, s);
- result.clear();
- }
- return result;
-}
-
-%}
-
-// Ensure that the returned object is destroyed when its wrapper is
-// garbage collected.
-%newobject CreateBufferedInputStream;
-%newobject CreateWritableFile;
-
-// Wrap the above functions.
-inline void FileExists(const string& filename, TF_Status* status);
-inline void DeleteFile(const string& filename, TF_Status* status);
-string ReadFileToString(const string& filename, TF_Status* status);
-void WriteStringToFile(const string& filename, const string& file_content,
- TF_Status* status);
-std::vector<string> GetChildren(const string& dir, TF_Status* status);
-std::vector<string> GetMatchingFiles(const string& filename,
- TF_Status* status);
-void CreateDir(const string& dirname, TF_Status* status);
-void RecursivelyCreateDir(const string& dirname, TF_Status* status);
-void CopyFile(const string& oldpath, const string& newpath, bool overwrite,
- TF_Status* status);
-void RenameFile(const string& oldname, const string& newname, bool overwrite,
- TF_Status* status);
-void DeleteRecursively(const string& dirname, TF_Status* status);
-bool IsDirectory(const string& dirname, TF_Status* out_status);
-void Stat(const string& filename, tensorflow::FileStatistics* stats,
- TF_Status* status);
-tensorflow::io::BufferedInputStream* CreateBufferedInputStream(
- const string& filename, size_t buffer_size, TF_Status* status);
-tensorflow::WritableFile* CreateWritableFile(const string& filename,
- const string& mode,
- TF_Status* status);
-void AppendToFile(const string& file_content, tensorflow::WritableFile* file,
- TF_Status* status);
-int64 TellFile(tensorflow::WritableFile* file, TF_Status* status);
-string ReadFromStream(tensorflow::io::BufferedInputStream* stream,
- size_t bytes,
- TF_Status* status);
-
-%ignore tensorflow::Status::operator=;
-%include "tensorflow/core/platform/status.h"
-
-%ignoreall
-%unignore tensorflow::io;
-%unignore tensorflow::io::BufferedInputStream;
-%unignore tensorflow::io::BufferedInputStream::~BufferedInputStream;
-%unignore tensorflow::io::BufferedInputStream::ReadLineAsString;
-%unignore tensorflow::io::BufferedInputStream::Seek;
-%unignore tensorflow::io::BufferedInputStream::Tell;
-%unignore tensorflow::WritableFile;
-%unignore tensorflow::WritableFile::Close;
-%unignore tensorflow::WritableFile::Flush;
-%unignore tensorflow::WritableFile::~WritableFile;
-%include "tensorflow/core/platform/file_system.h"
-%include "tensorflow/core/lib/io/inputstream_interface.h"
-%include "tensorflow/core/lib/io/buffered_inputstream.h"
-%unignoreall
-
-%include "tensorflow/c/tf_status_helper.h"
-
-%ignore tensorflow::io::internal::JoinPathImpl;
-%include "tensorflow/core/lib/io/path.h"
-
-%include "tensorflow/core/platform/file_statistics.h"
diff --git a/tensorflow/python/lib/io/file_io.py b/tensorflow/python/lib/io/file_io.py
index 65c0f08..55b4359 100644
--- a/tensorflow/python/lib/io/file_io.py
+++ b/tensorflow/python/lib/io/file_io.py
@@ -12,11 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""File IO methods that wrap the C++ FileSystem API.
-
-The C++ FileSystem API is SWIG wrapped in file_io.i. These functions call those
-to accomplish basic File IO operations.
-"""
+"""File IO methods that wrap the C++ FileSystem API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -27,8 +23,7 @@
import six
-from tensorflow.python import pywrap_tensorflow
-from tensorflow.python.framework import c_api_util
+from tensorflow.python import _pywrap_file_io
from tensorflow.python.framework import errors
from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
@@ -80,15 +75,15 @@
if not self._read_check_passed:
raise errors.PermissionDeniedError(None, None,
"File isn't open for reading")
- self._read_buf = pywrap_tensorflow.CreateBufferedInputStream(
- compat.as_bytes(self.__name), 1024 * 512)
+ self._read_buf = _pywrap_file_io.BufferedInputStream(
+ self.__name, 1024 * 512)
def _prewrite_check(self):
if not self._writable_file:
if not self._write_check_passed:
raise errors.PermissionDeniedError(None, None,
"File isn't open for writing")
- self._writable_file = pywrap_tensorflow.CreateWritableFile(
+ self._writable_file = _pywrap_file_io.WritableFile(
compat.as_bytes(self.__name), compat.as_bytes(self.__mode))
def _prepare_value(self, val):
@@ -104,8 +99,7 @@
def write(self, file_content):
"""Writes file_content to the file. Appends to the end of the file."""
self._prewrite_check()
- pywrap_tensorflow.AppendToFile(
- compat.as_bytes(file_content), self._writable_file)
+ self._writable_file.append(compat.as_bytes(file_content))
def read(self, n=-1):
"""Returns the contents of a file as a string.
@@ -124,8 +118,7 @@
length = self.size() - self.tell()
else:
length = n
- return self._prepare_value(
- pywrap_tensorflow.ReadFromStream(self._read_buf, length))
+ return self._prepare_value(self._read_buf.read(length))
@deprecation.deprecated_args(
None, "position is deprecated in favor of the offset argument.",
@@ -158,25 +151,23 @@
if position is not None:
offset = position
- with errors.raise_exception_on_not_ok_status() as status:
- if whence == 0:
- pass
- elif whence == 1:
- offset += self.tell()
- elif whence == 2:
- offset += self.size()
- else:
- raise errors.InvalidArgumentError(
- None, None,
- "Invalid whence argument: {}. Valid values are 0, 1, or 2.".format(
- whence))
- ret_status = self._read_buf.Seek(offset)
- pywrap_tensorflow.Set_TF_Status_from_Status(status, ret_status)
+ if whence == 0:
+ pass
+ elif whence == 1:
+ offset += self.tell()
+ elif whence == 2:
+ offset += self.size()
+ else:
+ raise errors.InvalidArgumentError(
+ None, None,
+ "Invalid whence argument: {}. Valid values are 0, 1, or 2.".format(
+ whence))
+ self._read_buf.seek(offset)
def readline(self):
r"""Reads the next line from the file. Leaves the '\n' at the end."""
self._preread_check()
- return self._prepare_value(self._read_buf.ReadLineAsString())
+ return self._prepare_value(self._read_buf.readline())
def readlines(self):
"""Returns all lines from the file in a list."""
@@ -193,11 +184,11 @@
"""Returns the current position in the file."""
if self._read_check_passed:
self._preread_check()
- return self._read_buf.Tell()
+ return self._read_buf.tell()
else:
self._prewrite_check()
- return pywrap_tensorflow.TellFile(self._writable_file)
+ return self._writable_file.tell()
def __enter__(self):
"""Make usable with "with" statement."""
@@ -227,18 +218,14 @@
data would survive an application crash but not necessarily an OS crash.
"""
if self._writable_file:
- with errors.raise_exception_on_not_ok_status() as status:
- ret_status = self._writable_file.Flush()
- pywrap_tensorflow.Set_TF_Status_from_Status(status, ret_status)
+ self._writable_file.flush()
def close(self):
"""Closes FileIO. Should be called for the WritableFile to be flushed."""
self._read_buf = None
if self._writable_file:
- with errors.raise_exception_on_not_ok_status() as status:
- ret_status = self._writable_file.Close()
- pywrap_tensorflow.Set_TF_Status_from_Status(status, ret_status)
- self._writable_file = None
+ self._writable_file.close()
+ self._writable_file = None
def seekable(self):
"""Returns True as FileIO supports random access ops of seek()/tell()"""
@@ -277,7 +264,7 @@
errors.OpError: Propagates any errors reported by the FileSystem API.
"""
try:
- pywrap_tensorflow.FileExists(compat.as_bytes(path))
+ _pywrap_file_io.FileExists(compat.as_bytes(path))
except errors.NotFoundError:
return False
return True
@@ -308,7 +295,7 @@
errors.OpError: Propagates any errors reported by the FileSystem API. E.g.,
`NotFoundError` if the path does not exist.
"""
- pywrap_tensorflow.DeleteFile(compat.as_bytes(path))
+ _pywrap_file_io.DeleteFile(compat.as_bytes(path))
def read_file_to_string(filename, binary_mode=False):
@@ -380,7 +367,7 @@
return [
# Convert the filenames to string from bytes.
compat.as_str_any(matching_filename)
- for matching_filename in pywrap_tensorflow.GetMatchingFiles(
+ for matching_filename in _pywrap_file_io.GetMatchingFiles(
compat.as_bytes(pattern))
]
else:
@@ -388,7 +375,7 @@
# Convert the filenames to string from bytes.
compat.as_str_any(matching_filename) # pylint: disable=g-complex-comprehension
for single_filename in pattern
- for matching_filename in pywrap_tensorflow.GetMatchingFiles(
+ for matching_filename in _pywrap_file_io.GetMatchingFiles(
compat.as_bytes(single_filename))
]
@@ -422,7 +409,7 @@
Raises:
errors.OpError: If the operation fails.
"""
- pywrap_tensorflow.CreateDir(compat.as_bytes(path))
+ _pywrap_file_io.CreateDir(compat.as_bytes(path))
@tf_export(v1=["gfile.MakeDirs"])
@@ -452,7 +439,7 @@
Raises:
errors.OpError: If the operation fails.
"""
- pywrap_tensorflow.RecursivelyCreateDir(compat.as_bytes(path))
+ _pywrap_file_io.RecursivelyCreateDir(compat.as_bytes(path))
@tf_export(v1=["gfile.Copy"])
@@ -484,7 +471,7 @@
Raises:
errors.OpError: If the operation fails.
"""
- pywrap_tensorflow.CopyFile(
+ _pywrap_file_io.CopyFile(
compat.as_bytes(src), compat.as_bytes(dst), overwrite)
@@ -517,7 +504,7 @@
Raises:
errors.OpError: If the operation fails.
"""
- pywrap_tensorflow.RenameFile(
+ _pywrap_file_io.RenameFile(
compat.as_bytes(src), compat.as_bytes(dst), overwrite)
@@ -568,7 +555,7 @@
Raises:
errors.OpError: If the operation fails.
"""
- pywrap_tensorflow.DeleteRecursively(compat.as_bytes(path))
+ _pywrap_file_io.DeleteRecursively(compat.as_bytes(path))
@tf_export(v1=["gfile.IsDirectory"])
@@ -594,8 +581,10 @@
Returns:
True, if the path is a directory; False otherwise
"""
- status = c_api_util.ScopedTFStatus()
- return pywrap_tensorflow.IsDirectory(compat.as_bytes(path), status)
+ try:
+ return _pywrap_file_io.IsDirectory(compat.as_bytes(path))
+ except errors.OpError:
+ return False
@tf_export(v1=["gfile.ListDirectory"])
@@ -643,7 +632,7 @@
# vector of string should be interpreted as strings, not bytes.
return [
compat.as_str_any(filename)
- for filename in pywrap_tensorflow.GetChildren(compat.as_bytes(path))
+ for filename in _pywrap_file_io.GetChildren(compat.as_bytes(path))
]
@@ -742,9 +731,7 @@
Raises:
errors.OpError: If the operation fails.
"""
- file_statistics = pywrap_tensorflow.FileStatistics()
- pywrap_tensorflow.Stat(compat.as_bytes(path), file_statistics)
- return file_statistics
+ return _pywrap_file_io.Stat(path)
def filecmp(filename_a, filename_b):
diff --git a/tensorflow/python/lib/io/file_io_wrapper.cc b/tensorflow/python/lib/io/file_io_wrapper.cc
new file mode 100644
index 0000000..28e55f1
--- /dev/null
+++ b/tensorflow/python/lib/io/file_io_wrapper.cc
@@ -0,0 +1,205 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "include/pybind11/pybind11.h"
+#include "include/pybind11/stl.h"
+#include "tensorflow/core/lib/core/error_codes.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/io/buffered_inputstream.h"
+#include "tensorflow/core/lib/io/random_inputstream.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/file_statistics.h"
+#include "tensorflow/core/platform/file_system.h"
+#include "tensorflow/core/platform/stringpiece.h"
+#include "tensorflow/core/platform/tstring.h"
+#include "tensorflow/python/lib/core/pybind11_absl.h"
+#include "tensorflow/python/lib/core/pybind11_status.h"
+
+namespace {
+namespace py = pybind11;
+
+PYBIND11_MODULE(_pywrap_file_io, m) {
+ m.def("FileExists", [](const std::string& filename) {
+ tensorflow::MaybeRaiseRegisteredFromStatus(
+ tensorflow::Env::Default()->FileExists(filename));
+ });
+ m.def("DeleteFile", [](const std::string& filename) {
+ tensorflow::MaybeRaiseRegisteredFromStatus(
+ tensorflow::Env::Default()->DeleteFile(filename));
+ });
+ m.def("ReadFileToString", [](const std::string& filename) {
+ std::string data;
+ const auto status =
+ ReadFileToString(tensorflow::Env::Default(), filename, &data);
+ tensorflow::MaybeRaiseRegisteredFromStatus(status);
+ return py::bytes(data);
+ });
+ m.def("WriteStringToFile",
+ [](const std::string& filename, tensorflow::StringPiece data) {
+ return WriteStringToFile(tensorflow::Env::Default(), filename, data);
+ });
+ m.def("GetChildren", [](const std::string& dirname) {
+ std::vector<std::string> results;
+ const auto status =
+ tensorflow::Env::Default()->GetChildren(dirname, &results);
+ tensorflow::MaybeRaiseRegisteredFromStatus(status);
+ return results;
+ });
+ m.def("GetMatchingFiles", [](const std::string& pattern) {
+ std::vector<std::string> results;
+ const auto status =
+ tensorflow::Env::Default()->GetMatchingPaths(pattern, &results);
+ tensorflow::MaybeRaiseRegisteredFromStatus(status);
+ return results;
+ });
+ m.def("CreateDir", [](const std::string& dirname) {
+ const auto status = tensorflow::Env::Default()->CreateDir(dirname);
+ if (tensorflow::errors::IsAlreadyExists(status)) {
+ return;
+ }
+ tensorflow::MaybeRaiseRegisteredFromStatus(status);
+ });
+ m.def("RecursivelyCreateDir", [](const std::string& dirname) {
+ tensorflow::MaybeRaiseRegisteredFromStatus(
+ tensorflow::Env::Default()->RecursivelyCreateDir(dirname));
+ });
+ m.def("CopyFile",
+ [](const std::string& src, const std::string& target, bool overwrite) {
+ auto* env = tensorflow::Env::Default();
+ tensorflow::Status status;
+ if (!overwrite && env->FileExists(target).ok()) {
+ status = tensorflow::errors::AlreadyExists("file already exists");
+ } else {
+ status = env->CopyFile(src, target);
+ }
+ tensorflow::MaybeRaiseRegisteredFromStatus(status);
+ });
+ m.def("RenameFile",
+ [](const std::string& src, const std::string& target, bool overwrite) {
+ auto* env = tensorflow::Env::Default();
+ tensorflow::Status status;
+ if (!overwrite && env->FileExists(target).ok()) {
+ status = tensorflow::errors::AlreadyExists("file already exists");
+ } else {
+ status = env->RenameFile(src, target);
+ }
+ tensorflow::MaybeRaiseRegisteredFromStatus(status);
+ });
+ m.def("DeleteRecursively", [](const std::string& dirname) {
+ tensorflow::int64 undeleted_files;
+ tensorflow::int64 undeleted_dirs;
+ auto status = tensorflow::Env::Default()->DeleteRecursively(
+ dirname, &undeleted_files, &undeleted_dirs);
+ if (status.ok() && (undeleted_files > 0 || undeleted_dirs > 0)) {
+ status =
+ tensorflow::errors::PermissionDenied("could not fully delete dir");
+ }
+ tensorflow::MaybeRaiseRegisteredFromStatus(status);
+ });
+ m.def("IsDirectory", [](const std::string& dirname) {
+ const auto status = tensorflow::Env::Default()->IsDirectory(dirname);
+ // FAILED_PRECONDITION response means path exists but isn't a dir.
+ if (tensorflow::errors::IsFailedPrecondition(status)) {
+ return false;
+ }
+
+ tensorflow::MaybeRaiseRegisteredFromStatus(status);
+ return true;
+ });
+
+ py::class_<tensorflow::FileStatistics>(m, "FileStatistics")
+ .def_readonly("length", &tensorflow::FileStatistics::length)
+ .def_readonly("mtime_nsec", &tensorflow::FileStatistics::mtime_nsec)
+ .def_readonly("is_directory", &tensorflow::FileStatistics::is_directory);
+
+ m.def("Stat", [](const std::string& filename) {
+ std::unique_ptr<tensorflow::FileStatistics> self(
+ new tensorflow::FileStatistics);
+ const auto status = tensorflow::Env::Default()->Stat(filename, self.get());
+ tensorflow::MaybeRaiseRegisteredFromStatus(status);
+ return self.release();
+ });
+
+ using tensorflow::WritableFile;
+ py::class_<WritableFile>(m, "WritableFile")
+ .def(py::init([](const std::string& filename, const std::string& mode) {
+ auto* env = tensorflow::Env::Default();
+ std::unique_ptr<WritableFile> self;
+ const auto status = mode.find("a") == std::string::npos
+ ? env->NewWritableFile(filename, &self)
+ : env->NewAppendableFile(filename, &self);
+ tensorflow::MaybeRaiseRegisteredFromStatus(status);
+ return self.release();
+ }))
+ .def("append",
+ [](WritableFile* self, tensorflow::StringPiece data) {
+ tensorflow::MaybeRaiseRegisteredFromStatus(self->Append(data));
+ })
+ // TODO(slebedev): Make WritableFile::Tell const and change self
+ // to be a reference.
+ .def("tell",
+ [](WritableFile* self) {
+ tensorflow::int64 pos = -1;
+ const auto status = self->Tell(&pos);
+ tensorflow::MaybeRaiseRegisteredFromStatus(status);
+ return pos;
+ })
+ .def("flush",
+ [](WritableFile* self) {
+ tensorflow::MaybeRaiseRegisteredFromStatus(self->Flush());
+ })
+ .def("close", [](WritableFile* self) {
+ tensorflow::MaybeRaiseRegisteredFromStatus(self->Close());
+ });
+
+ using tensorflow::io::BufferedInputStream;
+ py::class_<BufferedInputStream>(m, "BufferedInputStream")
+ .def(py::init([](const std::string& filename, size_t buffer_size) {
+ std::unique_ptr<tensorflow::RandomAccessFile> file;
+ const auto status =
+ tensorflow::Env::Default()->NewRandomAccessFile(filename, &file);
+ tensorflow::MaybeRaiseRegisteredFromStatus(status);
+ std::unique_ptr<tensorflow::io::RandomAccessInputStream> input_stream(
+ new tensorflow::io::RandomAccessInputStream(file.release(),
+ /*owns_file=*/true));
+ return new BufferedInputStream(input_stream.release(), buffer_size,
+ /*owns_input_stream=*/true);
+ }))
+ .def("read",
+ [](BufferedInputStream* self, tensorflow::int64 bytes_to_read) {
+ tensorflow::tstring result;
+ const auto status = self->ReadNBytes(bytes_to_read, &result);
+ if (!status.ok() && !tensorflow::errors::IsOutOfRange(status)) {
+ result.clear();
+ tensorflow::MaybeRaiseRegisteredFromStatus(status);
+ }
+ return py::bytes(result);
+ })
+ .def("readline",
+ [](BufferedInputStream* self) {
+ return py::bytes(self->ReadLineAsString());
+ })
+ .def("seek",
+ [](BufferedInputStream* self, tensorflow::int64 pos) {
+ tensorflow::MaybeRaiseRegisteredFromStatus(self->Seek(pos));
+ })
+ .def("tell", [](BufferedInputStream* self) { return self->Tell(); });
+}
+} // namespace
diff --git a/tensorflow/python/module/module.py b/tensorflow/python/module/module.py
index bdc4d2a..714ab6a 100644
--- a/tensorflow/python/module/module.py
+++ b/tensorflow/python/module/module.py
@@ -19,6 +19,7 @@
from __future__ import print_function
import re
+
import six
from tensorflow.python import tf2
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 125c9d6..6811044 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -210,32 +210,33 @@
For example:
- ```
- # Output tensor has shape [2, 3].
- fill([2, 3], 9) ==> [[9, 9, 9]
- [9, 9, 9]]
- ```
+ >>> tf.fill([2, 3], 9)
+ <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
+ array([[9, 9, 9],
+ [9, 9, 9]], dtype=int32)>
- `tf.fill` differs from `tf.constant` in a few ways:
-
- * `tf.fill` only supports scalar contents, whereas `tf.constant` supports
- Tensor values.
- * `tf.fill` creates an Op in the computation graph that constructs the
- actual
- Tensor value at runtime. This is in contrast to `tf.constant` which embeds
- the entire Tensor into the graph with a `Const` node.
- * Because `tf.fill` evaluates at graph runtime, it supports dynamic shapes
- based on other runtime Tensors, unlike `tf.constant`.
+ `tf.fill` evaluates at graph runtime and supports dynamic shapes based on
+ other runtime `tf.Tensors`, unlike `tf.constant(value, shape=dims)`, which
+ embeds the value as a `Const` node.
Args:
- dims: A `Tensor`. Must be one of the following types: `int32`, `int64`. 1-D.
- Represents the shape of the output tensor.
- value: A `Tensor`. 0-D (scalar). Value to fill the returned tensor.
- @compatibility(numpy) Equivalent to np.full @end_compatibility
- name: A name for the operation (optional).
+ dims: A 1-D sequence of non-negative numbers. Represents the shape of the
+ output `tf.Tensor`. Entries should be of type: `int32`, `int64`.
+ value: A value to fill the returned `tf.Tensor`.
+ name: Optional string. The name of the output `tf.Tensor`.
Returns:
- A `Tensor`. Has the same type as `value`.
+ A `tf.Tensor` with shape `dims` and the same dtype as `value`.
+
+ Raises:
+ InvalidArgumentError: `dims` contains negative entries.
+ NotFoundError: `dims` contains non-integer entries.
+
+ @compatibility(numpy)
+ Similar to `np.full`. In `numpy`, more parameters are supported. Passing a
+ number argument as the shape (`np.full(5, value)`) is valid in `numpy` for
+ specifying a 1-D shaped result, while TensorFlow does not support this syntax.
+ @end_compatibility
"""
result = gen_array_ops.fill(dims, value, name=name)
tensor_util.maybe_set_static_shape(result, dims)
@@ -542,6 +543,7 @@
"""Returns the shape of a tensor.
This operation returns a 1-D integer tensor representing the shape of `input`.
+ This represents the minimal set of known information at definition time.
For example:
@@ -563,6 +565,10 @@
>>> a.shape
TensorShape([None, None, 10])
+ `tf.shape` and `Tensor.shape` should be identical in eager mode. Within
+ `tf.function` or within a `compat.v1` context, not all dimensions may be
+ known until execution time.
+
Args:
input: A `Tensor` or `SparseTensor`.
out_type: (Optional) The specified output type of the operation (`int32` or
@@ -1881,11 +1887,11 @@
@tf_export("split")
def split(value, num_or_size_splits, axis=0, num=None, name="split"):
- """Splits a tensor `value` into a list of sub tensors.
+ """Splits a tensor into sub tensors.
- If `num_or_size_splits` is an integer, then `value` is split along the
- dimension `axis` into `num_split` smaller tensors. This requires that
- `value.shape[axis]` is divisible by `num_split`.
+ If `num_or_size_splits` is an integer, then `value` is split along dimension
+ `axis` into `num_split` smaller tensors. This requires that `num_split` evenly
+ divides `value.shape[axis]`.
If `num_or_size_splits` is a 1-D Tensor (or list), we call it `size_splits`
and `value` is split into `len(size_splits)` elements. The shape of the `i`-th
@@ -1894,14 +1900,15 @@
For example:
- >>> x = tf.Variable(tf.random.uniform([5, 30], -1, 1))
-
Split `x` into 3 tensors along dimension 1
+
+ >>> x = tf.Variable(tf.random.uniform([5, 30], -1, 1))
>>> s0, s1, s2 = tf.split(x, num_or_size_splits=3, axis=1)
>>> tf.shape(s0).numpy()
array([ 5, 10], dtype=int32)
Split `x` into 3 tensors with sizes [4, 15, 11] along dimension 1
+
>>> split0, split1, split2 = tf.split(x, [4, 15, 11], 1)
>>> tf.shape(split0).numpy()
array([5, 4], dtype=int32)
@@ -1924,8 +1931,8 @@
name: A name for the operation (optional).
Returns:
- if `num_or_size_splits` is a scalar returns a list of `num_or_size_splits`
- `Tensor` objects; if `num_or_size_splits` is a 1-D Tensor returns
+ if `num_or_size_splits` is a scalar returns `num_or_size_splits` `Tensor`
+ objects; if `num_or_size_splits` is a 1-D Tensor returns
`num_or_size_splits.get_shape[0]` `Tensor` objects resulting from splitting
`value`.
@@ -1956,16 +1963,17 @@
@tf_export("transpose", v1=[])
def transpose_v2(a, perm=None, conjugate=False, name="transpose"):
- """Transposes `a`.
+ """Transposes `a`, where `a` is a Tensor.
- Permutes the dimensions according to `perm`.
+ Permutes the dimensions according to the value of `perm`.
- The returned tensor's dimension i will correspond to the input dimension
- `perm[i]`. If `perm` is not given, it is set to (n-1...0), where n is
- the rank of the input tensor. Hence by default, this operation performs a
- regular matrix transpose on 2-D input Tensors. If conjugate is True and
- `a.dtype` is either `complex64` or `complex128` then the values of `a`
- are conjugated and transposed.
+ The returned tensor's dimension `i` will correspond to the input dimension
+ `perm[i]`. If `perm` is not given, it is set to (n-1...0), where n is the rank
+ of the input tensor. Hence by default, this operation performs a regular
+ matrix transpose on 2-D input Tensors.
+
+ If conjugate is `True` and `a.dtype` is either `complex64` or `complex128`
+ then the values of `a` are conjugated and transposed.
@compatibility(numpy)
In `numpy` transposes are memory-efficient constant time operations as they
@@ -1977,43 +1985,52 @@
For example:
- ```python
- x = tf.constant([[1, 2, 3], [4, 5, 6]])
- tf.transpose(x) # [[1, 4]
- # [2, 5]
- # [3, 6]]
+ >>> x = tf.constant([[1, 2, 3], [4, 5, 6]])
+ >>> tf.transpose(x)
+ <tf.Tensor: shape=(3, 2), dtype=int32, numpy=
+ array([[1, 4],
+ [2, 5],
+ [3, 6]], dtype=int32)>
- # Equivalently
- tf.transpose(x, perm=[1, 0]) # [[1, 4]
- # [2, 5]
- # [3, 6]]
+ Equivalently, you could call `tf.transpose(x, perm=[1, 0])`.
- # If x is complex, setting conjugate=True gives the conjugate transpose
- x = tf.constant([[1 + 1j, 2 + 2j, 3 + 3j],
- [4 + 4j, 5 + 5j, 6 + 6j]])
- tf.transpose(x, conjugate=True) # [[1 - 1j, 4 - 4j],
- # [2 - 2j, 5 - 5j],
- # [3 - 3j, 6 - 6j]]
+ If `x` is complex, setting conjugate=True gives the conjugate transpose:
- # 'perm' is more useful for n-dimensional tensors, for n > 2
- x = tf.constant([[[ 1, 2, 3],
- [ 4, 5, 6]],
- [[ 7, 8, 9],
- [10, 11, 12]]])
+ >>> x = tf.constant([[1 + 1j, 2 + 2j, 3 + 3j],
+ ... [4 + 4j, 5 + 5j, 6 + 6j]])
+ >>> tf.transpose(x, conjugate=True)
+ <tf.Tensor: shape=(3, 2), dtype=complex128, numpy=
+ array([[1.-1.j, 4.-4.j],
+ [2.-2.j, 5.-5.j],
+ [3.-3.j, 6.-6.j]])>
- # Take the transpose of the matrices in dimension-0
- # (this common operation has a shorthand `linalg.matrix_transpose`)
- tf.transpose(x, perm=[0, 2, 1]) # [[[1, 4],
- # [2, 5],
- # [3, 6]],
- # [[7, 10],
- # [8, 11],
- # [9, 12]]]
- ```
+ 'perm' is more useful for n-dimensional tensors where n > 2:
+
+ >>> x = tf.constant([[[ 1, 2, 3],
+ ... [ 4, 5, 6]],
+ ... [[ 7, 8, 9],
+ ... [10, 11, 12]]])
+
+ As above, simply calling `tf.transpose` will default to `perm=[2,1,0]`.
+
+ To take the transpose of the matrices in dimension-0 (such as when you are
+ transposing matrices where 0 is the batch dimesnion), you would set
+ `perm=[0,2,1]`.
+
+ >>> tf.transpose(x, perm=[0, 2, 1])
+ <tf.Tensor: shape=(2, 3, 2), dtype=int32, numpy=
+ array([[[ 1, 4],
+ [ 2, 5],
+ [ 3, 6]],
+ [[ 7, 10],
+ [ 8, 11],
+ [ 9, 12]]], dtype=int32)>
+
+ Note: This has a shorthand `linalg.matrix_transpose`):
Args:
a: A `Tensor`.
- perm: A permutation of the dimensions of `a`.
+ perm: A permutation of the dimensions of `a`. This should be a vector.
conjugate: Optional bool. Setting it to `True` is mathematically equivalent
to tf.math.conj(tf.transpose(input)).
name: A name for the operation (optional).
diff --git a/tensorflow/python/ops/boosted_trees_ops.py b/tensorflow/python/ops/boosted_trees_ops.py
index 844b428..354180f 100644
--- a/tensorflow/python/ops/boosted_trees_ops.py
+++ b/tensorflow/python/ops/boosted_trees_ops.py
@@ -27,6 +27,7 @@
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_aggregate_stats
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_bucketize
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_feature_split as calculate_best_feature_split
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_feature_split_v2 as calculate_best_feature_split_v2
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_gains_per_feature as calculate_best_gains_per_feature
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_center_bias as center_bias
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_create_quantile_stream_resource as create_quantile_stream_resource
diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py
index 3ffaf18..242c41b 100644
--- a/tensorflow/python/ops/check_ops.py
+++ b/tensorflow/python/ops/check_ops.py
@@ -1927,6 +1927,13 @@
See also: `is_strictly_increasing`
+ >>> x1 = tf.constant([1.0, 1.0, 3.0])
+ >>> tf.math.is_non_decreasing(x1)
+ <tf.Tensor: shape=(), dtype=bool, numpy=True>
+ >>> x2 = tf.constant([3.0, 1.0, 2.0])
+ >>> tf.math.is_non_decreasing(x2)
+ <tf.Tensor: shape=(), dtype=bool, numpy=False>
+
Args:
x: Numeric `Tensor`.
name: A name for this operation (optional). Defaults to "is_non_decreasing"
@@ -1961,6 +1968,13 @@
See also: `is_non_decreasing`
+ >>> x1 = tf.constant([1.0, 2.0, 3.0])
+ >>> tf.math.is_strictly_increasing(x1)
+ <tf.Tensor: shape=(), dtype=bool, numpy=True>
+ >>> x2 = tf.constant([3.0, 1.0, 2.0])
+ >>> tf.math.is_strictly_increasing(x2)
+ <tf.Tensor: shape=(), dtype=bool, numpy=False>
+
Args:
x: Numeric `Tensor`.
name: A name for this operation (optional).
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index d33a9ad..8460301 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -2411,7 +2411,7 @@
```python
i = tf.constant(0)
c = lambda i: tf.less(i, 10)
- b = lambda i: tf.add(i, 1)
+ b = lambda i: (tf.add(i, 1), )
r = tf.while_loop(c, b, [i])
```
diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py
index c3a050f..69a19e7 100644
--- a/tensorflow/python/ops/embedding_ops.py
+++ b/tensorflow/python/ops/embedding_ops.py
@@ -518,13 +518,13 @@
elif combiner == "mean":
embeddings = math_ops.segment_sum(embeddings, segment_ids)
weight_sum = math_ops.segment_sum(weights, segment_ids)
- embeddings = math_ops.div(embeddings, weight_sum, name=name)
+ embeddings = math_ops.divide(embeddings, weight_sum, name=name)
elif combiner == "sqrtn":
embeddings = math_ops.segment_sum(embeddings, segment_ids)
weights_squared = math_ops.pow(weights, 2)
weight_sum = math_ops.segment_sum(weights_squared, segment_ids)
weight_sum_sqrt = math_ops.sqrt(weight_sum)
- embeddings = math_ops.div(embeddings, weight_sum_sqrt, name=name)
+ embeddings = math_ops.divide(embeddings, weight_sum_sqrt, name=name)
else:
assert False, "Unrecognized combiner"
else:
diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py
index a4c3caa..4698f87 100644
--- a/tensorflow/python/ops/functional_ops.py
+++ b/tensorflow/python/ops/functional_ops.py
@@ -1125,7 +1125,7 @@
if executor_type is None:
executor_type = ""
- if executing_eagerly or len(tout):
+ if executing_eagerly:
if f.stateful_ops:
outputs = gen_functional_ops.stateful_partitioned_call(
args=args,
@@ -1158,23 +1158,24 @@
# When running in graph mode, the graph and function graphs are optimized
# (i.e. run through grappler) per the session options, so we can disable any
# eager-specific rewriting.
- config_proto = attr_value_pb2.AttrValue(
- s=function_utils.get_disabled_rewriter_config())
+ config_proto = attr_value_pb2.AttrValue(s=config)
graph = ops.get_default_graph()
f.add_to_graph(graph)
op_name = "StatefulPartitionedCall" if f.stateful_ops else "PartitionedCall"
- op = graph.create_op(
- op_name,
- args,
- tout,
- name="PartitionedFunctionCall",
- attrs={
- "Tin": tin_attr,
- "Tout": tout_attr,
- "f": func_attr,
- "config_proto": config_proto,
- "executor_type": executor_type_attr,
- })
+
+ # Propagate the attribute indicating the need to compile from function to the
+ # call itself.
+ xla_compile_attr = "_XlaMustCompile"
+ op_attrs = {
+ "Tin": tin_attr,
+ "Tout": tout_attr,
+ "f": func_attr,
+ "config_proto": config_proto,
+ "executor_type": executor_type_attr,
+ }
+ if xla_compile_attr in f.definition.attr:
+ op_attrs[xla_compile_attr] = f.definition.attr[xla_compile_attr]
+ op = graph.create_op(op_name, args, tout, name=op_name, attrs=op_attrs)
outputs = op.outputs
return outputs if outputs else op
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 2db32c6..0ca39af 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -329,18 +329,12 @@
# override names. Use a dummy class to track the runtime division behavior
return DivideDelegateWithName(x, name) / y
else:
- # We could short-circuit when y is 1, but we'd still have to cast to float,
- # hence it doesn't seem to be worth optimizing.
return x / y
@tf_export("math.multiply", "multiply")
@dispatch.add_dispatch_support
-def multiply(x, y, name=None): # pylint: disable=missing-docstring
- # Do an is comparison here since this is cheaper than isinstance or __eq__
- if y is 1: # pylint: disable=literal-comparison
- return x
-
+def multiply(x, y, name=None):
return gen_math_ops.mul(x, y, name)
@@ -352,28 +346,16 @@
"2016-12-30",
"`tf.mul(x, y)` is deprecated, please use `tf.multiply(x, y)` or `x * y`")
def _mul(x, y, name=None):
- return multiply(x, y, name=name)
+ return gen_math_ops.mul(x, y, name)
_mul.__doc__ = (
gen_math_ops.mul.__doc__ + ("" if _mul.__doc__ is None else _mul.__doc__))
-def add_v2(x, y, name=None):
- # Do an is comparison here since this is cheaper than isinstance or __eq__
- if y is 0: # pylint: disable=literal-comparison
- return x
-
- return gen_math_ops.add_v2(x, y, name=name)
-
-
@tf_export("math.subtract", "subtract")
@dispatch.add_dispatch_support
def subtract(x, y, name=None):
- # Do an is comparison here since this is cheaper than isinstance or __eq__
- if y is 0: # pylint: disable=literal-comparison
- return x
-
return gen_math_ops.sub(x, y, name)
@@ -385,7 +367,7 @@
"2016-12-30",
"`tf.sub(x, y)` is deprecated, please use `tf.subtract(x, y)` or `x - y`")
def _sub(x, y, name=None):
- return subtract(x, y, name)
+ return gen_math_ops.sub(x, y, name)
_sub.__doc__ = (
@@ -1213,7 +1195,7 @@
if x.dtype == dtypes.string:
return gen_math_ops.add(x, y, name=name)
else:
- return add_v2(x, y, name=name)
+ return gen_math_ops.add_v2(x, y, name=name)
def _mul_dispatch(x, y, name=None):
@@ -1239,7 +1221,7 @@
sparse_tensor.SparseTensor)
_OverrideBinaryOperatorHelper(_add_dispatch, "add")
-_OverrideBinaryOperatorHelper(subtract, "sub")
+_OverrideBinaryOperatorHelper(gen_math_ops.sub, "sub")
_OverrideBinaryOperatorHelper(_mul_dispatch, "mul")
_OverrideBinaryOperatorHelper(_div_python2, "div")
_OverrideBinaryOperatorHelper(_truediv_python3, "truediv")
@@ -1360,6 +1342,7 @@
boolean values.
For example:
+
>>> x = tf.constant([2, 4])
>>> y = tf.constant(2)
>>> tf.math.equal(x, y)
@@ -1395,6 +1378,7 @@
of boolean values.
For example:
+
>>> x = tf.constant([2, 4])
>>> y = tf.constant(2)
>>> tf.math.not_equal(x, y)
@@ -1462,7 +1446,6 @@
For example:
- ```python
>>> start = 3
>>> limit = 18
>>> delta = 3
@@ -1482,8 +1465,6 @@
<tf.Tensor: shape=(5,), dtype=int32,
numpy=array([0, 1, 2, 3, 4], dtype=int32)>
- ```
-
Args:
start: A 0-D `Tensor` (scalar). Acts as first entry in the range if `limit`
is not None; otherwise, acts as range limit and first entry defaults to 0.
@@ -2729,6 +2710,7 @@
datatypes `bfloat16` or `float32`.
A simple 2-D tensor matrix multiplication:
+
>>> a = tf.constant([1, 2, 3, 4, 5, 6], shape=[2, 3])
>>> a # 2-D tensor
<tf.Tensor: shape=(2, 3), dtype=int32, numpy=
@@ -2746,7 +2728,8 @@
array([[ 58, 64],
[139, 154]], dtype=int32)>
- A batch matrix multiplication with batch shape [2]
+ A batch matrix multiplication with batch shape [2]:
+
>>> a = tf.constant(np.arange(1, 13, dtype=np.int32), shape=[2, 2, 3])
>>> a # 3-D tensor
<tf.Tensor: shape=(2, 2, 3), dtype=int32, numpy=
@@ -2775,6 +2758,7 @@
(see [PEP 465](https://www.python.org/dev/peps/pep-0465/)). In TensorFlow,
it simply calls the `tf.matmul()` function, so the following lines are
equivalent:
+
>>> d = a @ b @ [[10], [11]]
>>> d = tf.matmul(tf.matmul(a, b), [[10], [11]])
@@ -3375,28 +3359,27 @@
element of the input is identical to the first element of the output:
For example:
- # tf.cumsum([a, b, c]) # [a, a + b, a + b + c]
+ >>> # tf.cumsum([a, b, c]) # [a, a + b, a + b + c]
>>> x = tf.constant([2, 4, 6, 8])
>>> tf.cumsum(x)
<tf.Tensor: shape=(4,), dtype=int32,
numpy=array([ 2, 6, 12, 20], dtype=int32)>
-
- # using varying `axis` values
+
+ >>> # using varying `axis` values
>>> y = tf.constant([[2, 4, 6, 8], [1,3,5,7]])
>>> tf.cumsum(y, axis=0)
<tf.Tensor: shape=(2, 4), dtype=int32, numpy=
array([[ 2, 4, 6, 8],
[ 3, 7, 11, 15]], dtype=int32)>
-
>>> tf.cumsum(y, axis=1)
<tf.Tensor: shape=(2, 4), dtype=int32, numpy=
array([[ 2, 6, 12, 20],
[ 1, 4, 9, 16]], dtype=int32)>
-
+
By setting the `exclusive` kwarg to `True`, an exclusive cumsum is performed
instead:
-
- # tf.cumsum([a, b, c], exclusive=True) => [0, a, a + b]
+
+ >>> # tf.cumsum([a, b, c], exclusive=True) => [0, a, a + b]
>>> x = tf.constant([2, 4, 6, 8])
>>> tf.cumsum(x, exclusive=True)
<tf.Tensor: shape=(4,), dtype=int32,
@@ -3404,17 +3387,17 @@
By setting the `reverse` kwarg to `True`, the cumsum is performed in the
opposite direction:
-
- # tf.cumsum([a, b, c], reverse=True) # [a + b + c, b + c, c]
+
+ >>> # tf.cumsum([a, b, c], reverse=True) # [a + b + c, b + c, c]
>>> x = tf.constant([2, 4, 6, 8])
- >>> tf.cumsum(x, reverse=True)
+ >>> tf.cumsum(x, reverse=True)
<tf.Tensor: shape=(4,), dtype=int32,
numpy=array([20, 18, 14, 8], dtype=int32)>
This is more efficient than using separate `tf.reverse` ops.
The `reverse` and `exclusive` kwargs can also be combined:
-
- # tf.cumsum([a, b, c], exclusive=True, reverse=True) # [b + c, c, 0]
+
+ >>> # tf.cumsum([a, b, c], exclusive=True, reverse=True) # [b + c, c, 0]
>>> x = tf.constant([2, 4, 6, 8])
>>> tf.cumsum(x, exclusive=True, reverse=True)
<tf.Tensor: shape=(4,), dtype=int32,
diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py
index 54df055..37669bf 100644
--- a/tensorflow/python/ops/math_ops_test.py
+++ b/tensorflow/python/ops/math_ops_test.py
@@ -699,42 +699,5 @@
self.assertAllEqual(values, self.evaluate(tensor))
-@test_util.run_all_in_graph_and_eager_modes
-class ScalarOptimizationTest(test_util.TensorFlowTestCase):
-
- def testAddZero(self):
- x = constant_op.constant(1)
- y = math_ops.add_v2(x, 0)
- self.assertAllEqual(x, y)
- self.assertIs(x, y)
-
- # Optimization not applied
- y = math_ops.add_v2(x, constant_op.constant(0))
- self.assertAllEqual(x, y)
- self.assertIsNot(x, y)
-
- def testSubtractZero(self):
- x = constant_op.constant(1)
- y = math_ops.subtract(x, 0)
- self.assertAllEqual(x, y)
- self.assertIs(x, y)
-
- # Optimization not applied
- y = math_ops.subtract(x, constant_op.constant(0))
- self.assertAllEqual(x, y)
- self.assertIsNot(x, y)
-
- def testMultiplyOne(self):
- x = constant_op.constant(1)
- y = math_ops.multiply(x, 1)
- self.assertAllEqual(x, y)
- self.assertIs(x, y)
-
- # Optimization not applied
- y = math_ops.multiply(x, constant_op.constant(1))
- self.assertAllEqual(x, y)
- self.assertIsNot(x, y)
-
-
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py
index 54994d6..a4437d6 100644
--- a/tensorflow/python/ops/metrics_impl.py
+++ b/tensorflow/python/ops/metrics_impl.py
@@ -814,13 +814,13 @@
elif summation_method == 'careful_interpolation':
# This one is a bit tricky and is handled separately.
return interpolate_pr_auc(tp, fp, fn)
- rec = math_ops.div(tp + epsilon, tp + fn + epsilon)
+ rec = math_ops.divide(tp + epsilon, tp + fn + epsilon)
if curve == 'ROC':
- fp_rate = math_ops.div(fp, fp + tn + epsilon)
+ fp_rate = math_ops.divide(fp, fp + tn + epsilon)
x = fp_rate
y = rec
else: # curve == 'PR'.
- prec = math_ops.div(tp + epsilon, tp + fp + epsilon)
+ prec = math_ops.divide(tp + epsilon, tp + fp + epsilon)
x = rec
y = prec
if summation_method in ('trapezoidal', 'careful_interpolation'):
@@ -1184,7 +1184,7 @@
denominator = array_ops.where(
math_ops.greater(denominator, 0), denominator,
array_ops.ones_like(denominator))
- iou = math_ops.div(cm_diag, denominator)
+ iou = math_ops.divide(cm_diag, denominator)
# If the number of valid entries is 0 (no classes) we return 0.
result = array_ops.where(
@@ -1266,7 +1266,7 @@
predictions.get_shape().assert_is_compatible_with(normalizer.get_shape())
relative_errors = array_ops.where(
math_ops.equal(normalizer, 0.0), array_ops.zeros_like(labels),
- math_ops.div(math_ops.abs(labels - predictions), normalizer))
+ math_ops.divide(math_ops.abs(labels - predictions), normalizer))
return mean(relative_errors, weights, metrics_collections,
updates_collections, name or 'mean_relative_error')
@@ -2032,7 +2032,7 @@
def compute_precision(tp, fp, name):
return array_ops.where(
- math_ops.greater(tp + fp, 0), math_ops.div(tp, tp + fp), 0, name)
+ math_ops.greater(tp + fp, 0), math_ops.divide(tp, tp + fp), 0, name)
def once_across_replicas(_, true_p, false_p):
return compute_precision(true_p, false_p, 'value')
@@ -2113,7 +2113,7 @@
epsilon = 1e-7
def compute_precision(tp, fp, name):
- return math_ops.div(tp, epsilon + tp + fp, name='precision_' + name)
+ return math_ops.divide(tp, epsilon + tp + fp, name='precision_' + name)
def precision_across_replicas(_, values):
return compute_precision(values['tp'], values['fp'], 'value')
@@ -2206,7 +2206,7 @@
def compute_recall(true_p, false_n, name):
return array_ops.where(
math_ops.greater(true_p + false_n, 0),
- math_ops.div(true_p, true_p + false_n), 0, name)
+ math_ops.divide(true_p, true_p + false_n), 0, name)
def once_across_replicas(_, true_p, false_n):
return compute_recall(true_p, false_n, 'value')
@@ -2645,12 +2645,12 @@
weights=weights)
def compute_recall(_, tp, fn):
- return math_ops.div(tp, math_ops.add(tp, fn), name=scope)
+ return math_ops.divide(tp, math_ops.add(tp, fn), name=scope)
metric = _aggregate_across_replicas(
metrics_collections, compute_recall, tp, fn)
- update = math_ops.div(
+ update = math_ops.divide(
tp_update, math_ops.add(tp_update, fn_update), name='update')
if updates_collections:
ops.add_to_collections(updates_collections, update)
@@ -2720,7 +2720,7 @@
epsilon = 1e-7
def compute_recall(tp, fn, name):
- return math_ops.div(tp, epsilon + tp + fn, name='recall_' + name)
+ return math_ops.divide(tp, epsilon + tp + fn, name='recall_' + name)
def recall_across_replicas(_, values):
return compute_recall(values['tp'], values['fn'], 'value')
@@ -2884,13 +2884,13 @@
labels, predictions, thresholds, weights)
def compute_sensitivity_at_specificity(tp, tn, fp, fn, name):
- specificities = math_ops.div(tn, tn + fp + kepsilon)
+ specificities = math_ops.divide(tn, tn + fp + kepsilon)
tf_index = math_ops.argmin(math_ops.abs(specificities - specificity), 0)
tf_index = math_ops.cast(tf_index, dtypes.int32)
# Now, we have the implicit threshold, so compute the sensitivity:
- return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + kepsilon,
- name)
+ return math_ops.divide(tp[tf_index],
+ tp[tf_index] + fn[tf_index] + kepsilon, name)
def sensitivity_across_replicas(_, values):
return compute_sensitivity_at_specificity(
@@ -3070,7 +3070,7 @@
tp_per_k = math_ops.cumsum(relevant_per_k, axis=-1, name='tp_per_k')
retrieved_per_k = math_ops.cumsum(
array_ops.ones_like(relevant_per_k), axis=-1, name='retrieved_per_k')
- precision_per_k = math_ops.div(
+ precision_per_k = math_ops.divide(
math_ops.cast(tp_per_k, dtypes.float64),
math_ops.cast(retrieved_per_k, dtypes.float64),
name='precision_per_k')
@@ -3086,7 +3086,7 @@
# Divide by number of relevant items to get average precision. These are
# the "num_relevant_items" and "AveP" terms from the formula above.
num_relevant_items = math_ops.cast(_num_relevant(labels, k), dtypes.float64)
- return math_ops.div(precision_sum, num_relevant_items, name=scope)
+ return math_ops.divide(precision_sum, num_relevant_items, name=scope)
def _streaming_sparse_average_precision_at_top_k(labels,
@@ -3500,12 +3500,12 @@
weights=weights)
def precision_across_replicas(_, tp, fp):
- return math_ops.div(tp, math_ops.add(tp, fp), name=scope)
+ return math_ops.divide(tp, math_ops.add(tp, fp), name=scope)
metric = _aggregate_across_replicas(
metrics_collections, precision_across_replicas, tp, fp)
- update = math_ops.div(
+ update = math_ops.divide(
tp_update, math_ops.add(tp_update, fp_update), name='update')
if updates_collections:
ops.add_to_collections(updates_collections, update)
@@ -3718,7 +3718,7 @@
Returns:
The specificity using the aggregated values.
"""
- sensitivities = math_ops.div(tp, tp + fn + kepsilon)
+ sensitivities = math_ops.divide(tp, tp + fn + kepsilon)
# We'll need to use this trick until tf.argmax allows us to specify
# whether we should use the first or last index in case of ties.
@@ -3731,8 +3731,8 @@
tf_index = math_ops.cast(tf_index, dtypes.int32)
# Now, we have the implicit threshold, so compute the specificity:
- return math_ops.div(tn[tf_index], tn[tf_index] + fp[tf_index] + kepsilon,
- name)
+ return math_ops.divide(tn[tf_index],
+ tn[tf_index] + fp[tf_index] + kepsilon, name)
def specificity_across_replicas(_, values):
return compute_specificity_at_sensitivity(
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index 7e443b9..51eec89 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -1138,4 +1138,4 @@
grad = array_ops.expand_dims(grad, -1)
num_selected = array_ops.expand_dims(math_ops.reduce_sum(indicators, -1), -1)
- return [math_ops.div(indicators, num_selected) * grad, None]
+ return [math_ops.divide(indicators, num_selected) * grad, None]
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 0058e96..25da703 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -24,7 +24,6 @@
import numpy as np
-from tensorflow.python.compat import compat
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -45,7 +44,6 @@
from tensorflow.python.ops.gen_nn_ops import *
# pylint: enable=wildcard-import
from tensorflow.python.platform import device_context
-from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import deprecation
from tensorflow.python.util.compat import collections_abc
from tensorflow.python.util.deprecation import deprecated_args
@@ -1850,7 +1848,7 @@
filter[di, dj, q, k]
Must have `strides[0] = strides[3] = 1`. For the most common case of the same
- horizontal and vertices strides, `strides = [1, stride, stride, 1]`.
+ horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
Args:
input: A `Tensor`. Must be one of the following types:
@@ -1936,7 +1934,7 @@
* filter[di, dj, q, k]
Must have `strides[0] = strides[3] = 1`. For the most common case of the same
- horizontal and vertices strides, `strides = [1, stride, stride, 1]`.
+ horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
Args:
input: A `Tensor`. Must be one of the following types:
@@ -4403,100 +4401,51 @@
which is likely not what was intended.
"""
with ops.name_scope(name, "dropout", [x]) as name:
- # TODO(b/144930399): Remove this once the compatible window is passed.
- if compat.forward_compatible(2019, 12, 16):
- is_rate_number = isinstance(rate, numbers.Real)
- if is_rate_number and (rate < 0 or rate >= 1):
- raise ValueError("rate must be a scalar tensor or a float in the "
- "range [0, 1), got %g" % rate)
- x = ops.convert_to_tensor(x, name="x")
- x_dtype = x.dtype
- if not x_dtype.is_floating:
- raise ValueError("x has to be a floating point tensor since it's going "
- "to be scaled. Got a %s tensor instead." % x_dtype)
- is_executing_eagerly = context.executing_eagerly()
- if not tensor_util.is_tensor(rate):
- if is_rate_number:
- keep_prob = 1 - rate
- scale = 1 / keep_prob
- scale = ops.convert_to_tensor(scale, dtype=x_dtype)
- ret = gen_math_ops.mul(x, scale)
- else:
- raise ValueError("rate is neither scalar nor scalar tensor %r" % rate)
+ is_rate_number = isinstance(rate, numbers.Real)
+ if is_rate_number and (rate < 0 or rate >= 1):
+ raise ValueError("rate must be a scalar tensor or a float in the "
+ "range [0, 1), got %g" % rate)
+ x = ops.convert_to_tensor(x, name="x")
+ x_dtype = x.dtype
+ if not x_dtype.is_floating:
+ raise ValueError("x has to be a floating point tensor since it's going "
+ "to be scaled. Got a %s tensor instead." % x_dtype)
+ is_executing_eagerly = context.executing_eagerly()
+ if not tensor_util.is_tensor(rate):
+ if is_rate_number:
+ keep_prob = 1 - rate
+ scale = 1 / keep_prob
+ scale = ops.convert_to_tensor(scale, dtype=x_dtype)
+ ret = gen_math_ops.mul(x, scale)
else:
- rate.get_shape().assert_has_rank(0)
- rate_dtype = rate.dtype
- if rate_dtype != x_dtype:
- if not rate_dtype.is_compatible_with(x_dtype):
- raise ValueError(
- "Tensor dtype %s is incomptaible with Tensor dtype %s: %r" %
- (x_dtype.name, rate_dtype.name, rate))
- rate = gen_math_ops.cast(rate, x_dtype, name="rate")
- one_tensor = constant_op.constant(1, dtype=x_dtype)
- ret = gen_math_ops.real_div(x, gen_math_ops.sub(one_tensor, rate))
-
- noise_shape = _get_noise_shape(x, noise_shape)
- # Sample a uniform distribution on [0.0, 1.0) and select values larger
- # than rate.
- #
- # NOTE: Random uniform can only generate 2^23 floats on [1.0, 2.0)
- # and subtract 1.0.
- random_tensor = random_ops.random_uniform(
- noise_shape, seed=seed, dtype=x_dtype)
- # NOTE: if (1.0 + rate) - 1 is equal to rate, then that float is selected,
- # hence a >= comparison is used.
- keep_mask = random_tensor >= rate
- ret = gen_math_ops.mul(ret, gen_math_ops.cast(keep_mask, x_dtype))
- if not is_executing_eagerly:
- ret.set_shape(x.get_shape())
- return ret
+ raise ValueError("rate is neither scalar nor scalar tensor %r" % rate)
else:
- x = ops.convert_to_tensor(x, name="x")
- if not x.dtype.is_floating:
- raise ValueError("x has to be a floating point tensor since it will "
- "be scaled. Got a %s tensor instead." % x.dtype)
- if isinstance(rate, numbers.Real):
- if not (rate >= 0 and rate < 1):
- raise ValueError("rate must be a scalar tensor or a float in the "
- "range [0, 1), got %g" % rate)
- if rate > 0.5:
- logging.log_first_n(
- logging.WARN, "Large dropout rate: %g (>0.5). In TensorFlow "
- "2.x, dropout() uses dropout rate instead of keep_prob. "
- "Please ensure that this is intended.", 5, rate)
+ rate.get_shape().assert_has_rank(0)
+ rate_dtype = rate.dtype
+ if rate_dtype != x_dtype:
+ if not rate_dtype.is_compatible_with(x_dtype):
+ raise ValueError(
+ "Tensor dtype %s is incomptaible with Tensor dtype %s: %r" %
+ (x_dtype.name, rate_dtype.name, rate))
+ rate = gen_math_ops.cast(rate, x_dtype, name="rate")
+ one_tensor = constant_op.constant(1, dtype=x_dtype)
+ ret = gen_math_ops.real_div(x, gen_math_ops.sub(one_tensor, rate))
- # Early return if nothing needs to be dropped.
- if isinstance(rate, numbers.Real) and rate == 0:
- return x
- if context.executing_eagerly():
- if isinstance(rate, ops.EagerTensor):
- if rate.numpy() == 0:
- return x
- else:
- rate = ops.convert_to_tensor(rate, dtype=x.dtype, name="rate")
- rate.get_shape().assert_has_rank(0)
-
- # Do nothing if we know rate == 0
- if tensor_util.constant_value(rate) == 0:
- return x
-
- noise_shape = _get_noise_shape(x, noise_shape)
- # Sample a uniform distribution on [0.0, 1.0) and select values larger
- # than rate.
- #
- # NOTE: Random uniform can only generate 2^23 floats on [1.0, 2.0)
- # and subtract 1.0.
- random_tensor = random_ops.random_uniform(
- noise_shape, seed=seed, dtype=x.dtype)
- keep_prob = 1 - rate
- scale = 1 / keep_prob
- # NOTE: if (1.0 + rate) - 1 is equal to rate, then that
- # float is selected, hence we use a >= comparison.
- keep_mask = random_tensor >= rate
- ret = x * scale * math_ops.cast(keep_mask, x.dtype)
- if not context.executing_eagerly():
- ret.set_shape(x.get_shape())
- return ret
+ noise_shape = _get_noise_shape(x, noise_shape)
+ # Sample a uniform distribution on [0.0, 1.0) and select values larger
+ # than rate.
+ #
+ # NOTE: Random uniform can only generate 2^23 floats on [1.0, 2.0)
+ # and subtract 1.0.
+ random_tensor = random_ops.random_uniform(
+ noise_shape, seed=seed, dtype=x_dtype)
+ # NOTE: if (1.0 + rate) - 1 is equal to rate, then that float is selected,
+ # hence a >= comparison is used.
+ keep_mask = random_tensor >= rate
+ ret = gen_math_ops.mul(ret, gen_math_ops.cast(keep_mask, x_dtype))
+ if not is_executing_eagerly:
+ ret.set_shape(x.get_shape())
+ return ret
@tf_export("math.top_k", "nn.top_k")
diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops.py b/tensorflow/python/ops/parallel_for/control_flow_ops.py
index 179e664..d70a899 100644
--- a/tensorflow/python/ops/parallel_for/control_flow_ops.py
+++ b/tensorflow/python/ops/parallel_for/control_flow_ops.py
@@ -22,7 +22,7 @@
import functools
from tensorflow.python.eager import context
-from tensorflow.python.eager import function
+from tensorflow.python.eager import def_function
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
@@ -184,9 +184,21 @@
# Note that we wrap into a tf.function if in eager execution mode or under
# XLA compilation. The latter is so that we don't compile operations like
# tf.placeholder that are created by the loop body.
+ functions_run_eagerly = None
if context.executing_eagerly() or _is_under_xla_context():
- f = function.defun(f)
- return f()
+ functions_run_eagerly = def_function.functions_run_eagerly()
+ if functions_run_eagerly:
+ logging.warning(
+ "It looks like tf.function behavior was disabled, perhaps using "
+ "tf.config.experimental_run_functions_eagerly. Vectorization "
+ "primitives (e.g. tf.vectorized_map) require tf.function to work. "
+ "These primitives will override the disable.")
+ def_function.run_functions_eagerly(False)
+ f = def_function.function(f)
+ outputs = f()
+ if functions_run_eagerly is not None:
+ def_function.run_functions_eagerly(functions_run_eagerly)
+ return outputs
def _loop_fn_has_config(loop_fn):
@@ -209,6 +221,7 @@
def _pfor_impl(loop_fn, iters, parallel_iterations=None, pfor_config=None):
"""Implementation of pfor."""
+ assert not context.executing_eagerly()
loop_fn_has_config = _loop_fn_has_config(loop_fn)
existing_ops = set(ops.get_default_graph().get_operations())
# Run the loop body
diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
index 7f3930c..cdbe35f 100644
--- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
+++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
@@ -150,6 +150,15 @@
(batch_size, num_features, 1))
self.assertAllEqual(per_example_gradients[1].shape, (batch_size, 1))
+ def test_disable_tf_function(self):
+ def_function.run_functions_eagerly(True)
+ # vectorized_map should ignore disabling tf.functions
+ self.assertTrue(def_function.functions_run_eagerly())
+ self.assertAllEqual([0, 1, 4, 9],
+ pfor_control_flow_ops.vectorized_map(
+ lambda x: x * x, math_ops.range(4)))
+ self.assertTrue(def_function.functions_run_eagerly())
+
@test_util.run_all_in_graph_and_eager_modes
class IndexedSlicesTest(PForTestCase):
@@ -1477,5 +1486,36 @@
self._test_loop_fn(loop_fn, 4)
+class VariableTest(PForTestCase):
+
+ def test_create_variable_once(self):
+ x = array_ops.ones(shape=(3, 2, 2), dtype=dtypes.float32)
+ y = array_ops.ones(shape=(2, 3), dtype=dtypes.float32)
+ a_var = []
+
+ def f(z):
+ if not a_var:
+ a_var.append(variables.Variable(lambda: y, name="a"))
+ return math_ops.matmul(z, a_var[0] / 16)
+
+ pfor_control_flow_ops.vectorized_map(f, x)
+
+ @test_util.run_v2_only
+ def test_create_variable_repeated(self):
+ x = array_ops.ones(shape=(3, 2, 2), dtype=dtypes.float32)
+ y = array_ops.ones(shape=(2, 3), dtype=dtypes.float32)
+
+ def f(z):
+ a_var = variables.Variable(lambda: y, name="a") / 4
+ return math_ops.matmul(z, a_var / 16)
+
+ # Note that this error is only raised under v2 behavior.
+ with self.assertRaisesRegexp(
+ ValueError,
+ "tf.function-decorated function tried to create variables on non-first"
+ ):
+ pfor_control_flow_ops.vectorized_map(f, x)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/ops/rnn_cell_wrapper_impl.py b/tensorflow/python/ops/rnn_cell_wrapper_impl.py
index 49d61c5..f2f1737 100644
--- a/tensorflow/python/ops/rnn_cell_wrapper_impl.py
+++ b/tensorflow/python/ops/rnn_cell_wrapper_impl.py
@@ -210,7 +210,7 @@
# 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
binary_tensor = math_ops.floor(random_tensor)
- ret = math_ops.div(value, keep_prob) * binary_tensor
+ ret = math_ops.divide(value, keep_prob) * binary_tensor
ret.set_shape(value.get_shape())
return ret
diff --git a/tensorflow/python/ops/signal/dct_ops.py b/tensorflow/python/ops/signal/dct_ops.py
index 2d87af7..d628e54 100644
--- a/tensorflow/python/ops/signal/dct_ops.py
+++ b/tensorflow/python/ops/signal/dct_ops.py
@@ -34,8 +34,8 @@
raise NotImplementedError("axis must be -1. Got: %s" % axis)
if n is not None and n < 1:
raise ValueError("n should be a positive integer or None")
- if dct_type not in (1, 2, 3):
- raise ValueError("Only Types I, II and III (I)DCT are supported.")
+ if dct_type not in (1, 2, 3, 4):
+ raise ValueError("Types I, II, III and IV (I)DCT are supported.")
if dct_type == 1:
if norm == "ortho":
raise ValueError("Normalization is not supported for the Type-I DCT.")
@@ -53,22 +53,26 @@
def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin
"""Computes the 1D [Discrete Cosine Transform (DCT)][dct] of `input`.
- Currently only Types I, II and III are supported.
+ Types I, II, III and IV are supported.
Type I is implemented using a length `2N` padded `tf.signal.rfft`.
Type II is implemented using a length `2N` padded `tf.signal.rfft`, as
- described here: [Type 2 DCT using 2N FFT padded (Makhoul)](https://dsp.stackexchange.com/a/10606).
+ described here: [Type 2 DCT using 2N FFT padded (Makhoul)]
+ (https://dsp.stackexchange.com/a/10606).
Type III is a fairly straightforward inverse of Type II
- (i.e. using a length `2N` padded `tf.signal.irfft`).
+ (i.e. using a length `2N` padded `tf.signal.irfft`).
+ Type IV is calculated through 2N length DCT2 of padded signal and
+ picking the odd indices.
@compatibility(scipy)
- Equivalent to [scipy.fftpack.dct](https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html)
- for Type-I, Type-II and Type-III DCT.
+ Equivalent to [scipy.fftpack.dct]
+ (https://docs.scipy.org/doc/scipy-1.4.0/reference/generated/scipy.fftpack.dct.html)
+ for Type-I, Type-II, Type-III and Type-IV DCT.
@end_compatibility
Args:
input: A `[..., samples]` `float32`/`float64` `Tensor` containing the
signals to take the DCT of.
- type: The DCT type to perform. Must be 1, 2 or 3.
+ type: The DCT type to perform. Must be 1, 2, 3 or 4.
n: The length of the transform. If length is less than sequence length,
only the first n elements of the sequence are considered for the DCT.
If n is greater than the sequence length, zeros are padded and then
@@ -83,7 +87,7 @@
`input`.
Raises:
- ValueError: If `type` is not `1`, `2` or `3`, `axis` is
+ ValueError: If `type` is not `1`, `2`, `3` or `4`, `axis` is
not `-1`, `n` is not `None` or greater than 0,
or `norm` is not `None` or `'ortho'`.
ValueError: If `type` is `1` and `norm` is `ortho`.
@@ -163,13 +167,24 @@
return dct3
+ elif type == 4:
+ # DCT-2 of 2N length zero-padded signal, unnormalized.
+ dct2 = dct(input, type=2, n=2*axis_dim, axis=axis, norm=None)
+ # Get odd indices of DCT-2 of zero padded 2N signal to obtain
+ # DCT-4 of the original N length signal.
+ dct4 = dct2[..., 1::2]
+ if norm == "ortho":
+ dct4 *= _math.sqrt(0.5) * _math_ops.rsqrt(axis_dim_float)
+
+ return dct4
+
# TODO(rjryan): Implement `n` and `axis` parameters.
@tf_export("signal.idct", v1=["signal.idct", "spectral.idct"])
def idct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin
"""Computes the 1D [Inverse Discrete Cosine Transform (DCT)][idct] of `input`.
- Currently only Types I, II and III are supported. Type III is the inverse of
+ Currently Types I, II, III, IV are supported. Type III is the inverse of
Type II, and vice versa.
Note that you must re-normalize by 1/(2n) to obtain an inverse if `norm` is
@@ -179,14 +194,15 @@
`signal == idct(dct(signal, norm='ortho'), norm='ortho')`.
@compatibility(scipy)
- Equivalent to [scipy.fftpack.idct](https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.idct.html)
- for Type-I, Type-II and Type-III DCT.
+ Equivalent to [scipy.fftpack.idct]
+ (https://docs.scipy.org/doc/scipy-1.4.0/reference/generated/scipy.fftpack.idct.html)
+ for Type-I, Type-II, Type-III and Type-IV DCT.
@end_compatibility
Args:
input: A `[..., samples]` `float32`/`float64` `Tensor` containing the
signals to take the DCT of.
- type: The IDCT type to perform. Must be 1, 2 or 3.
+ type: The IDCT type to perform. Must be 1, 2, 3 or 4.
n: For future expansion. The length of the transform. Must be `None`.
axis: For future expansion. The axis to compute the DCT along. Must be `-1`.
norm: The normalization to apply. `None` for no normalization or `'ortho'`
@@ -205,5 +221,5 @@
https://en.wikipedia.org/wiki/Discrete_cosine_transform#Inverse_transforms
"""
_validate_dct_arguments(input, type, n, axis, norm)
- inverse_type = {1: 1, 2: 3, 3: 2}[type]
+ inverse_type = {1: 1, 2: 3, 3: 2, 4: 4}[type]
return dct(input, type=inverse_type, n=n, axis=axis, norm=norm, name=name)
diff --git a/tensorflow/python/platform/base.i b/tensorflow/python/platform/base.i
index 65a56f9..92f7d8b 100644
--- a/tensorflow/python/platform/base.i
+++ b/tensorflow/python/platform/base.i
@@ -23,6 +23,7 @@
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/python/lib/core/py_exception_registry.h"
+ using tensorflow::int64;
using tensorflow::uint64;
using tensorflow::string;
diff --git a/tensorflow/python/saved_model/load.py b/tensorflow/python/saved_model/load.py
index 4a44dd3..39e6a91 100644
--- a/tensorflow/python/saved_model/load.py
+++ b/tensorflow/python/saved_model/load.py
@@ -497,23 +497,25 @@
_Importing SavedModels from TensorFlow 1.x_
SavedModels from `tf.estimator.Estimator` or 1.x SavedModel APIs have a flat
- graph instead of `tf.function` objects. These SavedModels will have functions
- corresponding to their signatures in the `.signatures` attribute, but also
- have a `.prune` method which allows you to extract functions for new
- subgraphs. This is equivalent to importing the SavedModel and naming feeds and
- fetches in a Session from TensorFlow 1.x.
+ graph instead of `tf.function` objects. These SavedModels will be loaded with
+ the following attributes:
- ```python
- imported = tf.saved_model.load(path_to_v1_saved_model)
- pruned = imported.prune("x:0", "out:0")
- pruned(tf.ones([]))
- ```
+ * `.signatures`: A dictionary mapping signature names to functions.
+ * `.prune(feeds, fetches) `: A method which allows you to extract
+ functions for new subgraphs. This is equivalent to importing the SavedModel
+ and naming feeds and fetches in a Session from TensorFlow 1.x.
- See `tf.compat.v1.wrap_function` for details. These SavedModels also have a
- `.variables` attribute containing imported variables, and a `.graph` attribute
- representing the whole imported graph. For SavedModels exported from
- `tf.saved_model.save`, variables are instead assigned to whichever attributes
- they were assigned before export.
+ ```python
+ imported = tf.saved_model.load(path_to_v1_saved_model)
+ pruned = imported.prune("x:0", "out:0")
+ pruned(tf.ones([]))
+ ```
+
+ See `tf.compat.v1.wrap_function` for details.
+ * `.variables`: A list of imported variables.
+ * `.graph`: The whole imported graph.
+ * `.restore(save_path)`: A function that restores variables from a checkpoint
+ saved from `tf.compat.v1.Saver`.
_Consuming SavedModels asynchronously_
diff --git a/tensorflow/python/saved_model/load_v1_in_v2.py b/tensorflow/python/saved_model/load_v1_in_v2.py
index 8cbabf7..ede91da 100644
--- a/tensorflow/python/saved_model/load_v1_in_v2.py
+++ b/tensorflow/python/saved_model/load_v1_in_v2.py
@@ -91,19 +91,24 @@
# pylint: enable=protected-access
returns[0] = saver
- def restore_variables(self, wrapped, saver):
+ def _extract_saver_restore(self, wrapped, saver):
+ if saver is None:
+ return None
+ saver_def = saver.saver_def
+ filename_tensor = wrapped.graph.as_graph_element(
+ saver_def.filename_tensor_name)
+ # We both feed and fetch filename_tensor so we have an operation to use to
+ # feed into variable initializers (only relevant for v1 graph building).
+ return wrapped.prune(
+ feeds=[filename_tensor],
+ fetches=[filename_tensor,
+ wrapped.graph.as_graph_element(saver_def.restore_op_name)])
+
+ def restore_variables(self, wrapped, restore_from_saver):
"""Restores variables from the checkpoint."""
- if saver is not None:
- saver_def = saver.saver_def
- filename_tensor = wrapped.graph.as_graph_element(
- saver_def.filename_tensor_name)
- # We both feed and fetch filename_tensor so we have an operation to use to
- # feed into variable initializers (only relevant for v1 graph building).
- restore_fn = wrapped.prune(
- feeds=[filename_tensor],
- fetches=[filename_tensor,
- wrapped.graph.as_graph_element(saver_def.restore_op_name)])
- initializer, _ = restore_fn(constant_op.constant(self._variables_path))
+ if restore_from_saver is not None:
+ initializer, _ = restore_from_saver(
+ constant_op.constant(self._variables_path))
if not ops.executing_eagerly_outside_functions():
# Add the initialization operation to the table initializers collection
# in case we don't have any lifted variables to attach it to. There
@@ -203,7 +208,8 @@
functools.partial(self.load_graph, load_graph_returns, meta_graph_def),
signature=[])
saver, = load_graph_returns
- self.restore_variables(wrapped, saver)
+ restore_from_saver = self._extract_saver_restore(wrapped, saver)
+ self.restore_variables(wrapped, restore_from_saver)
with wrapped.graph.as_default():
init_op = loader_impl.get_init_op(
meta_graph_def) or monitored_session.Scaffold.default_local_init_op()
@@ -211,6 +217,9 @@
init_anchor = constant_op.constant(0., name="dummy_fetch")
root = tracking.AutoTrackable()
+ if restore_from_saver is not None:
+ root.restore = (
+ lambda path: restore_from_saver(constant_op.constant(path)))
asset_feed_tensors = []
asset_paths = []
for tensor_name, value in loader_impl.get_asset_tensors(
diff --git a/tensorflow/python/saved_model/load_v1_in_v2_test.py b/tensorflow/python/saved_model/load_v1_in_v2_test.py
index f02ab14..37b439f 100644
--- a/tensorflow/python/saved_model/load_v1_in_v2_test.py
+++ b/tensorflow/python/saved_model/load_v1_in_v2_test.py
@@ -37,9 +37,12 @@
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import init_ops
from tensorflow.python.ops import lookup_ops
+from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.saved_model import builder_impl
from tensorflow.python.saved_model import load
@@ -48,6 +51,7 @@
from tensorflow.python.saved_model import simple_save
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model import utils_impl
+from tensorflow.python.training import saver
class LoadTest(test.TestCase):
@@ -594,6 +598,38 @@
forty_two = constant_op.constant([42], dtype=dtypes.int64)
self.assertEqual([45], imported_fn(forty_two)["output"].numpy())
+ def test_load_and_restore_partitioned_variables(self):
+ export_graph = ops.Graph()
+ with export_graph.as_default():
+ partitioned_var = variable_scope.get_variable(
+ "a", shape=[6], initializer=init_ops.constant_initializer(13),
+ partitioner=partitioned_variables.fixed_size_partitioner(2),
+ use_resource=True)
+ x = array_ops.placeholder(shape=[], dtype=dtypes.float32)
+ y = x * partitioned_var
+ with session_lib.Session() as session:
+ session.run(variables.global_variables_initializer())
+ path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid()))
+ simple_save.simple_save(session, path,
+ inputs={"x": x}, outputs={"y": y})
+
+ # Create a name-based checkpoint with different values.
+ session.run(partitioned_var.assign([[5, 4, 3], [2, 1, 0]]))
+ ckpt_path = os.path.join(self.get_temp_dir(), "restore_ckpt")
+ saver.Saver().save(session, ckpt_path)
+
+ imported = load.load(path)
+ self.assertAllClose(self.evaluate(imported.variables),
+ [[13, 13, 13], [13, 13, 13]])
+
+ self.evaluate(imported.restore(ckpt_path))
+ self.assertAllClose(self.evaluate(imported.variables),
+ [[5, 4, 3], [2, 1, 0]])
+ self.assertAllClose(
+ self.evaluate(
+ imported.signatures["serving_default"](constant_op.constant(2.))),
+ {"y": [10, 8, 6, 4, 2, 0]})
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/saved_model/utils_test.py b/tensorflow/python/saved_model/utils_test.py
index d176b91..2b9e8fb 100644
--- a/tensorflow/python/saved_model/utils_test.py
+++ b/tensorflow/python/saved_model/utils_test.py
@@ -57,7 +57,7 @@
x = constant_op.constant(1, name="x")
y = constant_op.constant(2, name="y")
init_op_info = utils.build_tensor_info_from_op(my_init_fn(x, y))
- self.assertEqual("PartitionedFunctionCall", init_op_info.name)
+ self.assertEqual("PartitionedCall", init_op_info.name)
self.assertEqual(types_pb2.DT_INVALID, init_op_info.dtype)
self.assertEqual(0, len(init_op_info.tensor_shape.dim))
diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i
index 4e65120..2faff27 100644
--- a/tensorflow/python/tensorflow.i
+++ b/tensorflow/python/tensorflow.i
@@ -19,8 +19,6 @@
%include "tensorflow/python/client/tf_session.i"
-%include "tensorflow/python/lib/io/file_io.i"
-
%include "tensorflow/python/lib/io/py_record_reader.i"
%include "tensorflow/python/grappler/cluster.i"
@@ -29,3 +27,10 @@
%include "tensorflow/python/grappler/cost_analyzer.i"
%include "tensorflow/compiler/mlir/python/mlir.i"
+
+// TODO(slebedev): This is a temporary workaround for projects implicitly
+// relying on TensorFlow exposing tensorflow::Status.
+%unignoreall
+
+%ignore tensorflow::Status::operator=;
+%include "tensorflow/core/platform/status.h"
diff --git a/tensorflow/python/tools/api/generator/BUILD b/tensorflow/python/tools/api/generator/BUILD
index 664d368..bdf4871 100644
--- a/tensorflow/python/tools/api/generator/BUILD
+++ b/tensorflow/python/tools/api/generator/BUILD
@@ -82,7 +82,6 @@
srcs_version = "PY2AND3",
tags = [
"no_pip",
- "no_rocm",
],
deps = [
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/python/tools/api/generator/api_gen.bzl b/tensorflow/python/tools/api/generator/api_gen.bzl
index ec5317d..3c7f7ac 100644
--- a/tensorflow/python/tools/api/generator/api_gen.bzl
+++ b/tensorflow/python/tools/api/generator/api_gen.bzl
@@ -84,10 +84,10 @@
"""
root_init_template_flag = ""
if root_init_template:
- root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")"
+ root_init_template_flag = "--root_init_template=" + root_init_template
primary_package = packages[0]
- api_gen_binary_target = ("create_" + primary_package + "_api_%d_%s") % (api_version, name)
+ api_gen_binary_target = ("create_" + primary_package + "_api_%s") % name
native.py_binary(
name = api_gen_binary_target,
srcs = ["//tensorflow/python/tools/api/generator:create_python_api.py"],
diff --git a/tensorflow/python/tpu/BUILD b/tensorflow/python/tpu/BUILD
index 82dbc04..cf32d93 100644
--- a/tensorflow/python/tpu/BUILD
+++ b/tensorflow/python/tpu/BUILD
@@ -437,6 +437,7 @@
tf_proto_library(
name = "tensor_tracer_proto",
srcs = ["tensor_tracer.proto"],
+ cc_api_version = 2,
protodeps = [
"//tensorflow/core:protos_all",
],
diff --git a/tensorflow/python/tpu/tpu_embedding.py b/tensorflow/python/tpu/tpu_embedding.py
index fb74e7e..1e477e6 100644
--- a/tensorflow/python/tpu/tpu_embedding.py
+++ b/tensorflow/python/tpu/tpu_embedding.py
@@ -1011,17 +1011,11 @@
def _generate_enqueue_op(self, enqueue_datas, device_ordinal):
enqueue_data0 = list(enqueue_datas.values())[0]
with ops.colocate_with(enqueue_data0.embedding_indices):
- (sample_indices_list, embedding_indices_list, aggregation_weights_list,
- table_ids, max_sequence_lengths) = (
- self._format_for_tpu_embedding_sparse_tensor_batch(enqueue_datas))
return tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch(
- sample_indices_list,
- embedding_indices_list,
- aggregation_weights_list,
- table_ids,
device_ordinal=device_ordinal,
combiners=self._combiners,
- max_sequence_lengths=max_sequence_lengths)
+ **self._format_for_tpu_embedding_sparse_tensor_batch(enqueue_datas)
+ )
def _format_for_tpu_embedding_sparse_tensor_batch(self, enqueue_datas):
"""Format sparse features for `enqueue_tpu_embedding_sparse_tensor_batch()`.
@@ -1031,36 +1025,37 @@
dense.
Returns:
- Arguments for `enqueue_tpu_embedding_sparse_tensor_batch()`.
+ Dict of arguments for `enqueue_tpu_embedding_sparse_tensor_batch()`.
"""
-
- (sample_indices_list, embedding_indices_list, aggregation_weights_list,
- table_ids, max_sequence_lengths) = [], [], [], [], []
+ kwargs = {
+ 'sample_indices': [],
+ 'embedding_indices': [],
+ 'aggregation_weights': [],
+ 'table_ids': [],
+ 'max_sequence_lengths': [],
+ }
for table_id, table in enumerate(self._table_to_features_dict):
features = self._table_to_features_dict[table]
for feature in features:
enqueue_data = enqueue_datas[feature]
- sample_indices = (
+ kwargs['sample_indices'].append(
enqueue_data.sample_indices
if enqueue_data.sample_indices is not None else array_ops.zeros(
(0,), dtype=dtypes.int64))
- sample_indices_list.append(sample_indices)
- aggregation_weights = (
+ kwargs['aggregation_weights'].append(
enqueue_data.aggregation_weights if
enqueue_data.aggregation_weights is not None else array_ops.zeros(
(0,), dtype=dtypes.float32))
- aggregation_weights_list.append(aggregation_weights)
- embedding_indices_list.append(enqueue_data.embedding_indices)
+ kwargs['embedding_indices'].append(enqueue_data.embedding_indices)
- table_ids.append(table_id)
- max_sequence_lengths.append(
+ kwargs['table_ids'].append(table_id)
+ kwargs['max_sequence_lengths'].append(
self._feature_to_config_dict[feature].max_sequence_length)
- return (sample_indices_list, embedding_indices_list,
- aggregation_weights_list, table_ids, max_sequence_lengths)
+ return kwargs
def get_activations(self):
"""Get activations for features.
diff --git a/tensorflow/python/training/checkpoint_ops_test.py b/tensorflow/python/training/checkpoint_ops_test.py
index a0fd2dc..5a6a66f 100644
--- a/tensorflow/python/training/checkpoint_ops_test.py
+++ b/tensorflow/python/training/checkpoint_ops_test.py
@@ -18,6 +18,7 @@
from __future__ import print_function
import os
+
import numpy as np
from tensorflow.python.framework import constant_op
diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py
index 64ba099..1c24903 100644
--- a/tensorflow/python/training/checkpoint_utils.py
+++ b/tensorflow/python/training/checkpoint_utils.py
@@ -19,6 +19,7 @@
from __future__ import print_function
import time
+
import six
from tensorflow.python.distribute import distribution_strategy_context
diff --git a/tensorflow/python/training/checkpoint_utils_test.py b/tensorflow/python/training/checkpoint_utils_test.py
index 59972e6..caba001 100644
--- a/tensorflow/python/training/checkpoint_utils_test.py
+++ b/tensorflow/python/training/checkpoint_utils_test.py
@@ -20,6 +20,7 @@
import os
import time
+
import numpy as np
from tensorflow.core.protobuf import config_pb2
diff --git a/tensorflow/python/training/experimental/mixed_precision_test.py b/tensorflow/python/training/experimental/mixed_precision_test.py
index bbf9bad..7397ae9 100644
--- a/tensorflow/python/training/experimental/mixed_precision_test.py
+++ b/tensorflow/python/training/experimental/mixed_precision_test.py
@@ -18,6 +18,7 @@
from __future__ import print_function
import os
+
from absl.testing import parameterized
from tensorflow.core.protobuf import config_pb2
diff --git a/tensorflow/python/training/ftrl.py b/tensorflow/python/training/ftrl.py
index 0007c0e..c7b3867 100644
--- a/tensorflow/python/training/ftrl.py
+++ b/tensorflow/python/training/ftrl.py
@@ -132,11 +132,10 @@
def _create_slots(self, var_list):
# Create the "accum" and "linear" slots.
for v in var_list:
- with ops.colocate_with(v):
- val = constant_op.constant(
- self._initial_accumulator_value, dtype=v.dtype, shape=v.get_shape())
- self._get_or_make_slot(v, val, "accum", self._accum_name or self._name)
- self._zeros_slot(v, "linear", self._linear_name or self._name)
+ val = constant_op.constant(
+ self._initial_accumulator_value, dtype=v.dtype, shape=v.get_shape())
+ self._get_or_make_slot(v, val, "accum", self._accum_name or self._name)
+ self._zeros_slot(v, "linear", self._linear_name or self._name)
def _prepare(self):
self._learning_rate_tensor = ops.convert_to_tensor(
diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py
index cc5bcbb..6b9563f 100644
--- a/tensorflow/python/training/moving_averages.py
+++ b/tensorflow/python/training/moving_averages.py
@@ -175,7 +175,7 @@
if truediv:
return math_ops.truediv(numerator, denominator, name=scope.name)
else:
- return math_ops.div(numerator, denominator, name=scope.name)
+ return math_ops.divide(numerator, denominator, name=scope.name)
def _zero_debias(strategy, unbiased_var, value, decay):
diff --git a/tensorflow/python/training/session_manager.py b/tensorflow/python/training/session_manager.py
index 9c2db27..fe62209 100644
--- a/tensorflow/python/training/session_manager.py
+++ b/tensorflow/python/training/session_manager.py
@@ -18,6 +18,7 @@
from __future__ import print_function
import time
+
import numpy as np
from tensorflow.python.client import session
diff --git a/tensorflow/python/training/tracking/util.py b/tensorflow/python/training/tracking/util.py
index d40e00f..01f8620 100644
--- a/tensorflow/python/training/tracking/util.py
+++ b/tensorflow/python/training/tracking/util.py
@@ -1446,14 +1446,15 @@
"""
super(CheckpointV1, self).__init__()
for k, v in sorted(kwargs.items(), key=lambda item: item[0]):
- if not isinstance(v, (base.Trackable, def_function.Function)):
+ setattr(self, k, v)
+ if not isinstance(
+ getattr(self, k), (base.Trackable, def_function.Function)):
raise ValueError(
("`Checkpoint` was expecting a trackable object (an object "
"derived from `TrackableBase`), got %s. If you believe this "
"object should be trackable (i.e. it is part of the "
"TensorFlow Python API and manages state), please open an issue.")
% (v,))
- setattr(self, k, v)
self._save_counter = None # Created lazily for restore-on-create.
self._save_assign_op = None
self._saver = saver_with_op_caching(self)
@@ -1783,14 +1784,15 @@
"""
super(Checkpoint, self).__init__()
for k, v in sorted(kwargs.items(), key=lambda item: item[0]):
- if not isinstance(v, (base.Trackable, def_function.Function)):
+ setattr(self, k, v)
+ if not isinstance(
+ getattr(self, k), (base.Trackable, def_function.Function)):
raise ValueError(
("`Checkpoint` was expecting a trackable object (an object "
"derived from `TrackableBase`), got %s. If you believe this "
"object should be trackable (i.e. it is part of the "
"TensorFlow Python API and manages state), please open an issue.")
% (v,))
- setattr(self, k, v)
self._save_counter = None # Created lazily for restore-on-create.
self._save_assign_op = None
self._saver = saver_with_op_caching(self)
diff --git a/tensorflow/python/training/tracking/util_test.py b/tensorflow/python/training/tracking/util_test.py
index af4f504..646ca93 100644
--- a/tensorflow/python/training/tracking/util_test.py
+++ b/tensorflow/python/training/tracking/util_test.py
@@ -1055,9 +1055,9 @@
@test_util.run_in_graph_and_eager_modes
def testEmptyContainersIgnored(self):
checkpoint_directory = self.get_temp_dir()
- save_root = trackable_utils.Checkpoint()
+ save_root = trackable_utils.Checkpoint(a=[])
path = save_root.save(checkpoint_directory)
- load_root = trackable_utils.Checkpoint()
+ load_root = trackable_utils.Checkpoint(b=[])
load_root.dep = []
load_root.dep.append([])
status = load_root.restore(path)
@@ -1396,6 +1396,20 @@
load_checkpoint.restore(checkpoint_prefix).run_restore_ops()
self.assertEqual(3., self.evaluate(load_checkpoint.v))
+ def test_inititialize_with_data_structures(self):
+ checkpoint = trackable_utils.Checkpoint(
+ a=[variables_lib.Variable(0.), variables_lib.Variable(1.)],
+ b={"a": variables_lib.Variable(2.), "b": variables_lib.Variable(3.)})
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ save_path = checkpoint.save(checkpoint_prefix)
+ load_checkpoint = trackable_utils.Checkpoint(
+ a=[variables_lib.Variable(4.), variables_lib.Variable(5.)],
+ b={"a": variables_lib.Variable(6.), "b": variables_lib.Variable(7.)})
+ load_checkpoint.restore(save_path)
+ self.assertAllClose(self.evaluate(load_checkpoint.a), [0, 1])
+ self.assertAllClose(self.evaluate(load_checkpoint.b), {"a": 2, "b": 3})
+
class _ManualScope(tracking.AutoTrackable):
diff --git a/tensorflow/python/training/warm_starting_util.py b/tensorflow/python/training/warm_starting_util.py
index fa334a4..4d329b4 100644
--- a/tensorflow/python/training/warm_starting_util.py
+++ b/tensorflow/python/training/warm_starting_util.py
@@ -19,6 +19,7 @@
from __future__ import print_function
import collections
+
import six
from tensorflow.python.framework import errors
diff --git a/tensorflow/python/training/warm_starting_util_test.py b/tensorflow/python/training/warm_starting_util_test.py
index 14fa243..084779b 100644
--- a/tensorflow/python/training/warm_starting_util_test.py
+++ b/tensorflow/python/training/warm_starting_util_test.py
@@ -19,6 +19,7 @@
from __future__ import print_function
import os
+
import numpy as np
import six
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index 2b4ecd4..6187e32 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -87,7 +87,7 @@
def _sorted(dict_):
"""Returns a sorted list of the dict keys, with error if keys not sortable."""
try:
- return sorted(dict_)
+ return sorted(dict_.keys())
except TypeError:
raise TypeError("nest only supports dicts with sortable keys.")
diff --git a/tensorflow/stream_executor/gpu/redzone_allocator.cc b/tensorflow/stream_executor/gpu/redzone_allocator.cc
index 89f514c..7d21062 100644
--- a/tensorflow/stream_executor/gpu/redzone_allocator.cc
+++ b/tensorflow/stream_executor/gpu/redzone_allocator.cc
@@ -311,7 +311,8 @@
std::call_once(ptxas_not_found_logged, [&]() {
LOG(WARNING) << compiled_ptx_or.status().ToString()
<< "\nRelying on driver to perform ptx compilation. "
- << "This message will be only logged once.";
+ << "\nModify $PATH to customize ptxas location."
+ << "\nThis message will be only logged once.";
});
}
diff --git a/tensorflow/stream_executor/platform/BUILD b/tensorflow/stream_executor/platform/BUILD
index e2ada9d..3d20cc2 100644
--- a/tensorflow/stream_executor/platform/BUILD
+++ b/tensorflow/stream_executor/platform/BUILD
@@ -1,3 +1,4 @@
+load("//tensorflow/core/platform:build_config.bzl", "tf_platform_deps")
load("//tensorflow/stream_executor:build_defs.bzl", "stream_executor_friends")
package(
@@ -13,17 +14,16 @@
cc_library(
name = "platform",
textual_hdrs = [
+ "initialize.h",
"logging.h",
"platform.h",
"port.h",
"thread_annotations.h",
- "initialize.h",
],
deps = [
- "//tensorflow/core:lib",
- "//tensorflow/stream_executor/platform/default:platform",
"@com_google_absl//absl/strings",
- ],
+ "//tensorflow/core:lib",
+ ] + tf_platform_deps("platform", "//tensorflow/stream_executor/platform/"),
)
cc_library(
@@ -31,6 +31,5 @@
hdrs = ["dso_loader.h"],
deps = [
":platform",
- "//tensorflow/stream_executor/platform/default:dso_loader",
- ],
+ ] + tf_platform_deps("dso_loader", "//tensorflow/stream_executor/platform/"),
)
diff --git a/tensorflow/stream_executor/platform/default/BUILD b/tensorflow/stream_executor/platform/default/BUILD
index bd6404b..032dc51 100644
--- a/tensorflow/stream_executor/platform/default/BUILD
+++ b/tensorflow/stream_executor/platform/default/BUILD
@@ -6,9 +6,7 @@
cc_library(
name = "platform",
- textual_hdrs = [
- "initialize.h",
- ],
+ textual_hdrs = ["initialize.h"],
deps = ["//tensorflow/core:lib"],
)
@@ -21,6 +19,7 @@
}),
hdrs = ["dso_loader.h"],
copts = tf_copts(),
+ tags = ["nobuilder"],
deps = [
"//tensorflow/stream_executor:platform",
"//tensorflow/stream_executor/lib",
diff --git a/tensorflow/stream_executor/platform/logging.h b/tensorflow/stream_executor/platform/logging.h
index 6bc6ccb..348349b 100644
--- a/tensorflow/stream_executor/platform/logging.h
+++ b/tensorflow/stream_executor/platform/logging.h
@@ -19,7 +19,7 @@
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/stream_executor/platform/port.h"
-#if !defined(PLATFORM_GOOGLE)
+#if !defined(PLATFORM_GOOGLE) && !defined(PLATFORM_GOOGLE_ANDROID)
#define PCHECK(invocation) CHECK(invocation)
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 3ba4456..265371d 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -267,6 +267,8 @@
# "/EHs-c-",
"/wd4577",
"/DNOGDI",
+ # Also see build:windows lines in tensorflow/opensource_only/.bazelrc
+ # where we set some other options globally.
]
if is_external:
return WINDOWS_COPTS + ["/UTF_COMPILE_LIBRARY"]
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.activations.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.activations.pbtxt
index eb315e3..ee3d1f3 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.activations.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.activations.pbtxt
@@ -53,6 +53,10 @@
argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "swish"
+ argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "tanh"
argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt
index aba2d4cd..f269a54 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt
@@ -139,7 +139,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'units\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'units\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'None\', \'True\', \'zeros\', \'zeros\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense-features.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense-features.pbtxt
index 9d54752..2785166 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense-features.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense-features.pbtxt
@@ -113,7 +113,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'feature_columns\', \'trainable\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'True\', \'None\'], "
+ argspec: "args=[\'self\', \'feature_columns\', \'trainable\', \'name\', \'partitioner\'], varargs=None, keywords=kwargs, defaults=[\'True\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt
index e6a8267..6f0000b 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt
@@ -49,6 +49,14 @@
argspec: "args=[\'fname\', \'origin\', \'untar\', \'md5_hash\', \'file_hash\', \'cache_subdir\', \'hash_algorithm\', \'extract\', \'archive_format\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'datasets\', \'auto\', \'False\', \'auto\', \'None\'], "
}
member_method {
+ name: "get_registered_name"
+ argspec: "args=[\'obj\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_registered_object"
+ argspec: "args=[\'name\', \'custom_objects\', \'module_objects\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
name: "get_source_inputs"
argspec: "args=[\'tensor\', \'layer\', \'node_index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
index 2441232..e4bd8c5 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
@@ -493,6 +493,10 @@
argspec: "args=[\'node_id_range\', \'stats_summary\', \'l1\', \'l2\', \'tree_complexity\', \'min_node_weight\', \'logits_dimension\', \'split_type\', \'name\'], varargs=None, keywords=None, defaults=[\'inequality\', \'None\'], "
}
member_method {
+ name: "BoostedTreesCalculateBestFeatureSplitV2"
+ argspec: "args=[\'node_id_range\', \'stats_summaries_list\', \'split_types\', \'candidate_feature_ids\', \'l1\', \'l2\', \'tree_complexity\', \'min_node_weight\', \'logits_dimension\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "BoostedTreesCalculateBestGainsPerFeature"
argspec: "args=[\'node_id_range\', \'stats_summary_list\', \'l1\', \'l2\', \'tree_complexity\', \'min_node_weight\', \'max_splits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.activations.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.activations.pbtxt
index eb315e3..ee3d1f3 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.activations.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.activations.pbtxt
@@ -53,6 +53,10 @@
argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "swish"
+ argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "tanh"
argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt
index aba2d4cd..f269a54 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt
@@ -139,7 +139,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'units\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'units\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'None\', \'True\', \'zeros\', \'zeros\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt
index e6a8267..6f0000b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt
@@ -49,6 +49,14 @@
argspec: "args=[\'fname\', \'origin\', \'untar\', \'md5_hash\', \'file_hash\', \'cache_subdir\', \'hash_algorithm\', \'extract\', \'archive_format\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'datasets\', \'auto\', \'False\', \'auto\', \'None\'], "
}
member_method {
+ name: "get_registered_name"
+ argspec: "args=[\'obj\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_registered_object"
+ argspec: "args=[\'name\', \'custom_objects\', \'module_objects\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
name: "get_source_inputs"
argspec: "args=[\'tensor\', \'layer\', \'node_index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
index 2441232..e4bd8c5 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
@@ -493,6 +493,10 @@
argspec: "args=[\'node_id_range\', \'stats_summary\', \'l1\', \'l2\', \'tree_complexity\', \'min_node_weight\', \'logits_dimension\', \'split_type\', \'name\'], varargs=None, keywords=None, defaults=[\'inequality\', \'None\'], "
}
member_method {
+ name: "BoostedTreesCalculateBestFeatureSplitV2"
+ argspec: "args=[\'node_id_range\', \'stats_summaries_list\', \'split_types\', \'candidate_feature_ids\', \'l1\', \'l2\', \'tree_complexity\', \'min_node_weight\', \'logits_dimension\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "BoostedTreesCalculateBestGainsPerFeature"
argspec: "args=[\'node_id_range\', \'stats_summary_list\', \'l1\', \'l2\', \'tree_complexity\', \'min_node_weight\', \'max_splits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py
index 383dbb4..3680cad 100644
--- a/tensorflow/tools/api/tests/api_compatibility_test.py
+++ b/tensorflow/tools/api/tests/api_compatibility_test.py
@@ -367,7 +367,9 @@
api_version=api_version)
def testAPIBackwardsCompatibility(self):
- api_version = 2 if '_api.v2' in tf.bitwise.__name__ else 1
+ api_version = 1
+ if hasattr(tf, '_major_api_version') and tf._major_api_version == 2:
+ api_version = 2
golden_file_pattern = os.path.join(
resource_loader.get_root_dir_with_all_resources(),
_KeyToFilePath('*', api_version))
@@ -390,7 +392,7 @@
# Also check that V1 API has contrib
self.assertTrue(
api_version == 2 or
- 'tensorflow.python.util.lazy_loader.LazyLoader'
+ 'LazyLoader'
in str(type(tf.contrib)))
# Check that V2 API does not have contrib
self.assertTrue(api_version == 1 or not hasattr(tf, 'contrib'))
diff --git a/tensorflow/tools/api/tests/module_test.py b/tensorflow/tools/api/tests/module_test.py
index 1732ba4..2b3a7db 100644
--- a/tensorflow/tools/api/tests/module_test.py
+++ b/tensorflow/tools/api/tests/module_test.py
@@ -73,7 +73,7 @@
tf.summary.image
# If we use v2 API, check for create_file_writer,
# otherwise check for FileWriter.
- if '._api.v2' in tf.bitwise.__name__:
+ if hasattr(tf, '_major_api_version') and tf._major_api_version == 2:
tf.summary.create_file_writer
else:
tf.summary.FileWriter
diff --git a/tensorflow/tools/ci_build/Dockerfile.rbe.rocm-ubuntu16.04 b/tensorflow/tools/ci_build/Dockerfile.rbe.rocm-ubuntu16.04
index 5bf7d05..7fb037f 100644
--- a/tensorflow/tools/ci_build/Dockerfile.rbe.rocm-ubuntu16.04
+++ b/tensorflow/tools/ci_build/Dockerfile.rbe.rocm-ubuntu16.04
@@ -16,8 +16,7 @@
RUN apt-get update --allow-insecure-repositories && \
DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
rocm-dev rocm-libs hipcub rocm-utils rocm-cmake \
- rocfft miopen-hip miopengemm rocblas hipblas rocrand rccl \
- rocm-profiler cxlactivitylogger && \
+ rocfft miopen-hip miopengemm rocblas hipblas rocrand rccl && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
diff --git a/tensorflow/tools/ci_build/Dockerfile.rocm b/tensorflow/tools/ci_build/Dockerfile.rocm
index a083bc6..70029d2 100644
--- a/tensorflow/tools/ci_build/Dockerfile.rocm
+++ b/tensorflow/tools/ci_build/Dockerfile.rocm
@@ -58,8 +58,7 @@
RUN apt-get update --allow-insecure-repositories && \
DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
rocm-dev rocm-libs hipcub rocm-utils rocm-cmake \
- rocfft miopen-hip miopengemm rocblas hipblas rocrand rccl \
- rocm-profiler cxlactivitylogger && \
+ rocfft miopen-hip miopengemm rocblas hipblas rocrand rccl && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
diff --git a/tensorflow/tools/ci_build/install/install_centos_pip_packages.sh b/tensorflow/tools/ci_build/install/install_centos_pip_packages.sh
index 94e2aaa..ce7789b 100755
--- a/tensorflow/tools/ci_build/install/install_centos_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_centos_pip_packages.sh
@@ -57,8 +57,8 @@
pip2 install --upgrade numpy==1.14.5
pip3 install --upgrade numpy==1.14.5
-pip2 install scipy==1.1.0
-pip3 install scipy==1.1.0
+pip2 install scipy==1.2.2
+pip3 install scipy==1.4.1
pip2 install scikit-learn==0.18.1
pip3 install scikit-learn==0.18.1
diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh
index d0c922b..7fdb8bb 100755
--- a/tensorflow/tools/ci_build/install/install_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh
@@ -78,8 +78,8 @@
pip3 install --upgrade numpy==1.14.5
fi
-pip2 install scipy==1.1.0
-pip3 install scipy==1.1.0
+pip2 install scipy==1.2.2
+pip3 install scipy==1.4.1
pip2 install scikit-learn==0.18.1
pip3 install scikit-learn==0.18.1
diff --git a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
index 1053d99..bb53fc9 100755
--- a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
@@ -64,7 +64,7 @@
# This workaround isn't needed for Ubuntu 16.04 or later.
pip3.5 install --no-binary=:all: --upgrade numpy==1.14.5
-pip3.5 install scipy==0.18.1
+pip3.5 install scipy==1.4.1
pip3.5 install scikit-learn==0.19.1
diff --git a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
index 3a28890..bcf0d0b 100755
--- a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
@@ -76,7 +76,7 @@
# This workaround isn't needed for Ubuntu 16.04 or later.
pip3 install --no-binary=:all: --upgrade numpy==1.14.5
-pip3 install scipy==0.18.1
+pip3 install scipy==1.4.1
pip3 install scikit-learn==0.19.1
diff --git a/tensorflow/tools/ci_build/linux/gpu/run_cc_core.sh b/tensorflow/tools/ci_build/linux/gpu/run_cc_core.sh
index c07e1a0..af3bf0d 100755
--- a/tensorflow/tools/ci_build/linux/gpu/run_cc_core.sh
+++ b/tensorflow/tools/ci_build/linux/gpu/run_cc_core.sh
@@ -40,4 +40,4 @@
--build_tests_only --test_output=errors --local_test_jobs=8 --config=opt \
--test_size_filters=small,medium \
--run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute -- \
- //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/...
+ //tensorflow/... -//tensorflow/compiler/...
diff --git a/tensorflow/tools/ci_build/linux/gpu/run_mkl.sh b/tensorflow/tools/ci_build/linux/gpu/run_mkl.sh
index 50ee07e..aa22f8f 100755
--- a/tensorflow/tools/ci_build/linux/gpu/run_mkl.sh
+++ b/tensorflow/tools/ci_build/linux/gpu/run_mkl.sh
@@ -43,5 +43,5 @@
--test_timeout 300,450,1200,3600 --build_tests_only --test_env=KMP_BLOCKTIME=0\
--config=mkl --config=opt --test_output=errors --local_test_jobs=8 \
--run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute -- \
- //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/...
+ //tensorflow/... -//tensorflow/compiler/...
diff --git a/tensorflow/tools/ci_build/linux/gpu/run_py3_core.sh b/tensorflow/tools/ci_build/linux/gpu/run_py3_core.sh
index 7cefca0..b0f47b1 100755
--- a/tensorflow/tools/ci_build/linux/gpu/run_py3_core.sh
+++ b/tensorflow/tools/ci_build/linux/gpu/run_py3_core.sh
@@ -40,4 +40,4 @@
--build_tests_only --test_output=errors --local_test_jobs=8 --config=opt \
--test_size_filters=small,medium \
--run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute -- \
- //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/...
+ //tensorflow/... -//tensorflow/compiler/...
diff --git a/tensorflow/tools/ci_build/linux/rocm/run_cc_core.sh b/tensorflow/tools/ci_build/linux/rocm/run_cc_core.sh
index 0286d0a..0eb7fec 100755
--- a/tensorflow/tools/ci_build/linux/rocm/run_cc_core.sh
+++ b/tensorflow/tools/ci_build/linux/rocm/run_cc_core.sh
@@ -35,10 +35,33 @@
yes "" | $PYTHON_BIN_PATH configure.py
# Run bazel test command. Double test timeouts to avoid flakes.
-bazel test --config=rocm --test_tag_filters=-no_oss,-oss_serial,-no_gpu,-no_rocm,-benchmark-test -k \
- --test_lang_filters=cc --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \
- --build_tests_only --test_output=errors --local_test_jobs=${TF_GPU_COUNT} --config=opt \
- --test_sharding_strategy=disabled \
- --test_size_filters=small,medium \
- --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute -- \
- //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/...
+bazel test \
+ --config=rocm \
+ -k \
+ --test_tag_filters=-no_oss,-oss_serial,-no_gpu,-no_rocm,-benchmark-test,-rocm_multi_gpu,-v1only \
+ --test_lang_filters=cc \
+ --jobs=${N_JOBS} \
+ --local_test_jobs=${TF_GPU_COUNT}\
+ --test_timeout 300,450,1200,3600 \
+ --build_tests_only \
+ --test_output=errors \
+ --test_sharding_strategy=disabled \
+ --test_size_filters=small,medium \
+ --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute \
+ -- \
+ //tensorflow/... \
+ -//tensorflow/compiler/... \
+ -//tensorflow/lite/delegates/gpu/gl/... \
+ -//tensorflow/lite/delegates/gpu/cl/... \
+&& bazel test \
+ --config=rocm \
+ -k \
+ --test_tag_filters=-no_gpu,-no_rocm,-v1only \
+ --jobs=${N_JOBS} \
+ --local_test_jobs=1 \
+ --test_timeout 600,900,2400,7200 \
+ --build_tests_only \
+ --test_output=errors \
+ --test_sharding_strategy=disabled \
+ -- \
+ //tensorflow/core/nccl:nccl_manager_test
diff --git a/tensorflow/tools/ci_build/linux/rocm/run_py3_core.sh b/tensorflow/tools/ci_build/linux/rocm/run_py3_core.sh
index 424b3e6..64bfffa 100755
--- a/tensorflow/tools/ci_build/linux/rocm/run_py3_core.sh
+++ b/tensorflow/tools/ci_build/linux/rocm/run_py3_core.sh
@@ -35,9 +35,18 @@
yes "" | $PYTHON_BIN_PATH configure.py
# Run bazel test command. Double test timeouts to avoid flakes.
-bazel test --config=rocm --test_tag_filters=-no_oss,-oss_serial,-no_gpu,-no_rocm,-benchmark-test -k \
- --test_lang_filters=py --jobs=${N_JOBS} --test_timeout 600,900,2400,7200 \
- --build_tests_only --test_output=errors --local_test_jobs=${TF_GPU_COUNT} --config=opt \
- --test_sharding_strategy=disabled \
- --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute -- \
- //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/...
+bazel test \
+ --config=rocm \
+ -k \
+ --test_tag_filters=-no_oss,-oss_serial,-no_gpu,-no_rocm,-benchmark-test,-rocm_multi_gpu,-v1only \
+ --test_lang_filters=py \
+ --jobs=${N_JOBS} \
+ --local_test_jobs=${TF_GPU_COUNT} \
+ --test_timeout 600,900,2400,7200 \
+ --build_tests_only \
+ --test_output=errors \
+ --test_sharding_strategy=disabled \
+ --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute \
+ -- \
+ //tensorflow/... \
+ -//tensorflow/compiler/...
diff --git a/tensorflow/tools/ci_build/presubmit/macos/py2_cc/build.sh b/tensorflow/tools/ci_build/presubmit/macos/py2_cc/build.sh
new file mode 100644
index 0000000..92acb7a
--- /dev/null
+++ b/tensorflow/tools/ci_build/presubmit/macos/py2_cc/build.sh
@@ -0,0 +1,70 @@
+#!/bin/bash
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# TODO(mihaimaruseac,hyey,ggadde): Convert to py3
+
+set -e
+
+# Error if we somehow forget to set the path to bazel_wrapper.py
+set -u
+BAZEL_WRAPPER_PATH=$1
+set +u
+
+# From this point on, logs can be publicly available
+set -x
+
+function setup_pip () {
+ install_pip2
+ python -m virtualenv tf_build_env --system-site-packages
+ source tf_build_env/bin/activate
+ install_macos_pip_deps
+}
+
+function run_build () {
+ # Run configure.
+ export TF_NEED_CUDA=0
+ export PYTHON_BIN_PATH=$(which python2)
+ yes "" | $PYTHON_BIN_PATH configure.py
+ tag_filters="-no_oss,-no_oss_py2,-gpu,-tpu,-benchmark-test,-nomac,-no_mac,-v1only"
+
+ # Get the default test targets for bazel.
+ source tensorflow/tools/ci_build/build_scripts/PRESUBMIT_BUILD_TARGETS.sh
+
+ "${BAZEL_WRAPPER_PATH}" \
+ test \
+ --build_tag_filters="${tag_filters}" \
+ --test_tag_filters="${tag_filters}" \
+ --action_env=PATH \
+ --remote_accept_cached=true \
+ --spawn_strategy=standalone \
+ --remote_local_fallback=false \
+ --remote_timeout=600 \
+ --strategy=Javac=standalone \
+ --strategy=Closure=standalone \
+ --genrule_strategy=standalone \
+ -- ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/...
+
+ # Copy log to output to be available to GitHub
+ ls -la "$(bazel info output_base)/java.log"
+ cp "$(bazel info output_base)/java.log" "${KOKORO_ARTIFACTS_DIR}/"
+}
+
+source tensorflow/tools/ci_build/release/common.sh
+update_bazel_macos
+which bazel
+set_bazel_outdir
+
+setup_pip
+run_build
diff --git a/tensorflow/tools/ci_build/presubmit/ubuntu_16/android/build.sh b/tensorflow/tools/ci_build/presubmit/ubuntu_16/android/build.sh
new file mode 100644
index 0000000..5fe3c41
--- /dev/null
+++ b/tensorflow/tools/ci_build/presubmit/ubuntu_16/android/build.sh
@@ -0,0 +1,81 @@
+#!/bin/bash
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+set -e
+
+# Error if we somehow forget to set the path to bazel_wrapper.py
+set -u
+BAZEL_WRAPPER_PATH=$1
+set +u
+
+# From this point on, logs can be publicly available
+set -x
+
+function run_build () {
+ export ANDROID_NDK_HOME="/opt/android-ndk-r17c"
+ export NDK_HOME=$ANDROID_NDK_HOME
+ export ANDROID_SDK_HOME="/opt/android-sdk/current"
+ export ANDROID_API_LEVEL="23"
+ export ANDROID_BUILD_TOOLS_VERSION="28.0.0"
+
+ ANDROID_OUT=android.out
+ ANDROID_OUT_TARGET=gen_android_out
+
+ # Run the presubmit android build.
+ tensorflow/tools/ci_build/builds/android.sh 2>&1 | tee tensorflow/tools/ci_build/builds/${ANDROID_OUT}
+ RC=${PIPESTATUS[0]}
+
+ # Since we are running the build remotely (rbe), we need to build a bazel
+ # target that would output the log generated above and return the expected
+ # error code.
+ cat << EOF > tensorflow/tools/ci_build/builds/BUILD
+package(default_visibility = ["//tensorflow:internal"])
+
+sh_test(
+ name = "${ANDROID_OUT_TARGET}",
+ srcs = ["${ANDROID_OUT_TARGET}.sh"],
+ data = ["${ANDROID_OUT}"],
+ tags = ["local"],
+)
+EOF
+
+ cat << EOF > tensorflow/tools/ci_build/builds/${ANDROID_OUT_TARGET}.sh
+#!/bin/bash
+cat tensorflow/tools/ci_build/builds/${ANDROID_OUT}
+exit ${RC}
+EOF
+
+ # Now trigger the rbe build that outputs the log
+ chmod +x tensorflow/tools/ci_build/builds/${ANDROID_OUT_TARGET}.sh
+
+ # Run bazel test command. Double test timeouts to avoid flakes.
+ # //tensorflow/core:platform_setround_test is not supported. See b/64264700
+ "${BAZEL_WRAPPER_PATH}" \
+ --host_jvm_args=-Dbazel.DigestFunction=SHA256 \
+ test \
+ --test_output=all \
+ tensorflow/tools/ci_build/builds:${ANDROID_OUT_TARGET}
+
+ # Copy log to output to be available to GitHub
+ ls -la "$(bazel info output_base)/java.log"
+ cp "$(bazel info output_base)/java.log" "${KOKORO_ARTIFACTS_DIR}/"
+}
+
+source tensorflow/tools/ci_build/release/common.sh
+update_bazel_linux
+which bazel
+
+run_build
diff --git a/tensorflow/tools/ci_build/presubmit/ubuntu_16/cpu_py36_full/build.sh b/tensorflow/tools/ci_build/presubmit/ubuntu_16/cpu_py36_full/build.sh
new file mode 100644
index 0000000..d852ba3
--- /dev/null
+++ b/tensorflow/tools/ci_build/presubmit/ubuntu_16/cpu_py36_full/build.sh
@@ -0,0 +1,96 @@
+#!/bin/bash
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+set -e
+
+# Error if we somehow forget to set the path to bazel_wrapper.py
+set -u
+BAZEL_WRAPPER_PATH=$1
+set +u
+
+# From this point on, logs can be publicly available
+set -x
+
+function run_build () {
+ # Build a unique cache silo string.
+ UBUNTU_VERSION=$(lsb_release -a | grep Release | awk '{print $2}')
+ IMAGE_VERSION=$(cat /VERSION)
+ CACHE_SILO_VAL="cpu-py3-ubuntu-16-${UBUNTU_VERSION}-${IMAGE_VERSION}"
+
+ # Run configure.
+ # Do not run configure.py when doing remote build & test:
+ # Most things we set with configure.py are not used in a remote build setting,
+ # as the build will be defined by pre-configured build files that are checked
+ # in.
+ # TODO(klimek): Allow using the right set of bazel flags without the need to
+ # run configure.py; currently we need to carefully copy them, which is brittle.
+ export TF_NEED_GCP=0
+ export TF_NEED_HDFS=0
+ export TF_NEED_CUDA=0
+ export ACTION_PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
+ export PYTHON_BIN_PATH="/usr/bin/python3"
+ export TF2_BEHAVIOR=1
+ tag_filters="-no_oss,-oss_serial,-gpu,-tpu,-benchmark-test""$(maybe_skip_v1)"
+
+ # Get the default test targets for bazel.
+ source tensorflow/tools/ci_build/build_scripts/PRESUBMIT_BUILD_TARGETS.sh
+
+ # Run bazel test command. Double test timeouts to avoid flakes.
+ # //tensorflow/core:platform_setround_test is not supported. See b/64264700
+ "${BAZEL_WRAPPER_PATH}" \
+ test \
+ --config=rbe \
+ --python_path="${PYTHON_BIN_PATH}" \
+ --action_env=PATH="${ACTION_PATH}" \
+ --action_env=PYTHON_BIN_PATH="${PYTHON_BIN_PATH}" \
+ --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \
+ --action_env=TF_PYTHON_CONFIG_REPO=@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/py3 \
+ --action_env=TF_ENABLE_XLA=1 \
+ --test_tag_filters="${tag_filters}" \
+ --build_tag_filters="${tag_filters}" \
+ --test_lang_filters=cc,py \
+ --define=with_default_optimizations=true \
+ --define=framework_shared_object=true \
+ --define=with_xla_support=true \
+ -c opt \
+ --copt="-w" \
+ --copt=-mavx \
+ --linkopt=-lrt \
+ --distinct_host_configuration=false \
+ --remote_default_platform_properties="properties:{name:\"build\" value:\"${CACHE_SILO_VAL}\"}" \
+ --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain \
+ --host_javabase=@bazel_toolchains//configs/ubuntu16_04_clang/1.1:jdk8 \
+ --javabase=@bazel_toolchains//configs/ubuntu16_04_clang/1.1:jdk8 \
+ --host_java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8 \
+ --java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8 \
+ --extra_toolchains=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:cc-toolchain-k8 \
+ --extra_execution_platforms=@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010 \
+ --host_platform=@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010 \
+ --remote_timeout=3600 \
+ --platforms=@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010 \
+ -- \
+ ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/...
+
+ # Copy log to output to be available to GitHub
+ ls -la "$(bazel info output_base)/java.log"
+ cp "$(bazel info output_base)/java.log" "${KOKORO_ARTIFACTS_DIR}/"
+}
+
+source tensorflow/tools/ci_build/release/common.sh
+update_bazel_linux
+which bazel
+
+run_build
diff --git a/tensorflow/tools/ci_build/presubmit/ubuntu_16/gpu_py36_full/build.sh b/tensorflow/tools/ci_build/presubmit/ubuntu_16/gpu_py36_full/build.sh
new file mode 100644
index 0000000..3fa4d4f
--- /dev/null
+++ b/tensorflow/tools/ci_build/presubmit/ubuntu_16/gpu_py36_full/build.sh
@@ -0,0 +1,114 @@
+#!/bin/bash
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+set -e
+
+# Error if we somehow forget to set the path to bazel_wrapper.py
+set -u
+BAZEL_WRAPPER_PATH=$1
+set +u
+
+# From this point on, logs can be publicly available
+set -x
+
+function run_build () {
+ # Build a unique cache silo string.
+ UBUNTU_VERSION=$(lsb_release -a | grep Release | awk '{print $2}')
+ IMAGE_VERSION=$(cat /VERSION)
+ CACHE_SILO_VAL="gpu-py3-ubuntu-16-${UBUNTU_VERSION}-${IMAGE_VERSION}"
+
+ # Run configure.
+ # Do not run configure.py when doing remote build & test:
+ # Most things we set with configure.py are not used in a remote build setting,
+ # as the build will be defined by pre-configured build files that are checked
+ # in.
+ # TODO(klimek): Allow using the right set of bazel flags without the need to
+ # run configure.py; currently we need to carefully copy them, which is brittle.
+ export LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
+ # TODO(klimek): Remove once we don't try to read it while setting up the remote
+ # config for cuda (we currently don't use it, as it's only used when compiling
+ # with clang, but we still require it to be set anyway).
+ export TF_CUDA_COMPUTE_CAPABILITIES=6.0
+ export ACTION_PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
+ export PYTHON_BIN_PATH="/usr/bin/python3"
+ export TF2_BEHAVIOR=1
+ tag_filters="gpu,-no_gpu,-nogpu,-benchmark-test,-no_oss,-oss_serial""$(maybe_skip_v1)"
+
+ # Get the default test targets for bazel.
+ source tensorflow/tools/ci_build/build_scripts/PRESUBMIT_BUILD_TARGETS.sh
+
+ # Run bazel test command. Double test timeouts to avoid flakes.
+ # //tensorflow/core:platform_setround_test is not supported. See b/64264700
+ # TODO(klimek): Re-enable tensorrt tests (with different runtime image) once
+ # we can build them.
+ # TODO(klimek): Stop using action_env for things that are only needed during
+ # setup - we're artificially poisoning the cache.
+ "${BAZEL_WRAPPER_PATH}" \
+ test \
+ --config=rbe \
+ --python_path="${PYTHON_BIN_PATH}" \
+ --action_env=PATH="${ACTION_PATH}" \
+ --action_env=PYTHON_BIN_PATH="${PYTHON_BIN_PATH}" \
+ --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \
+ --action_env=REMOTE_GPU_TESTING=1 \
+ --action_env=TF_CUDA_COMPUTE_CAPABILITIES="${TF_CUDA_COMPUTE_CAPABILITIES}" \
+ --action_env=TF_CUDA_CONFIG_REPO=@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7 \
+ --action_env=TF_CUDA_VERSION=10 \
+ --action_env=TF_CUDNN_VERSION=7 \
+ --action_env=TF_NEED_TENSORRT=0 \
+ --action_env=TF_NEED_CUDA=1 \
+ --action_env=TF_PYTHON_CONFIG_REPO=@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/py3 \
+ --test_env=LD_LIBRARY_PATH \
+ --test_tag_filters="${tag_filters}" \
+ --build_tag_filters="${tag_filters}" \
+ --test_lang_filters=cc,py \
+ --define=with_default_optimizations=true \
+ --define=framework_shared_object=true \
+ --define=with_xla_support=true \
+ --define=using_cuda_nvcc=true \
+ --define=use_fast_cpp_protos=true \
+ --define=allow_oversize_protos=true \
+ --define=grpc_no_ares=true \
+ -c opt \
+ --copt="-w" \
+ --copt=-mavx \
+ --linkopt=-lrt \
+ --distinct_host_configuration=false \
+ --remote_default_platform_properties="properties:{name:\"build\" value:\"${CACHE_SILO_VAL}\"}" \
+ --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0:toolchain \
+ --host_javabase=@bazel_toolchains//configs/ubuntu16_04_clang/1.1:jdk8 \
+ --javabase=@bazel_toolchains//configs/ubuntu16_04_clang/1.0:jdk8 \
+ --host_java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8 \
+ --java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8 \
+ --extra_toolchains=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0:toolchain-linux-x86_64 \
+ --extra_execution_platforms=@org_tensorflow//third_party/toolchains:rbe_cuda10.0-cudnn7-ubuntu16.04-manylinux2010,@org_tensorflow//third_party/toolchains:rbe_cuda10.0-cudnn7-ubuntu16.04-manylinux2010-gpu \
+ --host_platform=@org_tensorflow//third_party/toolchains:rbe_cuda10.0-cudnn7-ubuntu16.04-manylinux2010 \
+ --local_test_jobs=4 \
+ --remote_timeout=3600 \
+ --platforms=@org_tensorflow//third_party/toolchains:rbe_cuda10.0-cudnn7-ubuntu16.04-manylinux2010 \
+ -- \
+ ${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/...
+
+ # Copy log to output to be available to GitHub
+ ls -la "$(bazel info output_base)/java.log"
+ cp "$(bazel info output_base)/java.log" "${KOKORO_ARTIFACTS_DIR}/"
+}
+
+source tensorflow/tools/ci_build/release/common.sh
+update_bazel_linux
+which bazel
+
+run_build
diff --git a/tensorflow/tools/ci_build/presubmit/ubuntu_16/sanity/build.sh b/tensorflow/tools/ci_build/presubmit/ubuntu_16/sanity/build.sh
new file mode 100644
index 0000000..250b0c1
--- /dev/null
+++ b/tensorflow/tools/ci_build/presubmit/ubuntu_16/sanity/build.sh
@@ -0,0 +1,86 @@
+#!/bin/bash
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+set -e
+
+# Error if we somehow forget to set the path to bazel_wrapper.py
+set -u
+BAZEL_WRAPPER_PATH=$1
+set +u
+
+# From this point on, logs can be publicly available
+set -x
+
+function install_pylint () {
+ # TODO(gunan): figure out why we get stuck with later versions of pylint.
+ # TODO(mihaimaruseac): this is used in the release build in the same way,
+ # maybe extract out to a common?
+ sudo python2 -m pip install pylint==1.6.4
+ sudo python3 -m pip install pylint==1.6.4
+}
+
+function run_sanity_checks () {
+ SANITY_OUT=ci_sanity.out
+ SANITY_OUT_TARGET=gen_ci_sanity_out
+
+ # Run tensorflow sanity checks.
+ tensorflow/tools/ci_build/ci_sanity.sh 2>&1 | tee tensorflow/tools/ci_build/${SANITY_OUT}
+ RC=${PIPESTATUS[0]}
+
+ # Since we are running the sanity build remotely (rbe), we need to build a bazel
+ # target that would output the log generated above and return the expected
+ # error code.
+ cat << EOF > tensorflow/tools/ci_build/BUILD
+package(default_visibility = ["//tensorflow:internal"])
+
+sh_test(
+ name = "${SANITY_OUT_TARGET}",
+ srcs = ["${SANITY_OUT_TARGET}.sh"],
+ data = ["${SANITY_OUT}"],
+ tags = ["local"],
+)
+EOF
+
+ cat << EOF > tensorflow/tools/ci_build/${SANITY_OUT_TARGET}.sh
+#!/bin/bash
+cat tensorflow/tools/ci_build/${SANITY_OUT}
+exit ${RC}
+EOF
+
+ # Now trigger the rbe build that outputs the log
+ chmod +x tensorflow/tools/ci_build/${SANITY_OUT_TARGET}.sh
+
+ # Run bazel test command. Double test timeouts to avoid flakes.
+ # //tensorflow/core:platform_setround_test is not supported. See b/64264700
+ "${BAZEL_WRAPPER_PATH}" \
+ --host_jvm_args=-Dbazel.DigestFunction=SHA256 \
+ test \
+ --test_output=all \
+ tensorflow/tools/ci_build:${SANITY_OUT_TARGET}
+
+ # Copy log to output to be available to GitHub
+ ls -la "$(bazel info output_base)/java.log"
+ cp "$(bazel info output_base)/java.log" "${KOKORO_ARTIFACTS_DIR}/"
+}
+
+
+source tensorflow/tools/ci_build/release/common.sh
+update_bazel_linux
+which bazel
+
+install_pylint
+
+run_sanity_checks
diff --git a/tensorflow/tools/ci_build/presubmit/windows/cpu_py36_full/build.bat b/tensorflow/tools/ci_build/presubmit/windows/cpu_py36_full/build.bat
new file mode 100644
index 0000000..fcc079f
--- /dev/null
+++ b/tensorflow/tools/ci_build/presubmit/windows/cpu_py36_full/build.bat
@@ -0,0 +1,44 @@
+echo on
+setlocal enableextensions enabledelayedexpansion
+
+@REM This is the path to bazel_wrapper.py, should be set as an argument
+set BAZEL_WRAPPER_PATH=%~f1
+
+@REM Load common definitions, install bazel
+CALL tensorflow\tools\ci_build\release\common_win.bat
+
+@REM Set up common variables used through the script
+set WIN_OUT=win.out
+set WIN_OUT_TARGET=gen_win_out
+set BUILD_PATH=tensorflow/tools/ci_build/builds
+set GEN_SCRIPT=%BUILD_PATH%/%WIN_OUT_TARGET%.sh
+set GEN_BUILD=%BUILD_PATH%/BUILD
+
+@REM Run the presubmit win build.
+CALL tensorflow\tools\ci_build\windows\cpu\pip\run.bat --enable_remote_cache %* > %BUILD_PATH%/%WIN_OUT% 2>&1
+set RC=%errorlevel%
+
+@REM Since we are running the sanity build remotely (rbe), we need to build a bazel
+@REM target that would output the log generated above and return the expected
+@REM error code.
+echo package(default_visibility = ["//visibility:public"]) > %GEN_BUILD%
+echo. >> %GEN_BUILD%
+echo sh_test( >> %GEN_BUILD%
+echo name = "%WIN_OUT_TARGET%", >> %GEN_BUILD%
+echo srcs = ["%WIN_OUT_TARGET%.sh"], >> %GEN_BUILD%
+echo data = ["%WIN_OUT%"], >> %GEN_BUILD%
+echo tags = ["local"], >> %GEN_BUILD%
+echo ) >> %GEN_BUILD%
+
+echo #!/bin/bash > %GEN_SCRIPT%
+echo function rlocation() { >> %GEN_SCRIPT%
+echo fgrep -m1 "$1 " "$RUNFILES_MANIFEST_FILE" ^| cut -d' ' -f2- >> %GEN_SCRIPT%
+echo } >> %GEN_SCRIPT%
+echo cat $(rlocation %BUILD_PATH%/%WIN_OUT%) >> %GEN_SCRIPT%
+echo exit %RC% >> %GEN_SCRIPT%
+
+@REM Now trigger the rbe build that outputs the log
+chmod +x %GEN_SCRIPT%
+
+@REM Run bazel test command.
+%PY_EXE% %BAZEL_WRAPPER_PATH% --output_user_root=%TMPDIR% --host_jvm_args=-Dbazel.DigestFunction=SHA256 test %BUILD_PATH%:%WIN_OUT_TARGET% --test_output=all
diff --git a/tensorflow/tools/ci_build/presubmit/windows/gpu_py36_full/build.bat b/tensorflow/tools/ci_build/presubmit/windows/gpu_py36_full/build.bat
new file mode 100644
index 0000000..80edefc
--- /dev/null
+++ b/tensorflow/tools/ci_build/presubmit/windows/gpu_py36_full/build.bat
@@ -0,0 +1,45 @@
+echo on
+setlocal enableextensions enabledelayedexpansion
+
+@REM This is the path to bazel_wrapper.py, should be set as an argument
+set BAZEL_WRAPPER_PATH=%~f1
+
+@REM Load common definitions, install bazel
+CALL tensorflow\tools\ci_build\release\common_win.bat
+
+@REM Set up common variables used through the script
+set WIN_OUT=win.out
+set WIN_OUT_TARGET=gen_win_out
+set BUILD_PATH=tensorflow/tools/ci_build/builds
+set GEN_SCRIPT=%BUILD_PATH%/%WIN_OUT_TARGET%.sh
+set GEN_BUILD=%BUILD_PATH%/BUILD
+
+@REM Run the presubmit win build.
+CALL tensorflow\tools\ci_build\windows\gpu\pip\run.bat --enable_remote_cache %* > %BUILD_PATH%/%WIN_OUT% 2>&1
+set RC=%errorlevel%
+
+@REM Since we are running the sanity build remotely (rbe), we need to build a bazel
+@REM target that would output the log generated above and return the expected
+@REM error code.
+echo package(default_visibility = ["//visibility:public"]) > %GEN_BUILD%
+echo. >> %GEN_BUILD%
+echo sh_test( >> %GEN_BUILD%
+echo name = "%WIN_OUT_TARGET%", >> %GEN_BUILD%
+echo srcs = ["%WIN_OUT_TARGET%.sh"], >> %GEN_BUILD%
+echo data = ["%WIN_OUT%"], >> %GEN_BUILD%
+echo tags = ["local"], >> %GEN_BUILD%
+echo ) >> %GEN_BUILD%
+
+echo #!/bin/bash > %GEN_SCRIPT%
+echo function rlocation() { >> %GEN_SCRIPT%
+echo fgrep -m1 "$1 " "$RUNFILES_MANIFEST_FILE" ^| cut -d' ' -f2- >> %GEN_SCRIPT%
+echo } >> %GEN_SCRIPT%
+echo cat $(rlocation %BUILD_PATH%/%WIN_OUT%) >> %GEN_SCRIPT%
+echo exit %RC% >> %GEN_SCRIPT%
+
+@REM Now trigger the rbe build that outputs the log
+chmod +x %GEN_SCRIPT%
+
+@REM Run bazel test command.
+%PY_EXE% %BAZEL_WRAPPER_PATH% --output_user_root=%TMPDIR% --host_jvm_args=-Dbazel.DigestFunction=SHA256 test %BUILD_PATH%:%WIN_OUT_TARGET% --test_output=all
+
diff --git a/tensorflow/tools/ci_build/xla/linux/rocm/run_py3.sh b/tensorflow/tools/ci_build/xla/linux/rocm/run_py3.sh
index 72924fb..9288b7b35 100755
--- a/tensorflow/tools/ci_build/xla/linux/rocm/run_py3.sh
+++ b/tensorflow/tools/ci_build/xla/linux/rocm/run_py3.sh
@@ -27,6 +27,7 @@
# Run configure.
export PYTHON_BIN_PATH=`which python3`
+export CC_OPT_FLAGS='-mavx'
export TF_NEED_ROCM=1
export TF_GPU_COUNT=${N_GPUS}
@@ -34,12 +35,50 @@
yes "" | $PYTHON_BIN_PATH configure.py
echo "build --distinct_host_configuration=false" >> .tf_configure.bazelrc
-bazel clean
# Run bazel test command. Double test timeouts to avoid flakes.
-bazel test --config=rocm --test_tag_filters=-no_gpu,-benchmark-test,-no_oss,-no_rocm -k \
- --jobs=${N_JOBS} --test_timeout 600,900,2400,7200 \
- --build_tests_only --test_output=errors --local_test_jobs=${TF_GPU_COUNT} \
- --test_sharding_strategy=disabled \
- --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute \
- --config=xla -- \
- //tensorflow/compiler/...
+bazel test \
+ --config=rocm \
+ --config=xla \
+ -k \
+ --test_tag_filters=-no_oss,-oss_serial,-no_gpu,-no_rocm,-benchmark-test,-rocm_multi_gpu,-v1only \
+ --jobs=${N_JOBS} \
+ --local_test_jobs=${TF_GPU_COUNT} \
+ --test_timeout 600,900,2400,7200 \
+ --build_tests_only \
+ --test_output=errors \
+ --test_sharding_strategy=disabled \
+ --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute \
+ -- \
+ //tensorflow/compiler/... \
+ -//tensorflow/compiler/tests:dense_layer_test \
+ -//tensorflow/compiler/tests:dense_layer_test_gpu \
+ -//tensorflow/compiler/tests:jit_test \
+ -//tensorflow/compiler/tests:jit_test_gpu \
+ -//tensorflow/compiler/tests:matrix_triangular_solve_op_test \
+ -//tensorflow/compiler/tests:tensor_array_ops_test \
+ -//tensorflow/compiler/tests:xla_ops_test \
+ -//tensorflow/compiler/xla/client/lib:svd_test \
+ -//tensorflow/compiler/tests:lstm_test \
+&& bazel test \
+ --config=rocm \
+ --config=xla \
+ -k \
+ --test_tag_filters=-no_oss,-oss_serial,-no_gpu,-no_rocm,-benchmark-test,-rocm_multi_gpu,-v1only \
+ --jobs=${N_JOBS} \
+ --local_test_jobs=${TF_GPU_COUNT} \
+ --test_timeout 600,900,2400,7200 \
+ --build_tests_only \
+ --test_output=errors \
+ --test_sharding_strategy=disabled \
+ --test_env=TF2_BEHAVIOR=0 \
+ --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute \
+ -- \
+ //tensorflow/compiler/tests:dense_layer_test \
+ //tensorflow/compiler/tests:dense_layer_test_gpu \
+ //tensorflow/compiler/tests:jit_test \
+ //tensorflow/compiler/tests:jit_test_gpu \
+ //tensorflow/compiler/tests:matrix_triangular_solve_op_test \
+ //tensorflow/compiler/tests:tensor_array_ops_test \
+ //tensorflow/compiler/tests:xla_ops_test \
+ //tensorflow/compiler/xla/client/lib:svd_test \
+ //tensorflow/compiler/tests:lstm_test
diff --git a/tensorflow/tools/def_file_filter/def_file_filter.py.tpl b/tensorflow/tools/def_file_filter/def_file_filter.py.tpl
index 82640ea..f894c00 100644
--- a/tensorflow/tools/def_file_filter/def_file_filter.py.tpl
+++ b/tensorflow/tools/def_file_filter/def_file_filter.py.tpl
@@ -175,7 +175,7 @@
else:
# If not a section header and not an empty line, then it's a symbol
# line. e.g. `tensorflow::swig::IsSequence`
- symbols[curr_lib].append(re.escape(line))
+ symbols[curr_lib].append(line)
lib_paths = []
with open(lib_paths_file, "r") as f:
diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt
index 23a5b20..e657edc 100644
--- a/tensorflow/tools/def_file_filter/symbols_pybind.txt
+++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt
@@ -43,6 +43,10 @@
[graph_analyzer_tool] # graph_analyzer
tensorflow::grappler::graph_analyzer::GraphAnalyzerTool
+[bfloat16_lib] # bfloat16
+tensorflow::RegisterNumpyBfloat16
+tensorflow::Bfloat16PyType
+
[events_writer] # events_writer
tensorflow::EventsWriter::Init
tensorflow::EventsWriter::InitWithSuffix
@@ -185,13 +189,3 @@
[context] # tfe
tensorflow::EagerContext::WaitForAndCloseRemoteContexts
-
-[ndarray_tensor_types] # _dtypes
-tensorflow::MaybeRegisterCustomNumPyTypes
-tensorflow::BFLOAT16_DESCR
-tensorflow::QINT8_DESCR
-tensorflow::QINT16_DESCR
-tensorflow::QINT32_DESCR
-tensorflow::QUINT8_DESCR
-tensorflow::QUINT16_DESCR
-tensorflow::RESOURCE_DESCR
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/cpu-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/cpu-jupyter.Dockerfile
index 46443bb..8e83923 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/cpu-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/cpu-jupyter.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -42,7 +42,7 @@
setuptools
# Some TF tools expect a "python" binary
-RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
+RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
# Options:
# tensorflow
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/cpu.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/cpu.Dockerfile
index bf1d518..6e7e29f 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/cpu.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/cpu.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -42,7 +42,7 @@
setuptools
# Some TF tools expect a "python" binary
-RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
+RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
# Options:
# tensorflow
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu-jupyter.Dockerfile
index 19732f3..fe0b901 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu-jupyter.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -34,7 +34,7 @@
pkg-config \
rsync \
software-properties-common \
- sudo \
+ sudo \
unzip \
zip \
zlib1g-dev \
@@ -71,7 +71,7 @@
setuptools
# Some TF tools expect a "python" binary
-RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
+RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
RUN apt-get update && apt-get install -y \
build-essential \
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu.Dockerfile
index 05528a1..293934d 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -34,7 +34,7 @@
pkg-config \
rsync \
software-properties-common \
- sudo \
+ sudo \
unzip \
zip \
zlib1g-dev \
@@ -71,7 +71,7 @@
setuptools
# Some TF tools expect a "python" binary
-RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
+RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
RUN apt-get update && apt-get install -y \
build-essential \
@@ -108,3 +108,4 @@
rm -f /bazel/installer.sh
COPY bashrc /etc/bash.bashrc
+RUN chmod a+rwx /etc/bash.bashrc
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu-jupyter.Dockerfile
index 4df7f84..ba4f620 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu-jupyter.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,22 +22,30 @@
ARG UBUNTU_VERSION=18.04
ARG ARCH=
-ARG CUDA=10.0
+ARG CUDA=10.1
FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base
# ARCH and CUDA are specified again because the FROM directive resets ARGs
# (but their default value is retained if set previously)
ARG ARCH
ARG CUDA
-ARG CUDNN=7.6.2.24-1
+ARG CUDNN=7.6.4.38-1
ARG CUDNN_MAJOR_VERSION=7
ARG LIB_DIR_PREFIX=x86_64
+ARG LIBNVINFER=6.0.1-1
+ARG LIBNVINFER_MAJOR_VERSION=6
# Needed for string substitution
SHELL ["/bin/bash", "-c"]
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
cuda-command-line-tools-${CUDA/./-} \
- cuda-cublas-dev-${CUDA/./-} \
+ # There appears to be a regression in libcublas10=10.2.2.89-1 which
+ # prevents cublas from initializing in TF. See
+ # https://github.com/tensorflow/tensorflow/issues/9489#issuecomment-562394257
+ libcublas10=10.2.1.243-1 \
+ libcublas-dev=10.2.1.243-1 \
+ cuda-nvrtc-${CUDA/./-} \
+ cuda-nvrtc-dev-${CUDA/./-} \
cuda-cudart-dev-${CUDA/./-} \
cuda-cufft-dev-${CUDA/./-} \
cuda-curand-dev-${CUDA/./-} \
@@ -61,18 +69,19 @@
find /usr/local/cuda-${CUDA}/lib64/ -type f -name 'lib*_static.a' -not -name 'libcudart_static.a' -delete && \
rm /usr/lib/${LIB_DIR_PREFIX}-linux-gnu/libcudnn_static_v7.a
+# Install TensorRT if not building for PowerPC
RUN [[ "${ARCH}" = "ppc64le" ]] || { apt-get update && \
- apt-get install -y --no-install-recommends libnvinfer5=5.1.5-1+cuda${CUDA} \
- libnvinfer-dev=5.1.5-1+cuda${CUDA} \
+ apt-get install -y --no-install-recommends libnvinfer${LIBNVINFER_MAJOR_VERSION}=${LIBNVINFER}+cuda${CUDA} \
+ libnvinfer-dev=${LIBNVINFER}+cuda${CUDA} \
+ libnvinfer-plugin-dev=${LIBNVINFER}+cuda${CUDA} \
+ libnvinfer-plugin${LIBNVINFER_MAJOR_VERSION}=${LIBNVINFER}+cuda${CUDA} \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*; }
# Configure the build for our CUDA configuration.
-ENV CI_BUILD_PYTHON python
-ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
+ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/lib64:/usr/local/cuda/lib64/stubs:/usr/include/x64_64-linux-gnu:$LD_LIBRARY_PATH
ENV TF_NEED_CUDA 1
ENV TF_NEED_TENSORRT 1
-ENV TF_CUDA_COMPUTE_CAPABILITIES=3.5,5.2,6.0,6.1,7.0
ENV TF_CUDA_VERSION=${CUDA}
ENV TF_CUDNN_VERSION=${CUDNN_MAJOR_VERSION}
# CACHE_STOP is used to rerun future commands, otherwise cloning tensorflow will be cached and will not pull the most recent version
@@ -104,7 +113,7 @@
setuptools
# Some TF tools expect a "python" binary
-RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
+RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
RUN apt-get update && apt-get install -y \
build-essential \
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu.Dockerfile
index 79787ad..ae6ad2a 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,22 +22,30 @@
ARG UBUNTU_VERSION=18.04
ARG ARCH=
-ARG CUDA=10.0
+ARG CUDA=10.1
FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base
# ARCH and CUDA are specified again because the FROM directive resets ARGs
# (but their default value is retained if set previously)
ARG ARCH
ARG CUDA
-ARG CUDNN=7.6.2.24-1
+ARG CUDNN=7.6.4.38-1
ARG CUDNN_MAJOR_VERSION=7
ARG LIB_DIR_PREFIX=x86_64
+ARG LIBNVINFER=6.0.1-1
+ARG LIBNVINFER_MAJOR_VERSION=6
# Needed for string substitution
SHELL ["/bin/bash", "-c"]
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
cuda-command-line-tools-${CUDA/./-} \
- cuda-cublas-dev-${CUDA/./-} \
+ # There appears to be a regression in libcublas10=10.2.2.89-1 which
+ # prevents cublas from initializing in TF. See
+ # https://github.com/tensorflow/tensorflow/issues/9489#issuecomment-562394257
+ libcublas10=10.2.1.243-1 \
+ libcublas-dev=10.2.1.243-1 \
+ cuda-nvrtc-${CUDA/./-} \
+ cuda-nvrtc-dev-${CUDA/./-} \
cuda-cudart-dev-${CUDA/./-} \
cuda-cufft-dev-${CUDA/./-} \
cuda-curand-dev-${CUDA/./-} \
@@ -61,18 +69,19 @@
find /usr/local/cuda-${CUDA}/lib64/ -type f -name 'lib*_static.a' -not -name 'libcudart_static.a' -delete && \
rm /usr/lib/${LIB_DIR_PREFIX}-linux-gnu/libcudnn_static_v7.a
+# Install TensorRT if not building for PowerPC
RUN [[ "${ARCH}" = "ppc64le" ]] || { apt-get update && \
- apt-get install -y --no-install-recommends libnvinfer5=5.1.5-1+cuda${CUDA} \
- libnvinfer-dev=5.1.5-1+cuda${CUDA} \
+ apt-get install -y --no-install-recommends libnvinfer${LIBNVINFER_MAJOR_VERSION}=${LIBNVINFER}+cuda${CUDA} \
+ libnvinfer-dev=${LIBNVINFER}+cuda${CUDA} \
+ libnvinfer-plugin-dev=${LIBNVINFER}+cuda${CUDA} \
+ libnvinfer-plugin${LIBNVINFER_MAJOR_VERSION}=${LIBNVINFER}+cuda${CUDA} \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*; }
# Configure the build for our CUDA configuration.
-ENV CI_BUILD_PYTHON python
-ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
+ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/lib64:/usr/local/cuda/lib64/stubs:/usr/include/x64_64-linux-gnu:$LD_LIBRARY_PATH
ENV TF_NEED_CUDA 1
ENV TF_NEED_TENSORRT 1
-ENV TF_CUDA_COMPUTE_CAPABILITIES=3.5,5.2,6.0,6.1,7.0
ENV TF_CUDA_VERSION=${CUDA}
ENV TF_CUDNN_VERSION=${CUDNN_MAJOR_VERSION}
# CACHE_STOP is used to rerun future commands, otherwise cloning tensorflow will be cached and will not pull the most recent version
@@ -104,7 +113,7 @@
setuptools
# Some TF tools expect a "python" binary
-RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
+RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
RUN apt-get update && apt-get install -y \
build-essential \
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/gpu-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/gpu-jupyter.Dockerfile
index fe2045b..30d9183 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/gpu-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/gpu-jupyter.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,13 +22,17 @@
ARG UBUNTU_VERSION=18.04
ARG ARCH=
-ARG CUDA=10.0
+ARG CUDA=10.1
FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base
# ARCH and CUDA are specified again because the FROM directive resets ARGs
# (but their default value is retained if set previously)
ARG ARCH
ARG CUDA
-ARG CUDNN=7.6.2.24-1
+ARG CUDNN=7.6.4.38-1
+ARG CUDNN_MAJOR_VERSION=7
+ARG LIB_DIR_PREFIX=x86_64
+ARG LIBNVINFER=6.0.1-1
+ARG LIBNVINFER_MAJOR_VERSION=6
# Needed for string substitution
SHELL ["/bin/bash", "-c"]
@@ -36,7 +40,11 @@
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
cuda-command-line-tools-${CUDA/./-} \
- cuda-cublas-${CUDA/./-} \
+ # There appears to be a regression in libcublas10=10.2.2.89-1 which
+ # prevents cublas from initializing in TF. See
+ # https://github.com/tensorflow/tensorflow/issues/9489#issuecomment-562394257
+ libcublas10=10.2.1.243-1 \
+ cuda-nvrtc-${CUDA/./-} \
cuda-cufft-${CUDA/./-} \
cuda-curand-${CUDA/./-} \
cuda-cusolver-${CUDA/./-} \
@@ -50,10 +58,12 @@
software-properties-common \
unzip
-RUN [ ${ARCH} = ppc64le ] || (apt-get update && \
- apt-get install -y --no-install-recommends libnvinfer5=5.1.5-1+cuda${CUDA} \
+# Install TensorRT if not building for PowerPC
+RUN [[ "${ARCH}" = "ppc64le" ]] || { apt-get update && \
+ apt-get install -y --no-install-recommends libnvinfer${LIBNVINFER_MAJOR_VERSION}=${LIBNVINFER}+cuda${CUDA} \
+ libnvinfer-plugin${LIBNVINFER_MAJOR_VERSION}=${LIBNVINFER}+cuda${CUDA} \
&& apt-get clean \
- && rm -rf /var/lib/apt/lists/*)
+ && rm -rf /var/lib/apt/lists/*; }
# For CUDA profiling, TensorFlow requires CUPTI.
ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
@@ -81,7 +91,7 @@
setuptools
# Some TF tools expect a "python" binary
-RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
+RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
# Options:
# tensorflow
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/gpu.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/gpu.Dockerfile
index bfeaebe..d6ea415 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/gpu.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/gpu.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,13 +22,17 @@
ARG UBUNTU_VERSION=18.04
ARG ARCH=
-ARG CUDA=10.0
+ARG CUDA=10.1
FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base
# ARCH and CUDA are specified again because the FROM directive resets ARGs
# (but their default value is retained if set previously)
ARG ARCH
ARG CUDA
-ARG CUDNN=7.6.2.24-1
+ARG CUDNN=7.6.4.38-1
+ARG CUDNN_MAJOR_VERSION=7
+ARG LIB_DIR_PREFIX=x86_64
+ARG LIBNVINFER=6.0.1-1
+ARG LIBNVINFER_MAJOR_VERSION=6
# Needed for string substitution
SHELL ["/bin/bash", "-c"]
@@ -36,7 +40,11 @@
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
cuda-command-line-tools-${CUDA/./-} \
- cuda-cublas-${CUDA/./-} \
+ # There appears to be a regression in libcublas10=10.2.2.89-1 which
+ # prevents cublas from initializing in TF. See
+ # https://github.com/tensorflow/tensorflow/issues/9489#issuecomment-562394257
+ libcublas10=10.2.1.243-1 \
+ cuda-nvrtc-${CUDA/./-} \
cuda-cufft-${CUDA/./-} \
cuda-curand-${CUDA/./-} \
cuda-cusolver-${CUDA/./-} \
@@ -50,10 +58,12 @@
software-properties-common \
unzip
-RUN [ ${ARCH} = ppc64le ] || (apt-get update && \
- apt-get install -y --no-install-recommends libnvinfer5=5.1.5-1+cuda${CUDA} \
+# Install TensorRT if not building for PowerPC
+RUN [[ "${ARCH}" = "ppc64le" ]] || { apt-get update && \
+ apt-get install -y --no-install-recommends libnvinfer${LIBNVINFER_MAJOR_VERSION}=${LIBNVINFER}+cuda${CUDA} \
+ libnvinfer-plugin${LIBNVINFER_MAJOR_VERSION}=${LIBNVINFER}+cuda${CUDA} \
&& apt-get clean \
- && rm -rf /var/lib/apt/lists/*)
+ && rm -rf /var/lib/apt/lists/*; }
# For CUDA profiling, TensorFlow requires CUPTI.
ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
@@ -81,7 +91,7 @@
setuptools
# Some TF tools expect a "python" binary
-RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
+RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
# Options:
# tensorflow
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/devel-horovod-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/devel-horovod-jupyter.Dockerfile
index b5bb5d6..6ac98b9 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/devel-horovod-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/devel-horovod-jupyter.Dockerfile
@@ -34,7 +34,7 @@
pkg-config \
rsync \
software-properties-common \
- sudo \
+ sudo \
unzip \
zip \
zlib1g-dev \
@@ -50,6 +50,8 @@
ARG CACHE_STOP=1
# Check out TensorFlow source code if --build-arg CHECKOUT_TF_SRC=1
ARG CHECKOUT_TF_SRC=0
+# In case of Python 2.7+ we need to add passwd entries for user and group id
+RUN chmod a+w /etc/passwd /etc/group
RUN test "${CHECKOUT_TF_SRC}" -eq 1 && git clone https://github.com/tensorflow/tensorflow.git /tensorflow_src || true
ARG USE_PYTHON_3_NOT_2
@@ -97,7 +99,7 @@
enum34
# Install bazel
-ARG BAZEL_VERSION=0.24.1
+ARG BAZEL_VERSION=1.1.0
RUN mkdir /bazel && \
wget -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \
wget -O /bazel/LICENSE.txt "https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE" && \
@@ -168,8 +170,12 @@
RUN mkdir /.local && chmod a+rwx /.local
RUN apt-get install -y --no-install-recommends wget
WORKDIR /tf/tensorflow-tutorials
-RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/basic_classification.ipynb
-RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/basic_text_classification.ipynb
+RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/classification.ipynb
+RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/overfit_and_underfit.ipynb
+RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/regression.ipynb
+RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/save_and_load.ipynb
+RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/text_classification.ipynb
+RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/text_classification_with_hub.ipynb
COPY readme-for-jupyter.md README.md
RUN apt-get autoremove -y && apt-get remove -y wget
WORKDIR /tf
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/devel-horovod.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/devel-horovod.Dockerfile
index f4162a2..e35e877 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/devel-horovod.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/devel-horovod.Dockerfile
@@ -34,7 +34,7 @@
pkg-config \
rsync \
software-properties-common \
- sudo \
+ sudo \
unzip \
zip \
zlib1g-dev \
@@ -50,6 +50,8 @@
ARG CACHE_STOP=1
# Check out TensorFlow source code if --build-arg CHECKOUT_TF_SRC=1
ARG CHECKOUT_TF_SRC=0
+# In case of Python 2.7+ we need to add passwd entries for user and group id
+RUN chmod a+w /etc/passwd /etc/group
RUN test "${CHECKOUT_TF_SRC}" -eq 1 && git clone https://github.com/tensorflow/tensorflow.git /tensorflow_src || true
ARG USE_PYTHON_3_NOT_2
@@ -97,7 +99,7 @@
enum34
# Install bazel
-ARG BAZEL_VERSION=0.24.1
+ARG BAZEL_VERSION=1.1.0
RUN mkdir /bazel && \
wget -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \
wget -O /bazel/LICENSE.txt "https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE" && \
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/horovod-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/horovod-jupyter.Dockerfile
index 5ba0fe6..cb1155a 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/horovod-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/horovod-jupyter.Dockerfile
@@ -23,6 +23,8 @@
FROM ubuntu:${UBUNTU_VERSION} as base
+RUN apt-get update && apt-get install -y curl
+
ARG USE_PYTHON_3_NOT_2
ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
ARG PYTHON=python${_PY_SUFFIX}
@@ -116,8 +118,12 @@
RUN mkdir /.local && chmod a+rwx /.local
RUN apt-get install -y --no-install-recommends wget
WORKDIR /tf/tensorflow-tutorials
-RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/basic_classification.ipynb
-RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/basic_text_classification.ipynb
+RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/classification.ipynb
+RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/overfit_and_underfit.ipynb
+RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/regression.ipynb
+RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/save_and_load.ipynb
+RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/text_classification.ipynb
+RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/text_classification_with_hub.ipynb
COPY readme-for-jupyter.md README.md
RUN apt-get autoremove -y && apt-get remove -y wget
WORKDIR /tf
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/horovod.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/horovod.Dockerfile
index e08b910..9102967 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/horovod.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/mkl_horovod/horovod.Dockerfile
@@ -23,6 +23,8 @@
FROM ubuntu:${UBUNTU_VERSION} as base
+RUN apt-get update && apt-get install -y curl
+
ARG USE_PYTHON_3_NOT_2
ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
ARG PYTHON=python${_PY_SUFFIX}
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/cpu-ppc64le-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/cpu-ppc64le-jupyter.Dockerfile
index 907d6af..72a33cd 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/cpu-ppc64le-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/cpu-ppc64le-jupyter.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -42,7 +42,7 @@
setuptools
# Some TF tools expect a "python" binary
-RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
+RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
# Options:
# tensorflow
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/cpu-ppc64le.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/cpu-ppc64le.Dockerfile
index 3ec3f3a..1abf31b8 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/cpu-ppc64le.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/cpu-ppc64le.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -42,7 +42,7 @@
setuptools
# Some TF tools expect a "python" binary
-RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
+RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
# Options:
# tensorflow
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-cpu-ppc64le-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-cpu-ppc64le-jupyter.Dockerfile
index 8006db4..d4fb001 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-cpu-ppc64le-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-cpu-ppc64le-jupyter.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -34,7 +34,7 @@
pkg-config \
rsync \
software-properties-common \
- sudo \
+ sudo \
unzip \
zip \
zlib1g-dev \
@@ -71,7 +71,7 @@
setuptools
# Some TF tools expect a "python" binary
-RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
+RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
RUN apt-get update && apt-get install -y \
build-essential \
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-cpu-ppc64le.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-cpu-ppc64le.Dockerfile
index 06f2b77..15ca286 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-cpu-ppc64le.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-cpu-ppc64le.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -34,7 +34,7 @@
pkg-config \
rsync \
software-properties-common \
- sudo \
+ sudo \
unzip \
zip \
zlib1g-dev \
@@ -71,7 +71,7 @@
setuptools
# Some TF tools expect a "python" binary
-RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
+RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
RUN apt-get update && apt-get install -y \
build-essential \
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le-jupyter.Dockerfile
index 5fc850a..be13cff 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le-jupyter.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,22 +22,30 @@
ARG UBUNTU_VERSION=18.04
ARG ARCH=
-ARG CUDA=10.0
+ARG CUDA=10.1
FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base
# ARCH and CUDA are specified again because the FROM directive resets ARGs
# (but their default value is retained if set previously)
ARG ARCH
ARG CUDA
-ARG CUDNN=7.6.2.24-1
+ARG CUDNN=7.6.4.38-1
ARG CUDNN_MAJOR_VERSION=7
ARG LIB_DIR_PREFIX=x86_64
+ARG LIBNVINFER=6.0.1-1
+ARG LIBNVINFER_MAJOR_VERSION=6
# Needed for string substitution
SHELL ["/bin/bash", "-c"]
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
cuda-command-line-tools-${CUDA/./-} \
- cuda-cublas-dev-${CUDA/./-} \
+ # There appears to be a regression in libcublas10=10.2.2.89-1 which
+ # prevents cublas from initializing in TF. See
+ # https://github.com/tensorflow/tensorflow/issues/9489#issuecomment-562394257
+ libcublas10=10.2.1.243-1 \
+ libcublas-dev=10.2.1.243-1 \
+ cuda-nvrtc-${CUDA/./-} \
+ cuda-nvrtc-dev-${CUDA/./-} \
cuda-cudart-dev-${CUDA/./-} \
cuda-cufft-dev-${CUDA/./-} \
cuda-curand-dev-${CUDA/./-} \
@@ -61,18 +69,19 @@
find /usr/local/cuda-${CUDA}/lib64/ -type f -name 'lib*_static.a' -not -name 'libcudart_static.a' -delete && \
rm /usr/lib/${LIB_DIR_PREFIX}-linux-gnu/libcudnn_static_v7.a
+# Install TensorRT if not building for PowerPC
RUN [[ "${ARCH}" = "ppc64le" ]] || { apt-get update && \
- apt-get install -y --no-install-recommends libnvinfer5=5.1.5-1+cuda${CUDA} \
- libnvinfer-dev=5.1.5-1+cuda${CUDA} \
+ apt-get install -y --no-install-recommends libnvinfer${LIBNVINFER_MAJOR_VERSION}=${LIBNVINFER}+cuda${CUDA} \
+ libnvinfer-dev=${LIBNVINFER}+cuda${CUDA} \
+ libnvinfer-plugin-dev=${LIBNVINFER}+cuda${CUDA} \
+ libnvinfer-plugin${LIBNVINFER_MAJOR_VERSION}=${LIBNVINFER}+cuda${CUDA} \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*; }
# Configure the build for our CUDA configuration.
-ENV CI_BUILD_PYTHON python
-ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
+ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/lib64:/usr/local/cuda/lib64/stubs:/usr/include/x64_64-linux-gnu:$LD_LIBRARY_PATH
ENV TF_NEED_CUDA 1
ENV TF_NEED_TENSORRT 1
-ENV TF_CUDA_COMPUTE_CAPABILITIES=3.5,5.2,6.0,6.1,7.0
ENV TF_CUDA_VERSION=${CUDA}
ENV TF_CUDNN_VERSION=${CUDNN_MAJOR_VERSION}
# CACHE_STOP is used to rerun future commands, otherwise cloning tensorflow will be cached and will not pull the most recent version
@@ -104,7 +113,7 @@
setuptools
# Some TF tools expect a "python" binary
-RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
+RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
RUN apt-get update && apt-get install -y \
build-essential \
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le.Dockerfile
index 21cd014..015fc39 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,22 +22,30 @@
ARG UBUNTU_VERSION=18.04
ARG ARCH=
-ARG CUDA=10.0
+ARG CUDA=10.1
FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base
# ARCH and CUDA are specified again because the FROM directive resets ARGs
# (but their default value is retained if set previously)
ARG ARCH
ARG CUDA
-ARG CUDNN=7.6.2.24-1
+ARG CUDNN=7.6.4.38-1
ARG CUDNN_MAJOR_VERSION=7
ARG LIB_DIR_PREFIX=x86_64
+ARG LIBNVINFER=6.0.1-1
+ARG LIBNVINFER_MAJOR_VERSION=6
# Needed for string substitution
SHELL ["/bin/bash", "-c"]
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
cuda-command-line-tools-${CUDA/./-} \
- cuda-cublas-dev-${CUDA/./-} \
+ # There appears to be a regression in libcublas10=10.2.2.89-1 which
+ # prevents cublas from initializing in TF. See
+ # https://github.com/tensorflow/tensorflow/issues/9489#issuecomment-562394257
+ libcublas10=10.2.1.243-1 \
+ libcublas-dev=10.2.1.243-1 \
+ cuda-nvrtc-${CUDA/./-} \
+ cuda-nvrtc-dev-${CUDA/./-} \
cuda-cudart-dev-${CUDA/./-} \
cuda-cufft-dev-${CUDA/./-} \
cuda-curand-dev-${CUDA/./-} \
@@ -61,18 +69,19 @@
find /usr/local/cuda-${CUDA}/lib64/ -type f -name 'lib*_static.a' -not -name 'libcudart_static.a' -delete && \
rm /usr/lib/${LIB_DIR_PREFIX}-linux-gnu/libcudnn_static_v7.a
+# Install TensorRT if not building for PowerPC
RUN [[ "${ARCH}" = "ppc64le" ]] || { apt-get update && \
- apt-get install -y --no-install-recommends libnvinfer5=5.1.5-1+cuda${CUDA} \
- libnvinfer-dev=5.1.5-1+cuda${CUDA} \
+ apt-get install -y --no-install-recommends libnvinfer${LIBNVINFER_MAJOR_VERSION}=${LIBNVINFER}+cuda${CUDA} \
+ libnvinfer-dev=${LIBNVINFER}+cuda${CUDA} \
+ libnvinfer-plugin-dev=${LIBNVINFER}+cuda${CUDA} \
+ libnvinfer-plugin${LIBNVINFER_MAJOR_VERSION}=${LIBNVINFER}+cuda${CUDA} \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*; }
# Configure the build for our CUDA configuration.
-ENV CI_BUILD_PYTHON python
-ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
+ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/lib64:/usr/local/cuda/lib64/stubs:/usr/include/x64_64-linux-gnu:$LD_LIBRARY_PATH
ENV TF_NEED_CUDA 1
ENV TF_NEED_TENSORRT 1
-ENV TF_CUDA_COMPUTE_CAPABILITIES=3.5,5.2,6.0,6.1,7.0
ENV TF_CUDA_VERSION=${CUDA}
ENV TF_CUDNN_VERSION=${CUDNN_MAJOR_VERSION}
# CACHE_STOP is used to rerun future commands, otherwise cloning tensorflow will be cached and will not pull the most recent version
@@ -104,7 +113,7 @@
setuptools
# Some TF tools expect a "python" binary
-RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
+RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
RUN apt-get update && apt-get install -y \
build-essential \
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le-jupyter.Dockerfile
index 71a1b79..b2ebddb 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le-jupyter.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,13 +22,17 @@
ARG UBUNTU_VERSION=18.04
ARG ARCH=
-ARG CUDA=10.0
+ARG CUDA=10.1
FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base
# ARCH and CUDA are specified again because the FROM directive resets ARGs
# (but their default value is retained if set previously)
ARG ARCH
ARG CUDA
-ARG CUDNN=7.6.2.24-1
+ARG CUDNN=7.6.4.38-1
+ARG CUDNN_MAJOR_VERSION=7
+ARG LIB_DIR_PREFIX=x86_64
+ARG LIBNVINFER=6.0.1-1
+ARG LIBNVINFER_MAJOR_VERSION=6
# Needed for string substitution
SHELL ["/bin/bash", "-c"]
@@ -36,7 +40,11 @@
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
cuda-command-line-tools-${CUDA/./-} \
- cuda-cublas-${CUDA/./-} \
+ # There appears to be a regression in libcublas10=10.2.2.89-1 which
+ # prevents cublas from initializing in TF. See
+ # https://github.com/tensorflow/tensorflow/issues/9489#issuecomment-562394257
+ libcublas10=10.2.1.243-1 \
+ cuda-nvrtc-${CUDA/./-} \
cuda-cufft-${CUDA/./-} \
cuda-curand-${CUDA/./-} \
cuda-cusolver-${CUDA/./-} \
@@ -50,10 +58,12 @@
software-properties-common \
unzip
-RUN [ ${ARCH} = ppc64le ] || (apt-get update && \
- apt-get install -y --no-install-recommends libnvinfer5=5.1.5-1+cuda${CUDA} \
+# Install TensorRT if not building for PowerPC
+RUN [[ "${ARCH}" = "ppc64le" ]] || { apt-get update && \
+ apt-get install -y --no-install-recommends libnvinfer${LIBNVINFER_MAJOR_VERSION}=${LIBNVINFER}+cuda${CUDA} \
+ libnvinfer-plugin${LIBNVINFER_MAJOR_VERSION}=${LIBNVINFER}+cuda${CUDA} \
&& apt-get clean \
- && rm -rf /var/lib/apt/lists/*)
+ && rm -rf /var/lib/apt/lists/*; }
# For CUDA profiling, TensorFlow requires CUPTI.
ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
@@ -81,7 +91,7 @@
setuptools
# Some TF tools expect a "python" binary
-RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
+RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
# Options:
# tensorflow
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le.Dockerfile
index 4655b1d..cef34a5 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le.Dockerfile
@@ -1,4 +1,4 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,13 +22,17 @@
ARG UBUNTU_VERSION=18.04
ARG ARCH=
-ARG CUDA=10.0
+ARG CUDA=10.1
FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base
# ARCH and CUDA are specified again because the FROM directive resets ARGs
# (but their default value is retained if set previously)
ARG ARCH
ARG CUDA
-ARG CUDNN=7.6.2.24-1
+ARG CUDNN=7.6.4.38-1
+ARG CUDNN_MAJOR_VERSION=7
+ARG LIB_DIR_PREFIX=x86_64
+ARG LIBNVINFER=6.0.1-1
+ARG LIBNVINFER_MAJOR_VERSION=6
# Needed for string substitution
SHELL ["/bin/bash", "-c"]
@@ -36,7 +40,11 @@
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
cuda-command-line-tools-${CUDA/./-} \
- cuda-cublas-${CUDA/./-} \
+ # There appears to be a regression in libcublas10=10.2.2.89-1 which
+ # prevents cublas from initializing in TF. See
+ # https://github.com/tensorflow/tensorflow/issues/9489#issuecomment-562394257
+ libcublas10=10.2.1.243-1 \
+ cuda-nvrtc-${CUDA/./-} \
cuda-cufft-${CUDA/./-} \
cuda-curand-${CUDA/./-} \
cuda-cusolver-${CUDA/./-} \
@@ -50,10 +58,12 @@
software-properties-common \
unzip
-RUN [ ${ARCH} = ppc64le ] || (apt-get update && \
- apt-get install -y --no-install-recommends libnvinfer5=5.1.5-1+cuda${CUDA} \
+# Install TensorRT if not building for PowerPC
+RUN [[ "${ARCH}" = "ppc64le" ]] || { apt-get update && \
+ apt-get install -y --no-install-recommends libnvinfer${LIBNVINFER_MAJOR_VERSION}=${LIBNVINFER}+cuda${CUDA} \
+ libnvinfer-plugin${LIBNVINFER_MAJOR_VERSION}=${LIBNVINFER}+cuda${CUDA} \
&& apt-get clean \
- && rm -rf /var/lib/apt/lists/*)
+ && rm -rf /var/lib/apt/lists/*; }
# For CUDA profiling, TensorFlow requires CUPTI.
ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
@@ -81,7 +91,7 @@
setuptools
# Some TF tools expect a "python" binary
-RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
+RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
# Options:
# tensorflow
diff --git a/tensorflow/tools/dockerfiles/partials/ubuntu/devel-cpu.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/ubuntu/devel-cpu.partial.Dockerfile
index 8b4f1a7..5f6caa6 100644
--- a/tensorflow/tools/dockerfiles/partials/ubuntu/devel-cpu.partial.Dockerfile
+++ b/tensorflow/tools/dockerfiles/partials/ubuntu/devel-cpu.partial.Dockerfile
@@ -11,7 +11,7 @@
pkg-config \
rsync \
software-properties-common \
- sudo \
+ sudo \
unzip \
zip \
zlib1g-dev \
diff --git a/tensorflow/tools/dockerfiles/partials/ubuntu/devel-nvidia.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/ubuntu/devel-nvidia.partial.Dockerfile
index 496b3ac..d7e0107 100644
--- a/tensorflow/tools/dockerfiles/partials/ubuntu/devel-nvidia.partial.Dockerfile
+++ b/tensorflow/tools/dockerfiles/partials/ubuntu/devel-nvidia.partial.Dockerfile
@@ -1,20 +1,28 @@
ARG ARCH=
-ARG CUDA=10.0
+ARG CUDA=10.1
FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base
# ARCH and CUDA are specified again because the FROM directive resets ARGs
# (but their default value is retained if set previously)
ARG ARCH
ARG CUDA
-ARG CUDNN=7.6.2.24-1
+ARG CUDNN=7.6.4.38-1
ARG CUDNN_MAJOR_VERSION=7
ARG LIB_DIR_PREFIX=x86_64
+ARG LIBNVINFER=6.0.1-1
+ARG LIBNVINFER_MAJOR_VERSION=6
# Needed for string substitution
SHELL ["/bin/bash", "-c"]
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
cuda-command-line-tools-${CUDA/./-} \
- cuda-cublas-dev-${CUDA/./-} \
+ # There appears to be a regression in libcublas10=10.2.2.89-1 which
+ # prevents cublas from initializing in TF. See
+ # https://github.com/tensorflow/tensorflow/issues/9489#issuecomment-562394257
+ libcublas10=10.2.1.243-1 \
+ libcublas-dev=10.2.1.243-1 \
+ cuda-nvrtc-${CUDA/./-} \
+ cuda-nvrtc-dev-${CUDA/./-} \
cuda-cudart-dev-${CUDA/./-} \
cuda-cufft-dev-${CUDA/./-} \
cuda-curand-dev-${CUDA/./-} \
@@ -38,18 +46,19 @@
find /usr/local/cuda-${CUDA}/lib64/ -type f -name 'lib*_static.a' -not -name 'libcudart_static.a' -delete && \
rm /usr/lib/${LIB_DIR_PREFIX}-linux-gnu/libcudnn_static_v7.a
+# Install TensorRT if not building for PowerPC
RUN [[ "${ARCH}" = "ppc64le" ]] || { apt-get update && \
- apt-get install -y --no-install-recommends libnvinfer5=5.1.5-1+cuda${CUDA} \
- libnvinfer-dev=5.1.5-1+cuda${CUDA} \
+ apt-get install -y --no-install-recommends libnvinfer${LIBNVINFER_MAJOR_VERSION}=${LIBNVINFER}+cuda${CUDA} \
+ libnvinfer-dev=${LIBNVINFER}+cuda${CUDA} \
+ libnvinfer-plugin-dev=${LIBNVINFER}+cuda${CUDA} \
+ libnvinfer-plugin${LIBNVINFER_MAJOR_VERSION}=${LIBNVINFER}+cuda${CUDA} \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*; }
# Configure the build for our CUDA configuration.
-ENV CI_BUILD_PYTHON python
-ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
+ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/lib64:/usr/local/cuda/lib64/stubs:/usr/include/x64_64-linux-gnu:$LD_LIBRARY_PATH
ENV TF_NEED_CUDA 1
ENV TF_NEED_TENSORRT 1
-ENV TF_CUDA_COMPUTE_CAPABILITIES=3.5,5.2,6.0,6.1,7.0
ENV TF_CUDA_VERSION=${CUDA}
ENV TF_CUDNN_VERSION=${CUDNN_MAJOR_VERSION}
# CACHE_STOP is used to rerun future commands, otherwise cloning tensorflow will be cached and will not pull the most recent version
diff --git a/tensorflow/tools/dockerfiles/partials/ubuntu/nvidia.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/ubuntu/nvidia.partial.Dockerfile
index 8593d1f..555caf0 100644
--- a/tensorflow/tools/dockerfiles/partials/ubuntu/nvidia.partial.Dockerfile
+++ b/tensorflow/tools/dockerfiles/partials/ubuntu/nvidia.partial.Dockerfile
@@ -1,11 +1,15 @@
ARG ARCH=
-ARG CUDA=10.0
+ARG CUDA=10.1
FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base
# ARCH and CUDA are specified again because the FROM directive resets ARGs
# (but their default value is retained if set previously)
ARG ARCH
ARG CUDA
-ARG CUDNN=7.6.2.24-1
+ARG CUDNN=7.6.4.38-1
+ARG CUDNN_MAJOR_VERSION=7
+ARG LIB_DIR_PREFIX=x86_64
+ARG LIBNVINFER=6.0.1-1
+ARG LIBNVINFER_MAJOR_VERSION=6
# Needed for string substitution
SHELL ["/bin/bash", "-c"]
@@ -13,7 +17,11 @@
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
cuda-command-line-tools-${CUDA/./-} \
- cuda-cublas-${CUDA/./-} \
+ # There appears to be a regression in libcublas10=10.2.2.89-1 which
+ # prevents cublas from initializing in TF. See
+ # https://github.com/tensorflow/tensorflow/issues/9489#issuecomment-562394257
+ libcublas10=10.2.1.243-1 \
+ cuda-nvrtc-${CUDA/./-} \
cuda-cufft-${CUDA/./-} \
cuda-curand-${CUDA/./-} \
cuda-cusolver-${CUDA/./-} \
@@ -27,10 +35,12 @@
software-properties-common \
unzip
-RUN [ ${ARCH} = ppc64le ] || (apt-get update && \
- apt-get install -y --no-install-recommends libnvinfer5=5.1.5-1+cuda${CUDA} \
+# Install TensorRT if not building for PowerPC
+RUN [[ "${ARCH}" = "ppc64le" ]] || { apt-get update && \
+ apt-get install -y --no-install-recommends libnvinfer${LIBNVINFER_MAJOR_VERSION}=${LIBNVINFER}+cuda${CUDA} \
+ libnvinfer-plugin${LIBNVINFER_MAJOR_VERSION}=${LIBNVINFER}+cuda${CUDA} \
&& apt-get clean \
- && rm -rf /var/lib/apt/lists/*)
+ && rm -rf /var/lib/apt/lists/*; }
# For CUDA profiling, TensorFlow requires CUPTI.
ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
diff --git a/tensorflow/tools/dockerfiles/spec.yml b/tensorflow/tools/dockerfiles/spec.yml
index 5a64b70..29a4f74 100644
--- a/tensorflow/tools/dockerfiles/spec.yml
+++ b/tensorflow/tools/dockerfiles/spec.yml
@@ -57,6 +57,8 @@
- "{ubuntu-devel}{jupyter}"
- "{ubuntu-ppc64le}{jupyter}"
- "{ubuntu-devel-ppc64le}{jupyter}"
+ - "{ubuntu-horovod}{jupyter}"
+ - "{ubuntu-devel-horovod}{jupyter}"
slice_sets:
@@ -83,21 +85,6 @@
- ubuntu/python
- tensorflow
- shell
- - add_to_name: "-horovod"
- dockerfile_exclusive_name: "horovod"
- dockerfile_subdirectory: "mkl_horovod"
- partials:
- - ubuntu/version
- - ubuntu/cpu
- - ubuntu/python
- - tensorflow
- - mkl_horovod/mpi
- - mkl_horovod/horovod
- - shell
- tests:
- - import-mkl-horovod.sh
- args:
- - TF_PACKAGE=intel-tensorflow
- add_to_name: "-gpu"
dockerfile_exclusive_name: "gpu"
args:
@@ -125,6 +112,38 @@
- build-cpu.sh
args:
- CHECKOUT_TF_SRC=1
+ - add_to_name: "devel-gpu"
+ dockerfile_exclusive_name: "devel-gpu"
+ partials:
+ - ubuntu/version
+ - ubuntu/devel-nvidia
+ - ubuntu/python
+ - ubuntu/bazel
+ - shell
+ tests:
+ - build-gpu.sh
+ test_runtime: nvidia
+ args:
+ - CHECKOUT_TF_SRC=1
+
+ ubuntu-horovod:
+ - add_to_name: "-horovod"
+ dockerfile_exclusive_name: "horovod"
+ dockerfile_subdirectory: "mkl_horovod"
+ partials:
+ - ubuntu/version
+ - ubuntu/cpu
+ - ubuntu/python
+ - tensorflow
+ - mkl_horovod/mpi
+ - mkl_horovod/horovod
+ - shell
+ tests:
+ - import-mkl-horovod.sh
+ args:
+ - TF_PACKAGE=intel-tensorflow
+
+ ubuntu-devel-horovod:
- add_to_name: "devel-horovod"
dockerfile_exclusive_name: "devel-horovod"
dockerfile_subdirectory: "mkl_horovod"
@@ -141,19 +160,6 @@
args:
- CHECKOUT_TF_SRC=1
- CHECKOUT_HOROVOD_SRC=1
- - add_to_name: "devel-gpu"
- dockerfile_exclusive_name: "devel-gpu"
- partials:
- - ubuntu/version
- - ubuntu/devel-nvidia
- - ubuntu/python
- - ubuntu/bazel
- - shell
- tests:
- - build-gpu.sh
- test_runtime: nvidia
- args:
- - CHECKOUT_TF_SRC=1
ubuntu-ppc64le:
- add_to_name: "-ppc64le"
diff --git a/tensorflow/tools/dockerfiles/tests/build-cpu.sh b/tensorflow/tools/dockerfiles/tests/build-cpu.sh
index bcdc4c2..6b3c23f 100755
--- a/tensorflow/tools/dockerfiles/tests/build-cpu.sh
+++ b/tensorflow/tools/dockerfiles/tests/build-cpu.sh
@@ -15,23 +15,24 @@
# limitations under the License.
# ============================================================================
-# Download and build TensorFlow.
-set -euxo pipefail
-git clone --branch=master --depth=1 https://github.com/tensorflow/tensorflow.git /tensorflow
+set -ex
+git clone --branch=master --depth=1 https://github.com/tensorflow/tensorflow.git /tensorflow || true
cd /tensorflow
+ln -snf $(which ${PYTHON}) /usr/local/bin/python
+# Run configure.
+export TF_NEED_GCP=1
+export TF_NEED_HDFS=1
+export TF_NEED_S3=1
+export TF_NEED_CUDA=0
+# TensorRT build failing as of 2019-12-18, see
+# https://github.com/tensorflow/tensorflow/issues/35115
+export CC_OPT_FLAGS='-mavx'
+export PYTHON_BIN_PATH=$(which python3.7)
+export TMP=/tmp
+yes "" | /usr/local/bin/python configure.py
-ln -s $(which ${PYTHON}) /usr/local/bin/python
-
-# For optimized builds appropriate for the hardware platform of your choosing, uncomment below...
-# For ivy-bridge or sandy-bridge
-# --copt=-march="ivybridge" \
-# for haswell, broadwell, or skylake
-# --copt=-march="haswell" \
-tensorflow/tools/ci_build/builds/configured CPU \
- bazel build -c opt --copt=-mavx --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" \
- tensorflow/tools/pip_package:build_pip_package && \
- bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/pip && \
- pip --no-cache-dir install --upgrade /tmp/pip/tensorflow-*.whl && \
- rm -rf /tmp/pip && \
- rm -rf /root/.cache
+# Build the pip package and import
+bazel build --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" --config=opt --config=v2 tensorflow/tools/pip_package:build_pip_package
+./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg --gpu --nightly_flag
+pip --no-cache-dir install --upgrade /tmp/pip/tensorflow-*.whl
diff --git a/tensorflow/tools/dockerfiles/tests/build-gpu.sh b/tensorflow/tools/dockerfiles/tests/build-gpu.sh
index 0e107e3..2edef56 100755
--- a/tensorflow/tools/dockerfiles/tests/build-gpu.sh
+++ b/tensorflow/tools/dockerfiles/tests/build-gpu.sh
@@ -16,19 +16,25 @@
# ============================================================================
# Download and build TensorFlow.
-set -euxo pipefail
-git clone --branch=master --depth=1 https://github.com/tensorflow/tensorflow.git /tensorflow
+
+set -ex
+git clone --branch=master --depth=1 https://github.com/tensorflow/tensorflow.git /tensorflow || true
cd /tensorflow
+ln -snf $(which ${PYTHON}) /usr/local/bin/python
+# Run configure.
+export TF_NEED_GCP=1
+export TF_NEED_HDFS=1
+export TF_NEED_S3=1
+export TF_NEED_CUDA=1
+# TensorRT build failing as of 2019-12-18, see
+# https://github.com/tensorflow/tensorflow/issues/35115
+export TF_NEED_TENSORRT=0
+export CC_OPT_FLAGS='-mavx'
+export PYTHON_BIN_PATH=$(which python3.7)
+export TMP=/tmp
+yes "" | /usr/local/bin/python configure.py
-ln -s $(which ${PYTHON}) /usr/local/bin/python
-
-LD_LIBRARY_PATH=/usr/local/cuda/lib64/stubs:${LD_LIBRARY_PATH} \
-tensorflow/tools/ci_build/builds/configured GPU \
-bazel build -c opt --copt=-mavx --config=cuda \
- --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" \
- tensorflow/tools/pip_package:build_pip_package && \
-rm /usr/local/cuda/lib64/stubs/libcuda.so.1 && \
-bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/pip && \
-pip --no-cache-dir install --upgrade /tmp/pip/tensorflow-*.whl && \
-rm -rf /tmp/pip && \
-rm -rf /root/.cache
+# Build the pip package and import
+bazel build --config=cuda --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" --config=opt --config=v2 tensorflow/tools/pip_package:build_pip_package
+./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg --gpu --nightly_flag
+pip --no-cache-dir install --upgrade /tmp/pip/tensorflow-*.whl
diff --git a/tensorflow/tools/docs/generate2.py b/tensorflow/tools/docs/generate2.py
index 4c41e5c..6df4fc3 100644
--- a/tensorflow/tools/docs/generate2.py
+++ b/tensorflow/tools/docs/generate2.py
@@ -20,9 +20,9 @@
Requires a local installation of `tensorflow_docs`:
- ```
- pip install git+https://github.com/tensorflow/docs
- ```
+```
+pip install git+https://github.com/tensorflow/docs
+```
"""
from __future__ import absolute_import
@@ -34,7 +34,6 @@
from absl import app
from absl import flags
-from distutils.version import LooseVersion
import tensorflow as tf
@@ -56,7 +55,6 @@
# So patch `tf.__all__` to list everything.
tf.__all__ = [item_name for item_name, value in tf_inspect.getmembers(tf)]
-
FLAGS = flags.FLAGS
flags.DEFINE_string(
@@ -64,47 +62,31 @@
"/code/stable/tensorflow",
"A url to prepend to code paths when creating links to defining code")
-flags.DEFINE_string(
- "output_dir", "/tmp/out",
- "A directory, where the docs will be output to.")
+flags.DEFINE_string("output_dir", "/tmp/out",
+ "A directory, where the docs will be output to.")
flags.DEFINE_bool("search_hints", True,
"Include meta-data search hints at the top of each file.")
-flags.DEFINE_string("site_path", "",
- "The prefix ({site-path}/api_docs/python/...) used in the "
- "`_toc.yaml` and `_redirects.yaml` files")
+flags.DEFINE_string(
+ "site_path", "", "The prefix ({site-path}/api_docs/python/...) used in the "
+ "`_toc.yaml` and `_redirects.yaml` files")
+_PRIVATE_MAP = {
+ "tf": ["python", "core", "compiler", "examples", "tools"],
+ # There's some aliasing between the compats and v1/2s, so it's easier to
+ # block by name and location than by deleting, or hiding objects.
+ "tf.compat.v1.compat": ["v1", "v2"],
+ "tf.compat.v2.compat": ["v1", "v2"]
+}
-if tf.__version__.startswith('1'):
- PRIVATE_MAP = {
- 'tf.test': ['mock'],
- 'tf': ['python', 'core', 'compiler', 'examples', 'tools', 'contrib'],
- # There's some aliasing between the compats and v1/2s, so it's easier to
- # block by name and location than by deleting, or hiding objects.
- 'tf.compat.v1.compat': ['v1', 'v2'],
- 'tf.compat.v2.compat': ['v1', 'v2']
- }
+tf.__doc__ = """
+ ## TensorFlow
- DO_NOT_DESCEND_MAP = {
- 'tf': ['cli', 'lib', 'wrappers', 'contrib'],
- }
-else:
- PRIVATE_MAP = {
- 'tf': ['python', 'core', 'compiler', 'examples', 'tools'],
- # There's some aliasing between the compats and v1/2s, so it's easier to
- # block by name and location than by deleting, or hiding objects.
- 'tf.compat.v1.compat': ['v1', 'v2'],
- 'tf.compat.v2.compat': ['v1', 'v2']
- }
- DO_NOT_DESCEND_MAP = {}
- tf.__doc__ = """
- ## TensorFlow
-
- ```
- pip install tensorflow
- ```
- """
+ ```
+ pip install tensorflow
+ ```
+ """
_raw_ops_doc = textwrap.dedent("""\n
Note: `tf.raw_ops` provides direct/low level access to all TensorFlow ops. See \
@@ -112,27 +94,14 @@
for details. Unless you are library writer, you likely do not need to use these
ops directly.""")
-if LooseVersion(tf.__version__) < LooseVersion('2'):
- tf.raw_ops.__doc__ = _raw_ops_doc
- tf.contrib.__doc__ = """
- Contrib module containing volatile or experimental code.
-
- Warning: The `tf.contrib` module will not be included in TensorFlow 2.0. Many
- of its submodules have been integrated into TensorFlow core, or spun-off into
- other projects like [`tensorflow_io`](https://github.com/tensorflow/io), or
- [`tensorflow_addons`](https://github.com/tensorflow/addons). For instructions
- on how to upgrade see the
- [Migration guide](https://www.tensorflow.org/guide/migrate).
- """
-else:
- tf.raw_ops.__doc__ += _raw_ops_doc
+tf.raw_ops.__doc__ += _raw_ops_doc
# The doc generator isn't aware of tf_export.
# So prefix the score tuples with -1 when this is the canonical name, +1
# otherwise. The generator chooses the name with the lowest score.
-class TfExportAwareDocGeneratorVisitor(
- doc_generator_visitor.DocGeneratorVisitor):
+class TfExportAwareDocGeneratorVisitor(doc_generator_visitor.DocGeneratorVisitor
+ ):
"""A `tf_export` aware doc_visitor."""
def _score_name(self, name):
@@ -214,30 +183,25 @@
"https://github.com/tensorflow/estimator/tree/master/tensorflow_estimator",
)
- if LooseVersion(tf.__version__) < LooseVersion('2'):
- root_title = 'TensorFlow'
- elif LooseVersion(tf.__version__) >= LooseVersion('2'):
- root_title = 'TensorFlow 2.0'
-
doc_generator = generate_lib.DocGenerator(
- root_title=root_title,
+ root_title="TensorFlow 2.0",
py_modules=[("tf", tf)],
base_dir=base_dirs,
search_hints=search_hints,
code_url_prefix=code_url_prefixes,
site_path=FLAGS.site_path,
visitor_cls=TfExportAwareDocGeneratorVisitor,
- private_map=PRIVATE_MAP,
- do_not_descend_map=DO_NOT_DESCEND_MAP)
+ private_map=_PRIVATE_MAP)
doc_generator.build(output_dir)
def main(argv):
del argv
- build_docs(output_dir=FLAGS.output_dir,
- code_url_prefix=FLAGS.code_url_prefix,
- search_hints=FLAGS.search_hints)
+ build_docs(
+ output_dir=FLAGS.output_dir,
+ code_url_prefix=FLAGS.code_url_prefix,
+ search_hints=FLAGS.search_hints)
if __name__ == "__main__":
diff --git a/tensorflow/tools/docs/generate2_test.py b/tensorflow/tools/docs/generate2_test.py
index e4cd344..2775667 100644
--- a/tensorflow/tools/docs/generate2_test.py
+++ b/tensorflow/tools/docs/generate2_test.py
@@ -32,6 +32,7 @@
del tf.compat.v2
del tf.compat.v1
+
class Generate2Test(googletest.TestCase):
def test_end_to_end(self):
diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py
index 7e3643f..73c8cfb 100644
--- a/tensorflow/tools/pip_package/pip_smoke_test.py
+++ b/tensorflow/tools/pip_package/pip_smoke_test.py
@@ -84,6 +84,7 @@
"//tensorflow/core/kernels/cloud:bigquery_reader_ops",
"//tensorflow/python/debug:grpc_tensorflow_server.par",
"//tensorflow/python/feature_column:vocabulary_testdata",
+ "//tensorflow/python/keras:vocabulary_testdata",
"//tensorflow/python:framework/test_file_system.so",
"//tensorflow/python:util_nest_test_main_lib",
# lite
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index 90205a5..ea08517 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -73,6 +73,10 @@
# functools comes with python3, need to install the backport for python2
'functools32 >= 3.2.3;python_version<"3"',
'six >= 1.12.0',
+ # scipy < 1.4.1 causes segfaults due to pybind11
+ # Latest scipy pip for py2 is scipy==1.2.2
+ 'scipy == 1.4.1;python_version>="3"',
+ 'scipy == 1.2.2;python_version<"3"',
]
if sys.byteorder == 'little':
diff --git a/tensorflow/tools/proto_text/BUILD b/tensorflow/tools/proto_text/BUILD
index 85fba7e..9e0f976 100644
--- a/tensorflow/tools/proto_text/BUILD
+++ b/tensorflow/tools/proto_text/BUILD
@@ -41,7 +41,7 @@
"@com_google_protobuf//:protobuf",
"//tensorflow/core/platform:protobuf_compiler",
"//tensorflow/core:lib_proto_parsing",
- ] + if_ios(["//tensorflow/core/platform/default/build_config:logging"]),
+ ] + if_ios(["//tensorflow/core/platform:logging"]),
)
cc_library(
@@ -67,7 +67,7 @@
}),
deps = [
"//tensorflow/core:lib_proto_parsing",
- ] + if_ios(["//tensorflow/core/platform/default/build_config:logging"]),
+ ] + if_ios(["//tensorflow/core/platform:logging"]),
)
tf_proto_library_cc(
diff --git a/tensorflow/tools/test/run_and_gather_logs_lib.py b/tensorflow/tools/test/run_and_gather_logs_lib.py
index f629e3a..f92fb7b 100644
--- a/tensorflow/tools/test/run_and_gather_logs_lib.py
+++ b/tensorflow/tools/test/run_and_gather_logs_lib.py
@@ -158,7 +158,10 @@
try:
if not gfile.Exists(test_executable):
- raise ValueError("Executable does not exist: %s" % test_executable)
+ test_executable_py3 = test_executable + ".python3"
+ if not gfile.Exists(test_executable_py3):
+ raise ValueError("Executable does not exist: %s" % test_executable)
+ test_executable = test_executable_py3
test_args = shlex.split(test_args)
# This key is defined in tf/core/util/reporter.h as
diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD
index 638346e..7f5b4f1 100644
--- a/third_party/mlir/BUILD
+++ b/third_party/mlir/BUILD
@@ -167,6 +167,7 @@
"include/mlir/Pass/Pass.h",
"include/mlir/Pass/PassInstrumentation.h",
"include/mlir/Pass/PassManager.h",
+ "include/mlir/Pass/PassOptions.h",
"include/mlir/Pass/PassRegistry.h",
],
includes = ["include"],
@@ -2153,6 +2154,7 @@
srcs = [
"include/mlir/Dialect/Linalg/IR/LinalgBase.td",
"include/mlir/Dialect/Linalg/IR/LinalgOps.td",
+ "include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td",
":AffineOpsTdFiles",
":OpBaseTdFiles",
],
@@ -2179,40 +2181,64 @@
)
filegroup(
- name = "LinalgLibraryOpsTdFiles",
+ name = "LinalgStructuredOpsTdFiles",
srcs = [
"include/mlir/Dialect/Linalg/IR/LinalgBase.td",
- "include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td",
+ "include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td",
":AffineOpsTdFiles",
":OpBaseTdFiles",
],
)
gentbl(
- name = "LinalgLibraryOpsIncGen",
+ name = "LinalgStructuredOpsIncGen",
strip_include_prefix = "include",
tbl_outs = [
(
"-gen-op-decls",
- "include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.h.inc",
+ "include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.h.inc",
),
(
"-gen-op-defs",
- "include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.cpp.inc",
+ "include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc",
),
(
"-gen-op-interface-decls",
- "include/mlir/Dialect/Linalg/IR/LinalgLibraryOpInterfaces.h.inc",
+ "include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterfaces.h.inc",
),
(
"-gen-op-interface-defs",
- "include/mlir/Dialect/Linalg/IR/LinalgLibraryOpInterfaces.cpp.inc",
+ "include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterfaces.cpp.inc",
),
],
tblgen = ":mlir-tblgen",
- td_file = "include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td",
+ td_file = "include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td",
td_srcs = [
- ":LinalgLibraryOpsTdFiles",
+ ":LinalgStructuredOpsTdFiles",
+ ],
+)
+
+filegroup(
+ name = "LinalgDocTdFiles",
+ srcs = [
+ "include/mlir/Dialect/Linalg/IR/LinalgDoc.td",
+ ":LinalgOpsTdFiles",
+ ],
+)
+
+gentbl(
+ name = "LinalgDocIncGen",
+ strip_include_prefix = "include",
+ tbl_outs = [
+ (
+ "-gen-op-doc",
+ "g3doc/Dialects/Linalg/LinalgOps.md",
+ ),
+ ],
+ tblgen = ":mlir-tblgen",
+ td_file = "include/mlir/Dialect/Linalg/IR/LinalgDoc.td",
+ td_srcs = [
+ ":LinalgDocTdFiles",
],
)
@@ -2221,8 +2247,8 @@
srcs = [
"include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td",
":AffineOpsTdFiles",
- ":LinalgLibraryOpsTdFiles",
":LinalgOpsTdFiles",
+ ":LinalgStructuredOpsTdFiles",
":OpBaseTdFiles",
],
)
@@ -2308,8 +2334,8 @@
":IR",
":LLVMDialect",
":LLVMTransforms",
- ":LinalgLibraryOpsIncGen",
":LinalgOpsIncGen",
+ ":LinalgStructuredOpsIncGen",
":LinalgTransformPatternsIncGen",
":LoopOps",
":Parser",
@@ -2427,8 +2453,8 @@
srcs = [
"include/mlir/Dialect/VectorOps/VectorTransformPatterns.td",
":AffineOpsTdFiles",
- ":LinalgLibraryOpsTdFiles",
":LinalgOpsTdFiles",
+ ":LinalgStructuredOpsTdFiles",
":OpBaseTdFiles",
":StdOpsTdFiles",
":VectorOpsTdFiles",
diff --git a/third_party/mlir/CMakeLists.txt b/third_party/mlir/CMakeLists.txt
index d6767fa..67d1f00 100644
--- a/third_party/mlir/CMakeLists.txt
+++ b/third_party/mlir/CMakeLists.txt
@@ -12,23 +12,24 @@
PARENT_SCOPE)
endfunction()
-function(add_mlir_dialect dialect)
+function(add_mlir_dialect dialect dialect_doc_filename)
set(LLVM_TARGET_DEFINITIONS ${dialect}.td)
mlir_tablegen(${dialect}.h.inc -gen-op-decls)
mlir_tablegen(${dialect}.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIR${dialect}IncGen)
# Generate Dialect Documentation
- tablegen(MLIR ${dialect}.md -gen-op-doc "-I${MLIR_MAIN_SRC_DIR}" "-I${MLIR_INCLUDE_DIR}")
- set(GEN_DOC_FILE ${MLIR_BINARY_DIR}/docs/Dialects/${dialect}.md)
+ set(LLVM_TARGET_DEFINITIONS ${dialect_doc_filename}.td)
+ tablegen(MLIR ${dialect_doc_filename}.md -gen-op-doc "-I${MLIR_MAIN_SRC_DIR}" "-I${MLIR_INCLUDE_DIR}")
+ set(GEN_DOC_FILE ${MLIR_BINARY_DIR}/docs/Dialects/${dialect_doc_filename}.md)
add_custom_command(
OUTPUT ${GEN_DOC_FILE}
COMMAND ${CMAKE_COMMAND} -E copy
- ${CMAKE_CURRENT_BINARY_DIR}/${dialect}.md
+ ${CMAKE_CURRENT_BINARY_DIR}/${dialect_doc_filename}.md
${GEN_DOC_FILE}
- DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/${dialect}.md)
- add_custom_target(${dialect}DocGen DEPENDS ${GEN_DOC_FILE})
- add_dependencies(mlir-doc ${dialect}DocGen)
+ DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/${dialect_doc_filename}.md)
+ add_custom_target(${dialect_doc_filename}DocGen DEPENDS ${GEN_DOC_FILE})
+ add_dependencies(mlir-doc ${dialect_doc_filename}DocGen)
endfunction()
add_custom_target(mlir-doc)
diff --git a/third_party/mlir/CONTRIBUTING.md b/third_party/mlir/CONTRIBUTING.md
index e21e4b8..ffb19fe 100644
--- a/third_party/mlir/CONTRIBUTING.md
+++ b/third_party/mlir/CONTRIBUTING.md
@@ -46,4 +46,3 @@
Include a license at the top of new files.
* [C/C++ license example](https://github.com/tensorflow/mlir/blob/master/examples/toy/Ch1/toyc.cpp)
-* [Python license example](https://github.com/tensorflow/mlir/blob/master/bindings/python/test/test_py2and3.py)
diff --git a/third_party/mlir/LICENSE.TXT b/third_party/mlir/LICENSE.TXT
index a4b160b..fa6ac54 100644
--- a/third_party/mlir/LICENSE.TXT
+++ b/third_party/mlir/LICENSE.TXT
@@ -1,12 +1,14 @@
-Copyright 2019 The MLIR Authors.
+==============================================================================
+The LLVM Project is under the Apache License v2.0 with LLVM Exceptions:
+==============================================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
- 1. Definitions.
+ 1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
@@ -65,14 +67,14 @@
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
- 2. Grant of Copyright License. Subject to the terms and conditions of
+ 2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
- 3. Grant of Patent License. Subject to the terms and conditions of
+ 3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
@@ -88,7 +90,7 @@
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
- 4. Redistribution. You may reproduce and distribute copies of the
+ 4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
@@ -129,7 +131,7 @@
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
- 5. Submission of Contributions. Unless You explicitly state otherwise,
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
@@ -137,12 +139,12 @@
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
- 6. Trademarks. This License does not grant permission to use the trade
+ 6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
- 7. Disclaimer of Warranty. Unless required by applicable law or
+ 7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
@@ -152,7 +154,7 @@
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
- 8. Limitation of Liability. In no event and under no legal theory,
+ 8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
@@ -164,7 +166,7 @@
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
- 9. Accepting Warranty or Additional Liability. While redistributing
+ 9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
@@ -175,9 +177,9 @@
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
- END OF TERMS AND CONDITIONS
+ END OF TERMS AND CONDITIONS
- APPENDIX: How to apply the Apache License to your work.
+ APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
@@ -188,18 +190,90 @@
same "printed page" as the copyright notice for easier
identification within third-party archives.
- Copyright [yyyy] [name of copyright owner]
+ Copyright [yyyy] [name of copyright owner]
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+---- LLVM Exceptions to the Apache 2.0 License ----
+
+As an exception, if, as a result of your compiling your source code, portions
+of this Software are embedded into an Object form of such source code, you
+may redistribute such embedded portions in such Object form without complying
+with the conditions of Sections 4(a), 4(b) and 4(d) of the License.
+
+In addition, if you combine or link compiled forms of this Software with
+software that is licensed under the GPLv2 ("Combined Software") and if a
+court of competent jurisdiction determines that the patent provision (Section
+3), the indemnity provision (Section 9) or other Section of the License
+conflicts with the conditions of the GPLv2, you may retroactively and
+prospectively choose to deem waived or otherwise exclude such Section(s) of
+the License, but only in their entirety and only with respect to the Combined
+Software.
+
+==============================================================================
+Software from third parties included in the LLVM Project:
+==============================================================================
+The LLVM Project contains third party software which is under different license
+terms. All such code will be identified clearly using at least one of two
+mechanisms:
+1) It will be in a separate directory tree with its own `LICENSE.txt` or
+ `LICENSE` file at the top containing the specific license and restrictions
+ which apply to that software, or
+2) It will contain specific license and restriction terms at the top of every
+ file.
+
+==============================================================================
+Legacy LLVM License (https://llvm.org/docs/DeveloperPolicy.html#legacy):
+==============================================================================
+University of Illinois/NCSA
+Open Source License
+
+Copyright (c) 2003-2019 University of Illinois at Urbana-Champaign.
+All rights reserved.
+
+Developed by:
+
+ LLVM Team
+
+ University of Illinois at Urbana-Champaign
+
+ http://llvm.org
+
+Permission is hereby granted, free of charge, to any person obtaining a copy of
+this software and associated documentation files (the "Software"), to deal with
+the Software without restriction, including without limitation the rights to
+use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
+of the Software, and to permit persons to whom the Software is furnished to do
+so, subject to the following conditions:
+
+ * Redistributions of source code must retain the above copyright notice,
+ this list of conditions and the following disclaimers.
+
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimers in the
+ documentation and/or other materials provided with the distribution.
+
+ * Neither the names of the LLVM Team, University of Illinois at
+ Urbana-Champaign, nor the names of its contributors may be used to
+ endorse or promote products derived from this Software without specific
+ prior written permission.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
+FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH THE
+SOFTWARE.
+
diff --git a/third_party/mlir/bindings/python/BUILD b/third_party/mlir/bindings/python/BUILD
deleted file mode 100644
index 64ade7f..0000000
--- a/third_party/mlir/bindings/python/BUILD
+++ /dev/null
@@ -1,38 +0,0 @@
-# Description:
-# BUILD file for the Python bindings.
-
-licenses(["notice"]) # Apache 2.0
-
-# Export the BUILD file so automated tooling can check licenses
-exports_files(["BUILD"])
-
-package(
- default_visibility = ["@local_config_mlir//:friends"],
-)
-
-#
-# Pybind route uses exceptions and py_extension.
-#
-py_extension(
- name = "_pybind",
- srcs = ["pybind.cpp"],
- copts = ["-fexceptions"],
- features = ["-use_header_modules"],
- module_name = "pybind",
- deps = [
- "//third_party/llvm/llvm:ir",
- "//third_party/llvm/llvm:support",
- "//third_party/pybind11",
- "@local_config_mlir//:AffineToStandardTransforms",
- "@local_config_mlir//:EDSC",
- "@local_config_mlir//:EDSCInterface",
- "@local_config_mlir//:ExecutionEngine",
- "@local_config_mlir//:ExecutionEngineUtils",
- "@local_config_mlir//:IR",
- "@local_config_mlir//:LLVMTransforms",
- "@local_config_mlir//:Pass",
- "@local_config_mlir//:StandardDialectRegistration",
- "@local_config_mlir//:TargetLLVMIR",
- "@local_config_mlir//:Transforms",
- ],
-)
diff --git a/third_party/mlir/bindings/python/pybind.cpp b/third_party/mlir/bindings/python/pybind.cpp
deleted file mode 100644
index 825f800..0000000
--- a/third_party/mlir/bindings/python/pybind.cpp
+++ /dev/null
@@ -1,1167 +0,0 @@
-//===- pybind.cpp - MLIR Python bindings ----------------------------------===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/StringRef.h"
-#include "llvm/IR/Function.h"
-#include "llvm/IR/Module.h"
-#include "llvm/Support/TargetSelect.h"
-#include "llvm/Support/raw_ostream.h"
-#include <cstddef>
-#include <unordered_map>
-
-#include "mlir-c/Core.h"
-#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
-#include "mlir/EDSC/Builders.h"
-#include "mlir/EDSC/Helpers.h"
-#include "mlir/EDSC/Intrinsics.h"
-#include "mlir/ExecutionEngine/ExecutionEngine.h"
-#include "mlir/ExecutionEngine/OptUtils.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Function.h"
-#include "mlir/IR/Module.h"
-#include "mlir/IR/Types.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassManager.h"
-#include "mlir/Target/LLVMIR.h"
-#include "mlir/Transforms/Passes.h"
-#include "pybind11/pybind11.h"
-#include "pybind11/pytypes.h"
-#include "pybind11/stl.h"
-
-static bool inited = [] {
- llvm::InitializeNativeTarget();
- llvm::InitializeNativeTargetAsmPrinter();
- return true;
-}();
-
-namespace mlir {
-namespace edsc {
-namespace python {
-
-namespace py = pybind11;
-
-struct PythonAttribute;
-struct PythonAttributedType;
-struct PythonBindable;
-struct PythonExpr;
-struct PythonFunctionContext;
-struct PythonStmt;
-struct PythonBlock;
-struct PythonAffineExpr;
-struct PythonAffineMap;
-
-struct PythonType {
- PythonType() : type{nullptr} {}
- PythonType(mlir_type_t t) : type{t} {}
-
- operator mlir_type_t() const { return type; }
-
- PythonAttributedType attachAttributeDict(
- const std::unordered_map<std::string, PythonAttribute> &attrs) const;
-
- std::string str() {
- mlir::Type f = mlir::Type::getFromOpaquePointer(type);
- std::string res;
- llvm::raw_string_ostream os(res);
- f.print(os);
- return res;
- }
-
- mlir_type_t type;
-};
-
-struct PythonValueHandle {
- PythonValueHandle(PythonType type)
- : value(mlir::Type::getFromOpaquePointer(type.type)) {}
- PythonValueHandle(const PythonValueHandle &other) = default;
- PythonValueHandle(const mlir::edsc::ValueHandle &other) : value(other) {}
- operator ValueHandle() const { return value; }
- operator ValueHandle &() { return value; }
-
- std::string str() const {
- return std::to_string(reinterpret_cast<intptr_t>(value.getValue()));
- }
-
- PythonValueHandle call(const std::vector<PythonValueHandle> &args) {
- assert(value.hasType() && value.getType().isa<FunctionType>() &&
- "can only call function-typed values");
-
- std::vector<Value *> argValues;
- argValues.reserve(args.size());
- for (auto arg : args)
- argValues.push_back(arg.value.getValue());
- return ValueHandle::create<CallIndirectOp>(value, argValues);
- }
-
- PythonType type() const {
- return PythonType(value.getType().getAsOpaquePointer());
- }
-
- mlir::edsc::ValueHandle value;
-};
-
-struct PythonFunction {
- PythonFunction() : function{nullptr} {}
- PythonFunction(mlir_func_t f) : function{f} {}
- PythonFunction(mlir::FuncOp f)
- : function(const_cast<void *>(f.getAsOpaquePointer())) {}
- operator mlir_func_t() { return function; }
- std::string str() {
- mlir::FuncOp f = mlir::FuncOp::getFromOpaquePointer(function);
- std::string res;
- llvm::raw_string_ostream os(res);
- f.print(os);
- return res;
- }
-
- // If the function does not yet have an entry block, i.e. if it is a function
- // declaration, add the entry block, transforming the declaration into a
- // definition. Return true if the block was added, false otherwise.
- bool define() {
- auto f = mlir::FuncOp::getFromOpaquePointer(function);
- if (!f.getBlocks().empty())
- return false;
-
- f.addEntryBlock();
- return true;
- }
-
- PythonValueHandle arg(unsigned index) {
- auto f = mlir::FuncOp::getFromOpaquePointer(function);
- assert(index < f.getNumArguments() && "argument index out of bounds");
- return PythonValueHandle(ValueHandle(f.getArgument(index)));
- }
-
- mlir_func_t function;
-};
-
-/// Trivial C++ wrappers make use of the EDSC C API.
-struct PythonMLIRModule {
- PythonMLIRModule()
- : mlirContext(),
- module(mlir::ModuleOp::create(mlir::UnknownLoc::get(&mlirContext))),
- symbolTable(*module) {}
-
- PythonType makeMemRefType(PythonType elemType, std::vector<int64_t> sizes) {
- return ::makeMemRefType(mlir_context_t{&mlirContext}, elemType,
- int64_list_t{sizes.data(), sizes.size()});
- }
- PythonType makeIndexType() {
- return ::makeIndexType(mlir_context_t{&mlirContext});
- }
- PythonType makeType(const std::string &type) {
- return ::mlirParseType(type.c_str(), mlir_context_t{&mlirContext}, nullptr);
- }
-
- // Declare a function with the given name, input types and their attributes,
- // output types, and function attributes, but do not define it.
- PythonFunction declareFunction(const std::string &name,
- const py::list &inputs,
- const std::vector<PythonType> &outputTypes,
- const py::kwargs &funcAttributes);
-
- // Declare a function with the given name, input types and their attributes,
- // output types, and function attributes.
- PythonFunction makeFunction(const std::string &name, const py::list &inputs,
- const std::vector<PythonType> &outputTypes,
- const py::kwargs &funcAttributes) {
- auto declaration =
- declareFunction(name, inputs, outputTypes, funcAttributes);
- declaration.define();
- return declaration;
- }
-
- // Create a custom op given its name and arguments.
- PythonExpr op(const std::string &name, PythonType type,
- const py::list &arguments, const py::list &successors,
- py::kwargs attributes);
-
- // Creates an integer attribute.
- PythonAttribute integerAttr(PythonType type, int64_t value);
-
- // Creates a boolean attribute.
- PythonAttribute boolAttr(bool value);
-
- // Creates a float attribute.
- PythonAttribute floatAttr(float value);
-
- // Creates a string atrribute.
- PythonAttribute stringAttr(const std::string &value);
-
- // Creates an Array attribute.
- PythonAttribute arrayAttr(const std::vector<PythonAttribute> &values);
-
- // Creates an AffineMap attribute.
- PythonAttribute affineMapAttr(PythonAffineMap value);
-
- // Creates an affine constant expression.
- PythonAffineExpr affineConstantExpr(int64_t value);
-
- // Creates an affine symbol expression.
- PythonAffineExpr affineSymbolExpr(unsigned position);
-
- // Creates an affine dimension expression.
- PythonAffineExpr affineDimExpr(unsigned position);
-
- // Creates a single constant result affine map.
- PythonAffineMap affineConstantMap(int64_t value);
-
- // Creates an affine map.
- PythonAffineMap affineMap(unsigned dimCount, unsigned symbolCount,
- const std::vector<PythonAffineExpr> &results);
-
- // Compile the module save the execution engine. "optLevel" and
- // "codegenOptLevel" contain the levels of optimization to run (0 to 3) for
- // transformations and codegen. -1 means ExecutionEngine default.
- void compile(int optLevel, int codegenOptLevel) {
- PassManager manager(module->getContext());
- manager.addNestedPass<FuncOp>(mlir::createCanonicalizerPass());
- manager.addNestedPass<FuncOp>(mlir::createCSEPass());
- manager.addPass(mlir::createLowerAffinePass());
- manager.addPass(mlir::createLowerToLLVMPass());
- if (failed(manager.run(*module))) {
- llvm::errs() << "conversion to the LLVM IR dialect failed\n";
- return;
- }
-
- // Make sure the executione engine runs LLVM passes for the specified
- // optimization level.
- auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
- assert(tmBuilderOrError);
- auto tmOrError = tmBuilderOrError->createTargetMachine();
- assert(tmOrError);
- targetMachine = std::move(tmOrError.get());
- auto transformer = mlir::makeLLVMPassesTransformer(
- /*llvmPasses=*/{},
- optLevel == -1 ? llvm::Optional<unsigned>() : optLevel,
- targetMachine.get(),
- /*optPassesInsertPos=*/0);
-
- auto created = mlir::ExecutionEngine::create(
- *module, transformer,
- codegenOptLevel == -1
- ? llvm::Optional<llvm::CodeGenOpt::Level>()
- : static_cast<llvm::CodeGenOpt::Level>(codegenOptLevel));
- llvm::handleAllErrors(created.takeError(),
- [](const llvm::ErrorInfoBase &b) {
- b.log(llvm::errs());
- assert(false);
- });
- engine = std::move(*created);
- }
-
- std::string getIR() {
- std::string res;
- llvm::raw_string_ostream os(res);
- module->print(os);
- return res;
- }
-
- uint64_t getEngineAddress() {
- assert(engine && "module must be compiled into engine first");
- return reinterpret_cast<uint64_t>(reinterpret_cast<void *>(engine.get()));
- }
-
- PythonFunction getNamedFunction(const std::string &name) {
- return symbolTable.lookup<FuncOp>(name);
- }
-
- PythonFunctionContext
- makeFunctionContext(const std::string &name, const py::list &inputs,
- const std::vector<PythonType> &outputs,
- const py::kwargs &attributes);
-
-private:
- mlir::MLIRContext mlirContext;
- // One single module in a python-exposed MLIRContext for now.
- mlir::OwningModuleRef module;
- mlir::SymbolTable symbolTable;
-
- // An execution engine and an associated target machine. The latter must
- // outlive the former since it may be used by the transformation layers.
- std::unique_ptr<mlir::ExecutionEngine> engine;
- std::unique_ptr<llvm::TargetMachine> targetMachine;
-};
-
-struct PythonFunctionContext {
- PythonFunctionContext(PythonFunction f) : function(f) {}
- PythonFunctionContext(PythonMLIRModule &module, const std::string &name,
- const py::list &inputs,
- const std::vector<PythonType> &outputs,
- const py::kwargs &attributes) {
- auto function = module.declareFunction(name, inputs, outputs, attributes);
- function.define();
- }
-
- PythonFunction enter() {
- assert(function.function && "function is not set up");
- auto mlirFunc = mlir::FuncOp::getFromOpaquePointer(function.function);
- contextBuilder.emplace(mlirFunc.getBody());
- context = new mlir::edsc::ScopedContext(*contextBuilder, mlirFunc.getLoc());
- return function;
- }
-
- void exit(py::object, py::object, py::object) {
- delete context;
- context = nullptr;
- contextBuilder.reset();
- }
-
- PythonFunction function;
- mlir::edsc::ScopedContext *context;
- llvm::Optional<OpBuilder> contextBuilder;
-};
-
-PythonFunctionContext PythonMLIRModule::makeFunctionContext(
- const std::string &name, const py::list &inputs,
- const std::vector<PythonType> &outputs, const py::kwargs &attributes) {
- auto func = declareFunction(name, inputs, outputs, attributes);
- func.define();
- return PythonFunctionContext(func);
-}
-
-struct PythonBlockHandle {
- PythonBlockHandle() : value(nullptr) {}
- PythonBlockHandle(const PythonBlockHandle &other) = default;
- PythonBlockHandle(const mlir::edsc::BlockHandle &other) : value(other) {}
- operator mlir::edsc::BlockHandle() const { return value; }
-
- PythonValueHandle arg(int index) { return arguments[index]; }
-
- std::string str() {
- std::string s;
- llvm::raw_string_ostream os(s);
- value.getBlock()->print(os);
- return os.str();
- }
-
- mlir::edsc::BlockHandle value;
- std::vector<mlir::edsc::ValueHandle> arguments;
-};
-
-struct PythonLoopContext {
- PythonLoopContext(PythonValueHandle lb, PythonValueHandle ub, int64_t step)
- : lb(lb), ub(ub), step(step) {}
- PythonLoopContext(const PythonLoopContext &) = delete;
- PythonLoopContext(PythonLoopContext &&) = default;
- PythonLoopContext &operator=(const PythonLoopContext &) = delete;
- PythonLoopContext &operator=(PythonLoopContext &&) = default;
- ~PythonLoopContext() { assert(!builder && "did not exit from the context"); }
-
- PythonValueHandle enter() {
- ValueHandle iv(lb.value.getType());
- builder = new AffineLoopNestBuilder(&iv, lb.value, ub.value, step);
- return iv;
- }
-
- void exit(py::object, py::object, py::object) {
- (*builder)({}); // exit from the builder's scope.
- delete builder;
- builder = nullptr;
- }
-
- PythonValueHandle lb, ub;
- int64_t step;
- AffineLoopNestBuilder *builder = nullptr;
-};
-
-struct PythonLoopNestContext {
- PythonLoopNestContext(const std::vector<PythonValueHandle> &lbs,
- const std::vector<PythonValueHandle> &ubs,
- const std::vector<int64_t> steps)
- : lbs(lbs), ubs(ubs), steps(steps) {
- assert(lbs.size() == ubs.size() && lbs.size() == steps.size() &&
- "expected the same number of lower, upper bounds, and steps");
- }
- PythonLoopNestContext(const PythonLoopNestContext &) = delete;
- PythonLoopNestContext(PythonLoopNestContext &&) = default;
- PythonLoopNestContext &operator=(const PythonLoopNestContext &) = delete;
- PythonLoopNestContext &operator=(PythonLoopNestContext &&) = default;
- ~PythonLoopNestContext() {
- assert(!builder && "did not exit from the context");
- }
-
- std::vector<PythonValueHandle> enter() {
- if (steps.empty())
- return {};
-
- auto type = mlir_type_t(lbs.front().value.getType().getAsOpaquePointer());
- std::vector<PythonValueHandle> handles(steps.size(),
- PythonValueHandle(type));
- std::vector<ValueHandle *> handlePtrs;
- handlePtrs.reserve(steps.size());
- for (auto &h : handles)
- handlePtrs.push_back(&h.value);
- builder = new AffineLoopNestBuilder(
- handlePtrs, std::vector<ValueHandle>(lbs.begin(), lbs.end()),
- std::vector<ValueHandle>(ubs.begin(), ubs.end()), steps);
- return handles;
- }
-
- void exit(py::object, py::object, py::object) {
- (*builder)({}); // exit from the builder's scope.
- delete builder;
- builder = nullptr;
- }
-
- std::vector<PythonValueHandle> lbs;
- std::vector<PythonValueHandle> ubs;
- std::vector<int64_t> steps;
- AffineLoopNestBuilder *builder = nullptr;
-};
-
-struct PythonBlockAppender {
- PythonBlockAppender(const PythonBlockHandle &handle) : handle(handle) {}
- PythonBlockHandle handle;
-};
-
-struct PythonBlockContext {
-public:
- PythonBlockContext() {
- createBlockBuilder();
- clearBuilder();
- }
- PythonBlockContext(const std::vector<PythonType> &argTypes) {
- handle.arguments.reserve(argTypes.size());
- for (const auto &t : argTypes) {
- auto type =
- Type::getFromOpaquePointer(reinterpret_cast<const void *>(t.type));
- handle.arguments.emplace_back(type);
- }
- createBlockBuilder();
- clearBuilder();
- }
- PythonBlockContext(const PythonBlockAppender &a) : handle(a.handle) {}
- PythonBlockContext(const PythonBlockContext &) = delete;
- PythonBlockContext(PythonBlockContext &&) = default;
- PythonBlockContext &operator=(const PythonBlockContext &) = delete;
- PythonBlockContext &operator=(PythonBlockContext &&) = default;
- ~PythonBlockContext() {
- assert(!builder && "did not exit from the block context");
- }
-
- // EDSC maintain an implicit stack of builders (mostly for keeping track of
- // insertion points); every operation gets inserted using the top-of-the-stack
- // builder. Creating a new EDSC Builder automatically puts it on the stack,
- // effectively entering the block for it.
- void createBlockBuilder() {
- if (handle.value.getBlock()) {
- builder = new BlockBuilder(handle.value, mlir::edsc::Append());
- } else {
- std::vector<ValueHandle *> args;
- args.reserve(handle.arguments.size());
- for (auto &a : handle.arguments)
- args.push_back(&a);
- builder = new BlockBuilder(&handle.value, args);
- }
- }
-
- PythonBlockHandle enter() {
- createBlockBuilder();
- return handle;
- }
-
- void exit(py::object, py::object, py::object) { clearBuilder(); }
-
- PythonBlockHandle getHandle() { return handle; }
-
- // EDSC maintain an implicit stack of builders (mostly for keeping track of
- // insertion points); every operation gets inserted using the top-of-the-stack
- // builder. Calling operator() on a builder pops the builder from the stack,
- // effectively resetting the insertion point to its position before we entered
- // the block.
- void clearBuilder() {
- (*builder)({}); // exit from the builder's scope.
- delete builder;
- builder = nullptr;
- }
-
- PythonBlockHandle handle;
- BlockBuilder *builder = nullptr;
-};
-
-struct PythonAttribute {
- PythonAttribute() : attr(nullptr) {}
- PythonAttribute(const mlir_attr_t &a) : attr(a) {}
- PythonAttribute(const PythonAttribute &other) = default;
- operator mlir_attr_t() { return attr; }
-
- operator Attribute() const { return Attribute::getFromOpaquePointer(attr); }
-
- std::string str() const {
- if (!attr)
- return "##null attr##";
-
- std::string res;
- llvm::raw_string_ostream os(res);
- Attribute().print(os);
- return res;
- }
-
- mlir_attr_t attr;
-};
-
-struct PythonAttributedType {
- PythonAttributedType() : type(nullptr) {}
- PythonAttributedType(mlir_type_t t) : type(t) {}
- PythonAttributedType(
- PythonType t,
- const std::unordered_map<std::string, PythonAttribute> &attributes =
- std::unordered_map<std::string, PythonAttribute>())
- : type(t), attrs(attributes) {}
-
- operator mlir_type_t() const { return type.type; }
- operator PythonType() const { return type; }
-
- // Return a vector of named attribute descriptors. The vector owns the
- // mlir_named_attr_t objects it contains, but not the names and attributes
- // those objects point to (names and opaque pointers to attributes are owned
- // by `this`).
- std::vector<mlir_named_attr_t> getNamedAttrs() const {
- std::vector<mlir_named_attr_t> result;
- result.reserve(attrs.size());
- for (const auto &namedAttr : attrs)
- result.push_back({namedAttr.first.c_str(), namedAttr.second.attr});
- return result;
- }
-
- std::string str() {
- mlir::Type t = mlir::Type::getFromOpaquePointer(type);
- std::string res;
- llvm::raw_string_ostream os(res);
- t.print(os);
- if (attrs.empty())
- return os.str();
-
- os << '{';
- bool first = true;
- for (const auto &namedAttr : attrs) {
- if (first)
- first = false;
- else
- os << ", ";
- os << namedAttr.first << ": " << namedAttr.second.str();
- }
- os << '}';
-
- return os.str();
- }
-
-private:
- PythonType type;
- std::unordered_map<std::string, PythonAttribute> attrs;
-};
-
-// Wraps mlir::AffineExpr.
-struct PythonAffineExpr {
- PythonAffineExpr() : affine_expr() {}
- PythonAffineExpr(const AffineExpr &a) : affine_expr(a) {}
- PythonAffineExpr(const PythonAffineExpr &other) = default;
-
- operator AffineExpr() const { return affine_expr; }
- operator AffineExpr &() { return affine_expr; }
-
- AffineExpr get() const { return affine_expr; }
-
- std::string str() const {
- std::string res;
- llvm::raw_string_ostream os(res);
- affine_expr.print(os);
- return res;
- }
-
-private:
- AffineExpr affine_expr;
-};
-
-// Wraps mlir::AffineMap.
-struct PythonAffineMap {
- PythonAffineMap() : affine_map() {}
- PythonAffineMap(const AffineMap &a) : affine_map(a) {}
- PythonAffineMap(const PythonAffineMap &other) = default;
-
- operator AffineMap() const { return affine_map; }
- operator AffineMap &() { return affine_map; }
-
- std::string str() const {
- std::string res;
- llvm::raw_string_ostream os(res);
- affine_map.print(os);
- return res;
- }
-
-private:
- AffineMap affine_map;
-};
-
-struct PythonIndexedValue {
- explicit PythonIndexedValue(PythonType type)
- : indexed(Type::getFromOpaquePointer(type.type)) {}
- explicit PythonIndexedValue(const IndexedValue &other) : indexed(other) {}
- PythonIndexedValue(PythonValueHandle handle) : indexed(handle.value) {}
- PythonIndexedValue(const PythonIndexedValue &other) = default;
-
- // Create a new indexed value with the same base as this one but with indices
- // provided as arguments.
- PythonIndexedValue index(const std::vector<PythonValueHandle> &indices) {
- std::vector<ValueHandle> handles(indices.begin(), indices.end());
- return PythonIndexedValue(IndexedValue(indexed(handles)));
- }
-
- void store(const std::vector<PythonValueHandle> &indices,
- PythonValueHandle value) {
- // Uses the overloaded `operator=` to emit a store.
- index(indices).indexed = value.value;
- }
-
- PythonValueHandle load(const std::vector<PythonValueHandle> &indices) {
- // Uses the overloaded cast to `ValueHandle` to emit a load.
- return static_cast<ValueHandle>(index(indices).indexed);
- }
-
- IndexedValue indexed;
-};
-
-template <typename ListTy, typename PythonTy, typename Ty>
-ListTy makeCList(SmallVectorImpl<Ty> &owning, const py::list &list) {
- for (auto &inp : list) {
- owning.push_back(Ty{inp.cast<PythonTy>()});
- }
- return ListTy{owning.data(), owning.size()};
-}
-
-static mlir_type_list_t makeCTypes(llvm::SmallVectorImpl<mlir_type_t> &owning,
- const py::list &types) {
- return makeCList<mlir_type_list_t, PythonType>(owning, types);
-}
-
-PythonFunction
-PythonMLIRModule::declareFunction(const std::string &name,
- const py::list &inputs,
- const std::vector<PythonType> &outputTypes,
- const py::kwargs &funcAttributes) {
-
- std::vector<PythonAttributedType> attributedInputs;
- attributedInputs.reserve(inputs.size());
- for (const auto &in : inputs) {
- std::string className = in.get_type().str();
- if (className.find(".Type'") != std::string::npos)
- attributedInputs.emplace_back(in.cast<PythonType>());
- else
- attributedInputs.push_back(in.cast<PythonAttributedType>());
- }
-
- // Create the function type.
- std::vector<mlir_type_t> ins(attributedInputs.begin(),
- attributedInputs.end());
- std::vector<mlir_type_t> outs(outputTypes.begin(), outputTypes.end());
- auto funcType = ::makeFunctionType(
- mlir_context_t{&mlirContext}, mlir_type_list_t{ins.data(), ins.size()},
- mlir_type_list_t{outs.data(), outs.size()});
-
- // Build the list of function attributes.
- std::vector<mlir::NamedAttribute> attrs;
- attrs.reserve(funcAttributes.size());
- for (const auto &named : funcAttributes)
- attrs.emplace_back(
- Identifier::get(std::string(named.first.str()), &mlirContext),
- mlir::Attribute::getFromOpaquePointer(reinterpret_cast<const void *>(
- named.second.cast<PythonAttribute>().attr)));
-
- // Build the list of lists of function argument attributes.
- std::vector<mlir::NamedAttributeList> inputAttrs;
- inputAttrs.reserve(attributedInputs.size());
- for (const auto &in : attributedInputs) {
- std::vector<mlir::NamedAttribute> inAttrs;
- for (const auto &named : in.getNamedAttrs())
- inAttrs.emplace_back(Identifier::get(named.name, &mlirContext),
- mlir::Attribute::getFromOpaquePointer(
- reinterpret_cast<const void *>(named.value)));
- inputAttrs.emplace_back(inAttrs);
- }
-
- // Create the function itself.
- auto func = mlir::FuncOp::create(
- UnknownLoc::get(&mlirContext), name,
- mlir::Type::getFromOpaquePointer(funcType).cast<FunctionType>(), attrs,
- inputAttrs);
- symbolTable.insert(func);
- return func;
-}
-
-PythonAttributedType PythonType::attachAttributeDict(
- const std::unordered_map<std::string, PythonAttribute> &attrs) const {
- return PythonAttributedType(*this, attrs);
-}
-
-PythonAttribute PythonMLIRModule::integerAttr(PythonType type, int64_t value) {
- return PythonAttribute(::makeIntegerAttr(type, value));
-}
-
-PythonAttribute PythonMLIRModule::boolAttr(bool value) {
- return PythonAttribute(::makeBoolAttr(&mlirContext, value));
-}
-
-PythonAttribute PythonMLIRModule::floatAttr(float value) {
- return PythonAttribute(::makeFloatAttr(&mlirContext, value));
-}
-
-PythonAttribute PythonMLIRModule::stringAttr(const std::string &value) {
- return PythonAttribute(::makeStringAttr(&mlirContext, value.c_str()));
-}
-
-PythonAttribute
-PythonMLIRModule::arrayAttr(const std::vector<PythonAttribute> &values) {
- std::vector<mlir::Attribute> mlir_attributes(values.begin(), values.end());
- auto array_attr = ArrayAttr::get(
- llvm::ArrayRef<mlir::Attribute>(mlir_attributes), &mlirContext);
- return PythonAttribute(array_attr.getAsOpaquePointer());
-}
-
-PythonAttribute PythonMLIRModule::affineMapAttr(PythonAffineMap value) {
- return PythonAttribute(AffineMapAttr::get(value).getAsOpaquePointer());
-}
-
-PythonAffineExpr PythonMLIRModule::affineConstantExpr(int64_t value) {
- return PythonAffineExpr(getAffineConstantExpr(value, &mlirContext));
-}
-
-PythonAffineExpr PythonMLIRModule::affineSymbolExpr(unsigned position) {
- return PythonAffineExpr(getAffineSymbolExpr(position, &mlirContext));
-}
-
-PythonAffineExpr PythonMLIRModule::affineDimExpr(unsigned position) {
- return PythonAffineExpr(getAffineDimExpr(position, &mlirContext));
-}
-
-PythonAffineMap PythonMLIRModule::affineConstantMap(int64_t value) {
- return PythonAffineMap(AffineMap::getConstantMap(value, &mlirContext));
-}
-
-PythonAffineMap
-PythonMLIRModule::affineMap(unsigned dimCount, unsigned SymbolCount,
- const std::vector<PythonAffineExpr> &results) {
- std::vector<AffineExpr> mlir_results(results.begin(), results.end());
- return PythonAffineMap(AffineMap::get(
- dimCount, SymbolCount, llvm::ArrayRef<AffineExpr>(mlir_results)));
-}
-
-PYBIND11_MODULE(pybind, m) {
- m.doc() =
- "Python bindings for MLIR Embedded Domain-Specific Components (EDSCs)";
- m.def("version", []() { return "EDSC Python extensions v1.0"; });
-
- py::class_<PythonLoopContext>(
- m, "LoopContext", "A context for building the body of a 'for' loop")
- .def(py::init<PythonValueHandle, PythonValueHandle, int64_t>())
- .def("__enter__", &PythonLoopContext::enter)
- .def("__exit__", &PythonLoopContext::exit);
-
- py::class_<PythonLoopNestContext>(m, "LoopNestContext",
- "A context for building the body of a the "
- "innermost loop in a nest of 'for' loops")
- .def(py::init<const std::vector<PythonValueHandle> &,
- const std::vector<PythonValueHandle> &,
- const std::vector<int64_t> &>())
- .def("__enter__", &PythonLoopNestContext::enter)
- .def("__exit__", &PythonLoopNestContext::exit);
-
- m.def("constant_index", [](int64_t val) -> PythonValueHandle {
- return ValueHandle(index_t(val));
- });
- m.def("constant_int", [](int64_t val, int width) -> PythonValueHandle {
- return ValueHandle::create<ConstantIntOp>(val, width);
- });
- m.def("constant_float", [](double val, PythonType type) -> PythonValueHandle {
- FloatType floatType =
- Type::getFromOpaquePointer(type.type).cast<FloatType>();
- assert(floatType);
- auto value = APFloat(val);
- bool lostPrecision;
- value.convert(floatType.getFloatSemantics(), APFloat::rmNearestTiesToEven,
- &lostPrecision);
- return ValueHandle::create<ConstantFloatOp>(value, floatType);
- });
- m.def("constant_function", [](PythonFunction func) -> PythonValueHandle {
- auto function = FuncOp::getFromOpaquePointer(func.function);
- auto attr = SymbolRefAttr::get(function.getName(), function.getContext());
- return ValueHandle::create<ConstantOp>(function.getType(), attr);
- });
- m.def("appendTo", [](const PythonBlockHandle &handle) {
- return PythonBlockAppender(handle);
- });
- m.def(
- "ret",
- [](const std::vector<PythonValueHandle> &args) {
- std::vector<ValueHandle> values(args.begin(), args.end());
- (intrinsics::ret(ArrayRef<ValueHandle>{values})); // vexing parse
- return PythonValueHandle(nullptr);
- },
- py::arg("args") = std::vector<PythonValueHandle>());
- m.def(
- "br",
- [](const PythonBlockHandle &dest,
- const std::vector<PythonValueHandle> &args) {
- std::vector<ValueHandle> values(args.begin(), args.end());
- intrinsics::br(dest, values);
- return PythonValueHandle(nullptr);
- },
- py::arg("dest"), py::arg("args") = std::vector<PythonValueHandle>());
- m.def(
- "cond_br",
- [](PythonValueHandle condition, const PythonBlockHandle &trueDest,
- const std::vector<PythonValueHandle> &trueArgs,
- const PythonBlockHandle &falseDest,
- const std::vector<PythonValueHandle> &falseArgs) -> PythonValueHandle {
- std::vector<ValueHandle> trueArguments(trueArgs.begin(),
- trueArgs.end());
- std::vector<ValueHandle> falseArguments(falseArgs.begin(),
- falseArgs.end());
- intrinsics::cond_br(condition, trueDest, trueArguments, falseDest,
- falseArguments);
- return PythonValueHandle(nullptr);
- });
- m.def("index_cast",
- [](PythonValueHandle element, PythonType type) -> PythonValueHandle {
- return ValueHandle::create<IndexCastOp>(
- element.value, Type::getFromOpaquePointer(type.type));
- });
- m.def("select",
- [](PythonValueHandle condition, PythonValueHandle trueValue,
- PythonValueHandle falseValue) -> PythonValueHandle {
- return ValueHandle::create<SelectOp>(condition.value, trueValue.value,
- falseValue.value);
- });
- m.def("op",
- [](const std::string &name,
- const std::vector<PythonValueHandle> &operands,
- const std::vector<PythonType> &resultTypes,
- const py::kwargs &attributes) -> PythonValueHandle {
- std::vector<ValueHandle> operandHandles(operands.begin(),
- operands.end());
- std::vector<Type> types;
- types.reserve(resultTypes.size());
- for (auto t : resultTypes)
- types.push_back(Type::getFromOpaquePointer(t.type));
-
- std::vector<NamedAttribute> attrs;
- attrs.reserve(attributes.size());
- for (const auto &a : attributes) {
- std::string name = a.first.str();
- auto pyAttr = a.second.cast<PythonAttribute>();
- auto cppAttr = Attribute::getFromOpaquePointer(pyAttr.attr);
- auto identifier =
- Identifier::get(name, ScopedContext::getContext());
- attrs.emplace_back(identifier, cppAttr);
- }
-
- return ValueHandle::create(name, operandHandles, types, attrs);
- });
-
- py::class_<PythonFunction>(m, "Function", "Wrapping class for mlir::FuncOp.")
- .def(py::init<PythonFunction>())
- .def("__str__", &PythonFunction::str)
- .def("define", &PythonFunction::define,
- "Adds a body to the function if it does not already have one. "
- "Returns true if the body was added")
- .def("arg", &PythonFunction::arg,
- "Get the ValueHandle to the indexed argument of the function");
-
- py::class_<PythonAttribute>(m, "Attribute",
- "Wrapping class for mlir::Attribute")
- .def(py::init<PythonAttribute>())
- .def("__str__", &PythonAttribute::str);
-
- py::class_<PythonType>(m, "Type", "Wrapping class for mlir::Type.")
- .def(py::init<PythonType>())
- .def("__call__", &PythonType::attachAttributeDict,
- "Attach the attributes to these type, making it suitable for "
- "constructing functions with argument attributes")
- .def("__str__", &PythonType::str);
-
- py::class_<PythonAttributedType>(
- m, "AttributedType",
- "A class containing a wrapped mlir::Type and a wrapped "
- "mlir::NamedAttributeList that are used together, e.g. in function "
- "argument declaration")
- .def(py::init<PythonAttributedType>())
- .def("__str__", &PythonAttributedType::str);
-
- py::class_<PythonMLIRModule>(
- m, "MLIRModule",
- "An MLIRModule is the abstraction that owns the allocations to support "
- "compilation of a single mlir::ModuleOp into an ExecutionEngine backed "
- "by "
- "the LLVM ORC JIT. A typical flow consists in creating an MLIRModule, "
- "adding functions, compiling the module to obtain an ExecutionEngine on "
- "which named functions may be called. For now the only means to retrieve "
- "the ExecutionEngine is by calling `get_engine_address`. This mode of "
- "execution is limited to passing the pointer to C++ where the function "
- "is called. Extending the API to allow calling JIT compiled functions "
- "directly require integration with a tensor library (e.g. numpy). This "
- "is left as the prerogative of libraries and frameworks for now.")
- .def(py::init<>())
- .def("boolAttr", &PythonMLIRModule::boolAttr,
- "Creates an mlir::BoolAttr with the given value")
- .def(
- "integerAttr", &PythonMLIRModule::integerAttr,
- "Creates an mlir::IntegerAttr of the given type with the given value "
- "in the context associated with this MLIR module.")
- .def("floatAttr", &PythonMLIRModule::floatAttr,
- "Creates an mlir::FloatAttr with the given value")
- .def("stringAttr", &PythonMLIRModule::stringAttr,
- "Creates an mlir::StringAttr with the given value")
- .def("arrayAttr", &PythonMLIRModule::arrayAttr,
- "Creates an mlir::ArrayAttr of the given type with the given values "
- "in the context associated with this MLIR module.")
- .def("affineMapAttr", &PythonMLIRModule::affineMapAttr,
- "Creates an mlir::AffineMapAttr of the given type with the given "
- "value in the context associated with this MLIR module.")
- .def("declare_function", &PythonMLIRModule::declareFunction,
- "Declares a new mlir::FuncOp in the current mlir::ModuleOp. The "
- "function arguments can have attributes. The function has no "
- "definition and can be linked to an external library.")
- .def("make_function", &PythonMLIRModule::makeFunction,
- "Defines a new mlir::FuncOp in the current mlir::ModuleOp.")
- .def("function_context", &PythonMLIRModule::makeFunctionContext,
- "Defines a new mlir::FuncOp in the mlir::ModuleOp and creates the "
- "function context for building the body of the function.")
- .def("get_function", &PythonMLIRModule::getNamedFunction,
- "Looks up the function with the given name in the module.")
- .def("make_memref_type", &PythonMLIRModule::makeMemRefType,
- "Returns an mlir::MemRefType of an elemental scalar. -1 is used to "
- "denote symbolic dimensions in the resulting memref shape.")
- .def("make_index_type", &PythonMLIRModule::makeIndexType,
- "Returns an mlir::IndexType")
- .def("make_type", &PythonMLIRModule::makeType,
- "Returns an mlir::Type defined by the IR passed in as the argument.")
- .def("compile", &PythonMLIRModule::compile,
- "Compiles the mlir::ModuleOp to LLVMIR a creates new opaque "
- "ExecutionEngine backed by the ORC JIT. The arguments, if present, "
- "indicates the level of LLVM optimizations to run (similar to -O?).",
- py::arg("optLevel") = -1, py::arg("codegenOptLevel") = -1)
- .def("get_ir", &PythonMLIRModule::getIR,
- "Returns a dump of the MLIR representation of the module. This is "
- "used for serde to support out-of-process execution as well as "
- "debugging purposes.")
- .def("get_engine_address", &PythonMLIRModule::getEngineAddress,
- "Returns the address of the compiled ExecutionEngine. This is used "
- "for in-process execution.")
- .def("affine_constant_expr", &PythonMLIRModule::affineConstantExpr,
- "Returns an affine constant expression.")
- .def("affine_symbol_expr", &PythonMLIRModule::affineSymbolExpr,
- "Returns an affine symbol expression.")
- .def("affine_dim_expr", &PythonMLIRModule::affineDimExpr,
- "Returns an affine dim expression.")
- .def("affine_constant_map", &PythonMLIRModule::affineConstantMap,
- "Returns an affine map with single constant result.")
- .def("affine_map", &PythonMLIRModule::affineMap, "Returns an affine map.",
- py::arg("dimCount"), py::arg("symbolCount"), py::arg("results"))
- .def("__str__", &PythonMLIRModule::getIR,
- "Get the string representation of the module");
-
- py::class_<PythonFunctionContext>(
- m, "FunctionContext", "A wrapper around mlir::edsc::ScopedContext")
- .def(py::init<PythonFunction>())
- .def("__enter__", &PythonFunctionContext::enter)
- .def("__exit__", &PythonFunctionContext::exit);
-
- {
- using namespace mlir::edsc::op;
- py::class_<PythonValueHandle>(m, "ValueHandle",
- "A wrapper around mlir::edsc::ValueHandle")
- .def(py::init<PythonType>())
- .def(py::init<PythonValueHandle>())
- .def("__add__",
- [](PythonValueHandle lhs, PythonValueHandle rhs)
- -> PythonValueHandle { return lhs.value + rhs.value; })
- .def("__sub__",
- [](PythonValueHandle lhs, PythonValueHandle rhs)
- -> PythonValueHandle { return lhs.value - rhs.value; })
- .def("__mul__",
- [](PythonValueHandle lhs, PythonValueHandle rhs)
- -> PythonValueHandle { return lhs.value * rhs.value; })
- .def("__div__",
- [](PythonValueHandle lhs, PythonValueHandle rhs)
- -> PythonValueHandle { return lhs.value / rhs.value; })
- .def("__truediv__",
- [](PythonValueHandle lhs, PythonValueHandle rhs)
- -> PythonValueHandle { return lhs.value / rhs.value; })
- .def("__floordiv__",
- [](PythonValueHandle lhs, PythonValueHandle rhs)
- -> PythonValueHandle { return floorDiv(lhs, rhs); })
- .def("__mod__",
- [](PythonValueHandle lhs, PythonValueHandle rhs)
- -> PythonValueHandle { return lhs.value % rhs.value; })
- .def("__lt__",
- [](PythonValueHandle lhs,
- PythonValueHandle rhs) -> PythonValueHandle {
- return ValueHandle::create<CmpIOp>(CmpIPredicate::slt, lhs.value,
- rhs.value);
- })
- .def("__le__",
- [](PythonValueHandle lhs,
- PythonValueHandle rhs) -> PythonValueHandle {
- return ValueHandle::create<CmpIOp>(CmpIPredicate::sle, lhs.value,
- rhs.value);
- })
- .def("__gt__",
- [](PythonValueHandle lhs,
- PythonValueHandle rhs) -> PythonValueHandle {
- return ValueHandle::create<CmpIOp>(CmpIPredicate::sgt, lhs.value,
- rhs.value);
- })
- .def("__ge__",
- [](PythonValueHandle lhs,
- PythonValueHandle rhs) -> PythonValueHandle {
- return ValueHandle::create<CmpIOp>(CmpIPredicate::sge, lhs.value,
- rhs.value);
- })
- .def("__eq__",
- [](PythonValueHandle lhs,
- PythonValueHandle rhs) -> PythonValueHandle {
- return ValueHandle::create<CmpIOp>(CmpIPredicate::eq, lhs.value,
- rhs.value);
- })
- .def("__ne__",
- [](PythonValueHandle lhs,
- PythonValueHandle rhs) -> PythonValueHandle {
- return ValueHandle::create<CmpIOp>(CmpIPredicate::ne, lhs.value,
- rhs.value);
- })
- .def("__invert__",
- [](PythonValueHandle handle) -> PythonValueHandle {
- return !handle.value;
- })
- .def("__and__",
- [](PythonValueHandle lhs, PythonValueHandle rhs)
- -> PythonValueHandle { return lhs.value && rhs.value; })
- .def("__or__",
- [](PythonValueHandle lhs, PythonValueHandle rhs)
- -> PythonValueHandle { return lhs.value || rhs.value; })
- .def("__call__", &PythonValueHandle::call)
- .def("type", &PythonValueHandle::type);
- }
-
- py::class_<PythonBlockAppender>(
- m, "BlockAppender",
- "A dummy class signaling BlockContext to append IR to the given block "
- "instead of creating a new block")
- .def(py::init<const PythonBlockHandle &>());
- py::class_<PythonBlockHandle>(m, "BlockHandle",
- "A wrapper around mlir::edsc::BlockHandle")
- .def(py::init<PythonBlockHandle>())
- .def("arg", &PythonBlockHandle::arg);
-
- py::class_<PythonBlockContext>(m, "BlockContext",
- "A wrapper around mlir::edsc::BlockBuilder")
- .def(py::init<>())
- .def(py::init<const std::vector<PythonType> &>())
- .def(py::init<const PythonBlockAppender &>())
- .def("__enter__", &PythonBlockContext::enter)
- .def("__exit__", &PythonBlockContext::exit)
- .def("handle", &PythonBlockContext::getHandle);
-
- py::class_<PythonIndexedValue>(m, "IndexedValue",
- "A wrapper around mlir::edsc::IndexedValue")
- .def(py::init<PythonValueHandle>())
- .def("load", &PythonIndexedValue::load)
- .def("store", &PythonIndexedValue::store);
-
- py::class_<PythonAffineExpr>(m, "AffineExpr",
- "A wrapper around mlir::AffineExpr")
- .def(py::init<PythonAffineExpr>())
- .def("__add__",
- [](PythonAffineExpr lhs, int64_t rhs) -> PythonAffineExpr {
- return PythonAffineExpr(lhs.get() + rhs);
- })
- .def("__add__",
- [](PythonAffineExpr lhs, PythonAffineExpr rhs) -> PythonAffineExpr {
- return PythonAffineExpr(lhs.get() + rhs.get());
- })
- .def("__neg__",
- [](PythonAffineExpr lhs) -> PythonAffineExpr {
- return PythonAffineExpr(-lhs.get());
- })
- .def("__sub__",
- [](PythonAffineExpr lhs, int64_t rhs) -> PythonAffineExpr {
- return PythonAffineExpr(lhs.get() - rhs);
- })
- .def("__sub__",
- [](PythonAffineExpr lhs, PythonAffineExpr rhs) -> PythonAffineExpr {
- return PythonAffineExpr(lhs.get() - rhs.get());
- })
- .def("__mul__",
- [](PythonAffineExpr lhs, int64_t rhs) -> PythonAffineExpr {
- return PythonAffineExpr(lhs.get() * rhs);
- })
- .def("__mul__",
- [](PythonAffineExpr lhs, PythonAffineExpr rhs) -> PythonAffineExpr {
- return PythonAffineExpr(lhs.get() * rhs.get());
- })
- .def("__floordiv__",
- [](PythonAffineExpr lhs, uint64_t rhs) -> PythonAffineExpr {
- return PythonAffineExpr(lhs.get().floorDiv(rhs));
- })
- .def("__floordiv__",
- [](PythonAffineExpr lhs, PythonAffineExpr rhs) -> PythonAffineExpr {
- return PythonAffineExpr(lhs.get().floorDiv(rhs.get()));
- })
- .def("ceildiv",
- [](PythonAffineExpr lhs, uint64_t rhs) -> PythonAffineExpr {
- return PythonAffineExpr(lhs.get().ceilDiv(rhs));
- })
- .def("ceildiv",
- [](PythonAffineExpr lhs, PythonAffineExpr rhs) -> PythonAffineExpr {
- return PythonAffineExpr(lhs.get().ceilDiv(rhs.get()));
- })
- .def("__mod__",
- [](PythonAffineExpr lhs, uint64_t rhs) -> PythonAffineExpr {
- return PythonAffineExpr(lhs.get() % rhs);
- })
- .def("__mod__",
- [](PythonAffineExpr lhs, PythonAffineExpr rhs) -> PythonAffineExpr {
- return PythonAffineExpr(lhs.get() % rhs.get());
- })
- .def("compose",
- [](PythonAffineExpr self, PythonAffineMap map) -> PythonAffineExpr {
- return PythonAffineExpr(self.get().compose(map));
- })
- .def(
- "get_constant_value",
- [](PythonAffineExpr self) -> py::object {
- auto const_expr = self.get().dyn_cast<AffineConstantExpr>();
- if (const_expr)
- return py::cast(const_expr.getValue());
- return py::none();
- },
- "Returns the constant value for the affine expression if any, or "
- "returns None.")
- .def("__str__", &PythonAffineExpr::str);
-
- py::class_<PythonAffineMap>(m, "AffineMap",
- "A wrapper around mlir::AffineMap")
- .def(py::init<PythonAffineMap>())
- .def("__str__", &PythonAffineMap::str);
-}
-
-} // namespace python
-} // namespace edsc
-} // namespace mlir
diff --git a/third_party/mlir/bindings/python/test/BUILD b/third_party/mlir/bindings/python/test/BUILD
deleted file mode 100644
index 36fe5cb..0000000
--- a/third_party/mlir/bindings/python/test/BUILD
+++ /dev/null
@@ -1,36 +0,0 @@
-# Description:
-# BUILD file for the Python wrappers for EDSCs
-
-licenses(["notice"]) # Apache 2.0
-
-# Export the BUILD file so automated tooling can check licenses
-exports_files(["BUILD"])
-
-load("//third_party/llvm/build_defs:lit.bzl", "glob_lit_tests")
-
-glob_lit_tests(
- data = [":test_utilities"],
- driver = "@local_config_mlir//:run_lit.sh",
- test_file_exts = ["py"],
-)
-
-# Bundle together all of the test utilities that are used by tests.
-filegroup(
- name = "test_utilities",
- testonly = True,
- data = [
- ":test_edsc",
- "//third_party/llvm/llvm:FileCheck",
- ],
-)
-
-py_binary(
- name = "test_edsc",
- srcs = ["test_py2and3.py"],
- main = "test_py2and3.py",
- python_version = "PY2",
- deps = [
- "//testing/pybase",
- "@local_config_mlir//bindings/python:_pybind",
- ],
-)
diff --git a/third_party/mlir/bindings/python/test/test_py2and3.py b/third_party/mlir/bindings/python/test/test_py2and3.py
deleted file mode 100644
index 02f8f62..0000000
--- a/third_party/mlir/bindings/python/test/test_py2and3.py
+++ /dev/null
@@ -1,583 +0,0 @@
-# Copyright 2019 The MLIR Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-# RUN: %p/test_edsc %s | FileCheck %s
-"""Python2 and 3 test for the MLIR EDSC Python bindings"""
-
-import google_mlir.bindings.python.pybind as E
-import inspect
-
-
-# Prints `str` prefixed by the current test function name so we can use it in
-# Filecheck label directives.
-# This is achieved by inspecting the stack and getting the parent name.
-def printWithCurrentFunctionName(str):
- print(inspect.stack()[1][3])
- print(str)
-
-
-class EdscTest:
-
- def setUp(self):
- self.module = E.MLIRModule()
- self.boolType = self.module.make_type("i1")
- self.i32Type = self.module.make_type("i32")
- self.f32Type = self.module.make_type("f32")
- self.indexType = self.module.make_index_type()
-
- def testBlockArguments(self):
- self.setUp()
- with self.module.function_context("foo", [], []) as fun:
- E.constant_index(42)
- with E.BlockContext([self.f32Type, self.f32Type]) as b:
- b.arg(0) + b.arg(1)
- printWithCurrentFunctionName(str(fun))
- # CHECK-LABEL: testBlockArguments
- # CHECK: %{{.*}} = constant 42 : index
- # CHECK: ^bb{{.*}}(%{{.*}}: f32, %{{.*}}: f32):
- # CHECK: %{{.*}} = addf %{{.*}}, %{{.*}} : f32
-
- def testBlockContext(self):
- self.setUp()
- with self.module.function_context("foo", [], []) as fun:
- cst = E.constant_index(42)
- with E.BlockContext():
- cst + cst
- printWithCurrentFunctionName(str(fun))
- # CHECK-LABEL: testBlockContext
- # CHECK: %{{.*}} = constant 42 : index
- # CHECK: ^bb
- # CHECK: %{{.*}} = "affine.apply"() {map = () -> (84)} : () -> index
-
- def testBlockContextAppend(self):
- self.setUp()
- with self.module.function_context("foo", [], []) as fun:
- E.constant_index(41)
- with E.BlockContext() as b:
- blk = b # save block handle for later
- E.constant_index(0)
- E.constant_index(42)
- with E.BlockContext(E.appendTo(blk)):
- E.constant_index(1)
- printWithCurrentFunctionName(str(fun))
- # CHECK-LABEL: testBlockContextAppend
- # CHECK: %{{.*}} = constant 41 : index
- # CHECK: %{{.*}} = constant 42 : index
- # CHECK: ^bb
- # CHECK: %{{.*}} = constant 0 : index
- # CHECK: %{{.*}} = constant 1 : index
-
- def testBlockContextStandalone(self):
- self.setUp()
- with self.module.function_context("foo", [], []) as fun:
- blk1 = E.BlockContext()
- blk2 = E.BlockContext()
- with blk1:
- E.constant_index(0)
- with blk2:
- E.constant_index(56)
- E.constant_index(57)
- E.constant_index(41)
- with blk1:
- E.constant_index(1)
- E.constant_index(42)
- printWithCurrentFunctionName(str(fun))
- # CHECK-LABEL: testBlockContextStandalone
- # CHECK: %{{.*}} = constant 41 : index
- # CHECK: %{{.*}} = constant 42 : index
- # CHECK: ^bb
- # CHECK: %{{.*}} = constant 0 : index
- # CHECK: %{{.*}} = constant 1 : index
- # CHECK: ^bb
- # CHECK: %{{.*}} = constant 56 : index
- # CHECK: %{{.*}} = constant 57 : index
-
- def testBooleanOps(self):
- self.setUp()
- with self.module.function_context("booleans",
- [self.boolType for _ in range(4)],
- []) as fun:
- i, j, k, l = (fun.arg(x) for x in range(4))
- stmt1 = (i < j) & (j >= k)
- stmt2 = ~(stmt1 | (k == l))
- printWithCurrentFunctionName(str(fun))
- # CHECK-LABEL: testBooleanOps
- # CHECK: %{{.*}} = cmpi "slt", %{{.*}}, %{{.*}} : i1
- # CHECK: %{{.*}} = cmpi "sge", %{{.*}}, %{{.*}} : i1
- # CHECK: %{{.*}} = muli %{{.*}}, %{{.*}} : i1
- # CHECK: %{{.*}} = cmpi "eq", %{{.*}}, %{{.*}} : i1
- # CHECK: %{{.*}} = constant 1 : i1
- # CHECK: %{{.*}} = subi %{{.*}}, %{{.*}} : i1
- # CHECK: %{{.*}} = constant 1 : i1
- # CHECK: %{{.*}} = subi %{{.*}}, %{{.*}} : i1
- # CHECK: %{{.*}} = muli %{{.*}}, %{{.*}} : i1
- # CHECK: %{{.*}} = constant 1 : i1
- # CHECK: %{{.*}} = subi %{{.*}}, %{{.*}} : i1
- # CHECK: %{{.*}} = constant 1 : i1
- # CHECK: %{{.*}} = subi %{{.*}}, %{{.*}} : i1
-
- def testBr(self):
- self.setUp()
- with self.module.function_context("foo", [], []) as fun:
- with E.BlockContext() as b:
- blk = b
- E.ret()
- E.br(blk)
- printWithCurrentFunctionName(str(fun))
- # CHECK-LABEL: testBr
- # CHECK: br ^bb
- # CHECK: ^bb
- # CHECK: return
-
- def testBrArgs(self):
- self.setUp()
- with self.module.function_context("foo", [], []) as fun:
- # Create an infinite loop.
- with E.BlockContext([self.indexType, self.indexType]) as b:
- E.br(b, [b.arg(1), b.arg(0)])
- E.br(b, [E.constant_index(0), E.constant_index(1)])
- printWithCurrentFunctionName(str(fun))
- # CHECK-LABEL: testBrArgs
- # CHECK: %{{.*}} = constant 0 : index
- # CHECK: %{{.*}} = constant 1 : index
- # CHECK: br ^bb{{.*}}(%{{.*}}, %{{.*}} : index, index)
- # CHECK: ^bb{{.*}}(%{{.*}}: index, %{{.*}}: index):
- # CHECK: br ^bb{{.*}}(%{{.*}}, %{{.*}} : index, index)
-
- def testBrDeclaration(self):
- self.setUp()
- with self.module.function_context("foo", [], []) as fun:
- blk = E.BlockContext()
- E.br(blk.handle())
- with blk:
- E.ret()
- printWithCurrentFunctionName(str(fun))
- # CHECK-LABEL: testBrDeclaration
- # CHECK: br ^bb
- # CHECK: ^bb
- # CHECK: return
-
- def testCallOp(self):
- self.setUp()
- callee = self.module.declare_function("sqrtf", [self.f32Type],
- [self.f32Type])
- with self.module.function_context("call", [self.f32Type], []) as fun:
- funCst = E.constant_function(callee)
- funCst([fun.arg(0)]) + E.constant_float(42., self.f32Type)
- printWithCurrentFunctionName(str(self.module))
- # CHECK-LABEL: testCallOp
- # CHECK: func @sqrtf(f32) -> f32
- # CHECK: %{{.*}} = constant @sqrtf : (f32) -> f32
- # CHECK: %{{.*}} = call_indirect %{{.*}}(%{{.*}}) : (f32) -> f32
-
- def testCondBr(self):
- self.setUp()
- with self.module.function_context("foo", [self.boolType], []) as fun:
- with E.BlockContext() as blk1:
- E.ret([])
- with E.BlockContext([self.indexType]) as blk2:
- E.ret([])
- cst = E.constant_index(0)
- E.cond_br(fun.arg(0), blk1, [], blk2, [cst])
- printWithCurrentFunctionName(str(fun))
- # CHECK-LABEL: testCondBr
- # CHECK: cond_br %{{.*}}, ^bb{{.*}}, ^bb{{.*}}(%{{.*}} : index)
-
- def testConstantAffineExpr(self):
- self.setUp()
- with self.module.function_context("constant_affine", [], []) as fun:
- a1 = self.module.affine_dim_expr(0)
- a2 = self.module.affine_dim_expr(1)
- a3 = a1 + a2 + 3
- composedExpr = a3.compose(
- self.module.affine_map(2, 0, [
- self.module.affine_constant_expr(4),
- self.module.affine_constant_expr(7)
- ]))
- printWithCurrentFunctionName(str(fun))
- print("constant value : %d" % composedExpr.get_constant_value())
- # CHECK-LABEL: testConstantAffineExpr
- # CHECK: constant value : 14
-
- def testConstants(self):
- self.setUp()
- with self.module.function_context("constants", [], []) as fun:
- E.constant_float(1.23, self.module.make_type("bf16"))
- E.constant_float(1.23, self.module.make_type("f16"))
- E.constant_float(1.23, self.module.make_type("f32"))
- E.constant_float(1.23, self.module.make_type("f64"))
- E.constant_int(1, 1)
- E.constant_int(123, 8)
- E.constant_int(123, 16)
- E.constant_int(123, 32)
- E.constant_int(123, 64)
- E.constant_index(123)
- E.constant_function(fun)
- printWithCurrentFunctionName(str(fun))
- # CHECK-LABEL: testConstants
- # CHECK: constant 1.230000e+00 : bf16
- # CHECK: constant 1.230470e+00 : f16
- # CHECK: constant 1.230000e+00 : f32
- # CHECK: constant 1.230000e+00 : f64
- # CHECK: constant 1 : i1
- # CHECK: constant 123 : i8
- # CHECK: constant 123 : i16
- # CHECK: constant 123 : i32
- # CHECK: constant 123 : index
- # CHECK: constant @constants : () -> ()
-
- def testCustom(self):
- self.setUp()
- with self.module.function_context("custom", [self.indexType, self.f32Type],
- []) as fun:
- E.op("foo", [fun.arg(0)], [self.f32Type]) + fun.arg(1)
- printWithCurrentFunctionName(str(fun))
- # CHECK-LABEL: testCustom
- # CHECK: %{{.*}} = "foo"(%{{.*}}) : (index) -> f32
- # CHECK: %{{.*}} = addf %{{.*}}, %{{.*}} : f32
-
- # Create 'addi' using the generic Op interface. We need an operation known
- # to the execution engine so that the engine can compile it.
- def testCustomOpCompilation(self):
- self.setUp()
- with self.module.function_context("adder", [self.i32Type], []) as f:
- c1 = E.op(
- "std.constant", [], [self.i32Type],
- value=self.module.integerAttr(self.i32Type, 42))
- E.op("std.addi", [c1, f.arg(0)], [self.i32Type])
- E.ret([])
- self.module.compile()
- printWithCurrentFunctionName(str(self.module.get_engine_address() == 0))
- # CHECK-LABEL: testCustomOpCompilation
- # CHECK: False
-
- def testDivisions(self):
- self.setUp()
- with self.module.function_context(
- "division", [self.indexType, self.i32Type, self.i32Type], []) as fun:
- # indices only support floor division
- fun.arg(0) // E.constant_index(42)
- # regular values only support regular division
- fun.arg(1) / fun.arg(2)
- printWithCurrentFunctionName(str(self.module))
- # CHECK-LABEL: testDivisions
- # CHECK: floordiv 42
- # CHECK: divis %{{.*}}, %{{.*}} : i32
-
- def testFunctionArgs(self):
- self.setUp()
- with self.module.function_context("foo", [self.f32Type, self.f32Type],
- [self.indexType]) as fun:
- pass
- printWithCurrentFunctionName(str(fun))
- # CHECK-LABEL: testFunctionArgs
- # CHECK: func @foo(%{{.*}}: f32, %{{.*}}: f32) -> index
-
- def testFunctionContext(self):
- self.setUp()
- with self.module.function_context("foo", [], []):
- pass
- printWithCurrentFunctionName(self.module.get_function("foo"))
- # CHECK-LABEL: testFunctionContext
- # CHECK: func @foo() {
-
- def testFunctionDeclaration(self):
- self.setUp()
- boolAttr = self.module.boolAttr(True)
- t = self.module.make_memref_type(self.f32Type, [10])
- t_llvm_noalias = t({"llvm.noalias": boolAttr})
- t_readonly = t({"readonly": boolAttr})
- f = self.module.declare_function("foo", [t, t_llvm_noalias, t_readonly], [])
- printWithCurrentFunctionName(str(self.module))
- # CHECK-LABEL: testFunctionDeclaration
- # CHECK: func @foo(memref<10xf32>, memref<10xf32> {llvm.noalias = true}, memref<10xf32> {readonly = true})
-
- def testFunctionDeclarationWithAffineAttr(self):
- self.setUp()
- a1 = self.module.affine_constant_expr(23)
- a2 = self.module.affine_constant_expr(44)
- a3 = self.module.affine_dim_expr(1)
- s0 = self.module.affine_symbol_expr(0)
- aMap1 = self.module.affine_map(2, 0, [a1, a2, s0])
- aMap2 = self.module.affine_constant_map(42)
- aMap3 = self.module.affine_map(
- 2, 0,
- [a1 + a2 * a3, a1 // a3 % a2,
- a1.ceildiv(a2), a1 - 2, a2 * 2, -a3])
-
- affineAttr1 = self.module.affineMapAttr(aMap1)
- affineAttr2 = self.module.affineMapAttr(aMap2)
- affineAttr3 = self.module.affineMapAttr(aMap3)
-
- t = self.module.make_memref_type(self.f32Type, [10])
- t_with_attr = t({
- "affine_attr_1": affineAttr1,
- "affine_attr_2": affineAttr2,
- "affine_attr_3": affineAttr3,
- })
-
- f = self.module.declare_function("foo", [t, t_with_attr], [])
- printWithCurrentFunctionName(str(self.module))
- # CHECK-LABEL: testFunctionDeclarationWithAffineAttr
- # CHECK: func @foo(memref<10xf32>, memref<10xf32> {affine_attr_1 = (d0, d1) -> (23, 44, s0), affine_attr_2 = () -> (42), affine_attr_3 = (d0, d1) -> (d1 * 44 + 23, (23 floordiv d1) mod 44, 1, 21, 88, -d1)})
-
- def testFunctionDeclarationWithArrayAttr(self):
- self.setUp()
- arrayAttr = self.module.arrayAttr([
- self.module.integerAttr(self.i32Type, 43),
- self.module.integerAttr(self.i32Type, 33),
- ])
- t = self.module.make_memref_type(self.f32Type, [10])
- t_with_attr = t({"array_attr": arrayAttr})
-
- f = self.module.declare_function("foo", [t, t_with_attr], [])
- printWithCurrentFunctionName(str(self.module))
- # CHECK-LABEL: testFunctionDeclarationWithArrayAttr
- # CHECK: func @foo(memref<10xf32>, memref<10xf32> {array_attr = [43 : i32, 33 : i32]})
-
- def testFunctionDeclarationWithFloatAndStringAttr(self):
- self.setUp()
- float_attr = self.module.floatAttr(23.3)
- string_attr = self.module.stringAttr("TEST_STRING")
-
- f = self.module.declare_function(
- "foo", [], [], float_attr=float_attr, string_attr=string_attr)
- printWithCurrentFunctionName(str(self.module))
- # CHECK-LABEL: testFunctionDeclarationWithFloatAndStringAttr
- # CHECK: func @foo() attributes {float_attr = 2.330000e+01 : f32, string_attr = "TEST_STRING"}
-
- def testFunctionMultiple(self):
- self.setUp()
- with self.module.function_context("foo", [], []):
- pass
- with self.module.function_context("foo", [], []):
- E.constant_index(0)
- printWithCurrentFunctionName(str(self.module))
- # CHECK-LABEL: testFunctionMultiple
- # CHECK: func @foo()
- # CHECK: func @foo_0()
- # CHECK: %{{.*}} = constant 0 : index
-
- def testIndexCast(self):
- self.setUp()
- with self.module.function_context("testIndexCast", [], []):
- index = E.constant_index(0)
- E.index_cast(index, self.i32Type)
- printWithCurrentFunctionName(str(self.module))
- # CHECK-LABEL: testIndexCast
- # CHECK: index_cast %{{.*}} : index to i32
-
- def testIndexedValue(self):
- self.setUp()
- memrefType = self.module.make_memref_type(self.f32Type, [10, 42])
- with self.module.function_context("indexed", [memrefType],
- [memrefType]) as fun:
- A = E.IndexedValue(fun.arg(0))
- cst = E.constant_float(1., self.f32Type)
- with E.LoopNestContext(
- [E.constant_index(0), E.constant_index(0)],
- [E.constant_index(10), E.constant_index(42)], [1, 1]) as (i, j):
- A.store([i, j], A.load([i, j]) + cst)
- E.ret([fun.arg(0)])
- printWithCurrentFunctionName(str(fun))
- # CHECK-LABEL: testIndexedValue
- # CHECK: "affine.for"()
- # CHECK: "affine.for"()
- # CHECK: "affine.load"
- # CHECK-SAME: memref<10x42xf32>
- # CHECK: %{{.*}} = addf %{{.*}}, %{{.*}} : f32
- # CHECK: "affine.store"
- # CHECK-SAME: memref<10x42xf32>
- # CHECK: {lower_bound = () -> (0), step = 1 : index, upper_bound = () -> (42)}
- # CHECK: {lower_bound = () -> (0), step = 1 : index, upper_bound = () -> (10)}
-
- def testLoopContext(self):
- self.setUp()
- with self.module.function_context("foo", [], []) as fun:
- lhs = E.constant_index(0)
- rhs = E.constant_index(42)
- with E.LoopContext(lhs, rhs, 1) as i:
- lhs + rhs + i
- with E.LoopContext(rhs, rhs + rhs, 2) as j:
- x = i + j
- printWithCurrentFunctionName(str(fun))
- # CHECK-LABEL: testLoopContext
- # CHECK: "affine.for"() (
- # CHECK: ^bb{{.*}}(%{{.*}}: index):
- # CHECK: "affine.for"(%{{.*}}, %{{.*}}) (
- # CHECK: ^bb{{.*}}(%{{.*}}: index):
- # CHECK: "affine.apply"(%{{.*}}, %{{.*}}) {map = (d0, d1) -> (d0 + d1)} : (index, index) -> index
- # CHECK: {lower_bound = (d0) -> (d0), step = 2 : index, upper_bound = (d0) -> (d0)} : (index, index) -> ()
- # CHECK: {lower_bound = () -> (0), step = 1 : index, upper_bound = () -> (42)}
-
- def testLoopNestContext(self):
- self.setUp()
- with self.module.function_context("foo", [], []) as fun:
- lbs = [E.constant_index(i) for i in range(4)]
- ubs = [E.constant_index(10 * i + 5) for i in range(4)]
- with E.LoopNestContext(lbs, ubs, [1, 3, 5, 7]) as (i, j, k, l):
- i + j + k + l
- printWithCurrentFunctionName(str(fun))
- # CHECK-LABEL: testLoopNestContext
- # CHECK: "affine.for"() (
- # CHECK: ^bb{{.*}}(%{{.*}}: index):
- # CHECK: "affine.for"() (
- # CHECK: ^bb{{.*}}(%{{.*}}: index):
- # CHECK: "affine.for"() (
- # CHECK: ^bb{{.*}}(%{{.*}}: index):
- # CHECK: "affine.for"() (
- # CHECK: ^bb{{.*}}(%{{.*}}: index):
- # CHECK: %{{.*}} = "affine.apply"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) {map = (d0, d1, d2, d3) -> (d0 + d1 + d2 + d3)} : (index, index, index, index) -> index
-
- def testMLIRBooleanCompilation(self):
- self.setUp()
- m = self.module.make_memref_type(self.boolType, [10]) # i1 tensor
- with self.module.function_context("mkbooltensor", [m, m], []) as f:
- input = E.IndexedValue(f.arg(0))
- output = E.IndexedValue(f.arg(1))
- zero = E.constant_index(0)
- ten = E.constant_index(10)
- with E.LoopNestContext([zero] * 3, [ten] * 3, [1] * 3) as (i, j, k):
- b1 = (i < j) & (j < k)
- b2 = ~b1
- b3 = b2 | (k < j)
- output.store([i], input.load([i]) & b3)
- E.ret([])
- self.module.compile()
- printWithCurrentFunctionName(str(self.module.get_engine_address() == 0))
- # CHECK-LABEL: testMLIRBooleanCompilation
- # CHECK: False
-
- def testMLIRFunctionCreation(self):
- self.setUp()
- module = E.MLIRModule()
- t = module.make_type("f32")
- m = module.make_memref_type(t, [3, 4, -1, 5])
- printWithCurrentFunctionName(str(t))
- print(str(m))
- print(str(module.make_function("copy", [m, m], [])))
- print(str(module.make_function("sqrtf", [t], [t])))
- # CHECK-LABEL: testMLIRFunctionCreation
- # CHECK: f32
- # CHECK: memref<3x4x?x5xf32>
- # CHECK: func @copy(%{{.*}}: memref<3x4x?x5xf32>, %{{.*}}: memref<3x4x?x5xf32>) {
- # CHECK: func @sqrtf(%{{.*}}: f32) -> f32
-
- def testMLIRScalarTypes(self):
- self.setUp()
- module = E.MLIRModule()
- printWithCurrentFunctionName(str(module.make_type("bf16")))
- print(str(module.make_type("f16")))
- print(str(module.make_type("f32")))
- print(str(module.make_type("f64")))
- print(str(module.make_type("i1")))
- print(str(module.make_type("i8")))
- print(str(module.make_type("i32")))
- print(str(module.make_type("i123")))
- print(str(module.make_type("index")))
- # CHECK-LABEL: testMLIRScalarTypes
- # CHECK: bf16
- # CHECK: f16
- # CHECK: f32
- # CHECK: f64
- # CHECK: i1
- # CHECK: i8
- # CHECK: i32
- # CHECK: i123
- # CHECK: index
-
- def testMatrixMultiply(self):
- self.setUp()
- memrefType = self.module.make_memref_type(self.f32Type, [32, 32])
- with self.module.function_context("matmul",
- [memrefType, memrefType, memrefType],
- []) as fun:
- A = E.IndexedValue(fun.arg(0))
- B = E.IndexedValue(fun.arg(1))
- C = E.IndexedValue(fun.arg(2))
- c0 = E.constant_index(0)
- c32 = E.constant_index(32)
- with E.LoopNestContext([c0, c0, c0], [c32, c32, c32],
- [1, 1, 1]) as (i, j, k):
- C.store([i, j], A.load([i, k]) * B.load([k, j]))
- E.ret([])
- printWithCurrentFunctionName(str(fun))
- # CHECK-LABEL: testMatrixMultiply
- # CHECK: "affine.for"()
- # CHECK: "affine.for"()
- # CHECK: "affine.for"()
- # CHECK-DAG: %{{.*}} = "affine.load"
- # CHECK-DAG: %{{.*}} = "affine.load"
- # CHECK: %{{.*}} = mulf %{{.*}}, %{{.*}} : f32
- # CHECK: "affine.store"
- # CHECK-SAME: memref<32x32xf32>
- # CHECK: {lower_bound = () -> (0), step = 1 : index, upper_bound = () -> (32)} : () -> ()
- # CHECK: {lower_bound = () -> (0), step = 1 : index, upper_bound = () -> (32)} : () -> ()
- # CHECK: {lower_bound = () -> (0), step = 1 : index, upper_bound = () -> (32)} : () -> ()
-
- def testRet(self):
- self.setUp()
- with self.module.function_context("foo", [],
- [self.indexType, self.indexType]) as fun:
- c42 = E.constant_index(42)
- c0 = E.constant_index(0)
- E.ret([c42, c0])
- printWithCurrentFunctionName(str(fun))
- # CHECK-LABEL: testRet
- # CHECK: %{{.*}} = constant 42 : index
- # CHECK: %{{.*}} = constant 0 : index
- # CHECK: return %{{.*}}, %{{.*}} : index, index
-
- def testSelectOp(self):
- self.setUp()
- with self.module.function_context("foo", [self.boolType],
- [self.i32Type]) as fun:
- a = E.constant_int(42, 32)
- b = E.constant_int(0, 32)
- E.ret([E.select(fun.arg(0), a, b)])
- printWithCurrentFunctionName(str(fun))
- # CHECK-LABEL: testSelectOp
- # CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : i32
-
- def testType(self):
- self.setUp()
- printWithCurrentFunctionName("")
- with self.module.function_context(
- "foo", [self.module.make_memref_type(self.f32Type, [10])], []) as fun:
- c42 = E.constant_int(42, 32)
- print(str(c42.type()))
- print(str(fun.arg(0).type()))
- # CHECK-LABEL: testType
- # CHECK: i32
- # CHECK: memref<10xf32>
-
-
-# Until python 3.6 this cannot be used because the order in the dict is not the
-# order of method declaration.
-def runTests():
-
- def isTest(attr):
- return inspect.ismethod(attr) and "EdscTest.setUp " not in str(attr)
-
- edscTest = EdscTest()
- tests = sorted(
- filter(isTest, (getattr(edscTest, attr) for attr in dir(edscTest))),
- key=lambda x: str(x))
- for test in tests:
- test()
-
-
-if __name__ == "__main__":
- runTests()
diff --git a/third_party/mlir/g3doc/DeclarativeRewrites.md b/third_party/mlir/g3doc/DeclarativeRewrites.md
index 5adcb32..67ff102 100644
--- a/third_party/mlir/g3doc/DeclarativeRewrites.md
+++ b/third_party/mlir/g3doc/DeclarativeRewrites.md
@@ -233,7 +233,7 @@
Given that `COp` was specified with table-driven op definition, there will be
several `build()` methods generated for it. One of them has aggregated
parameters for result types, operands, and attributes in the signature: `void
-COp::build(..., ArrayRef<Type> resultTypes, Array<Value *> operands,
+COp::build(..., ArrayRef<Type> resultTypes, Array<Value> operands,
ArrayRef<NamedAttribute> attr)`. The pattern in the above calls this `build()`
method for constructing the `COp`.
@@ -266,7 +266,7 @@
```c++
void AOp::build(Builder *builder, OperationState &state,
- Value *input, Attribute attr) {
+ Value input, Attribute attr) {
state.addOperands({input});
state.addAttribute("a_attr", attr);
Type type = ...; // Deduce result type here
@@ -422,7 +422,7 @@
If we have a C++ function for building an op:
```c++
-Operation *createMyOp(OpBuilder builder, Value *input, Attribute attr);
+Operation *createMyOp(OpBuilder builder, Value input, Attribute attr);
```
We can wrap it up and invoke it like:
diff --git a/third_party/mlir/g3doc/DialectConversion.md b/third_party/mlir/g3doc/DialectConversion.md
index b4e309d..e6b652f2 100644
--- a/third_party/mlir/g3doc/DialectConversion.md
+++ b/third_party/mlir/g3doc/DialectConversion.md
@@ -209,7 +209,7 @@
/// the conversion has finished.
virtual Operation *materializeConversion(PatternRewriter &rewriter,
Type resultType,
- ArrayRef<Value *> inputs,
+ ArrayRef<Value> inputs,
Location loc);
};
```
@@ -232,7 +232,7 @@
/// `operands` parameter, containing the remapped operands of the original
/// operation.
virtual PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const;
};
```
@@ -269,7 +269,7 @@
/// Remap an input of the original signature to another `replacement`
/// value. This drops the original argument.
- void remapInput(unsigned origInputNo, Value *replacement);
+ void remapInput(unsigned origInputNo, Value replacement);
};
```
diff --git a/third_party/mlir/g3doc/Dialects/Linalg.md b/third_party/mlir/g3doc/Dialects/Linalg.md
new file mode 100644
index 0000000..1ed5a2c
--- /dev/null
+++ b/third_party/mlir/g3doc/Dialects/Linalg.md
@@ -0,0 +1,8 @@
+# Linalg Dialect
+
+To generate the documentation:
+
+```sh
+mlir-tblgen --gen-op-doc -I /path/to/mlir/include \
+/path/to/mlir/include/mlir/Dialect/Linalg/IR/LinalgDoc.td
+```
diff --git a/third_party/mlir/g3doc/Dialects/SPIR-V.md b/third_party/mlir/g3doc/Dialects/SPIR-V.md
index b753435..1d72e54 100644
--- a/third_party/mlir/g3doc/Dialects/SPIR-V.md
+++ b/third_party/mlir/g3doc/Dialects/SPIR-V.md
@@ -1,47 +1,101 @@
# SPIR-V Dialect
-This document defines the SPIR-V dialect in MLIR.
+This document describes the design of the SPIR-V dialect in MLIR. It lists
+various design choices we made for modeling different SPIR-V mechanisms, and
+their rationale.
-[SPIR-V][SPIR-V] is the Khronos Group’s binary intermediate language for
-representing graphics shaders and compute kernels. It is adopted by multiple
-Khronos Group’s APIs, including Vulkan and OpenCL.
+This document also explains in a high-level manner how different components are
+organized and implemented in the code and gives steps to follow for extending
+them.
-## Design Principles
+This document assumes familiarity with SPIR-V. [SPIR-V][Spirv] is the Khronos
+Group’s binary intermediate language for representing graphics shaders and
+compute kernels. It is adopted by multiple Khronos Group’s APIs, including
+Vulkan and OpenCL. It is fully defined in a
+[human-readable specification][SpirvSpec]; the syntax of various SPIR-V
+instructions are encoded in a [machine-readable grammar][SpirvGrammar].
-SPIR-V defines a stable binary format for hardware driver consumption.
-Regularity is one of the design goals of SPIR-V. All concepts are represented
-as SPIR-V instructions, including declaring extensions and capabilities,
-defining types and constants, defining functions, attaching additional
-properties to computation results, etc. This way favors driver consumption
-but not necessarily compiler transformations.
+## Design Guidelines
-The purpose of the SPIR-V dialect is to serve as the "proxy" of the binary
-format and to facilitate transformations. Therefore, it should
+SPIR-V is a binary intermediate language that serves dual purpose: on one side,
+it is an intermediate language to represent graphics shaders and compute kernels
+for high-level languages to target; on the other side, it defines a stable
+binary format for hardware driver consumption. As a result, SPIR-V has design
+principles pertain to not only intermediate language, but also binary format.
+For example, regularity is one of the design goals of SPIR-V. All concepts are
+represented as SPIR-V instructions, including declaring extensions and
+capabilities, defining types and constants, defining functions, attaching
+additional properties to computation results, etc. This way favors binary
+encoding and decoding for driver consumption but not necessarily compiler
+transformations.
-* Stay as the same semantic level and try to be a mechanical 1:1 mapping;
-* But deviate representationally if possible with MLIR mechanisms.
+### Dialect design principles
+
+The main objective of the SPIR-V dialect is to be a proper intermediate
+representation (IR) to facilitate compiler transformations. While we still aim
+to support serializing to and deserializing from the binary format for various
+good reasons, the binary format and its concerns play less a role in the design
+of the SPIR-V dialect: when there is a trade-off to be made between favoring IR
+and supporting binary format, we lean towards the former.
+
+On the IR aspect, the SPIR-V dialect aims to model SPIR-V at the same semantic
+level. It is not intended to be a higher level or lower level abstraction than
+the SPIR-V specification. Those abstractions are easily outside the domain of
+SPIR-V and should be modeled with other proper dialects so they can be shared
+among various compilation paths. Because of the dual purpose of SPIR-V, SPIR-V
+dialect staying at the same semantic level as the SPIR-V specification also
+means we can still have straightforward serailization and deserailization for
+the majority of functionalities.
+
+To summarize, the SPIR-V dialect follows the following design principles:
+
+* Stay as the same semantic level as the SPIR-V specification by having
+ one-to-one mapping for most concepts and entities.
+* Adopt SPIR-V specification's syntax if possible, but deviate intentionally
+ to utilize MLIR mechanisms if it results in better representation and
+ benefits transformation.
* Be straightforward to serialize into and deserialize from the SPIR-V binary
format.
+SPIR-V is designed to be consumed by hardware drivers, so its representation is
+quite clear, yet verbose for some cases. Allowing representational deviation
+gives us the flexibility to reduce the verbosity by using MLIR mechanisms.
+
+### Dialect scopes
+
+SPIR-V supports multiple execution environments, specified by client APIs.
+Notable adopters include Vulkan and OpenCL. It follows that the SPIR-V dialect
+should support multiple execution environments if to be a proper proxy of SPIR-V
+in MLIR systems. The SPIR-V dialect is designed with these considerations: it
+has proper support for versions, extensions, and capabilities and is as
+extensible as SPIR-V specification.
+
## Conventions
-The SPIR-V dialect has the following conventions:
+The SPIR-V dialect adopts the following conventions for IR:
* The prefix for all SPIR-V types and operations are `spv.`.
-* Ops that directly mirror instructions in the binary format have `CamelCase`
+* All instructions in an extended instruction set are further qualified with
+ the extended instruction set's prefix. For example, all operations in the
+ GLSL extended instruction set is has the prefix of `spv.GLSL.`.
+* Ops that directly mirror instructions in the specification have `CamelCase`
names that are the same as the instruction opnames (without the `Op`
- prefix). For example, `spv.FMul` is a direct mirror of `OpFMul`. They will
- be serialized into and deserialized from one instruction.
+ prefix). For example, `spv.FMul` is a direct mirror of `OpFMul` in the
+ specification. Such an op will be serialized into and deserialized from one
+ SPIR-V instruction.
* Ops with `snake_case` names are those that have different representation
- from corresponding instructions (or concepts) in the binary format. These
+ from corresponding instructions (or concepts) in the specification. These
ops are mostly for defining the SPIR-V structure. For example, `spv.module`
- and `spv.constant`. They may correspond to zero or more instructions during
+ and `spv.constant`. They may correspond to one or more instructions during
(de)serialization.
* Ops with `_snake_case` names are those that have no corresponding
instructions (or concepts) in the binary format. They are introduced to
satisfy MLIR structural requirements. For example, `spv._module_end` and
`spv._merge`. They maps to no instructions during (de)serialization.
+(TODO: consider merging the last two cases and adopting `spv.mlir.` prefix for
+them.)
+
## Module
A SPIR-V module is defined via the `spv.module` op, which has one region that
@@ -49,27 +103,77 @@
are all placed inside the block. Functions are defined using the builtin `func`
op.
-Compared to the binary format, we adjust how certain module-level SPIR-V
-instructions are represented in the SPIR-V dialect. Notably,
+We choose to model a SPIR-V module with a dedicated `spv.module` op based on the
+following considerations:
+
+* It maps cleanly to a SPIR-V module in the specification.
+* We can enforce SPIR-V specific verification that is suitable to be performed
+ at the module-level.
+* We can attach additional model-level attributes.
+* We can control custom assembly form.
+
+The `spv.module` op's region cannot capture SSA values from outside, neither
+implicitly nor explicitly. The `spv.module` op's region is closed as to what ops
+can appear inside: apart from the builtin `func` op, it can only contain ops
+from the SPIR-V dialect. The `spv.module` op's verifier enforces this rule. This
+meaningfully guarantees that a `spv.module` can be the entry point and boundary
+for serialization.
+
+### Module-level operations
+
+SPIR-V binary format defines the following [sections][SpirvLogicalLayout]:
+
+1. Capabilities required by the module.
+1. Extensions required by the module.
+1. Extended instructions sets required by the module.
+1. Addressing and memory model specification.
+1. Entry point specifications.
+1. Execution mode declarations.
+1. Debug instructions.
+1. Annotation/decoration instructions.
+1. Type, constant, global variables.
+1. Function declarations.
+1. Function definitions.
+
+Basically, a SPIR-V binary module contains multiple module-level instructions
+followed by a list of functions. Those module-level instructions are essential
+and they can generate result ids referenced by functions, notably, declaring
+resource variables to interact with the execution environment.
+
+Compared to the binary format, we adjust how these module-level SPIR-V
+instructions are represented in the SPIR-V dialect:
+
+#### Use MLIR attributes for metadata
* Requirements for capabilities, extensions, extended instruction sets,
addressing model, and memory model is conveyed using `spv.module`
attributes. This is considered better because these information are for the
- execution environment. It's easier to probe them if on the module op
- itself.
+ execution environment. It's easier to probe them if on the module op itself.
* Annotations/decoration instructions are "folded" into the instructions they
decorate and represented as attributes on those ops. This eliminates
potential forward references of SSA values, improves IR readability, and
- makes querying the annotations more direct.
+ makes querying the annotations more direct. More discussions can be found in
+ the [`Decorations`](#decorations) section.
+
+#### Model types with MLIR custom types
+
* Types are represented using MLIR standard types and SPIR-V dialect specific
- types. There are no type declaration ops in the SPIR-V dialect.
+ types. There are no type declaration ops in the SPIR-V dialect. More
+ discussions can be found in the [Types](#types) section later.
+
+#### Unify and localize constants
+
* Various normal constant instructions are represented by the same
`spv.constant` op. Those instructions are just for constants of different
types; using one op to represent them reduces IR verbosity and makes
transformations less tedious.
* Normal constants are not placed in `spv.module`'s region; they are localized
into functions. This is to make functions in the SPIR-V dialect to be
- isolated and explicit capturing.
+ isolated and explicit capturing. Constants are cheap to duplicate given
+ attributes are uniqued in `MLIRContext`.
+
+#### Adopt symbol-based global variables and specialization constant
+
* Global variables are defined with the `spv.globalVariable` op. They do not
generate SSA values. Instead they have symbols and should be referenced via
symbols. To use a global variables in a function block, `spv._address_of` is
@@ -79,15 +183,90 @@
reference, too. `spv._reference_of` is needed to turn the symbol into a SSA
value for use in a function block.
+The above choices enables functions in the SPIR-V dialect to be isolated and
+explicit capturing.
+
+#### Disallow implicit capturing in functions
+
+* In SPIR-V specification, functions support implicit capturing: they can
+ reference SSA values defined in modules. In the SPIR-V dialect functions are
+ defined with `func` op, which disallows implicit capturing. This is more
+ friendly to compiler analyses and transformations. More discussions can be
+ found in the [Function](#function) section later.
+
+### Model entry points and execution models as normal ops
+
+* A SPIR-V module can have multiple entry points. And these entry points refer
+ to the function and interface variables. It’s not suitable to model them as
+ `spv.module` op attributes. We can model them as normal ops of using symbol
+ references.
+* Similarly for execution modes, which are coupled with entry points, we can
+ model them as normal ops in `spv.module`'s region.
+
+## Decorations
+
+Annotations/decorations provide additional information on result ids. In SPIR-V,
+all instructions can generate result ids, including value-computing and
+type-defining ones.
+
+For decorations on value result ids, we can just have a corresponding attribute
+attached to the operation generating the SSA value. For example, for the
+following SPIR-V:
+
+```spirv
+OpDecorate %v1 RelaxedPrecision
+OpDecorate %v2 NoContraction
+...
+%v1 = OpFMul %float %0 %0
+%v2 = OpFMul %float %1 %1
+```
+
+We can represent them in the SPIR-V dialect as:
+
+```mlir
+%v1 = "spv.FMul"(%0, %0) {RelaxedPrecision: unit} : (f32, f32) -> (f32)
+%v2 = "spv.FMul"(%1, %1) {NoContraction: unit} : (f32, f32) -> (f32)
+```
+
+This approach benefits transformations. Essentially those decorations are just
+additional properties of the result ids (and thus their defining instructions).
+In SPIR-V binary format, they are just represented as instructions. Literally
+following SPIR-V binary format means we need to through def-use chains to find
+the decoration instructions and query information from them.
+
+For decorations on type result ids, notice that practically, only result ids
+generated from composite types (e.g., `OpTypeArray`, `OpTypeStruct`) need to be
+decorated for memory layouting purpose (e.g., `ArrayStride`, `Offset`, etc.);
+scalar/vector types are required to be uniqued in SPIR-V. Therefore, we can just
+encode them directly in the dialect-specific type.
+
## Types
-The SPIR-V dialect reuses standard integer, float, and vector types and defines
-the following dialect-specific types:
+Theoretically we can define all SPIR-V types using MLIR extensible type system,
+but other than representational purity, it does not buy us more. Instead, we
+need to maintain the code and invest in pretty printing them. So we prefer to
+use builtin/standard types if possible.
+
+The SPIR-V dialect reuses standard integer, float, and vector types:
+
+Specification | Dialect
+:----------------------------------: | :-------------------------------:
+`OpTypeBool` | `i1`
+`OpTypeInt <bitwidth>` | `i<bitwidth>`
+`OpTypeFloat <bitwidth>` | `f<bitwidth>`
+`OpTypeVector <scalar-type> <count>` | `vector<<count> x <scalar-type>>`
+
+Similarly, `mlir::NoneType` can be used for SPIR-V `OpTypeVoid`; builtin
+function types can be used for SPIR-V `OpTypeFunction` types.
+
+The SPIR-V dialect and defines the following dialect-specific types:
```
spirv-type ::= array-type
+ | image-type
| pointer-type
| runtime-array-type
+ | struct-type
```
### Array type
@@ -134,7 +313,7 @@
For example,
-```
+```mlir
!spv.image<f32, 1D, NoDepth, NonArrayed, SingleSampled, SamplerUnknown, Unknown>
!spv.image<f32, Cube, IsDepth, Arrayed, MultiSampled, NeedSampler, Rgba32f>
```
@@ -186,7 +365,7 @@
For Example,
-```
+```mlir
!spv.struct<f32>
!spv.struct<f32 [0]>
!spv.struct<f32, !spv.image<f32, 1D, NoDepth, NonArrayed, SingleSampled, SamplerUnknown, Unknown>>
@@ -195,16 +374,115 @@
## Function
-A SPIR-V function is defined using the builtin `func` op. `spv.module` verifies
-that the functions inside it comply with SPIR-V requirements: at most one
-result, no nested functions, and so on.
+In SPIR-V, a function construct consists of multiple instructions involving
+`OpFunction`, `OpFunctionParameter`, `OpLabel`, `OpFunctionEnd`.
+
+```spirv
+// int f(int v) { return v; }
+%1 = OpTypeInt 32 0
+%2 = OpTypeFunction %1 %1
+%3 = OpFunction %1 %2
+%4 = OpFunctionParameter %1
+%5 = OpLabel
+%6 = OpReturnValue %4
+ OpFunctionEnd
+```
+
+This construct is very clear yet quite verbose. It is intended for driver
+consumption. There is little benefit to literally replicate this construct in
+the SPIR-V dialect. Instead, we reuse the builtin `func` op to express functions
+more concisely:
+
+```mlir
+func @f(%arg: i32) -> i32 {
+ "spv.ReturnValue"(%arg) : (i32) -> (i32)
+}
+```
+
+A SPIR-V function can have at most one result. It cannot contain nested
+functions or non-SPIR-V operations. `spv.module` verifies these requirements.
+
+A major difference between the SPIR-V dialect and the SPIR-V specification for
+functions is that the former are isolated and require explicit capturing, while
+the latter allow implicit capturing. In SPIR-V specification, functions can
+refer to SSA values (generated by constants, global variables, etc.) defined in
+modules. The SPIR-V dialect adjusted how constants and global variables are
+modeled to enable isolated functions. Isolated functions are more friendly to
+compiler analyses and transformations. This also enables the SPIR-V dialect to
+better utilize core infrastructure: many functionalities in the core
+infrastructure requires ops to be isolated, e.g., the
+[greedy pattern rewriter][GreedyPatternRewriter] can only act on ops isolated
+from above.
+
+(TODO: create a dedicated `spv.fn` op for SPIR-V functions.)
## Operations
+In SPIR-V, instruction is a generalized concept; a SPIR-V module is just a
+sequence of instructions. Declaring types, expressing computations, annotating
+result ids, expressing control flows and others are all in the form of
+instructions.
+
+We only discuss instructions expressing computations here, which can be
+represented via SPIR-V dialect ops. Module-level instructions for declarations
+and definitions are represented differently in the SPIR-V dialect as explained
+earlier in the [Module-level operations](#module-level-operations) section.
+
+An instruction computes zero or one result from zero or more operands. The
+result is a new result id. An operand can be a result id generated by a previous
+instruction, an immediate value, or a case of an enum type. We can model result
+id operands and results with MLIR SSA values; for immediate value and enum
+cases, we can model them with MLIR attributes.
+
+For example,
+
+```spirv
+%i32 = OpTypeInt 32 0
+%c42 = OpConstant %i32 42
+...
+%3 = OpVariable %i32 Function 42
+%4 = OpIAdd %i32 %c42 %c42
+```
+
+can be represented in the dialect as
+
+```mlir
+%0 = "spv.constant"() { value = 42 : i32 } : () -> i32
+%1 = "spv.Variable"(%0) { storage_class = "Function" } : (i32) -> !spv.ptr<i32, Function>
+%2 = "spv.IAdd"(%0, %0) : (i32, i32) -> i32
+```
+
Operation documentation is written in each op's Op Definition Spec using
TableGen. A markdown version of the doc can be generated using `mlir-tblgen
-gen-doc`.
+### Ops from extended instruction sets
+
+Analogically extended instruction set is a mechanism to import SPIR-V
+instructions within another namespace. [`GLSL.std.450`][GlslStd450] is an
+extended instruction set that provides common mathematical routines that should
+be supported. Instead of modeling `OpExtInstImport` as a separate op and use a
+single op to model `OpExtInst` for all extended instructions, we model each
+SPIR-V instruction in an extended instruction set as a separate op with the
+proper name prefix. For example, for
+
+```spirv
+%glsl = OpExtInstImport "GLSL.std.450"
+
+%f32 = OpTypeFloat 32
+%cst = OpConstant %f32 ...
+
+%1 = OpExtInst %f32 %glsl 28 %cst
+%2 = OpExtInst %f32 %glsl 31 %cst
+```
+
+we can have
+
+```mlir
+%1 = "spv.GLSL.Log"(%cst) : (f32) -> (f32)
+%2 = "spv.GLSL.Sqrt(%cst) : (f32) -> (f32)
+```
+
## Control Flow
SPIR-V binary format uses merge instructions (`OpSelectionMerge` and
@@ -447,44 +725,315 @@
}
```
+## Shader interface (ABI)
+
+SPIR-V itself is just expressing computation happening on GPU device. SPIR-V
+programs themselves are not enough for running workloads on GPU; a companion
+host application is needed to manage the resources referenced by SPIR-V programs
+and dispatch the workload. For the Vulkan execution environment, the host
+application will be written using Vulkan API. Unlike CUDA, the SPIR-V program
+and the Vulkan application are typically authored with different front-end
+languages, which isolates these two worlds. Yet they still need to match
+_interfaces_: the variables declared in a SPIR-V program for referencing
+resources need to match with the actual resources managed by the application
+regarding their parameters.
+
+Still using Vulkan as an example execution environment, there are two primary
+resource types in Vulkan: buffers and images. They are used to back various uses
+that may differ regarding the classes of operations (load, store, atomic) to be
+performed. These uses are differentiated via descriptor types. (For example,
+uniform storage buffer descriptors can only support load operations while
+storage buffer descriptors can support load, store, and atomic operations.)
+Vulkan uses a binding model for resources. Resources are associated with
+descriptors and descriptors are further grouped into sets. Each descriptor thus
+has a set number and a binding number. Descriptors in the application
+corresponds to variables in the SPIR-V program. Their parameters must match,
+including but not limited to set and binding numbers.
+
+Apart from buffers and images, there is other data that is set up by Vulkan and
+referenced inside the SPIR-V program, for example, push constants. They also
+have parameters that require matching between the two worlds.
+
+The interface requirements are external information to the SPIR-V compilation
+path in MLIR. Besides, each Vulkan application may want to handle resources
+differently. To avoid duplication and to share common utilities, a SPIR-V shader
+interface specification needs to be defined to provide the external requirements
+to and guide the SPIR-V compilation path.
+
+### Shader interface attributes
+
+The SPIR-V dialect defines [a few attributes][MlirSpirvAbi] for specifying these
+interfaces:
+
+* `spv.entry_point_abi` is a struct attribute that should be attached to the
+ entry function. It contains:
+ * `local_size` for specifying the local work group size for the dispatch.
+* `spv.interface_var_abi` is a struct attribute that should be attached to
+ each operand and result of the entry function. It contains:
+ * `descriptor_set` for specifying the descriptor set number for the
+ corresponding resource variable.
+ * `binding` for specifying the binding number for the corresponding
+ resource variable.
+ * `storage_class` for specifying the storage class for the corresponding
+ resource variable.
+
+The SPIR-V dialect provides a [`LowerABIAttributesPass`][MlirSpirvPasses] for
+consuming these attributes and create SPIR-V module complying with the
+interface.
+
## Serialization and deserialization
+Although the main objective of the SPIR-V dialect is to act as a proper IR for
+compiler transformations, being able to serialize to and deserialize from the
+binary format is still very valuable for many good reasons. Serialization
+enables the artifacts of SPIR-V compilation to be consumed by a execution
+environment; deserialization allows us to import SPIR-V binary modules and run
+transformations on them. So serialization and deserialization is supported from
+the very beginning of the development of the SPIR-V dialect.
+
The serialization library provides two entry points, `mlir::spirv::serialize()`
and `mlir::spirv::deserialize()`, for converting a MLIR SPIR-V module to binary
-format and back.
+format and back. The [Code organization](#code-organization) explains more about
+this.
-The purpose of this library is to enable importing SPIR-V binary modules to run
-transformations on them and exporting SPIR-V modules to be consumed by execution
-environments. The focus is transformations, which inevitably means changes to
-the binary module; so it is not designed to be a general tool for investigating
-the SPIR-V binary module and does not guarantee roundtrip equivalence (at least
-for now). For the latter, please use the assembler/disassembler in the
-[SPIRV-Tools][SPIRV-Tools] project.
+Given that the focus is transformations, which inevitably means changes to the
+binary module; so serialization is not designed to be a general tool for
+investigating the SPIR-V binary module and does not guarantee roundtrip
+equivalence (at least for now). For the latter, please use the
+assembler/disassembler in the [SPIRV-Tools][SpirvTools] project.
A few transformations are performed in the process of serialization because of
the representational differences between SPIR-V dialect and binary format:
* Attributes on `spv.module` are emitted as their corresponding SPIR-V
instructions.
+* Types are serialized into `OpType*` instructions in the SPIR-V binary module
+ section for types, constants, and global variables.
* `spv.constant`s are unified and placed in the SPIR-V binary module section
for types, constants, and global variables.
+* Attributes on ops, if not part of the op's binary encoding, are emitted as
+ `OpDecorate*` instructions in the SPIR-V binary module section for
+ decorations.
* `spv.selection`s and `spv.loop`s are emitted as basic blocks with `Op*Merge`
instructions in the header block as required by the binary format.
+* Block arguments are materialized as `OpPhi` instructions at the beginning of
+ the corresponding blocks.
Similarly, a few transformations are performed during deserialization:
-* Instructions for execution environment requirements will be placed as
- attributes on `spv.module`.
+* Instructions for execution environment requirements (extensions,
+ capabilities, extended instruction sets, etc.) will be placed as attributes
+ on `spv.module`.
+* `OpType*` instructions will be converted into proper `mlir::Type`s.
* `OpConstant*` instructions are materialized as `spv.constant` at each use
site.
+* `OpVariable` instructions will be converted to `spv.globalVariable` ops if
+ in module-level; otherwise they will be converted into `spv.Variable` ops.
+* Every use of a module-level `OpVariable` instruction will materialize a
+ `spv._address_of` op to turn the symbol of the corresponding
+ `spv.globalVariable` into an SSA value.
+* Every use of a `OpSpecConstant` instruction will materialize a
+ `spv._reference_of` op to turn the symbol of the corresponding
+ `spv.specConstant` into an SSA value.
* `OpPhi` instructions are converted to block arguments.
* Structured control flow are placed inside `spv.selection` and `spv.loop`.
-[SPIR-V]: https://www.khronos.org/registry/spir-v/
+## Conversions
+
+(TODO: expand this section)
+
+## Code organization
+
+We aim to provide multiple libraries with clear dependencies for SPIR-V related
+functionalities in MLIR so developers can just choose the needed components
+without pulling in the whole world.
+
+### The dialect
+
+The code for the SPIR-V dialect resides in a few places:
+
+* Public headers are placed in [include/mlir/Dialect/SPIRV][MlirSpirvHeaders].
+* Libraries are placed in [lib/Dialect/SPIRV][MlirSpirvLibs].
+* IR tests are placed in [test/Dialect/SPIRV][MlirSpirvTests].
+* Unit tests are placed in [unittests/Dialect/SPIRV][MlirSpirvUnittests].
+
+The whole SPIR-V dialect is exposed via multiple headers for better
+organization:
+
+* [SPIRVDialect.h][MlirSpirvDialect] defines the SPIR-V dialect.
+* [SPIRVTypes.h][MlirSpirvTypes] defines all SPIR-V specific types.
+* [SPIRVOps.h][MlirSPirvOps] defines all SPIR-V operations.
+* [Serialization.h][MlirSpirvSerialization] defines the entry points for
+ serialization and deserialization.
+
+The dialect itself, including all types and ops, is in the `MLIRSPIRV` library.
+Serialization functionalities are in the `MLIRSPIRVSerialization` library.
+
+### Op definitions
+
+We use [Op Definition Spec][ODS] to define all SPIR-V ops. They are written in
+TableGen syntax and placed in various `*Ops.td` files in the header directory.
+Those `*Ops.td` files are organized according to the instruction categories used
+in the SPIR-V specification, for example, an op belonging to the "Atomics
+Instructions" section is put in the `SPIRVAtomicOps.td` file.
+
+`SPIRVOps.td` serves as the master op definition file that includes all files
+for specific categories.
+
+`SPIRVBase.td` defines common classes and utilities used by various op
+definitions. It contains the TableGen SPIR-V dialect definition, SPIR-V
+versions, known extensions, various SPIR-V enums, TableGen SPIR-V types, and
+base op classes, etc.
+
+Many of the contents in `SPIRVBase.td`, e.g., the opcodes and various enums, and
+all `*Ops.td` files can be automatically updated via a Python script, which
+queries the SPIR-V specification and grammar. This greatly reduces the burden of
+supporting new ops and keeping updated with the SPIR-V spec. More details on
+this automated development can be found in the
+[Automated development flow](#automated-development-flow) section.
+
+### Dialect conversions
+
+The code for conversions from other dialects to the SPIR-V dialect also resides
+in a few places:
+
+* From GPU dialect: headers are at
+ [include/mlir/Conversion/GPUTOSPIRV][MlirGpuToSpirvHeaders]; libraries are
+ at [lib/Conversion/GPUToSPIRV][MlirGpuToSpirvLibs].
+* From standard dialect: headers are at
+ [include/mlir/Conversion/StandardTOSPIRV][MlirStdToSpirvHeaders]; libraries
+ are at [lib/Conversion/StandardToSPIRV][MlirStdToSpirvLibs].
+
+These dialect to dialect conversions have their dedicated libraries,
+`MLIRGPUToSPIRVTransforms` and `MLIRStandardToSPIRVTransforms`, respectively.
+
+There are also common utilities when targeting SPIR-V from any dialect:
+
+* [include/mlir/Dialect/SPIRV/Passes.h][MlirSpirvPasses] contains SPIR-V
+ specific analyses and transformations.
+* [include/mlir/Dialect/SPIRV/SPIRVLowering.h][MlirSpirvLowering] contains
+ type converters and other utility functions.
+
+These common utilities are implemented in the `MLIRSPIRVTransforms` library.
+
+## Contribution
+
+All kinds of contributions are highly appreciated! :) We have GitHub issues for
+tracking the [dialect][GitHubDialectTracking] and
+[lowering][GitHubLoweringTracking] development. You can find todo tasks there.
+The [Code organization](#code-organization) section gives an overview of how
+SPIR-V related functionalities are implemented in MLIR. This section gives more
+concrete steps on how to contribute.
+
+### Automated development flow
+
+One of the goals of SPIR-V dialect development is to leverage both the SPIR-V
+[human-readable specification][SpirvSpec] and
+[machine-readable grammar][SpirvGrammar] to auto-generate as much contents as
+possible. Specifically, the following tasks can be automated (partially or
+fully):
+
+* Adding support for a new operation.
+* Adding support for a new SPIR-V enum.
+* Serialization and deserialization of a new operation.
+
+We achieve this using the Python script
+[`gen_spirv_dialect.py`][GenSpirvUtilsPy]. It fetches the human-readable
+specification and machine-readable grammar directly from the Internet and
+updates various SPIR-V `*.td` files in place. The script gives us an automated
+flow for adding support for new ops or enums.
+
+Afterwards, we have SPIR-V specific `mlir-tblgen` backends for reading the Op
+Definition Spec and generate various components, including (de)serialization
+logic for ops. Together with standard `mlir-tblgen` backends, we auto-generate
+all op classes, enum classes, etc.
+
+In the following subsections, we list the detailed steps to follow for common
+tasks.
+
+### Add a new op
+
+To add a new op, invoke the `define_inst.sh` script wrapper in utils/spirv.
+`define_inst.sh` requires a few parameters:
+
+```sh
+./define_inst.sh <filename> <base-class-name> <opname>
+```
+
+For example, to define the op for `OpIAdd`, invoke
+
+```sh
+./define_inst.sh SPIRVArithmeticOps.td ArithmeticBinaryOp OpIAdd
+```
+
+where `SPIRVArithmeticOps.td` is the filename for hosting the new op and
+`ArithmeticBinaryOp` is the direct base class the newly defined op will derive
+from.
+
+Similarly, to define the op for `OpAtomicAnd`,
+
+```sh
+./define_inst.sh SPIRVAtomicOps.td AtomicUpdateWithValueOp OpAtomicAnd
+```
+
+Note that the generated SPIR-V op definition is just a best-effort template; it
+is still expected to be updated to have more accurate traits, arguments, and
+results.
+
+The generated op will automatically gain the logic for (de)serialization.
+However, tests still need to be coupled with the change to make sure no
+surprises. Serialization tests live in test/Dialect/SPIRV/Serialization.
+
+### Add a new enum
+
+To add a new enum, invoke the `define_enum.sh` script wrapper in utils/spirv.
+`define_enum.sh` expects the following parameters:
+
+```sh
+./define_enum.sh <enum-class-name>
+```
+
+For example, to add the definition for SPIR-V storage class in to
+`SPIRVBase.td`:
+
+```sh
+./define_enum.sh StorageClass
+```
+
+### Add a new conversion
+
+(TODO: add details for this section)
+
+[Spirv]: https://www.khronos.org/registry/spir-v/
+[SpirvSpec]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html
+[SpirvLogicalLayout]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
+[SpirvGrammar]: https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/spirv.core.grammar.json
+[GlslStd450]: https://www.khronos.org/registry/spir-v/specs/1.0/GLSL.std.450.html
[ArrayType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeArray
[ImageType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeImage
[PointerType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypePointer
[RuntimeArrayType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeRuntimeArray
[StructType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#Structure
-[SPIRV-Tools]: https://github.com/KhronosGroup/SPIRV-Tools
+[SpirvTools]: https://github.com/KhronosGroup/SPIRV-Tools
[Rationale]: https://github.com/tensorflow/mlir/blob/master/g3doc/Rationale.md#block-arguments-vs-phi-nodes
+[ODS]: https://github.com/tensorflow/mlir/blob/master/g3doc/OpDefinitions.md
+[GreedyPatternRewriter]: https://github.com/tensorflow/mlir/blob/master/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+[MlirSpirvHeaders]: https://github.com/tensorflow/mlir/tree/master/include/mlir/Dialect/SPIRV
+[MlirSpirvLibs]: https://github.com/tensorflow/mlir/tree/master/lib/Dialect/SPIRV
+[MlirSpirvTests]: https://github.com/tensorflow/mlir/tree/master/test/Dialect/SPIRV
+[MlirSpirvUnittests]: https://github.com/tensorflow/mlir/tree/master/unittests/Dialect/SPIRV
+[MlirGpuToSpirvHeaders]: https://github.com/tensorflow/mlir/tree/master/include/mlir/Conversion/GPUToSPIRV
+[MlirGpuToSpirvLibs]: https://github.com/tensorflow/mlir/tree/master/lib/Conversion/GPUToSPIRV
+[MlirStdToSpirvHeaders]: https://github.com/tensorflow/mlir/tree/master/include/mlir/Conversion/StandardToSPIRV
+[MlirStdToSpirvLibs]: https://github.com/tensorflow/mlir/tree/master/lib/Conversion/StandardToSPIRV
+[MlirSpirvDialect]: https://github.com/tensorflow/mlir/blob/master/include/mlir/Dialect/SPIRV/SPIRVDialect.h
+[MlirSpirvTypes]: https://github.com/tensorflow/mlir/blob/master/include/mlir/Dialect/SPIRV/SPIRVTypes.h
+[MlirSpirvOps]: https://github.com/tensorflow/mlir/blob/master/include/mlir/Dialect/SPIRV/SPIRVOps.h
+[MlirSpirvSerialization]: https://github.com/tensorflow/mlir/blob/master/include/mlir/Dialect/SPIRV/Serialization.h
+[MlirSpirvBase]: https://github.com/tensorflow/mlir/blob/master/include/mlir/Dialect/SPIRV/SPIRVBase.td
+[MlirSpirvPasses]: https://github.com/tensorflow/mlir/blob/master/include/mlir/Dialect/SPIRV/Passes.h
+[MlirSpirvLowering]: https://github.com/tensorflow/mlir/blob/master/include/mlir/Dialect/SPIRV/SPIRVLowering.h
+[MlirSpirvAbi]: https://github.com/tensorflow/mlir/blob/master/include/mlir/Dialect/SPIRV/SPIRVLowering.td
+[GitHubDialectTracking]: https://github.com/tensorflow/mlir/issues/302
+[GitHubLoweringTracking]: https://github.com/tensorflow/mlir/issues/303
+[GenSpirvUtilsPy]: https://github.com/tensorflow/mlir/blob/master/utils/spirv/gen_spirv_dialect.py
diff --git a/third_party/mlir/g3doc/EDSC.md b/third_party/mlir/g3doc/EDSC.md
index afceac2..eaaeb6c 100644
--- a/third_party/mlir/g3doc/EDSC.md
+++ b/third_party/mlir/g3doc/EDSC.md
@@ -15,10 +15,10 @@
## ValueHandle and IndexHandle
`mlir::edsc::ValueHandle` and `mlir::edsc::IndexHandle` provide typed
-abstractions around an `mlir::Value*`. These abstractions are "delayed", in the
-sense that they allow separating declaration from definition. They may
-capture IR snippets, as they are built, for programmatic manipulation.
-Intuitive operators are provided to allow concise and idiomatic expressions.
+abstractions around an `mlir::Value`. These abstractions are "delayed", in the
+sense that they allow separating declaration from definition. They may capture
+IR snippets, as they are built, for programmatic manipulation. Intuitive
+operators are provided to allow concise and idiomatic expressions.
```c++
ValueHandle zero = constant_index(0);
diff --git a/third_party/mlir/g3doc/GenericDAGRewriter.md b/third_party/mlir/g3doc/GenericDAGRewriter.md
index 3b26c22..8cc09f7 100644
--- a/third_party/mlir/g3doc/GenericDAGRewriter.md
+++ b/third_party/mlir/g3doc/GenericDAGRewriter.md
@@ -128,7 +128,7 @@
if (match(LHS, m_Xor(m_Value(Y), m_APInt(C1))))
if (C1->countTrailingZeros() == 0)
if (match(Y, m_And(m_Value(Z), m_APInt(C2))) && *C1 == (*C2 + 1)) {
- Value *NewOr = Builder.CreateOr(Z, ~(*C2));
+ Value NewOr = Builder.CreateOr(Z, ~(*C2));
return Builder.CreateSub(RHS, NewOr, "sub");
}
```
diff --git a/third_party/mlir/g3doc/OpDefinitions.md b/third_party/mlir/g3doc/OpDefinitions.md
index 1f98671..ff3a21f 100644
--- a/third_party/mlir/g3doc/OpDefinitions.md
+++ b/third_party/mlir/g3doc/OpDefinitions.md
@@ -360,7 +360,7 @@
// A new non-static method accepting an input argument.
InterfaceMethod<"/*insert doc here*/",
- "Value *", "bar", (ins "unsigned":$i)
+ "Value ", "bar", (ins "unsigned":$i)
>,
// Query a static property of the derived operation.
@@ -438,7 +438,7 @@
// for attributes are of mlir::Attribute types.
static void build(Builder *tblgen_builder, OperationState &tblgen_state,
Type i32_result, Type f32_result, ...,
- Value *i32_operand, Value *f32_operand, ...,
+ Value i32_operand, Value f32_operand, ...,
IntegerAttr i32_attr, FloatAttr f32_attr, ...);
// Each result-type/operand/attribute has a separate parameter. The parameters
@@ -447,13 +447,13 @@
// explanation for more details.)
static void build(Builder *tblgen_builder, OperationState &tblgen_state,
Type i32_result, Type f32_result, ...,
- Value *i32_operand, Value *f32_operand, ...,
+ Value i32_operand, Value f32_operand, ...,
APInt i32_attr, StringRef f32_attr, ...);
// Each operand/attribute has a separate parameter but result type is aggregate.
static void build(Builder *tblgen_builder, OperationState &tblgen_state,
ArrayRef<Type> resultTypes,
- Value *i32_operand, Value *f32_operand, ...,
+ Value i32_operand, Value f32_operand, ...,
IntegerAttr i32_attr, FloatAttr f32_attr, ...);
// All operands/attributes have aggregate parameters.
@@ -615,10 +615,9 @@
For each operation, we automatically generate an _operand adaptor_. This class
solves the problem of accessing operands provided as a list of `Value`s without
using "magic" constants. The operand adaptor takes a reference to an array of
-`Value *` and provides methods with the same names as those in the operation
-class to access them. For example, for a binary arithmetic operation, it may
-provide `.lhs()` to access the first operand and `.rhs()` to access the second
-operand.
+`Value` and provides methods with the same names as those in the operation class
+to access them. For example, for a binary arithmetic operation, it may provide
+`.lhs()` to access the first operand and `.rhs()` to access the second operand.
The operand adaptor class lives in the same namespace as the operation class,
and has the name of the operation followed by `OperandAdaptor`. A template
@@ -629,11 +628,11 @@
```c++
template <typename BinaryOpTy>
-std::pair<Value *, Value *> zip(BinaryOpTy &&op) {
+std::pair<Value, Value> zip(BinaryOpTy &&op) {
return std::make_pair(op.lhs(), op.rhs());;
}
-void process(AddOp op, ArrayRef<Value *> newOperands) {
+void process(AddOp op, ArrayRef<Value> newOperands) {
zip(op);
zip(OperandAdaptor<AddOp>(newOperands));
/*...*/
diff --git a/third_party/mlir/g3doc/QuickstartRewrites.md b/third_party/mlir/g3doc/QuickstartRewrites.md
index d7bf9a5..6a4a7cc 100644
--- a/third_party/mlir/g3doc/QuickstartRewrites.md
+++ b/third_party/mlir/g3doc/QuickstartRewrites.md
@@ -128,8 +128,8 @@
```
```c++
-static Value* createTFLLeakyRelu(PatternRewriter &rewriter, Operation *op,
- Value* operand, Attribute attr) {
+static Value createTFLLeakyRelu(PatternRewriter &rewriter, Operation *op,
+ Value operand, Attribute attr) {
return rewriter.create<mlir::TFL::LeakyReluOp>(
op->getLoc(), operands[0]->getType(), /*arg=*/operands[0],
/*alpha=*/attrs[0].cast<FloatAttr>());
diff --git a/third_party/mlir/g3doc/Rationale.md b/third_party/mlir/g3doc/Rationale.md
index 66cf800..763442d 100644
--- a/third_party/mlir/g3doc/Rationale.md
+++ b/third_party/mlir/g3doc/Rationale.md
@@ -1099,7 +1099,7 @@
The problem is that LLVM has several objects in its IR that are globally uniqued
and also mutable: notably constants like `i32 0`. In LLVM, these constants are
-`Value*r`'s, which allow them to be used as operands to instructions, and that
+`Value`'s, which allow them to be used as operands to instructions, and that
they also have SSA use lists. Because these things are uniqued, every `i32 0` in
any function shares a use list. This means that optimizing multiple functions in
parallel won't work (at least without some sort of synchronization on the use
diff --git a/third_party/mlir/g3doc/Tutorials/Toy/Ch-3.md b/third_party/mlir/g3doc/Tutorials/Toy/Ch-3.md
index 07ead64..615c2c1 100644
--- a/third_party/mlir/g3doc/Tutorials/Toy/Ch-3.md
+++ b/third_party/mlir/g3doc/Tutorials/Toy/Ch-3.md
@@ -90,7 +90,7 @@
matchAndRewrite(TransposeOp op,
mlir::PatternRewriter &rewriter) const override {
// Look through the input of the current transpose.
- mlir::Value *transposeInput = op.getOperand();
+ mlir::Value transposeInput = op.getOperand();
TransposeOp transposeInputOp =
llvm::dyn_cast_or_null<TransposeOp>(transposeInput->getDefiningOp());
// If the input is defined by another Transpose, bingo!
diff --git a/third_party/mlir/g3doc/Tutorials/Toy/Ch-4.md b/third_party/mlir/g3doc/Tutorials/Toy/Ch-4.md
index ac12469..4a4e11c 100644
--- a/third_party/mlir/g3doc/Tutorials/Toy/Ch-4.md
+++ b/third_party/mlir/g3doc/Tutorials/Toy/Ch-4.md
@@ -75,7 +75,7 @@
/// previously returned by the call operation with the operands of the
/// return.
void handleTerminator(Operation *op,
- ArrayRef<Value *> valuesToRepl) const final {
+ ArrayRef<Value> valuesToRepl) const final {
// Only "toy.return" needs to be handled here.
auto returnOp = cast<ReturnOp>(op);
@@ -207,7 +207,7 @@
/// operation that takes 'input' as the only operand, and produces a single
/// result of 'resultType'. If a conversion can not be generated, nullptr
/// should be returned.
- Operation *materializeCallConversion(OpBuilder &builder, Value *input,
+ Operation *materializeCallConversion(OpBuilder &builder, Value input,
Type resultType,
Location conversionLoc) const final {
return builder.create<CastOp>(conversionLoc, resultType, input);
diff --git a/third_party/mlir/g3doc/Tutorials/Toy/Ch-5.md b/third_party/mlir/g3doc/Tutorials/Toy/Ch-5.md
index 1124cf1..8a4268b 100644
--- a/third_party/mlir/g3doc/Tutorials/Toy/Ch-5.md
+++ b/third_party/mlir/g3doc/Tutorials/Toy/Ch-5.md
@@ -101,7 +101,7 @@
/// Match and rewrite the given `toy.transpose` operation, with the given
/// operands that have been remapped from `tensor<...>` to `memref<...>`.
mlir::PatternMatchResult
- matchAndRewrite(mlir::Operation *op, ArrayRef<mlir::Value *> operands,
+ matchAndRewrite(mlir::Operation *op, ArrayRef<mlir::Value> operands,
mlir::ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
@@ -112,18 +112,18 @@
lowerOpToLoops(
op, operands, rewriter,
[loc](mlir::PatternRewriter &rewriter,
- ArrayRef<mlir::Value *> memRefOperands,
- ArrayRef<mlir::Value *> loopIvs) {
+ ArrayRef<mlir::Value> memRefOperands,
+ ArrayRef<mlir::Value> loopIvs) {
// Generate an adaptor for the remapped operands of the TransposeOp.
// This allows for using the nice named accessors that are generated
// by the ODS. This adaptor is automatically provided by the ODS
// framework.
TransposeOpOperandAdaptor transposeAdaptor(memRefOperands);
- mlir::Value *input = transposeAdaptor.input();
+ mlir::Value input = transposeAdaptor.input();
// Transpose the elements by generating a load from the reverse
// indices.
- SmallVector<mlir::Value *, 2> reverseIvs(llvm::reverse(loopIvs));
+ SmallVector<mlir::Value, 2> reverseIvs(llvm::reverse(loopIvs));
return rewriter.create<mlir::AffineLoadOp>(loc, input, reverseIvs);
});
return matchSuccess();
diff --git a/third_party/mlir/g3doc/UsageOfConst.md b/third_party/mlir/g3doc/UsageOfConst.md
index 052f14d..6e8ce78 100644
--- a/third_party/mlir/g3doc/UsageOfConst.md
+++ b/third_party/mlir/g3doc/UsageOfConst.md
@@ -10,8 +10,8 @@
The design team since decided to change to a different module, which eschews
`const` entirely for the core IR types: you should never see a `const` method on
-`Operation`, should never see the type `const Value *`, and you shouldn't feel
-bad about this. That said, you *should* use `const` for non-IR types, like
+`Operation`, should never see the type `const Value`, and you shouldn't feel bad
+about this. That said, you *should* use `const` for non-IR types, like
`SmallVector`'s and many other things.
The document below explains this design point from the viewpoint of "why make a
@@ -39,7 +39,7 @@
a poor tradeoff, and proposes switching to a much simpler approach - eliminating
the use of const of these IR types entirely.
-**Note:** **This document is only discussing things like `const Value*` and
+**Note:** **This document is only discussing things like `const Value` and
`const Operation*`. There is no proposed change for other types, e.g.
`SmallVector` references, the immutable types like `Attribute`, etc.**
@@ -130,7 +130,7 @@
operand_iterator operand_begin();
operand_iterator operand_end();
- /// Returns an iterator on the underlying Value's (Value *).
+ /// Returns an iterator on the underlying Value's (Value ).
operand_range getOperands();
// Support const operand iteration.
@@ -141,7 +141,7 @@
const_operand_iterator operand_begin() const;
const_operand_iterator operand_end() const;
- /// Returns a const iterator on the underlying Value's (Value *).
+ /// Returns a const iterator on the underlying Value's (Value ).
llvm::iterator_range<const_operand_iterator> getOperands() const;
ArrayRef<OpOperand> getOpOperands() const {
diff --git a/third_party/mlir/g3doc/WritingAPass.md b/third_party/mlir/g3doc/WritingAPass.md
index 7847571..5119c46 100644
--- a/third_party/mlir/g3doc/WritingAPass.md
+++ b/third_party/mlir/g3doc/WritingAPass.md
@@ -421,7 +421,8 @@
pass pipeline, e.g. `cse` or `canonicalize`.
* `options`
* Options are pass specific key value pairs that are handled as described
- in the instance specific pass options section.
+ in the [instance specific pass options](#instance-specific-pass-options)
+ section.
For example, the following pipeline:
@@ -443,30 +444,47 @@
### Instance Specific Pass Options
Options may be specified for a parametric pass. Individual options are defined
-using `llvm::cl::opt` flag definition rules. These options will then be parsed
-at pass construction time independently for each instance of the pass. The
-`PassRegistration` and `PassPipelineRegistration` templates take an additional
-optional template parameter that is the Option struct definition to be used for
-that pass. To use pass specific options, create a class that inherits from
-`mlir::PassOptions` and then add a new constructor that takes `const
-MyPassOptions&` and constructs the pass. When using `PassPipelineRegistration`,
-the constructor now takes a function with the signature `void (OpPassManager
-&pm, const MyPassOptions&)` which should construct the passes from the options
-and pass them to the pm. The user code will look like the following:
+using the [LLVM command line](https://llvm.org/docs/CommandLine.html) flag
+definition rules. These options will then be parsed at pass construction time
+independently for each instance of the pass. To provide options for passes, the
+`Option<>` and `OptionList<>` classes may be used:
```c++
-class MyPass ... {
-public:
- MyPass(const MyPassOptions& options) ...
-};
+struct MyPass ... {
+ /// Make sure that we have a valid default constructor and copy constructor to
+ /// make sure that the options are initialized properly.
+ MyPass() = default;
+ MyPass(const MyPass& pass) {}
-struct MyPassOptions : public PassOptions<MyPassOptions> {
// These just forward onto llvm::cl::list and llvm::cl::opt respectively.
Option<int> exampleOption{*this, "flag-name", llvm::cl::desc("...")};
- List<int> exampleListOption{*this, "list-flag-name", llvm::cl::desc("...")};
+ ListOption<int> exampleListOption{*this, "list-flag-name",
+ llvm::cl::desc("...")};
+};
+```
+
+For pass pipelines, the `PassPipelineRegistration` templates take an additional
+optional template parameter that is the Option struct definition to be used for
+that pipeline. To use pipeline specific options, create a class that inherits
+from `mlir::PassPipelineOptions` that contains the desired options. When using
+`PassPipelineRegistration`, the constructor now takes a function with the
+signature `void (OpPassManager &pm, const MyPipelineOptions&)` which should
+construct the passes from the options and pass them to the pm:
+
+```c++
+struct MyPipelineOptions : public PassPipelineOptions {
+ // These just forward onto llvm::cl::list and llvm::cl::opt respectively.
+ Option<int> exampleOption{*this, "flag-name", llvm::cl::desc("...")};
+ ListOption<int> exampleListOption{*this, "list-flag-name",
+ llvm::cl::desc("...")};
};
-static PassRegistration<MyPass, MyPassOptions> pass("my-pass", "description");
+
+static mlir::PassPipelineRegistration<MyPipelineOptions> pipeline(
+ "example-pipeline", "Run an example pipeline.",
+ [](OpPassManager &pm, const MyPipelineOptions &pipelineOptions) {
+ // Initialize the pass manager.
+ });
```
## Pass Statistics
diff --git a/third_party/mlir/include/mlir-c/Core.h b/third_party/mlir/include/mlir-c/Core.h
index c205e89..5e3e208 100644
--- a/third_party/mlir/include/mlir-c/Core.h
+++ b/third_party/mlir/include/mlir-c/Core.h
@@ -1,18 +1,9 @@
/*===-- mlir-c/Core.h - Core Library C Interface ------------------*- C -*-===*\
|* *|
-|* Copyright 2019 The MLIR Authors. *|
-|* *|
-|* Licensed under the Apache License, Version 2.0 (the "License"); *|
-|* you may not use this file except in compliance with the License. *|
-|* You may obtain a copy of the License at *|
-|* *|
-|* http://www.apache.org/licenses/LICENSE-2.0 *|
-|* *|
-|* Unless required by applicable law or agreed to in writing, software *|
-|* distributed under the License is distributed on an "AS IS" BASIS, *|
-|* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *|
-|* See the License for the specific language governing permissions and *|
-|* limitations under the License. *|
+|* Part of the MLIR Project, under the Apache License v2.0 with LLVM *|
+|* Exceptions. *|
+|* See https://llvm.org/LICENSE.txt for license information. *|
+|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception *|
|* *|
|*===----------------------------------------------------------------------===*|
|* *|
diff --git a/third_party/mlir/include/mlir/ADT/TypeSwitch.h b/third_party/mlir/include/mlir/ADT/TypeSwitch.h
index 75051b6..2dbc611 100644
--- a/third_party/mlir/include/mlir/ADT/TypeSwitch.h
+++ b/third_party/mlir/include/mlir/ADT/TypeSwitch.h
@@ -1,19 +1,10 @@
//===- TypeSwitch.h - Switch functionality for RTTI casting -*- C++ -*-----===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements the TypeSwitch template, which mimics a switch()
// statement whose cases are type names.
diff --git a/third_party/mlir/include/mlir/Analysis/AffineAnalysis.h b/third_party/mlir/include/mlir/Analysis/AffineAnalysis.h
index 8243d1f..d0bcb93 100644
--- a/third_party/mlir/include/mlir/Analysis/AffineAnalysis.h
+++ b/third_party/mlir/include/mlir/Analysis/AffineAnalysis.h
@@ -1,19 +1,10 @@
//===- AffineAnalysis.h - analyses for affine structures --------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This header file defines prototypes for methods that perform analysis
// involving affine structures (AffineExprStorage, AffineMap, IntegerSet, etc.)
@@ -24,9 +15,7 @@
#ifndef MLIR_ANALYSIS_AFFINE_ANALYSIS_H
#define MLIR_ANALYSIS_AFFINE_ANALYSIS_H
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-#include "llvm/ADT/ArrayRef.h"
+#include "mlir/IR/Value.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/SmallVector.h"
@@ -37,12 +26,11 @@
class AffineValueMap;
class FlatAffineConstraints;
class Operation;
-class Value;
/// Returns in `affineApplyOps`, the sequence of those AffineApplyOp
/// Operations that are reachable via a search starting from `operands` and
/// ending at those operands that are not the result of an AffineApplyOp.
-void getReachableAffineApplyOps(ArrayRef<Value *> operands,
+void getReachableAffineApplyOps(ArrayRef<Value> operands,
SmallVectorImpl<Operation *> &affineApplyOps);
/// Builds a system of constraints with dimensional identifiers corresponding to
@@ -56,9 +44,9 @@
/// Encapsulates a memref load or store access information.
struct MemRefAccess {
- Value *memref;
+ Value memref;
Operation *opInst;
- SmallVector<Value *, 4> indices;
+ SmallVector<Value, 4> indices;
/// Constructs a MemRefAccess from a load or store operation.
// TODO(b/119949820): add accessors to standard op's load, store, DMA op's to
diff --git a/third_party/mlir/include/mlir/Analysis/AffineStructures.h b/third_party/mlir/include/mlir/Analysis/AffineStructures.h
index e53af50..47e0dda 100644
--- a/third_party/mlir/include/mlir/Analysis/AffineStructures.h
+++ b/third_party/mlir/include/mlir/Analysis/AffineStructures.h
@@ -1,19 +1,10 @@
//===- AffineStructures.h - MLIR Affine Structures Class --------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Structures for affine/polyhedral analysis of ML functions.
//
@@ -123,8 +114,8 @@
// Creates an empty AffineValueMap (users should call 'reset' to reset map
// and operands).
AffineValueMap() {}
- AffineValueMap(AffineMap map, ArrayRef<Value *> operands,
- ArrayRef<Value *> results = llvm::None);
+ AffineValueMap(AffineMap map, ArrayRef<Value> operands,
+ ArrayRef<Value> results = llvm::None);
explicit AffineValueMap(AffineApplyOp applyOp);
explicit AffineValueMap(AffineBound bound);
@@ -132,8 +123,8 @@
~AffineValueMap();
// Resets this AffineValueMap with 'map', 'operands', and 'results'.
- void reset(AffineMap map, ArrayRef<Value *> operands,
- ArrayRef<Value *> results = llvm::None);
+ void reset(AffineMap map, ArrayRef<Value> operands,
+ ArrayRef<Value> results = llvm::None);
/// Return the value map that is the difference of value maps 'a' and 'b',
/// represented as an affine map and its operands. The output map + operands
@@ -146,7 +137,7 @@
inline bool isMultipleOf(unsigned idx, int64_t factor) const;
/// Return true if the idx^th result depends on 'value', false otherwise.
- bool isFunctionOf(unsigned idx, Value *value) const;
+ bool isFunctionOf(unsigned idx, Value value) const;
/// Return true if the result at 'idx' is a constant, false
/// otherwise.
@@ -162,8 +153,8 @@
inline unsigned getNumSymbols() const { return map.getNumSymbols(); }
inline unsigned getNumResults() const { return map.getNumResults(); }
- Value *getOperand(unsigned i) const;
- ArrayRef<Value *> getOperands() const;
+ Value getOperand(unsigned i) const;
+ ArrayRef<Value> getOperands() const;
AffineMap getAffineMap() const;
private:
@@ -172,9 +163,9 @@
// TODO: make these trailing objects?
/// The SSA operands binding to the dim's and symbols of 'map'.
- SmallVector<Value *, 4> operands;
+ SmallVector<Value, 4> operands;
/// The SSA results binding to the results of 'map'.
- SmallVector<Value *, 4> results;
+ SmallVector<Value, 4> results;
};
/// An IntegerValueSet is an integer set plus its operands.
@@ -207,7 +198,7 @@
// 'AffineCondition'.
MutableIntegerSet set;
/// The SSA operands binding to the dim's and symbols of 'set'.
- SmallVector<Value *, 4> operands;
+ SmallVector<Value, 4> operands;
};
/// A flat list of affine equalities and inequalities in the form.
@@ -245,7 +236,7 @@
unsigned numReservedEqualities,
unsigned numReservedCols, unsigned numDims = 0,
unsigned numSymbols = 0, unsigned numLocals = 0,
- ArrayRef<Optional<Value *>> idArgs = {})
+ ArrayRef<Optional<Value>> idArgs = {})
: numReservedCols(numReservedCols), numDims(numDims),
numSymbols(numSymbols) {
assert(numReservedCols >= numDims + numSymbols + 1);
@@ -264,7 +255,7 @@
/// dimensions and symbols.
FlatAffineConstraints(unsigned numDims = 0, unsigned numSymbols = 0,
unsigned numLocals = 0,
- ArrayRef<Optional<Value *>> idArgs = {})
+ ArrayRef<Optional<Value>> idArgs = {})
: numReservedCols(numDims + numSymbols + numLocals + 1), numDims(numDims),
numSymbols(numSymbols) {
assert(numReservedCols >= numDims + numSymbols + 1);
@@ -304,10 +295,10 @@
// Clears any existing data and reserves memory for the specified constraints.
void reset(unsigned numReservedInequalities, unsigned numReservedEqualities,
unsigned numReservedCols, unsigned numDims, unsigned numSymbols,
- unsigned numLocals = 0, ArrayRef<Value *> idArgs = {});
+ unsigned numLocals = 0, ArrayRef<Value> idArgs = {});
void reset(unsigned numDims = 0, unsigned numSymbols = 0,
- unsigned numLocals = 0, ArrayRef<Value *> idArgs = {});
+ unsigned numLocals = 0, ArrayRef<Value> idArgs = {});
/// Appends constraints from 'other' into this. This is equivalent to an
/// intersection with no simplification of any sort attempted.
@@ -396,7 +387,7 @@
/// operands. If `eq` is true, add a single equality equal to the bound map's
/// first result expr.
LogicalResult addLowerOrUpperBound(unsigned pos, AffineMap boundMap,
- ArrayRef<Value *> operands, bool eq,
+ ArrayRef<Value> operands, bool eq,
bool lower = true);
/// Computes the lower and upper bounds of the first 'num' dimensional
@@ -415,10 +406,10 @@
/// operand list 'operands'.
/// This function assumes 'values.size' == 'lbMaps.size' == 'ubMaps.size'.
/// Note that both lower/upper bounds use operands from 'operands'.
- LogicalResult addSliceBounds(ArrayRef<Value *> values,
+ LogicalResult addSliceBounds(ArrayRef<Value> values,
ArrayRef<AffineMap> lbMaps,
ArrayRef<AffineMap> ubMaps,
- ArrayRef<Value *> operands);
+ ArrayRef<Value> operands);
// Adds an inequality (>= 0) from the coefficients specified in inEq.
void addInequality(ArrayRef<int64_t> inEq);
@@ -447,25 +438,25 @@
/// Sets the identifier corresponding to the specified Value id to a
/// constant. Asserts if the 'id' is not found.
- void setIdToConstant(Value &id, int64_t val);
+ void setIdToConstant(Value id, int64_t val);
/// Looks up the position of the identifier with the specified Value. Returns
/// true if found (false otherwise). `pos' is set to the (column) position of
/// the identifier.
- bool findId(Value &id, unsigned *pos) const;
+ bool findId(Value id, unsigned *pos) const;
/// Returns true if an identifier with the specified Value exists, false
/// otherwise.
- bool containsId(Value &id) const;
+ bool containsId(Value id) const;
// Add identifiers of the specified kind - specified positions are relative to
// the kind of identifier. The coefficient column corresponding to the added
// identifier is initialized to zero. 'id' is the Value corresponding to the
// identifier that can optionally be provided.
- void addDimId(unsigned pos, Value *id = nullptr);
- void addSymbolId(unsigned pos, Value *id = nullptr);
+ void addDimId(unsigned pos, Value id = nullptr);
+ void addSymbolId(unsigned pos, Value id = nullptr);
void addLocalId(unsigned pos);
- void addId(IdKind kind, unsigned pos, Value *id = nullptr);
+ void addId(IdKind kind, unsigned pos, Value id = nullptr);
/// Add the specified values as a dim or symbol id depending on its nature, if
/// it already doesn't exist in the system. `id' has to be either a terminal
@@ -473,7 +464,7 @@
/// symbols or loop IVs. The identifier is added to the end of the existing
/// dims or symbols. Additional information on the identifier is extracted
/// from the IR and added to the constraint system.
- void addInductionVarOrTerminalSymbol(Value *id);
+ void addInductionVarOrTerminalSymbol(Value id);
/// Composes the affine value map with this FlatAffineConstrains, adding the
/// results of the map as dimensions at the front [0, vMap->getNumResults())
@@ -500,8 +491,8 @@
void projectOut(unsigned pos, unsigned num);
inline void projectOut(unsigned pos) { return projectOut(pos, 1); }
- /// Projects out the identifier that is associate with Value *.
- void projectOut(Value *id);
+ /// Projects out the identifier that is associate with Value .
+ void projectOut(Value id);
void removeId(IdKind idKind, unsigned pos);
void removeId(unsigned pos);
@@ -577,20 +568,20 @@
return numIds - numDims - numSymbols;
}
- inline ArrayRef<Optional<Value *>> getIds() const {
+ inline ArrayRef<Optional<Value>> getIds() const {
return {ids.data(), ids.size()};
}
- inline MutableArrayRef<Optional<Value *>> getIds() {
+ inline MutableArrayRef<Optional<Value>> getIds() {
return {ids.data(), ids.size()};
}
/// Returns the optional Value corresponding to the pos^th identifier.
- inline Optional<Value *> getId(unsigned pos) const { return ids[pos]; }
- inline Optional<Value *> &getId(unsigned pos) { return ids[pos]; }
+ inline Optional<Value> getId(unsigned pos) const { return ids[pos]; }
+ inline Optional<Value> &getId(unsigned pos) { return ids[pos]; }
/// Returns the Value associated with the pos^th identifier. Asserts if
/// no Value identifier was associated.
- inline Value *getIdValue(unsigned pos) const {
+ inline Value getIdValue(unsigned pos) const {
assert(ids[pos].hasValue() && "identifier's Value not set");
return ids[pos].getValue();
}
@@ -598,7 +589,7 @@
/// Returns the Values associated with identifiers in range [start, end).
/// Asserts if no Value was associated with one of these identifiers.
void getIdValues(unsigned start, unsigned end,
- SmallVectorImpl<Value *> *values) const {
+ SmallVectorImpl<Value> *values) const {
assert((start < numIds || start == end) && "invalid start position");
assert(end <= numIds && "invalid end position");
values->clear();
@@ -607,17 +598,17 @@
values->push_back(getIdValue(i));
}
}
- inline void getAllIdValues(SmallVectorImpl<Value *> *values) const {
+ inline void getAllIdValues(SmallVectorImpl<Value> *values) const {
getIdValues(0, numIds, values);
}
/// Sets Value associated with the pos^th identifier.
- inline void setIdValue(unsigned pos, Value *val) {
+ inline void setIdValue(unsigned pos, Value val) {
assert(pos < numIds && "invalid id position");
ids[pos] = val;
}
/// Sets Values associated with identifiers in the range [start, end).
- void setIdValues(unsigned start, unsigned end, ArrayRef<Value *> values) {
+ void setIdValues(unsigned start, unsigned end, ArrayRef<Value> values) {
assert((start < numIds || end == start) && "invalid start position");
assert(end <= numIds && "invalid end position");
assert(values.size() == end - start);
@@ -766,7 +757,7 @@
/// system appearing in the order the identifiers correspond to columns.
/// Temporary ones or those that aren't associated to any Value are set to
/// None.
- SmallVector<Optional<Value *>, 8> ids;
+ SmallVector<Optional<Value>, 8> ids;
/// A parameter that controls detection of an unrealistic number of
/// constraints. If the number of constraints is this many times the number of
diff --git a/third_party/mlir/include/mlir/Analysis/CallGraph.h b/third_party/mlir/include/mlir/Analysis/CallGraph.h
index 700a016..8f95416 100644
--- a/third_party/mlir/include/mlir/Analysis/CallGraph.h
+++ b/third_party/mlir/include/mlir/Analysis/CallGraph.h
@@ -1,19 +1,10 @@
//===- CallGraph.h - CallGraph analysis for MLIR ----------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file contains an analysis for computing the multi-level callgraph from a
// given top-level operation. This nodes within this callgraph are defined by
diff --git a/third_party/mlir/include/mlir/Analysis/CallInterfaces.h b/third_party/mlir/include/mlir/Analysis/CallInterfaces.h
index dd23d77..b5870ba 100644
--- a/third_party/mlir/include/mlir/Analysis/CallInterfaces.h
+++ b/third_party/mlir/include/mlir/Analysis/CallInterfaces.h
@@ -1,19 +1,10 @@
//===- CallInterfaces.h - Call Interfaces for MLIR --------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file contains the definitions of the call interfaces defined in
// `CallInterfaces.td`.
@@ -30,8 +21,8 @@
/// A callable is either a symbol, or an SSA value, that is referenced by a
/// call-like operation. This represents the destination of the call.
-struct CallInterfaceCallable : public PointerUnion<SymbolRefAttr, Value *> {
- using PointerUnion<SymbolRefAttr, Value *>::PointerUnion;
+struct CallInterfaceCallable : public PointerUnion<SymbolRefAttr, Value> {
+ using PointerUnion<SymbolRefAttr, Value>::PointerUnion;
};
#include "mlir/Analysis/CallInterfaces.h.inc"
diff --git a/third_party/mlir/include/mlir/Analysis/CallInterfaces.td b/third_party/mlir/include/mlir/Analysis/CallInterfaces.td
index 043f009..3e5b599 100644
--- a/third_party/mlir/include/mlir/Analysis/CallInterfaces.td
+++ b/third_party/mlir/include/mlir/Analysis/CallInterfaces.td
@@ -1,19 +1,10 @@
//===- CallInterfaces.td - Call Interfaces for ops -*- tablegen ---------*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file contains a set of interfaces that can be used to define information
// related to call-like and callable operations. Each of which are defined along
diff --git a/third_party/mlir/include/mlir/Analysis/Dominance.h b/third_party/mlir/include/mlir/Analysis/Dominance.h
index 09114ea..ead54b9 100644
--- a/third_party/mlir/include/mlir/Analysis/Dominance.h
+++ b/third_party/mlir/include/mlir/Analysis/Dominance.h
@@ -1,19 +1,10 @@
//===- Dominance.h - Dominator analysis for CFGs ----------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_ANALYSIS_DOMINANCE_H
#define MLIR_ANALYSIS_DOMINANCE_H
@@ -74,10 +65,10 @@
}
/// Return true if value A properly dominates operation B.
- bool properlyDominates(Value *a, Operation *b);
+ bool properlyDominates(Value a, Operation *b);
/// Return true if operation A dominates operation B.
- bool dominates(Value *a, Operation *b) {
+ bool dominates(Value a, Operation *b) {
return (Operation *)a->getDefiningOp() == b || properlyDominates(a, b);
}
diff --git a/third_party/mlir/include/mlir/Analysis/InferTypeOpInterface.h b/third_party/mlir/include/mlir/Analysis/InferTypeOpInterface.h
index 2d68ada..baf1616 100644
--- a/third_party/mlir/include/mlir/Analysis/InferTypeOpInterface.h
+++ b/third_party/mlir/include/mlir/Analysis/InferTypeOpInterface.h
@@ -1,19 +1,10 @@
//===- InferTypeOpInterface.h - Infer Type Interfaces -----------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file contains the definitions of the infer op interfaces defined in
// `InferTypeOpInterface.td`.
diff --git a/third_party/mlir/include/mlir/Analysis/InferTypeOpInterface.td b/third_party/mlir/include/mlir/Analysis/InferTypeOpInterface.td
index 14d5809..bbcea6b 100644
--- a/third_party/mlir/include/mlir/Analysis/InferTypeOpInterface.td
+++ b/third_party/mlir/include/mlir/Analysis/InferTypeOpInterface.td
@@ -1,19 +1,10 @@
//===- InferTypeOpInterface.td - Infer Type interfaces -----*- tablegen -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file contains a set of interfaces that can be used to define information
// related to type inference.
diff --git a/third_party/mlir/include/mlir/Analysis/Liveness.h b/third_party/mlir/include/mlir/Analysis/Liveness.h
index 0bdb474..7e1dc29 100644
--- a/third_party/mlir/include/mlir/Analysis/Liveness.h
+++ b/third_party/mlir/include/mlir/Analysis/Liveness.h
@@ -1,19 +1,10 @@
//===- Liveness.h - Liveness analysis for MLIR ------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file contains an analysis for computing liveness information from a
// given top-level operation. The current version of the analysis uses a
@@ -57,7 +48,7 @@
public:
using OperationListT = std::vector<Operation *>;
using BlockMapT = DenseMap<Block *, LivenessBlockInfo>;
- using ValueSetT = SmallPtrSet<Value *, 16>;
+ using ValueSetT = SmallPtrSet<Value, 16>;
public:
/// Creates a new Liveness analysis that computes liveness
@@ -72,7 +63,7 @@
/// Note that the operations in this list are not ordered and the current
/// implementation is computationally expensive (as it iterates over all
/// blocks in which the given value is live).
- OperationListT resolveLiveness(Value *value) const;
+ OperationListT resolveLiveness(Value value) const;
/// Gets liveness info (if any) for the block.
const LivenessBlockInfo *getLiveness(Block *block) const;
@@ -85,7 +76,7 @@
/// Returns true if the given operation represent the last use of the
/// given value.
- bool isLastUse(Value *value, Operation *operation) const;
+ bool isLastUse(Value value, Operation *operation) const;
/// Dumps the liveness information in a human readable format.
void dump() const;
@@ -124,20 +115,20 @@
const ValueSetT &out() const { return outValues; }
/// Returns true if the given value is in the live-in set.
- bool isLiveIn(Value *value) const;
+ bool isLiveIn(Value value) const;
/// Returns true if the given value is in the live-out set.
- bool isLiveOut(Value *value) const;
+ bool isLiveOut(Value value) const;
/// Gets the start operation for the given value. This is the first operation
/// the given value is considered to be live. This could either be the start
/// operation of the current block (in case the value is live-in) or the
/// operation that defines the given value (must be referenced in this block).
- Operation *getStartOperation(Value *value) const;
+ Operation *getStartOperation(Value value) const;
/// Gets the end operation for the given value using the start operation
/// provided (must be referenced in this block).
- Operation *getEndOperation(Value *value, Operation *startOperation) const;
+ Operation *getEndOperation(Value value, Operation *startOperation) const;
private:
/// The underlying block.
diff --git a/third_party/mlir/include/mlir/Analysis/LoopAnalysis.h b/third_party/mlir/include/mlir/Analysis/LoopAnalysis.h
index 47cc22a..0dd89e4 100644
--- a/third_party/mlir/include/mlir/Analysis/LoopAnalysis.h
+++ b/third_party/mlir/include/mlir/Analysis/LoopAnalysis.h
@@ -1,19 +1,10 @@
//===- LoopAnalysis.h - loop analysis methods -------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This header file defines prototypes for methods to analyze loops.
//
@@ -45,7 +36,7 @@
// TODO(mlir-team): this should be moved into 'Transforms/' and be replaced by a
// pure analysis method relying on FlatAffineConstraints
void buildTripCountMapAndOperands(AffineForOp forOp, AffineMap *map,
- SmallVectorImpl<Value *> *operands);
+ SmallVectorImpl<Value> *operands);
/// Returns the trip count of the loop if it's a constant, None otherwise. This
/// uses affine expression analysis and is able to determine constant trip count
@@ -66,8 +57,8 @@
///
/// Emits a note if it encounters a chain of affine.apply and conservatively
/// those cases.
-DenseSet<Value *, DenseMapInfo<Value *>>
-getInvariantAccesses(Value *iv, ArrayRef<Value *> indices);
+DenseSet<Value, DenseMapInfo<Value>>
+getInvariantAccesses(Value iv, ArrayRef<Value> indices);
using VectorizableLoopFun = std::function<bool(AffineForOp)>;
diff --git a/third_party/mlir/include/mlir/Analysis/NestedMatcher.h b/third_party/mlir/include/mlir/Analysis/NestedMatcher.h
index 9af26e8..2da64e8 100644
--- a/third_party/mlir/include/mlir/Analysis/NestedMatcher.h
+++ b/third_party/mlir/include/mlir/Analysis/NestedMatcher.h
@@ -1,19 +1,10 @@
//===- NestedMacher.h - Nested matcher for Function -------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_ANALYSIS_MLFUNCTIONMATCHER_H_
#define MLIR_ANALYSIS_MLFUNCTIONMATCHER_H_
diff --git a/third_party/mlir/include/mlir/Analysis/Passes.h b/third_party/mlir/include/mlir/Analysis/Passes.h
index b233ab5..0bbc850 100644
--- a/third_party/mlir/include/mlir/Analysis/Passes.h
+++ b/third_party/mlir/include/mlir/Analysis/Passes.h
@@ -1,19 +1,10 @@
//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This header file defines prototypes that expose pass constructors in the
// analysis library.
diff --git a/third_party/mlir/include/mlir/Analysis/SliceAnalysis.h b/third_party/mlir/include/mlir/Analysis/SliceAnalysis.h
index ad6b653..d7b6e95 100644
--- a/third_party/mlir/include/mlir/Analysis/SliceAnalysis.h
+++ b/third_party/mlir/include/mlir/Analysis/SliceAnalysis.h
@@ -1,19 +1,10 @@
//===- SliceAnalysis.h - Analysis for Transitive UseDef chains --*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_ANALYSIS_SLICEANALYSIS_H_
#define MLIR_ANALYSIS_SLICEANALYSIS_H_
diff --git a/third_party/mlir/include/mlir/Analysis/Utils.h b/third_party/mlir/include/mlir/Analysis/Utils.h
index cffa222..7cf1e5c 100644
--- a/third_party/mlir/include/mlir/Analysis/Utils.h
+++ b/third_party/mlir/include/mlir/Analysis/Utils.h
@@ -1,19 +1,10 @@
//===- Utils.h - General analysis utilities ---------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This header file defines prototypes for various transformation utilities for
// memref's and non-loop IR structures. These are not passes by themselves but
@@ -55,7 +46,7 @@
/// Returns in 'sequentialLoops' all sequential loops in loop nest rooted
/// at 'forOp'.
void getSequentialLoops(AffineForOp forOp,
- llvm::SmallDenseSet<Value *, 8> *sequentialLoops);
+ llvm::SmallDenseSet<Value, 8> *sequentialLoops);
/// ComputationSliceState aggregates loop IVs, loop bound AffineMaps and their
/// associated operands for a set of loops within a loop nest (typically the
@@ -64,15 +55,15 @@
struct ComputationSliceState {
// List of sliced loop IVs (ordered from outermost to innermost).
// EX: 'ivs[i]' has lower bound 'lbs[i]' and upper bound 'ubs[i]'.
- SmallVector<Value *, 4> ivs;
+ SmallVector<Value, 4> ivs;
// List of lower bound AffineMaps.
SmallVector<AffineMap, 4> lbs;
// List of upper bound AffineMaps.
SmallVector<AffineMap, 4> ubs;
// List of lower bound operands (lbOperands[i] are used by 'lbs[i]').
- std::vector<SmallVector<Value *, 4>> lbOperands;
+ std::vector<SmallVector<Value, 4>> lbOperands;
// List of upper bound operands (ubOperands[i] are used by 'ubs[i]').
- std::vector<SmallVector<Value *, 4>> ubOperands;
+ std::vector<SmallVector<Value, 4>> ubOperands;
// Slice loop nest insertion point in target loop nest.
Block::iterator insertPoint;
// Adds to 'cst' with constraints which represent the slice bounds on 'ivs'
@@ -257,7 +248,7 @@
unsigned getRank() const;
/// Memref that this region corresponds to.
- Value *memref;
+ Value memref;
/// Read or write.
bool write;
diff --git a/third_party/mlir/include/mlir/Analysis/Verifier.h b/third_party/mlir/include/mlir/Analysis/Verifier.h
index daaff57..b7075b4 100644
--- a/third_party/mlir/include/mlir/Analysis/Verifier.h
+++ b/third_party/mlir/include/mlir/Analysis/Verifier.h
@@ -1,19 +1,10 @@
//===- Verifier.h - Verifier analysis for MLIR structures -------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_ANALYSIS_VERIFIER_H
#define MLIR_ANALYSIS_VERIFIER_H
diff --git a/third_party/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h b/third_party/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h
index b5c51ad..c6a2fac 100644
--- a/third_party/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h
+++ b/third_party/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h
@@ -1,19 +1,10 @@
//===- AffineToStandard.h - Convert Affine to Standard dialect --*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_AFFINETOSTANDARD_AFFINETOSTANDARD_H
#define MLIR_CONVERSION_AFFINETOSTANDARD_AFFINETOSTANDARD_H
@@ -35,9 +26,8 @@
/// Emit code that computes the given affine expression using standard
/// arithmetic operations applied to the provided dimension and symbol values.
-Value *expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr,
- ArrayRef<Value *> dimValues,
- ArrayRef<Value *> symbolValues);
+Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr,
+ ArrayRef<Value> dimValues, ArrayRef<Value> symbolValues);
/// Collect a set of patterns to convert from the Affine dialect to the Standard
/// dialect, in particular convert structured affine control flow into CFG
@@ -47,11 +37,11 @@
/// Emit code that computes the lower bound of the given affine loop using
/// standard arithmetic operations.
-Value *lowerAffineLowerBound(AffineForOp op, OpBuilder &builder);
+Value lowerAffineLowerBound(AffineForOp op, OpBuilder &builder);
/// Emit code that computes the upper bound of the given affine loop using
/// standard arithmetic operations.
-Value *lowerAffineUpperBound(AffineForOp op, OpBuilder &builder);
+Value lowerAffineUpperBound(AffineForOp op, OpBuilder &builder);
} // namespace mlir
#endif // MLIR_CONVERSION_AFFINETOSTANDARD_AFFINETOSTANDARD_H
diff --git a/third_party/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h b/third_party/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h
index 6b9b08e..4eb6379 100644
--- a/third_party/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h
+++ b/third_party/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h
@@ -1,19 +1,10 @@
//===- GPUToCUDAPass.h - MLIR CUDA runtime support --------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_
#define MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_
diff --git a/third_party/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/third_party/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
index 635d436..75e4f7e 100644
--- a/third_party/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
+++ b/third_party/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
@@ -1,19 +1,10 @@
//===- GPUToNVVMPass.h - Convert GPU kernel to NVVM dialect -----*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_
#define MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_
diff --git a/third_party/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h b/third_party/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
index 54cda41..e913c2e 100644
--- a/third_party/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
+++ b/third_party/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
@@ -1,19 +1,10 @@
//===- GPUToROCDLPass.h - Convert GPU kernel to ROCDL dialect ---*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_GPUTOROCDL_GPUTOROCDLPASS_H_
#define MLIR_CONVERSION_GPUTOROCDL_GPUTOROCDLPASS_H_
diff --git a/third_party/mlir/include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h b/third_party/mlir/include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h
index 134dbf4..762a6e5 100644
--- a/third_party/mlir/include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h
+++ b/third_party/mlir/include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h
@@ -1,19 +1,10 @@
//===- ConvertGPUToSPIRV.h - GPU Ops to SPIR-V dialect patterns ----C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Provides patterns for lowering GPU Ops to SPIR-V dialect.
//
diff --git a/third_party/mlir/include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h b/third_party/mlir/include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h
index 8f0a910..37230f4 100644
--- a/third_party/mlir/include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h
+++ b/third_party/mlir/include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h
@@ -1,19 +1,10 @@
//===- ConvertGPUToSPIRVPass.h - GPU to SPIR-V conversion pass --*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Provides a pass to convert GPU ops to SPIRV ops.
//
diff --git a/third_party/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h b/third_party/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h
index 6bae08e..2795017 100644
--- a/third_party/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h
+++ b/third_party/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h
@@ -1,19 +1,10 @@
//===- LinalgToLLVM.h - Utils to convert from the linalg dialect ----------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_LINALGTOLLVM_LINALGTOLLVM_H_
#define MLIR_CONVERSION_LINALGTOLLVM_LINALGTOLLVM_H_
diff --git a/third_party/mlir/include/mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h b/third_party/mlir/include/mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h
index 095c9f4..5cb8f59 100644
--- a/third_party/mlir/include/mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h
+++ b/third_party/mlir/include/mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h
@@ -1,19 +1,10 @@
//===- ConvertLoopToStandard.h - Pass entrypoint ----------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_LOOPTOSTANDARD_CONVERTLOOPTOSTANDARD_H_
#define MLIR_CONVERSION_LOOPTOSTANDARD_CONVERTLOOPTOSTANDARD_H_
diff --git a/third_party/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h b/third_party/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h
index 0aab872..80faa03 100644
--- a/third_party/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h
+++ b/third_party/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h
@@ -1,19 +1,10 @@
//===- LoopsToGPU.h - Convert loop nests to GPU kernels ---------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPU_H_
#define MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPU_H_
@@ -78,8 +69,8 @@
/// The above conditions are assumed to be satisfied by the computation rooted
/// at `forOp`.
LogicalResult convertLoopToGPULaunch(loop::ForOp forOp,
- ArrayRef<Value *> numWorkGroups,
- ArrayRef<Value *> workGroupSizes);
+ ArrayRef<Value> numWorkGroups,
+ ArrayRef<Value> workGroupSizes);
} // namespace mlir
diff --git a/third_party/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h b/third_party/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h
index a42320c..a3d663a 100644
--- a/third_party/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h
+++ b/third_party/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h
@@ -1,19 +1,10 @@
//===- LoopsToGPUPass.h - Pass converting loops to GPU kernels --*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPUPASS_H_
#define MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPUPASS_H_
diff --git a/third_party/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/third_party/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index e8d16f0..e78859f 100644
--- a/third_party/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/third_party/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -1,19 +1,10 @@
//===- ConvertStandardToLLVM.h - Convert to the LLVM dialect ----*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Provides a dialect conversion targeting the LLVM IR dialect. By default, it
// converts Standard ops and types and provides hooks for dialect-specific
@@ -74,16 +65,16 @@
/// Promote the LLVM struct representation of all MemRef descriptors to stack
/// and use pointers to struct to avoid the complexity of the
/// platform-specific C/C++ ABI lowering related to struct argument passing.
- SmallVector<Value *, 4> promoteMemRefDescriptors(Location loc,
- ValueRange opOperands,
- ValueRange operands,
- OpBuilder &builder);
+ SmallVector<Value, 4> promoteMemRefDescriptors(Location loc,
+ ValueRange opOperands,
+ ValueRange operands,
+ OpBuilder &builder);
/// Promote the LLVM struct representation of one MemRef descriptor to stack
/// and use pointer to struct to avoid the complexity of the platform-specific
/// C/C++ ABI lowering related to struct argument passing.
- Value *promoteOneMemRefDescriptor(Location loc, Value *operand,
- OpBuilder &builder);
+ Value promoteOneMemRefDescriptor(Location loc, Value operand,
+ OpBuilder &builder);
protected:
/// LLVM IR module used to parse/create types.
@@ -139,24 +130,24 @@
class StructBuilder {
public:
/// Construct a helper for the given value.
- explicit StructBuilder(Value *v);
+ explicit StructBuilder(Value v);
/// Builds IR creating an `undef` value of the descriptor type.
static StructBuilder undef(OpBuilder &builder, Location loc,
Type descriptorType);
- /*implicit*/ operator Value *() { return value; }
+ /*implicit*/ operator Value() { return value; }
protected:
// LLVM value
- Value *value;
+ Value value;
// Cached struct type.
Type structType;
protected:
/// Builds IR to extract a value from the struct at position pos
- Value *extractPtr(OpBuilder &builder, Location loc, unsigned pos);
+ Value extractPtr(OpBuilder &builder, Location loc, unsigned pos);
/// Builds IR to set a value in the struct at position pos
- void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value *ptr);
+ void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr);
};
/// Helper class to produce LLVM dialect operations extracting or inserting
/// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor.
@@ -164,7 +155,7 @@
class MemRefDescriptor : public StructBuilder {
public:
/// Construct a helper for the given descriptor value.
- explicit MemRefDescriptor(Value *descriptor);
+ explicit MemRefDescriptor(Value descriptor);
/// Builds IR creating an `undef` value of the descriptor type.
static MemRefDescriptor undef(OpBuilder &builder, Location loc,
Type descriptorType);
@@ -173,39 +164,39 @@
/// type.
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
- MemRefType type, Value *memory);
+ MemRefType type, Value memory);
/// Builds IR extracting the allocated pointer from the descriptor.
- Value *allocatedPtr(OpBuilder &builder, Location loc);
+ Value allocatedPtr(OpBuilder &builder, Location loc);
/// Builds IR inserting the allocated pointer into the descriptor.
- void setAllocatedPtr(OpBuilder &builder, Location loc, Value *ptr);
+ void setAllocatedPtr(OpBuilder &builder, Location loc, Value ptr);
/// Builds IR extracting the aligned pointer from the descriptor.
- Value *alignedPtr(OpBuilder &builder, Location loc);
+ Value alignedPtr(OpBuilder &builder, Location loc);
/// Builds IR inserting the aligned pointer into the descriptor.
- void setAlignedPtr(OpBuilder &builder, Location loc, Value *ptr);
+ void setAlignedPtr(OpBuilder &builder, Location loc, Value ptr);
/// Builds IR extracting the offset from the descriptor.
- Value *offset(OpBuilder &builder, Location loc);
+ Value offset(OpBuilder &builder, Location loc);
/// Builds IR inserting the offset into the descriptor.
- void setOffset(OpBuilder &builder, Location loc, Value *offset);
+ void setOffset(OpBuilder &builder, Location loc, Value offset);
void setConstantOffset(OpBuilder &builder, Location loc, uint64_t offset);
/// Builds IR extracting the pos-th size from the descriptor.
- Value *size(OpBuilder &builder, Location loc, unsigned pos);
+ Value size(OpBuilder &builder, Location loc, unsigned pos);
/// Builds IR inserting the pos-th size into the descriptor
- void setSize(OpBuilder &builder, Location loc, unsigned pos, Value *size);
+ void setSize(OpBuilder &builder, Location loc, unsigned pos, Value size);
void setConstantSize(OpBuilder &builder, Location loc, unsigned pos,
uint64_t size);
/// Builds IR extracting the pos-th size from the descriptor.
- Value *stride(OpBuilder &builder, Location loc, unsigned pos);
+ Value stride(OpBuilder &builder, Location loc, unsigned pos);
/// Builds IR inserting the pos-th stride into the descriptor
- void setStride(OpBuilder &builder, Location loc, unsigned pos, Value *stride);
+ void setStride(OpBuilder &builder, Location loc, unsigned pos, Value stride);
void setConstantStride(OpBuilder &builder, Location loc, unsigned pos,
uint64_t stride);
@@ -220,19 +211,19 @@
class UnrankedMemRefDescriptor : public StructBuilder {
public:
/// Construct a helper for the given descriptor value.
- explicit UnrankedMemRefDescriptor(Value *descriptor);
+ explicit UnrankedMemRefDescriptor(Value descriptor);
/// Builds IR creating an `undef` value of the descriptor type.
static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc,
Type descriptorType);
/// Builds IR extracting the rank from the descriptor
- Value *rank(OpBuilder &builder, Location loc);
+ Value rank(OpBuilder &builder, Location loc);
/// Builds IR setting the rank in the descriptor
- void setRank(OpBuilder &builder, Location loc, Value *value);
+ void setRank(OpBuilder &builder, Location loc, Value value);
/// Builds IR extracting ranked memref descriptor ptr
- Value *memRefDescPtr(OpBuilder &builder, Location loc);
+ Value memRefDescPtr(OpBuilder &builder, Location loc);
/// Builds IR setting ranked memref descriptor ptr
- void setMemRefDescPtr(OpBuilder &builder, Location loc, Value *value);
+ void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value);
};
/// Base class for operation conversions targeting the LLVM IR dialect. Provides
/// conversion patterns with an access to the containing LLVMLowering for the
diff --git a/third_party/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/third_party/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
index d49c1c2..a4d95da 100644
--- a/third_party/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
+++ b/third_party/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
@@ -1,19 +1,10 @@
//===- ConvertStandardToLLVMPass.h - Pass entrypoint ------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVMPASS_H_
#define MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVMPASS_H_
diff --git a/third_party/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h b/third_party/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h
index 4caa6d9..e0e8740 100644
--- a/third_party/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h
+++ b/third_party/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h
@@ -1,19 +1,10 @@
//===- ConvertStandardToSPIRV.h - Convert to SPIR-V dialect -----*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Provides patterns to lower StandardOps to SPIR-V dialect.
//
diff --git a/third_party/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h b/third_party/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h
index e8a71fe..7dbaf1c 100644
--- a/third_party/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h
+++ b/third_party/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h
@@ -1,19 +1,10 @@
//===- ConvertStandardToSPIRVPass.h - StdOps to SPIR-V pass -----*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Provides a pass to lower from StandardOps to SPIR-V dialect.
//
diff --git a/third_party/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/third_party/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
index a87e1c6..b8b97c2 100644
--- a/third_party/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
+++ b/third_party/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
@@ -1,19 +1,10 @@
//===- ConvertVectorToLLVM.h - Utils to convert from the vector dialect ---===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLLVM_H_
#define MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLLVM_H_
diff --git a/third_party/mlir/include/mlir/Conversion/VectorToLoops/ConvertVectorToLoops.h b/third_party/mlir/include/mlir/Conversion/VectorToLoops/ConvertVectorToLoops.h
index 198eace..4f7d084 100644
--- a/third_party/mlir/include/mlir/Conversion/VectorToLoops/ConvertVectorToLoops.h
+++ b/third_party/mlir/include/mlir/Conversion/VectorToLoops/ConvertVectorToLoops.h
@@ -1,19 +1,10 @@
//===- ConvertVectorToLoops.h - Utils to convert from the vector dialect --===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLOOPS_H_
#define MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLOOPS_H_
diff --git a/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.h b/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.h
index 36b4e55..b884ac5 100644
--- a/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.h
+++ b/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.h
@@ -1,19 +1,10 @@
//===- AffineOps.h - MLIR Affine Operations -------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines convenience types for working with Affine operations
// in the MLIR operation set.
@@ -41,7 +32,7 @@
/// A utility function to check if a value is defined at the top level of a
/// function. A value of index type defined at the top level is always a valid
/// symbol.
-bool isTopLevelValue(Value *value);
+bool isTopLevelValue(Value value);
class AffineOpsDialect : public Dialect {
public:
@@ -148,18 +139,17 @@
public:
using Op::Op;
- static void build(Builder *builder, OperationState &result, Value *srcMemRef,
- AffineMap srcMap, ValueRange srcIndices, Value *destMemRef,
- AffineMap dstMap, ValueRange destIndices, Value *tagMemRef,
- AffineMap tagMap, ValueRange tagIndices, Value *numElements,
- Value *stride = nullptr,
- Value *elementsPerStride = nullptr);
+ static void build(Builder *builder, OperationState &result, Value srcMemRef,
+ AffineMap srcMap, ValueRange srcIndices, Value destMemRef,
+ AffineMap dstMap, ValueRange destIndices, Value tagMemRef,
+ AffineMap tagMap, ValueRange tagIndices, Value numElements,
+ Value stride = nullptr, Value elementsPerStride = nullptr);
/// Returns the operand index of the src memref.
unsigned getSrcMemRefOperandIndex() { return 0; }
/// Returns the source MemRefType for this DMA operation.
- Value *getSrcMemRef() { return getOperand(getSrcMemRefOperandIndex()); }
+ Value getSrcMemRef() { return getOperand(getSrcMemRefOperandIndex()); }
MemRefType getSrcMemRefType() {
return getSrcMemRef()->getType().cast<MemRefType>();
}
@@ -191,7 +181,7 @@
}
/// Returns the destination MemRefType for this DMA operations.
- Value *getDstMemRef() { return getOperand(getDstMemRefOperandIndex()); }
+ Value getDstMemRef() { return getOperand(getDstMemRefOperandIndex()); }
MemRefType getDstMemRefType() {
return getDstMemRef()->getType().cast<MemRefType>();
}
@@ -225,7 +215,7 @@
}
/// Returns the Tag MemRef for this DMA operation.
- Value *getTagMemRef() { return getOperand(getTagMemRefOperandIndex()); }
+ Value getTagMemRef() { return getOperand(getTagMemRefOperandIndex()); }
MemRefType getTagMemRefType() {
return getTagMemRef()->getType().cast<MemRefType>();
}
@@ -249,13 +239,13 @@
}
/// Returns the number of elements being transferred by this DMA operation.
- Value *getNumElements() {
+ Value getNumElements() {
return getOperand(getTagMemRefOperandIndex() + 1 +
getTagMap().getNumInputs());
}
/// Returns the AffineMapAttr associated with 'memref'.
- NamedAttribute getAffineMapAttrForMemRef(Value *memref) {
+ NamedAttribute getAffineMapAttrForMemRef(Value memref) {
if (memref == getSrcMemRef())
return {Identifier::get(getSrcMapAttrName(), getContext()),
getSrcMapAttr()};
@@ -305,14 +295,14 @@
}
/// Returns the stride value for this DMA operation.
- Value *getStride() {
+ Value getStride() {
if (!isStrided())
return nullptr;
return getOperand(getNumOperands() - 1 - 1);
}
/// Returns the number of elements to transfer per stride for this DMA op.
- Value *getNumElementsPerStride() {
+ Value getNumElementsPerStride() {
if (!isStrided())
return nullptr;
return getOperand(getNumOperands() - 1);
@@ -337,14 +327,13 @@
public:
using Op::Op;
- static void build(Builder *builder, OperationState &result, Value *tagMemRef,
- AffineMap tagMap, ValueRange tagIndices,
- Value *numElements);
+ static void build(Builder *builder, OperationState &result, Value tagMemRef,
+ AffineMap tagMap, ValueRange tagIndices, Value numElements);
static StringRef getOperationName() { return "affine.dma_wait"; }
// Returns the Tag MemRef associated with the DMA operation being waited on.
- Value *getTagMemRef() { return getOperand(0); }
+ Value getTagMemRef() { return getOperand(0); }
MemRefType getTagMemRefType() {
return getTagMemRef()->getType().cast<MemRefType>();
}
@@ -367,14 +356,14 @@
}
/// Returns the AffineMapAttr associated with 'memref'.
- NamedAttribute getAffineMapAttrForMemRef(Value *memref) {
+ NamedAttribute getAffineMapAttrForMemRef(Value memref) {
assert(memref == getTagMemRef());
return {Identifier::get(getTagMapAttrName(), getContext()),
getTagMapAttr()};
}
/// Returns the number of elements transferred in the associated DMA op.
- Value *getNumElements() { return getOperand(1 + getTagMap().getNumInputs()); }
+ Value getNumElements() { return getOperand(1 + getTagMap().getNumInputs()); }
static StringRef getTagMapAttrName() { return "tag_map"; }
static ParseResult parse(OpAsmParser &parser, OperationState &result);
@@ -409,18 +398,18 @@
static void build(Builder *builder, OperationState &result, AffineMap map,
ValueRange operands);
/// Builds an affine load op with an identity map and operands.
- static void build(Builder *builder, OperationState &result, Value *memref,
+ static void build(Builder *builder, OperationState &result, Value memref,
ValueRange indices = {});
/// Builds an affine load op with the specified map and its operands.
- static void build(Builder *builder, OperationState &result, Value *memref,
+ static void build(Builder *builder, OperationState &result, Value memref,
AffineMap map, ValueRange mapOperands);
/// Returns the operand index of the memref.
unsigned getMemRefOperandIndex() { return 0; }
/// Get memref operand.
- Value *getMemRef() { return getOperand(getMemRefOperandIndex()); }
- void setMemRef(Value *value) { setOperand(getMemRefOperandIndex(), value); }
+ Value getMemRef() { return getOperand(getMemRefOperandIndex()); }
+ void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); }
MemRefType getMemRefType() {
return getMemRef()->getType().cast<MemRefType>();
}
@@ -435,7 +424,7 @@
}
/// Returns the AffineMapAttr associated with 'memref'.
- NamedAttribute getAffineMapAttrForMemRef(Value *memref) {
+ NamedAttribute getAffineMapAttrForMemRef(Value memref) {
assert(memref == getMemRef());
return {Identifier::get(getMapAttrName(), getContext()),
getAffineMapAttr()};
@@ -476,21 +465,21 @@
/// Builds an affine store operation with the provided indices (identity map).
static void build(Builder *builder, OperationState &result,
- Value *valueToStore, Value *memref, ValueRange indices);
+ Value valueToStore, Value memref, ValueRange indices);
/// Builds an affine store operation with the specified map and its operands.
static void build(Builder *builder, OperationState &result,
- Value *valueToStore, Value *memref, AffineMap map,
+ Value valueToStore, Value memref, AffineMap map,
ValueRange mapOperands);
/// Get value to be stored by store operation.
- Value *getValueToStore() { return getOperand(0); }
+ Value getValueToStore() { return getOperand(0); }
/// Returns the operand index of the memref.
unsigned getMemRefOperandIndex() { return 1; }
/// Get memref operand.
- Value *getMemRef() { return getOperand(getMemRefOperandIndex()); }
- void setMemRef(Value *value) { setOperand(getMemRefOperandIndex(), value); }
+ Value getMemRef() { return getOperand(getMemRefOperandIndex()); }
+ void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); }
MemRefType getMemRefType() {
return getMemRef()->getType().cast<MemRefType>();
@@ -506,7 +495,7 @@
}
/// Returns the AffineMapAttr associated with 'memref'.
- NamedAttribute getAffineMapAttrForMemRef(Value *memref) {
+ NamedAttribute getAffineMapAttrForMemRef(Value memref) {
assert(memref == getMemRef());
return {Identifier::get(getMapAttrName(), getContext()),
getAffineMapAttr()};
@@ -526,10 +515,10 @@
};
/// Returns true if the given Value can be used as a dimension id.
-bool isValidDim(Value *value);
+bool isValidDim(Value value);
/// Returns true if the given Value can be used as a symbol.
-bool isValidSymbol(Value *value);
+bool isValidSymbol(Value value);
/// Modifies both `map` and `operands` in-place so as to:
/// 1. drop duplicate operands
@@ -538,17 +527,17 @@
/// dimensional operands
/// 4. propagate constant operands and drop them
void canonicalizeMapAndOperands(AffineMap *map,
- SmallVectorImpl<Value *> *operands);
+ SmallVectorImpl<Value> *operands);
/// Canonicalizes an integer set the same way canonicalizeMapAndOperands does
/// for affine maps.
void canonicalizeSetAndOperands(IntegerSet *set,
- SmallVectorImpl<Value *> *operands);
+ SmallVectorImpl<Value> *operands);
/// Returns a composed AffineApplyOp by composing `map` and `operands` with
/// other AffineApplyOps supplying those operands. The operands of the resulting
/// AffineApplyOp do not change the length of AffineApplyOp chains.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map,
- ArrayRef<Value *> operands);
+ ArrayRef<Value> operands);
/// Given an affine map `map` and its input `operands`, this method composes
/// into `map`, maps of AffineApplyOps whose results are the values in
@@ -558,22 +547,22 @@
/// terminal symbol, i.e., a symbol defined at the top level or a block/function
/// argument.
void fullyComposeAffineMapAndOperands(AffineMap *map,
- SmallVectorImpl<Value *> *operands);
+ SmallVectorImpl<Value> *operands);
#define GET_OP_CLASSES
#include "mlir/Dialect/AffineOps/AffineOps.h.inc"
/// Returns if the provided value is the induction variable of a AffineForOp.
-bool isForInductionVar(Value *val);
+bool isForInductionVar(Value val);
/// Returns the loop parent of an induction variable. If the provided value is
/// not an induction variable, then return nullptr.
-AffineForOp getForInductionVarOwner(Value *val);
+AffineForOp getForInductionVarOwner(Value val);
/// Extracts the induction variables from a list of AffineForOps and places them
/// in the output argument `ivs`.
void extractForInductionVars(ArrayRef<AffineForOp> forInsts,
- SmallVectorImpl<Value *> *ivs);
+ SmallVectorImpl<Value> *ivs);
/// AffineBound represents a lower or upper bound in the for operation.
/// This class does not own the underlying operands. Instead, it refers
@@ -588,7 +577,7 @@
AffineValueMap getAsAffineValueMap();
unsigned getNumOperands() { return opEnd - opStart; }
- Value *getOperand(unsigned idx) { return op.getOperand(opStart + idx); }
+ Value getOperand(unsigned idx) { return op.getOperand(opStart + idx); }
using operand_iterator = AffineForOp::operand_iterator;
using operand_range = AffineForOp::operand_range;
@@ -613,7 +602,7 @@
};
/// An `AffineApplyNormalizer` is a helper class that supports renumbering
-/// operands of AffineApplyOp. This acts as a reindexing map of Value* to
+/// operands of AffineApplyOp. This acts as a reindexing map of Value to
/// positional dims or symbols and allows simplifications such as:
///
/// ```mlir
@@ -626,13 +615,13 @@
/// %1 = affine.apply () -> (0)
/// ```
struct AffineApplyNormalizer {
- AffineApplyNormalizer(AffineMap map, ArrayRef<Value *> operands);
+ AffineApplyNormalizer(AffineMap map, ArrayRef<Value> operands);
/// Returns the AffineMap resulting from normalization.
AffineMap getAffineMap() { return affineMap; }
- SmallVector<Value *, 8> getOperands() {
- SmallVector<Value *, 8> res(reorderedDims);
+ SmallVector<Value, 8> getOperands() {
+ SmallVector<Value, 8> res(reorderedDims);
res.append(concatenatedSymbols.begin(), concatenatedSymbols.end());
return res;
}
@@ -642,13 +631,13 @@
/// Normalizes 'otherMap' and its operands 'otherOperands' to map to this
/// normalizer's coordinate space.
- void normalize(AffineMap *otherMap, SmallVectorImpl<Value *> *otherOperands);
+ void normalize(AffineMap *otherMap, SmallVectorImpl<Value> *otherOperands);
private:
/// Helper function to insert `v` into the coordinate system of the current
/// AffineApplyNormalizer. Returns the AffineDimExpr with the corresponding
/// renumbered position.
- AffineDimExpr renumberOneDim(Value *v);
+ AffineDimExpr renumberOneDim(Value v);
/// Given an `other` normalizer, this rewrites `other.affineMap` in the
/// coordinate system of the current AffineApplyNormalizer.
@@ -656,13 +645,13 @@
/// `this`.
AffineMap renumber(const AffineApplyNormalizer &other);
- /// Maps of Value* to position in `affineMap`.
- DenseMap<Value *, unsigned> dimValueToPosition;
+ /// Maps of Value to position in `affineMap`.
+ DenseMap<Value, unsigned> dimValueToPosition;
/// Ordered dims and symbols matching positional dims and symbols in
/// `affineMap`.
- SmallVector<Value *, 8> reorderedDims;
- SmallVector<Value *, 8> concatenatedSymbols;
+ SmallVector<Value, 8> reorderedDims;
+ SmallVector<Value, 8> concatenatedSymbols;
AffineMap affineMap;
diff --git a/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.td b/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.td
index b40990e..114e205 100644
--- a/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.td
+++ b/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOps.td
@@ -1,19 +1,10 @@
//===- AffineOps.td - Affine operation definitions ---------*- tablegen -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Defines MLIR affine operations.
//
@@ -101,7 +92,7 @@
static StringRef getUpperBoundAttrName() { return "upper_bound"; }
Block *getBody() { return ®ion().front(); }
- Value *getInductionVar() { return getBody()->getArgument(0); }
+ Value getInductionVar() { return getBody()->getArgument(0); }
OpBuilder getBodyBuilder() {
return OpBuilder(getBody(), std::prev(getBody()->end()));
}
@@ -286,8 +277,8 @@
BoolAttr:$isDataCache);
let builders = [OpBuilder<
- "Builder *builder, OperationState &result, Value *memref,"
- "AffineMap map, ArrayRef<Value *> mapOperands, bool isWrite,"
+ "Builder *builder, OperationState &result, Value memref,"
+ "AffineMap map, ArrayRef<Value> mapOperands, bool isWrite,"
"unsigned localityHint, bool isDataCache",
[{
assert(map.getNumInputs() == mapOperands.size()
@@ -315,7 +306,7 @@
}
/// Returns the AffineMapAttr associated with 'memref'.
- NamedAttribute getAffineMapAttrForMemRef(Value *mref) {
+ NamedAttribute getAffineMapAttrForMemRef(Value mref) {
assert(mref == memref());
return {Identifier::get(getMapAttrName(), getContext()),
getAffineMapAttr()};
diff --git a/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOpsBase.td b/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOpsBase.td
index 755f65c..6aee5f3 100644
--- a/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOpsBase.td
+++ b/third_party/mlir/include/mlir/Dialect/AffineOps/AffineOpsBase.td
@@ -1,19 +1,10 @@
//===- AffineOpsBase.td - Affine operation definitions -----*- tablegen -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Defines base support for MLIR affine operations.
//
diff --git a/third_party/mlir/include/mlir/Dialect/AffineOps/CMakeLists.txt b/third_party/mlir/include/mlir/Dialect/AffineOps/CMakeLists.txt
index 8f812b3..7339bcc 100644
--- a/third_party/mlir/include/mlir/Dialect/AffineOps/CMakeLists.txt
+++ b/third_party/mlir/include/mlir/Dialect/AffineOps/CMakeLists.txt
@@ -1 +1 @@
-add_mlir_dialect(AffineOps)
+add_mlir_dialect(AffineOps AffineOps)
diff --git a/third_party/mlir/include/mlir/Dialect/CommonFolders.h b/third_party/mlir/include/mlir/Dialect/CommonFolders.h
index 4555294..d667de7 100644
--- a/third_party/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/third_party/mlir/include/mlir/Dialect/CommonFolders.h
@@ -1,19 +1,10 @@
//===- CommonFolders.h - Common Operation Folders----------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This header file declares various common operation folders. These folders
// are intended to be used by dialects to support common folding behavior
diff --git a/third_party/mlir/include/mlir/Dialect/FxpMathOps/CMakeLists.txt b/third_party/mlir/include/mlir/Dialect/FxpMathOps/CMakeLists.txt
index a8fb5e0..4842307 100644
--- a/third_party/mlir/include/mlir/Dialect/FxpMathOps/CMakeLists.txt
+++ b/third_party/mlir/include/mlir/Dialect/FxpMathOps/CMakeLists.txt
@@ -1 +1 @@
-add_mlir_dialect(FxpMathOps)
+add_mlir_dialect(FxpMathOps FxpMathOps)
diff --git a/third_party/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.h b/third_party/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.h
index 88a4234..8c0e7aa 100644
--- a/third_party/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.h
+++ b/third_party/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.h
@@ -1,19 +1,10 @@
//===- FxpMathOps.h - Fixed point ops ---------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_FXPMATHOPS_FXPMATHOPS_H_
#define MLIR_DIALECT_FXPMATHOPS_FXPMATHOPS_H_
diff --git a/third_party/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.td b/third_party/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.td
index b1bfb27..d527b75 100644
--- a/third_party/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.td
+++ b/third_party/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.td
@@ -1,19 +1,10 @@
//===- FxpMathOps.td - Fixed point ops --------------------*- tablegen -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This is the operation definition file for fixed point ops (and real
// equivalents).
diff --git a/third_party/mlir/include/mlir/Dialect/FxpMathOps/Passes.h b/third_party/mlir/include/mlir/Dialect/FxpMathOps/Passes.h
index 415b1c0..aec21c4 100644
--- a/third_party/mlir/include/mlir/Dialect/FxpMathOps/Passes.h
+++ b/third_party/mlir/include/mlir/Dialect/FxpMathOps/Passes.h
@@ -1,19 +1,10 @@
//===- Passes.h - Fixed point math passes -----------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines all of the passes owned by the FxpMathOps dialect.
//
diff --git a/third_party/mlir/include/mlir/Dialect/GPU/CMakeLists.txt b/third_party/mlir/include/mlir/Dialect/GPU/CMakeLists.txt
index bdb5dec..fd85b5b 100644
--- a/third_party/mlir/include/mlir/Dialect/GPU/CMakeLists.txt
+++ b/third_party/mlir/include/mlir/Dialect/GPU/CMakeLists.txt
@@ -1 +1 @@
-add_mlir_dialect(GPUOps)
+add_mlir_dialect(GPUOps GPUOps)
diff --git a/third_party/mlir/include/mlir/Dialect/GPU/GPUDialect.h b/third_party/mlir/include/mlir/Dialect/GPU/GPUDialect.h
index 495238f..1776ff7 100644
--- a/third_party/mlir/include/mlir/Dialect/GPU/GPUDialect.h
+++ b/third_party/mlir/include/mlir/Dialect/GPU/GPUDialect.h
@@ -1,19 +1,10 @@
//===- GPUDialect.h - MLIR Dialect for GPU Kernels --------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines the GPU kernel-related operations and puts them in the
// corresponding dialect.
@@ -26,6 +17,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/FunctionSupport.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/SymbolTable.h"
namespace mlir {
@@ -76,9 +68,9 @@
/// Utility class for the GPU dialect to represent triples of `Value`s
/// accessible through `.x`, `.y`, and `.z` similarly to CUDA notation.
struct KernelDim3 {
- Value *x;
- Value *y;
- Value *z;
+ Value x;
+ Value y;
+ Value z;
};
#define GET_OP_CLASSES
diff --git a/third_party/mlir/include/mlir/Dialect/GPU/GPUOps.td b/third_party/mlir/include/mlir/Dialect/GPU/GPUOps.td
index 46433c6..b5b93e9 100644
--- a/third_party/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/third_party/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -1,19 +1,10 @@
//===-- GPUOps.td - GPU dialect operation definitions ------*- tablegen -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Defines some operations of the GPU dialect.
//
@@ -157,7 +148,7 @@
/// Returns a list of block arguments that correspond to buffers located in
/// the workgroup memory
- ArrayRef<BlockArgument *> getWorkgroupAttributions() {
+ ArrayRef<BlockArgument> getWorkgroupAttributions() {
auto begin =
std::next(getBody().front().args_begin(), getType().getNumInputs());
auto end = std::next(begin, getNumWorkgroupAttributions());
@@ -166,7 +157,7 @@
/// Returns a list of block arguments that correspond to buffers located in
/// the private memory.
- ArrayRef<BlockArgument *> getPrivateAttributions() {
+ ArrayRef<BlockArgument> getPrivateAttributions() {
auto begin =
std::next(getBody().front().args_begin(),
getType().getNumInputs() + getNumWorkgroupAttributions());
@@ -282,8 +273,8 @@
let builders = [
OpBuilder<"Builder *builder, OperationState &result, GPUFuncOp kernelFunc, "
- "Value *gridSizeX, Value *gridSizeY, Value *gridSizeZ, "
- "Value *blockSizeX, Value *blockSizeY, Value *blockSizeZ, "
+ "Value gridSizeX, Value gridSizeY, Value gridSizeZ, "
+ "Value blockSizeX, Value blockSizeY, Value blockSizeZ, "
"ValueRange kernelOperands">,
OpBuilder<"Builder *builder, OperationState &result, GPUFuncOp kernelFunc, "
"KernelDim3 gridSize, KernelDim3 blockSize, "
@@ -302,7 +293,7 @@
StringRef getKernelModuleName();
/// The i-th operand passed to the kernel function.
- Value *getKernelOperand(unsigned i);
+ Value getKernelOperand(unsigned i);
/// Get the SSA values passed as operands to specify the grid size.
KernelDim3 getGridSizeOperandValues();
@@ -415,9 +406,9 @@
let skipDefaultBuilders = 1;
let builders = [
- OpBuilder<"Builder *builder, OperationState &result, Value *gridSizeX,"
- "Value *gridSizeY, Value *gridSizeZ, Value *blockSizeX,"
- "Value *blockSizeY, Value *blockSizeZ,"
+ OpBuilder<"Builder *builder, OperationState &result, Value gridSizeX,"
+ "Value gridSizeY, Value gridSizeZ, Value blockSizeX,"
+ "Value blockSizeY, Value blockSizeZ,"
"ValueRange operands">
];
@@ -536,6 +527,41 @@
let verifier = [{ return ::verifyAllReduce(*this); }];
}
+def GPU_ShuffleOpXor : StrEnumAttrCase<"xor">;
+
+def GPU_ShuffleModeAttr : StrEnumAttr<"ShuffleModeAttr",
+ "Indexing modes supported by gpu.shuffle.",
+ [
+ GPU_ShuffleOpXor,
+ ]>;
+
+def GPU_ShuffleOp : GPU_Op<"shuffle", [NoSideEffect]>,
+ Arguments<(ins AnyType:$value, I32:$offset, I32:$width,
+ GPU_ShuffleModeAttr:$mode)>,
+ Results<(outs AnyType:$result, I1:$valid)> {
+ let summary = "Shuffles values within a subgroup.";
+ let description = [{
+ The "shuffle" op moves values to a different invocation within the same
+ subgroup.
+
+ For example
+ ```
+ %1, %2 = gpu.shuffle %0, %offset, %width xor : f32
+ ```
+ for lane k returns the value from lane `k ^ offset` and `true` if that lane
+ is smaller than %width. Otherwise it returns an unspecified value and
+ `false`. A lane is the index of an invocation relative to its subgroup.
+
+ The width specifies the number of invocations that participate in the
+ shuffle. The width needs to be the same for all invocations that participate
+ in the shuffle. Exactly the first `width` invocations of a subgroup need to
+ execute this op in convergence.
+ }];
+ let verifier = [{ return ::verifyShuffleOp(*this); }];
+ let printer = [{ printShuffleOp(p, *this); }];
+ let parser = [{ return parseShuffleOp(parser, result); }];
+}
+
def GPU_BarrierOp : GPU_Op<"barrier"> {
let summary = "Synchronizes all work items of a workgroup.";
let description = [{
diff --git a/third_party/mlir/include/mlir/Dialect/GPU/Passes.h b/third_party/mlir/include/mlir/Dialect/GPU/Passes.h
index 7c8ce02..daf6d28 100644
--- a/third_party/mlir/include/mlir/Dialect/GPU/Passes.h
+++ b/third_party/mlir/include/mlir/Dialect/GPU/Passes.h
@@ -1,19 +1,10 @@
//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This header file defines prototypes that expose pass constructors.
//
diff --git a/third_party/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/third_party/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
index 4ecc71a..fa68eff 100644
--- a/third_party/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
+++ b/third_party/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
@@ -5,8 +5,8 @@
mlir_tablegen(LLVMOpsEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRLLVMOpsIncGen)
-add_mlir_dialect(NVVMOps)
-add_mlir_dialect(ROCDLOps)
+add_mlir_dialect(NVVMOps NVVMOps)
+add_mlir_dialect(ROCDLOps ROCDLOps)
set(LLVM_TARGET_DEFINITIONS LLVMOps.td)
mlir_tablegen(LLVMConversions.inc -gen-llvmir-conversions)
diff --git a/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index dae27d0..d36619b 100644
--- a/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -1,19 +1,10 @@
//===- LLVMDialect.h - MLIR LLVM IR dialect ---------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines the LLVM IR dialect in MLIR, containing LLVM operations and
// LLVM type system.
@@ -194,9 +185,9 @@
/// surrounding the insertion point of builder. Obtain the address of that
/// global and use it to compute the address of the first character in the
/// string (operations inserted at the builder insertion point).
-Value *createGlobalString(Location loc, OpBuilder &builder, StringRef name,
- StringRef value, LLVM::Linkage linkage,
- LLVM::LLVMDialect *llvmDialect);
+Value createGlobalString(Location loc, OpBuilder &builder, StringRef name,
+ StringRef value, LLVM::Linkage linkage,
+ LLVM::LLVMDialect *llvmDialect);
/// LLVM requires some operations to be inside of a Module operation. This
/// function confirms that the Operation has the desired properties.
diff --git a/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index 6257b4a..ed935d5 100644
--- a/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -1,19 +1,10 @@
//===-- LLVMOpBase.td - LLVM IR dialect shared definitions -*- tablegen -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file contains shared definitions for the LLVM IR dialect and its
// subdialects.
diff --git a/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 2f7a980..2e47eb0 100644
--- a/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/third_party/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1,19 +1,10 @@
//===-- LLVMOps.td - LLVM IR dialect op definition file ----*- tablegen -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This is the LLVM IR operation definition file.
//
@@ -185,8 +176,8 @@
$res = builder.CreateICmp(getLLVMCmpPredicate($predicate), $lhs, $rhs);
}];
let builders = [OpBuilder<
- "Builder *b, OperationState &result, ICmpPredicate predicate, Value *lhs, "
- "Value *rhs", [{
+ "Builder *b, OperationState &result, ICmpPredicate predicate, Value lhs, "
+ "Value rhs", [{
LLVMDialect *dialect = &lhs->getType().cast<LLVMType>().getDialect();
build(b, result, LLVMType::getInt1Ty(dialect),
b->getI64IntegerAttr(static_cast<int64_t>(predicate)), lhs, rhs);
@@ -232,8 +223,8 @@
$res = builder.CreateFCmp(getLLVMCmpPredicate($predicate), $lhs, $rhs);
}];
let builders = [OpBuilder<
- "Builder *b, OperationState &result, FCmpPredicate predicate, Value *lhs, "
- "Value *rhs", [{
+ "Builder *b, OperationState &result, FCmpPredicate predicate, Value lhs, "
+ "Value rhs", [{
LLVMDialect *dialect = &lhs->getType().cast<LLVMType>().getDialect();
build(b, result, LLVMType::getInt1Ty(dialect),
b->getI64IntegerAttr(static_cast<int64_t>(predicate)), lhs, rhs);
@@ -265,7 +256,7 @@
$res = alloca;
}];
let builders = [OpBuilder<
- "Builder *b, OperationState &result, Type resultType, Value *arraySize, "
+ "Builder *b, OperationState &result, Type resultType, Value arraySize, "
"unsigned alignment",
[{
if (alignment == 0)
@@ -292,7 +283,7 @@
def LLVM_LoadOp : LLVM_OneResultOp<"load">, Arguments<(ins LLVM_Type:$addr)>,
LLVM_Builder<"$res = builder.CreateLoad($addr);"> {
let builders = [OpBuilder<
- "Builder *b, OperationState &result, Value *addr",
+ "Builder *b, OperationState &result, Value addr",
[{
auto type = addr->getType().cast<LLVM::LLVMType>().getPointerElementTy();
build(b, result, type, addr);
@@ -353,7 +344,7 @@
$res = builder.CreateExtractElement($vector, $position);
}];
let builders = [OpBuilder<
- "Builder *b, OperationState &result, Value *vector, Value *position,"
+ "Builder *b, OperationState &result, Value vector, Value position,"
"ArrayRef<NamedAttribute> attrs = {}">];
let parser = [{ return parseExtractElementOp(parser, result); }];
let printer = [{ printExtractElementOp(p, *this); }];
@@ -384,7 +375,7 @@
extractPosition($position));
}];
let builders = [OpBuilder<
- "Builder *b, OperationState &result, Value *container, Value *value, "
+ "Builder *b, OperationState &result, Value container, Value value, "
"ArrayAttr position",
[{
build(b, result, container->getType(), container, value, position);
@@ -394,11 +385,11 @@
}
def LLVM_ShuffleVectorOp
: LLVM_OneResultOp<"shufflevector", [NoSideEffect]>,
- Arguments<(ins LLVM_Type:$v1, LLVM_Type:$v2, I32ArrayAttr:$mask)>,
+ Arguments<(ins LLVM_Type:$v1, LLVM_Type:$v2, ArrayAttr:$mask)>,
LLVM_Builder<
"$res = builder.CreateShuffleVector($v1, $v2, extractPosition($mask));"> {
let builders = [OpBuilder<
- "Builder *b, OperationState &result, Value *v1, Value *v2, "
+ "Builder *b, OperationState &result, Value v1, Value v2, "
"ArrayAttr mask, ArrayRef<NamedAttribute> attrs = {}">];
let verifier = [{
auto wrappedVectorType1 = v1()->getType().cast<LLVM::LLVMType>();
@@ -422,8 +413,8 @@
LLVM_Builder<
"$res = builder.CreateSelect($condition, $trueValue, $falseValue);"> {
let builders = [OpBuilder<
- "Builder *b, OperationState &result, Value *condition, Value *lhs, "
- "Value *rhs", [{
+ "Builder *b, OperationState &result, Value condition, Value lhs, "
+ "Value rhs", [{
build(b, result, lhs->getType(), condition, lhs, rhs);
}]>];
let parser = [{ return parseSelectOp(parser, result); }];
diff --git a/third_party/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/third_party/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
index 0328cf4..afb6d4a 100644
--- a/third_party/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
+++ b/third_party/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
@@ -1,19 +1,10 @@
//===- NVVMDialect.h - MLIR NVVM IR dialect ---------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines the NVVM IR dialect in MLIR, containing NVVM operations and
// NVVM specific extensions to the LLVM type system.
diff --git a/third_party/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/third_party/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index bc6887d..f35b779 100644
--- a/third_party/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/third_party/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1,19 +1,10 @@
//===-- NVVMOps.td - NVVM IR dialect op definition file ----*- tablegen -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This is the NVVM IR operation definition file.
//
diff --git a/third_party/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h b/third_party/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h
index a34c112..dab32d3 100644
--- a/third_party/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h
+++ b/third_party/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h
@@ -1,19 +1,10 @@
//===- ROCDLDialect.h - MLIR ROCDL IR dialect -------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines the ROCDL dialect in MLIR, containing ROCDL operations
// and ROCDL specific extensions to the LLVM type system.
diff --git a/third_party/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/third_party/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 79d4136..697ff97 100644
--- a/third_party/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/third_party/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -1,19 +1,10 @@
//===-- ROCDLOps.td - ROCDL IR dialect op definition file --*- tablegen -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This is the ROCDL IR operation definition file.
//
diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h b/third_party/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
index 01d3e4b..dd5034e 100644
--- a/third_party/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
+++ b/third_party/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
@@ -1,19 +1,10 @@
//===- DependenceAnalysis.h - Dependence analysis on SSA views --*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_
#define MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_
@@ -37,15 +28,15 @@
class Aliases {
public:
/// Returns true if v1 and v2 alias.
- bool alias(Value *v1, Value *v2) { return find(v1) == find(v2); }
+ bool alias(Value v1, Value v2) { return find(v1) == find(v2); }
private:
/// Returns the base buffer or block argument into which the view `v` aliases.
/// This lazily records the new aliases discovered while walking back the
/// use-def chain.
- Value *find(Value *v);
+ Value find(Value v);
- DenseMap<Value *, Value *> aliases;
+ DenseMap<Value, Value> aliases;
};
/// Data structure for holding a dependence graph that operates on LinalgOp and
@@ -54,7 +45,7 @@
public:
struct LinalgOpView {
Operation *op;
- Value *view;
+ Value view;
};
struct LinalgDependenceGraphElem {
// dependentOpView may be either:
@@ -64,7 +55,7 @@
// View in the op that is used to index in the graph:
// 1. src in the case of dependencesFromDstGraphs.
// 2. dst in the case of dependencesIntoGraphs.
- Value *indexingView;
+ Value indexingView;
};
using LinalgDependences = SmallVector<LinalgDependenceGraphElem, 8>;
using DependenceGraph = DenseMap<Operation *, LinalgDependences>;
@@ -97,14 +88,14 @@
/// Dependences are restricted to views aliasing `view`.
SmallVector<Operation *, 8> findCoveringReads(LinalgOp srcLinalgOp,
LinalgOp dstLinalgOp,
- Value *view) const;
+ Value view) const;
/// Returns the operations that are interleaved between `srcLinalgOp` and
/// `dstLinalgOp` and that are involved in a WAR or WAW with `srcLinalgOp`.
/// Dependences are restricted to views aliasing `view`.
SmallVector<Operation *, 8> findCoveringWrites(LinalgOp srcLinalgOp,
LinalgOp dstLinalgOp,
- Value *view) const;
+ Value view) const;
private:
// Keep dependences in both directions, this is not just a performance gain
@@ -130,7 +121,7 @@
/// Implementation detail for findCoveringxxx.
SmallVector<Operation *, 8>
findOperationsWithCoveringDependences(LinalgOp srcLinalgOp,
- LinalgOp dstLinalgOp, Value *view,
+ LinalgOp dstLinalgOp, Value view,
ArrayRef<DependenceType> types) const;
Aliases &aliases;
diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h b/third_party/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h
index cf63352..97fbede 100644
--- a/third_party/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h
+++ b/third_party/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h
@@ -1,19 +1,10 @@
//===- Builders.h - MLIR Declarative Linalg Builders ------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Provides intuitive composable interfaces for building structured MLIR
// snippets in a declarative fashion.
@@ -55,35 +46,34 @@
/// makeLinalgGenericOp({A({m, n}), B({k, n})}, {C({m, n})}, ... );
/// ```
struct StructuredIndexed {
- StructuredIndexed(Value *v) : value(v) {}
+ StructuredIndexed(Value v) : value(v) {}
StructuredIndexed operator()(ArrayRef<AffineExpr> indexings) {
return StructuredIndexed(value, indexings);
}
- operator Value *() const /* implicit */ { return value; }
+ operator Value() const /* implicit */ { return value; }
ArrayRef<AffineExpr> getExprs() { return exprs; }
private:
- StructuredIndexed(Value *v, ArrayRef<AffineExpr> indexings)
+ StructuredIndexed(Value v, ArrayRef<AffineExpr> indexings)
: value(v), exprs(indexings.begin(), indexings.end()) {
assert(v->getType().isa<MemRefType>() && "MemRefType expected");
}
StructuredIndexed(ValueHandle v, ArrayRef<AffineExpr> indexings)
: StructuredIndexed(v.getValue(), indexings) {}
- Value *value;
+ Value value;
SmallVector<AffineExpr, 4> exprs;
};
-inline void defaultRegionBuilder(ArrayRef<BlockArgument *> args) {}
+inline void defaultRegionBuilder(ArrayRef<BlockArgument> args) {}
-Operation *makeLinalgGenericOp(ArrayRef<IterType> iteratorTypes,
- ArrayRef<StructuredIndexed> inputs,
- ArrayRef<StructuredIndexed> outputs,
- function_ref<void(ArrayRef<BlockArgument *>)>
- regionBuilder = defaultRegionBuilder,
- ArrayRef<Value *> otherValues = {},
- ArrayRef<Attribute> otherAttributes = {});
+Operation *makeLinalgGenericOp(
+ ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
+ ArrayRef<StructuredIndexed> outputs,
+ function_ref<void(ArrayRef<BlockArgument>)> regionBuilder =
+ defaultRegionBuilder,
+ ArrayRef<Value> otherValues = {}, ArrayRef<Attribute> otherAttributes = {});
namespace ops {
using edsc::StructuredIndexed;
@@ -96,7 +86,7 @@
/// Build the body of a region to compute a multiply-accumulate, under the
/// current ScopedContext, at the current insert point.
-void macRegionBuilder(ArrayRef<BlockArgument *> args);
+void macRegionBuilder(ArrayRef<BlockArgument> args);
/// TODO(ntv): In the future we should tie these implementations to something in
/// Tablegen that generates the proper interfaces and the proper sugared named
@@ -120,7 +110,7 @@
/// with in-place semantics and parallelism.
/// Unary pointwise operation (with broadcast) entry point.
-using UnaryPointwiseOpBuilder = function_ref<Value *(ValueHandle)>;
+using UnaryPointwiseOpBuilder = function_ref<Value(ValueHandle)>;
Operation *linalg_pointwise(UnaryPointwiseOpBuilder unaryOp,
StructuredIndexed I, StructuredIndexed O);
@@ -130,8 +120,7 @@
Operation *linalg_pointwise_tanh(StructuredIndexed I, StructuredIndexed O);
/// Binary pointwise operation (with broadcast) entry point.
-using BinaryPointwiseOpBuilder =
- function_ref<Value *(ValueHandle, ValueHandle)>;
+using BinaryPointwiseOpBuilder = function_ref<Value(ValueHandle, ValueHandle)>;
Operation *linalg_pointwise(BinaryPointwiseOpBuilder binaryOp,
StructuredIndexed I1, StructuredIndexed I2,
StructuredIndexed O);
diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h b/third_party/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h
index f1acab6..b04c11f 100644
--- a/third_party/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h
+++ b/third_party/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h
@@ -1,19 +1,10 @@
//===- Intrinsics.h - MLIR EDSC Intrinsics for Linalg -----------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_LINALG_EDSC_INTRINSICS_H_
#define MLIR_DIALECT_LINALG_EDSC_INTRINSICS_H_
diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt b/third_party/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt
index 2a883a1..269729b 100644
--- a/third_party/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt
+++ b/third_party/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt
@@ -1,7 +1,8 @@
-add_mlir_dialect(LinalgOps)
-set(LLVM_TARGET_DEFINITIONS LinalgLibraryOps.td)
-mlir_tablegen(LinalgLibraryOps.h.inc -gen-op-decls)
-mlir_tablegen(LinalgLibraryOps.cpp.inc -gen-op-defs)
-mlir_tablegen(LinalgLibraryOpInterfaces.h.inc -gen-op-interface-decls)
-mlir_tablegen(LinalgLibraryOpInterfaces.cpp.inc -gen-op-interface-defs)
-add_public_tablegen_target(MLIRLinalgLibraryOpsIncGen)
+add_mlir_dialect(LinalgOps LinalgDoc)
+set(LLVM_TARGET_DEFINITIONS LinalgStructuredOps.td)
+mlir_tablegen(LinalgStructuredOps.h.inc -gen-op-decls)
+mlir_tablegen(LinalgStructuredOps.cpp.inc -gen-op-defs)
+mlir_tablegen(LinalgStructuredOpsInterfaces.h.inc -gen-op-interface-decls)
+mlir_tablegen(LinalgStructuredOpsInterfaces.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRLinalgStructuredOpsIncGen)
+
diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index edc8125..c1adc8b 100644
--- a/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -1,19 +1,10 @@
//===- LinalgBase.td - Linalg dialect base support ---------*- tablegen -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This is the definition file for base linear algebra support.
//
@@ -117,6 +108,4 @@
def LinalgIsRangeTypePred : CPred<"$_self.isa<RangeType>()">;
def Range : Type<LinalgIsRangeTypePred, "range">;
-// TODO(ntv): inject the doc for LinalgLibraryOps.td here.
-
#endif // LINALG_BASE
diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgDoc.td b/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgDoc.td
new file mode 100644
index 0000000..819d02d
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgDoc.td
@@ -0,0 +1,23 @@
+//===- LinalgDoc.td - Linalg documentation -----------------*- tablegen -*-===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This documentation files exists to circumvent limitations on mixing different
+// .td files in cases one does not want to have all ops belong to the same
+// logical unit. This file should only include other .td files only and be used
+// for the purpose of generating documentation.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LINALG_DOC
+#define LINALG_DOC
+
+include "mlir/Dialect/Linalg/IR/LinalgBase.td"
+include "mlir/Dialect/Linalg/IR/LinalgOps.td"
+include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.td"
+
+#endif // LINALG_DOC
diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td b/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td
index 12318a2..6fdb8a6 100644
--- a/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td
+++ b/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td
@@ -1,19 +1,10 @@
//===- LinalgLibraryOps.td - Linalg dialect library ops -*- tablegen ----*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This is the operation definition file for linear algebra operations that
// correspond to underlying library calls (e.g. BLAS).
@@ -92,22 +83,22 @@
"Query the number of loops within the current operation.",
"unsigned", "getNumLoops">,
InterfaceMethod<"Query the input view at the given index.",
- "Value *", "getInput", (ins "unsigned":$i)
+ "Value ", "getInput", (ins "unsigned":$i)
>,
InterfaceMethod<"Query the output view at the given index.",
- "Value *", "getOutput", (ins "unsigned":$i)
+ "Value ", "getOutput", (ins "unsigned":$i)
>,
InterfaceMethod<[{
Query the index of the given input value, or `None` if the value is not
an input.
}],
- "Optional<unsigned>", "getIndexOfInput", (ins "Value *":$view)
+ "Optional<unsigned>", "getIndexOfInput", (ins "Value ":$view)
>,
InterfaceMethod<[{
Query the index of the given view value, or `None` if the value is not
an view.
}],
- "Optional<unsigned>", "getIndexOfOutput", (ins "Value *":$view)
+ "Optional<unsigned>", "getIndexOfOutput", (ins "Value ":$view)
>,
InterfaceMethod<[{
Query the type of the input view at the given index.
@@ -228,7 +219,7 @@
// TODO(ntv) this should go away once the usage of OptionalAttr triggers
// emission of builders with default arguments left unspecified.
let builders = [OpBuilder<
- "Builder *builder, OperationState &result, Value *input, Value *output", [{
+ "Builder *builder, OperationState &result, Value input, Value output", [{
return build(
builder, result, input, output, AffineMapAttr(), AffineMapAttr());
}]>];
diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
index 2226b5e..3249edb 100644
--- a/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
+++ b/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
@@ -1,19 +1,10 @@
//===- LinalgOps.h - Linalg Operations --------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_LINALG_LINALGOPS_H_
#define MLIR_DIALECT_LINALG_LINALGOPS_H_
@@ -78,13 +69,13 @@
/// Only permutation maps are currently supported.
SmallVector<AffineMap, 4> loopToOperandRangesMaps(Operation *op);
-#include "mlir/Dialect/Linalg/IR/LinalgLibraryOpInterfaces.h.inc"
+#include "mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterfaces.h.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/IR/LinalgOps.h.inc"
#define GET_OP_CLASSES
-#include "mlir/Dialect/Linalg/IR/LinalgLibraryOps.h.inc"
+#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.h.inc"
} // namespace linalg
} // namespace mlir
diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index b806d75..0445968 100644
--- a/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -1,19 +1,10 @@
//===- LinalgOps.td - Linalg dialect ops -------------------*- tablegen -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This is the operation definition file for linear algebra operations.
//
@@ -56,8 +47,8 @@
````
}];
let builders = [OpBuilder<
- "Builder *builder, OperationState &result, Value *min, Value *max, "
- "Value *step",
+ "Builder *builder, OperationState &result, Value min, Value max, "
+ "Value step",
[{
auto rangeType = RangeType::get(builder->getContext());
build(builder, result, rangeType, min, max, step);
@@ -112,7 +103,7 @@
}];
let builders = [OpBuilder<
- "Builder *b, OperationState &result, Value *base, "
+ "Builder *b, OperationState &result, Value base, "
"ValueRange indexings">];
let extraClassDeclaration = [{
@@ -124,12 +115,12 @@
MemRefType getBaseViewType() { return view()->getType().cast<MemRefType>(); }
// Get the underlying indexing at a given rank.
- Value *indexing(unsigned rank) { return *(indexings().begin() + rank); }
+ Value indexing(unsigned rank) { return *(indexings().begin() + rank); }
// Get the subset of indexings that are of RangeType.
- SmallVector<Value *, 8> getRanges() {
- SmallVector<Value *, 8> res;
- for (auto *operand : indexings())
+ SmallVector<Value, 8> getRanges() {
+ SmallVector<Value, 8> res;
+ for (auto operand : indexings())
if (!operand->getType().isa<IndexType>())
res.push_back(operand);
return res;
@@ -154,7 +145,7 @@
}];
let builders = [OpBuilder<
- "Builder *b, OperationState &result, Value *view, "
+ "Builder *b, OperationState &result, Value view, "
"AffineMapAttr permutation, ArrayRef<NamedAttribute> attrs = {}">];
let verifier = [{
diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
new file mode 100644
index 0000000..dd9e09b
--- /dev/null
+++ b/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -0,0 +1,616 @@
+//===- LinalgStructuredOps.td - Linalg dialect library ops -*- tablegen -*-===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the operation definition file for structured operations on buffers
+// that correspond to underlying library calls (e.g. BLAS).
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LINALG_STRUCTURED_OPS
+#define LINALG_STRUCTURED_OPS
+
+include "mlir/Dialect/AffineOps/AffineOpsBase.td"
+include "mlir/Dialect/Linalg/IR/LinalgBase.td"
+
+// The Linalg `NInputs` trait provides the API for ops that are known
+// to have a specified number of inputs, all passed as operands.
+// See Linalg/LinalgTraits.h for implementation details an usage.
+class NInputs<int args_in> :
+ NativeOpTrait<"linalg::NInputs<" # !cast<string>(args_in) # ">::Impl"> {}
+
+// The Linalg `NOutputs` trait provides the API for ops that are known
+// to have a specified number of outputs, all passed as operands.
+// See Linalg/LinalgTraits.h for implementation details an usage.
+class NOutputs<int args_out> :
+ NativeOpTrait<"linalg::NOutputs<" # !cast<string>(args_out) # ">::Impl"> {}
+
+def ViewTraits : NativeOpTrait<"linalg::ViewTraits">;
+
+// The linalg 'LinalgStructuredInterface' provides access to the 'LinalgOp'
+// interface.
+def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
+ let methods = [
+ InterfaceMethod<
+ "Query the number of inputs from the current operation.",
+ "unsigned", "getNumInputs"
+ >,
+ InterfaceMethod<
+ "Query the number of outputs from the current operation.",
+ "unsigned", "getNumOutputs"
+ >,
+ InterfaceMethod<
+ "Query the number of inputs and outputs from the current operation.",
+ "unsigned", "getNumInputsAndOutputs"
+ >,
+ InterfaceMethod<
+ "Query the input operands from the current operation.",
+ "Operation::operand_range", "getInputs"
+ >,
+ InterfaceMethod<
+ "Query the output operands from the current operation.",
+ "Operation::operand_range", "getOutputs"
+ >,
+ InterfaceMethod<
+ "Query the input and output operands from the current operation.",
+ "Operation::operand_range", "getInputsAndOutputs"
+ >,
+ InterfaceMethod<
+ "Query the iterator types attribute within the current operation.",
+ "ArrayAttr", "iterator_types"
+ >,
+ InterfaceMethod<
+ "Query the indexing maps attribute within the current operation.",
+ "ArrayAttr", "indexing_maps"
+ >,
+ InterfaceMethod<
+ "Query the number of parallel loops within the current operation.",
+ "unsigned", "getNumParallelLoops"
+ >,
+ InterfaceMethod<
+ "Query the number of reduction loops within the current operation.",
+ "unsigned", "getNumReductionLoops"
+ >,
+ InterfaceMethod<
+ "Query the number of window loops within the current operation.",
+ "unsigned", "getNumWindowLoops"
+ >,
+ InterfaceMethod<
+ "Query the number of loops within the current operation.",
+ "unsigned", "getNumLoops">,
+ InterfaceMethod<"Query the input view at the given index.",
+ "Value ", "getInput", (ins "unsigned":$i)
+ >,
+ InterfaceMethod<"Query the output view at the given index.",
+ "Value ", "getOutput", (ins "unsigned":$i)
+ >,
+ InterfaceMethod<[{
+ Query the index of the given input value, or `None` if the value is not
+ an input.
+ }],
+ "llvm::Optional<unsigned>", "getIndexOfInput", (ins "Value ":$view)
+ >,
+ InterfaceMethod<[{
+ Query the index of the given view value, or `None` if the value is not
+ an view.
+ }],
+ "llvm::Optional<unsigned>", "getIndexOfOutput", (ins "Value ":$view)
+ >,
+ InterfaceMethod<[{
+ Query the type of the input view at the given index.
+ }], "MemRefType", "getInputViewType", (ins "unsigned":$i)>,
+ InterfaceMethod<[{
+ Query the type of the output view at the given index.
+ }], "MemRefType", "getOutputViewType", (ins "unsigned":$i)>,
+
+ StaticInterfaceMethod<[{
+ Create an operation of the current type with the given location,
+ operands, and attributes.
+ }],
+ "Operation *", "create",
+ (ins "OpBuilder &":$builder, "Location":$loc,
+ "ValueRange":$operands,
+ "ArrayRef<NamedAttribute>":$attributes), [{
+ return builder.create<ConcreteOp>(loc, ArrayRef<Type>{}, operands,
+ attributes);
+ }]
+ >,
+
+ /// Clone an operation with the given location and operands. This is used to
+ /// abstract away the optional underlying region creation.
+ InterfaceMethod<[{
+ Clone the current operation with the given location and operands. This
+ is used to abstract away the optional underlying region creation.
+ }],
+ "Operation *", "clone",
+ (ins "OpBuilder &":$b, "Location":$loc, "ValueRange":$operands), [{
+ BlockAndValueMapping map;
+ unsigned numRegions = op.getOperation()->getNumRegions();
+ Operation *res = create(b, loc, operands, op.getAttrs());
+ assert(res->getNumRegions() == numRegions && "inconsistent # regions");
+ for (unsigned ridx = 0; ridx < numRegions; ++ridx)
+ op.getOperation()->getRegion(ridx).cloneInto(
+ &res->getRegion(ridx), map);
+ return res;
+ }]
+ >
+ ];
+}
+
+// Base Tablegen class for Linalg ops.
+// Linalg ops that correspond to library calls operate on linalg::View as their
+// first operands. These may be optionally followed by non-view operands
+// depending on the specific Linalg op.
+class LinalgStructuredBase_Op<string mnemonic, list<OpTrait> props>
+ : Op<Linalg_Dialect, mnemonic,
+ !listconcat(props, [ViewTraits, LinalgStructuredInterface])> {
+ let parser = [{ return parseLinalgStructuredOp(parser, result); }];
+ let printer = [{ printLinalgStructuredOp(p, *this); }];
+}
+
+class LinalgStructured_Op<string mnemonic, list<OpTrait> props>
+ : LinalgStructuredBase_Op<mnemonic, props> {
+ code libraryCallName = [{
+ std::string getLibraryCallName() {
+ return generateLibraryCallName(getOperation());
+ }
+ }];
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Concrete Linalg ops.
+////////////////////////////////////////////////////////////////////////////////
+def CopyOp : LinalgStructured_Op<"copy", [NInputs<1>, NOutputs<1>]> {
+ let description = [{
+ Copies the data in the input view into the output view.
+
+ Usage:
+ ```mlir
+ linalg.copy(%arg0, %arg1) : memref<?xf32, stride_specification>,
+ memref<?xf32, stride_specification>
+ ```
+
+ One possible lowering to loop form is:
+ ```mlir
+ %0 = linalg.dim %arg0, 0 : index
+ loop.for %i0 = %c0 to %0 step %c1 {
+ %1 = linalg.load %arg0[%i0] : memref<?xf32, stride_specification>
+ linalg.store %1, %arg1[%i0] : memref<?xf32, stride_specification>
+ }
+ ```
+
+ Optionally, can take `input_permutation` and `output_permutation` attributes
+ to reorder the dimensions of the input and output views.
+
+ Usage:
+ ```mlir
+ linalg.copy(%arg0, %arg1) {inputPermutation : (i, j, k) -> (i, k, j),
+ outputPermutation : (i, j, k) -> (k, j, i)} :
+ memref<?x?x?xf32, stride_specification>,
+ memref<?x?x?xf32, stride_specification>
+ ```
+
+ One possible lowering to loop form is:
+ ```mlir
+ %0 = linalg.dim %arg0, 0
+ %1 = linalg.dim %arg0, 1
+ %2 = linalg.dim %arg0, 2
+ loop.for %i0 = %c0 to %{{.*}} step %c1 {
+ loop.for %i1 = %c0 to %{{.*}} step %c1 {
+ loop.for %i2 = %c0 to %{{.*}} step %c1 {
+ %3 = linalg.load %arg0[%i0, %i2, %i1] :
+ memref<?x?x?xf32, stride_specification>
+ linalg.store %3, %arg1[%i2, %i1, %i0] :
+ memref<?x?x?xf32, stride_specification>
+ ```
+
+ The views are expected to be compatible for correctness but this is not
+ enforced at the moment.
+ }];
+ let arguments = (ins
+ AnyStridedMemRef:$input,
+ AnyStridedMemRef:$output,
+ OptionalAttr<AffineMapAttr>:$inputPermutation,
+ OptionalAttr<AffineMapAttr>:$outputPermutation);
+ // TODO(ntv) this should go away once the usage of OptionalAttr triggers
+ // emission of builders with default arguments left unspecified.
+ let builders = [OpBuilder<
+ "Builder *builder, OperationState &result, Value input, Value output", [{
+ return build(
+ builder, result, input, output, AffineMapAttr(), AffineMapAttr());
+ }]>];
+ let extraClassDeclaration = libraryCallName # [{
+ ArrayAttr indexing_maps();
+
+ ArrayAttr iterator_types() {
+ unsigned nPar = input()->getType().cast<ShapedType>().getRank();
+ MLIRContext *ctx = getContext();
+ SmallVector<Attribute, 8> iters(
+ nPar, StringAttr::get(getParallelIteratorTypeName(), ctx));
+ return ArrayAttr::get(iters, ctx);
+ }
+ }];
+ let verifier = [{ return ::verify(*this); }];
+}
+
+def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> {
+ let arguments = (ins AnyStridedMemRef:$input,
+ AnyTypeOf<[AnyFloat, AnyInteger, AnyVector]>:$value);
+ let extraClassDeclaration = libraryCallName # [{
+ ArrayAttr indexing_maps();
+
+ ArrayAttr iterator_types() {
+ unsigned nPar = input()->getType().cast<ShapedType>().getRank();
+ MLIRContext *ctx = getContext();
+ SmallVector<Attribute, 8> iters(
+ nPar, StringAttr::get(getParallelIteratorTypeName(), ctx));
+ return ArrayAttr::get(iters, ctx);
+ }
+ }];
+ let verifier = [{ return ::verify(*this); }];
+}
+
+def DotOp : LinalgStructured_Op<"dot", [NInputs<2>, NOutputs<1>]> {
+ let arguments = (ins AnyStridedMemRefOfRank<1>,
+ AnyStridedMemRefOfRank<1>,
+ AnyStridedMemRefOfRank<0>);
+ let extraClassDeclaration = libraryCallName # [{
+ ArrayAttr indexing_maps();
+
+ ArrayAttr iterator_types() {
+ MLIRContext *ctx = getContext();
+ return ArrayAttr::get(
+ StringAttr::get(getReductionIteratorTypeName(), ctx), ctx);
+ }
+ }];
+}
+
+def MatvecOp : LinalgStructured_Op<"matvec", [NInputs<2>, NOutputs<1>]> {
+ let arguments = (ins AnyStridedMemRefOfRank<2>,
+ AnyStridedMemRefOfRank<1>,
+ AnyStridedMemRefOfRank<1>);
+ let extraClassDeclaration = libraryCallName # [{
+ ArrayAttr indexing_maps();
+
+ ArrayAttr iterator_types() {
+ MLIRContext *ctx = getContext();
+ Attribute iters[2]{
+ StringAttr::get(getParallelIteratorTypeName(), ctx),
+ StringAttr::get(getReductionIteratorTypeName(), ctx)};
+ return ArrayAttr::get(iters, ctx);
+ }
+ }];
+}
+
+def MatmulOp : LinalgStructured_Op<"matmul", [NInputs<2>, NOutputs<1>]> {
+ let arguments = (ins AnyStridedMemRefOfRank<2>,
+ AnyStridedMemRefOfRank<2>,
+ AnyStridedMemRefOfRank<2>);
+ let extraClassDeclaration = libraryCallName # [{
+ ArrayAttr indexing_maps();
+
+ ArrayAttr iterator_types() {
+ MLIRContext *ctx = getContext();
+ Attribute iters[3]{
+ StringAttr::get(getParallelIteratorTypeName(), ctx),
+ StringAttr::get(getParallelIteratorTypeName(), ctx),
+ StringAttr::get(getReductionIteratorTypeName(), ctx)};
+ return ArrayAttr::get(iters, ctx);
+ }
+ }];
+}
+
+def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
+ let description = [{
+ Generic n-D convolution as described in the TF documentation:
+ https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/nn/convolution
+
+ ```
+ output[b, x[0], ..., x[N-1], k] =
+ sum_{z[0], ..., z[N-1], q}
+ filter[z[0], ..., z[N-1], q, k] *
+ padded_input[b,
+ x[0] * strides[0] + dilation_rate[0] * z[0],
+ ...,
+ x[N-1] * strides[N-1] + dilation_rate[N-1] * z[N-1],
+ q]
+ ```
+ }];
+
+ // TODO(ntv) padding.
+ // Following the TF source of truth above, strides and dilations are integer
+ // attributes of the same rank as the number of window dimensions.
+ let arguments = (ins AnyStridedMemRef:$filter, AnyStridedMemRef:$input,
+ AnyStridedMemRef:$output,
+ OptionalAttr<I64ArrayAttr>:$strides,
+ OptionalAttr<I64ArrayAttr>:$dilations);
+ let extraClassDeclaration = libraryCallName # [{
+ // TODO(ntv) extend to support more than 1 dimensions and potentially
+ // grouping too.
+ unsigned getNumBatchDimensions() { return 1; }
+ unsigned getNumInputFeatureDimensions() { return 1; }
+ unsigned getNumOutputFeatureDimensions() { return 1; }
+
+ ArrayAttr indexing_maps();
+
+ ArrayAttr iterator_types() {
+ // Outer parallel loops are always the number of output dimensions; i.e.
+ // [ b, xs, q] in the TF notation above.
+ unsigned nPar = getOutputViewType(0).getRank();
+ unsigned nRed = getNumInputFeatureDimensions();
+ // Window loops are a special kind of reduction that is never tiled or
+ // parallelized across; i.e. [zs] in the TF notation above whose number
+ // match `xs` (i.e. 1 window loop per "image" dimension).
+ // This may evolve in the future.
+ unsigned nWin =
+ nPar - getNumBatchDimensions() - getNumInputFeatureDimensions();
+ MLIRContext *ctx = getContext();
+ SmallVector<Attribute, 8> iters(
+ nPar, StringAttr::get(getParallelIteratorTypeName(), ctx));
+ iters.reserve(nPar + nRed + nWin);
+ iters.append(nRed, StringAttr::get(getReductionIteratorTypeName(), ctx));
+ iters.append(nWin, StringAttr::get(getWindowIteratorTypeName(), ctx));
+ return ArrayAttr::get(iters, ctx);
+ }
+
+ int64_t getStride(unsigned i) {
+ assert(i < getNumWindowLoops());
+ if (!strides().hasValue()) return 1;
+ return strides()->getValue()[i]
+ .cast<IntegerAttr>().getValue().getSExtValue();
+ }
+
+ int64_t getDilation(unsigned i) {
+ assert(i < getNumWindowLoops());
+ if (!dilations().hasValue()) return 1;
+ return dilations()->getValue()[i]
+ .cast<IntegerAttr>().getValue().getSExtValue();
+ }
+ }];
+ let verifier = [{ return ::verify(*this); }];
+}
+
+class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, []> {
+ let arguments = (ins Variadic<AnyStridedMemRef>:$views,
+ I64Attr:$args_in,
+ I64Attr:$args_out,
+ AffineMapArrayAttr:$indexing_maps,
+ ArrayAttr:$iterator_types,
+ OptionalAttr<StrAttr>:$doc,
+ OptionalAttr<FlatSymbolRefAttr>:$fun,
+ OptionalAttr<StrAttr>:$library_call);
+ let regions = (region AnyRegion:$region);
+ let extraClassDeclaration = [{
+ SmallVector<StringRef, 8> linalgTraitAttrNames() {
+ return SmallVector<StringRef, 8>{
+ getArgsInAttrName(), getArgsOutAttrName(), getDocAttrName(),
+ getFunAttrName(), getIndexingMapsAttrName(), getLibraryCallAttrName(),
+ getIteratorTypesAttrName()
+ };
+ }
+ unsigned getNumInputs() { return args_in().getSExtValue(); }
+ unsigned getNumOutputs() { return args_out().getSExtValue(); }
+ FuncOp getFunction() {
+ auto moduleOp = getParentOfType<ModuleOp>();
+ return fun().hasValue() ?
+ moduleOp.lookupSymbol<FuncOp>(fun().getValue()) : FuncOp();
+ }
+ StringRef getLibraryCallName() {
+ return library_call().hasValue() ? library_call().getValue() : "";
+ }
+ AffineMap getIndexingMap(unsigned i) {
+ assert(i < getNumInputsAndOutputs());
+ return indexing_maps().getValue()[i].cast<AffineMapAttr>().getValue();
+ }
+ AffineMap getInputIndexingMap(unsigned i) {
+ assert(i < getNumInputs());
+ return indexing_maps().getValue()[i].cast<AffineMapAttr>().getValue();
+ }
+ AffineMap getOutputIndexingMap(unsigned i) {
+ assert(i < getNumOutputs());
+ return indexing_maps().getValue()[i + getNumInputs()]
+ .cast<AffineMapAttr>().getValue();
+ }
+ }];
+ let printer = [{ return ::print(p, *this); }];
+ let parser = [{ return ::parseGenericOp(parser, result); }];
+}
+
+def GenericOp : GenericOpBase<"generic"> {
+ let description = [{
+ Generic Linalg op form where the key properties of the computation are
+ specified as attributes. In pretty form, a linalg.generic op is written as:
+
+ ```mlir
+ linalg.generic #trait_attribute %A, %B, %C {other-attributes} :
+ memref<?x?xf32, stride_specification>,
+ memref<?x?xf32, stride_specification>,
+ memref<?x?xf32, stride_specification>
+ ```
+
+ Where #trait_attributes is an alias of a dictionary attribute containing:
+ - args_in: an I64Attr representing the number of input (readonly) views
+ - args_out: an I64Attr representing the number of output (readwrite) views
+ - doc [optional]: a documentation string
+ - fun: a FlatSymbolRefAttr that must resolve to an existing function
+ symbol. To support inplace updates in a generic fashion, the signature
+ of the function must be:
+ ```
+ fun([input views element types], [output views element types])
+ -> ([output views element types])
+ ```
+ - indexing_maps: a list of AffineMapAttr, one AffineMapAttr per each input
+ and output view. Such AffineMapAttr specifies the mapping between the
+ loops and the indexing within each view.
+ - library_call [optional]: a StringAttr containing the name of an
+ external library function that the linalg.generic operation maps to.
+ The external library is assumed to be dynamically linked and no strong
+ compile-time guarantees are provided. In the absence of such a library
+ call, linalg.generic will always lower to loops.
+ - iterator_types: an ArrayAttr specifying the type of the enclosing loops.
+ Each element of the list represents and iterator of one of the following
+ types:
+ parallel, reduction, window
+
+ Example:
+ Defining a #matmul_trait attribute in MLIR can be done as follows:
+ ```mlir
+ func @fma(%a: f32, %b: f32, %c: f32) -> f32 {
+ %d = mulf %a, %b: f32
+ %e = addf %c, %d: f32
+ return %e: f32
+ }
+ #matmul_accesses = [
+ (m, n, k) -> (m, k),
+ (m, n, k) -> (k, n),
+ (m, n, k) -> (m, n)
+ ]
+ #matmul_trait = {
+ doc = "C(m, n) += A(m, k) * B(k, n)",
+ fun = @fma,
+ indexing_maps = #matmul_accesses,
+ library_call = "linalg_matmul",
+ n_views = [2, 1],
+ iterator_types = ["parallel", "parallel", "reduction"]
+ }
+ ```
+
+ And can be reused in multiple places as:
+ ```mlir
+ linalg.generic #matmul_trait %A, %B, %C [other-attributes] :
+ memref<?x?xf32, stride_specification>,
+ memref<?x?xf32, stride_specification>,
+ memref<?x?xf32, stride_specification>
+ ```
+
+ This may lower to either:
+ ```mlir
+ call @linalg_matmul(%A, %B, %C) :
+ (memref<?x?xf32, stride_specification>,
+ memref<?x?xf32, stride_specification>,
+ memref<?x?xf32, stride_specification>)
+ -> ()
+ ```
+
+ or IR resembling:
+ ```mlir
+ loop.for %m = %c0 to %M step %c1 {
+ loop.for %n = %c0 to %N step %c1 {
+ loop.for %k = %c0 to %K step %c1 {
+ %a = linalg.load %A[%m, %k] : memref<?x?xf32, stride_specification>
+ %b = linalg.load %B[%k, %n] : memref<?x?xf32, stride_specification>
+ %c = linalg.load %C[%m, %n] : memref<?x?xf32, stride_specification>
+ %d = call @func_of_elements(%a, %b, %c)
+ : (f32, f32, f32) -> (f32)
+ linalg.store %d, %C[%m, %n] : memref<?x?x?xf32, stride_specification>
+ }
+ }
+ }
+ ```
+ }];
+ let verifier = [{ return ::verify(*this); }];
+}
+
+def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
+ let description = [{
+ Indexed Generic Linalg op form where the key properties of the computation
+ are specified as attributes. In pretty form, a linalg.indexed_generic op is
+ written as:
+
+ ```mlir
+ linalg.indexed_generic #trait_attribute %A, %B, %C {other-attributes} :
+ memref<?x?xf32, stride_specification>,
+ memref<?x?xf32, stride_specification>,
+ memref<?x?xf32, stride_specification>
+ ```
+
+ Where #trait_attributes is an alias of a dictionary attribute containing:
+ - args_in: an I64Attr representing the number of input (readonly) views
+ - args_out: an I64Attr representing the number of output (readwrite) views
+ - doc [optional]: a documentation string
+ - fun: a FlatSymbolRefAttr that must resolve to an existing function
+ symbol. To support inplace updates in a generic fashion, the signature
+ of the function must be:
+ ```
+ fun([index types of induction variables], [input views element types],
+ [output views element types]) -> ([output views element types])
+ ```
+ - indexing_maps: a list of AffineMapAttr, one AffineMapAttr per each input
+ and output view. Such AffineMapAttr specifies the mapping between the
+ loops and the indexing within each view.
+ - library_call [optional]: a StringAttr containing the name of an
+ external library function that the linalg.indexed_generic operation
+ maps to. The external library is assumed to be dynamically linked and
+ no strong compile-time guarantees are provided. In the absence of such
+ a library call, linalg.indexed_generic will always lower to loops.
+ - iterator_types: an ArrayAttr they type of the enclosing loops; Each
+ element of the list represents and iterator of one of the following
+ types:
+ parallel, reduction, window
+
+ Example:
+ Defining a #matmul_trait attribute in MLIR can be done as follows:
+ ```mlir
+ func @fma(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32)
+ -> f32
+ {
+ %d = mulf %a, %b: f32
+ %e = addf %c, %d: f32
+ return %e: f32
+ }
+ #matmul_accesses = [
+ (m, n, k) -> (m, k),
+ (m, n, k) -> (k, n),
+ (m, n, k) -> (m, n)
+ ]
+ #matmul_trait = {
+ doc = "C(m, n) += A(m, k) * B(k, n)",
+ fun = @fma,
+ indexing_maps = #matmul_accesses,
+ library_call = "linalg_matmul",
+ n_views = [2, 1],
+ iterator_types = ["parallel", "parallel", "reduction"]
+ }
+ ```
+
+ And can be reused in multiple places as:
+ ```mlir
+ linalg.indexed_generic #matmul_trait %A, %B, %C [other-attributes] :
+ memref<?x?xf32, stride_specification>,
+ memref<?x?xf32, stride_specification>,
+ memref<?x?xf32, stride_specification>
+ ```
+
+ This may lower to either:
+ ```mlir
+ call @linalg_matmul(%A, %B, %C) :
+ (memref<?x?xf32, stride_specification>,
+ memref<?x?xf32, stride_specification>,
+ memref<?x?xf32, stride_specification>)
+ -> ()
+ ```
+
+ or IR resembling:
+ ```mlir
+ loop.for %m = %c0 to %M step %c1 {
+ loop.for %n = %c0 to %N step %c1 {
+ loop.for %k = %c0 to %K step %c1 {
+ %a = linalg.load %A[%m, %k] : memref<?x?xf32, stride_specification>
+ %b = linalg.load %B[%k, %n] : memref<?x?xf32, stride_specification>
+ %c = linalg.load %C[%m, %n] : memref<?x?xf32, stride_specification>
+ %d = call @func_of_elements_and_indices(%m, %n, %k, %a, %b, %c)
+ : (index, index, index, f32, f32, f32) -> (f32)
+ linalg.store %d, %C[%m, %n] : memref<?x?x?xf32, stride_specification>
+ }
+ }
+ }
+ ```
+ }];
+ let verifier = [{ return ::verify(*this); }];
+}
+
+#endif // LINALG_STRUCTURED_OPS
diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
index a24c1ca..e0d6518 100644
--- a/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
+++ b/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
@@ -1,19 +1,10 @@
//===- LinalgTraits.h - Linalg Traits ---------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_LINALG_LINALGTRAITS_H_
#define MLIR_DIALECT_LINALG_LINALGTRAITS_H_
@@ -77,13 +68,13 @@
public:
/// Return the `i`-th input view.
- Value *getInput(unsigned i) {
+ Value getInput(unsigned i) {
assert(i < nInputs());
return this->getOperation()->getOperand(i);
}
/// Return the index of `view` in the list of input views if found, llvm::None
/// otherwise.
- Optional<unsigned> getIndexOfInput(Value *view) {
+ Optional<unsigned> getIndexOfInput(Value view) {
auto it = llvm::find(getInputs(), view);
if (it != getInputs().end())
return it - getInputs().begin();
@@ -99,12 +90,12 @@
return {range.begin(), range.begin() + nInputs()};
}
/// Return the `i`-th output view.
- Value *getOutput(unsigned i) {
+ Value getOutput(unsigned i) {
return this->getOperation()->getOperand(nInputs() + i);
}
/// Return the index of `view` in the list of output views if found,
/// llvm::None otherwise.
- Optional<unsigned> getIndexOfOutput(Value *view) {
+ Optional<unsigned> getIndexOfOutput(Value view) {
auto it = llvm::find(getOutputs(), view);
if (it != getOutputs().end())
return it - getOutputs().begin();
diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h b/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
index f779c3de..abeda3e 100644
--- a/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
+++ b/third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
@@ -1,19 +1,10 @@
//===- LinalgTypes.h - Linalg Types ---------------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_LINALG_LINALGTYPES_H_
#define MLIR_DIALECT_LINALG_LINALGTYPES_H_
diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/Passes.h b/third_party/mlir/include/mlir/Dialect/Linalg/Passes.h
index 7ae3877..86cf6fd 100644
--- a/third_party/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/third_party/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -1,19 +1,10 @@
//===- Passes.h - Linalg pass entry points ----------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This header file defines prototypes that expose pass constructors.
//
diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td b/third_party/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td
index d92eb77..8f6762f 100644
--- a/third_party/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td
+++ b/third_party/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td
@@ -1,19 +1,10 @@
//===- LinalgPatterns.td - Linalg transformation patterns --*- tablegen -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This is the pattern definition file for declarative Linalg transformation.
//
@@ -23,7 +14,7 @@
#define LINALG_TRANSFORMS
include "mlir/Dialect/Linalg/IR/LinalgOps.td"
-include "mlir/Dialect/Linalg/IR/LinalgLibraryOps.td"
+include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.td"
include "mlir/Dialect/AffineOps/AffineOps.td"
def HasNoLinalgTransformMarker : CPred<[{
@@ -45,7 +36,7 @@
class HasOperandsOfType<string type>: CPred<[{
llvm::any_of($0.getOperands(),
- [](Value* v) {
+ [](Value v) {
return dyn_cast_or_null<}] # type # [{>(v->getDefiningOp());
})
}]>;
diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h b/third_party/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h
index dfbac5a..757ee3a 100644
--- a/third_party/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h
+++ b/third_party/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h
@@ -1,19 +1,10 @@
//===- LinalgTransforms.h - Linalg transformations as patterns --*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef DIALECT_LINALG_TRANSFORMS_LINALGTRANSFORMS_H_
#define DIALECT_LINALG_TRANSFORMS_LINALGTRANSFORMS_H_
@@ -38,7 +29,7 @@
namespace detail {
// Implementation detail of isProducedByOpOfType avoids the need for explicit
// template instantiations.
-bool isProducedByOpOfTypeImpl(Operation *consumerOp, Value *consumedView,
+bool isProducedByOpOfTypeImpl(Operation *consumerOp, Value consumedView,
function_ref<bool(Operation *)> isaOpType);
} // namespace detail
@@ -46,7 +37,7 @@
// an op of type `OpTy`. This is used to implement use-def type information on
// buffers.
template <typename OpTy>
-bool isProducedByOpOfType(Operation *consumerOp, Value *consumedView) {
+bool isProducedByOpOfType(Operation *consumerOp, Value consumedView) {
return detail::isProducedByOpOfTypeImpl(
consumerOp, consumedView, [](Operation *op) { return isa<OpTy>(op); });
}
diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/Utils/Intrinsics.h b/third_party/mlir/include/mlir/Dialect/Linalg/Utils/Intrinsics.h
index 5a815ba..778d853 100644
--- a/third_party/mlir/include/mlir/Dialect/Linalg/Utils/Intrinsics.h
+++ b/third_party/mlir/include/mlir/Dialect/Linalg/Utils/Intrinsics.h
@@ -1,19 +1,10 @@
//===- Intrinsics.h - Linalg intrinsics definitions -----------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_LINALG_INTRINSICS_H_
#define MLIR_DIALECT_LINALG_INTRINSICS_H_
diff --git a/third_party/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/third_party/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index f8d10ec..996658b 100644
--- a/third_party/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/third_party/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -1,19 +1,10 @@
//===- Utils.h - Utilities to support the Linalg dialect --------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_LINALG_UTILS_H_
#define MLIR_DIALECT_LINALG_UTILS_H_
@@ -34,7 +25,7 @@
/// A LoopRangeBuilder is a generic NestedBuilder for loop.for operations.
/// More specifically it is meant to be used as a temporary object for
-/// representing any nested MLIR construct that is "related to" an mlir::Value*
+/// representing any nested MLIR construct that is "related to" an mlir::Value
/// (for now an induction variable).
class LoopRangeBuilder : public NestedBuilder {
public:
@@ -42,7 +33,7 @@
/// variable. A ValueHandle pointer is passed as the first argument and is the
/// *only* way to capture the loop induction variable.
LoopRangeBuilder(ValueHandle *iv, ValueHandle range);
- LoopRangeBuilder(ValueHandle *iv, Value *range);
+ LoopRangeBuilder(ValueHandle *iv, Value range);
LoopRangeBuilder(ValueHandle *iv, SubViewOp::Range range);
LoopRangeBuilder(const LoopRangeBuilder &) = delete;
@@ -65,7 +56,7 @@
LoopNestRangeBuilder(ArrayRef<edsc::ValueHandle *> ivs,
ArrayRef<edsc::ValueHandle> ranges);
LoopNestRangeBuilder(ArrayRef<edsc::ValueHandle *> ivs,
- ArrayRef<Value *> ranges);
+ ArrayRef<Value> ranges);
LoopNestRangeBuilder(ArrayRef<edsc::ValueHandle *> ivs,
ArrayRef<SubViewOp::Range> ranges);
edsc::ValueHandle operator()(std::function<void(void)> fun = nullptr);
@@ -88,14 +79,14 @@
/// whole `consumedView`. This checks structural dominance, that the dependence
/// is a RAW without any interleaved write to any piece of `consumedView`.
bool isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
- LinalgOp consumer, Value *consumedView,
+ LinalgOp consumer, Value consumedView,
LinalgOp producer);
/// Checks whether fusing the specific `producer` of the `consumedView` is
/// feasible. This checks `producer` is the last write of `consumedView` and
/// that no interleaved dependence would be violated (RAW, WAR or WAW).
bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer,
- Value *consumedView, LinalgOp producer);
+ Value consumedView, LinalgOp producer);
/// Fuses producer into consumer if the producer is structurally feasible and
/// the fusion would not violate dependencies.
@@ -111,8 +102,8 @@
/// the inverse, concatenated loopToOperandRangeMaps to this list allows the
/// derivation of loop ranges for any linalgOp.
template <typename ConcreteOp>
-SmallVector<Value *, 8> getViewSizes(ConcreteOp linalgOp) {
- SmallVector<Value *, 8> res;
+SmallVector<Value, 8> getViewSizes(ConcreteOp linalgOp) {
+ SmallVector<Value, 8> res;
for (auto v : linalgOp.getInputsAndOutputs()) {
MemRefType t = v->getType().template cast<MemRefType>();
for (unsigned i = 0; i < t.getRank(); ++i)
@@ -125,10 +116,9 @@
/// When non-null, the optional pointer `folder` is used to call into the
/// `createAndFold` builder method. If `folder` is null, the regular `create`
/// method is called.
-SmallVector<Value *, 4> applyMapToValues(OpBuilder &b, Location loc,
- AffineMap map,
- ArrayRef<Value *> values,
- OperationFolder *folder = nullptr);
+SmallVector<Value, 4> applyMapToValues(OpBuilder &b, Location loc,
+ AffineMap map, ArrayRef<Value> values,
+ OperationFolder *folder = nullptr);
struct TiledLinalgOp {
LinalgOp op;
@@ -151,7 +141,7 @@
/// `createAndFold` builder method. If `folder` is null, the regular `create`
/// method is called.
Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
- ArrayRef<Value *> tileSizes,
+ ArrayRef<Value> tileSizes,
ArrayRef<unsigned> permutation = {},
OperationFolder *folder = nullptr);
@@ -182,9 +172,9 @@
}
struct PromotionInfo {
- Value *buffer;
- Value *fullLocalView;
- Value *partialLocalView;
+ Value buffer;
+ Value fullLocalView;
+ Value partialLocalView;
};
/// Promotes the `subViews` into a new buffer allocated at the insertion point
@@ -199,13 +189,13 @@
/// Returns a list of PromotionInfo which hold the promoted buffer and the
/// full and partial views indexing into the buffer.
SmallVector<PromotionInfo, 8>
-promoteSubViews(OpBuilder &b, Location loc, ArrayRef<Value *> subViews,
+promoteSubViews(OpBuilder &b, Location loc, ArrayRef<Value> subViews,
bool dynamicBuffers = false, OperationFolder *folder = nullptr);
/// Returns all the operands of `linalgOp` that are not views.
/// Asserts that these operands are value types to allow transformations like
/// tiling to just use the values when cloning `linalgOp`.
-SmallVector<Value *, 4> getAssumedNonViewOperands(LinalgOp linalgOp);
+SmallVector<Value, 4> getAssumedNonViewOperands(LinalgOp linalgOp);
/// Apply the permutation defined by `permutation` to `inVec`.
/// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
@@ -226,7 +216,7 @@
/// It is the entry point for declarative transformation
/// Returns the cloned `LinalgOp` with the new operands
LinalgOp promoteSubViewOperands(OpBuilder &b, LinalgOp op,
- llvm::SetVector<Value *> subViews,
+ llvm::SetVector<Value> subViews,
bool dynamicBuffers = false,
OperationFolder *folder = nullptr);
diff --git a/third_party/mlir/include/mlir/Dialect/LoopOps/CMakeLists.txt b/third_party/mlir/include/mlir/Dialect/LoopOps/CMakeLists.txt
index 9f5863f..0fda882 100644
--- a/third_party/mlir/include/mlir/Dialect/LoopOps/CMakeLists.txt
+++ b/third_party/mlir/include/mlir/Dialect/LoopOps/CMakeLists.txt
@@ -1 +1 @@
-add_mlir_dialect(LoopOps)
+add_mlir_dialect(LoopOps LoopOps)
diff --git a/third_party/mlir/include/mlir/Dialect/LoopOps/LoopOps.h b/third_party/mlir/include/mlir/Dialect/LoopOps/LoopOps.h
index fdadf4a..2617d7f 100644
--- a/third_party/mlir/include/mlir/Dialect/LoopOps/LoopOps.h
+++ b/third_party/mlir/include/mlir/Dialect/LoopOps/LoopOps.h
@@ -1,19 +1,10 @@
//===- Ops.h - Loop MLIR Operations -----------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines convenience types for working with loop operations.
//
@@ -50,7 +41,7 @@
/// Returns the loop parent of an induction variable. If the provided value is
/// not an induction variable, then return nullptr.
-ForOp getForInductionVarOwner(Value *val);
+ForOp getForInductionVarOwner(Value val);
} // end namespace loop
} // end namespace mlir
diff --git a/third_party/mlir/include/mlir/Dialect/LoopOps/LoopOps.td b/third_party/mlir/include/mlir/Dialect/LoopOps/LoopOps.td
index 5e0b809..707b788 100644
--- a/third_party/mlir/include/mlir/Dialect/LoopOps/LoopOps.td
+++ b/third_party/mlir/include/mlir/Dialect/LoopOps/LoopOps.td
@@ -1,19 +1,10 @@
//===- Ops.td - Loop operation definitions ---------------*- tablegen -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Defines MLIR loop operations.
//
@@ -74,18 +65,18 @@
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<"Builder *builder, OperationState &result, "
- "Value *lowerBound, Value *upperBound, Value *step">
+ "Value lowerBound, Value upperBound, Value step">
];
let extraClassDeclaration = [{
Block *getBody() { return ®ion().front(); }
- Value *getInductionVar() { return getBody()->getArgument(0); }
+ Value getInductionVar() { return getBody()->getArgument(0); }
OpBuilder getBodyBuilder() {
return OpBuilder(getBody(), std::prev(getBody()->end()));
}
- void setLowerBound(Value *bound) { getOperation()->setOperand(0, bound); }
- void setUpperBound(Value *bound) { getOperation()->setOperand(1, bound); }
- void setStep(Value *step) { getOperation()->setOperand(2, step); }
+ void setLowerBound(Value bound) { getOperation()->setOperand(0, bound); }
+ void setUpperBound(Value bound) { getOperation()->setOperand(1, bound); }
+ void setStep(Value step) { getOperation()->setOperand(2, step); }
}];
}
@@ -116,7 +107,7 @@
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<"Builder *builder, OperationState &result, "
- "Value *cond, bool withElseRegion">
+ "Value cond, bool withElseRegion">
];
let extraClassDeclaration = [{
diff --git a/third_party/mlir/include/mlir/Dialect/QuantOps/CMakeLists.txt b/third_party/mlir/include/mlir/Dialect/QuantOps/CMakeLists.txt
index f95532e..90a61c4 100644
--- a/third_party/mlir/include/mlir/Dialect/QuantOps/CMakeLists.txt
+++ b/third_party/mlir/include/mlir/Dialect/QuantOps/CMakeLists.txt
@@ -1 +1 @@
-add_mlir_dialect(QuantOps)
+add_mlir_dialect(QuantOps QuantOps)
diff --git a/third_party/mlir/include/mlir/Dialect/QuantOps/FakeQuantSupport.h b/third_party/mlir/include/mlir/Dialect/QuantOps/FakeQuantSupport.h
index 23e2967..1a141e3 100644
--- a/third_party/mlir/include/mlir/Dialect/QuantOps/FakeQuantSupport.h
+++ b/third_party/mlir/include/mlir/Dialect/QuantOps/FakeQuantSupport.h
@@ -1,19 +1,10 @@
//===- FakeQuantSupport.h - Support utilities for FakeQuant ops -*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines support utilities for interoperating with FakeQuant* based
// QAT (Quantized Aware Training) computations, as implemented by TFLite. Note
diff --git a/third_party/mlir/include/mlir/Dialect/QuantOps/Passes.h b/third_party/mlir/include/mlir/Dialect/QuantOps/Passes.h
index c57d7bf..d310977 100644
--- a/third_party/mlir/include/mlir/Dialect/QuantOps/Passes.h
+++ b/third_party/mlir/include/mlir/Dialect/QuantOps/Passes.h
@@ -1,19 +1,10 @@
//===- Passes.h - Quantization Passes ------ --------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines all of the passes owned by the quantization dialect. As
// things mature, it is expected that passes specific to certain frontend or
diff --git a/third_party/mlir/include/mlir/Dialect/QuantOps/QuantOps.h b/third_party/mlir/include/mlir/Dialect/QuantOps/QuantOps.h
index 020d349..9a4eec6 100644
--- a/third_party/mlir/include/mlir/Dialect/QuantOps/QuantOps.h
+++ b/third_party/mlir/include/mlir/Dialect/QuantOps/QuantOps.h
@@ -1,19 +1,10 @@
//===- QuantOps.h - Quantization Ops and Types ------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_QUANTOPS_QUANTOPS_H_
#define MLIR_DIALECT_QUANTOPS_QUANTOPS_H_
diff --git a/third_party/mlir/include/mlir/Dialect/QuantOps/QuantOps.td b/third_party/mlir/include/mlir/Dialect/QuantOps/QuantOps.td
index 072715d..bbeb941 100644
--- a/third_party/mlir/include/mlir/Dialect/QuantOps/QuantOps.td
+++ b/third_party/mlir/include/mlir/Dialect/QuantOps/QuantOps.td
@@ -1,19 +1,10 @@
//===- QuantOps.td - Quantization operation definition -----*- tablegen -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This is the operation definition file for Quantization.
//
diff --git a/third_party/mlir/include/mlir/Dialect/QuantOps/QuantPredicates.td b/third_party/mlir/include/mlir/Dialect/QuantOps/QuantPredicates.td
index 2fbb799..7225dcc 100644
--- a/third_party/mlir/include/mlir/Dialect/QuantOps/QuantPredicates.td
+++ b/third_party/mlir/include/mlir/Dialect/QuantOps/QuantPredicates.td
@@ -1,19 +1,10 @@
//===- QuantPredicates.td - Predicates for dialect types ---*- tablegen -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Predicates for types in the Quantization dialect.
//
diff --git a/third_party/mlir/include/mlir/Dialect/QuantOps/QuantTypes.h b/third_party/mlir/include/mlir/Dialect/QuantOps/QuantTypes.h
index 55e921f..daeb037 100644
--- a/third_party/mlir/include/mlir/Dialect/QuantOps/QuantTypes.h
+++ b/third_party/mlir/include/mlir/Dialect/QuantOps/QuantTypes.h
@@ -1,19 +1,10 @@
//===- QuantTypes.h - Quantization Ops and Types ----------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_QUANTOPS_QUANT_TYPES_H_
#define MLIR_DIALECT_QUANTOPS_QUANT_TYPES_H_
diff --git a/third_party/mlir/include/mlir/Dialect/QuantOps/QuantizeUtils.h b/third_party/mlir/include/mlir/Dialect/QuantOps/QuantizeUtils.h
index de87ca1..c40b9e6 100644
--- a/third_party/mlir/include/mlir/Dialect/QuantOps/QuantizeUtils.h
+++ b/third_party/mlir/include/mlir/Dialect/QuantOps/QuantizeUtils.h
@@ -1,19 +1,10 @@
//===- QuantizeUtils.h - Support utilities for quantization -----*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_QUANTOPS_QUANTIZEUTILS_H_
#define MLIR_DIALECT_QUANTOPS_QUANTIZEUTILS_H_
diff --git a/third_party/mlir/include/mlir/Dialect/QuantOps/UniformSupport.h b/third_party/mlir/include/mlir/Dialect/QuantOps/UniformSupport.h
index 0416db3..7c74fc5 100644
--- a/third_party/mlir/include/mlir/Dialect/QuantOps/UniformSupport.h
+++ b/third_party/mlir/include/mlir/Dialect/QuantOps/UniformSupport.h
@@ -1,19 +1,10 @@
//===- UniformSupport.h - Support utilities for uniform quant ---*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_QUANTOPS_UNIFORMSUPPORT_H_
#define MLIR_DIALECT_QUANTOPS_UNIFORMSUPPORT_H_
diff --git a/third_party/mlir/include/mlir/Dialect/SDBM/SDBM.h b/third_party/mlir/include/mlir/Dialect/SDBM/SDBM.h
index f95a51e..c8a0eec 100644
--- a/third_party/mlir/include/mlir/Dialect/SDBM/SDBM.h
+++ b/third_party/mlir/include/mlir/Dialect/SDBM/SDBM.h
@@ -1,19 +1,10 @@
//===- SDBM.h - MLIR SDBM declaration ---------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// A striped difference-bound matrix (SDBM) is a set in Z^N (or R^N) defined
// as {(x_1, ... x_n) | f(x_1, ... x_n) >= 0} where f is an SDBM expression.
diff --git a/third_party/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h b/third_party/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h
index e3573ba..501c661 100644
--- a/third_party/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h
+++ b/third_party/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h
@@ -1,19 +1,10 @@
//===- SDBMDialect.h - Dialect for striped DBMs -----------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_SDBM_SDBMDIALECT_H
#define MLIR_DIALECT_SDBM_SDBMDIALECT_H
diff --git a/third_party/mlir/include/mlir/Dialect/SDBM/SDBMExpr.h b/third_party/mlir/include/mlir/Dialect/SDBM/SDBMExpr.h
index 8cb5ef0..84a9a84 100644
--- a/third_party/mlir/include/mlir/Dialect/SDBM/SDBMExpr.h
+++ b/third_party/mlir/include/mlir/Dialect/SDBM/SDBMExpr.h
@@ -1,19 +1,10 @@
//===- SDBMExpr.h - MLIR SDBM Expression ------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// A striped difference-bound matrix (SDBM) expression is a constant expression,
// an identifier, a binary expression with constant RHS and +, stripe operators
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt b/third_party/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt
index b6759a9..fc7180d 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt
@@ -3,7 +3,7 @@
mlir_tablegen(SPIRVLowering.cpp.inc -gen-struct-attr-defs)
add_public_tablegen_target(MLIRSPIRVLoweringStructGen)
-add_mlir_dialect(SPIRVOps)
+add_mlir_dialect(SPIRVOps SPIRVOps)
set(LLVM_TARGET_DEFINITIONS SPIRVBase.td)
mlir_tablegen(SPIRVEnums.h.inc -gen-enum-decls)
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/LayoutUtils.h b/third_party/mlir/include/mlir/Dialect/SPIRV/LayoutUtils.h
index 7537e5f..329caa2 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/LayoutUtils.h
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/LayoutUtils.h
@@ -1,19 +1,10 @@
//===-- LayoutUtils.h - Decorate composite type with layout information ---===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines utilities used to get alignment and layout information for
// types in SPIR-V dialect.
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/Passes.h b/third_party/mlir/include/mlir/Dialect/SPIRV/Passes.h
index fe029ff..68f149b 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/Passes.h
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/Passes.h
@@ -1,19 +1,10 @@
//===- Passes.h - SPIR-V pass entry points ----------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This header file defines prototypes that expose pass constructors.
//
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td
index f15d274..39858f3 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td
@@ -1,19 +1,10 @@
//===-- SPIRVArithmeticOps.td - MLIR SPIR-V Arithmetic Ops -*- tablegen -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file contains arithmetic ops for the SPIR-V dialect. It corresponds
// to "3.32.13. Arithmetic Instructions" of the SPIR-V specification.
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVAtomicOps.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVAtomicOps.td
index 15b6ab0..c2ea100 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVAtomicOps.td
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVAtomicOps.td
@@ -1,19 +1,10 @@
//===-- SPIRVAtomicOps.td - MLIR SPIR-V Atomic Ops ---------*- tablegen -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file contains atomic ops for the SPIR-V dialect. It corresponds to
// "3.32.18. Atomic Instructions" of the SPIR-V specification.
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index 8383988..5751a32 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -1,19 +1,10 @@
//===- SPIRVBase.td - MLIR SPIR-V Op Definitions Base file -*- tablegen -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This is the base file for SPIR-V operation definition specification.
// This file defines the SPIR-V dialect, common SPIR-V types, and utilities
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBinaryUtils.h b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBinaryUtils.h
index 3229e28..6a42648 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBinaryUtils.h
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBinaryUtils.h
@@ -1,19 +1,10 @@
//===- SPIRVBinaryUtils.cpp - SPIR-V Binary Module Utils --------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file declares common utilities for SPIR-V binary module.
//
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBitOps.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBitOps.td
index d76a1e3..360edee 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBitOps.td
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVBitOps.td
@@ -1,19 +1,10 @@
//===-- SPIRVBitOps.td - MLIR SPIR-V Bit Ops -*- tablegen -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file contains bit ops for the SPIR-V dialect. It corresponds
// to "3.32.13. Bit Instructions" of the SPIR-V specification.
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVCastOps.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVCastOps.td
index e4fe526..99fe0bb 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVCastOps.td
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVCastOps.td
@@ -1,19 +1,10 @@
//===-- SPIRVCastOps.td - MLIR SPIR-V Cast Ops -------*- tablegen -*-------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file contains cast ops for the SPIR-V dialect. It corresponds
// to "3.32.11. Convertion Instructions" of the SPIR-V specification.
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td
index d6e2e1c..5a8235f 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td
@@ -1,19 +1,10 @@
//===-- SPIRVCompositeOps.td - MLIR SPIR-V Composite Ops ---*- tablegen -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file contains composite ops for SPIR-V dialect. It corresponds
// to "3.32.12. Composite Instructions" of the SPIR-V spec.
@@ -120,7 +111,7 @@
let builders = [
OpBuilder<[{Builder *builder, OperationState &state,
- Value *composite, ArrayRef<int32_t> indices}]>
+ Value composite, ArrayRef<int32_t> indices}]>
];
let hasFolder = 1;
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
index 464b670..be09557 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
@@ -1,19 +1,10 @@
//===-- SPIRVControlFlowOps.td - SPIR-V Control Flow Ops ---*- tablegen -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file contains control flow ops for the SPIR-V dialect. It corresponds
// to "3.32.17. Control-Flow Instructions" of the SPIR-V specification.
@@ -132,7 +123,7 @@
let builders = [
OpBuilder<
- "Builder *builder, OperationState &state, Value *condition, "
+ "Builder *builder, OperationState &state, Value condition, "
"Block *trueBlock, ValueRange trueArguments, "
"Block *falseBlock, ValueRange falseArguments, "
"Optional<std::pair<uint32_t, uint32_t>> weights = {}",
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h
index 2571e5d..0c0eebd 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h
@@ -1,19 +1,10 @@
//===- SPIRVDialect.h - MLIR SPIR-V dialect ---------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file declares the SPIR-V dialect in MLIR.
//
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.td
index a031fac..b2eacbf 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.td
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.td
@@ -1,19 +1,10 @@
//===- SPIRVGLSLOps.td - GLSL extended insts spec file -----*- tablegen -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This is the op definition spec of GLSL extension ops.
//
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td
index c0388fe..827636a 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td
@@ -1,19 +1,10 @@
//===-- SPIRVGroupOps.td - MLIR SPIR-V (Sub)Group Ops ------*- tablegen -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file contains group and subgroup ops for the SPIR-V dialect. It
// corresponds to "3.32.21. Group and Subgroup Instructions" of the SPIR-V
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td
index 0c4b290..ac377d5 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td
@@ -1,19 +1,10 @@
//===-- SPIRVLogicalOps.td - MLIR SPIR-V Logical Ops -------*- tablegen -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file contains arithmetic ops for the SPIR-V dialect. It corresponds
// to "3.32.15. Relational and Logical Instructions" of the SPIR-V spec.
@@ -858,8 +849,8 @@
);
let builders = [OpBuilder<[{Builder *builder, OperationState &state,
- Value *cond, Value *trueValue,
- Value *falseValue}]>];
+ Value cond, Value trueValue,
+ Value falseValue}]>];
}
// -----
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
index f48a1d0..0f481f5 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
@@ -1,19 +1,10 @@
//===- SPIRVLowering.h - SPIR-V lowering utilities -------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Defines utilities to use while targeting SPIR-V dialect.
//
@@ -64,8 +55,8 @@
namespace spirv {
/// Returns a value that represents a builtin variable value within the SPIR-V
/// module.
-Value *getBuiltinVariableValue(Operation *op, spirv::BuiltIn builtin,
- OpBuilder &builder);
+Value getBuiltinVariableValue(Operation *op, spirv::BuiltIn builtin,
+ OpBuilder &builder);
/// Attribute name for specifying argument ABI information.
StringRef getInterfaceVarABIAttrName();
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.td
index d9cf0a7..91a8ff6 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.td
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.td
@@ -1,19 +1,10 @@
//===- SPIRVBase.td - MLIR SPIR-V Op Definitions Base file -*- tablegen -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This is the base file for supporting lowering to SPIR-V dialect. This
// file defines SPIR-V attributes used for specifying the shader
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td
index 1b3174c..f3a9a61 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td
@@ -1,19 +1,10 @@
//===-- SPIRVNonUniformOps.td - MLIR SPIR-V NonUniform Ops -*- tablegen -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file contains non-uniform ops for the SPIR-V dialect. It corresponds to
// "3.32.24. Non-Uniform Instructions" of the SPIR-V specification.
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h
index cb33146..2fa417b 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h
@@ -1,19 +1,10 @@
//===- SPIRVOps.h - MLIR SPIR-V operations ----------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file declares the operations in the SPIR-V dialect.
//
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
index 91ea8d7..1ce2892 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
@@ -1,19 +1,10 @@
//===-- SPIRVOps.td - MLIR SPIR-V Op Definitions Spec ------*- tablegen -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This is the main operation definition specification file for SPIR-V
// operations.
@@ -102,7 +93,7 @@
);
let builders = [OpBuilder<[{Builder *builder, OperationState &state,
- Value *basePtr, ValueRange indices}]>];
+ Value basePtr, ValueRange indices}]>];
let hasCanonicalizer = 1;
}
@@ -272,7 +263,7 @@
);
let builders = [OpBuilder<[{Builder *builder, OperationState &state,
- Value *basePtr, /*optional*/IntegerAttr memory_access,
+ Value basePtr, /*optional*/IntegerAttr memory_access,
/*optional*/IntegerAttr alignment}]>];
}
@@ -367,7 +358,7 @@
let builders = [
OpBuilder<"Builder *builder, OperationState &state, "
- "Value *ptr, Value *value, ArrayRef<NamedAttribute> namedAttrs", [{
+ "Value ptr, Value value, ArrayRef<NamedAttribute> namedAttrs", [{
state.addOperands(ptr);
state.addOperands(value);
state.addAttributes(namedAttrs);
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
index d1dacf3..c37796b 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
@@ -1,19 +1,10 @@
//===-- SPIRVStructureOps.td - MLIR SPIR-V Structure Ops ---*- tablegen -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file contains ops for defining the SPIR-V structure: module, function,
// and module-level operations. The representational form of these ops deviate
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
index bc3083e..001d313 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
@@ -1,19 +1,10 @@
//===- SPIRVTypes.h - MLIR SPIR-V Types -------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file declares the types in the SPIR-V dialect.
//
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/Serialization.h b/third_party/mlir/include/mlir/Dialect/SPIRV/Serialization.h
index bad7355..e8240b0 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/Serialization.h
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/Serialization.h
@@ -1,19 +1,10 @@
//===- Serialization.h - MLIR SPIR-V (De)serialization ----------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file declares the entry points for serialize and deserialize SPIR-V
// binary modules.
diff --git a/third_party/mlir/include/mlir/Dialect/StandardOps/Ops.h b/third_party/mlir/include/mlir/Dialect/StandardOps/Ops.h
index 1b1cf02..0ba16c5 100644
--- a/third_party/mlir/include/mlir/Dialect/StandardOps/Ops.h
+++ b/third_party/mlir/include/mlir/Dialect/StandardOps/Ops.h
@@ -1,19 +1,10 @@
//===- Ops.h - Standard MLIR Operations -------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines convenience types for working with standard operations
// in the MLIR operation set.
@@ -182,15 +173,14 @@
public:
using Op::Op;
- static void build(Builder *builder, OperationState &result, Value *srcMemRef,
- ValueRange srcIndices, Value *destMemRef,
- ValueRange destIndices, Value *numElements,
- Value *tagMemRef, ValueRange tagIndices,
- Value *stride = nullptr,
- Value *elementsPerStride = nullptr);
+ static void build(Builder *builder, OperationState &result, Value srcMemRef,
+ ValueRange srcIndices, Value destMemRef,
+ ValueRange destIndices, Value numElements, Value tagMemRef,
+ ValueRange tagIndices, Value stride = nullptr,
+ Value elementsPerStride = nullptr);
// Returns the source MemRefType for this DMA operation.
- Value *getSrcMemRef() { return getOperand(0); }
+ Value getSrcMemRef() { return getOperand(0); }
// Returns the rank (number of indices) of the source MemRefType.
unsigned getSrcMemRefRank() {
return getSrcMemRef()->getType().cast<MemRefType>().getRank();
@@ -202,7 +192,7 @@
}
// Returns the destination MemRefType for this DMA operations.
- Value *getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); }
+ Value getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); }
// Returns the rank (number of indices) of the destination MemRefType.
unsigned getDstMemRefRank() {
return getDstMemRef()->getType().cast<MemRefType>().getRank();
@@ -222,12 +212,12 @@
}
// Returns the number of elements being transferred by this DMA operation.
- Value *getNumElements() {
+ Value getNumElements() {
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank());
}
// Returns the Tag MemRef for this DMA operation.
- Value *getTagMemRef() {
+ Value getTagMemRef() {
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
}
// Returns the rank (number of indices) of the tag MemRefType.
@@ -276,13 +266,13 @@
1 + 1 + getTagMemRefRank();
}
- Value *getStride() {
+ Value getStride() {
if (!isStrided())
return nullptr;
return getOperand(getNumOperands() - 1 - 1);
}
- Value *getNumElementsPerStride() {
+ Value getNumElementsPerStride() {
if (!isStrided())
return nullptr;
return getOperand(getNumOperands() - 1);
@@ -307,13 +297,13 @@
public:
using Op::Op;
- static void build(Builder *builder, OperationState &result, Value *tagMemRef,
- ValueRange tagIndices, Value *numElements);
+ static void build(Builder *builder, OperationState &result, Value tagMemRef,
+ ValueRange tagIndices, Value numElements);
static StringRef getOperationName() { return "std.dma_wait"; }
// Returns the Tag MemRef associated with the DMA operation being waited on.
- Value *getTagMemRef() { return getOperand(0); }
+ Value getTagMemRef() { return getOperand(0); }
// Returns the tag memref index for this DMA operation.
operand_range getTagIndices() {
@@ -327,7 +317,7 @@
}
// Returns the number of elements transferred in the associated DMA operation.
- Value *getNumElements() { return getOperand(1 + getTagMemRefRank()); }
+ Value getNumElements() { return getOperand(1 + getTagMemRefRank()); }
static ParseResult parse(OpAsmParser &parser, OperationState &result);
void print(OpAsmPrinter &p);
@@ -342,7 +332,7 @@
/// Parses dimension and symbol list and returns true if parsing failed.
ParseResult parseDimAndSymbolList(OpAsmParser &parser,
- SmallVectorImpl<Value *> &operands,
+ SmallVectorImpl<Value> &operands,
unsigned &numDims);
raw_ostream &operator<<(raw_ostream &os, SubViewOp::Range &range);
diff --git a/third_party/mlir/include/mlir/Dialect/StandardOps/Ops.td b/third_party/mlir/include/mlir/Dialect/StandardOps/Ops.td
index 76c2ba5..1c8bb25 100644
--- a/third_party/mlir/include/mlir/Dialect/StandardOps/Ops.td
+++ b/third_party/mlir/include/mlir/Dialect/StandardOps/Ops.td
@@ -1,19 +1,10 @@
//===- Ops.td - Standard operation definitions -------------*- tablegen -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Defines some MLIR standard operations.
//
@@ -52,7 +43,7 @@
let results = (outs AnyType);
let builders = [OpBuilder<
- "Builder *builder, OperationState &result, Value *source, Type destType", [{
+ "Builder *builder, OperationState &result, Value source, Type destType", [{
impl::buildCastOp(builder, result, source, destType);
}]>];
@@ -191,7 +182,7 @@
}]>,
OpBuilder<
"Builder *builder, OperationState &result, MemRefType memrefType, " #
- "ArrayRef<Value*> operands, IntegerAttr alignment = IntegerAttr()", [{
+ "ArrayRef<Value> operands, IntegerAttr alignment = IntegerAttr()", [{
result.addOperands(operands);
result.types.push_back(memrefType);
if (alignment)
@@ -330,7 +321,7 @@
let results = (outs Variadic<AnyType>);
let builders = [OpBuilder<
- "Builder *, OperationState &result, Value *callee,"
+ "Builder *, OperationState &result, Value callee,"
"ValueRange operands = {}", [{
result.operands.push_back(callee);
result.addOperands(operands);
@@ -338,7 +329,7 @@
}]>];
let extraClassDeclaration = [{
- Value *getCallee() { return getOperand(0); }
+ Value getCallee() { return getOperand(0); }
/// Get the argument operands to the called function.
operand_range getArgOperands() {
@@ -395,7 +386,7 @@
let builders = [OpBuilder<
"Builder *builder, OperationState &result, CmpFPredicate predicate,"
- "Value *lhs, Value *rhs", [{
+ "Value lhs, Value rhs", [{
::buildCmpFOp(builder, result, predicate, lhs, rhs);
}]>];
@@ -463,7 +454,7 @@
let builders = [OpBuilder<
"Builder *builder, OperationState &result, CmpIPredicate predicate,"
- "Value *lhs, Value *rhs", [{
+ "Value lhs, Value rhs", [{
::buildCmpIOp(builder, result, predicate, lhs, rhs);
}]>];
@@ -502,7 +493,7 @@
let arguments = (ins I1:$condition, Variadic<AnyType>:$branchOperands);
let builders = [OpBuilder<
- "Builder *, OperationState &result, Value *condition,"
+ "Builder *, OperationState &result, Value condition,"
"Block *trueDest, ValueRange trueOperands,"
"Block *falseDest, ValueRange falseOperands", [{
result.addOperands(condition);
@@ -518,7 +509,7 @@
enum { trueIndex = 0, falseIndex = 1 };
// The condition operand is the first operand in the list.
- Value *getCondition() { return getOperand(0); }
+ Value getCondition() { return getOperand(0); }
/// Return the destination if the condition is true.
Block *getTrueDest() {
@@ -531,12 +522,12 @@
}
// Accessors for operands to the 'true' destination.
- Value *getTrueOperand(unsigned idx) {
+ Value getTrueOperand(unsigned idx) {
assert(idx < getNumTrueOperands());
return getOperand(getTrueDestOperandIndex() + idx);
}
- void setTrueOperand(unsigned idx, Value *value) {
+ void setTrueOperand(unsigned idx, Value value) {
assert(idx < getNumTrueOperands());
setOperand(getTrueDestOperandIndex() + idx, value);
}
@@ -561,11 +552,11 @@
}
// Accessors for operands to the 'false' destination.
- Value *getFalseOperand(unsigned idx) {
+ Value getFalseOperand(unsigned idx) {
assert(idx < getNumFalseOperands());
return getOperand(getFalseDestOperandIndex() + idx);
}
- void setFalseOperand(unsigned idx, Value *value) {
+ void setFalseOperand(unsigned idx, Value value) {
assert(idx < getNumFalseOperands());
setOperand(getFalseDestOperandIndex() + idx, value);
}
@@ -678,7 +669,7 @@
let results = (outs Index);
let builders = [OpBuilder<
- "Builder *builder, OperationState &result, Value *memrefOrTensor,"
+ "Builder *builder, OperationState &result, Value memrefOrTensor,"
"unsigned index", [{
auto indexType = builder->getIndexType();
auto indexAttr = builder->getIntegerAttr(indexType, index);
@@ -698,12 +689,12 @@
let summary = "floating point division operation";
}
-def DivISOp : IntArithmeticOp<"divis"> {
+def SignedDivIOp : IntArithmeticOp<"divi_signed"> {
let summary = "signed integer division operation";
let hasFolder = 1;
}
-def DivIUOp : IntArithmeticOp<"diviu"> {
+def UnsignedDivIOp : IntArithmeticOp<"divi_unsigned"> {
let summary = "unsigned integer division operation";
let hasFolder = 1;
}
@@ -730,7 +721,7 @@
let results = (outs AnyType);
let builders = [OpBuilder<
- "Builder *builder, OperationState &result, Value *aggregate,"
+ "Builder *builder, OperationState &result, Value aggregate,"
"ValueRange indices = {}", [{
auto resType = aggregate->getType().cast<ShapedType>()
.getElementType();
@@ -738,7 +729,7 @@
}]>];
let extraClassDeclaration = [{
- Value *getAggregate() { return getOperand(0); }
+ Value getAggregate() { return getOperand(0); }
operand_range getIndices() {
return {operand_begin() + 1, operand_end()};
@@ -816,7 +807,7 @@
let results = (outs AnyType);
let builders = [OpBuilder<
- "Builder *, OperationState &result, Value *memref,"
+ "Builder *, OperationState &result, Value memref,"
"ValueRange indices = {}", [{
auto memrefType = memref->getType().cast<MemRefType>();
result.addOperands(memref);
@@ -825,8 +816,8 @@
}]>];
let extraClassDeclaration = [{
- Value *getMemRef() { return getOperand(0); }
- void setMemRef(Value *value) { setOperand(0, value); }
+ Value getMemRef() { return getOperand(0); }
+ void setMemRef(Value value) { setOperand(0, value); }
MemRefType getMemRefType() {
return getMemRef()->getType().cast<MemRefType>();
}
@@ -952,8 +943,8 @@
BoolAttr:$isDataCache);
let builders = [OpBuilder<
- "Builder *builder, OperationState &result, Value *memref,"
- "ArrayRef<Value *> indices, bool isWrite, unsigned hint, bool isData",
+ "Builder *builder, OperationState &result, Value memref,"
+ "ArrayRef<Value> indices, bool isWrite, unsigned hint, bool isData",
[{
auto hintAttr = builder->getI32IntegerAttr(hint);
auto isWriteAttr = builder->getBoolAttr(isWrite);
@@ -990,7 +981,7 @@
let verifier = ?;
let builders = [OpBuilder<
- "Builder *builder, OperationState &result, Value *tensor", [{
+ "Builder *builder, OperationState &result, Value tensor", [{
auto indexType = builder->getIndexType();
build(builder, result, indexType, tensor);
}]>];
@@ -1002,12 +993,12 @@
let summary = "floating point division remainder operation";
}
-def RemISOp : IntArithmeticOp<"remis"> {
+def SignedRemIOp : IntArithmeticOp<"remi_signed"> {
let summary = "signed integer division remainder operation";
let hasFolder = 1;
}
-def RemIUOp : IntArithmeticOp<"remiu"> {
+def UnsignedRemIOp : IntArithmeticOp<"remi_unsigned"> {
let summary = "unsigned integer division remainder operation";
let hasFolder = 1;
}
@@ -1052,16 +1043,16 @@
let results = (outs AnyType);
let builders = [OpBuilder<
- "Builder *builder, OperationState &result, Value *condition,"
- "Value *trueValue, Value *falseValue", [{
+ "Builder *builder, OperationState &result, Value condition,"
+ "Value trueValue, Value falseValue", [{
result.addOperands({condition, trueValue, falseValue});
result.addTypes(trueValue->getType());
}]>];
let extraClassDeclaration = [{
- Value *getCondition() { return condition(); }
- Value *getTrueValue() { return true_value(); }
- Value *getFalseValue() { return false_value(); }
+ Value getCondition() { return condition(); }
+ Value getTrueValue() { return true_value(); }
+ Value getFalseValue() { return false_value(); }
}];
let hasFolder = 1;
@@ -1089,7 +1080,7 @@
let results = (outs IntegerLike);
let builders = [OpBuilder<
- "Builder *builder, OperationState &result, Value *value, Type destType", [{
+ "Builder *builder, OperationState &result, Value value, Type destType", [{
result.addOperands(value);
result.addTypes(destType);
}]>];
@@ -1102,8 +1093,45 @@
}];
}
-def ShlISOp : IntArithmeticOp<"shlis"> {
- let summary = "signed integer shift left";
+def ShiftLeftOp : IntArithmeticOp<"shift_left"> {
+ let summary = "integer left-shift";
+ let description = [{
+ The shift_left operation shifts an integer value to the left by a variable
+ amount. The low order bits are filled with zeros.
+
+ %1 = constant 5 : i8 // %1 is 0b00000101
+ %2 = constant 3 : i8
+ %3 = shift_left %1, %2 : (i8, i8) -> i8 // %3 is 0b00101000
+ }];
+}
+
+def SignedShiftRightOp : IntArithmeticOp<"shift_right_signed"> {
+ let summary = "signed integer right-shift";
+ let description = [{
+ The shift_right_signed operation shifts an integer value to the right by
+ a variable amount. The integer is interpreted as signed. The high order
+ bits in the output are filled with copies of the most-significant bit
+ of the shifted value (which means that the sign of the value is preserved).
+
+ %1 = constant 160 : i8 // %1 is 0b10100000
+ %2 = constant 3 : i8
+ %3 = shift_right_signed %1, %2 : (i8, i8) -> i8 // %3 is 0b11110100
+ %4 = constant 96 : i8 // %4 is 0b01100000
+ %5 = shift_right_signed %4, %2 : (i8, i8) -> i8 // %5 is 0b00001100
+ }];
+}
+
+def UnsignedShiftRightOp : IntArithmeticOp<"shift_right_unsigned"> {
+ let summary = "unsigned integer right-shift";
+ let description = [{
+ The shift_right_unsigned operation shifts an integer value to the right by
+ a variable amount. The integer is interpreted as unsigned. The high order
+ bits are always filled with zeros.
+
+ %1 = constant 160 : i8 // %1 is 0b10100000
+ %2 = constant 3 : i8
+ %3 = shift_right_unsigned %1, %2 : (i8, i8) -> i8 // %3 is 0b00010100
+ }];
}
def SIToFPOp : CastOp<"sitofp">, Arguments<(ins AnyType:$in)> {
@@ -1152,7 +1180,7 @@
let results = (outs AnyTypeOf<[AnyVector, AnyStaticShapeTensor]>:$aggregate);
let builders =
- [OpBuilder<"Builder *builder, OperationState &result, Value *element, "
+ [OpBuilder<"Builder *builder, OperationState &result, Value element, "
"Type aggregateType",
[{ build(builder, result, aggregateType, element); }]>];
@@ -1176,16 +1204,16 @@
Variadic<Index>:$indices);
let builders = [OpBuilder<
- "Builder *, OperationState &result, Value *valueToStore, Value *memref", [{
+ "Builder *, OperationState &result, Value valueToStore, Value memref", [{
result.addOperands(valueToStore);
result.addOperands(memref);
}]>];
let extraClassDeclaration = [{
- Value *getValueToStore() { return getOperand(0); }
+ Value getValueToStore() { return getOperand(0); }
- Value *getMemRef() { return getOperand(1); }
- void setMemRef(Value *value) { setOperand(1, value); }
+ Value getMemRef() { return getOperand(1); }
+ void setMemRef(Value value) { setOperand(1, value); }
MemRefType getMemRefType() {
return getMemRef()->getType().cast<MemRefType>();
}
@@ -1327,13 +1355,13 @@
let builders = [
OpBuilder<
- "Builder *b, OperationState &result, Value *source, "
+ "Builder *b, OperationState &result, Value source, "
"ValueRange offsets, ValueRange sizes, "
"ValueRange strides, Type resultType = Type(), "
"ArrayRef<NamedAttribute> attrs = {}">,
OpBuilder<
"Builder *builder, OperationState &result, "
- "Type resultType, Value *source">
+ "Type resultType, Value source">
];
let extraClassDeclaration = [{
@@ -1366,7 +1394,7 @@
// offset, size and stride operands of the SubViewOp into a list of triples.
// Such a list of triple is sometimes more convenient to manipulate.
struct Range {
- Value *offset, *size, *stride;
+ Value offset, size, stride;
};
SmallVector<Range, 8> getRanges();
}];
@@ -1428,7 +1456,7 @@
let verifier = ?;
let builders = [OpBuilder<
- "Builder *builder, OperationState &result, Value *memref", [{
+ "Builder *builder, OperationState &result, Value memref", [{
auto memrefType = memref->getType().cast<MemRefType>();
auto resultType = RankedTensorType::get(memrefType.getShape(),
memrefType.getElementType());
@@ -1482,7 +1510,7 @@
let results = (outs IntegerLike);
let builders = [OpBuilder<
- "Builder *builder, OperationState &result, Value *value, Type destType", [{
+ "Builder *builder, OperationState &result, Value value, Type destType", [{
result.addOperands(value);
result.addTypes(destType);
}]>];
@@ -1541,7 +1569,7 @@
/// Returns the dynamic offset for this view operation if specified.
/// Returns nullptr if no dynamic offset was specified.
- Value *getDynamicOffset();
+ Value getDynamicOffset();
/// Returns the starting operand list position of the dynamic size operands.
unsigned getDynamicSizesOperandStart() {
@@ -1582,7 +1610,7 @@
let results = (outs IntegerLike);
let builders = [OpBuilder<
- "Builder *builder, OperationState &result, Value *value, Type destType", [{
+ "Builder *builder, OperationState &result, Value value, Type destType", [{
result.addOperands(value);
result.addTypes(destType);
}]>];
diff --git a/third_party/mlir/include/mlir/Dialect/Traits.h b/third_party/mlir/include/mlir/Dialect/Traits.h
index e04eb82..87c8e66 100644
--- a/third_party/mlir/include/mlir/Dialect/Traits.h
+++ b/third_party/mlir/include/mlir/Dialect/Traits.h
@@ -1,19 +1,10 @@
//===- Traits.h - Common op traits shared by dialects -----------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file declares common op traits that are not core to MLIR but can be
// shared by multiple dialects.
diff --git a/third_party/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/third_party/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index b7e3990..9e7cbba 100644
--- a/third_party/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/third_party/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -1,19 +1,10 @@
//===- StructuredOpsUtils.h - Utilities used by structured ops --*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This header file define utilities that operate on standard types and are
// useful across multiple dialects that use structured ops abstractions. These
diff --git a/third_party/mlir/include/mlir/Dialect/VectorOps/CMakeLists.txt b/third_party/mlir/include/mlir/Dialect/VectorOps/CMakeLists.txt
index c165c5e..5ce3168 100644
--- a/third_party/mlir/include/mlir/Dialect/VectorOps/CMakeLists.txt
+++ b/third_party/mlir/include/mlir/Dialect/VectorOps/CMakeLists.txt
@@ -1,4 +1,4 @@
-add_mlir_dialect(VectorOps)
+add_mlir_dialect(VectorOps VectorOps)
set(LLVM_TARGET_DEFINITIONS VectorTransformPatterns.td)
mlir_tablegen(VectorTransformPatterns.h.inc -gen-rewriters)
diff --git a/third_party/mlir/include/mlir/Dialect/VectorOps/Utils.h b/third_party/mlir/include/mlir/Dialect/VectorOps/Utils.h
index f61a813..5f19f84 100644
--- a/third_party/mlir/include/mlir/Dialect/VectorOps/Utils.h
+++ b/third_party/mlir/include/mlir/Dialect/VectorOps/Utils.h
@@ -1,19 +1,10 @@
//===- Utils.h - VectorOps Utils ----------------------------*- C++ -*-=======//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_VECTOROPS_UTILS_H_
#define MLIR_DIALECT_VECTOROPS_UTILS_H_
@@ -122,7 +113,7 @@
/// `%arg0[%c0, %c0]` into vector<128xf32> which needs a 1-D vector broadcast.
///
AffineMap
-makePermutationMap(Operation *op, ArrayRef<Value *> indices,
+makePermutationMap(Operation *op, ArrayRef<Value> indices,
const DenseMap<Operation *, unsigned> &loopToVectorDim);
namespace matcher {
diff --git a/third_party/mlir/include/mlir/Dialect/VectorOps/VectorOps.h b/third_party/mlir/include/mlir/Dialect/VectorOps/VectorOps.h
index 06672c7..7234d46 100644
--- a/third_party/mlir/include/mlir/Dialect/VectorOps/VectorOps.h
+++ b/third_party/mlir/include/mlir/Dialect/VectorOps/VectorOps.h
@@ -1,19 +1,10 @@
//===- VectorOps.h - MLIR Super Vectorizer Operations -----------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines the Vector dialect.
//
@@ -52,6 +43,13 @@
void populateVectorToVectorTransformationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context);
+/// Returns the integer type required for subscripts in the vector dialect.
+IntegerType getVectorSubscriptType(Builder &builder);
+
+/// Returns an integer array attribute containing the given values using
+/// the integer type required for subscripts in the vector dialect.
+ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef<int64_t> values);
+
#define GET_OP_CLASSES
#include "mlir/Dialect/VectorOps/VectorOps.h.inc"
diff --git a/third_party/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/third_party/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
index 401e424..8726b16 100644
--- a/third_party/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
+++ b/third_party/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
@@ -1,19 +1,10 @@
//===- VectorOps.td - Vector op definitions ---------------*- tablegen -*-====//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Defines MLIR vector operations.
//
@@ -128,8 +119,8 @@
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
}];
let builders = [OpBuilder<
- "Builder *builder, OperationState &result, Value *lhs, Value *rhs, "
- "Value *acc, ArrayAttr indexingMaps, ArrayAttr iteratorTypes">];
+ "Builder *builder, OperationState &result, Value lhs, Value rhs, "
+ "Value acc, ArrayAttr indexingMaps, ArrayAttr iteratorTypes">];
let extraClassDeclaration = [{
VectorType getLhsType() {
return lhs()->getType().cast<VectorType>();
@@ -220,7 +211,7 @@
TCresVTEtIsSameAsOpBase<0, 0>>,
PredOpTrait<"second operand v2 and result have same element type",
TCresVTEtIsSameAsOpBase<0, 1>>]>,
- Arguments<(ins AnyVector:$v1, AnyVector:$v2, I32ArrayAttr:$mask)>,
+ Arguments<(ins AnyVector:$v1, AnyVector:$v2, I64ArrayAttr:$mask)>,
Results<(outs AnyVector:$vector)> {
let summary = "shuffle operation";
let description = [{
@@ -243,16 +234,17 @@
Examples:
```
- %0 = vector.shuffle %a, %b[0:i32, 3:i32]
+ %0 = vector.shuffle %a, %b[0, 3]
: vector<2xf32>, vector<2xf32> ; yields vector<2xf32>
- %1 = vector.shuffle %c, %b[0:i32, 1:i32, 2:i32]
+ %1 = vector.shuffle %c, %b[0, 1, 2]
: vector<2x16xf32>, vector<1x16xf32> ; yields vector<3x16xf32>
- %2 = vector.shuffle %a, %b[3:i32, 2:i32, 1:i32 : 0:i32]
+ %2 = vector.shuffle %a, %b[3, 2, 1, 0]
: vector<2xf32>, vector<2xf32> ; yields vector<4xf32>
```
}];
- let builders = [OpBuilder<"Builder *builder, OperationState &result, Value *v1, Value *v2, ArrayRef<int32_t>">];
+ let builders = [OpBuilder<"Builder *builder, OperationState &result,"
+ "Value v1, Value v2, ArrayRef<int64_t>">];
let extraClassDeclaration = [{
static StringRef getMaskAttrName() { return "mask"; }
VectorType getV1VectorType() {
@@ -271,7 +263,7 @@
Vector_Op<"extractelement", [NoSideEffect,
PredOpTrait<"operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>]>,
- Arguments<(ins AnyVector:$vector, Index:$position)>,
+ Arguments<(ins AnyVector:$vector, AnyInteger:$position)>,
Results<(outs AnyType)> {
let summary = "extractelement operation";
let description = [{
@@ -298,7 +290,7 @@
Vector_Op<"extract", [NoSideEffect,
PredOpTrait<"operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>]>,
- Arguments<(ins AnyVector:$vector, I32ArrayAttr:$position)>,
+ Arguments<(ins AnyVector:$vector, I64ArrayAttr:$position)>,
Results<(outs AnyType)> {
let summary = "extract operation";
let description = [{
@@ -312,7 +304,8 @@
```
}];
let builders = [OpBuilder<
- "Builder *builder, OperationState &result, Value *source, ArrayRef<int32_t>">];
+ "Builder *builder, OperationState &result, Value source,"
+ "ArrayRef<int64_t>">];
let extraClassDeclaration = [{
static StringRef getPositionAttrName() { return "position"; }
VectorType getVectorType() {
@@ -357,7 +350,7 @@
}];
let builders = [OpBuilder<
"Builder *builder, OperationState &result, TupleType tupleType, " #
- "Value *vector, ArrayRef<int64_t> sizes, " #
+ "Value vector, ArrayRef<int64_t> sizes, " #
"ArrayRef<int64_t> strides">];
let extraClassDeclaration = [{
VectorType getSourceVectorType() {
@@ -379,7 +372,7 @@
TCresVTEtIsSameAsOpBase<0, 0>>,
PredOpTrait<"dest operand and result have same type",
TCresIsSameAsOpBase<0, 1>>]>,
- Arguments<(ins AnyType:$source, AnyVector:$dest, Index:$position)>,
+ Arguments<(ins AnyType:$source, AnyVector:$dest, AnyInteger:$position)>,
Results<(outs AnyVector)> {
let summary = "insertelement operation";
let description = [{
@@ -411,7 +404,7 @@
TCresVTEtIsSameAsOpBase<0, 0>>,
PredOpTrait<"dest operand and result have same type",
TCresIsSameAsOpBase<0, 1>>]>,
- Arguments<(ins AnyType:$source, AnyVector:$dest, I32ArrayAttr:$position)>,
+ Arguments<(ins AnyType:$source, AnyVector:$dest, I64ArrayAttr:$position)>,
Results<(outs AnyVector)> {
let summary = "insert operation";
let description = [{
@@ -421,15 +414,15 @@
Examples:
```
- %2 = vector.insert %0, %1[3 : i32]:
+ %2 = vector.insert %0, %1[3]:
vector<8x16xf32> into vector<4x8x16xf32>
- %5 = vector.insert %3, %4[3 : i32, 3 : i32, 3 : i32]:
+ %5 = vector.insert %3, %4[3, 3, 3]:
f32 into vector<4x8x16xf32>
```
}];
let builders = [OpBuilder<
- "Builder *builder, OperationState &result, Value *source, " #
- "Value *dest, ArrayRef<int32_t>">];
+ "Builder *builder, OperationState &result, Value source, " #
+ "Value dest, ArrayRef<int64_t>">];
let extraClassDeclaration = [{
static StringRef getPositionAttrName() { return "position"; }
Type getSourceType() { return source()->getType(); }
@@ -521,7 +514,7 @@
```
}];
let builders = [OpBuilder<
- "Builder *builder, OperationState &result, Value *source, Value *dest, " #
+ "Builder *builder, OperationState &result, Value source, Value dest, " #
"ArrayRef<int64_t> offsets, ArrayRef<int64_t> strides">];
let extraClassDeclaration = [{
static StringRef getOffsetsAttrName() { return "offsets"; }
@@ -574,6 +567,123 @@
}];
}
+// TODO(andydavis) Add transformation which decomposes ReshapeOp into an
+// optimized sequence of vector rotate/shuffle/select operations.
+def Vector_ReshapeOp :
+ Vector_Op<"reshape", [AttrSizedOperandSegments, NoSideEffect]>,
+ Arguments<(ins AnyVector:$vector, Variadic<Index>:$input_shape,
+ Variadic<Index>:$output_shape,
+ I64ArrayAttr:$fixed_vector_sizes,
+ I32ElementsAttr:$operand_segment_sizes)>,
+ Results<(outs AnyVector)> {
+ let summary = "vector reshape operation";
+ let description = [{
+ Reshapes its vector operand from 'input_shape' to 'output_shape' maintaining
+ fixed vector dimension 'fixed_vector_sizes' on the innermost vector
+ dimensions.
+
+ The parameters 'input_shape' and 'output_shape' represent valid data shapes
+ across fixed vector shapes. For example, if a vector has a valid data
+ shape [6] with fixed vector size [8], then the valid data elements are
+ assumed to be stored at the beginning of the vector with the remaining
+ vector elements undefined.
+
+ In the examples below, valid data elements are represented by an alphabetic
+ character, and undefined data elements are represented by '-'.
+
+ Example
+
+ vector<1x8xf32> with valid data shape [6], fixed vector sizes [8]
+
+ input: [a, b, c, d, e, f]
+
+ layout map: (d0) -> (d0 floordiv 8, d0 mod 8)
+
+ vector layout: [a, b, c, d, e, f, -, -]
+
+ Example
+
+ vector<2x8xf32> with valid data shape [10], fixed vector sizes [8]
+
+ input: [a, b, c, d, e, f, g, h, i, j]
+
+ layout map: (d0) -> (d0 floordiv 8, d0 mod 8)
+
+ vector layout: [[a, b, c, d, e, f, g, h],
+ [i, j, -, -, -, -, -, -]]
+
+ Example
+
+ vector<2x2x2x3xf32> with valid data shape [3, 5], fixed vector sizes
+ [2, 3]
+
+ input: [[a, b, c, d, e],
+ [f, g, h, i, j],
+ [k, l, m, n, o]]
+
+ layout map: (d0, d1) -> (d0 floordiv 3, d1 floordiv 5,
+ d0 mod 3, d1 mod 5)
+
+ vector layout: [[[[a, b, c],
+ [f, g, h]]
+ [[d, e, -],
+ [i, j, -]]],
+ [[[k, l, m],
+ [-, -, -]]
+ [[n, o, -],
+ [-, -, -]]]]
+
+ Example
+
+ %1 = vector.reshape %0, [%c3, %c6], [%c2, %c9], [4]
+ : vector<3x2x4xf32> to vector<2x3x4xf32>
+
+ input: [[a, b, c, d, e, f],
+ [g, h, i, j, k, l],
+ [m, n, o, p, q, r]]
+
+ layout map: (d0, d1) -> (d0, d1 floordiv 4, d1 mod 4)
+
+
+ Input vector: [[[a, b, c, d],
+ [e, f, -, -]],
+ [[g, h, i, j],
+ [k, l, -, -]],
+ [[m, n, o, p],
+ [q, r, -, -]]]
+
+ Output vector: [[[a, b, c, d],
+ [e, f, g, h],
+ [i, -, -, -]],
+ [[j, k, l, m],
+ [n, o, p, q],
+ [r, -, -, -]]]
+ }];
+
+ let extraClassDeclaration = [{
+ VectorType getInputVectorType() {
+ return vector()->getType().cast<VectorType>();
+ }
+ VectorType getOutputVectorType() {
+ return getResult()->getType().cast<VectorType>();
+ }
+
+ /// Returns as integer value the number of input shape operands.
+ int64_t getNumInputShapeSizes() { return input_shape().size(); }
+
+ /// Returns as integer value the number of output shape operands.
+ int64_t getNumOutputShapeSizes() { return output_shape().size(); }
+
+ void getFixedVectorSizes(SmallVectorImpl<int64_t> &results);
+
+ static StringRef getFixedVectorSizesAttrName() {
+ return "fixed_vector_sizes";
+ }
+ static StringRef getInputShapeAttrName() { return "input_shape"; }
+ static StringRef getOutputShapeAttrName() { return "output_shape"; }
+ }];
+}
+
def Vector_StridedSliceOp :
Vector_Op<"strided_slice", [NoSideEffect,
PredOpTrait<"operand and result have same element type",
@@ -606,7 +716,7 @@
vector<4x8x16xf32> to vector<2x4x16xf32>
}];
let builders = [OpBuilder<
- "Builder *builder, OperationState &result, Value *source, " #
+ "Builder *builder, OperationState &result, Value source, " #
"ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, " #
"ArrayRef<int64_t> strides">];
let extraClassDeclaration = [{
@@ -629,10 +739,15 @@
let description = [{
The `vector.transfer_read` op performs a blocking read from a slice within
- a scalar [MemRef](../LangRef.md#memref-type) supplied as its first operand
- into a [vector](../LangRef.md#vector-type) of the same elemental type. The
- slice is further defined by a full-rank index within the MemRef, supplied as
- the operands `2 .. 1 + rank(memref)`. The permutation_map
+ a [MemRef](../LangRef.md#memref-type) supplied as its first operand
+ into a [vector](../LangRef.md#vector-type) of the same base elemental type.
+
+ A vector memref operand must have its vector element type match a suffix
+ (shape and element type) of the vector (e.g. memref<3x2x6x4x3xf32>,
+ vector<1x1x4x3xf32>).
+
+ The slice is further defined by a full-rank index within the MemRef,
+ supplied as the operands `2 .. 1 + rank(memref)`. The permutation_map
[attribute](../LangRef.md#attributes) is an
[affine-map](Affine.md#affine-maps) which specifies the transposition on the
slice to match the vector shape. The size of the slice is specified by the
@@ -737,6 +852,11 @@
memref<?x?xf32>, vector<128xf32>
}
}
+
+ // Read from a memref with vector element type.
+ %4 = vector.transfer_read %arg1[%c3, %c3], %vf0
+ {permutation_map = (d0, d1)->(d0, d1)}
+ : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
```
}];
@@ -761,10 +881,15 @@
let description = [{
The `vector.transfer_write` performs a blocking write from a
[vector](../LangRef.md#vector-type), supplied as its first operand, into a
- slice within a scalar [MemRef](../LangRef.md#memref-type) of the same
- elemental type, supplied as its second operand. The slice is further defined
- by a full-rank index within the MemRef, supplied as the operands
- `3 .. 2 + rank(memref)`.
+ slice within a [MemRef](../LangRef.md#memref-type) of the same base
+ elemental type, supplied as its second operand.
+
+ A vector memref operand must have its vector element type match a suffix
+ (shape and element type) of the vector (e.g. memref<3x2x6x4x3xf32>,
+ vector<1x1x4x3xf32>).
+
+ The slice is further defined by a full-rank index within the MemRef,
+ supplied as the operands `3 .. 2 + rank(memref)`.
The permutation_map [attribute](../LangRef.md#attributes) is an
[affine-map](Affine.md#affine-maps) which specifies the transposition on the
slice to match the vector shape. The size of the slice is specified by the
@@ -798,6 +923,11 @@
{permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} :
vector<16x32x64xf32>, memref<?x?x?x?xf32>
}}}}
+
+ // write to a memref with vector element type.
+ vector.transfer_write %4, %arg1[%c3, %c3]
+ {permutation_map = (d0, d1)->(d0, d1)}
+ : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
```
}];
@@ -838,7 +968,7 @@
}];
let builders = [OpBuilder<
- "Builder *builder, OperationState &result, Value *source">];
+ "Builder *builder, OperationState &result, Value source">];
let parser = [{
return impl::parseCastOp(parser, result);
@@ -931,7 +1061,7 @@
Note that this operation is used during the vector op unrolling
transformation and should be removed before lowering to lower-level
dialects.
-
+
Examples:
```
diff --git a/third_party/mlir/include/mlir/Dialect/VectorOps/VectorTransformPatterns.td b/third_party/mlir/include/mlir/Dialect/VectorOps/VectorTransformPatterns.td
index 86ff9b5..5d0244f 100644
--- a/third_party/mlir/include/mlir/Dialect/VectorOps/VectorTransformPatterns.td
+++ b/third_party/mlir/include/mlir/Dialect/VectorOps/VectorTransformPatterns.td
@@ -1,19 +1,10 @@
//===- VectorTransformPatterns.td - Vector-Vector patterns -*- tablegen -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This is the pattern definition file for declarative Vector transformations.
//
diff --git a/third_party/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h b/third_party/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h
index 2c2e4e7..feb8bd6 100644
--- a/third_party/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h
+++ b/third_party/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h
@@ -1,19 +1,10 @@
//===- VectorTransforms.h - Vector transformations as patterns --*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef DIALECT_VECTOROPS_VECTORTRANSFORMS_H_
#define DIALECT_VECTOROPS_VECTORTRANSFORMS_H_
@@ -73,8 +64,8 @@
//
// This will be extended in the future to support more advanced use cases than
// simple pointwise ops.
-Value *unrollSingleResultOpMatchingType(PatternRewriter &builder, Operation *op,
- ArrayRef<int64_t> targetShape);
+Value unrollSingleResultOpMatchingType(PatternRewriter &builder, Operation *op,
+ ArrayRef<int64_t> targetShape);
} // namespace vector
} // namespace mlir
diff --git a/third_party/mlir/include/mlir/EDSC/Builders.h b/third_party/mlir/include/mlir/EDSC/Builders.h
index 69c72a5..d598c1c 100644
--- a/third_party/mlir/include/mlir/EDSC/Builders.h
+++ b/third_party/mlir/include/mlir/EDSC/Builders.h
@@ -1,19 +1,10 @@
//===- Builders.h - MLIR Declarative Builder Classes ------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Provides intuitive composable interfaces for building structured MLIR
// snippets in a declarative fashion.
@@ -152,7 +143,7 @@
/// A LoopBuilder is a generic NestedBuilder for loop-like MLIR operations.
/// More specifically it is meant to be used as a temporary object for
-/// representing any nested MLIR construct that is "related to" an mlir::Value*
+/// representing any nested MLIR construct that is "related to" an mlir::Value
/// (for now an induction variable).
/// This is extensible and will evolve in the future as MLIR evolves, hence
/// the name LoopBuilder (as opposed to say ForBuilder or AffineForBuilder).
@@ -242,7 +233,7 @@
/// A BlockBuilder is a NestedBuilder for mlir::Block*.
/// This exists by opposition to LoopBuilder which is not related to an
-/// mlir::Block* but to a mlir::Value*.
+/// mlir::Block* but to a mlir::Value.
/// It is meant to be used as a temporary object for representing any nested
/// MLIR construct that is "related to" an mlir::Block*.
class BlockBuilder : public NestedBuilder {
@@ -257,7 +248,7 @@
///
/// Prerequisites:
/// The ValueHandle `args` are typed delayed ValueHandles; i.e. they are
- /// not yet bound to mlir::Value*.
+ /// not yet bound to mlir::Value.
BlockBuilder(BlockHandle *bh, ArrayRef<ValueHandle *> args);
/// The only purpose of this operator is to serve as a sequence point so that
@@ -291,10 +282,10 @@
/// typed "delayed" value that can be hold a Value in the future;
/// 3. constructed state,in which case it holds a Value.
///
-/// A ValueHandle is meant to capture a single Value* and should be used for
+/// A ValueHandle is meant to capture a single Value and should be used for
/// operations that have a single result. For convenience of use, we also
/// include AffineForOp in this category although it does not return a value.
-/// In the case of AffineForOp, the captured Value* is the loop induction
+/// In the case of AffineForOp, the captured Value is the loop induction
/// variable.
class ValueHandle : public CapturableHandle {
public:
@@ -304,15 +295,15 @@
/// A ValueHandle that is constructed from a Type represents a typed "delayed"
/// Value. A delayed Value can only capture Values of the specified type.
/// Such a delayed value represents the declaration (in the PL sense) of a
- /// placeholder for an mlir::Value* that will be constructed and captured at
+ /// placeholder for an mlir::Value that will be constructed and captured at
/// some later point in the program.
explicit ValueHandle(Type t) : t(t), v(nullptr) {}
- /// A ValueHandle that is constructed from an mlir::Value* is an "eager"
+ /// A ValueHandle that is constructed from an mlir::Value is an "eager"
/// Value. An eager Value represents both the declaration and the definition
- /// (in the PL sense) of a placeholder for an mlir::Value* that has already
+ /// (in the PL sense) of a placeholder for an mlir::Value that has already
/// been constructed in the past and that is captured "now" in the program.
- explicit ValueHandle(Value *v) : t(v->getType()), v(v) {}
+ explicit ValueHandle(Value v) : t(v->getType()), v(v) {}
/// Builds a ConstantIndexOp of value `cst`. The constant is created at the
/// current insertion point.
@@ -336,8 +327,9 @@
std::swap(v, other.v);
}
- /// Implicit conversion useful for automatic conversion to Container<Value*>.
- operator Value *() const { return getValue(); }
+ /// Implicit conversion useful for automatic conversion to Container<Value>.
+ operator Value() const { return getValue(); }
+ operator bool() const { return hasValue(); }
/// Generic mlir::Op create. This is the key to being extensible to the whole
/// of MLIR without duplicating the type system or the op definitions.
@@ -355,7 +347,7 @@
/// Special case to build composed AffineApply operations.
// TODO: createOrFold when available and move inside of the `create` method.
static ValueHandle createComposedAffineApply(AffineMap map,
- ArrayRef<Value *> operands);
+ ArrayRef<Value> operands);
/// Generic create for a named operation producing a single value.
static ValueHandle create(StringRef name, ArrayRef<ValueHandle> operands,
@@ -363,7 +355,7 @@
ArrayRef<NamedAttribute> attributes = {});
bool hasValue() const { return v != nullptr; }
- Value *getValue() const {
+ Value getValue() const {
assert(hasValue() && "Unexpected null value;");
return v;
}
@@ -380,12 +372,12 @@
ValueHandle() : t(), v(nullptr) {}
Type t;
- Value *v;
+ Value v;
};
/// An OperationHandle can be used in lieu of ValueHandle to capture the
/// operation in cases when one does not care about, or cannot extract, a
-/// unique Value* from the operation.
+/// unique Value from the operation.
/// This can be used for capturing zero result operations as well as
/// multi-result operations that are not supported by ValueHandle.
/// We do not distinguish further between zero and multi-result operations at
@@ -529,7 +521,7 @@
} // namespace op
-/// Entry point to build multiple ValueHandle from a `Container` of Value* or
+/// Entry point to build multiple ValueHandle from a `Container` of Value or
/// Type.
template <typename Container>
inline SmallVector<ValueHandle, 8> makeValueHandles(Container values) {
diff --git a/third_party/mlir/include/mlir/EDSC/Helpers.h b/third_party/mlir/include/mlir/EDSC/Helpers.h
index 423c92b..a7c0365 100644
--- a/third_party/mlir/include/mlir/EDSC/Helpers.h
+++ b/third_party/mlir/include/mlir/EDSC/Helpers.h
@@ -1,19 +1,10 @@
//===- Helpers.h - MLIR Declarative Helper Functionality --------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Provides helper classes and syntactic sugar for declarative builders.
//
@@ -75,7 +66,7 @@
// TODO(ntv): Support MemRefs with layoutMaps.
class MemRefView : public View {
public:
- explicit MemRefView(Value *v);
+ explicit MemRefView(Value v);
MemRefView(const MemRefView &) = default;
MemRefView &operator=(const MemRefView &) = default;
@@ -91,7 +82,7 @@
/// a MemRefView but for vectors. This exists purely for boilerplate avoidance.
class VectorView : public View {
public:
- explicit VectorView(Value *v);
+ explicit VectorView(Value v);
VectorView(const VectorView &) = default;
VectorView &operator=(const VectorView &) = default;
@@ -120,7 +111,7 @@
template <typename Load, typename Store> class TemplatedIndexedValue {
public:
explicit TemplatedIndexedValue(Type t) : base(t) {}
- explicit TemplatedIndexedValue(Value *v)
+ explicit TemplatedIndexedValue(Value v)
: TemplatedIndexedValue(ValueHandle(v)) {}
explicit TemplatedIndexedValue(ValueHandle v) : base(v) {}
@@ -161,8 +152,8 @@
return Load(getBase(), {indices.begin(), indices.end()});
}
- /// Emits a `load` when converting to a Value*.
- Value *operator*(void)const {
+ /// Emits a `load` when converting to a Value.
+ Value operator*(void) const {
return Load(getBase(), {indices.begin(), indices.end()}).getValue();
}
diff --git a/third_party/mlir/include/mlir/EDSC/Intrinsics.h b/third_party/mlir/include/mlir/EDSC/Intrinsics.h
index 06c7550..30cce6b 100644
--- a/third_party/mlir/include/mlir/EDSC/Intrinsics.h
+++ b/third_party/mlir/include/mlir/EDSC/Intrinsics.h
@@ -1,19 +1,10 @@
//===- Intrinsics.h - MLIR Operations for Declarative Builders ---*- C++-*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Provides intuitive composable intrinsics for building snippets of MLIR
// declaratively
@@ -44,7 +35,7 @@
explicit IndexHandle()
: ValueHandle(ScopedContext::getBuilder().getIndexType()) {}
explicit IndexHandle(index_t v) : ValueHandle(v) {}
- explicit IndexHandle(Value *v) : ValueHandle(v) {
+ explicit IndexHandle(Value v) : ValueHandle(v) {
assert(v->getType() == ScopedContext::getBuilder().getIndexType() &&
"Expected index type");
}
@@ -79,9 +70,9 @@
return pivs;
}
-/// Returns a vector of the underlying Value* from `ivs`.
-inline SmallVector<Value *, 8> extractValues(ArrayRef<IndexHandle> ivs) {
- SmallVector<Value *, 8> vals;
+/// Returns a vector of the underlying Value from `ivs`.
+inline SmallVector<Value, 8> extractValues(ArrayRef<IndexHandle> ivs) {
+ SmallVector<Value, 8> vals;
vals.reserve(ivs.size());
for (auto &iv : ivs) {
vals.push_back(iv.getValue());
@@ -96,7 +87,7 @@
namespace detail {
/// Helper structure to be used with ValueBuilder / OperationBuilder.
/// It serves the purpose of removing boilerplate specialization for the sole
-/// purpose of implicitly converting ArrayRef<ValueHandle> -> ArrayRef<Value*>.
+/// purpose of implicitly converting ArrayRef<ValueHandle> -> ArrayRef<Value>.
class ValueHandleArray {
public:
ValueHandleArray(ArrayRef<ValueHandle> vals) {
@@ -109,11 +100,11 @@
SmallVector<IndexHandle, 8> tmp(vals.begin(), vals.end());
values.append(tmp.begin(), tmp.end());
}
- operator ArrayRef<Value *>() { return values; }
+ operator ArrayRef<Value>() { return values; }
private:
ValueHandleArray() = default;
- SmallVector<Value *, 8> values;
+ SmallVector<Value, 8> values;
};
template <typename T> inline T unpack(T value) { return value; }
@@ -128,8 +119,8 @@
/// boilerplate or Tablegen.
/// Arguably a builder is not a ValueHandle but in practice it is only used as
/// an alias to a notional ValueHandle<Op>.
-/// Implementing it as a subclass allows it to compose all the way to Value*.
-/// Without subclassing, implicit conversion to Value* would fail when composing
+/// Implementing it as a subclass allows it to compose all the way to Value.
+/// Without subclassing, implicit conversion to Value would fail when composing
/// in patterns such as: `select(a, b, select(c, d, e))`.
template <typename Op> struct ValueBuilder : public ValueHandle {
// Builder-based
@@ -238,8 +229,8 @@
///
/// Prerequisites:
/// `b` has not yet captured an mlir::Block*.
-/// No `captures` have captured any mlir::Value*.
-/// All `operands` have already captured an mlir::Value*
+/// No `captures` have captured any mlir::Value.
+/// All `operands` have already captured an mlir::Value
/// captures.size() == operands.size()
/// captures and operands are pairwise of the same type.
OperationHandle br(BlockHandle *bh, ArrayRef<ValueHandle *> captures,
@@ -266,8 +257,8 @@
///
/// Prerequisites:
/// `trueBranch`/`falseBranch` has not yet captured an mlir::Block*.
-/// No `trueCaptures`/`falseCaptures` have captured any mlir::Value*.
-/// All `trueOperands`/`trueOperands` have already captured an mlir::Value*
+/// No `trueCaptures`/`falseCaptures` have captured any mlir::Value.
+/// All `trueOperands`/`trueOperands` have already captured an mlir::Value
/// `trueCaptures`.size() == `trueOperands`.size()
/// `falseCaptures`.size() == `falseOperands`.size()
/// `trueCaptures` and `trueOperands` are pairwise of the same type
diff --git a/third_party/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h b/third_party/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h
index 4e70a21..4f218bd 100644
--- a/third_party/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h
+++ b/third_party/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h
@@ -1,19 +1,10 @@
//===- ExecutionEngine.h - MLIR Execution engine and utils -----*- C++ -*--===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file provides a JIT-backed execution engine for MLIR modules.
//
diff --git a/third_party/mlir/include/mlir/ExecutionEngine/OptUtils.h b/third_party/mlir/include/mlir/ExecutionEngine/OptUtils.h
index 8c0249d..7b7b259 100644
--- a/third_party/mlir/include/mlir/ExecutionEngine/OptUtils.h
+++ b/third_party/mlir/include/mlir/ExecutionEngine/OptUtils.h
@@ -1,19 +1,10 @@
//===- OptUtils.h - MLIR Execution Engine opt pass utilities ----*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file declares the utility functions to trigger LLVM optimizations from
// MLIR Execution Engine.
diff --git a/third_party/mlir/include/mlir/IR/AffineExpr.h b/third_party/mlir/include/mlir/IR/AffineExpr.h
index b66933d..7059489 100644
--- a/third_party/mlir/include/mlir/IR/AffineExpr.h
+++ b/third_party/mlir/include/mlir/IR/AffineExpr.h
@@ -1,19 +1,10 @@
//===- AffineExpr.h - MLIR Affine Expr Class --------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// An affine expression is an affine combination of dimension identifiers and
// symbols, including ceildiv/floordiv/mod by a constant integer.
diff --git a/third_party/mlir/include/mlir/IR/AffineExprVisitor.h b/third_party/mlir/include/mlir/IR/AffineExprVisitor.h
index 9fa4021..7866d6b 100644
--- a/third_party/mlir/include/mlir/IR/AffineExprVisitor.h
+++ b/third_party/mlir/include/mlir/IR/AffineExprVisitor.h
@@ -1,19 +1,10 @@
//===- AffineExprVisitor.h - MLIR AffineExpr Visitor Class ------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines the AffineExpr visitor class.
//
diff --git a/third_party/mlir/include/mlir/IR/AffineMap.h b/third_party/mlir/include/mlir/IR/AffineMap.h
index abd3712..3f9116c 100644
--- a/third_party/mlir/include/mlir/IR/AffineMap.h
+++ b/third_party/mlir/include/mlir/IR/AffineMap.h
@@ -1,19 +1,10 @@
//===- AffineMap.h - MLIR Affine Map Class ----------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Affine maps are mathematical functions which map a list of dimension
// identifiers and symbols, to multidimensional affine expressions.
diff --git a/third_party/mlir/include/mlir/IR/AttributeSupport.h b/third_party/mlir/include/mlir/IR/AttributeSupport.h
index 78b3a27..9804d68 100644
--- a/third_party/mlir/include/mlir/IR/AttributeSupport.h
+++ b/third_party/mlir/include/mlir/IR/AttributeSupport.h
@@ -1,19 +1,10 @@
//===- AttributeSupport.h ---------------------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines support types for registering dialect extended attributes.
//
diff --git a/third_party/mlir/include/mlir/IR/Attributes.h b/third_party/mlir/include/mlir/IR/Attributes.h
index 94aea94..b839858 100644
--- a/third_party/mlir/include/mlir/IR/Attributes.h
+++ b/third_party/mlir/include/mlir/IR/Attributes.h
@@ -1,19 +1,10 @@
//===- Attributes.h - MLIR Attribute Classes --------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_ATTRIBUTES_H
#define MLIR_IR_ATTRIBUTES_H
@@ -82,11 +73,8 @@
/* implicit */ Attribute(const ImplType *impl)
: impl(const_cast<ImplType *>(impl)) {}
- Attribute(const Attribute &other) : impl(other.impl) {}
- Attribute &operator=(Attribute other) {
- impl = other.impl;
- return *this;
- }
+ Attribute(const Attribute &other) = default;
+ Attribute &operator=(const Attribute &other) = default;
bool operator==(Attribute other) const { return impl == other.impl; }
bool operator!=(Attribute other) const { return !(*this == other); }
diff --git a/third_party/mlir/include/mlir/IR/Block.h b/third_party/mlir/include/mlir/IR/Block.h
index 6c5099b..934eed9 100644
--- a/third_party/mlir/include/mlir/IR/Block.h
+++ b/third_party/mlir/include/mlir/IR/Block.h
@@ -1,19 +1,10 @@
//===- Block.h - MLIR Block Class -------------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines the Block class.
//
@@ -72,7 +63,7 @@
//===--------------------------------------------------------------------===//
// This is the list of arguments to the block.
- using BlockArgListType = ArrayRef<BlockArgument *>;
+ using BlockArgListType = MutableArrayRef<BlockArgument>;
BlockArgListType getArguments() { return arguments; }
@@ -86,7 +77,7 @@
bool args_empty() { return arguments.empty(); }
/// Add one value to the argument list.
- BlockArgument *addArgument(Type type);
+ BlockArgument addArgument(Type type);
/// Add one argument to the argument list for each type specified in the list.
iterator_range<args_iterator> addArguments(ArrayRef<Type> types);
@@ -97,7 +88,7 @@
void eraseArgument(unsigned index, bool updatePredTerms = true);
unsigned getNumArguments() { return arguments.size(); }
- BlockArgument *getArgument(unsigned i) { return arguments[i]; }
+ BlockArgument getArgument(unsigned i) { return arguments[i]; }
//===--------------------------------------------------------------------===//
// Operation list management
@@ -332,7 +323,7 @@
OpListType operations;
/// This is the list of arguments to the block.
- std::vector<BlockArgument *> arguments;
+ std::vector<BlockArgument> arguments;
Block(Block &) = delete;
void operator=(Block &) = delete;
diff --git a/third_party/mlir/include/mlir/IR/BlockAndValueMapping.h b/third_party/mlir/include/mlir/IR/BlockAndValueMapping.h
index cd15d45..b7ad360 100644
--- a/third_party/mlir/include/mlir/IR/BlockAndValueMapping.h
+++ b/third_party/mlir/include/mlir/IR/BlockAndValueMapping.h
@@ -1,19 +1,10 @@
//===- BlockAndValueMapping.h -----------------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines a utility class for maintaining a mapping for multiple
// value types.
@@ -37,14 +28,18 @@
/// Inserts a new mapping for 'from' to 'to'. If there is an existing mapping,
/// it is overwritten.
void map(Block *from, Block *to) { valueMap[from] = to; }
- void map(Value *from, Value *to) { valueMap[from] = to; }
+ void map(Value from, Value to) {
+ valueMap[from.getAsOpaquePointer()] = to.getAsOpaquePointer();
+ }
/// Erases a mapping for 'from'.
- void erase(IRObjectWithUseList *from) { valueMap.erase(from); }
+ void erase(Block *from) { valueMap.erase(from); }
+ void erase(Value from) { valueMap.erase(from.getAsOpaquePointer()); }
/// Checks to see if a mapping for 'from' exists.
- bool contains(IRObjectWithUseList *from) const {
- return valueMap.count(from);
+ bool contains(Block *from) const { return valueMap.count(from); }
+ bool contains(Value from) const {
+ return valueMap.count(from.getAsOpaquePointer());
}
/// Lookup a mapped value within the map. If a mapping for the provided value
@@ -52,23 +47,19 @@
Block *lookupOrNull(Block *from) const {
return lookupOrValue(from, (Block *)nullptr);
}
- Value *lookupOrNull(Value *from) const {
- return lookupOrValue(from, (Value *)nullptr);
- }
+ Value lookupOrNull(Value from) const { return lookupOrValue(from, Value()); }
/// Lookup a mapped value within the map. If a mapping for the provided value
/// does not exist then return the provided value.
Block *lookupOrDefault(Block *from) const {
return lookupOrValue(from, from);
}
- Value *lookupOrDefault(Value *from) const {
- return lookupOrValue(from, from);
- }
+ Value lookupOrDefault(Value from) const { return lookupOrValue(from, from); }
/// Lookup a mapped value within the map. This asserts the provided value
/// exists within the map.
- template <typename T> T *lookup(T *from) const {
- auto *result = lookupOrNull(from);
+ template <typename T> T lookup(T from) const {
+ auto result = lookupOrNull(from);
assert(result && "expected 'from' to be contained within the map");
return result;
}
@@ -78,14 +69,18 @@
private:
/// Utility lookupOrValue that looks up an existing key or returns the
- /// provided value. This function assumes that if a mapping does exist, then
- /// it is of 'T' type.
- template <typename T> T *lookupOrValue(T *from, T *value) const {
+ /// provided value.
+ Block *lookupOrValue(Block *from, Block *value) const {
auto it = valueMap.find(from);
- return it != valueMap.end() ? static_cast<T *>(it->second) : value;
+ return it != valueMap.end() ? reinterpret_cast<Block *>(it->second) : value;
+ }
+ Value lookupOrValue(Value from, Value value) const {
+ auto it = valueMap.find(from.getAsOpaquePointer());
+ return it != valueMap.end() ? Value::getFromOpaquePointer(it->second)
+ : value;
}
- DenseMap<IRObjectWithUseList *, IRObjectWithUseList *> valueMap;
+ DenseMap<void *, void *> valueMap;
};
} // end namespace mlir
diff --git a/third_party/mlir/include/mlir/IR/BlockSupport.h b/third_party/mlir/include/mlir/IR/BlockSupport.h
index fd30c36..bc6a824 100644
--- a/third_party/mlir/include/mlir/IR/BlockSupport.h
+++ b/third_party/mlir/include/mlir/IR/BlockSupport.h
@@ -1,19 +1,10 @@
//===- BlockSupport.h -------------------------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines a number of support types for the Block class.
//
@@ -70,6 +61,7 @@
public:
using RangeBaseT::RangeBaseT;
SuccessorRange(Block *block);
+ SuccessorRange(Operation *term);
private:
/// See `detail::indexed_accessor_range_base` for details.
diff --git a/third_party/mlir/include/mlir/IR/Builders.h b/third_party/mlir/include/mlir/IR/Builders.h
index 766902f..2db44cb 100644
--- a/third_party/mlir/include/mlir/IR/Builders.h
+++ b/third_party/mlir/include/mlir/IR/Builders.h
@@ -1,19 +1,10 @@
//===- Builders.h - Helpers for constructing MLIR Classes -------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_BUILDERS_H
#define MLIR_IR_BUILDERS_H
@@ -313,7 +304,7 @@
/// and immediately try to fold it. This functions populates 'results' with
/// the results after folding the operation.
template <typename OpTy, typename... Args>
- void createOrFold(SmallVectorImpl<Value *> &results, Location location,
+ void createOrFold(SmallVectorImpl<Value> &results, Location location,
Args &&... args) {
// Create the operation without using 'createOperation' as we don't want to
// insert it yet.
@@ -331,9 +322,9 @@
/// Overload to create or fold a single result operation.
template <typename OpTy, typename... Args>
typename std::enable_if<OpTy::template hasTrait<OpTrait::OneResult>(),
- Value *>::type
+ Value>::type
createOrFold(Location location, Args &&... args) {
- SmallVector<Value *, 1> results;
+ SmallVector<Value, 1> results;
createOrFold<OpTy>(results, location, std::forward<Args>(args)...);
return results.front();
}
@@ -344,7 +335,7 @@
OpTy>::type
createOrFold(Location location, Args &&... args) {
auto op = create<OpTy>(location, std::forward<Args>(args)...);
- SmallVector<Value *, 0> unused;
+ SmallVector<Value, 0> unused;
tryFold(op.getOperation(), unused);
// Folding cannot remove a zero-result operation, so for convenience we
@@ -355,7 +346,7 @@
/// Attempts to fold the given operation and places new results within
/// 'results'. Returns success if the operation was folded, failure otherwise.
/// Note: This function does not erase the operation on a successful fold.
- LogicalResult tryFold(Operation *op, SmallVectorImpl<Value *> &results);
+ LogicalResult tryFold(Operation *op, SmallVectorImpl<Value> &results);
/// Creates a deep copy of the specified operation, remapping any operands
/// that use values outside of the operation using the map that is provided
diff --git a/third_party/mlir/include/mlir/IR/Diagnostics.h b/third_party/mlir/include/mlir/IR/Diagnostics.h
index 9385de9..e3d0f83 100644
--- a/third_party/mlir/include/mlir/IR/Diagnostics.h
+++ b/third_party/mlir/include/mlir/IR/Diagnostics.h
@@ -1,19 +1,10 @@
//===- Diagnostics.h - MLIR Diagnostics -------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines utilities for emitting diagnostics.
//
diff --git a/third_party/mlir/include/mlir/IR/Dialect.h b/third_party/mlir/include/mlir/IR/Dialect.h
index a1855e7..d3b4b05 100644
--- a/third_party/mlir/include/mlir/IR/Dialect.h
+++ b/third_party/mlir/include/mlir/IR/Dialect.h
@@ -1,19 +1,10 @@
//===- Dialect.h - IR Dialect Description -----------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines the 'dialect' abstraction.
//
diff --git a/third_party/mlir/include/mlir/IR/DialectHooks.h b/third_party/mlir/include/mlir/IR/DialectHooks.h
index c51fafb..7e4e1d8 100644
--- a/third_party/mlir/include/mlir/IR/DialectHooks.h
+++ b/third_party/mlir/include/mlir/IR/DialectHooks.h
@@ -1,19 +1,10 @@
//===- DialectHooks.h - MLIR DialectHooks mechanism -------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines abstraction and registration mechanism for dialect hooks.
//
diff --git a/third_party/mlir/include/mlir/IR/DialectImplementation.h b/third_party/mlir/include/mlir/IR/DialectImplementation.h
index c645a24..1eada8f 100644
--- a/third_party/mlir/include/mlir/IR/DialectImplementation.h
+++ b/third_party/mlir/include/mlir/IR/DialectImplementation.h
@@ -1,19 +1,10 @@
//===- DialectImplementation.h ----------------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file contains utilities classes for implementing dialect attributes and
// types.
diff --git a/third_party/mlir/include/mlir/IR/DialectInterface.h b/third_party/mlir/include/mlir/IR/DialectInterface.h
index 4eb4110..ff1f8fb 100644
--- a/third_party/mlir/include/mlir/IR/DialectInterface.h
+++ b/third_party/mlir/include/mlir/IR/DialectInterface.h
@@ -1,19 +1,10 @@
//===- DialectInterface.h - IR Dialect Interfaces ---------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_DIALECTINTERFACE_H
#define MLIR_IR_DIALECTINTERFACE_H
diff --git a/third_party/mlir/include/mlir/IR/DialectSymbolRegistry.def b/third_party/mlir/include/mlir/IR/DialectSymbolRegistry.def
index c1056bd..14b876a 100644
--- a/third_party/mlir/include/mlir/IR/DialectSymbolRegistry.def
+++ b/third_party/mlir/include/mlir/IR/DialectSymbolRegistry.def
@@ -1,19 +1,10 @@
//===- DialectSymbolRegistry.def - MLIR Dialect Symbol Registry -*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file enumerates the different dialects that define custom classes
// within the attribute or type system.
diff --git a/third_party/mlir/include/mlir/IR/Function.h b/third_party/mlir/include/mlir/IR/Function.h
index 6731f54..3f788bb 100644
--- a/third_party/mlir/include/mlir/IR/Function.h
+++ b/third_party/mlir/include/mlir/IR/Function.h
@@ -1,19 +1,10 @@
//===- Function.h - MLIR Function Class -------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Functions are the basic unit of composition in MLIR.
//
diff --git a/third_party/mlir/include/mlir/IR/FunctionImplementation.h b/third_party/mlir/include/mlir/IR/FunctionImplementation.h
index c557d58..9d3e438 100644
--- a/third_party/mlir/include/mlir/IR/FunctionImplementation.h
+++ b/third_party/mlir/include/mlir/IR/FunctionImplementation.h
@@ -1,19 +1,10 @@
//===- FunctionImplementation.h - Function-like Op utilities ----*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file provides utility functions for implementing function-like
// operations, in particular, parsing, printing and verification components
diff --git a/third_party/mlir/include/mlir/IR/FunctionSupport.h b/third_party/mlir/include/mlir/IR/FunctionSupport.h
index b15b056..e6cba2c 100644
--- a/third_party/mlir/include/mlir/IR/FunctionSupport.h
+++ b/third_party/mlir/include/mlir/IR/FunctionSupport.h
@@ -1,19 +1,10 @@
//===- FunctionSupport.h - Utility types for function-like ops --*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines support types for Operations that represent function-like
// constructs to use.
@@ -183,7 +174,7 @@
}
/// Gets argument.
- BlockArgument *getArgument(unsigned idx) {
+ BlockArgument getArgument(unsigned idx) {
return getBlocks().front().getArgument(idx);
}
diff --git a/third_party/mlir/include/mlir/IR/Identifier.h b/third_party/mlir/include/mlir/IR/Identifier.h
index bc84c20..604eebf 100644
--- a/third_party/mlir/include/mlir/IR/Identifier.h
+++ b/third_party/mlir/include/mlir/IR/Identifier.h
@@ -1,19 +1,10 @@
//===- Identifier.h - MLIR Identifier Class ---------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_IDENTIFIER_H
#define MLIR_IR_IDENTIFIER_H
diff --git a/third_party/mlir/include/mlir/IR/IntegerSet.h b/third_party/mlir/include/mlir/IR/IntegerSet.h
index 6ffe830..1238511 100644
--- a/third_party/mlir/include/mlir/IR/IntegerSet.h
+++ b/third_party/mlir/include/mlir/IR/IntegerSet.h
@@ -1,19 +1,10 @@
//===- IntegerSet.h - MLIR Integer Set Class --------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Integer sets are sets of points from the integer lattice constrained by
// affine equality/inequality constraints. This class is meant to represent
diff --git a/third_party/mlir/include/mlir/IR/Location.h b/third_party/mlir/include/mlir/IR/Location.h
index bb55ad6..c36bcb3 100644
--- a/third_party/mlir/include/mlir/IR/Location.h
+++ b/third_party/mlir/include/mlir/IR/Location.h
@@ -1,19 +1,10 @@
//===- Location.h - MLIR Location Classes -----------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// These classes provide the ability to relate MLIR objects back to source
// location position information.
diff --git a/third_party/mlir/include/mlir/IR/MLIRContext.h b/third_party/mlir/include/mlir/IR/MLIRContext.h
index a93cb8b..e0761bc 100644
--- a/third_party/mlir/include/mlir/IR/MLIRContext.h
+++ b/third_party/mlir/include/mlir/IR/MLIRContext.h
@@ -1,19 +1,10 @@
//===- MLIRContext.h - MLIR Global Context Class ----------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_MLIRCONTEXT_H
#define MLIR_IR_MLIRCONTEXT_H
diff --git a/third_party/mlir/include/mlir/IR/Matchers.h b/third_party/mlir/include/mlir/IR/Matchers.h
index 1261916..2cfa242 100644
--- a/third_party/mlir/include/mlir/IR/Matchers.h
+++ b/third_party/mlir/include/mlir/IR/Matchers.h
@@ -1,19 +1,10 @@
//===- Matchers.h - Various common matchers ---------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file provides a simple and efficient mechanism for performing general
// tree-based pattern matching over MLIR. This mechanism is inspired by LLVM's
@@ -142,7 +133,7 @@
/// Statically switch to a Value matcher.
template <typename MatcherClass>
typename std::enable_if_t<is_detected<detail::has_operation_or_value_matcher_t,
- MatcherClass, Value *>::value,
+ MatcherClass, Value>::value,
bool>
matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
return matcher.match(op->getOperand(idx));
@@ -161,14 +152,14 @@
/// Terminal matcher, always returns true.
struct AnyValueMatcher {
- bool match(Value *op) const { return true; }
+ bool match(Value op) const { return true; }
};
/// Binds to a specific value and matches it.
struct PatternMatcherValue {
- PatternMatcherValue(Value *val) : value(val) {}
- bool match(Value *val) const { return val == value; }
- Value *value;
+ PatternMatcherValue(Value val) : value(val) {}
+ bool match(Value val) const { return val == value; }
+ Value value;
};
template <typename TupleT, class CallbackT, std::size_t... Is>
@@ -235,7 +226,7 @@
/// Entry point for matching a pattern over a Value.
template <typename Pattern>
-inline bool matchPattern(Value *value, const Pattern &pattern) {
+inline bool matchPattern(Value value, const Pattern &pattern) {
// TODO: handle other cases
if (auto *op = value->getDefiningOp())
return const_cast<Pattern &>(pattern).match(op);
@@ -262,7 +253,7 @@
namespace matchers {
inline auto m_Any() { return detail::AnyValueMatcher(); }
-inline auto m_Val(Value *v) { return detail::PatternMatcherValue(v); }
+inline auto m_Val(Value v) { return detail::PatternMatcherValue(v); }
} // namespace matchers
} // end namespace mlir
diff --git a/third_party/mlir/include/mlir/IR/Module.h b/third_party/mlir/include/mlir/IR/Module.h
index 52d2455..babc51a 100644
--- a/third_party/mlir/include/mlir/IR/Module.h
+++ b/third_party/mlir/include/mlir/IR/Module.h
@@ -1,19 +1,10 @@
//===- Module.h - MLIR Module Class -----------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Module is the top-level container for code in an MLIR program.
//
diff --git a/third_party/mlir/include/mlir/IR/OpAsmInterface.td b/third_party/mlir/include/mlir/IR/OpAsmInterface.td
index 85726a8..7e31c07 100644
--- a/third_party/mlir/include/mlir/IR/OpAsmInterface.td
+++ b/third_party/mlir/include/mlir/IR/OpAsmInterface.td
@@ -1,19 +1,10 @@
//===- OpAsmInterface.td - Asm Interfaces for opse ---------*- tablegen -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file contains Interfaces for interacting with the AsmParser and
// AsmPrinter.
diff --git a/third_party/mlir/include/mlir/IR/OpBase.td b/third_party/mlir/include/mlir/IR/OpBase.td
index 8f6770f..c457d25 100644
--- a/third_party/mlir/include/mlir/IR/OpBase.td
+++ b/third_party/mlir/include/mlir/IR/OpBase.td
@@ -1,19 +1,10 @@
//===-- OpBase.td - Base op definition file ----------------*- tablegen -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This is the base operation definition file.
//
@@ -1586,6 +1577,7 @@
bit hasFolder = 0;
// Op traits.
+ // Note: The list of traits will be uniqued by ODS.
list<OpTrait> traits = props;
// Additional code that will be added to the public part of the generated
diff --git a/third_party/mlir/include/mlir/IR/OpDefinition.h b/third_party/mlir/include/mlir/IR/OpDefinition.h
index c220120..1abf82f 100644
--- a/third_party/mlir/include/mlir/IR/OpDefinition.h
+++ b/third_party/mlir/include/mlir/IR/OpDefinition.h
@@ -1,19 +1,10 @@
//===- OpDefinition.h - Classes for defining concrete Op types --*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements helper classes for implementing the "Op" types. This
// includes the Op type, which is the base class for Op class definitions,
@@ -257,8 +248,8 @@
}
/// This class represents a single result from folding an operation.
-class OpFoldResult : public PointerUnion<Attribute, Value *> {
- using PointerUnion<Attribute, Value *>::PointerUnion;
+class OpFoldResult : public PointerUnion<Attribute, Value> {
+ using PointerUnion<Attribute, Value>::PointerUnion;
};
/// This template defines the foldHook as used by AbstractOperation.
@@ -311,8 +302,8 @@
typename std::enable_if<isSingleResult>::type> {
public:
/// If the operation returns a single value, then the Op can be implicitly
- /// converted to an Value*. This yields the value of the only result.
- operator Value *() {
+ /// converted to an Value. This yields the value of the only result.
+ operator Value() {
return static_cast<ConcreteType *>(this)->getOperation()->getResult(0);
}
@@ -326,7 +317,7 @@
// Check if the operation was folded in place. In this case, the operation
// returns itself.
- if (result.template dyn_cast<Value *>() != op->getResult(0))
+ if (result.template dyn_cast<Value>() != op->getResult(0))
results.push_back(result);
return success();
}
@@ -428,10 +419,10 @@
unsigned getNumOperands() { return this->getOperation()->getNumOperands(); }
/// Return the operand at index 'i'.
- Value *getOperand(unsigned i) { return this->getOperation()->getOperand(i); }
+ Value getOperand(unsigned i) { return this->getOperation()->getOperand(i); }
/// Set the operand at index 'i' to 'value'.
- void setOperand(unsigned i, Value *value) {
+ void setOperand(unsigned i, Value value) {
this->getOperation()->setOperand(i, value);
}
@@ -475,9 +466,9 @@
template <typename ConcreteType>
class OneOperand : public TraitBase<ConcreteType, OneOperand> {
public:
- Value *getOperand() { return this->getOperation()->getOperand(0); }
+ Value getOperand() { return this->getOperation()->getOperand(0); }
- void setOperand(Value *value) { this->getOperation()->setOperand(0, value); }
+ void setOperand(Value value) { this->getOperation()->setOperand(0, value); }
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyOneOperand(op);
@@ -550,7 +541,7 @@
unsigned getNumResults() { return this->getOperation()->getNumResults(); }
/// Return the result at index 'i'.
- Value *getResult(unsigned i) { return this->getOperation()->getResult(i); }
+ Value getResult(unsigned i) { return this->getOperation()->getResult(i); }
/// Replace all uses of results of this operation with the provided 'values'.
/// 'values' may correspond to an existing operation, or a range of 'Value'.
@@ -586,13 +577,13 @@
template <typename ConcreteType>
class OneResult : public TraitBase<ConcreteType, OneResult> {
public:
- Value *getResult() { return this->getOperation()->getResult(0); }
+ Value getResult() { return this->getOperation()->getResult(0); }
Type getType() { return getResult()->getType(); }
/// Replace all uses of 'this' value with the new value, updating anything in
/// the IR that uses 'this' to use the other value instead. When this returns
/// there are zero uses of 'this'.
- void replaceAllUsesWith(Value *newValue) {
+ void replaceAllUsesWith(Value newValue) {
getResult()->replaceAllUsesWith(newValue);
}
@@ -820,10 +811,10 @@
return this->getOperation()->setSuccessor(block, index);
}
- void addSuccessorOperand(unsigned index, Value *value) {
+ void addSuccessorOperand(unsigned index, Value value) {
return this->getOperation()->addSuccessorOperand(index, value);
}
- void addSuccessorOperands(unsigned index, ArrayRef<Value *> values) {
+ void addSuccessorOperands(unsigned index, ArrayRef<Value> values) {
return this->getOperation()->addSuccessorOperand(index, values);
}
};
@@ -1209,8 +1200,8 @@
ParseResult parseOneResultOneOperandTypeOp(OpAsmParser &parser,
OperationState &result);
-void buildBinaryOp(Builder *builder, OperationState &result, Value *lhs,
- Value *rhs);
+void buildBinaryOp(Builder *builder, OperationState &result, Value lhs,
+ Value rhs);
ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
OperationState &result);
@@ -1223,11 +1214,11 @@
// These functions are out-of-line implementations of the methods in CastOp,
// which avoids them being template instantiated/duplicated.
namespace impl {
-void buildCastOp(Builder *builder, OperationState &result, Value *source,
+void buildCastOp(Builder *builder, OperationState &result, Value source,
Type destType);
ParseResult parseCastOp(OpAsmParser &parser, OperationState &result);
void printCastOp(Operation *op, OpAsmPrinter &p);
-Value *foldCastOp(Operation *op);
+Value foldCastOp(Operation *op);
} // namespace impl
} // end namespace mlir
diff --git a/third_party/mlir/include/mlir/IR/OpImplementation.h b/third_party/mlir/include/mlir/IR/OpImplementation.h
index 97569cc..41acdba 100644
--- a/third_party/mlir/include/mlir/IR/OpImplementation.h
+++ b/third_party/mlir/include/mlir/IR/OpImplementation.h
@@ -1,19 +1,10 @@
//===- OpImplementation.h - Classes for implementing Op types ---*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This classes used by the implementation details of Op types.
//
@@ -45,7 +36,7 @@
virtual raw_ostream &getStream() const = 0;
/// Print implementations for various things an operation contains.
- virtual void printOperand(Value *value) = 0;
+ virtual void printOperand(Value value) = 0;
/// Print a comma separated list of operands.
template <typename ContainerType>
@@ -121,7 +112,7 @@
void printFunctionalType(Operation *op) {
auto &os = getStream();
os << "(";
- interleaveComma(op->getNonSuccessorOperands(), os, [&](Value *operand) {
+ interleaveComma(op->getNonSuccessorOperands(), os, [&](Value operand) {
if (operand)
printType(operand->getType());
else
@@ -150,17 +141,14 @@
};
// Make the implementations convenient to use.
-inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value &value) {
- p.printOperand(&value);
+inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value value) {
+ p.printOperand(value);
return p;
}
-inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value *value) {
- return p << *value;
-}
template <typename T,
typename std::enable_if<std::is_convertible<T &, ValueRange>::value &&
- !std::is_convertible<T &, Value *>::value,
+ !std::is_convertible<T &, Value &>::value,
T>::type * = nullptr>
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &values) {
p.printOperands(values);
@@ -182,7 +170,6 @@
// FunctionType with the Type version above, not have it match this.
template <typename T, typename std::enable_if<
!std::is_convertible<T &, Value &>::value &&
- !std::is_convertible<T &, Value *>::value &&
!std::is_convertible<T &, Type &>::value &&
!std::is_convertible<T &, Attribute &>::value &&
!std::is_convertible<T &, ValueRange>::value &&
@@ -467,13 +454,13 @@
/// Resolve an operand to an SSA value, emitting an error on failure.
virtual ParseResult resolveOperand(const OperandType &operand, Type type,
- SmallVectorImpl<Value *> &result) = 0;
+ SmallVectorImpl<Value> &result) = 0;
/// Resolve a list of operands to SSA values, emitting an error on failure, or
/// appending the results to the list on success. This method should be used
/// when all operands have the same type.
ParseResult resolveOperands(ArrayRef<OperandType> operands, Type type,
- SmallVectorImpl<Value *> &result) {
+ SmallVectorImpl<Value> &result) {
for (auto elt : operands)
if (resolveOperand(elt, type, result))
return failure();
@@ -485,7 +472,7 @@
/// to the list on success.
ParseResult resolveOperands(ArrayRef<OperandType> operands,
ArrayRef<Type> types, llvm::SMLoc loc,
- SmallVectorImpl<Value *> &result) {
+ SmallVectorImpl<Value> &result) {
if (operands.size() != types.size())
return emitError(loc)
<< operands.size() << " operands present, but expected "
@@ -555,8 +542,7 @@
/// Parse a single operation successor and its operand list.
virtual ParseResult
- parseSuccessorAndUseList(Block *&dest,
- SmallVectorImpl<Value *> &operands) = 0;
+ parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value> &operands) = 0;
//===--------------------------------------------------------------------===//
// Type Parsing
@@ -634,7 +620,7 @@
/// A functor used to set the name of the start of a result group of an
/// operation. See 'getAsmResultNames' below for more details.
-using OpAsmSetValueNameFn = function_ref<void(Value *, StringRef)>;
+using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>;
class OpAsmDialectInterface
: public DialectInterface::Base<OpAsmDialectInterface> {
@@ -661,6 +647,11 @@
/// OpAsmInterface.td#getAsmResultNames for usage details and documentation.
virtual void getAsmResultNames(Operation *op,
OpAsmSetValueNameFn setNameFn) const {}
+
+ /// Get a special name to use when printing the entry block arguments of the
+ /// region contained by an operation in this dialect.
+ virtual void getAsmBlockArgumentNames(Block *block,
+ OpAsmSetValueNameFn setNameFn) const {}
};
//===--------------------------------------------------------------------===//
diff --git a/third_party/mlir/include/mlir/IR/Operation.h b/third_party/mlir/include/mlir/IR/Operation.h
index 2159d10..9ef1636 100644
--- a/third_party/mlir/include/mlir/IR/Operation.h
+++ b/third_party/mlir/include/mlir/IR/Operation.h
@@ -1,19 +1,10 @@
//===- Operation.h - MLIR Operation Class -----------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines the Operation class.
//
@@ -43,8 +34,7 @@
public:
/// Create a new Operation with the specific fields.
static Operation *create(Location location, OperationName name,
- ArrayRef<Type> resultTypes,
- ArrayRef<Value *> operands,
+ ArrayRef<Type> resultTypes, ArrayRef<Value> operands,
ArrayRef<NamedAttribute> attributes,
ArrayRef<Block *> successors, unsigned numRegions,
bool resizableOperandList);
@@ -52,8 +42,7 @@
/// Overload of create that takes an existing NamedAttributeList to avoid
/// unnecessarily uniquing a list of attributes.
static Operation *create(Location location, OperationName name,
- ArrayRef<Type> resultTypes,
- ArrayRef<Value *> operands,
+ ArrayRef<Type> resultTypes, ArrayRef<Value> operands,
NamedAttributeList attributes,
ArrayRef<Block *> successors, unsigned numRegions,
bool resizableOperandList);
@@ -62,11 +51,12 @@
static Operation *create(const OperationState &state);
/// Create a new Operation with the specific fields.
- static Operation *
- create(Location location, OperationName name, ArrayRef<Type> resultTypes,
- ArrayRef<Value *> operands, NamedAttributeList attributes,
- ArrayRef<Block *> successors = {}, RegionRange regions = {},
- bool resizableOperandList = false);
+ static Operation *create(Location location, OperationName name,
+ ArrayRef<Type> resultTypes, ArrayRef<Value> operands,
+ NamedAttributeList attributes,
+ ArrayRef<Block *> successors = {},
+ RegionRange regions = {},
+ bool resizableOperandList = false);
/// The name of an operation is the key identifier for it.
OperationName getName() { return name; }
@@ -149,7 +139,7 @@
}
/// Replace any uses of 'from' with 'to' within this operation.
- void replaceUsesOfWith(Value *from, Value *to);
+ void replaceUsesOfWith(Value from, Value to);
/// Replace all uses of results of this operation with the provided 'values'.
template <typename ValuesT,
@@ -215,8 +205,8 @@
unsigned getNumOperands() { return getOperandStorage().size(); }
- Value *getOperand(unsigned idx) { return getOpOperand(idx).get(); }
- void setOperand(unsigned idx, Value *value) {
+ Value getOperand(unsigned idx) { return getOpOperand(idx).get(); }
+ void setOperand(unsigned idx, Value value) {
return getOpOperand(idx).set(value);
}
@@ -227,7 +217,7 @@
operand_iterator operand_begin() { return getOperands().begin(); }
operand_iterator operand_end() { return getOperands().end(); }
- /// Returns an iterator on the underlying Value's (Value *).
+ /// Returns an iterator on the underlying Value's (Value ).
operand_range getOperands() { return operand_range(this); }
/// Erase the operand at position `idx`.
@@ -255,7 +245,7 @@
unsigned getNumResults() { return numResults; }
- Value *getResult(unsigned idx) { return &getOpResult(idx); }
+ Value getResult(unsigned idx) { return getOpResult(idx); }
/// Support result iteration.
using result_range = ResultRange;
@@ -394,12 +384,18 @@
return {getTrailingObjects<BlockOperand>(), numSuccs};
}
+ // Successor iteration.
+ using succ_iterator = SuccessorRange::iterator;
+ succ_iterator successor_begin() { return getSuccessors().begin(); }
+ succ_iterator successor_end() { return getSuccessors().end(); }
+ SuccessorRange getSuccessors() { return SuccessorRange(this); }
+
/// Return the operands of this operation that are *not* successor arguments.
operand_range getNonSuccessorOperands();
operand_range getSuccessorOperands(unsigned index);
- Value *getSuccessorOperand(unsigned succIndex, unsigned opIndex) {
+ Value getSuccessorOperand(unsigned succIndex, unsigned opIndex) {
assert(!isKnownNonTerminator() && "only terminators may have successors");
assert(opIndex < getNumSuccessorOperands(succIndex));
return getOperand(getSuccessorOperandIndex(succIndex) + opIndex);
@@ -441,9 +437,9 @@
Optional<std::pair<unsigned, unsigned>>
decomposeSuccessorOperandIndex(unsigned operandIndex);
- /// Returns the `BlockArgument*` corresponding to operand `operandIndex` in
+ /// Returns the `BlockArgument` corresponding to operand `operandIndex` in
/// some successor, or None if `operandIndex` isn't a successor operand index.
- Optional<BlockArgument *> getSuccessorBlockArgument(unsigned operandIndex) {
+ Optional<BlockArgument> getSuccessorBlockArgument(unsigned operandIndex) {
auto decomposed = decomposeSuccessorOperandIndex(operandIndex);
if (!decomposed.hasValue())
return None;
diff --git a/third_party/mlir/include/mlir/IR/OperationSupport.h b/third_party/mlir/include/mlir/IR/OperationSupport.h
index 23ef0ce..30376b8 100644
--- a/third_party/mlir/include/mlir/IR/OperationSupport.h
+++ b/third_party/mlir/include/mlir/IR/OperationSupport.h
@@ -1,19 +1,10 @@
//===- OperationSupport.h ---------------------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines a number of support types that Operation and related
// classes build on top of.
@@ -270,7 +261,7 @@
struct OperationState {
Location location;
OperationName name;
- SmallVector<Value *, 4> operands;
+ SmallVector<Value, 4> operands;
/// Types of the results of this operation.
SmallVector<Type, 4> types;
SmallVector<NamedAttribute, 4> attributes;
@@ -534,8 +525,8 @@
/// This class implements iteration on the types of a given range of values.
template <typename ValueIteratorT>
class ValueTypeIterator final
- : public llvm::mapped_iterator<ValueIteratorT, Type (*)(Value *)> {
- static Type unwrap(Value *value) { return value->getType(); }
+ : public llvm::mapped_iterator<ValueIteratorT, Type (*)(Value)> {
+ static Type unwrap(Value value) { return value.getType(); }
public:
using reference = Type;
@@ -545,7 +536,7 @@
/// Initializes the type iterator to the specified value iterator.
ValueTypeIterator(ValueIteratorT it)
- : llvm::mapped_iterator<ValueIteratorT, Type (*)(Value *)>(it, &unwrap) {}
+ : llvm::mapped_iterator<ValueIteratorT, Type (*)(Value)>(it, &unwrap) {}
};
//===----------------------------------------------------------------------===//
@@ -554,7 +545,7 @@
/// This class implements the operand iterators for the Operation class.
class OperandRange final
: public detail::indexed_accessor_range_base<OperandRange, OpOperand *,
- Value *, Value *, Value *> {
+ Value, Value, Value> {
public:
using RangeBaseT::RangeBaseT;
OperandRange(Operation *op);
@@ -569,7 +560,7 @@
return object + index;
}
/// See `detail::indexed_accessor_range_base` for details.
- static Value *dereference_iterator(OpOperand *object, ptrdiff_t index) {
+ static Value dereference_iterator(OpOperand *object, ptrdiff_t index) {
return object[index].get();
}
@@ -582,8 +573,8 @@
/// This class implements the result iterators for the Operation class.
class ResultRange final
- : public detail::indexed_accessor_range_base<ResultRange, OpResult *,
- Value *, Value *, Value *> {
+ : public detail::indexed_accessor_range_base<ResultRange, OpResult *, Value,
+ Value, Value> {
public:
using RangeBaseT::RangeBaseT;
ResultRange(Operation *op);
@@ -598,8 +589,8 @@
return object + index;
}
/// See `detail::indexed_accessor_range_base` for details.
- static Value *dereference_iterator(OpResult *object, ptrdiff_t index) {
- return &object[index];
+ static Value dereference_iterator(OpResult *object, ptrdiff_t index) {
+ return object[index];
}
/// Allow access to `offset_base` and `dereference_iterator`.
@@ -610,31 +601,30 @@
// ValueRange
/// This class provides an abstraction over the different types of ranges over
-/// Value*s. In many cases, this prevents the need to explicitly materialize a
+/// Values. In many cases, this prevents the need to explicitly materialize a
/// SmallVector/std::vector. This class should be used in places that are not
/// suitable for a more derived type (e.g. ArrayRef) or a template range
/// parameter.
class ValueRange final
: public detail::indexed_accessor_range_base<
- ValueRange, PointerUnion<Value *const *, OpOperand *, OpResult *>,
- Value *, Value *, Value *> {
+ ValueRange, PointerUnion<const Value *, OpOperand *, OpResult *>,
+ Value, Value, Value> {
public:
using RangeBaseT::RangeBaseT;
template <typename Arg,
typename = typename std::enable_if_t<
- std::is_constructible<ArrayRef<Value *>, Arg>::value &&
- !std::is_convertible<Arg, Value *>::value>>
- ValueRange(Arg &&arg)
- : ValueRange(ArrayRef<Value *>(std::forward<Arg>(arg))) {}
- ValueRange(Value *const &value) : ValueRange(&value, /*count=*/1) {}
- ValueRange(const std::initializer_list<Value *> &values)
- : ValueRange(ArrayRef<Value *>(values)) {}
+ std::is_constructible<ArrayRef<Value>, Arg>::value &&
+ !std::is_convertible<Arg, Value>::value>>
+ ValueRange(Arg &&arg) : ValueRange(ArrayRef<Value>(std::forward<Arg>(arg))) {}
+ ValueRange(const Value &value) : ValueRange(&value, /*count=*/1) {}
+ ValueRange(const std::initializer_list<Value> &values)
+ : ValueRange(ArrayRef<Value>(values)) {}
ValueRange(iterator_range<OperandRange::iterator> values)
: ValueRange(OperandRange(values)) {}
ValueRange(iterator_range<ResultRange::iterator> values)
: ValueRange(ResultRange(values)) {}
- ValueRange(ArrayRef<Value *> values = llvm::None);
+ ValueRange(ArrayRef<Value> values = llvm::None);
ValueRange(OperandRange values);
ValueRange(ResultRange values);
@@ -645,12 +635,12 @@
private:
/// The type representing the owner of this range. This is either a list of
/// values, operands, or results.
- using OwnerT = PointerUnion<Value *const *, OpOperand *, OpResult *>;
+ using OwnerT = PointerUnion<const Value *, OpOperand *, OpResult *>;
/// See `detail::indexed_accessor_range_base` for details.
static OwnerT offset_base(const OwnerT &owner, ptrdiff_t index);
/// See `detail::indexed_accessor_range_base` for details.
- static Value *dereference_iterator(const OwnerT &owner, ptrdiff_t index);
+ static Value dereference_iterator(const OwnerT &owner, ptrdiff_t index);
/// Allow access to `offset_base` and `dereference_iterator`.
friend RangeBaseT;
diff --git a/third_party/mlir/include/mlir/IR/PatternMatch.h b/third_party/mlir/include/mlir/IR/PatternMatch.h
index 707bb7c..db160e3 100644
--- a/third_party/mlir/include/mlir/IR/PatternMatch.h
+++ b/third_party/mlir/include/mlir/IR/PatternMatch.h
@@ -1,19 +1,10 @@
//===- PatternMatch.h - PatternMatcher classes -------==---------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_PATTERNMATCHER_H
#define MLIR_PATTERNMATCHER_H
@@ -370,15 +361,31 @@
/// block into a new block, and return it.
virtual Block *splitBlock(Block *block, Block::iterator before);
- /// This method is used as the final notification hook for patterns that end
- /// up modifying the pattern root in place, by changing its operands. This is
- /// a minor efficiency win (it avoids creating a new operation and removing
- /// the old one) but also often allows simpler code in the client.
- ///
- /// The valuesToRemoveIfDead list is an optional list of values that the
- /// rewriter should remove if they are dead at this point.
- ///
- void updatedRootInPlace(Operation *op, ValueRange valuesToRemoveIfDead = {});
+ /// This method is used to notify the rewriter that an in-place operation
+ /// modification is about to happen. A call to this function *must* be
+ /// followed by a call to either `finalizeRootUpdate` or `cancelRootUpdate`.
+ /// This is a minor efficiency win (it avoids creating a new operation and
+ /// removing the old one) but also often allows simpler code in the client.
+ virtual void startRootUpdate(Operation *op) {}
+
+ /// This method is used to signal the end of a root update on the given
+ /// operation. This can only be called on operations that were provided to a
+ /// call to `startRootUpdate`.
+ virtual void finalizeRootUpdate(Operation *op) {}
+
+ /// This method cancels a pending root update. This can only be called on
+ /// operations that were provided to a call to `startRootUpdate`.
+ virtual void cancelRootUpdate(Operation *op) {}
+
+ /// This method is a utility wrapper around a root update of an operation. It
+ /// wraps calls to `startRootUpdate` and `finalizeRootUpdate` around the given
+ /// callable.
+ template <typename CallableT>
+ void updateRootInPlace(Operation *root, CallableT &&callable) {
+ startRootUpdate(root);
+ callable();
+ finalizeRootUpdate(root);
+ }
protected:
explicit PatternRewriter(MLIRContext *ctx) : OpBuilder(ctx) {}
@@ -387,10 +394,6 @@
// These are the callback methods that subclasses can choose to implement if
// they would like to be notified about certain types of mutations.
- /// Notify the pattern rewriter that the specified operation has been mutated
- /// in place. This is called after the mutation is done.
- virtual void notifyRootUpdated(Operation *op) {}
-
/// Notify the pattern rewriter that the specified operation is about to be
/// replaced with another set of operations. This is called before the uses
/// of the operation have been changed.
diff --git a/third_party/mlir/include/mlir/IR/Region.h b/third_party/mlir/include/mlir/IR/Region.h
index c1390ad..00f3ca7 100644
--- a/third_party/mlir/include/mlir/IR/Region.h
+++ b/third_party/mlir/include/mlir/IR/Region.h
@@ -1,19 +1,10 @@
//===- Region.h - MLIR Region Class -----------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines the Region class.
//
diff --git a/third_party/mlir/include/mlir/IR/RegionGraphTraits.h b/third_party/mlir/include/mlir/IR/RegionGraphTraits.h
index f45dcc4..b11c87d 100644
--- a/third_party/mlir/include/mlir/IR/RegionGraphTraits.h
+++ b/third_party/mlir/include/mlir/IR/RegionGraphTraits.h
@@ -1,19 +1,10 @@
//===- RegionGraphTraits.h - llvm::GraphTraits for CFGs ---------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements specializations of llvm::GraphTraits for various MLIR
// CFG data types. This allows the generic LLVM graph algorithms to be applied
diff --git a/third_party/mlir/include/mlir/IR/StandardTypes.h b/third_party/mlir/include/mlir/IR/StandardTypes.h
index b6b4b6e..89ffc45 100644
--- a/third_party/mlir/include/mlir/IR/StandardTypes.h
+++ b/third_party/mlir/include/mlir/IR/StandardTypes.h
@@ -1,19 +1,10 @@
//===- StandardTypes.h - MLIR Standard Type Classes -------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_STANDARDTYPES_H
#define MLIR_IR_STANDARDTYPES_H
diff --git a/third_party/mlir/include/mlir/IR/StorageUniquerSupport.h b/third_party/mlir/include/mlir/IR/StorageUniquerSupport.h
index 1a73073..f928819 100644
--- a/third_party/mlir/include/mlir/IR/StorageUniquerSupport.h
+++ b/third_party/mlir/include/mlir/IR/StorageUniquerSupport.h
@@ -1,19 +1,10 @@
//===- StorageUniquerSupport.h - MLIR Storage Uniquer Utilities -*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines utility classes for interfacing with StorageUniquer.
//
diff --git a/third_party/mlir/include/mlir/IR/SymbolTable.h b/third_party/mlir/include/mlir/IR/SymbolTable.h
index e04beac..0782918 100644
--- a/third_party/mlir/include/mlir/IR/SymbolTable.h
+++ b/third_party/mlir/include/mlir/IR/SymbolTable.h
@@ -1,19 +1,10 @@
//===- SymbolTable.h - MLIR Symbol Table Class ------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_SYMBOLTABLE_H
#define MLIR_IR_SYMBOLTABLE_H
diff --git a/third_party/mlir/include/mlir/IR/TypeSupport.h b/third_party/mlir/include/mlir/IR/TypeSupport.h
index 86620da..8cc811c 100644
--- a/third_party/mlir/include/mlir/IR/TypeSupport.h
+++ b/third_party/mlir/include/mlir/IR/TypeSupport.h
@@ -1,19 +1,10 @@
//===- TypeSupport.h --------------------------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines support types for registering dialect extended types.
//
diff --git a/third_party/mlir/include/mlir/IR/TypeUtilities.h b/third_party/mlir/include/mlir/IR/TypeUtilities.h
index 2cce4db..b095683 100644
--- a/third_party/mlir/include/mlir/IR/TypeUtilities.h
+++ b/third_party/mlir/include/mlir/IR/TypeUtilities.h
@@ -1,19 +1,10 @@
//===- TypeUtilities.h - Helper function for type queries -------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines generic type utilities.
//
@@ -41,8 +32,7 @@
/// Return the element type or return the type itself.
Type getElementTypeOrSelf(Attribute attr);
-Type getElementTypeOrSelf(Value *val);
-Type getElementTypeOrSelf(Value &val);
+Type getElementTypeOrSelf(Value val);
/// Get the types within a nested Tuple. A helper for the class method that
/// handles storage concerns, which is tricky to do in tablegen.
@@ -72,7 +62,7 @@
// An iterator for the element types of an op's operands of shaped types.
class OperandElementTypeIterator final
: public llvm::mapped_iterator<Operation::operand_iterator,
- Type (*)(Value *)> {
+ Type (*)(Value)> {
public:
using reference = Type;
@@ -81,7 +71,7 @@
explicit OperandElementTypeIterator(Operation::operand_iterator it);
private:
- static Type unwrap(Value *value);
+ static Type unwrap(Value value);
};
using OperandElementTypeRange = iterator_range<OperandElementTypeIterator>;
@@ -89,7 +79,7 @@
// An iterator for the tensor element types of an op's results of shaped types.
class ResultElementTypeIterator final
: public llvm::mapped_iterator<Operation::result_iterator,
- Type (*)(Value *)> {
+ Type (*)(Value)> {
public:
using reference = Type;
@@ -98,7 +88,7 @@
explicit ResultElementTypeIterator(Operation::result_iterator it);
private:
- static Type unwrap(Value *value);
+ static Type unwrap(Value value);
};
using ResultElementTypeRange = iterator_range<ResultElementTypeIterator>;
diff --git a/third_party/mlir/include/mlir/IR/Types.h b/third_party/mlir/include/mlir/IR/Types.h
index 11af3eb..6246e9b 100644
--- a/third_party/mlir/include/mlir/IR/Types.h
+++ b/third_party/mlir/include/mlir/IR/Types.h
@@ -1,19 +1,10 @@
//===- Types.h - MLIR Type Classes ------------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_TYPES_H
#define MLIR_IR_TYPES_H
@@ -121,11 +112,8 @@
/* implicit */ Type(const ImplType *impl)
: impl(const_cast<ImplType *>(impl)) {}
- Type(const Type &other) : impl(other.impl) {}
- Type &operator=(Type other) {
- impl = other.impl;
- return *this;
- }
+ Type(const Type &other) = default;
+ Type &operator=(const Type &other) = default;
bool operator==(Type other) const { return impl == other.impl; }
bool operator!=(Type other) const { return !(*this == other); }
diff --git a/third_party/mlir/include/mlir/IR/UseDefLists.h b/third_party/mlir/include/mlir/IR/UseDefLists.h
index 96e4ace..05720ed 100644
--- a/third_party/mlir/include/mlir/IR/UseDefLists.h
+++ b/third_party/mlir/include/mlir/IR/UseDefLists.h
@@ -1,19 +1,10 @@
//===- UseDefLists.h --------------------------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines generic use/def list machinery and manipulation utilities.
//
@@ -30,6 +21,7 @@
class IROperand;
class Operation;
+class Value;
template <typename OperandType> class ValueUseIterator;
template <typename OperandType> class ValueUserIterator;
@@ -176,6 +168,22 @@
}
};
+/// A reference to a value, suitable for use as an operand of an operation.
+class OpOperand : public IROperand {
+public:
+ OpOperand(Operation *owner) : IROperand(owner) {}
+ OpOperand(Operation *owner, Value value);
+
+ /// Return the current value being used by this operand.
+ Value get();
+
+ /// Set the current value being used by this operand.
+ void set(Value newValue);
+
+ /// Return which operand this is in the operand list of the User.
+ unsigned getOperandNumber();
+};
+
/// A reference to a value, suitable for use as an operand of an operation,
/// operation, etc. IRValueTy is the root type to use for values this tracks,
/// and SSAUserTy is the type that will contain operands.
diff --git a/third_party/mlir/include/mlir/IR/Value.h b/third_party/mlir/include/mlir/IR/Value.h
index 34c74c8..c4356b1 100644
--- a/third_party/mlir/include/mlir/IR/Value.h
+++ b/third_party/mlir/include/mlir/IR/Value.h
@@ -1,19 +1,10 @@
//===- Value.h - Base of the SSA Value hierarchy ----------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines generic Value type and manipulation utilities.
//
@@ -28,29 +19,107 @@
namespace mlir {
class Block;
+class BlockArgument;
class Operation;
+class OpResult;
class Region;
class Value;
-/// Operands contain a Value.
-using OpOperand = IROperandImpl<Value>;
+namespace detail {
+/// The internal implementation of a Value.
+class ValueImpl : public IRObjectWithUseList {
+protected:
+ /// This enumerates all of the SSA value kinds.
+ enum class Kind {
+ BlockArgument,
+ OpResult,
+ };
-/// This is the common base class for all SSA values in the MLIR system,
-/// representing a computable value that has a type and a set of users.
+ ValueImpl(Kind kind, Type type) : typeAndKind(type, kind) {}
+
+private:
+ /// The type of the value and its kind.
+ llvm::PointerIntPair<Type, 1, Kind> typeAndKind;
+
+ /// Allow access to 'typeAndKind'.
+ friend Value;
+};
+
+/// The internal implementation of a BlockArgument.
+class BlockArgumentImpl : public ValueImpl {
+ BlockArgumentImpl(Type type, Block *owner)
+ : ValueImpl(Kind::BlockArgument, type), owner(owner) {}
+
+ /// The owner of this argument.
+ Block *owner;
+
+ /// Allow access to owner and constructor.
+ friend BlockArgument;
+};
+
+class OpResultImpl : public ValueImpl {
+ OpResultImpl(Type type, Operation *owner)
+ : ValueImpl(Kind::OpResult, type), owner(owner) {}
+
+ /// The owner of this result.
+ Operation *owner;
+
+ /// Allow access to owner and the constructor.
+ friend OpResult;
+};
+} // end namespace detail
+
+/// This class represents an instance of an SSA value in the MLIR system,
+/// representing a computable value that has a type and a set of users. An SSA
+/// value is either a BlockArgument or the result of an operation. Note: This
+/// class has value-type semantics and is just a simple wrapper around a
+/// ValueImpl that is either owner by a block(in the case of a BlockArgument) or
+/// an Operation(in the case of an OpResult).
///
-class Value : public IRObjectWithUseList {
+class Value {
public:
/// This enumerates all of the SSA value kinds in the MLIR system.
enum class Kind {
- BlockArgument, // block argument
- OpResult, // operation result
+ BlockArgument,
+ OpResult,
};
+ Value(std::nullptr_t) : impl(nullptr) {}
+ Value(detail::ValueImpl *impl = nullptr) : impl(impl) {}
+ Value(const Value &) = default;
+ Value &operator=(const Value &) = default;
~Value() {}
- Kind getKind() const { return typeAndKind.getInt(); }
+ template <typename U> bool isa() const {
+ assert(impl && "isa<> used on a null type.");
+ return U::classof(*this);
+ }
+ template <typename U> U dyn_cast() const {
+ return isa<U>() ? U(impl) : U(nullptr);
+ }
+ template <typename U> U dyn_cast_or_null() const {
+ return (impl && isa<U>()) ? U(impl) : U(nullptr);
+ }
+ template <typename U> U cast() const {
+ assert(isa<U>());
+ return U(impl);
+ }
- Type getType() const { return typeAndKind.getPointer(); }
+ /// Temporary methods to enable transition of Value to being used as a
+ /// value-type.
+ /// TODO(riverriddle) Remove these when all usages have been removed.
+ Value operator*() const { return *this; }
+ Value *operator->() const { return (Value *)this; }
+
+ operator bool() const { return impl; }
+ bool operator==(const Value &other) const { return impl == other.impl; }
+ bool operator!=(const Value &other) const { return !(*this == other); }
+
+ /// Return the kind of this value.
+ Kind getKind() const { return (Kind)impl->typeAndKind.getInt(); }
+
+ /// Return the type of this value.
+ Type getType() const { return impl->typeAndKind.getPointer(); }
/// Utility to get the associated MLIRContext that this value is defined in.
MLIRContext *getContext() const { return getType().getContext(); }
@@ -61,18 +130,18 @@
/// completely invalid IR very easily. It is strongly recommended that you
/// recreate IR objects with the right types instead of mutating them in
/// place.
- void setType(Type newType) { typeAndKind.setPointer(newType); }
+ void setType(Type newType) { impl->typeAndKind.setPointer(newType); }
/// Replace all uses of 'this' value with the new value, updating anything in
/// the IR that uses 'this' to use the other value instead. When this returns
/// there are zero uses of 'this'.
- void replaceAllUsesWith(Value *newValue) {
- IRObjectWithUseList::replaceAllUsesWith(newValue);
+ void replaceAllUsesWith(Value newValue) const {
+ impl->replaceAllUsesWith(newValue.impl);
}
/// If this value is the result of an operation, return the operation that
/// defines it.
- Operation *getDefiningOp();
+ Operation *getDefiningOp() const;
/// If this value is the result of an operation, use it as a location,
/// otherwise return an unknown location.
@@ -90,24 +159,51 @@
/// Returns a range of all uses, which is useful for iterating over all uses.
inline use_range getUses();
+ using user_iterator = ValueUserIterator<IROperand>;
+ using user_range = iterator_range<user_iterator>;
+
+ user_iterator user_begin() const { return impl->user_begin(); }
+ user_iterator user_end() const { return impl->user_end(); }
+
+ /// Returns a range of all users.
+ user_range getUsers() const { return impl->getUsers(); }
+
+ /// Returns true if this value has no uses.
+ bool use_empty() const { return impl->use_empty(); }
+
+ /// Returns true if this value has exactly one use.
+ bool hasOneUse() const { return impl->hasOneUse(); }
+
+ /// Drop all uses of this object from their respective owners.
+ void dropAllUses() const { impl->dropAllUses(); }
+
void print(raw_ostream &os);
void dump();
-protected:
- Value(Kind kind, Type type) : typeAndKind(type, kind) {}
+ /// Methods for supporting PointerLikeTypeTraits.
+ void *getAsOpaquePointer() const { return static_cast<void *>(impl); }
+ static Value getFromOpaquePointer(const void *pointer) {
+ return reinterpret_cast<detail::ValueImpl *>(const_cast<void *>(pointer));
+ }
-private:
- llvm::PointerIntPair<Type, 1, Kind> typeAndKind;
+ friend ::llvm::hash_code hash_value(Value arg);
+
+protected:
+ /// The internal implementation of this value.
+ mutable detail::ValueImpl *impl;
+
+ /// Allow access to 'impl'.
+ friend OpOperand;
};
-inline raw_ostream &operator<<(raw_ostream &os, Value &value) {
+inline raw_ostream &operator<<(raw_ostream &os, Value value) {
value.print(os);
return os;
}
// Utility functions for iterating through Value uses.
inline auto Value::use_begin() -> use_iterator {
- return use_iterator((OpOperand *)getFirstUse());
+ return use_iterator((OpOperand *)impl->getFirstUse());
}
inline auto Value::use_end() -> use_iterator { return use_iterator(nullptr); }
@@ -119,48 +215,148 @@
/// Block arguments are values.
class BlockArgument : public Value {
public:
- static bool classof(const Value *value) {
- return const_cast<Value *>(value)->getKind() == Kind::BlockArgument;
+ using Value::Value;
+
+ /// Temporary methods to enable transition of Value to being used as a
+ /// value-type.
+ /// TODO(riverriddle) Remove this when all usages have been removed.
+ BlockArgument *operator->() { return this; }
+
+ static bool classof(Value value) {
+ return value.getKind() == Kind::BlockArgument;
}
- Block *getOwner() { return owner; }
+ /// Returns the block that owns this argument.
+ Block *getOwner() const { return getImpl()->owner; }
/// Returns the number of this argument.
- unsigned getArgNumber();
+ unsigned getArgNumber() const;
private:
- friend class Block; // For access to private constructor.
- BlockArgument(Type type, Block *owner)
- : Value(Value::Kind::BlockArgument, type), owner(owner) {}
+ /// Allocate a new argument with the given type and owner.
+ static BlockArgument create(Type type, Block *owner) {
+ return new detail::BlockArgumentImpl(type, owner);
+ }
- /// The owner of this operand.
- /// TODO: can encode this more efficiently to avoid the space hit of this
- /// through bitpacking shenanigans.
- Block *const owner;
+ /// Destroy and deallocate this argument.
+ void destroy() { delete getImpl(); }
+
+ /// Get a raw pointer to the internal implementation.
+ detail::BlockArgumentImpl *getImpl() const {
+ return reinterpret_cast<detail::BlockArgumentImpl *>(impl);
+ }
+
+ /// Allow access to `create` and `destroy`.
+ friend Block;
};
/// This is a value defined by a result of an operation.
class OpResult : public Value {
public:
- OpResult(Type type, Operation *owner)
- : Value(Value::Kind::OpResult, type), owner(owner) {}
+ using Value::Value;
- static bool classof(const Value *value) {
- return const_cast<Value *>(value)->getKind() == Kind::OpResult;
- }
+ /// Temporary methods to enable transition of Value to being used as a
+ /// value-type.
+ /// TODO(riverriddle) Remove these when all usages have been removed.
+ OpResult *operator*() { return this; }
+ OpResult *operator->() { return this; }
- Operation *getOwner() { return owner; }
+ static bool classof(Value value) { return value.getKind() == Kind::OpResult; }
+
+ /// Returns the operation that owns this result.
+ Operation *getOwner() const { return getImpl()->owner; }
/// Returns the number of this result.
- unsigned getResultNumber();
+ unsigned getResultNumber() const;
private:
- /// The owner of this operand.
- /// TODO: can encode this more efficiently to avoid the space hit of this
- /// through bitpacking shenanigans.
- Operation *const owner;
+ /// Allocate a new result with the given type and owner.
+ static OpResult create(Type type, Operation *owner) {
+ return new detail::OpResultImpl(type, owner);
+ }
+
+ /// Destroy and deallocate this result.
+ void destroy() { delete getImpl(); }
+
+ /// Get a raw pointer to the internal implementation.
+ detail::OpResultImpl *getImpl() const {
+ return reinterpret_cast<detail::OpResultImpl *>(impl);
+ }
+
+ /// Allow access to `create` and `destroy`.
+ friend Operation;
};
+/// Make Value hashable.
+inline ::llvm::hash_code hash_value(Value arg) {
+ return ::llvm::hash_value(arg.impl);
+}
+
} // namespace mlir
+namespace llvm {
+
+template <> struct DenseMapInfo<mlir::Value> {
+ static mlir::Value getEmptyKey() {
+ auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
+ return mlir::Value(static_cast<mlir::detail::ValueImpl *>(pointer));
+ }
+ static mlir::Value getTombstoneKey() {
+ auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
+ return mlir::Value(static_cast<mlir::detail::ValueImpl *>(pointer));
+ }
+ static unsigned getHashValue(mlir::Value val) {
+ return mlir::hash_value(val);
+ }
+ static bool isEqual(mlir::Value LHS, mlir::Value RHS) { return LHS == RHS; }
+};
+
+/// Allow stealing the low bits of a value.
+template <> struct PointerLikeTypeTraits<mlir::Value> {
+public:
+ static inline void *getAsVoidPointer(mlir::Value I) {
+ return const_cast<void *>(I.getAsOpaquePointer());
+ }
+ static inline mlir::Value getFromVoidPointer(void *P) {
+ return mlir::Value::getFromOpaquePointer(P);
+ }
+ enum {
+ NumLowBitsAvailable =
+ PointerLikeTypeTraits<mlir::detail::ValueImpl *>::NumLowBitsAvailable
+ };
+};
+
+template <> struct DenseMapInfo<mlir::BlockArgument> {
+ static mlir::BlockArgument getEmptyKey() {
+ auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
+ return mlir::BlockArgument(static_cast<mlir::detail::ValueImpl *>(pointer));
+ }
+ static mlir::BlockArgument getTombstoneKey() {
+ auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
+ return mlir::BlockArgument(static_cast<mlir::detail::ValueImpl *>(pointer));
+ }
+ static unsigned getHashValue(mlir::BlockArgument val) {
+ return mlir::hash_value(val);
+ }
+ static bool isEqual(mlir::BlockArgument LHS, mlir::BlockArgument RHS) {
+ return LHS == RHS;
+ }
+};
+
+/// Allow stealing the low bits of a value.
+template <> struct PointerLikeTypeTraits<mlir::BlockArgument> {
+public:
+ static inline void *getAsVoidPointer(mlir::Value I) {
+ return const_cast<void *>(I.getAsOpaquePointer());
+ }
+ static inline mlir::BlockArgument getFromVoidPointer(void *P) {
+ return mlir::Value::getFromOpaquePointer(P).cast<mlir::BlockArgument>();
+ }
+ enum {
+ NumLowBitsAvailable =
+ PointerLikeTypeTraits<mlir::detail::ValueImpl *>::NumLowBitsAvailable
+ };
+};
+} // end namespace llvm
+
#endif
diff --git a/third_party/mlir/include/mlir/IR/Visitors.h b/third_party/mlir/include/mlir/IR/Visitors.h
index 50d6562..aaab933 100644
--- a/third_party/mlir/include/mlir/IR/Visitors.h
+++ b/third_party/mlir/include/mlir/IR/Visitors.h
@@ -1,19 +1,10 @@
//===- Visitors.h - Utilities for visiting operations -----------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines utilities for walking and visiting operations.
//
diff --git a/third_party/mlir/include/mlir/Parser.h b/third_party/mlir/include/mlir/Parser.h
index 3a818ff..cae1e8b 100644
--- a/third_party/mlir/include/mlir/Parser.h
+++ b/third_party/mlir/include/mlir/Parser.h
@@ -1,19 +1,10 @@
//===- Parser.h - MLIR Parser Library Interface -----------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file is contains the interface to the MLIR parser library.
//
diff --git a/third_party/mlir/include/mlir/Pass/AnalysisManager.h b/third_party/mlir/include/mlir/Pass/AnalysisManager.h
index e233a4a..471cd01 100644
--- a/third_party/mlir/include/mlir/Pass/AnalysisManager.h
+++ b/third_party/mlir/include/mlir/Pass/AnalysisManager.h
@@ -1,19 +1,10 @@
//===- AnalysisManager.h - Analysis Management Infrastructure ---*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_PASS_ANALYSISMANAGER_H
#define MLIR_PASS_ANALYSISMANAGER_H
diff --git a/third_party/mlir/include/mlir/Pass/Pass.h b/third_party/mlir/include/mlir/Pass/Pass.h
index 380b097..bcb2973 100644
--- a/third_party/mlir/include/mlir/Pass/Pass.h
+++ b/third_party/mlir/include/mlir/Pass/Pass.h
@@ -1,19 +1,10 @@
//===- Pass.h - Base classes for compiler passes ----------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_PASS_PASS_H
#define MLIR_PASS_PASS_H
@@ -70,12 +61,40 @@
/// this is a generic OperationPass.
Optional<StringRef> getOpName() const { return opName; }
+ //===--------------------------------------------------------------------===//
+ // Options
+ //===--------------------------------------------------------------------===//
+
+ /// This class represents a specific pass option, with a provided data type.
+ template <typename DataType>
+ struct Option : public detail::PassOptions::Option<DataType> {
+ template <typename... Args>
+ Option(Pass &parent, StringRef arg, Args &&... args)
+ : detail::PassOptions::Option<DataType>(parent.passOptions, arg,
+ std::forward<Args>(args)...) {}
+ using detail::PassOptions::Option<DataType>::operator=;
+ };
+ /// This class represents a specific pass option that contains a list of
+ /// values of the provided data type.
+ template <typename DataType>
+ struct ListOption : public detail::PassOptions::ListOption<DataType> {
+ template <typename... Args>
+ ListOption(Pass &parent, StringRef arg, Args &&... args)
+ : detail::PassOptions::ListOption<DataType>(
+ parent.passOptions, arg, std::forward<Args>(args)...) {}
+ using detail::PassOptions::ListOption<DataType>::operator=;
+ };
+
+ /// Attempt to initialize the options of this pass from the given string.
+ LogicalResult initializeOptions(StringRef options);
+
/// Prints out the pass in the textual representation of pipelines. If this is
/// an adaptor pass, print with the op_name(sub_pass,...) format.
- /// Note: The default implementation uses the class name and does not respect
- /// options used to construct the pass. Override this method to allow for your
- /// pass to be to be round-trippable to the textual format.
- virtual void printAsTextualPipeline(raw_ostream &os);
+ void printAsTextualPipeline(raw_ostream &os);
+
+ //===--------------------------------------------------------------------===//
+ // Statistics
+ //===--------------------------------------------------------------------===//
/// This class represents a single pass statistic. This statistic functions
/// similarly to an unsigned integer value, and may be updated and incremented
@@ -128,6 +147,10 @@
return getPassState().analysisManager;
}
+ /// Copy the option values from 'other', which is another instance of this
+ /// pass.
+ void copyOptionValuesFrom(const Pass *other);
+
private:
/// Forwarding function to execute this pass on the given operation.
LLVM_NODISCARD
@@ -150,6 +173,9 @@
/// The set of statistics held by this pass.
std::vector<Statistic *> statistics;
+ /// The pass options registered to this pass instance.
+ detail::PassOptions passOptions;
+
/// Allow access to 'clone' and 'run'.
friend class OpPassManager;
};
@@ -213,7 +239,9 @@
/// A clone method to create a copy of this pass.
std::unique_ptr<Pass> clone() const override {
- return std::make_unique<PassT>(*static_cast<const PassT *>(this));
+ auto newInst = std::make_unique<PassT>(*static_cast<const PassT *>(this));
+ newInst->copyOptionValuesFrom(this);
+ return newInst;
}
/// Returns the analysis for the parent operation if it exists.
diff --git a/third_party/mlir/include/mlir/Pass/PassInstrumentation.h b/third_party/mlir/include/mlir/Pass/PassInstrumentation.h
index 4b61850..ef75e56 100644
--- a/third_party/mlir/include/mlir/Pass/PassInstrumentation.h
+++ b/third_party/mlir/include/mlir/Pass/PassInstrumentation.h
@@ -1,19 +1,10 @@
//===- PassInstrumentation.h ------------------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_PASS_PASSINSTRUMENTATION_H_
#define MLIR_PASS_PASSINSTRUMENTATION_H_
diff --git a/third_party/mlir/include/mlir/Pass/PassManager.h b/third_party/mlir/include/mlir/Pass/PassManager.h
index 9de8ace..d4f3683 100644
--- a/third_party/mlir/include/mlir/Pass/PassManager.h
+++ b/third_party/mlir/include/mlir/Pass/PassManager.h
@@ -1,19 +1,10 @@
//===- PassManager.h - Pass Management Interface ----------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_PASS_PASSMANAGER_H
#define MLIR_PASS_PASSMANAGER_H
diff --git a/third_party/mlir/include/mlir/Pass/PassOptions.h b/third_party/mlir/include/mlir/Pass/PassOptions.h
new file mode 100644
index 0000000..66f4e86
--- /dev/null
+++ b/third_party/mlir/include/mlir/Pass/PassOptions.h
@@ -0,0 +1,238 @@
+//===- PassOptions.h - Pass Option Utilities --------------------*- C++ -*-===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains utilities for registering options with compiler passes and
+// pipelines.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_PASS_PASSOPTIONS_H_
+#define MLIR_PASS_PASSOPTIONS_H_
+
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Support/STLExtras.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Compiler.h"
+#include <memory>
+
+namespace mlir {
+namespace detail {
+/// Base container class and manager for all pass options.
+class PassOptions : protected llvm::cl::SubCommand {
+private:
+ /// This is the type-erased option base class. This provides some additional
+ /// hooks into the options that are not available via llvm::cl::Option.
+ class OptionBase {
+ public:
+ virtual ~OptionBase() = default;
+
+ /// Out of line virtual function to provide home for the class.
+ virtual void anchor();
+
+ /// Print the name and value of this option to the given stream.
+ virtual void print(raw_ostream &os) = 0;
+
+ /// Return the argument string of this option.
+ StringRef getArgStr() const { return getOption()->ArgStr; }
+
+ protected:
+ /// Return the main option instance.
+ virtual const llvm::cl::Option *getOption() const = 0;
+
+ /// Copy the value from the given option into this one.
+ virtual void copyValueFrom(const OptionBase &other) = 0;
+
+ /// Allow access to private methods.
+ friend PassOptions;
+ };
+
+ /// This is the parser that is used by pass options that use literal options.
+ /// This is a thin wrapper around the llvm::cl::parser, that exposes some
+ /// additional methods.
+ template <typename DataType>
+ struct GenericOptionParser : public llvm::cl::parser<DataType> {
+ using llvm::cl::parser<DataType>::parser;
+
+ /// Returns an argument name that maps to the specified value.
+ Optional<StringRef> findArgStrForValue(const DataType &value) {
+ for (auto &it : this->Values)
+ if (it.V.compare(value))
+ return it.Name;
+ return llvm::None;
+ }
+ };
+
+ /// The specific parser to use depending on llvm::cl parser used. This is only
+ /// necessary because we need to provide additional methods for certain data
+ /// type parsers.
+ /// TODO(riverriddle) We should upstream the methods in GenericOptionParser to
+ /// avoid the need to do this.
+ template <typename DataType>
+ using OptionParser =
+ std::conditional_t<std::is_base_of<llvm::cl::generic_parser_base,
+ llvm::cl::parser<DataType>>::value,
+ GenericOptionParser<DataType>,
+ llvm::cl::parser<DataType>>;
+
+ /// Utility methods for printing option values.
+ template <typename DataT>
+ static void printValue(raw_ostream &os, GenericOptionParser<DataT> &parser,
+ const DataT &value) {
+ if (Optional<StringRef> argStr = parser.findArgStrForValue(value))
+ os << argStr;
+ else
+ llvm_unreachable("unknown data value for option");
+ }
+ template <typename DataT, typename ParserT>
+ static void printValue(raw_ostream &os, ParserT &parser, const DataT &value) {
+ os << value;
+ }
+ template <typename ParserT>
+ static void printValue(raw_ostream &os, ParserT &parser, const bool &value) {
+ os << (value ? StringRef("true") : StringRef("false"));
+ }
+
+public:
+ /// This class represents a specific pass option, with a provided data type.
+ template <typename DataType>
+ class Option : public llvm::cl::opt<DataType, /*ExternalStorage=*/false,
+ OptionParser<DataType>>,
+ public OptionBase {
+ public:
+ template <typename... Args>
+ Option(PassOptions &parent, StringRef arg, Args &&... args)
+ : llvm::cl::opt<DataType, /*ExternalStorage=*/false,
+ OptionParser<DataType>>(arg, llvm::cl::sub(parent),
+ std::forward<Args>(args)...) {
+ assert(!this->isPositional() && !this->isSink() &&
+ "sink and positional options are not supported");
+ parent.options.push_back(this);
+ }
+ using llvm::cl::opt<DataType, /*ExternalStorage=*/false,
+ OptionParser<DataType>>::operator=;
+ ~Option() override = default;
+
+ private:
+ /// Return the main option instance.
+ const llvm::cl::Option *getOption() const final { return this; }
+
+ /// Print the name and value of this option to the given stream.
+ void print(raw_ostream &os) final {
+ os << this->ArgStr << '=';
+ printValue(os, this->getParser(), this->getValue());
+ }
+
+ /// Copy the value from the given option into this one.
+ void copyValueFrom(const OptionBase &other) final {
+ this->setValue(static_cast<const Option<DataType> &>(other).getValue());
+ }
+ };
+
+ /// This class represents a specific pass option that contains a list of
+ /// values of the provided data type.
+ template <typename DataType>
+ class ListOption : public llvm::cl::list<DataType, /*StorageClass=*/bool,
+ OptionParser<DataType>>,
+ public OptionBase {
+ public:
+ template <typename... Args>
+ ListOption(PassOptions &parent, StringRef arg, Args &&... args)
+ : llvm::cl::list<DataType, /*StorageClass=*/bool,
+ OptionParser<DataType>>(arg, llvm::cl::sub(parent),
+ std::forward<Args>(args)...) {
+ assert(!this->isPositional() && !this->isSink() &&
+ "sink and positional options are not supported");
+ parent.options.push_back(this);
+ }
+ ~ListOption() override = default;
+
+ /// Allow assigning from an ArrayRef.
+ ListOption<DataType> &operator=(ArrayRef<DataType> values) {
+ (*this)->assign(values.begin(), values.end());
+ return *this;
+ }
+
+ std::vector<DataType> *operator->() { return &*this; }
+
+ private:
+ /// Return the main option instance.
+ const llvm::cl::Option *getOption() const final { return this; }
+
+ /// Print the name and value of this option to the given stream.
+ void print(raw_ostream &os) final {
+ os << this->ArgStr << '=';
+ auto printElementFn = [&](const DataType &value) {
+ printValue(os, this->getParser(), value);
+ };
+ interleave(*this, os, printElementFn, ",");
+ }
+
+ /// Copy the value from the given option into this one.
+ void copyValueFrom(const OptionBase &other) final {
+ (*this) = ArrayRef<DataType>((ListOption<DataType> &)other);
+ }
+ };
+
+ PassOptions() = default;
+
+ /// Copy the option values from 'other' into 'this', where 'other' has the
+ /// same options as 'this'.
+ void copyOptionValuesFrom(const PassOptions &other);
+
+ /// Parse options out as key=value pairs that can then be handed off to the
+ /// `llvm::cl` command line passing infrastructure. Everything is space
+ /// separated.
+ LogicalResult parseFromString(StringRef options);
+
+ /// Print the options held by this struct in a form that can be parsed via
+ /// 'parseFromString'.
+ void print(raw_ostream &os);
+
+private:
+ /// A list of all of the opaque options.
+ std::vector<OptionBase *> options;
+};
+} // end namespace detail
+
+//===----------------------------------------------------------------------===//
+// PassPipelineOptions
+//===----------------------------------------------------------------------===//
+
+/// Subclasses of PassPipelineOptions provide a set of options that can be used
+/// to initialize a pass pipeline. See PassPipelineRegistration for usage
+/// details.
+///
+/// Usage:
+///
+/// struct MyPipelineOptions : PassPipelineOptions<MyPassOptions> {
+/// ListOption<int> someListFlag{
+/// *this, "flag-name", llvm::cl::MiscFlags::CommaSeparated,
+/// llvm::cl::desc("...")};
+/// };
+template <typename T> class PassPipelineOptions : public detail::PassOptions {
+public:
+ /// Factory that parses the provided options and returns a unique_ptr to the
+ /// struct.
+ static std::unique_ptr<T> createFromString(StringRef options) {
+ auto result = std::make_unique<T>();
+ if (failed(result->parseFromString(options)))
+ return nullptr;
+ return result;
+ }
+};
+
+/// A default empty option struct to be used for passes that do not need to take
+/// any options.
+struct EmptyPipelineOptions : public PassPipelineOptions<EmptyPipelineOptions> {
+};
+
+} // end namespace mlir
+
+#endif // MLIR_PASS_PASSOPTIONS_H_
diff --git a/third_party/mlir/include/mlir/Pass/PassRegistry.h b/third_party/mlir/include/mlir/Pass/PassRegistry.h
index 356b13e..c5604c0 100644
--- a/third_party/mlir/include/mlir/Pass/PassRegistry.h
+++ b/third_party/mlir/include/mlir/Pass/PassRegistry.h
@@ -1,19 +1,10 @@
//===- PassRegistry.h - Pass Registration Utilities -------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file contains utilities for registering information about compiler
// passes.
@@ -23,14 +14,8 @@
#ifndef MLIR_PASS_PASSREGISTRY_H_
#define MLIR_PASS_PASSREGISTRY_H_
-#include "mlir/Support/LLVM.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Support/STLExtras.h"
-#include "llvm/ADT/StringRef.h"
-#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/Compiler.h"
+#include "mlir/Pass/PassOptions.h"
#include <functional>
-#include <memory>
namespace mlir {
class OpPassManager;
@@ -40,6 +25,7 @@
/// also parse options and return success() if parsing succeeded.
using PassRegistryFunction =
std::function<LogicalResult(OpPassManager &, StringRef options)>;
+using PassAllocatorFunction = std::function<std::unique_ptr<Pass>()>;
/// A special type used by transformation passes to provide an address that can
/// act as a unique identifier during pass registration.
@@ -71,7 +57,7 @@
protected:
PassRegistryEntry(StringRef arg, StringRef description,
- PassRegistryFunction builder)
+ const PassRegistryFunction &builder)
: arg(arg), description(description), builder(builder) {}
private:
@@ -89,7 +75,7 @@
class PassPipelineInfo : public PassRegistryEntry {
public:
PassPipelineInfo(StringRef arg, StringRef description,
- PassRegistryFunction builder)
+ const PassRegistryFunction &builder)
: PassRegistryEntry(arg, description, builder) {}
};
@@ -99,8 +85,7 @@
/// PassInfo constructor should not be invoked directly, instead use
/// PassRegistration or registerPass.
PassInfo(StringRef arg, StringRef description, const PassID *passID,
- PassRegistryFunction allocator)
- : PassRegistryEntry(arg, description, allocator) {}
+ const PassAllocatorFunction &allocator);
};
//===----------------------------------------------------------------------===//
@@ -115,140 +100,28 @@
/// Register a specific dialect pass allocator function with the system,
/// typically used through the PassRegistration template.
void registerPass(StringRef arg, StringRef description, const PassID *passID,
- const PassRegistryFunction &function);
-
-namespace detail {
-/// Base class for PassOptions<T> that holds all of the non-CRTP features.
-class PassOptionsBase : protected llvm::cl::SubCommand {
-public:
- /// This class represents a specific pass option, with a provided data type.
- template <typename DataType> struct Option : public llvm::cl::opt<DataType> {
- template <typename... Args>
- Option(PassOptionsBase &parent, StringRef arg, Args &&... args)
- : llvm::cl::opt<DataType>(arg, llvm::cl::sub(parent),
- std::forward<Args>(args)...) {
- assert(!this->isPositional() && !this->isSink() &&
- "sink and positional options are not supported");
- }
- };
-
- /// This class represents a specific pass option that contains a list of
- /// values of the provided data type.
- template <typename DataType> struct List : public llvm::cl::list<DataType> {
- template <typename... Args>
- List(PassOptionsBase &parent, StringRef arg, Args &&... args)
- : llvm::cl::list<DataType>(arg, llvm::cl::sub(parent),
- std::forward<Args>(args)...) {
- assert(!this->isPositional() && !this->isSink() &&
- "sink and positional options are not supported");
- }
- };
-
- /// Parse options out as key=value pairs that can then be handed off to the
- /// `llvm::cl` command line passing infrastructure. Everything is space
- /// separated.
- LogicalResult parseFromString(StringRef options);
-};
-} // end namespace detail
-
-/// Subclasses of PassOptions provide a set of options that can be used to
-/// initialize a pass instance. See PassRegistration for usage details.
-///
-/// Usage:
-///
-/// struct MyPassOptions : PassOptions<MyPassOptions> {
-/// List<int> someListFlag{
-/// *this, "flag-name", llvm::cl::MiscFlags::CommaSeparated,
-/// llvm::cl::desc("...")};
-/// };
-template <typename T> class PassOptions : public detail::PassOptionsBase {
-public:
- /// Factory that parses the provided options and returns a unique_ptr to the
- /// struct.
- static std::unique_ptr<T> createFromString(StringRef options) {
- auto result = std::make_unique<T>();
- if (failed(result->parseFromString(options)))
- return nullptr;
- return result;
- }
-};
-
-/// A default empty option struct to be used for passes that do not need to take
-/// any options.
-struct EmptyPassOptions : public PassOptions<EmptyPassOptions> {};
-
-namespace detail {
-
-// Calls `pm.addPass(std::move(pass))` to avoid including the PassManager
-// header. Only used in `makePassRegistryFunction`.
-void addPassToPassManager(OpPassManager &pm, std::unique_ptr<Pass> pass);
-
-// Helper function which constructs a PassRegistryFunction that parses options
-// into a struct of type `Options` and then calls constructor(options) to
-// build the pass.
-template <typename Options, typename PassConstructor>
-PassRegistryFunction makePassRegistryFunction(PassConstructor constructor) {
- return [=](OpPassManager &pm, StringRef optionsStr) {
- Options options;
- if (failed(options.parseFromString(optionsStr)))
- return failure();
- addPassToPassManager(pm, constructor(options));
- return success();
- };
-}
-
-} // end namespace detail
+ const PassAllocatorFunction &function);
/// PassRegistration provides a global initializer that registers a Pass
-/// allocation routine for a concrete pass instance. The third argument is
+/// allocation routine for a concrete pass instance. The third argument is
/// optional and provides a callback to construct a pass that does not have
/// a default constructor.
///
/// Usage:
///
-/// // At namespace scope.
+/// /// At namespace scope.
/// static PassRegistration<MyPass> reg("my-pass", "My Pass Description.");
///
-/// // Same, but also providing an Options struct.
-/// static PassRegistration<MyPass, MyPassOptions> reg("my-pass", "Docs...");
-template <typename ConcretePass, typename Options = EmptyPassOptions>
-struct PassRegistration {
+template <typename ConcretePass> struct PassRegistration {
PassRegistration(StringRef arg, StringRef description,
- const std::function<std::unique_ptr<Pass>(const Options &)>
- &constructor) {
- registerPass(arg, description, PassID::getID<ConcretePass>(),
- detail::makePassRegistryFunction<Options>(constructor));
+ const PassAllocatorFunction &constructor) {
+ registerPass(arg, description, PassID::getID<ConcretePass>(), constructor);
}
- PassRegistration(StringRef arg, StringRef description) {
- registerPass(
- arg, description, PassID::getID<ConcretePass>(),
- detail::makePassRegistryFunction<Options>([](const Options &options) {
- return std::make_unique<ConcretePass>(options);
- }));
- }
-};
-
-/// Convenience specialization of PassRegistration for EmptyPassOptions that
-/// does not pass an empty options struct to the pass constructor.
-template <typename ConcretePass>
-struct PassRegistration<ConcretePass, EmptyPassOptions> {
- PassRegistration(StringRef arg, StringRef description,
- const std::function<std::unique_ptr<Pass>()> &constructor) {
- registerPass(
- arg, description, PassID::getID<ConcretePass>(),
- detail::makePassRegistryFunction<EmptyPassOptions>(
- [=](const EmptyPassOptions &options) { return constructor(); }));
- }
-
- PassRegistration(StringRef arg, StringRef description) {
- registerPass(arg, description, PassID::getID<ConcretePass>(),
- detail::makePassRegistryFunction<EmptyPassOptions>(
- [](const EmptyPassOptions &options) {
- return std::make_unique<ConcretePass>();
- }));
- }
+ PassRegistration(StringRef arg, StringRef description)
+ : PassRegistration(arg, description,
+ [] { return std::make_unique<ConcretePass>(); }) {}
};
/// PassPipelineRegistration provides a global initializer that registers a Pass
@@ -264,7 +137,8 @@
///
/// static PassPipelineRegistration Unused("unused", "Unused pass",
/// pipelineBuilder);
-template <typename Options = EmptyPassOptions> struct PassPipelineRegistration {
+template <typename Options = EmptyPipelineOptions>
+struct PassPipelineRegistration {
PassPipelineRegistration(
StringRef arg, StringRef description,
std::function<void(OpPassManager &, const Options &options)> builder) {
@@ -281,7 +155,7 @@
/// Convenience specialization of PassPipelineRegistration for EmptyPassOptions
/// that does not pass an empty options struct to the pass builder function.
-template <> struct PassPipelineRegistration<EmptyPassOptions> {
+template <> struct PassPipelineRegistration<EmptyPipelineOptions> {
PassPipelineRegistration(StringRef arg, StringRef description,
std::function<void(OpPassManager &)> builder) {
registerPassPipeline(arg, description,
diff --git a/third_party/mlir/include/mlir/Quantizer/Configurations/FxpMathConfig.h b/third_party/mlir/include/mlir/Quantizer/Configurations/FxpMathConfig.h
index 467512f..f27d12d 100644
--- a/third_party/mlir/include/mlir/Quantizer/Configurations/FxpMathConfig.h
+++ b/third_party/mlir/include/mlir/Quantizer/Configurations/FxpMathConfig.h
@@ -1,19 +1,10 @@
//===- FxpMathConfig.h - Reference fixed point config -----------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines a TargetConfiguration for reference fixed-point math
// quantization scheme based on the FxpMathOps (plus a small category of
diff --git a/third_party/mlir/include/mlir/Quantizer/Support/Configuration.h b/third_party/mlir/include/mlir/Quantizer/Support/Configuration.h
index 17a472d..3732fba 100644
--- a/third_party/mlir/include/mlir/Quantizer/Support/Configuration.h
+++ b/third_party/mlir/include/mlir/Quantizer/Support/Configuration.h
@@ -1,19 +1,10 @@
//===- Configuration.h - Configuration object base classes ------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// The quantizer is relatively agnostic to source and target dialects, with
// the specific represented by configuration policy objects derived from
diff --git a/third_party/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraph.h b/third_party/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraph.h
index 070b3c3..d99db65 100644
--- a/third_party/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraph.h
+++ b/third_party/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraph.h
@@ -1,19 +1,10 @@
//===- ConstraintAnalysisGraph.h - Graphs type for constraints --*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file provides graph-based data structures for representing anchors
// and constraints between them.
@@ -163,7 +154,7 @@
}
virtual Operation *getOp() const = 0;
- virtual Value *getValue() const = 0;
+ virtual Value getValue() const = 0;
static bool classof(const CAGNode *n) {
return n->getKind() >= Kind::Anchor && n->getKind() <= Kind::LastAnchor;
@@ -210,7 +201,7 @@
return n->getKind() == Kind::Anchor || n->getKind() == Kind::OperandAnchor;
}
- Value *getValue() const final { return op->getOperand(operandIdx); }
+ Value getValue() const final { return op->getOperand(operandIdx); }
void printLabel(raw_ostream &os) const override;
@@ -221,7 +212,7 @@
/// An anchor tied to a specific result.
/// Since a result is already anchored to its defining op, result anchors refer
-/// directly to the underlying Value*.
+/// directly to the underlying Value.
class CAGResultAnchor : public CAGAnchorNode {
public:
CAGResultAnchor(Operation *op, unsigned resultIdx);
@@ -231,12 +222,12 @@
}
Operation *getOp() const final { return resultValue->getDefiningOp(); }
- Value *getValue() const final { return resultValue; }
+ Value getValue() const final { return resultValue; }
void printLabel(raw_ostream &os) const override;
private:
- Value *resultValue;
+ Value resultValue;
};
/// Base class for constraint nodes.
diff --git a/third_party/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h b/third_party/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h
index 7e2b61d..35ec85f 100644
--- a/third_party/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h
+++ b/third_party/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h
@@ -1,19 +1,10 @@
//===- ConstraintAnalysisGraphTraits.h - Traits for CAGs --------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Provides graph traits for constraint analysis graphs.
//
diff --git a/third_party/mlir/include/mlir/Quantizer/Support/Metadata.h b/third_party/mlir/include/mlir/Quantizer/Support/Metadata.h
index 6c327d9..0545e78 100644
--- a/third_party/mlir/include/mlir/Quantizer/Support/Metadata.h
+++ b/third_party/mlir/include/mlir/Quantizer/Support/Metadata.h
@@ -1,19 +1,10 @@
//===- Metadata.h - Top level types and metadata ----------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file contains top level types needed to construct constraint graphs,
// including context/allocator support and concrete metadata structs for
diff --git a/third_party/mlir/include/mlir/Quantizer/Support/Rules.h b/third_party/mlir/include/mlir/Quantizer/Support/Rules.h
index 9d1e53d..536dd7e 100644
--- a/third_party/mlir/include/mlir/Quantizer/Support/Rules.h
+++ b/third_party/mlir/include/mlir/Quantizer/Support/Rules.h
@@ -1,19 +1,10 @@
//===- Rules.h - Helpers for declaring facts and rules ----------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines helper classes and functions for managing state (facts),
// merging and tracking modification for various data types important for
diff --git a/third_party/mlir/include/mlir/Quantizer/Support/Statistics.h b/third_party/mlir/include/mlir/Quantizer/Support/Statistics.h
index 744c5b6..a24eecd 100644
--- a/third_party/mlir/include/mlir/Quantizer/Support/Statistics.h
+++ b/third_party/mlir/include/mlir/Quantizer/Support/Statistics.h
@@ -1,19 +1,10 @@
//===- Statistics.h - Collects statistics over tensors ----------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines adapters for extracting various (per layer and per axis)
// statistics over tensors.
diff --git a/third_party/mlir/include/mlir/Quantizer/Support/TypeUtils.h b/third_party/mlir/include/mlir/Quantizer/Support/TypeUtils.h
index 074f8b9..64ae5d6 100644
--- a/third_party/mlir/include/mlir/Quantizer/Support/TypeUtils.h
+++ b/third_party/mlir/include/mlir/Quantizer/Support/TypeUtils.h
@@ -1,19 +1,10 @@
//===- TypeUtils.h - Helper function for manipulating types -----*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines various helper functions for manipulating types. The
// process of quantizing typically involves a number of type manipulations
diff --git a/third_party/mlir/include/mlir/Quantizer/Support/UniformConstraints.h b/third_party/mlir/include/mlir/Quantizer/Support/UniformConstraints.h
index 90b5fe1..70c022c 100644
--- a/third_party/mlir/include/mlir/Quantizer/Support/UniformConstraints.h
+++ b/third_party/mlir/include/mlir/Quantizer/Support/UniformConstraints.h
@@ -1,19 +1,10 @@
//===- UniformConstraints.h - Constraints for uniform quant -----*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines a builder that lets you attach constraints necessary to
// perform a variety of uniform quantization conversions to CAG anchors.
diff --git a/third_party/mlir/include/mlir/Quantizer/Support/UniformSolvers.h b/third_party/mlir/include/mlir/Quantizer/Support/UniformSolvers.h
index 98df671..d6bd1a2 100644
--- a/third_party/mlir/include/mlir/Quantizer/Support/UniformSolvers.h
+++ b/third_party/mlir/include/mlir/Quantizer/Support/UniformSolvers.h
@@ -1,19 +1,10 @@
//===- UniformSolvers.h - Uniform type solver algorithms --------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines algorithms for solving uniform type parameters for various
// conditions (i.e. fixed-point, affine, scale matching, etc).
diff --git a/third_party/mlir/include/mlir/Quantizer/Transforms/Passes.h b/third_party/mlir/include/mlir/Quantizer/Transforms/Passes.h
index 4fdea58..3490f29 100644
--- a/third_party/mlir/include/mlir/Quantizer/Transforms/Passes.h
+++ b/third_party/mlir/include/mlir/Quantizer/Transforms/Passes.h
@@ -1,19 +1,10 @@
//===- Passes.h - Quantizer passes -----------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines entry points to create passes to perform various kinds
// of quantization related transforms.
diff --git a/third_party/mlir/include/mlir/Support/DebugStringHelper.h b/third_party/mlir/include/mlir/Support/DebugStringHelper.h
index 230ed23..0fa3426 100644
--- a/third_party/mlir/include/mlir/Support/DebugStringHelper.h
+++ b/third_party/mlir/include/mlir/Support/DebugStringHelper.h
@@ -1,19 +1,10 @@
//===- DebugStringHelper.h - helpers to generate debug strings --*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Convenience functions to make it easier to get a string representation for
// ops that have a print method. For use in debugging output and errors
diff --git a/third_party/mlir/include/mlir/Support/FileUtilities.h b/third_party/mlir/include/mlir/Support/FileUtilities.h
index 5ce9722..c13b39e 100644
--- a/third_party/mlir/include/mlir/Support/FileUtilities.h
+++ b/third_party/mlir/include/mlir/Support/FileUtilities.h
@@ -1,19 +1,10 @@
//===- FileUtilities.h - utilities for working with files -------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Common utilities for working with files.
//
diff --git a/third_party/mlir/include/mlir/Support/Functional.h b/third_party/mlir/include/mlir/Support/Functional.h
index e8bf394..f18677f 100644
--- a/third_party/mlir/include/mlir/Support/Functional.h
+++ b/third_party/mlir/include/mlir/Support/Functional.h
@@ -1,19 +1,10 @@
//===- Functional.h - Helpers for functional-style Combinators --*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_SUPPORT_FUNCTIONAL_H_
#define MLIR_SUPPORT_FUNCTIONAL_H_
diff --git a/third_party/mlir/include/mlir/Support/JitRunner.h b/third_party/mlir/include/mlir/Support/JitRunner.h
index 14b66a8..71c1d7d 100644
--- a/third_party/mlir/include/mlir/Support/JitRunner.h
+++ b/third_party/mlir/include/mlir/Support/JitRunner.h
@@ -1,19 +1,10 @@
//===- JitRunner.h - MLIR CPU Execution Driver Library ----------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This is a library that provides a shared implementation for command line
// utilities that execute an MLIR file on the CPU by translating MLIR to LLVM
diff --git a/third_party/mlir/include/mlir/Support/LLVM.h b/third_party/mlir/include/mlir/Support/LLVM.h
index 91d145d..1885ebe 100644
--- a/third_party/mlir/include/mlir/Support/LLVM.h
+++ b/third_party/mlir/include/mlir/Support/LLVM.h
@@ -1,19 +1,10 @@
//===- LLVM.h - Import and forward declare core LLVM types ------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file forward declares and imports various common LLVM datatypes that
// MLIR wants to use unqualified.
diff --git a/third_party/mlir/include/mlir/Support/LogicalResult.h b/third_party/mlir/include/mlir/Support/LogicalResult.h
index a9fc77c..418293c 100644
--- a/third_party/mlir/include/mlir/Support/LogicalResult.h
+++ b/third_party/mlir/include/mlir/Support/LogicalResult.h
@@ -1,19 +1,10 @@
//===- LogicalResult.h - Utilities for handling success/failure -*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_SUPPORT_LOGICAL_RESULT_H
#define MLIR_SUPPORT_LOGICAL_RESULT_H
diff --git a/third_party/mlir/include/mlir/Support/MathExtras.h b/third_party/mlir/include/mlir/Support/MathExtras.h
index 767677f..1fd0634 100644
--- a/third_party/mlir/include/mlir/Support/MathExtras.h
+++ b/third_party/mlir/include/mlir/Support/MathExtras.h
@@ -1,19 +1,10 @@
//===- MathExtras.h - Math functions relevant to MLIR -----------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file contains math functions relevant to MLIR.
//
diff --git a/third_party/mlir/include/mlir/Support/MlirOptMain.h b/third_party/mlir/include/mlir/Support/MlirOptMain.h
index be8e432..eac5ee7 100644
--- a/third_party/mlir/include/mlir/Support/MlirOptMain.h
+++ b/third_party/mlir/include/mlir/Support/MlirOptMain.h
@@ -1,19 +1,10 @@
//===- MlirOptMain.h - MLIR Optimizer Driver main ---------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Main entry function for mlir-opt for when built as standalone binary.
//
diff --git a/third_party/mlir/include/mlir/Support/STLExtras.h b/third_party/mlir/include/mlir/Support/STLExtras.h
index 9bae7ac..9a12861 100644
--- a/third_party/mlir/include/mlir/Support/STLExtras.h
+++ b/third_party/mlir/include/mlir/Support/STLExtras.h
@@ -1,19 +1,10 @@
//===- STLExtras.h - STL-like extensions that are used by MLIR --*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file contains stuff that should be arguably sunk down to the LLVM
// Support/STLExtras.h file over time.
diff --git a/third_party/mlir/include/mlir/Support/StorageUniquer.h b/third_party/mlir/include/mlir/Support/StorageUniquer.h
index fe1f898..f505731 100644
--- a/third_party/mlir/include/mlir/Support/StorageUniquer.h
+++ b/third_party/mlir/include/mlir/Support/StorageUniquer.h
@@ -1,19 +1,10 @@
//===- StorageUniquer.h - Common Storage Class Uniquer ----------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_SUPPORT_STORAGEUNIQUER_H
#define MLIR_SUPPORT_STORAGEUNIQUER_H
diff --git a/third_party/mlir/include/mlir/Support/StringExtras.h b/third_party/mlir/include/mlir/Support/StringExtras.h
index 2f75c8e..5fc6769 100644
--- a/third_party/mlir/include/mlir/Support/StringExtras.h
+++ b/third_party/mlir/include/mlir/Support/StringExtras.h
@@ -1,19 +1,10 @@
//===- StringExtras.h - String utilities used by MLIR -----------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file contains string utility functions used within MLIR.
//
diff --git a/third_party/mlir/include/mlir/Support/ToolUtilities.h b/third_party/mlir/include/mlir/Support/ToolUtilities.h
index 13a3742..3175ebb 100644
--- a/third_party/mlir/include/mlir/Support/ToolUtilities.h
+++ b/third_party/mlir/include/mlir/Support/ToolUtilities.h
@@ -1,19 +1,10 @@
//===- ToolUtilities.h - MLIR Tool Utilities --------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file declares common utilities for implementing MLIR tools.
//
diff --git a/third_party/mlir/include/mlir/Support/TranslateClParser.h b/third_party/mlir/include/mlir/Support/TranslateClParser.h
index ccd4fb9..822d4b1 100644
--- a/third_party/mlir/include/mlir/Support/TranslateClParser.h
+++ b/third_party/mlir/include/mlir/Support/TranslateClParser.h
@@ -1,19 +1,10 @@
//===- TranslateClParser.h - Translations command line parser ---*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file contains custom command line parser for translations.
//
diff --git a/third_party/mlir/include/mlir/TableGen/Argument.h b/third_party/mlir/include/mlir/TableGen/Argument.h
index 8390939..6a0787e 100644
--- a/third_party/mlir/include/mlir/TableGen/Argument.h
+++ b/third_party/mlir/include/mlir/TableGen/Argument.h
@@ -1,19 +1,10 @@
//===- Argument.h - Argument definitions ------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This header file contains definitions for TableGen operation's arguments.
// Operation arguments fall into two categories:
diff --git a/third_party/mlir/include/mlir/TableGen/Attribute.h b/third_party/mlir/include/mlir/TableGen/Attribute.h
index 242376e..747df94 100644
--- a/third_party/mlir/include/mlir/TableGen/Attribute.h
+++ b/third_party/mlir/include/mlir/TableGen/Attribute.h
@@ -1,19 +1,10 @@
//===- Attribute.h - Attribute wrapper class --------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Attribute wrapper to simplify using TableGen Record defining a MLIR
// Attribute.
diff --git a/third_party/mlir/include/mlir/TableGen/Constraint.h b/third_party/mlir/include/mlir/TableGen/Constraint.h
index 17b60da..fb7c1d7 100644
--- a/third_party/mlir/include/mlir/TableGen/Constraint.h
+++ b/third_party/mlir/include/mlir/TableGen/Constraint.h
@@ -1,19 +1,10 @@
//===- Constraint.h - Constraint class --------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Constraint wrapper to simplify using TableGen Record for constraints.
//
diff --git a/third_party/mlir/include/mlir/TableGen/Dialect.h b/third_party/mlir/include/mlir/TableGen/Dialect.h
index 6861da4..56d17f4 100644
--- a/third_party/mlir/include/mlir/TableGen/Dialect.h
+++ b/third_party/mlir/include/mlir/TableGen/Dialect.h
@@ -1,18 +1,9 @@
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Dialect wrapper to simplify using TableGen Record defining a MLIR dialect.
//
diff --git a/third_party/mlir/include/mlir/TableGen/Format.h b/third_party/mlir/include/mlir/TableGen/Format.h
index 6f02c28..160ba5f 100644
--- a/third_party/mlir/include/mlir/TableGen/Format.h
+++ b/third_party/mlir/include/mlir/TableGen/Format.h
@@ -1,19 +1,10 @@
//===- Format.h - Utilities for String Format -------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file declares utilities for formatting strings. They are specially
// tailored to the needs of TableGen'ing op definitions and rewrite rules,
diff --git a/third_party/mlir/include/mlir/TableGen/GenInfo.h b/third_party/mlir/include/mlir/TableGen/GenInfo.h
index 0b0bd19..3c732c2 100644
--- a/third_party/mlir/include/mlir/TableGen/GenInfo.h
+++ b/third_party/mlir/include/mlir/TableGen/GenInfo.h
@@ -1,19 +1,10 @@
//===- GenInfo.h - Generator info -------------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_TABLEGEN_GENINFO_H_
#define MLIR_TABLEGEN_GENINFO_H_
diff --git a/third_party/mlir/include/mlir/TableGen/GenNameParser.h b/third_party/mlir/include/mlir/TableGen/GenNameParser.h
index 7b1e8a3..65f4a8c 100644
--- a/third_party/mlir/include/mlir/TableGen/GenNameParser.h
+++ b/third_party/mlir/include/mlir/TableGen/GenNameParser.h
@@ -1,19 +1,10 @@
//===- GenNameParser.h - Command line parser for generators -----*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// The GenNameParser class adds all passes linked in to the system that are
// creatable to the tool.
diff --git a/third_party/mlir/include/mlir/TableGen/OpInterfaces.h b/third_party/mlir/include/mlir/TableGen/OpInterfaces.h
index 0959f6b..9bf1816 100644
--- a/third_party/mlir/include/mlir/TableGen/OpInterfaces.h
+++ b/third_party/mlir/include/mlir/TableGen/OpInterfaces.h
@@ -1,19 +1,10 @@
//===- OpInterfaces.h - OpInterfaces wrapper class --------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// OpInterfaces wrapper to simplify using TableGen OpInterfaces.
//
diff --git a/third_party/mlir/include/mlir/TableGen/OpTrait.h b/third_party/mlir/include/mlir/TableGen/OpTrait.h
index c3ea9a7..59fc7ac 100644
--- a/third_party/mlir/include/mlir/TableGen/OpTrait.h
+++ b/third_party/mlir/include/mlir/TableGen/OpTrait.h
@@ -1,19 +1,10 @@
//===- OpTrait.h - OpTrait wrapper class ------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// OpTrait wrapper to simplify using TableGen Record defining an MLIR OpTrait.
//
diff --git a/third_party/mlir/include/mlir/TableGen/Operator.h b/third_party/mlir/include/mlir/TableGen/Operator.h
index 89fd4ed..dd5ff35 100644
--- a/third_party/mlir/include/mlir/TableGen/Operator.h
+++ b/third_party/mlir/include/mlir/TableGen/Operator.h
@@ -1,19 +1,10 @@
//===- Operator.h - Operator class ------------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Operator wrapper to simplify using TableGen Record defining a MLIR Op.
//
diff --git a/third_party/mlir/include/mlir/TableGen/Pattern.h b/third_party/mlir/include/mlir/TableGen/Pattern.h
index 8bd1c91..bf89f6e 100644
--- a/third_party/mlir/include/mlir/TableGen/Pattern.h
+++ b/third_party/mlir/include/mlir/TableGen/Pattern.h
@@ -1,19 +1,10 @@
//===- Pattern.h - Pattern wrapper class ------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Pattern wrapper class to simplify using TableGen Record defining a MLIR
// Pattern.
diff --git a/third_party/mlir/include/mlir/TableGen/Predicate.h b/third_party/mlir/include/mlir/TableGen/Predicate.h
index 49f7ebc..045b7fe 100644
--- a/third_party/mlir/include/mlir/TableGen/Predicate.h
+++ b/third_party/mlir/include/mlir/TableGen/Predicate.h
@@ -1,19 +1,10 @@
//===- Predicate.h - Predicate class ----------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Wrapper around predicates defined in TableGen.
//
diff --git a/third_party/mlir/include/mlir/TableGen/Region.h b/third_party/mlir/include/mlir/TableGen/Region.h
index 21dffe6..778f686 100644
--- a/third_party/mlir/include/mlir/TableGen/Region.h
+++ b/third_party/mlir/include/mlir/TableGen/Region.h
@@ -1,19 +1,10 @@
//===- TGRegion.h - TableGen region definitions -----------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_TABLEGEN_REGION_H_
#define MLIR_TABLEGEN_REGION_H_
diff --git a/third_party/mlir/include/mlir/TableGen/Type.h b/third_party/mlir/include/mlir/TableGen/Type.h
index 03cbd10..35de70f 100644
--- a/third_party/mlir/include/mlir/TableGen/Type.h
+++ b/third_party/mlir/include/mlir/TableGen/Type.h
@@ -1,19 +1,10 @@
//===- Type.h - Type class --------------------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Type wrapper to simplify using TableGen Record defining a MLIR Type.
//
diff --git a/third_party/mlir/include/mlir/Target/LLVMIR.h b/third_party/mlir/include/mlir/Target/LLVMIR.h
index 7ed7b39..1cdc26c 100644
--- a/third_party/mlir/include/mlir/Target/LLVMIR.h
+++ b/third_party/mlir/include/mlir/Target/LLVMIR.h
@@ -1,19 +1,10 @@
//===- LLVMIR.h - MLIR to LLVM IR conversion --------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file declares the entry point for the MLIR to LLVM IR conversion.
//
diff --git a/third_party/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/third_party/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index da2670a..d0b13a6 100644
--- a/third_party/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/third_party/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -1,19 +1,10 @@
//===- ModuleTranslation.h - MLIR to LLVM conversion ------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements the translation between an MLIR LLVM dialect module and
// the corresponding LLVMIR module. It only handles core LLVM IR operations.
@@ -87,16 +78,8 @@
llvm::IRBuilder<> &builder);
static std::unique_ptr<llvm::Module> prepareLLVMModule(Operation *m);
- // A helper to look up remapped operands in the value remapping table.
- template <typename Range>
- SmallVector<llvm::Value *, 8> lookupValues(Range &&values) {
- SmallVector<llvm::Value *, 8> remapped;
- remapped.reserve(llvm::size(values));
- for (Value *v : values) {
- remapped.push_back(valueMapping.lookup(v));
- }
- return remapped;
- }
+ /// A helper to look up remapped operands in the value remapping table.
+ SmallVector<llvm::Value *, 8> lookupValues(ValueRange values);
private:
/// Check whether the module contains only supported ops directly in its body.
@@ -121,7 +104,7 @@
protected:
// Mappings between original and translated values, used for lookups.
llvm::StringMap<llvm::Function *> functionMapping;
- DenseMap<Value *, llvm::Value *> valueMapping;
+ DenseMap<Value, llvm::Value *> valueMapping;
DenseMap<Block *, llvm::BasicBlock *> blockMapping;
};
diff --git a/third_party/mlir/include/mlir/Target/NVVMIR.h b/third_party/mlir/include/mlir/Target/NVVMIR.h
index ec9858e..377ee16 100644
--- a/third_party/mlir/include/mlir/Target/NVVMIR.h
+++ b/third_party/mlir/include/mlir/Target/NVVMIR.h
@@ -1,19 +1,10 @@
//===- NVVMIR.h - MLIR to LLVM + NVVM IR conversion -------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file declares the entry point for the MLIR to LLVM + NVVM IR conversion.
//
diff --git a/third_party/mlir/include/mlir/Target/ROCDLIR.h b/third_party/mlir/include/mlir/Target/ROCDLIR.h
index fd00e94..25937ee 100644
--- a/third_party/mlir/include/mlir/Target/ROCDLIR.h
+++ b/third_party/mlir/include/mlir/Target/ROCDLIR.h
@@ -1,19 +1,10 @@
//===- ROCDLIR.h - MLIR to LLVM + ROCDL IR conversion -----------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file declares the entry point for the MLIR to LLVM + ROCDL IR
// conversion.
diff --git a/third_party/mlir/include/mlir/Transforms/DialectConversion.h b/third_party/mlir/include/mlir/Transforms/DialectConversion.h
index 814f220..5cbbcae 100644
--- a/third_party/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/third_party/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1,19 +1,10 @@
//===- DialectConversion.h - MLIR dialect conversion pass -------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file declares a generic pass for converting between MLIR dialects.
//
@@ -60,7 +51,7 @@
/// remaps an existing signature input.
struct InputMapping {
size_t inputNo, size;
- Value *replacementValue;
+ Value replacementValue;
};
/// Return the argument types for the new signature.
@@ -90,7 +81,7 @@
/// Remap an input of the original signature to another `replacement`
/// value. This drops the original argument.
- void remapInput(unsigned origInputNo, Value *replacement);
+ void remapInput(unsigned origInputNo, Value replacement);
private:
/// The remapping information for each of the original arguments.
@@ -143,7 +134,7 @@
/// the conversion has finished.
virtual Operation *materializeConversion(PatternRewriter &rewriter,
Type resultType,
- ArrayRef<Value *> inputs,
+ ArrayRef<Value> inputs,
Location loc) {
llvm_unreachable("expected 'materializeConversion' to be overridden");
}
@@ -172,7 +163,7 @@
/// ConversionPattern ever needs to replace an operation that does not
/// have successors. This function should not fail. If some specific cases of
/// the operation are not supported, these cases should not be matched.
- virtual void rewrite(Operation *op, ArrayRef<Value *> operands,
+ virtual void rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("unimplemented rewrite");
}
@@ -187,18 +178,18 @@
/// terminator operation that has successors. This function should not fail
/// the pass. If some specific cases of the operation are not supported,
/// these cases should not be matched.
- virtual void rewrite(Operation *op, ArrayRef<Value *> properOperands,
+ virtual void rewrite(Operation *op, ArrayRef<Value> properOperands,
ArrayRef<Block *> destinations,
- ArrayRef<ArrayRef<Value *>> operands,
+ ArrayRef<ArrayRef<Value>> operands,
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("unimplemented rewrite for terminators");
}
/// Hook for derived classes to implement combined matching and rewriting.
virtual PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> properOperands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> properOperands,
ArrayRef<Block *> destinations,
- ArrayRef<ArrayRef<Value *>> operands,
+ ArrayRef<ArrayRef<Value>> operands,
ConversionPatternRewriter &rewriter) const {
if (!match(op))
return matchFailure();
@@ -208,7 +199,7 @@
/// Hook for derived classes to implement combined matching and rewriting.
virtual PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (!match(op))
return matchFailure();
@@ -234,27 +225,27 @@
/// Wrappers around the ConversionPattern methods that pass the derived op
/// type.
- void rewrite(Operation *op, ArrayRef<Value *> operands,
+ void rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
rewrite(cast<SourceOp>(op), operands, rewriter);
}
- void rewrite(Operation *op, ArrayRef<Value *> properOperands,
+ void rewrite(Operation *op, ArrayRef<Value> properOperands,
ArrayRef<Block *> destinations,
- ArrayRef<ArrayRef<Value *>> operands,
+ ArrayRef<ArrayRef<Value>> operands,
ConversionPatternRewriter &rewriter) const final {
rewrite(cast<SourceOp>(op), properOperands, destinations, operands,
rewriter);
}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> properOperands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> properOperands,
ArrayRef<Block *> destinations,
- ArrayRef<ArrayRef<Value *>> operands,
+ ArrayRef<ArrayRef<Value>> operands,
ConversionPatternRewriter &rewriter) const final {
return matchAndRewrite(cast<SourceOp>(op), properOperands, destinations,
operands, rewriter);
}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
}
@@ -264,22 +255,22 @@
/// Rewrite and Match methods that operate on the SourceOp type. These must be
/// overridden by the derived pattern class.
- virtual void rewrite(SourceOp op, ArrayRef<Value *> operands,
+ virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("must override matchAndRewrite or a rewrite method");
}
- virtual void rewrite(SourceOp op, ArrayRef<Value *> properOperands,
+ virtual void rewrite(SourceOp op, ArrayRef<Value> properOperands,
ArrayRef<Block *> destinations,
- ArrayRef<ArrayRef<Value *>> operands,
+ ArrayRef<ArrayRef<Value>> operands,
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("unimplemented rewrite for terminators");
}
virtual PatternMatchResult
- matchAndRewrite(SourceOp op, ArrayRef<Value *> properOperands,
+ matchAndRewrite(SourceOp op, ArrayRef<Value> properOperands,
ArrayRef<Block *> destinations,
- ArrayRef<ArrayRef<Value *>> operands,
+ ArrayRef<ArrayRef<Value>> operands,
ConversionPatternRewriter &rewriter) const {
if (!match(op))
return matchFailure();
@@ -288,7 +279,7 @@
}
virtual PatternMatchResult
- matchAndRewrite(SourceOp op, ArrayRef<Value *> operands,
+ matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (!match(op))
return matchFailure();
@@ -330,11 +321,11 @@
TypeConverter::SignatureConversion &conversion);
/// Replace all the uses of the block argument `from` with value `to`.
- void replaceUsesOfBlockArgument(BlockArgument *from, Value *to);
+ void replaceUsesOfBlockArgument(BlockArgument from, Value to);
/// Return the converted value that replaces 'key'. Return 'key' if there is
/// no such a converted value.
- Value *getRemappedValue(Value *key);
+ Value getRemappedValue(Value key);
//===--------------------------------------------------------------------===//
// PatternRewriter Hooks
@@ -374,7 +365,16 @@
Operation *insert(Operation *op) override;
/// PatternRewriter hook for updating the root operation in-place.
- void notifyRootUpdated(Operation *op) override;
+ /// Note: These methods only track updates to the top-level operation itself,
+ /// and not nested regions. Updates to regions will still require notification
+ /// through other more specific hooks above.
+ void startRootUpdate(Operation *op) override;
+
+ /// PatternRewriter hook for updating the root operation in-place.
+ void finalizeRootUpdate(Operation *op) override;
+
+ /// PatternRewriter hook for updating the root operation in-place.
+ void cancelRootUpdate(Operation *op) override;
/// Return a reference to the internal implementation.
detail::ConversionPatternRewriterImpl &getImpl();
diff --git a/third_party/mlir/include/mlir/Transforms/FoldUtils.h b/third_party/mlir/include/mlir/Transforms/FoldUtils.h
index bdf88d3..6b0e827 100644
--- a/third_party/mlir/include/mlir/Transforms/FoldUtils.h
+++ b/third_party/mlir/include/mlir/Transforms/FoldUtils.h
@@ -1,19 +1,10 @@
//===- FoldUtils.h - Operation Fold Utilities -------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This header file declares various operation folding utilities. These
// utilities are intended to be used by passes to unify and simply their logic.
@@ -82,7 +73,7 @@
/// and immediately try to fold it. This function populates 'results' with
/// the results after folding the operation.
template <typename OpTy, typename... Args>
- void create(OpBuilder &builder, SmallVectorImpl<Value *> &results,
+ void create(OpBuilder &builder, SmallVectorImpl<Value> &results,
Location location, Args &&... args) {
Operation *op = builder.create<OpTy>(location, std::forward<Args>(args)...);
if (failed(tryToFold(op, results)))
@@ -94,9 +85,9 @@
/// Overload to create or fold a single result operation.
template <typename OpTy, typename... Args>
typename std::enable_if<OpTy::template hasTrait<OpTrait::OneResult>(),
- Value *>::type
+ Value>::type
create(OpBuilder &builder, Location location, Args &&... args) {
- SmallVector<Value *, 1> results;
+ SmallVector<Value, 1> results;
create<OpTy>(builder, results, location, std::forward<Args>(args)...);
return results.front();
}
@@ -107,7 +98,7 @@
OpTy>::type
create(OpBuilder &builder, Location location, Args &&... args) {
auto op = builder.create<OpTy>(location, std::forward<Args>(args)...);
- SmallVector<Value *, 0> unused;
+ SmallVector<Value, 0> unused;
(void)tryToFold(op.getOperation(), unused);
// Folding cannot remove a zero-result operation, so for convenience we
@@ -126,7 +117,7 @@
/// Tries to perform folding on the given `op`. If successful, populates
/// `results` with the results of the folding.
LogicalResult tryToFold(
- Operation *op, SmallVectorImpl<Value *> &results,
+ Operation *op, SmallVectorImpl<Value> &results,
function_ref<void(Operation *)> processGeneratedConstants = nullptr);
/// Try to get or create a new constant entry. On success this returns the
diff --git a/third_party/mlir/include/mlir/Transforms/InliningUtils.h b/third_party/mlir/include/mlir/Transforms/InliningUtils.h
index 590b46a..e3631c2 100644
--- a/third_party/mlir/include/mlir/Transforms/InliningUtils.h
+++ b/third_party/mlir/include/mlir/Transforms/InliningUtils.h
@@ -1,19 +1,10 @@
//===- InliningUtils.h - Inliner utilities ----------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This header file defines interfaces for various inlining utility methods.
//
@@ -105,7 +96,7 @@
/// operation). The given 'op' will be removed by the caller, after this
/// function has been called.
virtual void handleTerminator(Operation *op,
- ArrayRef<Value *> valuesToReplace) const {
+ ArrayRef<Value> valuesToReplace) const {
llvm_unreachable(
"must implement handleTerminator in the case of one inlined block");
}
@@ -125,7 +116,7 @@
/// ... = foo.call @foo(%input : i32) -> i16
///
/// NOTE: This hook may be invoked before the 'isLegal' checks above.
- virtual Operation *materializeCallConversion(OpBuilder &builder, Value *input,
+ virtual Operation *materializeCallConversion(OpBuilder &builder, Value input,
Type resultType,
Location conversionLoc) const {
return nullptr;
@@ -165,7 +156,7 @@
virtual void handleTerminator(Operation *op, Block *newDest) const;
virtual void handleTerminator(Operation *op,
- ArrayRef<Value *> valuesToRepl) const;
+ ArrayRef<Value> valuesToRepl) const;
};
//===----------------------------------------------------------------------===//
@@ -187,7 +178,7 @@
/// be cloned into the 'inlinePoint' or spliced directly.
LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
Operation *inlinePoint, BlockAndValueMapping &mapper,
- ArrayRef<Value *> resultsToReplace,
+ ArrayRef<Value> resultsToReplace,
Optional<Location> inlineLoc = llvm::None,
bool shouldCloneInlinedRegion = true);
@@ -196,8 +187,8 @@
/// in-favor of the region arguments when inlining.
LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
Operation *inlinePoint,
- ArrayRef<Value *> inlinedOperands,
- ArrayRef<Value *> resultsToReplace,
+ ArrayRef<Value> inlinedOperands,
+ ArrayRef<Value> resultsToReplace,
Optional<Location> inlineLoc = llvm::None,
bool shouldCloneInlinedRegion = true);
diff --git a/third_party/mlir/include/mlir/Transforms/LoopFusionUtils.h b/third_party/mlir/include/mlir/Transforms/LoopFusionUtils.h
index af84b89..4c307ff 100644
--- a/third_party/mlir/include/mlir/Transforms/LoopFusionUtils.h
+++ b/third_party/mlir/include/mlir/Transforms/LoopFusionUtils.h
@@ -1,19 +1,10 @@
//===- LoopFusionUtils.h - Loop fusion utilities ----------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This header file defines prototypes for various loop fusion utility
// methods: these are not passes by themselves but are used either by passes,
diff --git a/third_party/mlir/include/mlir/Transforms/LoopLikeInterface.h b/third_party/mlir/include/mlir/Transforms/LoopLikeInterface.h
index a8bc0d1..cba9ae7 100644
--- a/third_party/mlir/include/mlir/Transforms/LoopLikeInterface.h
+++ b/third_party/mlir/include/mlir/Transforms/LoopLikeInterface.h
@@ -1,19 +1,10 @@
//===- LoopLikeInterface.h - Loop-like operations interface ---------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements the operation interface for loop like operations.
//
diff --git a/third_party/mlir/include/mlir/Transforms/LoopLikeInterface.td b/third_party/mlir/include/mlir/Transforms/LoopLikeInterface.td
index 5c324b7..c110b19 100644
--- a/third_party/mlir/include/mlir/Transforms/LoopLikeInterface.td
+++ b/third_party/mlir/include/mlir/Transforms/LoopLikeInterface.td
@@ -1,19 +1,10 @@
//===- LoopLikeInterface.td - LoopLike interface -----------*- tablegen -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Defines the interface for loop-like operations as used by LICM.
//
@@ -38,7 +29,7 @@
explicit capture of dependencies, an implementation could check whether
the value corresponds to a captured dependency.
}],
- "bool", "isDefinedOutsideOfLoop", (ins "Value *":$value)
+ "bool", "isDefinedOutsideOfLoop", (ins "Value ":$value)
>,
InterfaceMethod<[{
Returns the region that makes up the body of the loop and should be
diff --git a/third_party/mlir/include/mlir/Transforms/LoopUtils.h b/third_party/mlir/include/mlir/Transforms/LoopUtils.h
index 5ca3f7f..402a336 100644
--- a/third_party/mlir/include/mlir/Transforms/LoopUtils.h
+++ b/third_party/mlir/include/mlir/Transforms/LoopUtils.h
@@ -1,19 +1,10 @@
//===- LoopUtils.h - Loop transformation utilities --------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This header file defines prototypes for various loop transformation utility
// methods: these are not passes by themselves but are used either by passes,
@@ -84,8 +75,7 @@
/// operands or a null map when the trip count can't be expressed as an affine
/// expression.
void getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor,
- AffineMap *map,
- SmallVectorImpl<Value *> *operands,
+ AffineMap *map, SmallVectorImpl<Value> *operands,
OpBuilder &builder);
/// Skew the operations in the body of a 'affine.for' operation with the
@@ -139,8 +129,7 @@
SmallVector<SmallVector<AffineForOp, 8>, 8> tile(ArrayRef<AffineForOp> forOps,
ArrayRef<uint64_t> sizes,
ArrayRef<AffineForOp> targets);
-SmallVector<Loops, 8> tile(ArrayRef<loop::ForOp> forOps,
- ArrayRef<Value *> sizes,
+SmallVector<Loops, 8> tile(ArrayRef<loop::ForOp> forOps, ArrayRef<Value> sizes,
ArrayRef<loop::ForOp> targets);
/// Performs tiling (with interchange) by strip-mining the `forOps` by `sizes`
@@ -149,7 +138,7 @@
/// `target`.
SmallVector<AffineForOp, 8> tile(ArrayRef<AffineForOp> forOps,
ArrayRef<uint64_t> sizes, AffineForOp target);
-Loops tile(ArrayRef<loop::ForOp> forOps, ArrayRef<Value *> sizes,
+Loops tile(ArrayRef<loop::ForOp> forOps, ArrayRef<Value> sizes,
loop::ForOp target);
/// Tile a nest of loop::ForOp loops rooted at `rootForOp` with the given
@@ -157,7 +146,7 @@
/// runtime. If more sizes than loops are provided, discard the trailing values
/// in sizes. Assumes the loop nest is permutable.
/// Returns the newly created intra-tile loops.
-Loops tilePerfectlyNested(loop::ForOp rootForOp, ArrayRef<Value *> sizes);
+Loops tilePerfectlyNested(loop::ForOp rootForOp, ArrayRef<Value> sizes);
/// Explicit copy / DMA generation options for mlir::affineDataCopyGenerate.
struct AffineCopyOptions {
@@ -229,8 +218,8 @@
/// ...
/// }
/// ```
-void mapLoopToProcessorIds(loop::ForOp forOp, ArrayRef<Value *> processorId,
- ArrayRef<Value *> numProcessors);
+void mapLoopToProcessorIds(loop::ForOp forOp, ArrayRef<Value> processorId,
+ ArrayRef<Value> numProcessors);
} // end namespace mlir
#endif // MLIR_TRANSFORMS_LOOP_UTILS_H
diff --git a/third_party/mlir/include/mlir/Transforms/Passes.h b/third_party/mlir/include/mlir/Transforms/Passes.h
index 5480a9a..1ea8f06 100644
--- a/third_party/mlir/include/mlir/Transforms/Passes.h
+++ b/third_party/mlir/include/mlir/Transforms/Passes.h
@@ -1,19 +1,10 @@
//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This header file defines prototypes that expose pass constructors in the loop
// transformation library.
diff --git a/third_party/mlir/include/mlir/Transforms/RegionUtils.h b/third_party/mlir/include/mlir/Transforms/RegionUtils.h
index 48080b2..bd71553 100644
--- a/third_party/mlir/include/mlir/Transforms/RegionUtils.h
+++ b/third_party/mlir/include/mlir/Transforms/RegionUtils.h
@@ -1,19 +1,10 @@
//===- RegionUtils.h - Region-related transformation utilities --*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_TRANSFORMS_REGIONUTILS_H_
#define MLIR_TRANSFORMS_REGIONUTILS_H_
@@ -30,15 +21,14 @@
/// of `limit`.
template <typename Range>
bool areValuesDefinedAbove(Range values, Region &limit) {
- for (Value *v : values)
+ for (Value v : values)
if (!v->getParentRegion()->isProperAncestor(&limit))
return false;
return true;
}
/// Replace all uses of `orig` within the given region with `replacement`.
-void replaceAllUsesInRegionWith(Value *orig, Value *replacement,
- Region ®ion);
+void replaceAllUsesInRegionWith(Value orig, Value replacement, Region ®ion);
/// Calls `callback` for each use of a value within `region` or its descendants
/// that was defined at the ancestors of the `limit`.
@@ -53,12 +43,12 @@
/// Fill `values` with a list of values defined at the ancestors of the `limit`
/// region and used within `region` or its descendants.
void getUsedValuesDefinedAbove(Region ®ion, Region &limit,
- llvm::SetVector<Value *> &values);
+ llvm::SetVector<Value> &values);
/// Fill `values` with a list of values used within any of the regions provided
/// but defined in one of the ancestors.
void getUsedValuesDefinedAbove(MutableArrayRef<Region> regions,
- llvm::SetVector<Value *> &values);
+ llvm::SetVector<Value> &values);
/// Run a set of structural simplifications over the given regions. This
/// includes transformations like unreachable block elimination, dead argument
diff --git a/third_party/mlir/include/mlir/Transforms/SideEffectsInterface.h b/third_party/mlir/include/mlir/Transforms/SideEffectsInterface.h
index 443596b..69c2a27 100644
--- a/third_party/mlir/include/mlir/Transforms/SideEffectsInterface.h
+++ b/third_party/mlir/include/mlir/Transforms/SideEffectsInterface.h
@@ -1,19 +1,10 @@
//===- SideEffectsInterface.h - dialect interface modeling side effects ---===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file specifies a dialect interface to model side-effects.
//
diff --git a/third_party/mlir/include/mlir/Transforms/Utils.h b/third_party/mlir/include/mlir/Transforms/Utils.h
index c682b48..3b7f6cd 100644
--- a/third_party/mlir/include/mlir/Transforms/Utils.h
+++ b/third_party/mlir/include/mlir/Transforms/Utils.h
@@ -1,19 +1,10 @@
//===- Utils.h - General transformation utilities ---------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This header file defines prototypes for various transformation utilities for
// memref's and non-loop IR structures. These are not passes by themselves but
@@ -66,22 +57,22 @@
// extra operands, note that 'indexRemap' would just be applied to existing
// indices (%i, %j).
// TODO(bondhugula): allow extraIndices to be added at any position.
-LogicalResult replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
- ArrayRef<Value *> extraIndices = {},
+LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
+ ArrayRef<Value> extraIndices = {},
AffineMap indexRemap = AffineMap(),
- ArrayRef<Value *> extraOperands = {},
- ArrayRef<Value *> symbolOperands = {},
+ ArrayRef<Value> extraOperands = {},
+ ArrayRef<Value> symbolOperands = {},
Operation *domInstFilter = nullptr,
Operation *postDomInstFilter = nullptr);
/// Performs the same replacement as the other version above but only for the
/// dereferencing uses of `oldMemRef` in `op`.
-LogicalResult replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
+LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
Operation *op,
- ArrayRef<Value *> extraIndices = {},
+ ArrayRef<Value> extraIndices = {},
AffineMap indexRemap = AffineMap(),
- ArrayRef<Value *> extraOperands = {},
- ArrayRef<Value *> symbolOperands = {});
+ ArrayRef<Value> extraOperands = {},
+ ArrayRef<Value> symbolOperands = {});
/// Rewrites the memref defined by this alloc op to have an identity layout map
/// and updates all its indexing uses. Returns failure if any of its uses
@@ -96,9 +87,9 @@
/// The final results of the composed AffineApplyOp are returned in output
/// parameter 'results'. Returns the affine apply op created.
Operation *createComposedAffineApplyOp(OpBuilder &builder, Location loc,
- ArrayRef<Value *> operands,
+ ArrayRef<Value> operands,
ArrayRef<Operation *> affineApplyOps,
- SmallVectorImpl<Value *> *results);
+ SmallVectorImpl<Value> *results);
/// Given an operation, inserts one or more single result affine apply
/// operations, results of which are exclusively used by this operation.
diff --git a/third_party/mlir/include/mlir/Transforms/ViewOpGraph.h b/third_party/mlir/include/mlir/Transforms/ViewOpGraph.h
index 41f5eb5..c178208 100644
--- a/third_party/mlir/include/mlir/Transforms/ViewOpGraph.h
+++ b/third_party/mlir/include/mlir/Transforms/ViewOpGraph.h
@@ -1,19 +1,10 @@
//===- ViewOpGraph.h - View/write op graphviz graphs ------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Defines interface to produce Graphviz outputs of MLIR op within block.
//
diff --git a/third_party/mlir/include/mlir/Transforms/ViewRegionGraph.h b/third_party/mlir/include/mlir/Transforms/ViewRegionGraph.h
index 4378d38..e8c4750 100644
--- a/third_party/mlir/include/mlir/Transforms/ViewRegionGraph.h
+++ b/third_party/mlir/include/mlir/Transforms/ViewRegionGraph.h
@@ -1,19 +1,10 @@
//===- ViewRegionGraph.h - View/write graphviz graphs -----------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Defines interface to produce Graphviz outputs of MLIR Regions.
//
diff --git a/third_party/mlir/include/mlir/Translation.h b/third_party/mlir/include/mlir/Translation.h
index 0bf8178..9244b97 100644
--- a/third_party/mlir/include/mlir/Translation.h
+++ b/third_party/mlir/include/mlir/Translation.h
@@ -1,19 +1,10 @@
//===- Translation.h - Translation registry ---------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Registry for user-provided translations.
//
diff --git a/third_party/mlir/lib/Analysis/AffineAnalysis.cpp b/third_party/mlir/lib/Analysis/AffineAnalysis.cpp
index 97868a5..3358bb4 100644
--- a/third_party/mlir/lib/Analysis/AffineAnalysis.cpp
+++ b/third_party/mlir/lib/Analysis/AffineAnalysis.cpp
@@ -1,19 +1,10 @@
//===- AffineAnalysis.cpp - Affine structures analysis routines -----------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements miscellaneous analysis routines for affine structures
// (expressions, maps, sets), and other utilities relying on such analysis.
@@ -48,15 +39,15 @@
// TODO(andydavis) Add a method to AffineApplyOp which forward substitutes
// the AffineApplyOp into any user AffineApplyOps.
void mlir::getReachableAffineApplyOps(
- ArrayRef<Value *> operands, SmallVectorImpl<Operation *> &affineApplyOps) {
+ ArrayRef<Value> operands, SmallVectorImpl<Operation *> &affineApplyOps) {
struct State {
// The ssa value for this node in the DFS traversal.
- Value *value;
+ Value value;
// The operand index of 'value' to explore next during DFS traversal.
unsigned operandIndex;
};
SmallVector<State, 4> worklist;
- for (auto *operand : operands) {
+ for (auto operand : operands) {
worklist.push_back({operand, 0});
}
@@ -77,7 +68,7 @@
if (state.operandIndex < opInst->getNumOperands()) {
// Visit: Add next 'affineApplyOp' operand to worklist.
// Get next operand to visit at 'operandIndex'.
- auto *nextOperand = opInst->getOperand(state.operandIndex);
+ auto nextOperand = opInst->getOperand(state.operandIndex);
// Increment 'operandIndex' in 'state'.
++state.operandIndex;
// Add 'nextOperand' to worklist.
@@ -99,7 +90,7 @@
// setExprStride(ArrayRef<int64_t> expr, int64_t stride)
LogicalResult mlir::getIndexSet(MutableArrayRef<AffineForOp> forOps,
FlatAffineConstraints *domain) {
- SmallVector<Value *, 4> indices;
+ SmallVector<Value, 4> indices;
extractForInductionVars(forOps, &indices);
// Reset while associated Values in 'indices' to the domain.
domain->reset(forOps.size(), /*numSymbols=*/0, /*numLocals=*/0, indices);
@@ -146,25 +137,25 @@
// of maps to check. So getSrcDimOrSymPos would be "getPos(value, {0, 2})".
class ValuePositionMap {
public:
- void addSrcValue(Value *value) {
+ void addSrcValue(Value value) {
if (addValueAt(value, &srcDimPosMap, numSrcDims))
++numSrcDims;
}
- void addDstValue(Value *value) {
+ void addDstValue(Value value) {
if (addValueAt(value, &dstDimPosMap, numDstDims))
++numDstDims;
}
- void addSymbolValue(Value *value) {
+ void addSymbolValue(Value value) {
if (addValueAt(value, &symbolPosMap, numSymbols))
++numSymbols;
}
- unsigned getSrcDimOrSymPos(Value *value) const {
+ unsigned getSrcDimOrSymPos(Value value) const {
return getDimOrSymPos(value, srcDimPosMap, 0);
}
- unsigned getDstDimOrSymPos(Value *value) const {
+ unsigned getDstDimOrSymPos(Value value) const {
return getDimOrSymPos(value, dstDimPosMap, numSrcDims);
}
- unsigned getSymPos(Value *value) const {
+ unsigned getSymPos(Value value) const {
auto it = symbolPosMap.find(value);
assert(it != symbolPosMap.end());
return numSrcDims + numDstDims + it->second;
@@ -176,7 +167,7 @@
unsigned getNumSymbols() const { return numSymbols; }
private:
- bool addValueAt(Value *value, DenseMap<Value *, unsigned> *posMap,
+ bool addValueAt(Value value, DenseMap<Value, unsigned> *posMap,
unsigned position) {
auto it = posMap->find(value);
if (it == posMap->end()) {
@@ -185,8 +176,8 @@
}
return false;
}
- unsigned getDimOrSymPos(Value *value,
- const DenseMap<Value *, unsigned> &dimPosMap,
+ unsigned getDimOrSymPos(Value value,
+ const DenseMap<Value, unsigned> &dimPosMap,
unsigned dimPosOffset) const {
auto it = dimPosMap.find(value);
if (it != dimPosMap.end()) {
@@ -200,9 +191,9 @@
unsigned numSrcDims = 0;
unsigned numDstDims = 0;
unsigned numSymbols = 0;
- DenseMap<Value *, unsigned> srcDimPosMap;
- DenseMap<Value *, unsigned> dstDimPosMap;
- DenseMap<Value *, unsigned> symbolPosMap;
+ DenseMap<Value, unsigned> srcDimPosMap;
+ DenseMap<Value, unsigned> dstDimPosMap;
+ DenseMap<Value, unsigned> symbolPosMap;
};
// Builds a map from Value to identifier position in a new merged identifier
@@ -219,9 +210,9 @@
const FlatAffineConstraints &dstDomain, const AffineValueMap &srcAccessMap,
const AffineValueMap &dstAccessMap, ValuePositionMap *valuePosMap,
FlatAffineConstraints *dependenceConstraints) {
- auto updateValuePosMap = [&](ArrayRef<Value *> values, bool isSrc) {
+ auto updateValuePosMap = [&](ArrayRef<Value> values, bool isSrc) {
for (unsigned i = 0, e = values.size(); i < e; ++i) {
- auto *value = values[i];
+ auto value = values[i];
if (!isForInductionVar(values[i])) {
assert(isValidSymbol(values[i]) &&
"access operand has to be either a loop IV or a symbol");
@@ -234,7 +225,7 @@
}
};
- SmallVector<Value *, 4> srcValues, destValues;
+ SmallVector<Value, 4> srcValues, destValues;
srcDomain.getIdValues(0, srcDomain.getNumDimAndSymbolIds(), &srcValues);
dstDomain.getIdValues(0, dstDomain.getNumDimAndSymbolIds(), &destValues);
// Update value position map with identifiers from src iteration domain.
@@ -273,7 +264,7 @@
numLocals);
// Set values corresponding to dependence constraint identifiers.
- SmallVector<Value *, 4> srcLoopIVs, dstLoopIVs;
+ SmallVector<Value, 4> srcLoopIVs, dstLoopIVs;
srcDomain.getIdValues(0, srcDomain.getNumDimIds(), &srcLoopIVs);
dstDomain.getIdValues(0, dstDomain.getNumDimIds(), &dstLoopIVs);
@@ -282,8 +273,8 @@
srcLoopIVs.size(), srcLoopIVs.size() + dstLoopIVs.size(), dstLoopIVs);
// Set values for the symbolic identifier dimensions.
- auto setSymbolIds = [&](ArrayRef<Value *> values) {
- for (auto *value : values) {
+ auto setSymbolIds = [&](ArrayRef<Value> values) {
+ for (auto value : values) {
if (!isForInductionVar(value)) {
assert(isValidSymbol(value) && "expected symbol");
dependenceConstraints->setIdValue(valuePosMap.getSymPos(value), value);
@@ -294,7 +285,7 @@
setSymbolIds(srcAccessMap.getOperands());
setSymbolIds(dstAccessMap.getOperands());
- SmallVector<Value *, 8> srcSymbolValues, dstSymbolValues;
+ SmallVector<Value, 8> srcSymbolValues, dstSymbolValues;
srcDomain.getIdValues(srcDomain.getNumDimIds(),
srcDomain.getNumDimAndSymbolIds(), &srcSymbolValues);
dstDomain.getIdValues(dstDomain.getNumDimIds(),
@@ -398,10 +389,10 @@
unsigned numResults = srcMap.getNumResults();
unsigned srcNumIds = srcMap.getNumDims() + srcMap.getNumSymbols();
- ArrayRef<Value *> srcOperands = srcAccessMap.getOperands();
+ ArrayRef<Value> srcOperands = srcAccessMap.getOperands();
unsigned dstNumIds = dstMap.getNumDims() + dstMap.getNumSymbols();
- ArrayRef<Value *> dstOperands = dstAccessMap.getOperands();
+ ArrayRef<Value> dstOperands = dstAccessMap.getOperands();
std::vector<SmallVector<int64_t, 8>> srcFlatExprs;
std::vector<SmallVector<int64_t, 8>> destFlatExprs;
@@ -457,11 +448,11 @@
}
// Add equality constraints for any operands that are defined by constant ops.
- auto addEqForConstOperands = [&](ArrayRef<Value *> operands) {
+ auto addEqForConstOperands = [&](ArrayRef<Value> operands) {
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
if (isForInductionVar(operands[i]))
continue;
- auto *symbol = operands[i];
+ auto symbol = operands[i];
assert(isValidSymbol(symbol));
// Check if the symbol is a constant.
if (auto cOp = dyn_cast_or_null<ConstantIndexOp>(symbol->getDefiningOp()))
@@ -553,7 +544,7 @@
}
return block;
}
- auto *commonForValue = srcDomain.getIdValue(numCommonLoops - 1);
+ auto commonForValue = srcDomain.getIdValue(numCommonLoops - 1);
auto forOp = getForInductionVarOwner(commonForValue);
assert(forOp && "commonForValue was not an induction variable");
return forOp.getBody();
@@ -675,7 +666,7 @@
map = loadOp.getAffineMap();
else if (auto storeOp = dyn_cast<AffineStoreOp>(opInst))
map = storeOp.getAffineMap();
- SmallVector<Value *, 8> operands(indices.begin(), indices.end());
+ SmallVector<Value, 8> operands(indices.begin(), indices.end());
fullyComposeAffineMapAndOperands(&map, &operands);
map = simplifyAffineMap(map);
canonicalizeMapAndOperands(&map, &operands);
diff --git a/third_party/mlir/lib/Analysis/AffineStructures.cpp b/third_party/mlir/lib/Analysis/AffineStructures.cpp
index d678355..78a8698 100644
--- a/third_party/mlir/lib/Analysis/AffineStructures.cpp
+++ b/third_party/mlir/lib/Analysis/AffineStructures.cpp
@@ -1,19 +1,10 @@
//===- AffineStructures.cpp - MLIR Affine Structures Class-----------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Structures for affine/polyhedral analysis of MLIR functions.
//
@@ -204,8 +195,8 @@
// AffineValueMap.
//===----------------------------------------------------------------------===//
-AffineValueMap::AffineValueMap(AffineMap map, ArrayRef<Value *> operands,
- ArrayRef<Value *> results)
+AffineValueMap::AffineValueMap(AffineMap map, ArrayRef<Value> operands,
+ ArrayRef<Value> results)
: map(map), operands(operands.begin(), operands.end()),
results(results.begin(), results.end()) {}
@@ -219,8 +210,8 @@
: map(bound.getMap()),
operands(bound.operand_begin(), bound.operand_end()) {}
-void AffineValueMap::reset(AffineMap map, ArrayRef<Value *> operands,
- ArrayRef<Value *> results) {
+void AffineValueMap::reset(AffineMap map, ArrayRef<Value> operands,
+ ArrayRef<Value> results) {
this->map.reset(map);
this->operands.assign(operands.begin(), operands.end());
this->results.assign(results.begin(), results.end());
@@ -232,14 +223,14 @@
// Fully compose A's map + operands.
auto aMap = a.getAffineMap();
- SmallVector<Value *, 4> aOperands(a.getOperands().begin(),
- a.getOperands().end());
+ SmallVector<Value, 4> aOperands(a.getOperands().begin(),
+ a.getOperands().end());
fullyComposeAffineMapAndOperands(&aMap, &aOperands);
// Use the affine apply normalizer to get B's map into A's coordinate space.
AffineApplyNormalizer normalizer(aMap, aOperands);
- SmallVector<Value *, 4> bOperands(b.getOperands().begin(),
- b.getOperands().end());
+ SmallVector<Value, 4> bOperands(b.getOperands().begin(),
+ b.getOperands().end());
auto bMap = b.getAffineMap();
normalizer.normalize(&bMap, &bOperands);
@@ -263,7 +254,7 @@
// Returns true and sets 'indexOfMatch' if 'valueToMatch' is found in
// 'valuesToSearch' beginning at 'indexStart'. Returns false otherwise.
-static bool findIndex(Value *valueToMatch, ArrayRef<Value *> valuesToSearch,
+static bool findIndex(Value valueToMatch, ArrayRef<Value> valuesToSearch,
unsigned indexStart, unsigned *indexOfMatch) {
unsigned size = valuesToSearch.size();
for (unsigned i = indexStart; i < size; ++i) {
@@ -281,7 +272,7 @@
/// This method uses the invariant that operands are always positionally aligned
/// with the AffineDimExpr in the underlying AffineMap.
-bool AffineValueMap::isFunctionOf(unsigned idx, Value *value) const {
+bool AffineValueMap::isFunctionOf(unsigned idx, Value value) const {
unsigned index;
if (!findIndex(value, operands, /*indexStart=*/0, &index)) {
return false;
@@ -292,12 +283,12 @@
return expr.isFunctionOfDim(index);
}
-Value *AffineValueMap::getOperand(unsigned i) const {
- return static_cast<Value *>(operands[i]);
+Value AffineValueMap::getOperand(unsigned i) const {
+ return static_cast<Value>(operands[i]);
}
-ArrayRef<Value *> AffineValueMap::getOperands() const {
- return ArrayRef<Value *>(operands);
+ArrayRef<Value> AffineValueMap::getOperands() const {
+ return ArrayRef<Value>(operands);
}
AffineMap AffineValueMap::getAffineMap() const { return map.getAffineMap(); }
@@ -378,7 +369,7 @@
unsigned newNumReservedCols,
unsigned newNumDims, unsigned newNumSymbols,
unsigned newNumLocals,
- ArrayRef<Value *> idArgs) {
+ ArrayRef<Value> idArgs) {
assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 &&
"minimum 1 column");
numReservedCols = newNumReservedCols;
@@ -401,7 +392,7 @@
void FlatAffineConstraints::reset(unsigned newNumDims, unsigned newNumSymbols,
unsigned newNumLocals,
- ArrayRef<Value *> idArgs) {
+ ArrayRef<Value> idArgs) {
reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims,
newNumSymbols, newNumLocals, idArgs);
}
@@ -428,17 +419,17 @@
addId(IdKind::Local, pos);
}
-void FlatAffineConstraints::addDimId(unsigned pos, Value *id) {
+void FlatAffineConstraints::addDimId(unsigned pos, Value id) {
addId(IdKind::Dimension, pos, id);
}
-void FlatAffineConstraints::addSymbolId(unsigned pos, Value *id) {
+void FlatAffineConstraints::addSymbolId(unsigned pos, Value id) {
addId(IdKind::Symbol, pos, id);
}
/// Adds a dimensional identifier. The added column is initialized to
/// zero.
-void FlatAffineConstraints::addId(IdKind kind, unsigned pos, Value *id) {
+void FlatAffineConstraints::addId(IdKind kind, unsigned pos, Value id) {
if (kind == IdKind::Dimension) {
assert(pos <= getNumDimIds());
} else if (kind == IdKind::Symbol) {
@@ -527,7 +518,7 @@
/// Checks if the SSA values associated with `cst''s identifiers are unique.
static bool LLVM_ATTRIBUTE_UNUSED
areIdsUnique(const FlatAffineConstraints &cst) {
- SmallPtrSet<Value *, 8> uniqueIds;
+ SmallPtrSet<Value, 8> uniqueIds;
for (auto id : cst.getIds()) {
if (id.hasValue() && !uniqueIds.insert(id.getValue()).second)
return false;
@@ -571,11 +562,11 @@
assert(std::all_of(A->getIds().begin() + offset,
A->getIds().begin() + A->getNumDimAndSymbolIds(),
- [](Optional<Value *> id) { return id.hasValue(); }));
+ [](Optional<Value> id) { return id.hasValue(); }));
assert(std::all_of(B->getIds().begin() + offset,
B->getIds().begin() + B->getNumDimAndSymbolIds(),
- [](Optional<Value *> id) { return id.hasValue(); }));
+ [](Optional<Value> id) { return id.hasValue(); }));
// Place local id's of A after local id's of B.
for (unsigned l = 0, e = A->getNumLocalIds(); l < e; l++) {
@@ -586,13 +577,13 @@
A->addLocalId(A->getNumLocalIds());
}
- SmallVector<Value *, 4> aDimValues, aSymValues;
+ SmallVector<Value, 4> aDimValues, aSymValues;
A->getIdValues(offset, A->getNumDimIds(), &aDimValues);
A->getIdValues(A->getNumDimIds(), A->getNumDimAndSymbolIds(), &aSymValues);
{
// Merge dims from A into B.
unsigned d = offset;
- for (auto *aDimValue : aDimValues) {
+ for (auto aDimValue : aDimValues) {
unsigned loc;
if (B->findId(*aDimValue, &loc)) {
assert(loc >= offset && "A's dim appears in B's aligned range");
@@ -615,7 +606,7 @@
{
// Merge symbols: merge A's symbols into B first.
unsigned s = B->getNumDimIds();
- for (auto *aSymValue : aSymValues) {
+ for (auto aSymValue : aSymValues) {
unsigned loc;
if (B->findId(*aSymValue, &loc)) {
assert(loc >= B->getNumDimIds() && loc < B->getNumDimAndSymbolIds() &&
@@ -785,7 +776,7 @@
}
// Turn a dimension into a symbol.
-static void turnDimIntoSymbol(FlatAffineConstraints *cst, Value &id) {
+static void turnDimIntoSymbol(FlatAffineConstraints *cst, Value id) {
unsigned pos;
if (cst->findId(id, &pos) && pos < cst->getNumDimIds()) {
swapId(cst, pos, cst->getNumDimIds() - 1);
@@ -794,7 +785,7 @@
}
// Turn a symbol into a dimension.
-static void turnSymbolIntoDim(FlatAffineConstraints *cst, Value &id) {
+static void turnSymbolIntoDim(FlatAffineConstraints *cst, Value id) {
unsigned pos;
if (cst->findId(id, &pos) && pos >= cst->getNumDimIds() &&
pos < cst->getNumDimAndSymbolIds()) {
@@ -806,18 +797,18 @@
// Changes all symbol identifiers which are loop IVs to dim identifiers.
void FlatAffineConstraints::convertLoopIVSymbolsToDims() {
// Gather all symbols which are loop IVs.
- SmallVector<Value *, 4> loopIVs;
+ SmallVector<Value, 4> loopIVs;
for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++) {
if (ids[i].hasValue() && getForInductionVarOwner(ids[i].getValue()))
loopIVs.push_back(ids[i].getValue());
}
// Turn each symbol in 'loopIVs' into a dim identifier.
- for (auto *iv : loopIVs) {
+ for (auto iv : loopIVs) {
turnSymbolIntoDim(this, *iv);
}
}
-void FlatAffineConstraints::addInductionVarOrTerminalSymbol(Value *id) {
+void FlatAffineConstraints::addInductionVarOrTerminalSymbol(Value id) {
if (containsId(*id))
return;
@@ -876,8 +867,8 @@
addConstantLowerBound(pos, forOp.getConstantLowerBound());
} else {
// Non-constant lower bound case.
- SmallVector<Value *, 4> lbOperands(forOp.getLowerBoundOperands().begin(),
- forOp.getLowerBoundOperands().end());
+ SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands().begin(),
+ forOp.getLowerBoundOperands().end());
if (failed(addLowerOrUpperBound(pos, forOp.getLowerBoundMap(), lbOperands,
/*eq=*/false, /*lower=*/true)))
return failure();
@@ -888,8 +879,8 @@
return success();
}
// Non-constant upper bound case.
- SmallVector<Value *, 4> ubOperands(forOp.getUpperBoundOperands().begin(),
- forOp.getUpperBoundOperands().end());
+ SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands().begin(),
+ forOp.getUpperBoundOperands().end());
return addLowerOrUpperBound(pos, forOp.getUpperBoundMap(), ubOperands,
/*eq=*/false, /*lower=*/false);
}
@@ -1757,7 +1748,7 @@
LogicalResult
FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap,
- ArrayRef<Value *> boundOperands,
+ ArrayRef<Value> boundOperands,
bool eq, bool lower) {
assert(pos < getNumDimAndSymbolIds() && "invalid position");
// Equality follows the logic of lower bound except that we add an equality
@@ -1769,11 +1760,11 @@
// Fully compose map and operands; canonicalize and simplify so that we
// transitively get to terminal symbols or loop IVs.
auto map = boundMap;
- SmallVector<Value *, 4> operands(boundOperands.begin(), boundOperands.end());
+ SmallVector<Value, 4> operands(boundOperands.begin(), boundOperands.end());
fullyComposeAffineMapAndOperands(&map, &operands);
map = simplifyAffineMap(map);
canonicalizeMapAndOperands(&map, &operands);
- for (auto *operand : operands)
+ for (auto operand : operands)
addInductionVarOrTerminalSymbol(operand);
FlatAffineConstraints localVarCst;
@@ -1787,7 +1778,7 @@
if (localVarCst.getNumLocalIds() > 0) {
// Set values for localVarCst.
localVarCst.setIdValues(0, localVarCst.getNumDimAndSymbolIds(), operands);
- for (auto *operand : operands) {
+ for (auto operand : operands) {
unsigned pos;
if (findId(*operand, &pos)) {
if (pos >= getNumDimIds() && pos < getNumDimAndSymbolIds()) {
@@ -1807,7 +1798,7 @@
// this here since the constraint system changes after a bound is added.
SmallVector<unsigned, 8> positions;
unsigned numOperands = operands.size();
- for (auto *operand : operands) {
+ for (auto operand : operands) {
unsigned pos;
if (!findId(*operand, &pos))
assert(0 && "expected to be found");
@@ -1847,9 +1838,10 @@
// Note that both lower/upper bounds use operands from 'operands'.
// Returns failure for unimplemented cases such as semi-affine expressions or
// expressions with mod/floordiv.
-LogicalResult FlatAffineConstraints::addSliceBounds(
- ArrayRef<Value *> values, ArrayRef<AffineMap> lbMaps,
- ArrayRef<AffineMap> ubMaps, ArrayRef<Value *> operands) {
+LogicalResult FlatAffineConstraints::addSliceBounds(ArrayRef<Value> values,
+ ArrayRef<AffineMap> lbMaps,
+ ArrayRef<AffineMap> ubMaps,
+ ArrayRef<Value> operands) {
assert(values.size() == lbMaps.size());
assert(lbMaps.size() == ubMaps.size());
@@ -1971,10 +1963,10 @@
addInequality(bound);
}
-bool FlatAffineConstraints::findId(Value &id, unsigned *pos) const {
+bool FlatAffineConstraints::findId(Value id, unsigned *pos) const {
unsigned i = 0;
for (const auto &mayBeId : ids) {
- if (mayBeId.hasValue() && mayBeId.getValue() == &id) {
+ if (mayBeId.hasValue() && mayBeId.getValue() == id) {
*pos = i;
return true;
}
@@ -1983,9 +1975,9 @@
return false;
}
-bool FlatAffineConstraints::containsId(Value &id) const {
- return llvm::any_of(ids, [&](const Optional<Value *> &mayBeId) {
- return mayBeId.hasValue() && mayBeId.getValue() == &id;
+bool FlatAffineConstraints::containsId(Value id) const {
+ return llvm::any_of(ids, [&](const Optional<Value> &mayBeId) {
+ return mayBeId.hasValue() && mayBeId.getValue() == id;
});
}
@@ -2008,7 +2000,7 @@
/// Sets the specified identifier to a constant value; asserts if the id is not
/// found.
-void FlatAffineConstraints::setIdToConstant(Value &id, int64_t val) {
+void FlatAffineConstraints::setIdToConstant(Value id, int64_t val) {
unsigned pos;
if (!findId(id, &pos))
// This is a pre-condition for this method.
@@ -2573,7 +2565,7 @@
unsigned newNumDims = dimsSymbols.first;
unsigned newNumSymbols = dimsSymbols.second;
- SmallVector<Optional<Value *>, 8> newIds;
+ SmallVector<Optional<Value>, 8> newIds;
newIds.reserve(numIds - 1);
newIds.append(ids.begin(), ids.begin() + pos);
newIds.append(ids.begin() + pos + 1, ids.end());
@@ -2709,7 +2701,7 @@
normalizeConstraintsByGCD();
}
-void FlatAffineConstraints::projectOut(Value *id) {
+void FlatAffineConstraints::projectOut(Value id) {
unsigned pos;
bool ret = findId(*id, &pos);
assert(ret);
diff --git a/third_party/mlir/lib/Analysis/CallGraph.cpp b/third_party/mlir/lib/Analysis/CallGraph.cpp
index 93017ca..c35421d 100644
--- a/third_party/mlir/lib/Analysis/CallGraph.cpp
+++ b/third_party/mlir/lib/Analysis/CallGraph.cpp
@@ -1,19 +1,10 @@
//===- CallGraph.cpp - CallGraph analysis for MLIR ------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file contains interfaces and analyses for defining a nested callgraph.
//
@@ -188,7 +179,7 @@
callee = SymbolTable::lookupNearestSymbolFrom(from,
symbolRef.getRootReference());
else
- callee = callable.get<Value *>()->getDefiningOp();
+ callee = callable.get<Value>()->getDefiningOp();
// If the callee is non-null and is a valid callable object, try to get the
// called region from it.
diff --git a/third_party/mlir/lib/Analysis/Dominance.cpp b/third_party/mlir/lib/Analysis/Dominance.cpp
index c422578..e4af4c0 100644
--- a/third_party/mlir/lib/Analysis/Dominance.cpp
+++ b/third_party/mlir/lib/Analysis/Dominance.cpp
@@ -1,19 +1,10 @@
//===- Dominance.cpp - Dominator analysis for CFGs ------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Implementation of dominance related classes and instantiations of extern
// templates.
@@ -127,7 +118,7 @@
}
/// Return true if value A properly dominates operation B.
-bool DominanceInfo::properlyDominates(Value *a, Operation *b) {
+bool DominanceInfo::properlyDominates(Value a, Operation *b) {
if (auto *aOp = a->getDefiningOp()) {
// The values defined by an operation do *not* dominate any nested
// operations.
@@ -138,7 +129,7 @@
// block arguments properly dominate all operations in their own block, so
// we use a dominates check here, not a properlyDominates check.
- return dominates(cast<BlockArgument>(a)->getOwner(), b->getBlock());
+ return dominates(a.cast<BlockArgument>()->getOwner(), b->getBlock());
}
DominanceInfoNode *DominanceInfo::getNode(Block *a) {
diff --git a/third_party/mlir/lib/Analysis/InferTypeOpInterface.cpp b/third_party/mlir/lib/Analysis/InferTypeOpInterface.cpp
index cbbd446..2e52de2 100644
--- a/third_party/mlir/lib/Analysis/InferTypeOpInterface.cpp
+++ b/third_party/mlir/lib/Analysis/InferTypeOpInterface.cpp
@@ -1,19 +1,10 @@
//===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file contains the definitions of the infer op interfaces defined in
// `InferTypeOpInterface.td`.
diff --git a/third_party/mlir/lib/Analysis/Liveness.cpp b/third_party/mlir/lib/Analysis/Liveness.cpp
index 6aaec4c..7ba3136 100644
--- a/third_party/mlir/lib/Analysis/Liveness.cpp
+++ b/third_party/mlir/lib/Analysis/Liveness.cpp
@@ -1,19 +1,10 @@
//===- Liveness.cpp - Liveness analysis for MLIR --------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Implementation of the liveness analysis.
//
@@ -40,13 +31,13 @@
/// Fills the block builder with initial liveness information.
BlockInfoBuilder(Block *block) : block(block) {
// Mark all block arguments (phis) as defined.
- for (BlockArgument *argument : block->getArguments())
+ for (BlockArgument argument : block->getArguments())
defValues.insert(argument);
// Check all result values and whether their uses
// are inside this block or not (see outValues).
for (Operation &operation : *block)
- for (Value *result : operation.getResults()) {
+ for (Value result : operation.getResults()) {
defValues.insert(result);
// Check whether this value will be in the outValues
@@ -63,7 +54,7 @@
// Check all operations for used operands.
for (Operation &operation : block->getOperations())
- for (Value *operand : operation.getOperands()) {
+ for (Value operand : operation.getOperands()) {
// If the operand is already defined in the scope of this
// block, we can skip the value in the use set.
if (!defValues.count(operand))
@@ -173,7 +164,7 @@
}
/// Gets liveness info (if any) for the given value.
-Liveness::OperationListT Liveness::resolveLiveness(Value *value) const {
+Liveness::OperationListT Liveness::resolveLiveness(Value value) const {
OperationListT result;
SmallPtrSet<Block *, 32> visited;
SmallVector<Block *, 8> toProcess;
@@ -183,7 +174,7 @@
if (Operation *defOp = value->getDefiningOp())
currentBlock = defOp->getBlock();
else
- currentBlock = cast<BlockArgument>(value)->getOwner();
+ currentBlock = value.cast<BlockArgument>()->getOwner();
toProcess.push_back(currentBlock);
visited.insert(currentBlock);
@@ -238,7 +229,7 @@
/// Returns true if the given operation represent the last use of the
/// given value.
-bool Liveness::isLastUse(Value *value, Operation *operation) const {
+bool Liveness::isLastUse(Value value, Operation *operation) const {
Block *block = operation->getBlock();
const LivenessBlockInfo *blockInfo = getLiveness(block);
@@ -263,25 +254,25 @@
// Builds unique block/value mappings for testing purposes.
DenseMap<Block *, size_t> blockIds;
DenseMap<Operation *, size_t> operationIds;
- DenseMap<Value *, size_t> valueIds;
+ DenseMap<Value, size_t> valueIds;
for (Region ®ion : operation->getRegions())
for (Block &block : region) {
blockIds.insert({&block, blockIds.size()});
- for (BlockArgument *argument : block.getArguments())
+ for (BlockArgument argument : block.getArguments())
valueIds.insert({argument, valueIds.size()});
for (Operation &operation : block) {
operationIds.insert({&operation, operationIds.size()});
- for (Value *result : operation.getResults())
+ for (Value result : operation.getResults())
valueIds.insert({result, valueIds.size()});
}
}
// Local printing helpers
- auto printValueRef = [&](Value *value) {
+ auto printValueRef = [&](Value value) {
if (Operation *defOp = value->getDefiningOp())
os << "val_" << defOp->getName();
else {
- auto blockArg = cast<BlockArgument>(value);
+ auto blockArg = value.cast<BlockArgument>();
os << "arg" << blockArg->getArgNumber() << "@"
<< blockIds[blockArg->getOwner()];
}
@@ -289,12 +280,12 @@
};
auto printValueRefs = [&](const ValueSetT &values) {
- std::vector<Value *> orderedValues(values.begin(), values.end());
+ std::vector<Value> orderedValues(values.begin(), values.end());
std::sort(orderedValues.begin(), orderedValues.end(),
- [&](Value *left, Value *right) {
+ [&](Value left, Value right) {
return valueIds[left] < valueIds[right];
});
- for (Value *value : orderedValues)
+ for (Value value : orderedValues)
printValueRef(value);
};
@@ -315,7 +306,7 @@
if (op.getNumResults() < 1)
continue;
os << "\n";
- for (Value *result : op.getResults()) {
+ for (Value result : op.getResults()) {
os << "// ";
printValueRef(result);
os << ":";
@@ -340,18 +331,18 @@
//===----------------------------------------------------------------------===//
/// Returns true if the given value is in the live-in set.
-bool LivenessBlockInfo::isLiveIn(Value *value) const {
+bool LivenessBlockInfo::isLiveIn(Value value) const {
return inValues.count(value);
}
/// Returns true if the given value is in the live-out set.
-bool LivenessBlockInfo::isLiveOut(Value *value) const {
+bool LivenessBlockInfo::isLiveOut(Value value) const {
return outValues.count(value);
}
/// Gets the start operation for the given value
/// (must be referenced in this block).
-Operation *LivenessBlockInfo::getStartOperation(Value *value) const {
+Operation *LivenessBlockInfo::getStartOperation(Value value) const {
Operation *definingOp = value->getDefiningOp();
// The given value is either live-in or is defined
// in the scope of this block.
@@ -362,7 +353,7 @@
/// Gets the end operation for the given value using the start operation
/// provided (must be referenced in this block).
-Operation *LivenessBlockInfo::getEndOperation(Value *value,
+Operation *LivenessBlockInfo::getEndOperation(Value value,
Operation *startOperation) const {
// The given value is either dying in this block or live-out.
if (isLiveOut(value))
diff --git a/third_party/mlir/lib/Analysis/LoopAnalysis.cpp b/third_party/mlir/lib/Analysis/LoopAnalysis.cpp
index a811165..18c86dc 100644
--- a/third_party/mlir/lib/Analysis/LoopAnalysis.cpp
+++ b/third_party/mlir/lib/Analysis/LoopAnalysis.cpp
@@ -1,19 +1,10 @@
//===- LoopAnalysis.cpp - Misc loop analysis routines //-------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements miscellaneous loop analysis routines.
//
@@ -43,7 +34,7 @@
// be more powerful (since both inequalities and equalities will be considered).
void mlir::buildTripCountMapAndOperands(
AffineForOp forOp, AffineMap *tripCountMap,
- SmallVectorImpl<Value *> *tripCountOperands) {
+ SmallVectorImpl<Value> *tripCountOperands) {
int64_t loopSpan;
int64_t step = forOp.getStep();
@@ -65,8 +56,8 @@
*tripCountMap = AffineMap();
return;
}
- SmallVector<Value *, 4> lbOperands(forOp.getLowerBoundOperands());
- SmallVector<Value *, 4> ubOperands(forOp.getUpperBoundOperands());
+ SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
+ SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
// Difference of each upper bound expression from the single lower bound
// expression (divided by the step) provides the expressions for the trip
@@ -98,7 +89,7 @@
// works with analysis structures (FlatAffineConstraints) and thus doesn't
// update the IR.
Optional<uint64_t> mlir::getConstantTripCount(AffineForOp forOp) {
- SmallVector<Value *, 4> operands;
+ SmallVector<Value, 4> operands;
AffineMap map;
buildTripCountMapAndOperands(forOp, &map, &operands);
@@ -124,7 +115,7 @@
/// expression analysis is used (indirectly through getTripCount), and
/// this method is thus able to determine non-trivial divisors.
uint64_t mlir::getLargestDivisorOfTripCount(AffineForOp forOp) {
- SmallVector<Value *, 4> operands;
+ SmallVector<Value, 4> operands;
AffineMap map;
buildTripCountMapAndOperands(forOp, &map, &operands);
@@ -173,7 +164,7 @@
///
/// Returns false in cases with more than one AffineApplyOp, this is
/// conservative.
-static bool isAccessIndexInvariant(Value *iv, Value *index) {
+static bool isAccessIndexInvariant(Value iv, Value index) {
assert(isForInductionVar(iv) && "iv must be a AffineForOp");
assert(index->getType().isa<IndexType>() && "index must be of IndexType");
SmallVector<Operation *, 4> affineApplyOps;
@@ -197,11 +188,10 @@
return !(AffineValueMap(composeOp).isFunctionOf(0, iv));
}
-DenseSet<Value *> mlir::getInvariantAccesses(Value *iv,
- ArrayRef<Value *> indices) {
- DenseSet<Value *> res;
+DenseSet<Value> mlir::getInvariantAccesses(Value iv, ArrayRef<Value> indices) {
+ DenseSet<Value> res;
for (unsigned idx = 0, n = indices.size(); idx < n; ++idx) {
- auto *val = indices[idx];
+ auto val = indices[idx];
if (isAccessIndexInvariant(iv, val)) {
res.insert(val);
}
@@ -229,7 +219,7 @@
///
// TODO(ntv): check strides.
template <typename LoadOrStoreOp>
-static bool isContiguousAccess(Value *iv, LoadOrStoreOp memoryOp,
+static bool isContiguousAccess(Value iv, LoadOrStoreOp memoryOp,
int *memRefDim) {
static_assert(std::is_same<LoadOrStoreOp, AffineLoadOp>::value ||
std::is_same<LoadOrStoreOp, AffineStoreOp>::value,
@@ -250,11 +240,11 @@
int uniqueVaryingIndexAlongIv = -1;
auto accessMap = memoryOp.getAffineMap();
- SmallVector<Value *, 4> mapOperands(memoryOp.getMapOperands());
+ SmallVector<Value, 4> mapOperands(memoryOp.getMapOperands());
unsigned numDims = accessMap.getNumDims();
for (unsigned i = 0, e = memRefType.getRank(); i < e; ++i) {
// Gather map operands used result expr 'i' in 'exprOperands'.
- SmallVector<Value *, 4> exprOperands;
+ SmallVector<Value, 4> exprOperands;
auto resultExpr = accessMap.getResult(i);
resultExpr.walk([&](AffineExpr expr) {
if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
@@ -263,7 +253,7 @@
exprOperands.push_back(mapOperands[numDims + symExpr.getPosition()]);
});
// Check access invariance of each operand in 'exprOperands'.
- for (auto *exprOperand : exprOperands) {
+ for (auto exprOperand : exprOperands) {
if (!isAccessIndexInvariant(iv, exprOperand)) {
if (uniqueVaryingIndexAlongIv != -1) {
// 2+ varying indices -> do not vectorize along iv.
@@ -382,7 +372,7 @@
// Validate the results of this operation if it were to be shifted.
for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) {
- Value *result = op.getResult(i);
+ Value result = op.getResult(i);
for (auto *user : result->getUsers()) {
// If an ancestor operation doesn't lie in the block of forOp,
// there is no shift to check.
diff --git a/third_party/mlir/lib/Analysis/MemRefBoundCheck.cpp b/third_party/mlir/lib/Analysis/MemRefBoundCheck.cpp
index 4696ce6..1f7c1a1 100644
--- a/third_party/mlir/lib/Analysis/MemRefBoundCheck.cpp
+++ b/third_party/mlir/lib/Analysis/MemRefBoundCheck.cpp
@@ -1,19 +1,10 @@
//===- MemRefBoundCheck.cpp - MLIR Affine Structures Class ----------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements a pass to check memref accesses for out of bound
// accesses.
diff --git a/third_party/mlir/lib/Analysis/NestedMatcher.cpp b/third_party/mlir/lib/Analysis/NestedMatcher.cpp
index 5f2be48..97eaafd 100644
--- a/third_party/mlir/lib/Analysis/NestedMatcher.cpp
+++ b/third_party/mlir/lib/Analysis/NestedMatcher.cpp
@@ -1,19 +1,10 @@
//===- NestedMatcher.cpp - NestedMatcher Impl ----------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Analysis/NestedMatcher.h"
#include "mlir/Dialect/AffineOps/AffineOps.h"
diff --git a/third_party/mlir/lib/Analysis/OpStats.cpp b/third_party/mlir/lib/Analysis/OpStats.cpp
index 1c9f621..dbd9387 100644
--- a/third_party/mlir/lib/Analysis/OpStats.cpp
+++ b/third_party/mlir/lib/Analysis/OpStats.cpp
@@ -1,19 +1,10 @@
//===- OpStats.cpp - Prints stats of operations in module -----------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/IR/Module.h"
#include "mlir/IR/Operation.h"
diff --git a/third_party/mlir/lib/Analysis/SliceAnalysis.cpp b/third_party/mlir/lib/Analysis/SliceAnalysis.cpp
index 700321e..89ee613 100644
--- a/third_party/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/third_party/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -1,19 +1,10 @@
//===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements Analysis functions specific to slicing in Function.
//
@@ -104,8 +95,8 @@
}
for (auto en : llvm::enumerate(op->getOperands())) {
- auto *operand = en.value();
- if (auto *blockArg = dyn_cast<BlockArgument>(operand)) {
+ auto operand = en.value();
+ if (auto blockArg = operand.dyn_cast<BlockArgument>()) {
if (auto affIv = getForInductionVarOwner(operand)) {
auto *affOp = affIv.getOperation();
if (backwardSlice->count(affOp) == 0)
diff --git a/third_party/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp b/third_party/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp
index 80a579d..c6d7519 100644
--- a/third_party/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp
+++ b/third_party/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp
@@ -1,19 +1,10 @@
//===- TestMemRefDependenceCheck.cpp - Test dep analysis ------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements a pass to run pair-wise memref access dependence checks.
//
diff --git a/third_party/mlir/lib/Analysis/TestParallelismDetection.cpp b/third_party/mlir/lib/Analysis/TestParallelismDetection.cpp
index a9f9ea9..6cfc543 100644
--- a/third_party/mlir/lib/Analysis/TestParallelismDetection.cpp
+++ b/third_party/mlir/lib/Analysis/TestParallelismDetection.cpp
@@ -1,19 +1,10 @@
//===- ParallelismDetection.cpp - Parallelism Detection pass ------------*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements a pass to detect parallel affine 'affine.for' ops.
//
diff --git a/third_party/mlir/lib/Analysis/Utils.cpp b/third_party/mlir/lib/Analysis/Utils.cpp
index 3ba27bb..8ddf2e2 100644
--- a/third_party/mlir/lib/Analysis/Utils.cpp
+++ b/third_party/mlir/lib/Analysis/Utils.cpp
@@ -1,19 +1,10 @@
//===- Utils.cpp ---- Misc utilities for analysis -------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements miscellaneous analysis routines for non-loop IR
// structures.
@@ -60,7 +51,7 @@
// Adds operands (dst ivs and symbols) as symbols in 'cst'.
unsigned numSymbols = lbOperands[0].size();
- SmallVector<Value *, 4> values(ivs);
+ SmallVector<Value, 4> values(ivs);
// Append 'ivs' then 'operands' to 'values'.
values.append(lbOperands[0].begin(), lbOperands[0].end());
cst->reset(numDims, numSymbols, 0, values);
@@ -185,7 +176,7 @@
if (rank == 0) {
SmallVector<AffineForOp, 4> ivs;
getLoopIVs(*op, &ivs);
- SmallVector<Value *, 8> regionSymbols;
+ SmallVector<Value, 8> regionSymbols;
extractForInductionVars(ivs, ®ionSymbols);
// A rank 0 memref has a 0-d region.
cst.reset(rank, loopDepth, 0, regionSymbols);
@@ -201,7 +192,7 @@
unsigned numSymbols = accessMap.getNumSymbols();
unsigned numOperands = accessValueMap.getNumOperands();
// Merge operands with slice operands.
- SmallVector<Value *, 4> operands;
+ SmallVector<Value, 4> operands;
operands.resize(numOperands);
for (unsigned i = 0; i < numOperands; ++i)
operands[i] = accessValueMap.getOperand(i);
@@ -224,7 +215,7 @@
// Add equality constraints.
// Add inequalities for loop lower/upper bounds.
for (unsigned i = 0; i < numDims + numSymbols; ++i) {
- auto *operand = operands[i];
+ auto operand = operands[i];
if (auto loop = getForInductionVarOwner(operand)) {
// Note that cst can now have more dimensions than accessMap if the
// bounds expressions involve outer loops or other symbols.
@@ -234,7 +225,7 @@
return failure();
} else {
// Has to be a valid symbol.
- auto *symbol = operand;
+ auto symbol = operand;
assert(isValidSymbol(symbol));
// Check if the symbol is a constant.
if (auto *op = symbol->getDefiningOp()) {
@@ -278,9 +269,9 @@
getLoopIVs(*op, &enclosingIVs);
assert(loopDepth <= enclosingIVs.size() && "invalid loop depth");
enclosingIVs.resize(loopDepth);
- SmallVector<Value *, 4> ids;
+ SmallVector<Value, 4> ids;
cst.getIdValues(cst.getNumDimIds(), cst.getNumDimAndSymbolIds(), &ids);
- for (auto *id : ids) {
+ for (auto id : ids) {
AffineForOp iv;
if ((iv = getForInductionVarOwner(id)) &&
llvm::is_contained(enclosingIVs, iv) == false) {
@@ -345,9 +336,9 @@
// Indices to use for the DmaStart op.
// Indices for the original memref being DMAed from/to.
- SmallVector<Value *, 4> memIndices;
+ SmallVector<Value, 4> memIndices;
// Indices for the faster buffer being DMAed into/from.
- SmallVector<Value *, 4> bufIndices;
+ SmallVector<Value, 4> bufIndices;
// Compute the extents of the buffer.
Optional<int64_t> numElements = getConstantBoundingSizeAndShape();
@@ -480,10 +471,10 @@
}
// Adds loop IV bounds to 'cst' for loop IVs not found in 'ivs'.
-LogicalResult addMissingLoopIVBounds(SmallPtrSet<Value *, 8> &ivs,
+LogicalResult addMissingLoopIVBounds(SmallPtrSet<Value, 8> &ivs,
FlatAffineConstraints *cst) {
for (unsigned i = 0, e = cst->getNumDimIds(); i < e; ++i) {
- auto *value = cst->getIdValue(i);
+ auto value = cst->getIdValue(i);
if (ivs.count(value) == 0) {
assert(isForInductionVar(value));
auto loop = getForInductionVarOwner(value);
@@ -596,10 +587,10 @@
// Pre-constraint id alignment: record loop IVs used in each constraint
// system.
- SmallPtrSet<Value *, 8> sliceUnionIVs;
+ SmallPtrSet<Value, 8> sliceUnionIVs;
for (unsigned k = 0, l = sliceUnionCst.getNumDimIds(); k < l; ++k)
sliceUnionIVs.insert(sliceUnionCst.getIdValue(k));
- SmallPtrSet<Value *, 8> tmpSliceIVs;
+ SmallPtrSet<Value, 8> tmpSliceIVs;
for (unsigned k = 0, l = tmpSliceCst.getNumDimIds(); k < l; ++k)
tmpSliceIVs.insert(tmpSliceCst.getIdValue(k));
@@ -659,7 +650,7 @@
&sliceUnion->ubs);
// Add slice bound operands of union.
- SmallVector<Value *, 4> sliceBoundOperands;
+ SmallVector<Value, 4> sliceBoundOperands;
sliceUnionCst.getIdValues(numSliceLoopIVs,
sliceUnionCst.getNumDimAndSymbolIds(),
&sliceBoundOperands);
@@ -725,7 +716,7 @@
&sliceState->lbs, &sliceState->ubs);
// Set up bound operands for the slice's lower and upper bounds.
- SmallVector<Value *, 4> sliceBoundOperands;
+ SmallVector<Value, 4> sliceBoundOperands;
unsigned numDimsAndSymbols = dependenceConstraints->getNumDimAndSymbolIds();
for (unsigned i = 0; i < numDimsAndSymbols; ++i) {
if (i < offset || i >= offset + numSliceLoopIVs) {
@@ -743,7 +734,7 @@
isBackwardSlice ? dstLoopIVs[loopDepth - 1].getBody()->begin()
: std::prev(srcLoopIVs[loopDepth - 1].getBody()->end());
- llvm::SmallDenseSet<Value *, 8> sequentialLoops;
+ llvm::SmallDenseSet<Value, 8> sequentialLoops;
if (isa<AffineLoadOp>(depSourceOp) && isa<AffineLoadOp>(depSinkOp)) {
// For read-read access pairs, clear any slice bounds on sequential loops.
// Get sequential loops in loop nest rooted at 'srcLoopIVs[0]'.
@@ -758,7 +749,7 @@
return isBackwardSlice ? srcLoopIVs[i] : dstLoopIVs[i];
};
for (unsigned i = 0; i < numSliceLoopIVs; ++i) {
- Value *iv = getSliceLoop(i).getInductionVar();
+ Value iv = getSliceLoop(i).getInductionVar();
if (sequentialLoops.count(iv) == 0 &&
getSliceLoop(i).getAttr(kSliceFusionBarrierAttrName) == nullptr)
continue;
@@ -846,7 +837,7 @@
opInst = loadOrStoreOpInst;
auto loadMemrefType = loadOp.getMemRefType();
indices.reserve(loadMemrefType.getRank());
- for (auto *index : loadOp.getMapOperands()) {
+ for (auto index : loadOp.getMapOperands()) {
indices.push_back(index);
}
} else {
@@ -856,7 +847,7 @@
memref = storeOp.getMemRef();
auto storeMemrefType = storeOp.getMemRefType();
indices.reserve(storeMemrefType.getRank());
- for (auto *index : storeOp.getMapOperands()) {
+ for (auto index : storeOp.getMapOperands()) {
indices.push_back(index);
}
}
@@ -919,7 +910,7 @@
Block::iterator start,
Block::iterator end,
int memorySpace) {
- SmallDenseMap<Value *, std::unique_ptr<MemRefRegion>, 4> regions;
+ SmallDenseMap<Value, std::unique_ptr<MemRefRegion>, 4> regions;
// Walk this 'affine.for' operation to gather all memory regions.
auto result = block.walk(start, end, [&](Operation *opInst) -> WalkResult {
@@ -969,8 +960,8 @@
/// Returns in 'sequentialLoops' all sequential loops in loop nest rooted
/// at 'forOp'.
-void mlir::getSequentialLoops(
- AffineForOp forOp, llvm::SmallDenseSet<Value *, 8> *sequentialLoops) {
+void mlir::getSequentialLoops(AffineForOp forOp,
+ llvm::SmallDenseSet<Value, 8> *sequentialLoops) {
forOp.getOperation()->walk([&](Operation *op) {
if (auto innerFor = dyn_cast<AffineForOp>(op))
if (!isLoopParallel(innerFor))
diff --git a/third_party/mlir/lib/Analysis/VectorAnalysis.cpp b/third_party/mlir/lib/Analysis/VectorAnalysis.cpp
index 42d3f10..1c7dbed 100644
--- a/third_party/mlir/lib/Analysis/VectorAnalysis.cpp
+++ b/third_party/mlir/lib/Analysis/VectorAnalysis.cpp
@@ -1,19 +1,10 @@
//===- VectorAnalysis.cpp - Analysis for Vectorization --------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/LoopAnalysis.h"
@@ -109,7 +100,7 @@
/// Examples can be found in the documentation of `makePermutationMap`, in the
/// header file.
static AffineMap makePermutationMap(
- ArrayRef<Value *> indices,
+ ArrayRef<Value> indices,
const DenseMap<Operation *, unsigned> &enclosingLoopToVectorDim) {
if (enclosingLoopToVectorDim.empty())
return AffineMap();
@@ -167,7 +158,7 @@
}
AffineMap mlir::makePermutationMap(
- Operation *op, ArrayRef<Value *> indices,
+ Operation *op, ArrayRef<Value> indices,
const DenseMap<Operation *, unsigned> &loopToVectorDim) {
DenseMap<Operation *, unsigned> enclosingLoopToVectorDim;
auto enclosingLoops = getEnclosingforOps(op);
diff --git a/third_party/mlir/lib/Analysis/Verifier.cpp b/third_party/mlir/lib/Analysis/Verifier.cpp
index 82f5aa5..d4861b1 100644
--- a/third_party/mlir/lib/Analysis/Verifier.cpp
+++ b/third_party/mlir/lib/Analysis/Verifier.cpp
@@ -1,19 +1,10 @@
//===- Verifier.cpp - MLIR Verifier Implementation ------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements the verify() methods on the various IR types, performing
// (potentially expensive) checks on the holistic structure of the code. This
@@ -138,7 +129,7 @@
}
LogicalResult OperationVerifier::verifyBlock(Block &block) {
- for (auto *arg : block.getArguments())
+ for (auto arg : block.getArguments())
if (arg->getOwner() != &block)
return emitError(block, "block argument not owned by block");
@@ -175,7 +166,7 @@
LogicalResult OperationVerifier::verifyOperation(Operation &op) {
// Check that operands are non-nil and structurally ok.
- for (auto *operand : op.getOperands())
+ for (auto operand : op.getOperands())
if (!operand)
return op.emitError("null operand found");
@@ -244,7 +235,7 @@
// Check that operands properly dominate this use.
for (unsigned operandNo = 0, e = op.getNumOperands(); operandNo != e;
++operandNo) {
- auto *operand = op.getOperand(operandNo);
+ auto operand = op.getOperand(operandNo);
if (domInfo->properlyDominates(operand, &op))
continue;
diff --git a/third_party/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/third_party/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
index 9208ce8..e9a9ca8 100644
--- a/third_party/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
+++ b/third_party/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
@@ -1,19 +1,10 @@
//===- AffineToStandard.cpp - Lower affine constructs to primitives -------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file lowers affine constructs (If and For statements, AffineApply
// operations) within a function into their standard If and For equivalent ops.
@@ -42,16 +33,16 @@
// that correspond to it. Visitation functions return an Value of the
// expression subtree they visited or `nullptr` on error.
class AffineApplyExpander
- : public AffineExprVisitor<AffineApplyExpander, Value *> {
+ : public AffineExprVisitor<AffineApplyExpander, Value> {
public:
// This internal class expects arguments to be non-null, checks must be
// performed at the call site.
- AffineApplyExpander(OpBuilder &builder, ArrayRef<Value *> dimValues,
- ArrayRef<Value *> symbolValues, Location loc)
+ AffineApplyExpander(OpBuilder &builder, ArrayRef<Value> dimValues,
+ ArrayRef<Value> symbolValues, Location loc)
: builder(builder), dimValues(dimValues), symbolValues(symbolValues),
loc(loc) {}
- template <typename OpTy> Value *buildBinaryExpr(AffineBinaryOpExpr expr) {
+ template <typename OpTy> Value buildBinaryExpr(AffineBinaryOpExpr expr) {
auto lhs = visit(expr.getLHS());
auto rhs = visit(expr.getRHS());
if (!lhs || !rhs)
@@ -60,11 +51,11 @@
return op.getResult();
}
- Value *visitAddExpr(AffineBinaryOpExpr expr) {
+ Value visitAddExpr(AffineBinaryOpExpr expr) {
return buildBinaryExpr<AddIOp>(expr);
}
- Value *visitMulExpr(AffineBinaryOpExpr expr) {
+ Value visitMulExpr(AffineBinaryOpExpr expr) {
return buildBinaryExpr<MulIOp>(expr);
}
@@ -77,7 +68,7 @@
// let remainder = srem a, b;
// negative = a < 0 in
// select negative, remainder + b, remainder.
- Value *visitModExpr(AffineBinaryOpExpr expr) {
+ Value visitModExpr(AffineBinaryOpExpr expr) {
auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>();
if (!rhsConst) {
emitError(
@@ -94,13 +85,13 @@
auto rhs = visit(expr.getRHS());
assert(lhs && rhs && "unexpected affine expr lowering failure");
- Value *remainder = builder.create<RemISOp>(loc, lhs, rhs);
- Value *zeroCst = builder.create<ConstantIndexOp>(loc, 0);
- Value *isRemainderNegative =
+ Value remainder = builder.create<SignedRemIOp>(loc, lhs, rhs);
+ Value zeroCst = builder.create<ConstantIndexOp>(loc, 0);
+ Value isRemainderNegative =
builder.create<CmpIOp>(loc, CmpIPredicate::slt, remainder, zeroCst);
- Value *correctedRemainder = builder.create<AddIOp>(loc, remainder, rhs);
- Value *result = builder.create<SelectOp>(loc, isRemainderNegative,
- correctedRemainder, remainder);
+ Value correctedRemainder = builder.create<AddIOp>(loc, remainder, rhs);
+ Value result = builder.create<SelectOp>(loc, isRemainderNegative,
+ correctedRemainder, remainder);
return result;
}
@@ -114,7 +105,7 @@
// let absolute = negative ? -a - 1 : a in
// let quotient = absolute / b in
// negative ? -quotient - 1 : quotient
- Value *visitFloorDivExpr(AffineBinaryOpExpr expr) {
+ Value visitFloorDivExpr(AffineBinaryOpExpr expr) {
auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>();
if (!rhsConst) {
emitError(
@@ -131,16 +122,16 @@
auto rhs = visit(expr.getRHS());
assert(lhs && rhs && "unexpected affine expr lowering failure");
- Value *zeroCst = builder.create<ConstantIndexOp>(loc, 0);
- Value *noneCst = builder.create<ConstantIndexOp>(loc, -1);
- Value *negative =
+ Value zeroCst = builder.create<ConstantIndexOp>(loc, 0);
+ Value noneCst = builder.create<ConstantIndexOp>(loc, -1);
+ Value negative =
builder.create<CmpIOp>(loc, CmpIPredicate::slt, lhs, zeroCst);
- Value *negatedDecremented = builder.create<SubIOp>(loc, noneCst, lhs);
- Value *dividend =
+ Value negatedDecremented = builder.create<SubIOp>(loc, noneCst, lhs);
+ Value dividend =
builder.create<SelectOp>(loc, negative, negatedDecremented, lhs);
- Value *quotient = builder.create<DivISOp>(loc, dividend, rhs);
- Value *correctedQuotient = builder.create<SubIOp>(loc, noneCst, quotient);
- Value *result =
+ Value quotient = builder.create<SignedDivIOp>(loc, dividend, rhs);
+ Value correctedQuotient = builder.create<SubIOp>(loc, noneCst, quotient);
+ Value result =
builder.create<SelectOp>(loc, negative, correctedQuotient, quotient);
return result;
}
@@ -155,7 +146,7 @@
// let absolute = negative ? -a : a - 1 in
// let quotient = absolute / b in
// negative ? -quotient : quotient + 1
- Value *visitCeilDivExpr(AffineBinaryOpExpr expr) {
+ Value visitCeilDivExpr(AffineBinaryOpExpr expr) {
auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>();
if (!rhsConst) {
emitError(loc) << "semi-affine expressions (division by non-const) are "
@@ -170,23 +161,23 @@
auto rhs = visit(expr.getRHS());
assert(lhs && rhs && "unexpected affine expr lowering failure");
- Value *zeroCst = builder.create<ConstantIndexOp>(loc, 0);
- Value *oneCst = builder.create<ConstantIndexOp>(loc, 1);
- Value *nonPositive =
+ Value zeroCst = builder.create<ConstantIndexOp>(loc, 0);
+ Value oneCst = builder.create<ConstantIndexOp>(loc, 1);
+ Value nonPositive =
builder.create<CmpIOp>(loc, CmpIPredicate::sle, lhs, zeroCst);
- Value *negated = builder.create<SubIOp>(loc, zeroCst, lhs);
- Value *decremented = builder.create<SubIOp>(loc, lhs, oneCst);
- Value *dividend =
+ Value negated = builder.create<SubIOp>(loc, zeroCst, lhs);
+ Value decremented = builder.create<SubIOp>(loc, lhs, oneCst);
+ Value dividend =
builder.create<SelectOp>(loc, nonPositive, negated, decremented);
- Value *quotient = builder.create<DivISOp>(loc, dividend, rhs);
- Value *negatedQuotient = builder.create<SubIOp>(loc, zeroCst, quotient);
- Value *incrementedQuotient = builder.create<AddIOp>(loc, quotient, oneCst);
- Value *result = builder.create<SelectOp>(loc, nonPositive, negatedQuotient,
- incrementedQuotient);
+ Value quotient = builder.create<SignedDivIOp>(loc, dividend, rhs);
+ Value negatedQuotient = builder.create<SubIOp>(loc, zeroCst, quotient);
+ Value incrementedQuotient = builder.create<AddIOp>(loc, quotient, oneCst);
+ Value result = builder.create<SelectOp>(loc, nonPositive, negatedQuotient,
+ incrementedQuotient);
return result;
}
- Value *visitConstantExpr(AffineConstantExpr expr) {
+ Value visitConstantExpr(AffineConstantExpr expr) {
auto valueAttr =
builder.getIntegerAttr(builder.getIndexType(), expr.getValue());
auto op =
@@ -194,13 +185,13 @@
return op.getResult();
}
- Value *visitDimExpr(AffineDimExpr expr) {
+ Value visitDimExpr(AffineDimExpr expr) {
assert(expr.getPosition() < dimValues.size() &&
"affine dim position out of range");
return dimValues[expr.getPosition()];
}
- Value *visitSymbolExpr(AffineSymbolExpr expr) {
+ Value visitSymbolExpr(AffineSymbolExpr expr) {
assert(expr.getPosition() < symbolValues.size() &&
"symbol dim position out of range");
return symbolValues[expr.getPosition()];
@@ -208,8 +199,8 @@
private:
OpBuilder &builder;
- ArrayRef<Value *> dimValues;
- ArrayRef<Value *> symbolValues;
+ ArrayRef<Value> dimValues;
+ ArrayRef<Value> symbolValues;
Location loc;
};
@@ -217,18 +208,17 @@
// Create a sequence of operations that implement the `expr` applied to the
// given dimension and symbol values.
-mlir::Value *mlir::expandAffineExpr(OpBuilder &builder, Location loc,
- AffineExpr expr,
- ArrayRef<Value *> dimValues,
- ArrayRef<Value *> symbolValues) {
+mlir::Value mlir::expandAffineExpr(OpBuilder &builder, Location loc,
+ AffineExpr expr, ArrayRef<Value> dimValues,
+ ArrayRef<Value> symbolValues) {
return AffineApplyExpander(builder, dimValues, symbolValues, loc).visit(expr);
}
// Create a sequence of operations that implement the `affineMap` applied to
// the given `operands` (as it it were an AffineApplyOp).
-Optional<SmallVector<Value *, 8>> static expandAffineMap(
+Optional<SmallVector<Value, 8>> static expandAffineMap(
OpBuilder &builder, Location loc, AffineMap affineMap,
- ArrayRef<Value *> operands) {
+ ArrayRef<Value> operands) {
auto numDims = affineMap.getNumDims();
auto expanded = functional::map(
[numDims, &builder, loc, operands](AffineExpr expr) {
@@ -237,7 +227,7 @@
operands.drop_front(numDims));
},
affineMap.getResults());
- if (llvm::all_of(expanded, [](Value *v) { return v; }))
+ if (llvm::all_of(expanded, [](Value v) { return v; }))
return expanded;
return None;
}
@@ -253,13 +243,13 @@
// Multiple values are scanned in a linear sequence. This creates a data
// dependences that wouldn't exist in a tree reduction, but is easier to
// recognize as a reduction by the subsequent passes.
-static Value *buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate,
- ArrayRef<Value *> values,
- OpBuilder &builder) {
+static Value buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate,
+ ArrayRef<Value> values,
+ OpBuilder &builder) {
assert(!llvm::empty(values) && "empty min/max chain");
auto valueIt = values.begin();
- Value *value = *valueIt++;
+ Value value = *valueIt++;
for (; valueIt != values.end(); ++valueIt) {
auto cmpOp = builder.create<CmpIOp>(loc, predicate, value, *valueIt);
value = builder.create<SelectOp>(loc, cmpOp.getResult(), value, *valueIt);
@@ -271,8 +261,8 @@
// Emit instructions that correspond to the affine map in the lower bound
// applied to the respective operands, and compute the maximum value across
// the results.
-Value *mlir::lowerAffineLowerBound(AffineForOp op, OpBuilder &builder) {
- SmallVector<Value *, 8> boundOperands(op.getLowerBoundOperands());
+Value mlir::lowerAffineLowerBound(AffineForOp op, OpBuilder &builder) {
+ SmallVector<Value, 8> boundOperands(op.getLowerBoundOperands());
auto lbValues = expandAffineMap(builder, op.getLoc(), op.getLowerBoundMap(),
boundOperands);
if (!lbValues)
@@ -284,8 +274,8 @@
// Emit instructions that correspond to the affine map in the upper bound
// applied to the respective operands, and compute the minimum value across
// the results.
-Value *mlir::lowerAffineUpperBound(AffineForOp op, OpBuilder &builder) {
- SmallVector<Value *, 8> boundOperands(op.getUpperBoundOperands());
+Value mlir::lowerAffineUpperBound(AffineForOp op, OpBuilder &builder) {
+ SmallVector<Value, 8> boundOperands(op.getUpperBoundOperands());
auto ubValues = expandAffineMap(builder, op.getLoc(), op.getUpperBoundMap(),
boundOperands);
if (!ubValues)
@@ -314,9 +304,9 @@
PatternMatchResult matchAndRewrite(AffineForOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- Value *lowerBound = lowerAffineLowerBound(op, rewriter);
- Value *upperBound = lowerAffineUpperBound(op, rewriter);
- Value *step = rewriter.create<ConstantIndexOp>(loc, op.getStep());
+ Value lowerBound = lowerAffineLowerBound(op, rewriter);
+ Value upperBound = lowerAffineUpperBound(op, rewriter);
+ Value step = rewriter.create<ConstantIndexOp>(loc, op.getStep());
auto f = rewriter.create<loop::ForOp>(loc, lowerBound, upperBound, step);
f.region().getBlocks().clear();
rewriter.inlineRegionBefore(op.region(), f.region(), f.region().end());
@@ -335,25 +325,25 @@
// Now we just have to handle the condition logic.
auto integerSet = op.getIntegerSet();
- Value *zeroConstant = rewriter.create<ConstantIndexOp>(loc, 0);
- SmallVector<Value *, 8> operands(op.getOperands());
+ Value zeroConstant = rewriter.create<ConstantIndexOp>(loc, 0);
+ SmallVector<Value, 8> operands(op.getOperands());
auto operandsRef = llvm::makeArrayRef(operands);
// Calculate cond as a conjunction without short-circuiting.
- Value *cond = nullptr;
+ Value cond = nullptr;
for (unsigned i = 0, e = integerSet.getNumConstraints(); i < e; ++i) {
AffineExpr constraintExpr = integerSet.getConstraint(i);
bool isEquality = integerSet.isEq(i);
// Build and apply an affine expression
auto numDims = integerSet.getNumDims();
- Value *affResult = expandAffineExpr(rewriter, loc, constraintExpr,
- operandsRef.take_front(numDims),
- operandsRef.drop_front(numDims));
+ Value affResult = expandAffineExpr(rewriter, loc, constraintExpr,
+ operandsRef.take_front(numDims),
+ operandsRef.drop_front(numDims));
if (!affResult)
return matchFailure();
auto pred = isEquality ? CmpIPredicate::eq : CmpIPredicate::sge;
- Value *cmpVal =
+ Value cmpVal =
rewriter.create<CmpIOp>(loc, pred, affResult, zeroConstant);
cond =
cond ? rewriter.create<AndOp>(loc, cond, cmpVal).getResult() : cmpVal;
@@ -404,7 +394,7 @@
PatternMatchResult matchAndRewrite(AffineLoadOp op,
PatternRewriter &rewriter) const override {
// Expand affine map from 'affineLoadOp'.
- SmallVector<Value *, 8> indices(op.getMapOperands());
+ SmallVector<Value, 8> indices(op.getMapOperands());
auto resultOperands =
expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
if (!resultOperands)
@@ -426,7 +416,7 @@
PatternMatchResult matchAndRewrite(AffinePrefetchOp op,
PatternRewriter &rewriter) const override {
// Expand affine map from 'affinePrefetchOp'.
- SmallVector<Value *, 8> indices(op.getMapOperands());
+ SmallVector<Value, 8> indices(op.getMapOperands());
auto resultOperands =
expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
if (!resultOperands)
@@ -450,7 +440,7 @@
PatternMatchResult matchAndRewrite(AffineStoreOp op,
PatternRewriter &rewriter) const override {
// Expand affine map from 'affineStoreOp'.
- SmallVector<Value *, 8> indices(op.getMapOperands());
+ SmallVector<Value, 8> indices(op.getMapOperands());
auto maybeExpandedMap =
expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
if (!maybeExpandedMap)
@@ -472,7 +462,7 @@
PatternMatchResult matchAndRewrite(AffineDmaStartOp op,
PatternRewriter &rewriter) const override {
- SmallVector<Value *, 8> operands(op.getOperands());
+ SmallVector<Value, 8> operands(op.getOperands());
auto operandsRef = llvm::makeArrayRef(operands);
// Expand affine map for DMA source memref.
@@ -513,7 +503,7 @@
PatternMatchResult matchAndRewrite(AffineDmaWaitOp op,
PatternRewriter &rewriter) const override {
// Expand affine map for DMA tag memref.
- SmallVector<Value *, 8> indices(op.getTagIndices());
+ SmallVector<Value, 8> indices(op.getTagIndices());
auto maybeExpandedTagMap =
expandAffineMap(rewriter, op.getLoc(), op.getTagMap(), indices);
if (!maybeExpandedTagMap)
diff --git a/third_party/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/third_party/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
index 6a1a580..63bc151 100644
--- a/third_party/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
+++ b/third_party/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h
@@ -1,19 +1,10 @@
//===- IndexIntrinsicsOpLowering.h - GPU IndexOps Lowering class *- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_
#define MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_
@@ -57,11 +48,11 @@
// Convert the kernel arguments to an LLVM type, preserve the rest.
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto dialect = lowering.getDialect();
- Value *newOp;
+ Value newOp;
switch (dimensionToIndex(cast<Op>(op))) {
case X:
newOp = rewriter.create<XOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
diff --git a/third_party/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/third_party/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index 23bfa30..b75c1bf 100644
--- a/third_party/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/third_party/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -1,19 +1,10 @@
//===- OpToFuncCallLowering.h - GPU ops lowering to custom calls *- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
#define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
@@ -44,7 +35,7 @@
f32Func(f32Func), f64Func(f64Func) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
using LLVM::LLVMFuncOp;
using LLVM::LLVMType;
@@ -69,10 +60,10 @@
private:
LLVM::LLVMType getFunctionType(LLVM::LLVMType resultType,
- ArrayRef<Value *> operands) const {
+ ArrayRef<Value> operands) const {
using LLVM::LLVMType;
SmallVector<LLVMType, 1> operandTypes;
- for (Value *operand : operands) {
+ for (Value operand : operands) {
operandTypes.push_back(operand->getType().cast<LLVMType>());
}
return LLVMType::getFunctionTy(resultType, operandTypes,
diff --git a/third_party/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp b/third_party/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp
index a91c43e..66a2e66 100644
--- a/third_party/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp
+++ b/third_party/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp
@@ -1,19 +1,10 @@
//===- ConvertKernelFuncToCubin.cpp - MLIR GPU lowering passes ------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements a pass to convert gpu kernel functions into a
// corresponding binary blob that can be executed on a CUDA GPU. Currently
diff --git a/third_party/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp b/third_party/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
index f342083..19dabcd 100644
--- a/third_party/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
+++ b/third_party/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
@@ -1,19 +1,10 @@
//===- ConvertLaunchFuncToCudaCalls.cpp - MLIR CUDA lowering passes -------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements a pass to convert gpu.launch_func op into a sequence of
// CUDA runtime calls. As the CUDA runtime does not have a stable published ABI,
@@ -114,7 +105,7 @@
}
// Allocate a void pointer on the stack.
- Value *allocatePointer(OpBuilder &builder, Location loc) {
+ Value allocatePointer(OpBuilder &builder, Location loc) {
auto one = builder.create<LLVM::ConstantOp>(loc, getInt32Type(),
builder.getI32IntegerAttr(1));
return builder.create<LLVM::AllocaOp>(loc, getPointerPointerType(), one,
@@ -122,9 +113,9 @@
}
void declareCudaFunctions(Location loc);
- Value *setupParamsArray(gpu::LaunchFuncOp launchOp, OpBuilder &builder);
- Value *generateKernelNameConstant(StringRef name, Location loc,
- OpBuilder &builder);
+ Value setupParamsArray(gpu::LaunchFuncOp launchOp, OpBuilder &builder);
+ Value generateKernelNameConstant(StringRef name, Location loc,
+ OpBuilder &builder);
void translateGpuLaunchCalls(mlir::gpu::LaunchFuncOp launchOp);
public:
@@ -248,9 +239,8 @@
// for (i : [0, NumKernelOperands))
// %array[i] = cast<void*>(KernelOperand[i])
// return %array
-Value *
-GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp,
- OpBuilder &builder) {
+Value GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp,
+ OpBuilder &builder) {
auto numKernelOperands = launchOp.getNumKernelOperands();
Location loc = launchOp.getLoc();
auto one = builder.create<LLVM::ConstantOp>(loc, getInt32Type(),
@@ -264,7 +254,7 @@
for (unsigned idx = 0; idx < numKernelOperands; ++idx) {
auto operand = launchOp.getKernelOperand(idx);
auto llvmType = operand->getType().cast<LLVM::LLVMType>();
- Value *memLocation = builder.create<LLVM::AllocaOp>(
+ Value memLocation = builder.create<LLVM::AllocaOp>(
loc, llvmType.getPointerTo(), one, /*alignment=*/1);
builder.create<LLVM::StoreOp>(loc, operand, memLocation);
auto casted =
@@ -280,12 +270,12 @@
getModule().lookupSymbol<LLVM::LLVMFuncOp>(kMcuMemHostRegister);
auto nullPtr = builder.create<LLVM::NullOp>(loc, llvmType.getPointerTo());
auto gep = builder.create<LLVM::GEPOp>(loc, llvmType.getPointerTo(),
- ArrayRef<Value *>{nullPtr, one});
+ ArrayRef<Value>{nullPtr, one});
auto size = builder.create<LLVM::PtrToIntOp>(loc, getInt64Type(), gep);
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{},
builder.getSymbolRefAttr(registerFunc),
- ArrayRef<Value *>{casted, size});
- Value *memLocation = builder.create<LLVM::AllocaOp>(
+ ArrayRef<Value>{casted, size});
+ Value memLocation = builder.create<LLVM::AllocaOp>(
loc, getPointerPointerType(), one, /*alignment=*/1);
builder.create<LLVM::StoreOp>(loc, casted, memLocation);
casted =
@@ -295,7 +285,7 @@
auto index = builder.create<LLVM::ConstantOp>(
loc, getInt32Type(), builder.getI32IntegerAttr(idx));
auto gep = builder.create<LLVM::GEPOp>(loc, getPointerPointerType(), array,
- ArrayRef<Value *>{index});
+ ArrayRef<Value>{index});
builder.create<LLVM::StoreOp>(loc, casted, gep);
}
return array;
@@ -311,7 +301,7 @@
// %1 = llvm.constant (0 : index)
// %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*">
// }
-Value *GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant(
+Value GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant(
StringRef name, Location loc, OpBuilder &builder) {
// Make sure the trailing zero is included in the constant.
std::vector<char> kernelName(name.begin(), name.end());
@@ -367,7 +357,7 @@
assert(kernelModule.getName() && "expected a named module");
SmallString<128> nameBuffer(*kernelModule.getName());
nameBuffer.append(kCubinStorageSuffix);
- Value *data = LLVM::createGlobalString(
+ Value data = LLVM::createGlobalString(
loc, builder, nameBuffer.str(), cubinAttr.getValue(),
LLVM::Linkage::Internal, getLLVMDialect());
@@ -378,7 +368,7 @@
getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleLoadName);
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
builder.getSymbolRefAttr(cuModuleLoad),
- ArrayRef<Value *>{cuModule, data});
+ ArrayRef<Value>{cuModule, data});
// Get the function from the module. The name corresponds to the name of
// the kernel function.
auto cuOwningModuleRef =
@@ -390,13 +380,13 @@
builder.create<LLVM::CallOp>(
loc, ArrayRef<Type>{getCUResultType()},
builder.getSymbolRefAttr(cuModuleGetFunction),
- ArrayRef<Value *>{cuFunction, cuOwningModuleRef, kernelName});
+ ArrayRef<Value>{cuFunction, cuOwningModuleRef, kernelName});
// Grab the global stream needed for execution.
auto cuGetStreamHelper =
getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuGetStreamHelperName);
auto cuStream = builder.create<LLVM::CallOp>(
loc, ArrayRef<Type>{getPointerType()},
- builder.getSymbolRefAttr(cuGetStreamHelper), ArrayRef<Value *>{});
+ builder.getSymbolRefAttr(cuGetStreamHelper), ArrayRef<Value>{});
// Invoke the function with required arguments.
auto cuLaunchKernel =
getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuLaunchKernelName);
@@ -408,19 +398,19 @@
builder.create<LLVM::CallOp>(
loc, ArrayRef<Type>{getCUResultType()},
builder.getSymbolRefAttr(cuLaunchKernel),
- ArrayRef<Value *>{cuFunctionRef, launchOp.getOperand(0),
- launchOp.getOperand(1), launchOp.getOperand(2),
- launchOp.getOperand(3), launchOp.getOperand(4),
- launchOp.getOperand(5), zero, /* sharedMemBytes */
- cuStream.getResult(0), /* stream */
- paramsArray, /* kernel params */
- nullpointer /* extra */});
+ ArrayRef<Value>{cuFunctionRef, launchOp.getOperand(0),
+ launchOp.getOperand(1), launchOp.getOperand(2),
+ launchOp.getOperand(3), launchOp.getOperand(4),
+ launchOp.getOperand(5), zero, /* sharedMemBytes */
+ cuStream.getResult(0), /* stream */
+ paramsArray, /* kernel params */
+ nullpointer /* extra */});
// Sync on the stream to make it synchronous.
auto cuStreamSync =
getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuStreamSynchronizeName);
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
builder.getSymbolRefAttr(cuStreamSync),
- ArrayRef<Value *>(cuStream.getResult(0)));
+ ArrayRef<Value>(cuStream.getResult(0)));
launchOp.erase();
}
diff --git a/third_party/mlir/lib/Conversion/GPUToNVVM/GPUToNVVM.td b/third_party/mlir/lib/Conversion/GPUToNVVM/GPUToNVVM.td
index 8c27ba4..0a6aec0 100644
--- a/third_party/mlir/lib/Conversion/GPUToNVVM/GPUToNVVM.td
+++ b/third_party/mlir/lib/Conversion/GPUToNVVM/GPUToNVVM.td
@@ -1,19 +1,10 @@
//==-- GPUToNVVM.td - GPU Ops to NVVM Patterns ---------------*- tablegen -*==//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Defines Patterns to lower GPU ops to NVVM.
//
diff --git a/third_party/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/third_party/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 78fe15d..08c18c1 100644
--- a/third_party/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/third_party/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -1,19 +1,10 @@
//===- LowerGpuOpsToNVVMOps.cpp - MLIR GPU to NVVM lowering passes --------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements a pass to generate NVVMIR operations for higher-level
// GPU operations.
@@ -60,8 +51,8 @@
/// Converts all_reduce op to LLVM/NVVM ops.
struct GPUAllReduceOpLowering : public LLVMOpLowering {
- using AccumulatorFactory = std::function<Value *(
- Location, Value *, Value *, ConversionPatternRewriter &)>;
+ using AccumulatorFactory =
+ std::function<Value(Location, Value, Value, ConversionPatternRewriter &)>;
explicit GPUAllReduceOpLowering(LLVMTypeConverter &lowering_)
: LLVMOpLowering(gpu::AllReduceOp::getOperationName(),
@@ -69,10 +60,10 @@
int32Type(LLVM::LLVMType::getInt32Ty(lowering_.getDialect())) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
- Value *operand = operands.front();
+ Value operand = operands.front();
// TODO(csigg): Generalize to other types of accumulation.
assert(op->getOperand(0)->getType().isIntOrFloat());
@@ -81,7 +72,7 @@
AccumulatorFactory factory =
getFactory(cast<gpu::AllReduceOp>(op), operand);
assert(factory && "failed to create accumulator factory");
- Value *result = createBlockReduce(loc, operand, factory, rewriter);
+ Value result = createBlockReduce(loc, operand, factory, rewriter);
rewriter.replaceOp(op, {result});
return matchSuccess();
@@ -91,7 +82,7 @@
/// Returns an accumulator factory using either the op attribute or the body
/// region.
AccumulatorFactory getFactory(gpu::AllReduceOp allReduce,
- Value *operand) const {
+ Value operand) const {
if (!allReduce.body().empty()) {
return getFactory(allReduce.body());
}
@@ -106,7 +97,7 @@
/// block is expected to have 2 arguments. The gpu.yield return the
/// accumulated value of the same type.
AccumulatorFactory getFactory(Region &body) const {
- return AccumulatorFactory([&](Location loc, Value *lhs, Value *rhs,
+ return AccumulatorFactory([&](Location loc, Value lhs, Value rhs,
ConversionPatternRewriter &rewriter) {
Block *block = rewriter.getInsertionBlock();
Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint());
@@ -120,7 +111,7 @@
// Add branch before inserted body, into body.
block = block->getNextNode();
- rewriter.create<LLVM::BrOp>(loc, ArrayRef<Value *>{},
+ rewriter.create<LLVM::BrOp>(loc, ArrayRef<Value>{},
llvm::makeArrayRef(block), ValueRange());
// Replace all gpu.yield ops with branch out of body.
@@ -130,7 +121,7 @@
continue;
rewriter.setInsertionPointToEnd(block);
rewriter.replaceOpWithNewOp<LLVM::BrOp>(
- terminator, ArrayRef<Value *>{}, llvm::makeArrayRef(split),
+ terminator, ArrayRef<Value>{}, llvm::makeArrayRef(split),
ValueRange(terminator->getOperand(0)));
}
@@ -161,7 +152,7 @@
/// Returns an accumulator factory that creates an op of type T.
template <typename T> AccumulatorFactory getFactory() const {
- return [](Location loc, Value *lhs, Value *rhs,
+ return [](Location loc, Value lhs, Value rhs,
ConversionPatternRewriter &rewriter) {
return rewriter.create<T>(loc, lhs->getType(), lhs, rhs);
};
@@ -203,60 +194,60 @@
/// %result = llvm.load %result_ptr
/// return %result
///
- Value *createBlockReduce(Location loc, Value *operand,
- AccumulatorFactory &accumFactory,
- ConversionPatternRewriter &rewriter) const {
+ Value createBlockReduce(Location loc, Value operand,
+ AccumulatorFactory &accumFactory,
+ ConversionPatternRewriter &rewriter) const {
auto type = operand->getType().cast<LLVM::LLVMType>();
// Create shared memory array to store the warp reduction.
auto module = operand->getDefiningOp()->getParentOfType<ModuleOp>();
assert(module && "op must belong to a module");
- Value *sharedMemPtr =
+ Value sharedMemPtr =
createSharedMemoryArray(loc, module, type, kWarpSize, rewriter);
- Value *zero = rewriter.create<LLVM::ConstantOp>(
+ Value zero = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(0u));
- Value *laneId = rewriter.create<NVVM::LaneIdOp>(loc, int32Type);
- Value *isFirstLane = rewriter.create<LLVM::ICmpOp>(
+ Value laneId = rewriter.create<NVVM::LaneIdOp>(loc, int32Type);
+ Value isFirstLane = rewriter.create<LLVM::ICmpOp>(
loc, LLVM::ICmpPredicate::eq, laneId, zero);
- Value *threadIdx = getLinearThreadIndex(loc, rewriter);
- Value *blockSize = getBlockSize(loc, rewriter);
- Value *activeWidth = getActiveWidth(loc, threadIdx, blockSize, rewriter);
+ Value threadIdx = getLinearThreadIndex(loc, rewriter);
+ Value blockSize = getBlockSize(loc, rewriter);
+ Value activeWidth = getActiveWidth(loc, threadIdx, blockSize, rewriter);
// Reduce elements within each warp to produce the intermediate results.
- Value *warpReduce = createWarpReduce(loc, activeWidth, laneId, operand,
- accumFactory, rewriter);
+ Value warpReduce = createWarpReduce(loc, activeWidth, laneId, operand,
+ accumFactory, rewriter);
// Write the intermediate results to shared memory, using the first lane of
// each warp.
createPredicatedBlock(loc, rewriter, isFirstLane, [&] {
- Value *warpId = getDivideByWarpSize(threadIdx, rewriter);
- Value *storeDst = rewriter.create<LLVM::GEPOp>(
- loc, type, sharedMemPtr, ArrayRef<Value *>({zero, warpId}));
+ Value warpId = getDivideByWarpSize(threadIdx, rewriter);
+ Value storeDst = rewriter.create<LLVM::GEPOp>(
+ loc, type, sharedMemPtr, ArrayRef<Value>({zero, warpId}));
rewriter.create<LLVM::StoreOp>(loc, warpReduce, storeDst);
});
rewriter.create<NVVM::Barrier0Op>(loc);
- Value *numWarps = getNumWarps(loc, blockSize, rewriter);
- Value *isValidWarp = rewriter.create<LLVM::ICmpOp>(
+ Value numWarps = getNumWarps(loc, blockSize, rewriter);
+ Value isValidWarp = rewriter.create<LLVM::ICmpOp>(
loc, LLVM::ICmpPredicate::slt, threadIdx, numWarps);
- Value *resultPtr = rewriter.create<LLVM::GEPOp>(
- loc, type, sharedMemPtr, ArrayRef<Value *>({zero, zero}));
+ Value resultPtr = rewriter.create<LLVM::GEPOp>(
+ loc, type, sharedMemPtr, ArrayRef<Value>({zero, zero}));
// Use the first numWarps threads to reduce the intermediate results from
// shared memory. The final result is written to shared memory again.
createPredicatedBlock(loc, rewriter, isValidWarp, [&] {
- Value *loadSrc = rewriter.create<LLVM::GEPOp>(
- loc, type, sharedMemPtr, ArrayRef<Value *>({zero, threadIdx}));
- Value *value = rewriter.create<LLVM::LoadOp>(loc, type, loadSrc);
- Value *result = createWarpReduce(loc, numWarps, laneId, value,
- accumFactory, rewriter);
+ Value loadSrc = rewriter.create<LLVM::GEPOp>(
+ loc, type, sharedMemPtr, ArrayRef<Value>({zero, threadIdx}));
+ Value value = rewriter.create<LLVM::LoadOp>(loc, type, loadSrc);
+ Value result = createWarpReduce(loc, numWarps, laneId, value,
+ accumFactory, rewriter);
rewriter.create<LLVM::StoreOp>(loc, result, resultPtr);
});
rewriter.create<NVVM::Barrier0Op>(loc);
// Load and return result from shared memory.
- Value *result = rewriter.create<LLVM::LoadOp>(loc, type, resultPtr);
+ Value result = rewriter.create<LLVM::LoadOp>(loc, type, resultPtr);
return result;
}
@@ -274,7 +265,7 @@
///
template <typename ThenOpsFactory, typename ElseOpsFactory>
void createIf(Location loc, ConversionPatternRewriter &rewriter,
- Value *condition, ThenOpsFactory &&thenOpsFactory,
+ Value condition, ThenOpsFactory &&thenOpsFactory,
ElseOpsFactory &&elseOpsFactory) const {
Block *currentBlock = rewriter.getInsertionBlock();
auto currentPoint = rewriter.getInsertionPoint();
@@ -288,7 +279,7 @@
ArrayRef<Block *>{thenBlock, elseBlock});
auto addBranch = [&](ValueRange operands) {
- rewriter.create<LLVM::BrOp>(loc, ArrayRef<Value *>{},
+ rewriter.create<LLVM::BrOp>(loc, ArrayRef<Value>{},
llvm::makeArrayRef(continueBlock),
llvm::makeArrayRef(operands));
};
@@ -303,32 +294,32 @@
assert(thenOperands.size() == elseOperands.size());
rewriter.setInsertionPointToStart(continueBlock);
- for (auto *operand : thenOperands)
+ for (auto operand : thenOperands)
continueBlock->addArgument(operand->getType());
}
/// Shortcut for createIf with empty else block and no block operands.
template <typename Factory>
void createPredicatedBlock(Location loc, ConversionPatternRewriter &rewriter,
- Value *condition,
+ Value condition,
Factory &&predicatedOpsFactory) const {
createIf(
loc, rewriter, condition,
[&] {
predicatedOpsFactory();
- return ArrayRef<Value *>();
+ return ArrayRef<Value>();
},
- [&] { return ArrayRef<Value *>(); });
+ [&] { return ArrayRef<Value>(); });
}
/// Creates a reduction across the first activeWidth lanes of a warp.
/// The first lane returns the result, all others return values are undefined.
- Value *createWarpReduce(Location loc, Value *activeWidth, Value *laneId,
- Value *operand, AccumulatorFactory accumFactory,
- ConversionPatternRewriter &rewriter) const {
- Value *warpSize = rewriter.create<LLVM::ConstantOp>(
+ Value createWarpReduce(Location loc, Value activeWidth, Value laneId,
+ Value operand, AccumulatorFactory accumFactory,
+ ConversionPatternRewriter &rewriter) const {
+ Value warpSize = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize));
- Value *isPartialWarp = rewriter.create<LLVM::ICmpOp>(
+ Value isPartialWarp = rewriter.create<LLVM::ICmpOp>(
loc, LLVM::ICmpPredicate::slt, activeWidth, warpSize);
auto type = operand->getType().cast<LLVM::LLVMType>();
@@ -336,16 +327,16 @@
loc, rewriter, isPartialWarp,
// Generate reduction over a (potentially) partial warp.
[&] {
- Value *value = operand;
- Value *one = rewriter.create<LLVM::ConstantOp>(
+ Value value = operand;
+ Value one = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(1));
// Bit mask of active lanes: `(1 << activeWidth) - 1`.
- Value *activeMask = rewriter.create<LLVM::SubOp>(
+ Value activeMask = rewriter.create<LLVM::SubOp>(
loc, int32Type,
rewriter.create<LLVM::ShlOp>(loc, int32Type, one, activeWidth),
one);
// Clamp lane: `activeWidth - 1`
- Value *maskAndClamp =
+ Value maskAndClamp =
rewriter.create<LLVM::SubOp>(loc, int32Type, activeWidth, one);
auto dialect = lowering.getDialect();
auto predTy = LLVM::LLVMType::getInt1Ty(dialect);
@@ -356,53 +347,53 @@
// lane is within the active range. All lanes contain the final
// result, but only the first lane's result is used.
for (int i = 1; i < kWarpSize; i <<= 1) {
- Value *offset = rewriter.create<LLVM::ConstantOp>(
+ Value offset = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(i));
- Value *shfl = rewriter.create<NVVM::ShflBflyOp>(
+ Value shfl = rewriter.create<NVVM::ShflBflyOp>(
loc, shflTy, activeMask, value, offset, maskAndClamp,
returnValueAndIsValidAttr);
- Value *isActiveSrcLane = rewriter.create<LLVM::ExtractValueOp>(
+ Value isActiveSrcLane = rewriter.create<LLVM::ExtractValueOp>(
loc, predTy, shfl, rewriter.getIndexArrayAttr(1));
// Skip the accumulation if the shuffle op read from a lane outside
// of the active range.
createIf(
loc, rewriter, isActiveSrcLane,
[&] {
- Value *shflValue = rewriter.create<LLVM::ExtractValueOp>(
+ Value shflValue = rewriter.create<LLVM::ExtractValueOp>(
loc, type, shfl, rewriter.getIndexArrayAttr(0));
- return SmallVector<Value *, 1>{
+ return SmallVector<Value, 1>{
accumFactory(loc, value, shflValue, rewriter)};
},
[&] { return llvm::makeArrayRef(value); });
value = rewriter.getInsertionBlock()->getArgument(0);
}
- return SmallVector<Value *, 1>{value};
+ return SmallVector<Value, 1>{value};
},
// Generate a reduction over the entire warp. This is a specialization
// of the above reduction with unconditional accumulation.
[&] {
- Value *value = operand;
- Value *activeMask = rewriter.create<LLVM::ConstantOp>(
+ Value value = operand;
+ Value activeMask = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(~0u));
- Value *maskAndClamp = rewriter.create<LLVM::ConstantOp>(
+ Value maskAndClamp = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1));
for (int i = 1; i < kWarpSize; i <<= 1) {
- Value *offset = rewriter.create<LLVM::ConstantOp>(
+ Value offset = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(i));
- Value *shflValue = rewriter.create<NVVM::ShflBflyOp>(
+ Value shflValue = rewriter.create<NVVM::ShflBflyOp>(
loc, type, activeMask, value, offset, maskAndClamp,
/*return_value_and_is_valid=*/UnitAttr());
value = accumFactory(loc, value, shflValue, rewriter);
}
- return SmallVector<Value *, 1>{value};
+ return SmallVector<Value, 1>{value};
});
return rewriter.getInsertionBlock()->getArgument(0);
}
/// Creates a global array stored in shared memory.
- Value *createSharedMemoryArray(Location loc, ModuleOp module,
- LLVM::LLVMType elementType, int numElements,
- ConversionPatternRewriter &rewriter) const {
+ Value createSharedMemoryArray(Location loc, ModuleOp module,
+ LLVM::LLVMType elementType, int numElements,
+ ConversionPatternRewriter &rewriter) const {
OpBuilder builder(module.getBodyRegion());
auto arrayType = LLVM::LLVMType::getArrayTy(elementType, numElements);
@@ -416,31 +407,31 @@
}
/// Returns the index of the thread within the block.
- Value *getLinearThreadIndex(Location loc,
- ConversionPatternRewriter &rewriter) const {
- Value *dimX = rewriter.create<NVVM::BlockDimXOp>(loc, int32Type);
- Value *dimY = rewriter.create<NVVM::BlockDimYOp>(loc, int32Type);
- Value *idX = rewriter.create<NVVM::ThreadIdXOp>(loc, int32Type);
- Value *idY = rewriter.create<NVVM::ThreadIdYOp>(loc, int32Type);
- Value *idZ = rewriter.create<NVVM::ThreadIdZOp>(loc, int32Type);
- Value *tmp1 = rewriter.create<LLVM::MulOp>(loc, int32Type, idZ, dimY);
- Value *tmp2 = rewriter.create<LLVM::AddOp>(loc, int32Type, tmp1, idY);
- Value *tmp3 = rewriter.create<LLVM::MulOp>(loc, int32Type, tmp2, dimX);
+ Value getLinearThreadIndex(Location loc,
+ ConversionPatternRewriter &rewriter) const {
+ Value dimX = rewriter.create<NVVM::BlockDimXOp>(loc, int32Type);
+ Value dimY = rewriter.create<NVVM::BlockDimYOp>(loc, int32Type);
+ Value idX = rewriter.create<NVVM::ThreadIdXOp>(loc, int32Type);
+ Value idY = rewriter.create<NVVM::ThreadIdYOp>(loc, int32Type);
+ Value idZ = rewriter.create<NVVM::ThreadIdZOp>(loc, int32Type);
+ Value tmp1 = rewriter.create<LLVM::MulOp>(loc, int32Type, idZ, dimY);
+ Value tmp2 = rewriter.create<LLVM::AddOp>(loc, int32Type, tmp1, idY);
+ Value tmp3 = rewriter.create<LLVM::MulOp>(loc, int32Type, tmp2, dimX);
return rewriter.create<LLVM::AddOp>(loc, int32Type, tmp3, idX);
}
/// Returns the number of threads in the block.
- Value *getBlockSize(Location loc, ConversionPatternRewriter &rewriter) const {
- Value *dimX = rewriter.create<NVVM::BlockDimXOp>(loc, int32Type);
- Value *dimY = rewriter.create<NVVM::BlockDimYOp>(loc, int32Type);
- Value *dimZ = rewriter.create<NVVM::BlockDimZOp>(loc, int32Type);
- Value *dimXY = rewriter.create<LLVM::MulOp>(loc, int32Type, dimX, dimY);
+ Value getBlockSize(Location loc, ConversionPatternRewriter &rewriter) const {
+ Value dimX = rewriter.create<NVVM::BlockDimXOp>(loc, int32Type);
+ Value dimY = rewriter.create<NVVM::BlockDimYOp>(loc, int32Type);
+ Value dimZ = rewriter.create<NVVM::BlockDimZOp>(loc, int32Type);
+ Value dimXY = rewriter.create<LLVM::MulOp>(loc, int32Type, dimX, dimY);
return rewriter.create<LLVM::MulOp>(loc, int32Type, dimXY, dimZ);
}
/// Returns the number of warps in the block.
- Value *getNumWarps(Location loc, Value *blockSize,
- ConversionPatternRewriter &rewriter) const {
+ Value getNumWarps(Location loc, Value blockSize,
+ ConversionPatternRewriter &rewriter) const {
auto warpSizeMinusOne = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1));
auto biasedBlockSize = rewriter.create<LLVM::AddOp>(
@@ -449,19 +440,19 @@
}
/// Returns the number of active threads in the warp, not clamped to 32.
- Value *getActiveWidth(Location loc, Value *threadIdx, Value *blockSize,
- ConversionPatternRewriter &rewriter) const {
- Value *threadIdxMask = rewriter.create<LLVM::ConstantOp>(
+ Value getActiveWidth(Location loc, Value threadIdx, Value blockSize,
+ ConversionPatternRewriter &rewriter) const {
+ Value threadIdxMask = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(~(kWarpSize - 1)));
- Value *numThreadsWithSmallerWarpId =
+ Value numThreadsWithSmallerWarpId =
rewriter.create<LLVM::AndOp>(loc, threadIdx, threadIdxMask);
return rewriter.create<LLVM::SubOp>(loc, blockSize,
numThreadsWithSmallerWarpId);
}
/// Returns value divided by the warp size (i.e. 32).
- Value *getDivideByWarpSize(Value *value,
- ConversionPatternRewriter &rewriter) const {
+ Value getDivideByWarpSize(Value value,
+ ConversionPatternRewriter &rewriter) const {
auto loc = value->getLoc();
auto warpSize = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize));
@@ -473,6 +464,64 @@
static constexpr int kWarpSize = 32;
};
+struct GPUShuffleOpLowering : public LLVMOpLowering {
+ explicit GPUShuffleOpLowering(LLVMTypeConverter &lowering_)
+ : LLVMOpLowering(gpu::ShuffleOp::getOperationName(),
+ lowering_.getDialect()->getContext(), lowering_) {}
+
+ /// Lowers a shuffle to the corresponding NVVM op.
+ ///
+ /// Convert the `width` argument into an activeMask (a bitmask which specifies
+ /// which threads participate in the shuffle) and a maskAndClamp (specifying
+ /// the highest lane which participates in the shuffle).
+ ///
+ /// %one = llvm.constant(1 : i32) : !llvm.i32
+ /// %shl = llvm.shl %one, %width : !llvm.i32
+ /// %active_mask = llvm.sub %shl, %one : !llvm.i32
+ /// %mask_and_clamp = llvm.sub %width, %one : !llvm.i32
+ /// %shfl = nvvm.shfl.sync.bfly %active_mask, %value, %offset,
+ /// %mask_and_clamp : !llvm<"{ float, i1 }">
+ /// %shfl_value = llvm.extractvalue %shfl[0 : index] :
+ /// !llvm<"{ float, i1 }">
+ /// %shfl_pred = llvm.extractvalue %shfl[1 : index] :
+ /// !llvm<"{ float, i1 }">
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
+ gpu::ShuffleOpOperandAdaptor adaptor(operands);
+
+ auto dialect = lowering.getDialect();
+ auto valueTy = adaptor.value()->getType().cast<LLVM::LLVMType>();
+ auto int32Type = LLVM::LLVMType::getInt32Ty(dialect);
+ auto predTy = LLVM::LLVMType::getInt1Ty(dialect);
+ auto resultTy = LLVM::LLVMType::getStructTy(dialect, {valueTy, predTy});
+
+ Value one = rewriter.create<LLVM::ConstantOp>(
+ loc, int32Type, rewriter.getI32IntegerAttr(1));
+ // Bit mask of active lanes: `(1 << activeWidth) - 1`.
+ Value activeMask = rewriter.create<LLVM::SubOp>(
+ loc, int32Type,
+ rewriter.create<LLVM::ShlOp>(loc, int32Type, one, adaptor.width()),
+ one);
+ // Clamp lane: `activeWidth - 1`
+ Value maskAndClamp =
+ rewriter.create<LLVM::SubOp>(loc, int32Type, adaptor.width(), one);
+
+ auto returnValueAndIsValidAttr = rewriter.getUnitAttr();
+ Value shfl = rewriter.create<NVVM::ShflBflyOp>(
+ loc, resultTy, activeMask, adaptor.value(), adaptor.offset(),
+ maskAndClamp, returnValueAndIsValidAttr);
+ Value shflValue = rewriter.create<LLVM::ExtractValueOp>(
+ loc, valueTy, shfl, rewriter.getIndexArrayAttr(0));
+ Value isActiveSrcLane = rewriter.create<LLVM::ExtractValueOp>(
+ loc, predTy, shfl, rewriter.getIndexArrayAttr(1));
+
+ rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
+ return matchSuccess();
+ }
+};
+
struct GPUFuncOpLowering : LLVMOpLowering {
explicit GPUFuncOpLowering(LLVMTypeConverter &typeConverter)
: LLVMOpLowering(gpu::GPUFuncOp::getOperationName(),
@@ -480,7 +529,7 @@
typeConverter) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
assert(operands.empty() && "func op is not expected to have operands");
auto gpuFuncOp = cast<gpu::GPUFuncOp>(op);
@@ -489,7 +538,7 @@
SmallVector<LLVM::GlobalOp, 3> workgroupBuffers;
workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());
for (auto en : llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
- Value *attribution = en.value();
+ Value attribution = en.value();
auto type = attribution->getType().dyn_cast<MemRefType>();
assert(type && type.hasStaticShape() && "unexpected type in attribution");
@@ -546,23 +595,23 @@
unsigned numProperArguments = gpuFuncOp.getNumArguments();
auto i32Type = LLVM::LLVMType::getInt32Ty(lowering.getDialect());
- Value *zero = nullptr;
+ Value zero = nullptr;
if (!workgroupBuffers.empty())
zero = rewriter.create<LLVM::ConstantOp>(loc, i32Type,
rewriter.getI32IntegerAttr(0));
for (auto en : llvm::enumerate(workgroupBuffers)) {
LLVM::GlobalOp global = en.value();
- Value *address = rewriter.create<LLVM::AddressOfOp>(loc, global);
+ Value address = rewriter.create<LLVM::AddressOfOp>(loc, global);
auto elementType = global.getType().getArrayElementType();
- Value *memory = rewriter.create<LLVM::GEPOp>(
+ Value memory = rewriter.create<LLVM::GEPOp>(
loc, elementType.getPointerTo(global.addr_space().getZExtValue()),
- address, ArrayRef<Value *>{zero, zero});
+ address, ArrayRef<Value>{zero, zero});
// Build a memref descriptor pointing to the buffer to plug with the
// existing memref infrastructure. This may use more registers than
// otherwise necessary given that memref sizes are fixed, but we can try
// and canonicalize that away later.
- Value *attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()];
+ Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()];
auto type = attribution->getType().cast<MemRefType>();
auto descr = MemRefDescriptor::fromStaticShape(rewriter, loc, lowering,
type, memory);
@@ -574,7 +623,7 @@
gpuFuncOp.getNumWorkgroupAttributions();
auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
for (auto en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
- Value *attribution = en.value();
+ Value attribution = en.value();
auto type = attribution->getType().cast<MemRefType>();
assert(type && type.hasStaticShape() &&
"unexpected type in attribution");
@@ -585,10 +634,10 @@
auto ptrType = lowering.convertType(type.getElementType())
.cast<LLVM::LLVMType>()
.getPointerTo();
- Value *numElements = rewriter.create<LLVM::ConstantOp>(
+ Value numElements = rewriter.create<LLVM::ConstantOp>(
gpuFuncOp.getLoc(), int64Ty,
rewriter.getI64IntegerAttr(type.getNumElements()));
- Value *allocated = rewriter.create<LLVM::AllocaOp>(
+ Value allocated = rewriter.create<LLVM::AllocaOp>(
gpuFuncOp.getLoc(), ptrType, numElements, /*alignment=*/0);
auto descr = MemRefDescriptor::fromStaticShape(rewriter, loc, lowering,
type, allocated);
@@ -616,8 +665,8 @@
!en.value().isa<UnrankedMemRefType>())
continue;
- BlockArgument *arg = block.getArgument(en.index());
- Value *loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
+ BlockArgument arg = block.getArgument(en.index());
+ Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
rewriter.replaceUsesOfBlockArgument(arg, loaded);
}
}
@@ -634,7 +683,7 @@
typeConverter) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands,
ArrayRef<Block *>());
@@ -688,8 +737,8 @@
NVVM::BlockIdYOp, NVVM::BlockIdZOp>,
GPUIndexIntrinsicOpLowering<gpu::GridDimOp, NVVM::GridDimXOp,
NVVM::GridDimYOp, NVVM::GridDimZOp>,
- GPUAllReduceOpLowering, GPUFuncOpLowering, GPUReturnOpLowering>(
- converter);
+ GPUAllReduceOpLowering, GPUShuffleOpLowering, GPUFuncOpLowering,
+ GPUReturnOpLowering>(converter);
patterns.insert<OpToFuncCallLowering<ExpOp>>(converter, "__nv_expf",
"__nv_exp");
}
diff --git a/third_party/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/third_party/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 59892db..8377064 100644
--- a/third_party/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/third_party/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -1,19 +1,10 @@
//===- LowerGpuOpsToROCDLOps.cpp - MLIR GPU to ROCDL lowering passes ------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements a pass to generate ROCDLIR operations for higher-level
// GPU operations.
diff --git a/third_party/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp b/third_party/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
index 42483a6..509457d 100644
--- a/third_party/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
+++ b/third_party/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
@@ -1,19 +1,10 @@
//===- ConvertGPUToSPIRV.cpp - Convert GPU ops to SPIR-V dialect ----------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements the conversion patterns from GPU ops to SPIR-V dialect.
//
@@ -36,7 +27,7 @@
using SPIRVOpLowering<loop::ForOp>::SPIRVOpLowering;
PatternMatchResult
- matchAndRewrite(loop::ForOp forOp, ArrayRef<Value *> operands,
+ matchAndRewrite(loop::ForOp forOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@@ -48,7 +39,7 @@
using SPIRVOpLowering<SourceOp>::SPIRVOpLowering;
PatternMatchResult
- matchAndRewrite(SourceOp op, ArrayRef<Value *> operands,
+ matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@@ -65,7 +56,7 @@
}
PatternMatchResult
- matchAndRewrite(gpu::GPUFuncOp funcOp, ArrayRef<Value *> operands,
+ matchAndRewrite(gpu::GPUFuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
private:
@@ -79,7 +70,7 @@
using SPIRVOpLowering<ModuleOp>::SPIRVOpLowering;
PatternMatchResult
- matchAndRewrite(ModuleOp moduleOp, ArrayRef<Value *> operands,
+ matchAndRewrite(ModuleOp moduleOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@@ -92,7 +83,7 @@
using SPIRVOpLowering<ModuleTerminatorOp>::SPIRVOpLowering;
PatternMatchResult
- matchAndRewrite(ModuleTerminatorOp terminatorOp, ArrayRef<Value *> operands,
+ matchAndRewrite(ModuleTerminatorOp terminatorOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@@ -103,7 +94,7 @@
using SPIRVOpLowering<gpu::ReturnOp>::SPIRVOpLowering;
PatternMatchResult
- matchAndRewrite(gpu::ReturnOp returnOp, ArrayRef<Value *> operands,
+ matchAndRewrite(gpu::ReturnOp returnOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@@ -114,7 +105,7 @@
//===----------------------------------------------------------------------===//
PatternMatchResult
-ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef<Value *> operands,
+ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
// loop::ForOp can be lowered to the structured control flow represented by
// spirv::LoopOp by making the continue block of the spirv::LoopOp the loop
@@ -135,7 +126,7 @@
loopOp.body().getBlocks().insert(std::next(loopOp.body().begin(), 1), header);
// Create the new induction variable to use.
- BlockArgument *newIndVar =
+ BlockArgument newIndVar =
header->addArgument(forOperands.lowerBound()->getType());
Block *body = forOp.getBody();
@@ -166,7 +157,7 @@
auto cmpOp = rewriter.create<spirv::SLessThanOp>(
loc, rewriter.getI1Type(), newIndVar, forOperands.upperBound());
rewriter.create<spirv::BranchConditionalOp>(
- loc, cmpOp, body, ArrayRef<Value *>(), mergeBlock, ArrayRef<Value *>());
+ loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>());
// Generate instructions to increment the step of the induction variable and
// branch to the header.
@@ -174,7 +165,7 @@
rewriter.setInsertionPointToEnd(continueBlock);
// Add the step to the induction variable and branch to the header.
- Value *updatedIndVar = rewriter.create<spirv::IAddOp>(
+ Value updatedIndVar = rewriter.create<spirv::IAddOp>(
loc, newIndVar->getType(), newIndVar, forOperands.step());
rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);
@@ -188,7 +179,7 @@
template <typename SourceOp, spirv::BuiltIn builtin>
PatternMatchResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
- SourceOp op, ArrayRef<Value *> operands,
+ SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
auto dimAttr =
op.getOperation()->template getAttrOfType<StringAttr>("dimension");
@@ -267,7 +258,7 @@
PatternMatchResult
KernelFnConversion::matchAndRewrite(gpu::GPUFuncOp funcOp,
- ArrayRef<Value *> operands,
+ ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (!gpu::GPUDialect::isKernel(funcOp)) {
return matchFailure();
@@ -297,7 +288,7 @@
//===----------------------------------------------------------------------===//
PatternMatchResult KernelModuleConversion::matchAndRewrite(
- ModuleOp moduleOp, ArrayRef<Value *> operands,
+ ModuleOp moduleOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (!moduleOp.getAttrOfType<UnitAttr>(
gpu::GPUDialect::getKernelModuleAttrName())) {
@@ -327,7 +318,7 @@
//===----------------------------------------------------------------------===//
PatternMatchResult KernelModuleTerminatorConversion::matchAndRewrite(
- ModuleTerminatorOp terminatorOp, ArrayRef<Value *> operands,
+ ModuleTerminatorOp terminatorOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
rewriter.replaceOpWithNewOp<spirv::ModuleEndOp>(terminatorOp);
return matchSuccess();
@@ -338,7 +329,7 @@
//===----------------------------------------------------------------------===//
PatternMatchResult GPUReturnOpConversion::matchAndRewrite(
- gpu::ReturnOp returnOp, ArrayRef<Value *> operands,
+ gpu::ReturnOp returnOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (!operands.empty())
return matchFailure();
diff --git a/third_party/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp b/third_party/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
index b8fe27e..68392c3 100644
--- a/third_party/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
+++ b/third_party/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
@@ -1,19 +1,10 @@
//===- ConvertGPUToSPIRVPass.cpp - GPU to SPIR-V dialect lowering passes --===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements a pass to convert a kernel function in the GPU Dialect
// into a spv.module operation
@@ -44,17 +35,17 @@
/// 2) Lower the body of the spirv::ModuleOp.
class GPUToSPIRVPass : public ModulePass<GPUToSPIRVPass> {
public:
- GPUToSPIRVPass(ArrayRef<int64_t> workGroupSize)
- : workGroupSize(workGroupSize.begin(), workGroupSize.end()) {}
+ GPUToSPIRVPass() = default;
+ GPUToSPIRVPass(const GPUToSPIRVPass &) {}
+ GPUToSPIRVPass(ArrayRef<int64_t> workGroupSize) {
+ this->workGroupSize = workGroupSize;
+ }
+
void runOnModule() override;
private:
- SmallVector<int64_t, 3> workGroupSize;
-};
-
-/// Command line option to specify the workgroup size.
-struct GPUToSPIRVPassOptions : public PassOptions<GPUToSPIRVPassOptions> {
- List<unsigned> workGroupSize{
+ /// Command line option to specify the workgroup size.
+ ListOption<int64_t> workGroupSize{
*this, "workgroup-size",
llvm::cl::desc(
"Workgroup Sizes in the SPIR-V module for x, followed by y, followed "
@@ -101,11 +92,5 @@
return std::make_unique<GPUToSPIRVPass>(workGroupSize);
}
-static PassRegistration<GPUToSPIRVPass, GPUToSPIRVPassOptions>
- pass("convert-gpu-to-spirv", "Convert GPU dialect to SPIR-V dialect",
- [](const GPUToSPIRVPassOptions &passOptions) {
- SmallVector<int64_t, 3> workGroupSize;
- workGroupSize.assign(passOptions.workGroupSize.begin(),
- passOptions.workGroupSize.end());
- return std::make_unique<GPUToSPIRVPass>(workGroupSize);
- });
+static PassRegistration<GPUToSPIRVPass>
+ pass("convert-gpu-to-spirv", "Convert GPU dialect to SPIR-V dialect");
diff --git a/third_party/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/third_party/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index 3eb23c1..2a034fd 100644
--- a/third_party/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/third_party/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -1,19 +1,10 @@
//===- LinalgToLLVM.cpp - conversion from Linalg to LLVM dialect ----------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
@@ -120,21 +111,21 @@
BaseViewConversionHelper(Type type)
: d(MemRefDescriptor::undef(rewriter(), loc(), type)) {}
- BaseViewConversionHelper(Value *v) : d(v) {}
+ BaseViewConversionHelper(Value v) : d(v) {}
/// Wrappers around MemRefDescriptor that use EDSC builder and location.
- Value *allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); }
- void setAllocatedPtr(Value *v) { d.setAllocatedPtr(rewriter(), loc(), v); }
- Value *alignedPtr() { return d.alignedPtr(rewriter(), loc()); }
- void setAlignedPtr(Value *v) { d.setAlignedPtr(rewriter(), loc(), v); }
- Value *offset() { return d.offset(rewriter(), loc()); }
- void setOffset(Value *v) { d.setOffset(rewriter(), loc(), v); }
- Value *size(unsigned i) { return d.size(rewriter(), loc(), i); }
- void setSize(unsigned i, Value *v) { d.setSize(rewriter(), loc(), i, v); }
- Value *stride(unsigned i) { return d.stride(rewriter(), loc(), i); }
- void setStride(unsigned i, Value *v) { d.setStride(rewriter(), loc(), i, v); }
+ Value allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); }
+ void setAllocatedPtr(Value v) { d.setAllocatedPtr(rewriter(), loc(), v); }
+ Value alignedPtr() { return d.alignedPtr(rewriter(), loc()); }
+ void setAlignedPtr(Value v) { d.setAlignedPtr(rewriter(), loc(), v); }
+ Value offset() { return d.offset(rewriter(), loc()); }
+ void setOffset(Value v) { d.setOffset(rewriter(), loc(), v); }
+ Value size(unsigned i) { return d.size(rewriter(), loc(), i); }
+ void setSize(unsigned i, Value v) { d.setSize(rewriter(), loc(), i, v); }
+ Value stride(unsigned i) { return d.stride(rewriter(), loc(), i); }
+ void setStride(unsigned i, Value v) { d.setStride(rewriter(), loc(), i, v); }
- operator Value *() { return d; }
+ operator Value() { return d; }
private:
OpBuilder &rewriter() { return ScopedContext::getBuilder(); }
@@ -151,7 +142,7 @@
: LLVMOpLowering(RangeOp::getOperationName(), context, lowering_) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto rangeOp = cast<RangeOp>(op);
auto rangeDescriptorTy =
@@ -161,7 +152,7 @@
// Fill in an aggregate value of the descriptor.
RangeOpOperandAdaptor adaptor(operands);
- Value *desc = llvm_undef(rangeDescriptorTy);
+ Value desc = llvm_undef(rangeDescriptorTy);
desc = insertvalue(desc, adaptor.min(), rewriter.getI64ArrayAttr(0));
desc = insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1));
desc = insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2));
@@ -184,7 +175,7 @@
: LLVMOpLowering(SliceOp::getOperationName(), context, lowering_) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
edsc::ScopedContext context(rewriter, op->getLoc());
SliceOpOperandAdaptor adaptor(operands);
@@ -198,7 +189,7 @@
BaseViewConversionHelper desc(lowering.convertType(sliceOp.getViewType()));
// TODO(ntv): extract sizes and emit asserts.
- SmallVector<Value *, 4> strides(memRefType.getRank());
+ SmallVector<Value, 4> strides(memRefType.getRank());
for (int i = 0, e = memRefType.getRank(); i < e; ++i)
strides[i] = baseDesc.stride(i);
@@ -207,10 +198,10 @@
};
// Compute base offset.
- Value *baseOffset = baseDesc.offset();
+ Value baseOffset = baseDesc.offset();
for (int i = 0, e = memRefType.getRank(); i < e; ++i) {
- Value *indexing = adaptor.indexings()[i];
- Value *min = indexing;
+ Value indexing = adaptor.indexings()[i];
+ Value min = indexing;
if (sliceOp.indexing(i)->getType().isa<RangeType>())
min = extractvalue(int64Ty, indexing, pos(0));
baseOffset = add(baseOffset, mul(min, strides[i]));
@@ -227,29 +218,29 @@
if (sliceOp.getViewType().getRank() == 0)
return rewriter.replaceOp(op, {desc}), matchSuccess();
- Value *zero =
+ Value zero =
constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
// Compute and insert view sizes (max - min along the range) and strides.
// Skip the non-range operands as they will be projected away from the view.
int numNewDims = 0;
for (auto en : llvm::enumerate(sliceOp.indexings())) {
- Value *indexing = en.value();
+ Value indexing = en.value();
if (indexing->getType().isa<RangeType>()) {
int rank = en.index();
- Value *rangeDescriptor = adaptor.indexings()[rank];
- Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0));
- Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
- Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2));
- Value *baseSize = baseDesc.size(rank);
+ Value rangeDescriptor = adaptor.indexings()[rank];
+ Value min = extractvalue(int64Ty, rangeDescriptor, pos(0));
+ Value max = extractvalue(int64Ty, rangeDescriptor, pos(1));
+ Value step = extractvalue(int64Ty, rangeDescriptor, pos(2));
+ Value baseSize = baseDesc.size(rank);
// Bound upper by base view upper bound.
max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max,
baseSize);
- Value *size = sub(max, min);
+ Value size = sub(max, min);
// Bound lower by zero.
size =
llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size);
- Value *stride = mul(strides[rank], step);
+ Value stride = mul(strides[rank], step);
desc.setSize(numNewDims, size);
desc.setStride(numNewDims, stride);
++numNewDims;
@@ -275,7 +266,7 @@
: LLVMOpLowering(TransposeOp::getOperationName(), context, lowering_) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// Initialize the common boilerplate and alloca at the top of the FuncOp.
edsc::ScopedContext context(rewriter, op->getLoc());
@@ -318,7 +309,7 @@
: LLVMOpLowering(YieldOp::getOperationName(), context, lowering_) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands);
return matchSuccess();
@@ -453,7 +444,7 @@
op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
auto indexedGenericOp = cast<IndexedGenericOp>(op);
auto numLoops = indexedGenericOp.getNumLoops();
- SmallVector<Value *, 4> operands;
+ SmallVector<Value, 4> operands;
operands.reserve(numLoops + op.getNumOperands());
for (unsigned i = 0; i < numLoops; ++i) {
operands.push_back(zero);
@@ -477,7 +468,7 @@
PatternMatchResult matchAndRewrite(CopyOp op,
PatternRewriter &rewriter) const override {
- Value *in = op.input(), *out = op.output();
+ Value in = op.input(), out = op.output();
// If either inputPerm or outputPerm are non-identities, insert transposes.
auto inputPerm = op.inputPermutation();
diff --git a/third_party/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp b/third_party/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp
index ff93ce5..b257e9b 100644
--- a/third_party/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp
+++ b/third_party/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp
@@ -1,19 +1,10 @@
//===- ConvertLoopToStandard.cpp - ControlFlow to CFG conversion ----------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements a pass to convert loop.for, loop.if and loop.terminator
// ops into standard CFG ops.
@@ -182,22 +173,22 @@
rewriter.splitBlock(conditionBlock, conditionBlock->begin());
auto *lastBodyBlock = &forOp.region().back();
rewriter.inlineRegionBefore(forOp.region(), endBlock);
- auto *iv = conditionBlock->getArgument(0);
+ auto iv = conditionBlock->getArgument(0);
// Append the induction variable stepping logic to the last body block and
// branch back to the condition block. Construct an expression f :
// (x -> x+step) and apply this expression to the induction variable.
rewriter.setInsertionPointToEnd(lastBodyBlock);
- auto *step = forOp.step();
- auto *stepped = rewriter.create<AddIOp>(loc, iv, step).getResult();
+ auto step = forOp.step();
+ auto stepped = rewriter.create<AddIOp>(loc, iv, step).getResult();
if (!stepped)
return matchFailure();
rewriter.create<BranchOp>(loc, conditionBlock, stepped);
// Compute loop bounds before branching to the condition.
rewriter.setInsertionPointToEnd(initBlock);
- Value *lowerBound = forOp.lowerBound();
- Value *upperBound = forOp.upperBound();
+ Value lowerBound = forOp.lowerBound();
+ Value upperBound = forOp.upperBound();
if (!lowerBound || !upperBound)
return matchFailure();
rewriter.create<BranchOp>(loc, conditionBlock, lowerBound);
@@ -208,8 +199,7 @@
rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, iv, upperBound);
rewriter.create<CondBranchOp>(loc, comparison, firstBodyBlock,
- ArrayRef<Value *>(), endBlock,
- ArrayRef<Value *>());
+ ArrayRef<Value>(), endBlock, ArrayRef<Value>());
// Ok, we're done!
rewriter.eraseOp(forOp);
return matchSuccess();
@@ -248,8 +238,8 @@
rewriter.setInsertionPointToEnd(condBlock);
rewriter.create<CondBranchOp>(loc, ifOp.condition(), thenBlock,
- /*trueArgs=*/ArrayRef<Value *>(), elseBlock,
- /*falseArgs=*/ArrayRef<Value *>());
+ /*trueArgs=*/ArrayRef<Value>(), elseBlock,
+ /*falseArgs=*/ArrayRef<Value>());
// Ok, we're done!
rewriter.eraseOp(ifOp);
diff --git a/third_party/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp b/third_party/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp
index c269dc5..e500d10 100644
--- a/third_party/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp
+++ b/third_party/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp
@@ -1,19 +1,10 @@
//===- LoopsToGPU.cpp - Convert an affine loop nest to a GPU kernel -------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This implements a straightforward conversion of an loop nest into a GPU
// kernel. The caller is expected to guarantee that the conversion is correct
@@ -43,7 +34,7 @@
using llvm::seq;
// Extract an indexed value from KernelDim3.
-static Value *getDim3Value(const gpu::KernelDim3 &dim3, unsigned pos) {
+static Value getDim3Value(const gpu::KernelDim3 &dim3, unsigned pos) {
switch (pos) {
case 0:
return dim3.x;
@@ -61,8 +52,8 @@
static Operation::operand_range getLowerBoundOperands(AffineForOp forOp) {
return forOp.getLowerBoundOperands();
}
-static SmallVector<Value *, 1> getLowerBoundOperands(ForOp forOp) {
- SmallVector<Value *, 1> bounds(1, forOp.lowerBound());
+static SmallVector<Value, 1> getLowerBoundOperands(ForOp forOp) {
+ SmallVector<Value, 1> bounds(1, forOp.lowerBound());
return bounds;
}
@@ -70,33 +61,33 @@
static Operation::operand_range getUpperBoundOperands(AffineForOp forOp) {
return forOp.getUpperBoundOperands();
}
-static SmallVector<Value *, 1> getUpperBoundOperands(ForOp forOp) {
- SmallVector<Value *, 1> bounds(1, forOp.upperBound());
+static SmallVector<Value, 1> getUpperBoundOperands(ForOp forOp) {
+ SmallVector<Value, 1> bounds(1, forOp.upperBound());
return bounds;
}
// Get a Value that corresponds to the loop step. If the step is an attribute,
// materialize a corresponding constant using builder.
-static Value *getOrCreateStep(AffineForOp forOp, OpBuilder &builder) {
+static Value getOrCreateStep(AffineForOp forOp, OpBuilder &builder) {
return builder.create<ConstantIndexOp>(forOp.getLoc(), forOp.getStep());
}
-static Value *getOrCreateStep(ForOp forOp, OpBuilder &) { return forOp.step(); }
+static Value getOrCreateStep(ForOp forOp, OpBuilder &) { return forOp.step(); }
// Get a Value for the loop lower bound. If the value requires computation,
// materialize the instructions using builder.
-static Value *getOrEmitLowerBound(AffineForOp forOp, OpBuilder &builder) {
+static Value getOrEmitLowerBound(AffineForOp forOp, OpBuilder &builder) {
return lowerAffineLowerBound(forOp, builder);
}
-static Value *getOrEmitLowerBound(ForOp forOp, OpBuilder &) {
+static Value getOrEmitLowerBound(ForOp forOp, OpBuilder &) {
return forOp.lowerBound();
}
// Get a Value for the loop upper bound. If the value requires computation,
// materialize the instructions using builder.
-static Value *getOrEmitUpperBound(AffineForOp forOp, OpBuilder &builder) {
+static Value getOrEmitUpperBound(AffineForOp forOp, OpBuilder &builder) {
return lowerAffineUpperBound(forOp, builder);
}
-static Value *getOrEmitUpperBound(ForOp forOp, OpBuilder &) {
+static Value getOrEmitUpperBound(ForOp forOp, OpBuilder &) {
return forOp.upperBound();
}
@@ -212,18 +203,18 @@
unsigned numThreadDims);
// Ranges of the loops mapped to blocks or threads.
- SmallVector<Value *, 6> dims;
+ SmallVector<Value, 6> dims;
// Lower bounds of the loops mapped to blocks or threads.
- SmallVector<Value *, 6> lbs;
+ SmallVector<Value, 6> lbs;
// Induction variables of the loops mapped to blocks or threads.
- SmallVector<Value *, 6> ivs;
+ SmallVector<Value, 6> ivs;
// Steps of the loops mapped to blocks or threads.
- SmallVector<Value *, 6> steps;
+ SmallVector<Value, 6> steps;
};
} // namespace
// Return true if the value is obviously a constant "one".
-static bool isConstantOne(Value *value) {
+static bool isConstantOne(Value value) {
if (auto def = dyn_cast_or_null<ConstantIndexOp>(value->getDefiningOp()))
return def.getValue() == 1;
return false;
@@ -244,17 +235,17 @@
steps.reserve(numLoops);
OpTy currentLoop = forOp;
for (unsigned i = 0; i < numLoops; ++i) {
- Value *lowerBound = getOrEmitLowerBound(currentLoop, builder);
- Value *upperBound = getOrEmitUpperBound(currentLoop, builder);
+ Value lowerBound = getOrEmitLowerBound(currentLoop, builder);
+ Value upperBound = getOrEmitUpperBound(currentLoop, builder);
if (!lowerBound || !upperBound) {
return llvm::None;
}
- Value *range =
+ Value range =
builder.create<SubIOp>(currentLoop.getLoc(), upperBound, lowerBound);
- Value *step = getOrCreateStep(currentLoop, builder);
+ Value step = getOrCreateStep(currentLoop, builder);
if (!isConstantOne(step))
- range = builder.create<DivISOp>(currentLoop.getLoc(), range, step);
+ range = builder.create<SignedDivIOp>(currentLoop.getLoc(), range, step);
dims.push_back(range);
lbs.push_back(lowerBound);
@@ -274,8 +265,8 @@
/// `nids`. The innermost loop is mapped to the x-dimension, followed by the
/// next innermost loop to y-dimension, followed by z-dimension.
template <typename OpTy>
-OpTy createGPULaunchLoops(OpTy rootForOp, ArrayRef<Value *> ids,
- ArrayRef<Value *> nids) {
+OpTy createGPULaunchLoops(OpTy rootForOp, ArrayRef<Value> ids,
+ ArrayRef<Value> nids) {
auto nDims = ids.size();
assert(nDims == nids.size());
for (auto dim : llvm::seq<unsigned>(0, nDims)) {
@@ -295,11 +286,11 @@
/// each workgroup/workitem and number of workgroup/workitems along a dimension
/// of the launch into a container.
void packIdAndNumId(gpu::KernelDim3 kernelIds, gpu::KernelDim3 kernelNids,
- unsigned nDims, SmallVectorImpl<Value *> &ids,
- SmallVectorImpl<Value *> &nids) {
+ unsigned nDims, SmallVectorImpl<Value> &ids,
+ SmallVectorImpl<Value> &nids) {
assert(nDims <= 3 && "invalid number of launch dimensions");
- SmallVector<Value *, 3> allIds = {kernelIds.z, kernelIds.y, kernelIds.x};
- SmallVector<Value *, 3> allNids = {kernelNids.z, kernelNids.y, kernelNids.x};
+ SmallVector<Value, 3> allIds = {kernelIds.z, kernelIds.y, kernelIds.x};
+ SmallVector<Value, 3> allNids = {kernelNids.z, kernelNids.y, kernelNids.x};
ids.clear();
ids.append(std::next(allIds.begin(), allIds.size() - nDims), allIds.end());
nids.clear();
@@ -317,7 +308,7 @@
auto returnOp = builder.create<gpu::ReturnOp>(launchOp.getLoc());
rootForOp.getOperation()->moveBefore(returnOp);
- SmallVector<Value *, 3> workgroupID, numWorkGroups;
+ SmallVector<Value, 3> workgroupID, numWorkGroups;
packIdAndNumId(launchOp.getBlockIds(), launchOp.getGridSize(), numBlockDims,
workgroupID, numWorkGroups);
@@ -333,7 +324,7 @@
}
}
- SmallVector<Value *, 3> workItemID, workGroupSize;
+ SmallVector<Value, 3> workItemID, workGroupSize;
packIdAndNumId(launchOp.getThreadIds(), launchOp.getBlockSize(),
numThreadDims, workItemID, workGroupSize);
for (auto &loopOp : threadRootForOps) {
@@ -346,18 +337,17 @@
// Convert the computation rooted at the `rootForOp`, into a GPU kernel with the
// given workgroup size and number of workgroups.
template <typename OpTy>
-LogicalResult createLaunchFromOp(OpTy rootForOp,
- ArrayRef<Value *> numWorkGroups,
- ArrayRef<Value *> workGroupSizes) {
+LogicalResult createLaunchFromOp(OpTy rootForOp, ArrayRef<Value> numWorkGroups,
+ ArrayRef<Value> workGroupSizes) {
OpBuilder builder(rootForOp.getOperation());
if (numWorkGroups.size() > 3) {
return rootForOp.emitError("invalid ")
<< numWorkGroups.size() << "-D workgroup specification";
}
auto loc = rootForOp.getLoc();
- Value *one = builder.create<ConstantOp>(
+ Value one = builder.create<ConstantOp>(
loc, builder.getIntegerAttr(builder.getIndexType(), 1));
- SmallVector<Value *, 3> numWorkGroups3D(3, one), workGroupSize3D(3, one);
+ SmallVector<Value, 3> numWorkGroups3D(3, one), workGroupSize3D(3, one);
for (auto numWorkGroup : enumerate(numWorkGroups)) {
numWorkGroups3D[numWorkGroup.index()] = numWorkGroup.value();
}
@@ -367,7 +357,7 @@
// Get the values used within the region of the rootForOp but defined above
// it.
- llvm::SetVector<Value *> valuesToForwardSet;
+ llvm::SetVector<Value> valuesToForwardSet;
getUsedValuesDefinedAbove(rootForOp.region(), rootForOp.region(),
valuesToForwardSet);
// Also add the values used for the lb, ub, and step of the rootForOp.
@@ -387,8 +377,8 @@
// defined outside. They all are replaced with kernel arguments.
for (const auto &pair :
llvm::zip_first(valuesToForward, launchOp.getKernelArguments())) {
- Value *from = std::get<0>(pair);
- Value *to = std::get<1>(pair);
+ Value from = std::get<0>(pair);
+ Value to = std::get<1>(pair);
replaceAllUsesInRegionWith(from, to, launchOp.body());
}
return success();
@@ -408,22 +398,22 @@
OpBuilder builder(rootForOp.getOperation());
// Prepare the grid and block sizes for the launch operation. If there is
// no loop mapped to a specific dimension, use constant "1" as its size.
- Value *constOne = (numBlockDims < 3 || numThreadDims < 3)
- ? builder.create<ConstantIndexOp>(rootForOp.getLoc(), 1)
- : nullptr;
- Value *gridSizeX = dims[0];
- Value *gridSizeY = numBlockDims > 1 ? dims[1] : constOne;
- Value *gridSizeZ = numBlockDims > 2 ? dims[2] : constOne;
- Value *blockSizeX = dims[numBlockDims];
- Value *blockSizeY = numThreadDims > 1 ? dims[numBlockDims + 1] : constOne;
- Value *blockSizeZ = numThreadDims > 2 ? dims[numBlockDims + 2] : constOne;
+ Value constOne = (numBlockDims < 3 || numThreadDims < 3)
+ ? builder.create<ConstantIndexOp>(rootForOp.getLoc(), 1)
+ : nullptr;
+ Value gridSizeX = dims[0];
+ Value gridSizeY = numBlockDims > 1 ? dims[1] : constOne;
+ Value gridSizeZ = numBlockDims > 2 ? dims[2] : constOne;
+ Value blockSizeX = dims[numBlockDims];
+ Value blockSizeY = numThreadDims > 1 ? dims[numBlockDims + 1] : constOne;
+ Value blockSizeZ = numThreadDims > 2 ? dims[numBlockDims + 2] : constOne;
// Create a launch op and move the body region of the innermost loop to the
// launch op. Pass the values defined outside the outermost loop and used
// inside the innermost loop and loop lower bounds as kernel data arguments.
// Still assuming perfect nesting so there are no values other than induction
// variables that are defined in one loop and used in deeper loops.
- llvm::SetVector<Value *> valuesToForwardSet;
+ llvm::SetVector<Value> valuesToForwardSet;
getUsedValuesDefinedAbove(innermostForOp.region(), rootForOp.region(),
valuesToForwardSet);
auto valuesToForward = valuesToForwardSet.takeVector();
@@ -457,15 +447,15 @@
originallyForwardedValues);
auto stepArgumentIt = std::next(lbArgumentIt, lbs.size());
for (auto en : llvm::enumerate(ivs)) {
- Value *id =
+ Value id =
en.index() < numBlockDims
? getDim3Value(launchOp.getBlockIds(), en.index())
: getDim3Value(launchOp.getThreadIds(), en.index() - numBlockDims);
- Value *step = steps[en.index()];
+ Value step = steps[en.index()];
if (!isConstantOne(step))
id = builder.create<MulIOp>(rootForOp.getLoc(), step, id);
- Value *ivReplacement =
+ Value ivReplacement =
builder.create<AddIOp>(rootForOp.getLoc(), *lbArgumentIt, id);
en.value()->replaceAllUsesWith(ivReplacement);
replaceAllUsesInRegionWith(steps[en.index()], *stepArgumentIt,
@@ -479,8 +469,8 @@
// trailing positions, make sure we don't touch those.
for (const auto &pair :
llvm::zip_first(valuesToForward, launchOp.getKernelArguments())) {
- Value *from = std::get<0>(pair);
- Value *to = std::get<1>(pair);
+ Value from = std::get<0>(pair);
+ Value to = std::get<1>(pair);
replaceAllUsesInRegionWith(from, to, launchOp.body());
}
@@ -510,8 +500,8 @@
// nested. The workgroup size and num workgroups is provided as input
template <typename OpTy>
static LogicalResult convertLoopToGPULaunch(OpTy forOp,
- ArrayRef<Value *> numWorkGroups,
- ArrayRef<Value *> workGroupSize) {
+ ArrayRef<Value> numWorkGroups,
+ ArrayRef<Value> workGroupSize) {
if (failed(checkLoopOpMappable(forOp, numWorkGroups.size(),
workGroupSize.size()))) {
return failure();
@@ -532,7 +522,7 @@
}
LogicalResult mlir::convertLoopToGPULaunch(loop::ForOp forOp,
- ArrayRef<Value *> numWorkGroups,
- ArrayRef<Value *> workGroupSizes) {
+ ArrayRef<Value> numWorkGroups,
+ ArrayRef<Value> workGroupSizes) {
return ::convertLoopToGPULaunch(forOp, numWorkGroups, workGroupSizes);
}
diff --git a/third_party/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp b/third_party/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp
index 21abc3c..c3bbf27 100644
--- a/third_party/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp
+++ b/third_party/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp
@@ -1,19 +1,10 @@
//===- LoopsToGPUPass.cpp - Convert a loop nest to a GPU kernel -----------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h"
#include "mlir/Conversion/LoopsToGPU/LoopsToGPU.h"
@@ -98,7 +89,7 @@
// pass is only used for testing.
FuncOp funcOp = getFunction();
OpBuilder builder(funcOp.getOperation()->getRegion(0));
- SmallVector<Value *, 3> numWorkGroupsVal, workGroupSizeVal;
+ SmallVector<Value, 3> numWorkGroupsVal, workGroupSizeVal;
for (auto val : numWorkGroups) {
auto constOp = builder.create<ConstantOp>(
funcOp.getLoc(), builder.getIntegerAttr(builder.getIndexType(), val));
diff --git a/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index ea8501b..0c96cc5 100644
--- a/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -1,19 +1,10 @@
//===- ConvertStandardToLLVM.cpp - Standard to LLVM dialect conversion-----===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements a pass to convert MLIR standard and builtin dialects
// into the LLVM IR dialect.
@@ -256,20 +247,20 @@
/*============================================================================*/
/* StructBuilder implementation */
/*============================================================================*/
-StructBuilder::StructBuilder(Value *v) : value(v) {
+StructBuilder::StructBuilder(Value v) : value(v) {
assert(value != nullptr && "value cannot be null");
structType = value->getType().cast<LLVM::LLVMType>();
}
-Value *StructBuilder::extractPtr(OpBuilder &builder, Location loc,
- unsigned pos) {
+Value StructBuilder::extractPtr(OpBuilder &builder, Location loc,
+ unsigned pos) {
Type type = structType.cast<LLVM::LLVMType>().getStructElementType(pos);
return builder.create<LLVM::ExtractValueOp>(loc, type, value,
builder.getI64ArrayAttr(pos));
}
void StructBuilder::setPtr(OpBuilder &builder, Location loc, unsigned pos,
- Value *ptr) {
+ Value ptr) {
value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr,
builder.getI64ArrayAttr(pos));
}
@@ -278,7 +269,7 @@
/*============================================================================*/
/// Construct a helper for the given descriptor value.
-MemRefDescriptor::MemRefDescriptor(Value *descriptor)
+MemRefDescriptor::MemRefDescriptor(Value descriptor)
: StructBuilder(descriptor) {
assert(value != nullptr && "value cannot be null");
indexType = value->getType().cast<LLVM::LLVMType>().getStructElementType(
@@ -289,7 +280,7 @@
MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc,
Type descriptorType) {
- Value *descriptor =
+ Value descriptor =
builder.create<LLVM::UndefOp>(loc, descriptorType.cast<LLVM::LLVMType>());
return MemRefDescriptor(descriptor);
}
@@ -300,7 +291,7 @@
MemRefDescriptor
MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
- MemRefType type, Value *memory) {
+ MemRefType type, Value memory) {
assert(type.hasStaticShape() && "unexpected dynamic shape");
assert(type.getAffineMaps().empty() && "unexpected layout map");
@@ -325,37 +316,37 @@
}
/// Builds IR extracting the allocated pointer from the descriptor.
-Value *MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) {
+Value MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) {
return extractPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor);
}
/// Builds IR inserting the allocated pointer into the descriptor.
void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc,
- Value *ptr) {
+ Value ptr) {
setPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor, ptr);
}
/// Builds IR extracting the aligned pointer from the descriptor.
-Value *MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) {
+Value MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) {
return extractPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor);
}
/// Builds IR inserting the aligned pointer into the descriptor.
void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
- Value *ptr) {
+ Value ptr) {
setPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor, ptr);
}
// Creates a constant Op producing a value of `resultType` from an index-typed
// integer attribute.
-static Value *createIndexAttrConstant(OpBuilder &builder, Location loc,
- Type resultType, int64_t value) {
+static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
+ Type resultType, int64_t value) {
return builder.create<LLVM::ConstantOp>(
loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
}
/// Builds IR extracting the offset from the descriptor.
-Value *MemRefDescriptor::offset(OpBuilder &builder, Location loc) {
+Value MemRefDescriptor::offset(OpBuilder &builder, Location loc) {
return builder.create<LLVM::ExtractValueOp>(
loc, indexType, value,
builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor));
@@ -363,7 +354,7 @@
/// Builds IR inserting the offset into the descriptor.
void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
- Value *offset) {
+ Value offset) {
value = builder.create<LLVM::InsertValueOp>(
loc, structType, value, offset,
builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor));
@@ -377,7 +368,7 @@
}
/// Builds IR extracting the pos-th size from the descriptor.
-Value *MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) {
+Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) {
return builder.create<LLVM::ExtractValueOp>(
loc, indexType, value,
builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos}));
@@ -385,7 +376,7 @@
/// Builds IR inserting the pos-th size into the descriptor
void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos,
- Value *size) {
+ Value size) {
value = builder.create<LLVM::InsertValueOp>(
loc, structType, value, size,
builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos}));
@@ -399,8 +390,7 @@
}
/// Builds IR extracting the pos-th size from the descriptor.
-Value *MemRefDescriptor::stride(OpBuilder &builder, Location loc,
- unsigned pos) {
+Value MemRefDescriptor::stride(OpBuilder &builder, Location loc, unsigned pos) {
return builder.create<LLVM::ExtractValueOp>(
loc, indexType, value,
builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos}));
@@ -408,7 +398,7 @@
/// Builds IR inserting the pos-th stride into the descriptor
void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos,
- Value *stride) {
+ Value stride) {
value = builder.create<LLVM::InsertValueOp>(
loc, structType, value, stride,
builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos}));
@@ -431,30 +421,30 @@
/*============================================================================*/
/// Construct a helper for the given descriptor value.
-UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value *descriptor)
+UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value descriptor)
: StructBuilder(descriptor) {}
/// Builds IR creating an `undef` value of the descriptor type.
UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder,
Location loc,
Type descriptorType) {
- Value *descriptor =
+ Value descriptor =
builder.create<LLVM::UndefOp>(loc, descriptorType.cast<LLVM::LLVMType>());
return UnrankedMemRefDescriptor(descriptor);
}
-Value *UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) {
+Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) {
return extractPtr(builder, loc, kRankInUnrankedMemRefDescriptor);
}
void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc,
- Value *v) {
+ Value v) {
setPtr(builder, loc, kRankInUnrankedMemRefDescriptor, v);
}
-Value *UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder,
- Location loc) {
+Value UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder,
+ Location loc) {
return extractPtr(builder, loc, kPtrInUnrankedMemRefDescriptor);
}
void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder,
- Location loc, Value *v) {
+ Location loc, Value v) {
setPtr(builder, loc, kPtrInUnrankedMemRefDescriptor, v);
}
namespace {
@@ -495,8 +485,8 @@
}
// Create an LLVM IR pseudo-operation defining the given index constant.
- Value *createIndexConstant(ConversionPatternRewriter &builder, Location loc,
- uint64_t value) const {
+ Value createIndexConstant(ConversionPatternRewriter &builder, Location loc,
+ uint64_t value) const {
return createIndexAttrConstant(builder, loc, getIndexType(), value);
}
@@ -508,7 +498,7 @@
using LLVMLegalizationPattern<FuncOp>::LLVMLegalizationPattern;
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto funcOp = cast<FuncOp>(op);
FunctionType type = funcOp.getType();
@@ -556,8 +546,8 @@
Block *firstBlock = &newFuncOp.getBody().front();
rewriter.setInsertionPoint(firstBlock, firstBlock->begin());
for (unsigned idx : promotedArgIndices) {
- BlockArgument *arg = firstBlock->getArgument(idx);
- Value *loaded = rewriter.create<LLVM::LoadOp>(funcOp.getLoc(), arg);
+ BlockArgument arg = firstBlock->getArgument(idx);
+ Value loaded = rewriter.create<LLVM::LoadOp>(funcOp.getLoc(), arg);
rewriter.replaceUsesOfBlockArgument(arg, loaded);
}
}
@@ -656,7 +646,7 @@
// Convert the type of the result to an LLVM type, pass operands as is,
// preserve attributes.
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
unsigned numResults = op->getNumResults();
@@ -680,7 +670,7 @@
// Otherwise, it had been converted to an operation producing a structure.
// Extract individual results from the structure and return them as list.
- SmallVector<Value *, 4> results;
+ SmallVector<Value, 4> results;
results.reserve(numResults);
for (unsigned i = 0; i < numResults; ++i) {
auto type = this->lowering.convertType(op->getResult(i)->getType());
@@ -721,7 +711,7 @@
// Convert the type of the result to an LLVM type, pass operands as is,
// preserve attributes.
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
ValidateOpCount<SourceOp, OpCount>();
static_assert(
@@ -732,7 +722,7 @@
"expected same operands and result type");
// Cannot convert ops if their operands are not of LLVM type.
- for (Value *operand : operands) {
+ for (Value operand : operands) {
if (!operand || !operand->getType().isa<LLVM::LLVMType>())
return this->matchFailure();
}
@@ -755,16 +745,16 @@
if (!llvmVectorTy || llvmArrayTy != vectorTypeInfo.llvmArrayTy)
return this->matchFailure();
- Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy);
+ Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy);
nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
// For this unrolled `position` corresponding to the `linearIndex`^th
// element, extract operand vectors
- SmallVector<Value *, OpCount> extractedOperands;
+ SmallVector<Value, OpCount> extractedOperands;
for (unsigned i = 0; i < OpCount; ++i) {
extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
loc, llvmVectorTy, operands[i], position));
}
- Value *newVal = rewriter.create<TargetOp>(
+ Value newVal = rewriter.create<TargetOp>(
loc, llvmVectorTy, extractedOperands, op->getAttrs());
desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc,
newVal, position);
@@ -814,16 +804,20 @@
struct MulIOpLowering : public BinaryOpLLVMOpLowering<MulIOp, LLVM::MulOp> {
using Super::Super;
};
-struct DivISOpLowering : public BinaryOpLLVMOpLowering<DivISOp, LLVM::SDivOp> {
+struct SignedDivIOpLowering
+ : public BinaryOpLLVMOpLowering<SignedDivIOp, LLVM::SDivOp> {
using Super::Super;
};
-struct DivIUOpLowering : public BinaryOpLLVMOpLowering<DivIUOp, LLVM::UDivOp> {
+struct UnsignedDivIOpLowering
+ : public BinaryOpLLVMOpLowering<UnsignedDivIOp, LLVM::UDivOp> {
using Super::Super;
};
-struct RemISOpLowering : public BinaryOpLLVMOpLowering<RemISOp, LLVM::SRemOp> {
+struct SignedRemIOpLowering
+ : public BinaryOpLLVMOpLowering<SignedRemIOp, LLVM::SRemOp> {
using Super::Super;
};
-struct RemIUOpLowering : public BinaryOpLLVMOpLowering<RemIUOp, LLVM::URemOp> {
+struct UnsignedRemIOpLowering
+ : public BinaryOpLLVMOpLowering<UnsignedRemIOp, LLVM::URemOp> {
using Super::Super;
};
struct AndOpLowering : public BinaryOpLLVMOpLowering<AndOp, LLVM::AndOp> {
@@ -862,6 +856,18 @@
: public OneToOneLLVMOpLowering<ConstantOp, LLVM::ConstantOp> {
using Super::Super;
};
+struct ShiftLeftOpLowering
+ : public OneToOneLLVMOpLowering<ShiftLeftOp, LLVM::ShlOp> {
+ using Super::Super;
+};
+struct SignedShiftRightOpLowering
+ : public OneToOneLLVMOpLowering<SignedShiftRightOp, LLVM::AShrOp> {
+ using Super::Super;
+};
+struct UnsignedShiftRightOpLowering
+ : public OneToOneLLVMOpLowering<UnsignedShiftRightOp, LLVM::LShrOp> {
+ using Super::Super;
+};
// Check if the MemRefType `type` is supported by the lowering. We currently
// only support memrefs with identity maps.
@@ -911,7 +917,7 @@
return matchSuccess();
}
- void rewrite(Operation *op, ArrayRef<Value *> operands,
+ void rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto allocOp = cast<AllocOp>(op);
@@ -920,7 +926,7 @@
// Get actual sizes of the memref as values: static sizes are constant
// values and dynamic sizes are passed to 'alloc' as operands. In case of
// zero-dimensional memref, assume a scalar (size 1).
- SmallVector<Value *, 4> sizes;
+ SmallVector<Value, 4> sizes;
sizes.reserve(type.getRank());
unsigned i = 0;
for (int64_t s : type.getShape())
@@ -930,10 +936,10 @@
sizes.push_back(createIndexConstant(rewriter, loc, 1));
// Compute the total number of memref elements.
- Value *cumulativeSize = sizes.front();
+ Value cumulativeSize = sizes.front();
for (unsigned i = 1, e = sizes.size(); i < e; ++i)
cumulativeSize = rewriter.create<LLVM::MulOp>(
- loc, getIndexType(), ArrayRef<Value *>{cumulativeSize, sizes[i]});
+ loc, getIndexType(), ArrayRef<Value>{cumulativeSize, sizes[i]});
// Compute the size of an individual element. This emits the MLIR equivalent
// of the following sizeof(...) implementation in LLVM IR:
@@ -946,17 +952,17 @@
auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
auto one = createIndexConstant(rewriter, loc, 1);
auto gep = rewriter.create<LLVM::GEPOp>(loc, convertedPtrType,
- ArrayRef<Value *>{nullPtr, one});
+ ArrayRef<Value>{nullPtr, one});
auto elementSize =
rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
cumulativeSize = rewriter.create<LLVM::MulOp>(
- loc, getIndexType(), ArrayRef<Value *>{cumulativeSize, elementSize});
+ loc, getIndexType(), ArrayRef<Value>{cumulativeSize, elementSize});
// Allocate the underlying buffer and store a pointer to it in the MemRef
// descriptor.
- Value *allocated = nullptr;
+ Value allocated = nullptr;
int alignment = 0;
- Value *alignmentValue = nullptr;
+ Value alignmentValue = nullptr;
if (auto alignAttr = allocOp.alignment())
alignment = alignAttr.getValue().getSExtValue();
@@ -992,8 +998,8 @@
auto structElementType = lowering.convertType(elementType);
auto elementPtrType = structElementType.cast<LLVM::LLVMType>().getPointerTo(
type.getMemorySpace());
- Value *bitcastAllocated = rewriter.create<LLVM::BitcastOp>(
- loc, elementPtrType, ArrayRef<Value *>(allocated));
+ Value bitcastAllocated = rewriter.create<LLVM::BitcastOp>(
+ loc, elementPtrType, ArrayRef<Value>(allocated));
int64_t offset;
SmallVector<int64_t, 4> strides;
@@ -1015,22 +1021,21 @@
memRefDescriptor.setAllocatedPtr(rewriter, loc, bitcastAllocated);
// Field 2: Actual aligned pointer to payload.
- Value *bitcastAligned = bitcastAllocated;
+ Value bitcastAligned = bitcastAllocated;
if (!useAlloca && alignment != 0) {
assert(alignmentValue);
// offset = (align - (ptr % align))% align
- Value *intVal = rewriter.create<LLVM::PtrToIntOp>(
+ Value intVal = rewriter.create<LLVM::PtrToIntOp>(
loc, this->getIndexType(), allocated);
- Value *ptrModAlign =
+ Value ptrModAlign =
rewriter.create<LLVM::URemOp>(loc, intVal, alignmentValue);
- Value *subbed =
+ Value subbed =
rewriter.create<LLVM::SubOp>(loc, alignmentValue, ptrModAlign);
- Value *offset =
- rewriter.create<LLVM::URemOp>(loc, subbed, alignmentValue);
- Value *aligned = rewriter.create<LLVM::GEPOp>(loc, allocated->getType(),
- allocated, offset);
+ Value offset = rewriter.create<LLVM::URemOp>(loc, subbed, alignmentValue);
+ Value aligned = rewriter.create<LLVM::GEPOp>(loc, allocated->getType(),
+ allocated, offset);
bitcastAligned = rewriter.create<LLVM::BitcastOp>(
- loc, elementPtrType, ArrayRef<Value *>(aligned));
+ loc, elementPtrType, ArrayRef<Value>(aligned));
}
memRefDescriptor.setAlignedPtr(rewriter, loc, bitcastAligned);
@@ -1045,10 +1050,10 @@
// Fields 4 and 5: Sizes and strides of the strided MemRef.
// Store all sizes in the descriptor. Only dynamic sizes are passed in as
// operands to AllocOp.
- Value *runningStride = nullptr;
+ Value runningStride = nullptr;
// Iterate strides in reverse order, compute runningStride and strideValues.
auto nStrides = strides.size();
- SmallVector<Value *, 4> strideValues(nStrides, nullptr);
+ SmallVector<Value, 4> strideValues(nStrides, nullptr);
for (auto indexedStride : llvm::enumerate(llvm::reverse(strides))) {
int64_t index = nStrides - 1 - indexedStride.index();
if (strides[index] == MemRefType::getDynamicStrideOrOffset())
@@ -1085,7 +1090,7 @@
using Base = LLVMLegalizationPattern<CallOpType>;
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
OperandAdaptor<CallOpType> transformed(operands);
auto callOp = cast<CallOpType>(op);
@@ -1123,7 +1128,7 @@
// TODO(aminim, ntv, riverriddle, zinenko): this seems like patching around
// a particular interaction between MemRefType and CallOp lowering. Find a
// way to avoid special casing.
- SmallVector<Value *, 4> results;
+ SmallVector<Value, 4> results;
results.reserve(numResults);
for (unsigned i = 0; i < numResults; ++i) {
auto type = this->lowering.convertType(op->getResult(i)->getType());
@@ -1157,7 +1162,7 @@
useAlloca(useAlloca) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (useAlloca)
return rewriter.eraseOp(op), matchSuccess();
@@ -1177,7 +1182,7 @@
}
MemRefDescriptor memref(transformed.memref());
- Value *casted = rewriter.create<LLVM::BitcastOp>(
+ Value casted = rewriter.create<LLVM::BitcastOp>(
op->getLoc(), getVoidPtrType(),
memref.allocatedPtr(rewriter, op->getLoc()));
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
@@ -1193,7 +1198,7 @@
using LLVMLegalizationPattern<TanhOp>::LLVMLegalizationPattern;
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
using LLVMFuncOpT = LLVM::LLVMFuncOp;
@@ -1267,7 +1272,7 @@
: matchFailure();
}
- void rewrite(Operation *op, ArrayRef<Value *> operands,
+ void rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto memRefCastOp = cast<MemRefCastOp>(op);
OperandAdaptor<MemRefCastOp> transformed(operands);
@@ -1308,7 +1313,7 @@
memRefDesc.setRank(rewriter, loc, rankVal);
// d2 = InsertValueOp d1, voidptr, 1
memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
- rewriter.replaceOp(op, (Value *)memRefDesc);
+ rewriter.replaceOp(op, (Value)memRefDesc);
} else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
// Casting from unranked type to ranked.
@@ -1339,7 +1344,7 @@
using LLVMLegalizationPattern<DimOp>::LLVMLegalizationPattern;
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto dimOp = cast<DimOp>(op);
OperandAdaptor<DimOp> transformed(operands);
@@ -1381,43 +1386,42 @@
// by accumulating the running linearized value.
// Note that `indices` and `allocSizes` are passed in the same order as they
// appear in load/store operations and memref type declarations.
- Value *linearizeSubscripts(ConversionPatternRewriter &builder, Location loc,
- ArrayRef<Value *> indices,
- ArrayRef<Value *> allocSizes) const {
+ Value linearizeSubscripts(ConversionPatternRewriter &builder, Location loc,
+ ArrayRef<Value> indices,
+ ArrayRef<Value> allocSizes) const {
assert(indices.size() == allocSizes.size() &&
"mismatching number of indices and allocation sizes");
assert(!indices.empty() && "cannot linearize a 0-dimensional access");
- Value *linearized = indices.front();
+ Value linearized = indices.front();
for (int i = 1, nSizes = allocSizes.size(); i < nSizes; ++i) {
linearized = builder.create<LLVM::MulOp>(
loc, this->getIndexType(),
- ArrayRef<Value *>{linearized, allocSizes[i]});
+ ArrayRef<Value>{linearized, allocSizes[i]});
linearized = builder.create<LLVM::AddOp>(
- loc, this->getIndexType(), ArrayRef<Value *>{linearized, indices[i]});
+ loc, this->getIndexType(), ArrayRef<Value>{linearized, indices[i]});
}
return linearized;
}
// This is a strided getElementPtr variant that linearizes subscripts as:
// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
- Value *getStridedElementPtr(Location loc, Type elementTypePtr,
- Value *descriptor, ArrayRef<Value *> indices,
- ArrayRef<int64_t> strides, int64_t offset,
- ConversionPatternRewriter &rewriter) const {
+ Value getStridedElementPtr(Location loc, Type elementTypePtr,
+ Value descriptor, ArrayRef<Value> indices,
+ ArrayRef<int64_t> strides, int64_t offset,
+ ConversionPatternRewriter &rewriter) const {
MemRefDescriptor memRefDescriptor(descriptor);
- Value *base = memRefDescriptor.alignedPtr(rewriter, loc);
- Value *offsetValue = offset == MemRefType::getDynamicStrideOrOffset()
- ? memRefDescriptor.offset(rewriter, loc)
- : this->createIndexConstant(rewriter, loc, offset);
+ Value base = memRefDescriptor.alignedPtr(rewriter, loc);
+ Value offsetValue = offset == MemRefType::getDynamicStrideOrOffset()
+ ? memRefDescriptor.offset(rewriter, loc)
+ : this->createIndexConstant(rewriter, loc, offset);
for (int i = 0, e = indices.size(); i < e; ++i) {
- Value *stride =
- strides[i] == MemRefType::getDynamicStrideOrOffset()
- ? memRefDescriptor.stride(rewriter, loc, i)
- : this->createIndexConstant(rewriter, loc, strides[i]);
- Value *additionalOffset =
+ Value stride = strides[i] == MemRefType::getDynamicStrideOrOffset()
+ ? memRefDescriptor.stride(rewriter, loc, i)
+ : this->createIndexConstant(rewriter, loc, strides[i]);
+ Value additionalOffset =
rewriter.create<LLVM::MulOp>(loc, indices[i], stride);
offsetValue =
rewriter.create<LLVM::AddOp>(loc, offsetValue, additionalOffset);
@@ -1425,10 +1429,9 @@
return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, base, offsetValue);
}
- Value *getDataPtr(Location loc, MemRefType type, Value *memRefDesc,
- ArrayRef<Value *> indices,
- ConversionPatternRewriter &rewriter,
- llvm::Module &module) const {
+ Value getDataPtr(Location loc, MemRefType type, Value memRefDesc,
+ ArrayRef<Value> indices, ConversionPatternRewriter &rewriter,
+ llvm::Module &module) const {
LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType();
int64_t offset;
SmallVector<int64_t, 4> strides;
@@ -1446,14 +1449,14 @@
using Base::Base;
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loadOp = cast<LoadOp>(op);
OperandAdaptor<LoadOp> transformed(operands);
auto type = loadOp.getMemRefType();
- Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
- transformed.indices(), rewriter, getModule());
+ Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
+ transformed.indices(), rewriter, getModule());
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dataPtr);
return matchSuccess();
}
@@ -1465,13 +1468,13 @@
using Base::Base;
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto type = cast<StoreOp>(op).getMemRefType();
OperandAdaptor<StoreOp> transformed(operands);
- Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
- transformed.indices(), rewriter, getModule());
+ Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
+ transformed.indices(), rewriter, getModule());
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(),
dataPtr);
return matchSuccess();
@@ -1484,14 +1487,14 @@
using Base::Base;
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto prefetchOp = cast<PrefetchOp>(op);
OperandAdaptor<PrefetchOp> transformed(operands);
auto type = prefetchOp.getMemRefType();
- Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
- transformed.indices(), rewriter, getModule());
+ Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
+ transformed.indices(), rewriter, getModule());
// Replace with llvm.prefetch.
auto llvmI32Type = lowering.convertType(rewriter.getIntegerType(32));
@@ -1519,7 +1522,7 @@
using LLVMLegalizationPattern<IndexCastOp>::LLVMLegalizationPattern;
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
IndexCastOpOperandAdaptor transformed(operands);
auto indexCastOp = cast<IndexCastOp>(op);
@@ -1554,7 +1557,7 @@
using LLVMLegalizationPattern<CmpIOp>::LLVMLegalizationPattern;
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto cmpiOp = cast<CmpIOp>(op);
CmpIOpOperandAdaptor transformed(operands);
@@ -1573,7 +1576,7 @@
using LLVMLegalizationPattern<CmpFOp>::LLVMLegalizationPattern;
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto cmpfOp = cast<CmpFOp>(op);
CmpFOpOperandAdaptor transformed(operands);
@@ -1625,9 +1628,9 @@
using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> properOperands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> properOperands,
ArrayRef<Block *> destinations,
- ArrayRef<ArrayRef<Value *>> operands,
+ ArrayRef<ArrayRef<Value>> operands,
ConversionPatternRewriter &rewriter) const override {
SmallVector<ValueRange, 2> operandRanges(operands.begin(), operands.end());
rewriter.replaceOpWithNewOp<TargetOp>(op, properOperands, destinations,
@@ -1646,19 +1649,19 @@
using LLVMLegalizationPattern<ReturnOp>::LLVMLegalizationPattern;
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
unsigned numArguments = op->getNumOperands();
// If ReturnOp has 0 or 1 operand, create it and return immediately.
if (numArguments == 0) {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
- op, ArrayRef<Value *>(), ArrayRef<Block *>(), op->getAttrs());
+ op, ArrayRef<Value>(), ArrayRef<Block *>(), op->getAttrs());
return matchSuccess();
}
if (numArguments == 1) {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
- op, ArrayRef<Value *>(operands.front()), ArrayRef<Block *>(),
+ op, ArrayRef<Value>(operands.front()), ArrayRef<Block *>(),
op->getAttrs());
return matchSuccess();
}
@@ -1668,7 +1671,7 @@
auto packedType =
lowering.packFunctionResults(llvm::to_vector<4>(op->getOperandTypes()));
- Value *packed = rewriter.create<LLVM::UndefOp>(op->getLoc(), packedType);
+ Value packed = rewriter.create<LLVM::UndefOp>(op->getLoc(), packedType);
for (unsigned i = 0; i < numArguments; ++i) {
packed = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), packedType, packed, operands[i],
@@ -1696,7 +1699,7 @@
using LLVMLegalizationPattern<SplatOp>::LLVMLegalizationPattern;
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto splatOp = cast<SplatOp>(op);
VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
@@ -1705,7 +1708,7 @@
// First insert it into an undef vector so we can shuffle it.
auto vectorType = lowering.convertType(splatOp.getType());
- Value *undef = rewriter.create<LLVM::UndefOp>(op->getLoc(), vectorType);
+ Value undef = rewriter.create<LLVM::UndefOp>(op->getLoc(), vectorType);
auto zero = rewriter.create<LLVM::ConstantOp>(
op->getLoc(), lowering.convertType(rewriter.getIntegerType(32)),
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
@@ -1730,7 +1733,7 @@
using LLVMLegalizationPattern<SplatOp>::LLVMLegalizationPattern;
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto splatOp = cast<SplatOp>(op);
OperandAdaptor<SplatOp> adaptor(operands);
@@ -1747,16 +1750,16 @@
return matchFailure();
// Construct returned value.
- Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy);
+ Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy);
// Construct a 1-D vector with the splatted value that we insert in all the
// places within the returned descriptor.
- Value *vdesc = rewriter.create<LLVM::UndefOp>(loc, llvmVectorTy);
+ Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvmVectorTy);
auto zero = rewriter.create<LLVM::ConstantOp>(
loc, lowering.convertType(rewriter.getIntegerType(32)),
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
- Value *v = rewriter.create<LLVM::InsertElementOp>(loc, llvmVectorTy, vdesc,
- adaptor.input(), zero);
+ Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvmVectorTy, vdesc,
+ adaptor.input(), zero);
// Shuffle the value across the desired number of elements.
int64_t width = resultType.getDimSize(resultType.getRank() - 1);
@@ -1784,21 +1787,21 @@
using LLVMLegalizationPattern<SubViewOp>::LLVMLegalizationPattern;
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto viewOp = cast<SubViewOp>(op);
// TODO(b/144779634, ravishankarm) : After Tblgen is adapted to support
// having multiple variadic operands where each operand can have different
// number of entries, clean all of this up.
- SmallVector<Value *, 2> dynamicOffsets(
+ SmallVector<Value, 2> dynamicOffsets(
std::next(operands.begin()),
std::next(operands.begin(), 1 + viewOp.getNumOffsets()));
- SmallVector<Value *, 2> dynamicSizes(
+ SmallVector<Value, 2> dynamicSizes(
std::next(operands.begin(), 1 + viewOp.getNumOffsets()),
std::next(operands.begin(),
1 + viewOp.getNumOffsets() + viewOp.getNumSizes()));
- SmallVector<Value *, 2> dynamicStrides(
+ SmallVector<Value, 2> dynamicStrides(
std::next(operands.begin(),
1 + viewOp.getNumOffsets() + viewOp.getNumSizes()),
operands.end());
@@ -1835,8 +1838,8 @@
auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
// Copy the buffer pointer from the old descriptor to the new one.
- Value *extracted = sourceMemRef.allocatedPtr(rewriter, loc);
- Value *bitcastPtr = rewriter.create<LLVM::BitcastOp>(
+ Value extracted = sourceMemRef.allocatedPtr(rewriter, loc);
+ Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
loc, targetElementTy.getPointerTo(), extracted);
targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
@@ -1846,7 +1849,7 @@
targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
// Extract strides needed to compute offset.
- SmallVector<Value *, 4> strideValues;
+ SmallVector<Value, 4> strideValues;
strideValues.reserve(viewMemRefType.getRank());
for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i)
strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
@@ -1863,9 +1866,9 @@
}
// Offset.
- Value *baseOffset = sourceMemRef.offset(rewriter, loc);
+ Value baseOffset = sourceMemRef.offset(rewriter, loc);
for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) {
- Value *min = dynamicOffsets[i];
+ Value min = dynamicOffsets[i];
baseOffset = rewriter.create<LLVM::AddOp>(
loc, baseOffset,
rewriter.create<LLVM::MulOp>(loc, min, strideValues[i]));
@@ -1875,7 +1878,7 @@
// Update sizes and strides.
for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
targetMemRef.setSize(rewriter, loc, i, dynamicSizes[i]);
- Value *newStride;
+ Value newStride;
if (dynamicStrides.empty())
newStride = rewriter.create<LLVM::ConstantOp>(
loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i]));
@@ -1900,9 +1903,9 @@
// Build and return the value for the idx^th shape dimension, either by
// returning the constant shape dimension or counting the proper dynamic size.
- Value *getSize(ConversionPatternRewriter &rewriter, Location loc,
- ArrayRef<int64_t> shape, ArrayRef<Value *> dynamicSizes,
- unsigned idx) const {
+ Value getSize(ConversionPatternRewriter &rewriter, Location loc,
+ ArrayRef<int64_t> shape, ArrayRef<Value> dynamicSizes,
+ unsigned idx) const {
assert(idx < shape.size());
if (!ShapedType::isDynamic(shape[idx]))
return createIndexConstant(rewriter, loc, shape[idx]);
@@ -1917,9 +1920,9 @@
// or by computing the dynamic stride from the current `runningStride` and
// `nextSize`. The caller should keep a running stride and update it with the
// result returned by this function.
- Value *getStride(ConversionPatternRewriter &rewriter, Location loc,
- ArrayRef<int64_t> strides, Value *nextSize,
- Value *runningStride, unsigned idx) const {
+ Value getStride(ConversionPatternRewriter &rewriter, Location loc,
+ ArrayRef<int64_t> strides, Value nextSize,
+ Value runningStride, unsigned idx) const {
assert(idx < strides.size());
if (strides[idx] != MemRefType::getDynamicStrideOrOffset())
return createIndexConstant(rewriter, loc, strides[idx]);
@@ -1932,7 +1935,7 @@
}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto viewOp = cast<ViewOp>(op);
@@ -1959,8 +1962,8 @@
auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
// Field 1: Copy the allocated pointer, used for malloc/free.
- Value *extracted = sourceMemRef.allocatedPtr(rewriter, loc);
- Value *bitcastPtr = rewriter.create<LLVM::BitcastOp>(
+ Value extracted = sourceMemRef.allocatedPtr(rewriter, loc);
+ Value bitcastPtr = rewriter.create<LLVM::BitcastOp>(
loc, targetElementTy.getPointerTo(), extracted);
targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
@@ -1977,10 +1980,10 @@
auto sizeAndOffsetOperands = adaptor.operands();
assert(llvm::size(sizeAndOffsetOperands) ==
numDynamicSizes + (hasDynamicOffset ? 1 : 0));
- Value *baseOffset = !hasDynamicOffset
- ? createIndexConstant(rewriter, loc, offset)
- // TODO(ntv): better adaptor.
- : sizeAndOffsetOperands.front();
+ Value baseOffset = !hasDynamicOffset
+ ? createIndexConstant(rewriter, loc, offset)
+ // TODO(ntv): better adaptor.
+ : sizeAndOffsetOperands.front();
targetMemRef.setOffset(rewriter, loc, baseOffset);
// Early exit for 0-D corner case.
@@ -1991,14 +1994,14 @@
if (strides.back() != 1)
return op->emitWarning("cannot cast to non-contiguous shape"),
matchFailure();
- Value *stride = nullptr, *nextSize = nullptr;
+ Value stride = nullptr, nextSize = nullptr;
// Drop the dynamic stride from the operand list, if present.
- ArrayRef<Value *> sizeOperands(sizeAndOffsetOperands);
+ ArrayRef<Value> sizeOperands(sizeAndOffsetOperands);
if (hasDynamicOffset)
sizeOperands = sizeOperands.drop_front();
for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
// Update size.
- Value *size =
+ Value size =
getSize(rewriter, loc, viewMemRefType.getShape(), sizeOperands, i);
targetMemRef.setSize(rewriter, loc, i, size);
// Update stride.
@@ -2042,7 +2045,7 @@
auto *dummyBlock = new Block();
bb.getParent()->push_back(dummyBlock);
auto builder = OpBuilder(dummyBlock);
- SmallVector<Value *, 8> operands(
+ SmallVector<Value, 8> operands(
terminator->getSuccessorOperands(*position));
builder.create<BranchOp>(terminator->getLoc(), successor.first, operands);
terminator->setSuccessor(dummyBlock, *position);
@@ -2082,8 +2085,6 @@
CosOpLowering,
ConstLLVMOpLowering,
DivFOpLowering,
- DivISOpLowering,
- DivIUOpLowering,
ExpOpLowering,
LogOpLowering,
Log10OpLowering,
@@ -2097,18 +2098,23 @@
OrOpLowering,
PrefetchOpLowering,
RemFOpLowering,
- RemISOpLowering,
- RemIUOpLowering,
ReturnOpLowering,
SIToFPLowering,
SelectOpLowering,
+ ShiftLeftOpLowering,
SignExtendIOpLowering,
+ SignedDivIOpLowering,
+ SignedRemIOpLowering,
+ SignedShiftRightOpLowering,
SplatOpLowering,
SplatNdOpLowering,
SubFOpLowering,
SubIOpLowering,
TanhOpLowering,
TruncateIOpLowering,
+ UnsignedDivIOpLowering,
+ UnsignedRemIOpLowering,
+ UnsignedShiftRightOpLowering,
XOrOpLowering,
ZeroExtendIOpLowering>(*converter.getDialect(), converter);
// clang-format on
@@ -2160,33 +2166,32 @@
return LLVM::LLVMType::getStructTy(llvmDialect, resultTypes);
}
-Value *LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc,
- Value *operand,
- OpBuilder &builder) {
+Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
+ OpBuilder &builder) {
auto *context = builder.getContext();
auto int64Ty = LLVM::LLVMType::getInt64Ty(getDialect());
auto indexType = IndexType::get(context);
// Alloca with proper alignment. We do not expect optimizations of this
// alloca op and so we omit allocating at the entry block.
auto ptrType = operand->getType().cast<LLVM::LLVMType>().getPointerTo();
- Value *one = builder.create<LLVM::ConstantOp>(loc, int64Ty,
- IntegerAttr::get(indexType, 1));
- Value *allocated =
+ Value one = builder.create<LLVM::ConstantOp>(loc, int64Ty,
+ IntegerAttr::get(indexType, 1));
+ Value allocated =
builder.create<LLVM::AllocaOp>(loc, ptrType, one, /*alignment=*/0);
// Store into the alloca'ed descriptor.
builder.create<LLVM::StoreOp>(loc, operand, allocated);
return allocated;
}
-SmallVector<Value *, 4>
+SmallVector<Value, 4>
LLVMTypeConverter::promoteMemRefDescriptors(Location loc, ValueRange opOperands,
ValueRange operands,
OpBuilder &builder) {
- SmallVector<Value *, 4> promotedOperands;
+ SmallVector<Value, 4> promotedOperands;
promotedOperands.reserve(operands.size());
for (auto it : llvm::zip(opOperands, operands)) {
- auto *operand = std::get<0>(it);
- auto *llvmOperand = std::get<1>(it);
+ auto operand = std::get<0>(it);
+ auto llvmOperand = std::get<1>(it);
if (!operand->getType().isa<MemRefType>() &&
!operand->getType().isa<UnrankedMemRefType>()) {
promotedOperands.push_back(operand);
diff --git a/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index e87bd4e..a02dee4 100644
--- a/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -1,19 +1,10 @@
//===- ConvertStandardToSPIRV.cpp - Standard to SPIR-V dialect conversion--===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements patterns to convert Standard Ops to the SPIR-V dialect.
//
@@ -44,7 +35,7 @@
using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering;
PatternMatchResult
- matchAndRewrite(ConstantOp constIndexOp, ArrayRef<Value *> operands,
+ matchAndRewrite(ConstantOp constIndexOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@@ -54,7 +45,7 @@
using SPIRVOpLowering<CmpIOp>::SPIRVOpLowering;
PatternMatchResult
- matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value *> operands,
+ matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@@ -70,7 +61,7 @@
using SPIRVOpLowering<StdOp>::SPIRVOpLowering;
PatternMatchResult
- matchAndRewrite(StdOp operation, ArrayRef<Value *> operands,
+ matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto resultType =
this->typeConverter.convertType(operation.getResult()->getType());
@@ -89,7 +80,7 @@
using SPIRVOpLowering<LoadOp>::SPIRVOpLowering;
PatternMatchResult
- matchAndRewrite(LoadOp loadOp, ArrayRef<Value *> operands,
+ matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@@ -100,7 +91,7 @@
using SPIRVOpLowering<ReturnOp>::SPIRVOpLowering;
PatternMatchResult
- matchAndRewrite(ReturnOp returnOp, ArrayRef<Value *> operands,
+ matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@@ -110,7 +101,7 @@
public:
using SPIRVOpLowering<SelectOp>::SPIRVOpLowering;
PatternMatchResult
- matchAndRewrite(SelectOp op, ArrayRef<Value *> operands,
+ matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@@ -123,7 +114,7 @@
using SPIRVOpLowering<StoreOp>::SPIRVOpLowering;
PatternMatchResult
- matchAndRewrite(StoreOp storeOp, ArrayRef<Value *> operands,
+ matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@@ -141,7 +132,7 @@
spirv::AccessChainOp getElementPtr(OpBuilder &builder,
SPIRVTypeConverter &typeConverter,
Location loc, MemRefType origBaseType,
- Value *basePtr, ArrayRef<Value *> indices) {
+ Value basePtr, ArrayRef<Value> indices) {
// Get base and offset of the MemRefType and verify they are static.
int64_t offset;
SmallVector<int64_t, 4> strides;
@@ -152,18 +143,17 @@
auto indexType = typeConverter.getIndexType(builder.getContext());
- Value *ptrLoc = nullptr;
+ Value ptrLoc = nullptr;
assert(indices.size() == strides.size());
for (auto index : enumerate(indices)) {
- Value *strideVal = builder.create<spirv::ConstantOp>(
+ Value strideVal = builder.create<spirv::ConstantOp>(
loc, indexType, IntegerAttr::get(indexType, strides[index.index()]));
- Value *update =
- builder.create<spirv::IMulOp>(loc, strideVal, index.value());
+ Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value());
ptrLoc =
(ptrLoc ? builder.create<spirv::IAddOp>(loc, ptrLoc, update).getResult()
: update);
}
- SmallVector<Value *, 2> linearizedIndices;
+ SmallVector<Value, 2> linearizedIndices;
// Add a '0' at the start to index into the struct.
linearizedIndices.push_back(builder.create<spirv::ConstantOp>(
loc, indexType, IntegerAttr::get(indexType, 0)));
@@ -176,7 +166,7 @@
//===----------------------------------------------------------------------===//
PatternMatchResult ConstantIndexOpConversion::matchAndRewrite(
- ConstantOp constIndexOp, ArrayRef<Value *> operands,
+ ConstantOp constIndexOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (!constIndexOp.getResult()->getType().isa<IndexType>()) {
return matchFailure();
@@ -210,7 +200,7 @@
//===----------------------------------------------------------------------===//
PatternMatchResult
-CmpIOpConversion::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value *> operands,
+CmpIOpConversion::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
CmpIOpOperandAdaptor cmpIOpOperands(operands);
@@ -242,7 +232,7 @@
//===----------------------------------------------------------------------===//
PatternMatchResult
-LoadOpConversion::matchAndRewrite(LoadOp loadOp, ArrayRef<Value *> operands,
+LoadOpConversion::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
LoadOpOperandAdaptor loadOperands(operands);
auto loadPtr = getElementPtr(rewriter, typeConverter, loadOp.getLoc(),
@@ -259,8 +249,7 @@
//===----------------------------------------------------------------------===//
PatternMatchResult
-ReturnOpConversion::matchAndRewrite(ReturnOp returnOp,
- ArrayRef<Value *> operands,
+ReturnOpConversion::matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (returnOp.getNumOperands()) {
return matchFailure();
@@ -274,7 +263,7 @@
//===----------------------------------------------------------------------===//
PatternMatchResult
-SelectOpConversion::matchAndRewrite(SelectOp op, ArrayRef<Value *> operands,
+SelectOpConversion::matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
SelectOpOperandAdaptor selectOperands(operands);
rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, selectOperands.condition(),
@@ -288,7 +277,7 @@
//===----------------------------------------------------------------------===//
PatternMatchResult
-StoreOpConversion::matchAndRewrite(StoreOp storeOp, ArrayRef<Value *> operands,
+StoreOpConversion::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
StoreOpOperandAdaptor storeOperands(operands);
auto storePtr =
@@ -316,8 +305,8 @@
patterns.insert<ConstantIndexOpConversion, CmpIOpConversion,
IntegerOpConversion<AddIOp, spirv::IAddOp>,
IntegerOpConversion<MulIOp, spirv::IMulOp>,
- IntegerOpConversion<DivISOp, spirv::SDivOp>,
- IntegerOpConversion<RemISOp, spirv::SModOp>,
+ IntegerOpConversion<SignedDivIOp, spirv::SDivOp>,
+ IntegerOpConversion<SignedRemIOp, spirv::SModOp>,
IntegerOpConversion<SubIOp, spirv::ISubOp>, LoadOpConversion,
ReturnOpConversion, SelectOpConversion, StoreOpConversion>(
context, typeConverter);
diff --git a/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp b/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp
index c0c56a3..52456b6 100644
--- a/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp
+++ b/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp
@@ -1,19 +1,10 @@
//===- ConvertStandardToSPIRVPass.cpp - Convert Std Ops to SPIR-V Ops -----===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements a pass to convert MLIR standard ops into the SPIR-V
// ops.
@@ -37,7 +28,7 @@
using SPIRVOpLowering<FuncOp>::SPIRVOpLowering;
PatternMatchResult
- matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
+ matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@@ -49,7 +40,7 @@
} // namespace
PatternMatchResult
-FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
+FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
auto fnType = funcOp.getType();
if (fnType.getNumResults()) {
@@ -63,13 +54,12 @@
signatureConverter.addInputs(argType.index(), convertedType);
}
}
- auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
- rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
- newFuncOp.end());
- newFuncOp.setType(rewriter.getFunctionType(
- signatureConverter.getConvertedTypes(), llvm::None));
- rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
- rewriter.replaceOp(funcOp.getOperation(), llvm::None);
+
+ rewriter.updateRootInPlace(funcOp, [&] {
+ funcOp.setType(rewriter.getFunctionType(
+ signatureConverter.getConvertedTypes(), llvm::None));
+ rewriter.applySignatureConversion(&funcOp.getBody(), signatureConverter);
+ });
return matchSuccess();
}
diff --git a/third_party/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/third_party/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
index 4469c28..a658356 100644
--- a/third_party/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
+++ b/third_party/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
@@ -1,19 +1,10 @@
//===- LegalizeStandardForSPIRV.cpp - Legalize ops for SPIR-V lowering ----===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This transformation pass legalizes operations before the conversion to SPIR-V
// dialect to handle ops that cannot be lowered directly.
@@ -69,7 +60,7 @@
static LogicalResult
resolveSourceIndices(Location loc, PatternRewriter &rewriter,
SubViewOp subViewOp, ValueRange indices,
- SmallVectorImpl<Value *> &sourceIndices) {
+ SmallVectorImpl<Value> &sourceIndices) {
// TODO: Aborting when the offsets are static. There might be a way to fold
// the subview op with load even if the offsets have been canonicalized
// away.
@@ -77,7 +68,7 @@
return failure();
ValueRange opOffsets = subViewOp.offsets();
- SmallVector<Value *, 2> opStrides;
+ SmallVector<Value, 2> opStrides;
if (subViewOp.getNumStrides()) {
// If the strides are dynamic, get the stride operands.
opStrides = llvm::to_vector<2>(subViewOp.strides());
@@ -124,7 +115,7 @@
if (!subViewOp) {
return matchFailure();
}
- SmallVector<Value *, 4> sourceIndices;
+ SmallVector<Value, 4> sourceIndices;
if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp,
loadOp.indices(), sourceIndices)))
return matchFailure();
@@ -146,7 +137,7 @@
if (!subViewOp) {
return matchFailure();
}
- SmallVector<Value *, 4> sourceIndices;
+ SmallVector<Value, 4> sourceIndices;
if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
storeOp.indices(), sourceIndices)))
return matchFailure();
diff --git a/third_party/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/third_party/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 416cb4c..b48930c 100644
--- a/third_party/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/third_party/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1,19 +1,10 @@
//===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
@@ -62,9 +53,9 @@
}
// Helper that picks the proper sequence for inserting.
-static Value *insertOne(ConversionPatternRewriter &rewriter,
- LLVMTypeConverter &lowering, Location loc, Value *val1,
- Value *val2, Type llvmType, int64_t rank, int64_t pos) {
+static Value insertOne(ConversionPatternRewriter &rewriter,
+ LLVMTypeConverter &lowering, Location loc, Value val1,
+ Value val2, Type llvmType, int64_t rank, int64_t pos) {
if (rank == 1) {
auto idxType = rewriter.getIndexType();
auto constant = rewriter.create<LLVM::ConstantOp>(
@@ -78,9 +69,9 @@
}
// Helper that picks the proper sequence for extracting.
-static Value *extractOne(ConversionPatternRewriter &rewriter,
- LLVMTypeConverter &lowering, Location loc, Value *val,
- Type llvmType, int64_t rank, int64_t pos) {
+static Value extractOne(ConversionPatternRewriter &rewriter,
+ LLVMTypeConverter &lowering, Location loc, Value val,
+ Type llvmType, int64_t rank, int64_t pos) {
if (rank == 1) {
auto idxType = rewriter.getIndexType();
auto constant = rewriter.create<LLVM::ConstantOp>(
@@ -101,7 +92,7 @@
typeConverter) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto broadcastOp = cast<vector::BroadcastOp>(op);
VectorType dstVectorType = broadcastOp.getVectorType();
@@ -129,9 +120,9 @@
// ops once all insert/extract/shuffle operations
// are available with lowering implemention.
//
- Value *expandRanks(Value *value, Location loc, VectorType srcVectorType,
- VectorType dstVectorType,
- ConversionPatternRewriter &rewriter) const {
+ Value expandRanks(Value value, Location loc, VectorType srcVectorType,
+ VectorType dstVectorType,
+ ConversionPatternRewriter &rewriter) const {
assert((dstVectorType != nullptr) && "invalid result type in broadcast");
// Determine rank of source and destination.
int64_t srcRank = srcVectorType ? srcVectorType.getRank() : 0;
@@ -168,23 +159,22 @@
// becomes:
// x = [s,s]
// v = [x,x,x,x]
- Value *duplicateOneRank(Value *value, Location loc, VectorType srcVectorType,
- VectorType dstVectorType, int64_t rank, int64_t dim,
- ConversionPatternRewriter &rewriter) const {
+ Value duplicateOneRank(Value value, Location loc, VectorType srcVectorType,
+ VectorType dstVectorType, int64_t rank, int64_t dim,
+ ConversionPatternRewriter &rewriter) const {
Type llvmType = lowering.convertType(dstVectorType);
assert((llvmType != nullptr) && "unlowerable vector type");
if (rank == 1) {
- Value *undef = rewriter.create<LLVM::UndefOp>(loc, llvmType);
- Value *expand =
+ Value undef = rewriter.create<LLVM::UndefOp>(loc, llvmType);
+ Value expand =
insertOne(rewriter, lowering, loc, undef, value, llvmType, rank, 0);
SmallVector<int32_t, 4> zeroValues(dim, 0);
return rewriter.create<LLVM::ShuffleVectorOp>(
loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues));
}
- Value *expand =
- expandRanks(value, loc, srcVectorType,
- reducedVectorTypeFront(dstVectorType), rewriter);
- Value *result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
+ Value expand = expandRanks(value, loc, srcVectorType,
+ reducedVectorTypeFront(dstVectorType), rewriter);
+ Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
for (int64_t d = 0; d < dim; ++d) {
result =
insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d);
@@ -209,19 +199,19 @@
// y = broadcast w[1][0] : vector<2xf32> to vector <2x2xf32>
// a = [x, y]
// etc.
- Value *stretchOneRank(Value *value, Location loc, VectorType srcVectorType,
- VectorType dstVectorType, int64_t rank, int64_t dim,
- ConversionPatternRewriter &rewriter) const {
+ Value stretchOneRank(Value value, Location loc, VectorType srcVectorType,
+ VectorType dstVectorType, int64_t rank, int64_t dim,
+ ConversionPatternRewriter &rewriter) const {
Type llvmType = lowering.convertType(dstVectorType);
assert((llvmType != nullptr) && "unlowerable vector type");
- Value *result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
+ Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
bool atStretch = dim != srcVectorType.getDimSize(0);
if (rank == 1) {
assert(atStretch);
Type redLlvmType = lowering.convertType(dstVectorType.getElementType());
- Value *one =
+ Value one =
extractOne(rewriter, lowering, loc, value, redLlvmType, rank, 0);
- Value *expand =
+ Value expand =
insertOne(rewriter, lowering, loc, result, one, llvmType, rank, 0);
SmallVector<int32_t, 4> zeroValues(dim, 0);
return rewriter.create<LLVM::ShuffleVectorOp>(
@@ -232,9 +222,9 @@
Type redLlvmType = lowering.convertType(redSrcType);
for (int64_t d = 0; d < dim; ++d) {
int64_t pos = atStretch ? 0 : d;
- Value *one =
+ Value one =
extractOne(rewriter, lowering, loc, value, redLlvmType, rank, pos);
- Value *expand = expandRanks(one, loc, redSrcType, redDstType, rewriter);
+ Value expand = expandRanks(one, loc, redSrcType, redDstType, rewriter);
result =
insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d);
}
@@ -250,7 +240,7 @@
typeConverter) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto adaptor = vector::ShuffleOpOperandAdaptor(operands);
@@ -274,23 +264,23 @@
// For rank 1, where both operands have *exactly* the same vector type,
// there is direct shuffle support in LLVM. Use it!
if (rank == 1 && v1Type == v2Type) {
- Value *shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
+ Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
rewriter.replaceOp(op, shuffle);
return matchSuccess();
}
// For all other cases, insert the individual values individually.
- Value *insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
+ Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
int64_t insPos = 0;
for (auto en : llvm::enumerate(maskArrayAttr)) {
int64_t extPos = en.value().cast<IntegerAttr>().getInt();
- Value *value = adaptor.v1();
+ Value value = adaptor.v1();
if (extPos >= v1Dim) {
extPos -= v1Dim;
value = adaptor.v2();
}
- Value *extract =
+ Value extract =
extractOne(rewriter, lowering, loc, value, llvmType, rank, extPos);
insert = insertOne(rewriter, lowering, loc, insert, extract, llvmType,
rank, insPos++);
@@ -308,7 +298,7 @@
typeConverter) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto adaptor = vector::ExtractElementOpOperandAdaptor(operands);
auto extractEltOp = cast<vector::ExtractElementOp>(op);
@@ -333,7 +323,7 @@
typeConverter) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto adaptor = vector::ExtractOpOperandAdaptor(operands);
@@ -349,7 +339,7 @@
// One-shot extraction of vector from array (only requires extractvalue).
if (resultType.isa<VectorType>()) {
- Value *extracted = rewriter.create<LLVM::ExtractValueOp>(
+ Value extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, llvmResultType, adaptor.vector(), positionArrayAttr);
rewriter.replaceOp(op, extracted);
return matchSuccess();
@@ -357,7 +347,7 @@
// Potential extraction of 1-D vector from array.
auto *context = op->getContext();
- Value *extracted = adaptor.vector();
+ Value extracted = adaptor.vector();
auto positionAttrs = positionArrayAttr.getValue();
if (positionAttrs.size() > 1) {
auto oneDVectorType = reducedVectorTypeBack(vectorType);
@@ -370,8 +360,8 @@
// Remaining extraction of element from 1-D LLVM vector
auto position = positionAttrs.back().cast<IntegerAttr>();
- auto i32Type = LLVM::LLVMType::getInt32Ty(lowering.getDialect());
- auto constant = rewriter.create<LLVM::ConstantOp>(loc, i32Type, position);
+ auto i64Type = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
+ auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
extracted =
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
rewriter.replaceOp(op, extracted);
@@ -388,7 +378,7 @@
typeConverter) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto adaptor = vector::InsertElementOpOperandAdaptor(operands);
auto insertEltOp = cast<vector::InsertElementOp>(op);
@@ -413,7 +403,7 @@
typeConverter) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto adaptor = vector::InsertOpOperandAdaptor(operands);
@@ -429,7 +419,7 @@
// One-shot insertion of a vector into an array (only requires insertvalue).
if (sourceType.isa<VectorType>()) {
- Value *inserted = rewriter.create<LLVM::InsertValueOp>(
+ Value inserted = rewriter.create<LLVM::InsertValueOp>(
loc, llvmResultType, adaptor.dest(), adaptor.source(),
positionArrayAttr);
rewriter.replaceOp(op, inserted);
@@ -438,7 +428,7 @@
// Potential extraction of 1-D vector from array.
auto *context = op->getContext();
- Value *extracted = adaptor.dest();
+ Value extracted = adaptor.dest();
auto positionAttrs = positionArrayAttr.getValue();
auto position = positionAttrs.back().cast<IntegerAttr>();
auto oneDVectorType = destVectorType;
@@ -452,9 +442,9 @@
}
// Insertion of an element into a 1-D LLVM vector.
- auto i32Type = LLVM::LLVMType::getInt32Ty(lowering.getDialect());
- auto constant = rewriter.create<LLVM::ConstantOp>(loc, i32Type, position);
- Value *inserted = rewriter.create<LLVM::InsertElementOp>(
+ auto i64Type = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
+ auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
+ Value inserted = rewriter.create<LLVM::InsertElementOp>(
loc, lowering.convertType(oneDVectorType), extracted, adaptor.source(),
constant);
@@ -480,7 +470,7 @@
typeConverter) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto adaptor = vector::OuterProductOpOperandAdaptor(operands);
@@ -491,10 +481,10 @@
auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements();
auto llvmArrayOfVectType = lowering.convertType(
cast<vector::OuterProductOp>(op).getResult()->getType());
- Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType);
- Value *a = adaptor.lhs(), *b = adaptor.rhs();
- Value *acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front();
- SmallVector<Value *, 8> lhs, accs;
+ Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType);
+ Value a = adaptor.lhs(), b = adaptor.rhs();
+ Value acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front();
+ SmallVector<Value, 8> lhs, accs;
lhs.reserve(rankLHS);
accs.reserve(rankLHS);
for (unsigned d = 0, e = rankLHS; d < e; ++d) {
@@ -502,7 +492,7 @@
auto attr = rewriter.getI32IntegerAttr(d);
SmallVector<Attribute, 4> bcastAttr(rankRHS, attr);
auto bcastArrayAttr = ArrayAttr::get(bcastAttr, ctx);
- Value *aD = nullptr, *accD = nullptr;
+ Value aD = nullptr, accD = nullptr;
// 1. Broadcast the element a[d] into vector aD.
aD = rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, bcastArrayAttr);
// 2. If acc is present, extract 1-d vector acc[d] into accD.
@@ -510,7 +500,7 @@
accD = rewriter.create<LLVM::ExtractValueOp>(
loc, vRHS, acc, rewriter.getI64ArrayAttr(d));
// 3. Compute aD outer b (plus accD, if relevant).
- Value *aOuterbD =
+ Value aOuterbD =
accD ? rewriter.create<LLVM::FMulAddOp>(loc, vRHS, aD, b, accD)
.getResult()
: rewriter.create<LLVM::FMulOp>(loc, aD, b).getResult();
@@ -532,7 +522,7 @@
typeConverter) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op);
@@ -581,12 +571,12 @@
auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
Type llvmTargetElementTy = desc.getElementType();
// Set allocated ptr.
- Value *allocated = sourceMemRef.allocatedPtr(rewriter, loc);
+ Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
allocated =
rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
desc.setAllocatedPtr(rewriter, loc, allocated);
// Set aligned ptr.
- Value *ptr = sourceMemRef.alignedPtr(rewriter, loc);
+ Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
desc.setAlignedPtr(rewriter, loc, ptr);
// Fill offset 0.
@@ -632,7 +622,7 @@
// TODO(ajcbik): rely solely on libc in future? something else?
//
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto printOp = cast<vector::PrintOp>(op);
auto adaptor = vector::PrintOpOperandAdaptor(operands);
@@ -662,7 +652,7 @@
private:
void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
- Value *value, VectorType vectorType, Operation *printer,
+ Value value, VectorType vectorType, Operation *printer,
int64_t rank) const {
Location loc = op->getLoc();
if (rank == 0) {
@@ -678,7 +668,7 @@
rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
auto llvmType = lowering.convertType(
rank > 1 ? reducedType : vectorType.getElementType());
- Value *nestedVal =
+ Value nestedVal =
extractOne(rewriter, lowering, loc, value, llvmType, rank, d);
emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1);
if (d != dim - 1)
diff --git a/third_party/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp b/third_party/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp
index 33778e4..3ed031b 100644
--- a/third_party/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp
+++ b/third_party/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp
@@ -1,19 +1,10 @@
//===- VectorToLoops.cpp - Conversion from Vector to mix of Loops and Std -===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements target-dependent lowering of vector transfer operations.
//
diff --git a/third_party/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/third_party/mlir/lib/Dialect/AffineOps/AffineOps.cpp
index 6768aa5..5f4cc2e 100644
--- a/third_party/mlir/lib/Dialect/AffineOps/AffineOps.cpp
+++ b/third_party/mlir/lib/Dialect/AffineOps/AffineOps.cpp
@@ -1,19 +1,10 @@
//===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/AffineOps/AffineOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
@@ -115,8 +106,8 @@
/// A utility function to check if a value is defined at the top level of a
/// function. A value of index type defined at the top level is always a valid
/// symbol.
-bool mlir::isTopLevelValue(Value *value) {
- if (auto *arg = dyn_cast<BlockArgument>(value))
+bool mlir::isTopLevelValue(Value value) {
+ if (auto arg = value.dyn_cast<BlockArgument>())
return isFunctionRegion(arg->getOwner()->getParent());
return isFunctionRegion(value->getDefiningOp()->getParentRegion());
}
@@ -124,7 +115,7 @@
// Value can be used as a dimension id if it is valid as a symbol, or
// it is an induction variable, or it is a result of affine apply operation
// with dimension id arguments.
-bool mlir::isValidDim(Value *value) {
+bool mlir::isValidDim(Value value) {
// The value must be an index type.
if (!value->getType().isIndex())
return false;
@@ -142,8 +133,9 @@
return isTopLevelValue(dimOp.getOperand());
return false;
}
- // This value is a block argument (which also includes 'affine.for' loop IVs).
- return true;
+ // This value has to be a block argument for a FuncOp or an affine.for.
+ auto *parentOp = value.cast<BlockArgument>()->getOwner()->getParentOp();
+ return isa<FuncOp>(parentOp) || isa<AffineForOp>(parentOp);
}
/// Returns true if the 'index' dimension of the `memref` defined by
@@ -183,7 +175,7 @@
// the top level, or it is a result of affine apply operation with symbol
// arguments, or a result of the dim op on a memref satisfying certain
// constraints.
-bool mlir::isValidSymbol(Value *value) {
+bool mlir::isValidSymbol(Value value) {
// The value must be an index type.
if (!value->getType().isIndex())
return false;
@@ -206,7 +198,7 @@
// Returns true if 'value' is a valid index to an affine operation (e.g.
// affine.load, affine.store, affine.dma_start, affine.dma_wait).
// Returns false otherwise.
-static bool isValidAffineIndexOperand(Value *value) {
+static bool isValidAffineIndexOperand(Value value) {
return isValidDim(value) || isValidSymbol(value);
}
@@ -220,7 +212,7 @@
verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands,
unsigned numDims) {
unsigned opIt = 0;
- for (auto *operand : operands) {
+ for (auto operand : operands) {
if (opIt++ < numDims) {
if (!isValidDim(operand))
return op.emitOpError("operand cannot be used as a dimension id");
@@ -301,19 +293,18 @@
return success();
}
-// The result of the affine apply operation can be used as a dimension id if it
-// is a CFG value or if it is an Value, and all the operands are valid
-// dimension ids.
+// The result of the affine apply operation can be used as a dimension id if all
+// its operands are valid dimension ids.
bool AffineApplyOp::isValidDim() {
return llvm::all_of(getOperands(),
- [](Value *op) { return mlir::isValidDim(op); });
+ [](Value op) { return mlir::isValidDim(op); });
}
-// The result of the affine apply operation can be used as a symbol if it is
-// a CFG value or if it is an Value, and all the operands are symbols.
+// The result of the affine apply operation can be used as a symbol if all its
+// operands are symbols.
bool AffineApplyOp::isValidSymbol() {
return llvm::all_of(getOperands(),
- [](Value *op) { return mlir::isValidSymbol(op); });
+ [](Value op) { return mlir::isValidSymbol(op); });
}
OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) {
@@ -333,8 +324,8 @@
return result[0];
}
-AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value *v) {
- DenseMap<Value *, unsigned>::iterator iterPos;
+AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value v) {
+ DenseMap<Value, unsigned>::iterator iterPos;
bool inserted = false;
std::tie(iterPos, inserted) =
dimValueToPosition.insert(std::make_pair(v, dimValueToPosition.size()));
@@ -347,7 +338,7 @@
AffineMap AffineApplyNormalizer::renumber(const AffineApplyNormalizer &other) {
SmallVector<AffineExpr, 8> dimRemapping;
- for (auto *v : other.reorderedDims) {
+ for (auto v : other.reorderedDims) {
auto kvp = other.dimValueToPosition.find(v);
if (dimRemapping.size() <= kvp->second)
dimRemapping.resize(kvp->second + 1);
@@ -371,7 +362,7 @@
// Gather the positions of the operands that are produced by an AffineApplyOp.
static llvm::SetVector<unsigned>
-indicesFromAffineApplyOp(ArrayRef<Value *> operands) {
+indicesFromAffineApplyOp(ArrayRef<Value> operands) {
llvm::SetVector<unsigned> res;
for (auto en : llvm::enumerate(operands))
if (isa_and_nonnull<AffineApplyOp>(en.value()->getDefiningOp()))
@@ -393,13 +384,13 @@
// results in better simplifications and foldings. But we should evaluate
// whether this behavior is what we really want after using more.
static AffineMap promoteComposedSymbolsAsDims(AffineMap map,
- ArrayRef<Value *> symbols) {
+ ArrayRef<Value> symbols) {
if (symbols.empty()) {
return map;
}
// Sanity check on symbols.
- for (auto *sym : symbols) {
+ for (auto sym : symbols) {
assert(isValidSymbol(sym) && "Expected only valid symbols");
(void)sym;
}
@@ -446,7 +437,7 @@
/// `(d0)[s0, s1] -> (d0 + s0 + s1)`.
///
/// The result is only equivalent to `(d0)[s0] -> (d0 + 2 * s0)` when
-/// applied to the same mlir::Value* for both s0 and s1.
+/// applied to the same mlir::Value for both s0 and s1.
/// As a consequence mathematical composition of AffineMap always concatenates
/// symbols.
///
@@ -462,7 +453,7 @@
/// benefit potentially big: simpler and more maintainable code for a
/// non-trivial, recursive, procedure.
AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map,
- ArrayRef<Value *> operands)
+ ArrayRef<Value> operands)
: AffineApplyNormalizer() {
static_assert(kMaxAffineApplyDepth > 0, "kMaxAffineApplyDepth must be > 0");
assert(map.getNumInputs() == operands.size() &&
@@ -495,7 +486,7 @@
if (!furtherCompose) {
// 1. Only dispatch dims or symbols.
for (auto en : llvm::enumerate(operands)) {
- auto *t = en.value();
+ auto t = en.value();
assert(t->getType().isIndex());
bool isDim = (en.index() < map.getNumDims());
if (isDim) {
@@ -511,14 +502,14 @@
assert(numDimsBeforeRewrite <= operands.size());
// 2. Compose AffineApplyOps and dispatch dims or symbols.
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
- auto *t = operands[i];
+ auto t = operands[i];
auto affineApply = dyn_cast_or_null<AffineApplyOp>(t->getDefiningOp());
if (affineApply) {
// a. Compose affine.apply operations.
LLVM_DEBUG(affineApply.getOperation()->print(
dbgs() << "\nCompose AffineApplyOp recursively: "));
AffineMap affineApplyMap = affineApply.getAffineMap();
- SmallVector<Value *, 8> affineApplyOperands(
+ SmallVector<Value, 8> affineApplyOperands(
affineApply.getOperands().begin(), affineApply.getOperands().end());
AffineApplyNormalizer normalizer(affineApplyMap, affineApplyOperands);
@@ -570,7 +561,7 @@
}
void AffineApplyNormalizer::normalize(AffineMap *otherMap,
- SmallVectorImpl<Value *> *otherOperands) {
+ SmallVectorImpl<Value> *otherOperands) {
AffineApplyNormalizer other(*otherMap, *otherOperands);
*otherMap = renumber(other);
@@ -584,7 +575,7 @@
/// on `map` and `operands` without creating an AffineApplyOp that needs to be
/// immediately deleted.
static void composeAffineMapAndOperands(AffineMap *map,
- SmallVectorImpl<Value *> *operands) {
+ SmallVectorImpl<Value> *operands) {
AffineApplyNormalizer normalizer(*map, *operands);
auto normalizedMap = normalizer.getAffineMap();
auto normalizedOperands = normalizer.getOperands();
@@ -594,9 +585,9 @@
assert(*map);
}
-void mlir::fullyComposeAffineMapAndOperands(
- AffineMap *map, SmallVectorImpl<Value *> *operands) {
- while (llvm::any_of(*operands, [](Value *v) {
+void mlir::fullyComposeAffineMapAndOperands(AffineMap *map,
+ SmallVectorImpl<Value> *operands) {
+ while (llvm::any_of(*operands, [](Value v) {
return isa_and_nonnull<AffineApplyOp>(v->getDefiningOp());
})) {
composeAffineMapAndOperands(map, operands);
@@ -605,9 +596,9 @@
AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc,
AffineMap map,
- ArrayRef<Value *> operands) {
+ ArrayRef<Value> operands) {
AffineMap normalizedMap = map;
- SmallVector<Value *, 8> normalizedOperands(operands.begin(), operands.end());
+ SmallVector<Value, 8> normalizedOperands(operands.begin(), operands.end());
composeAffineMapAndOperands(&normalizedMap, &normalizedOperands);
assert(normalizedMap);
return b.create<AffineApplyOp>(loc, normalizedMap, normalizedOperands);
@@ -617,7 +608,7 @@
// canonicalizes dims that are valid symbols into actual symbols.
template <class MapOrSet>
static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
- SmallVectorImpl<Value *> *operands) {
+ SmallVectorImpl<Value> *operands) {
if (!mapOrSet || operands->empty())
return;
@@ -625,9 +616,9 @@
"map/set inputs must match number of operands");
auto *context = mapOrSet->getContext();
- SmallVector<Value *, 8> resultOperands;
+ SmallVector<Value, 8> resultOperands;
resultOperands.reserve(operands->size());
- SmallVector<Value *, 8> remappedSymbols;
+ SmallVector<Value, 8> remappedSymbols;
remappedSymbols.reserve(operands->size());
unsigned nextDim = 0;
unsigned nextSym = 0;
@@ -659,9 +650,8 @@
// Works for either an affine map or an integer set.
template <class MapOrSet>
-static void
-canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
- SmallVectorImpl<Value *> *operands) {
+static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
+ SmallVectorImpl<Value> *operands) {
static_assert(std::is_same<MapOrSet, AffineMap>::value ||
std::is_same<MapOrSet, IntegerSet>::value,
"Argument must be either of AffineMap or IntegerSet type");
@@ -686,10 +676,10 @@
auto *context = mapOrSet->getContext();
- SmallVector<Value *, 8> resultOperands;
+ SmallVector<Value, 8> resultOperands;
resultOperands.reserve(operands->size());
- llvm::SmallDenseMap<Value *, AffineExpr, 8> seenDims;
+ llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
unsigned nextDim = 0;
for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
@@ -705,7 +695,7 @@
}
}
}
- llvm::SmallDenseMap<Value *, AffineExpr, 8> seenSymbols;
+ llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
unsigned nextSym = 0;
for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
@@ -738,12 +728,12 @@
}
void mlir::canonicalizeMapAndOperands(AffineMap *map,
- SmallVectorImpl<Value *> *operands) {
+ SmallVectorImpl<Value> *operands) {
canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
}
void mlir::canonicalizeSetAndOperands(IntegerSet *set,
- SmallVectorImpl<Value *> *operands) {
+ SmallVectorImpl<Value> *operands) {
canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
}
@@ -758,7 +748,7 @@
/// Replace the affine op with another instance of it with the supplied
/// map and mapOperands.
void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
- AffineMap map, ArrayRef<Value *> mapOperands) const;
+ AffineMap map, ArrayRef<Value> mapOperands) const;
PatternMatchResult matchAndRewrite(AffineOpTy affineOp,
PatternRewriter &rewriter) const override {
@@ -770,7 +760,7 @@
auto map = affineOp.getAffineMap();
AffineMap oldMap = map;
auto oldOperands = affineOp.getMapOperands();
- SmallVector<Value *, 8> resultOperands(oldOperands);
+ SmallVector<Value, 8> resultOperands(oldOperands);
composeAffineMapAndOperands(&map, &resultOperands);
if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
resultOperands.begin()))
@@ -786,14 +776,14 @@
template <>
void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
- ArrayRef<Value *> mapOperands) const {
+ ArrayRef<Value> mapOperands) const {
rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map,
mapOperands);
}
template <>
void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
- ArrayRef<Value *> mapOperands) const {
+ ArrayRef<Value> mapOperands) const {
rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
prefetch, prefetch.memref(), map, mapOperands,
prefetch.localityHint().getZExtValue(), prefetch.isWrite(),
@@ -802,14 +792,14 @@
template <>
void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
- ArrayRef<Value *> mapOperands) const {
+ ArrayRef<Value> mapOperands) const {
rewriter.replaceOpWithNewOp<AffineStoreOp>(
store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
}
template <>
void SimplifyAffineOp<AffineApplyOp>::replaceAffineOp(
PatternRewriter &rewriter, AffineApplyOp apply, AffineMap map,
- ArrayRef<Value *> mapOperands) const {
+ ArrayRef<Value> mapOperands) const {
rewriter.replaceOpWithNewOp<AffineApplyOp>(apply, map, mapOperands);
}
} // end anonymous namespace.
@@ -844,12 +834,12 @@
// TODO(b/133776335) Check that map operands are loop IVs or symbols.
void AffineDmaStartOp::build(Builder *builder, OperationState &result,
- Value *srcMemRef, AffineMap srcMap,
- ValueRange srcIndices, Value *destMemRef,
+ Value srcMemRef, AffineMap srcMap,
+ ValueRange srcIndices, Value destMemRef,
AffineMap dstMap, ValueRange destIndices,
- Value *tagMemRef, AffineMap tagMap,
- ValueRange tagIndices, Value *numElements,
- Value *stride, Value *elementsPerStride) {
+ Value tagMemRef, AffineMap tagMap,
+ ValueRange tagIndices, Value numElements,
+ Value stride, Value elementsPerStride) {
result.addOperands(srcMemRef);
result.addAttribute(getSrcMapAttrName(), AffineMapAttr::get(srcMap));
result.addOperands(srcIndices);
@@ -980,19 +970,19 @@
return emitOpError("incorrect number of operands");
}
- for (auto *idx : getSrcIndices()) {
+ for (auto idx : getSrcIndices()) {
if (!idx->getType().isIndex())
return emitOpError("src index to dma_start must have 'index' type");
if (!isValidAffineIndexOperand(idx))
return emitOpError("src index must be a dimension or symbol identifier");
}
- for (auto *idx : getDstIndices()) {
+ for (auto idx : getDstIndices()) {
if (!idx->getType().isIndex())
return emitOpError("dst index to dma_start must have 'index' type");
if (!isValidAffineIndexOperand(idx))
return emitOpError("dst index must be a dimension or symbol identifier");
}
- for (auto *idx : getTagIndices()) {
+ for (auto idx : getTagIndices()) {
if (!idx->getType().isIndex())
return emitOpError("tag index to dma_start must have 'index' type");
if (!isValidAffineIndexOperand(idx))
@@ -1013,8 +1003,8 @@
// TODO(b/133776335) Check that map operands are loop IVs or symbols.
void AffineDmaWaitOp::build(Builder *builder, OperationState &result,
- Value *tagMemRef, AffineMap tagMap,
- ValueRange tagIndices, Value *numElements) {
+ Value tagMemRef, AffineMap tagMap,
+ ValueRange tagIndices, Value numElements) {
result.addOperands(tagMemRef);
result.addAttribute(getTagMapAttrName(), AffineMapAttr::get(tagMap));
result.addOperands(tagIndices);
@@ -1023,7 +1013,7 @@
void AffineDmaWaitOp::print(OpAsmPrinter &p) {
p << "affine.dma_wait " << *getTagMemRef() << '[';
- SmallVector<Value *, 2> operands(getTagIndices());
+ SmallVector<Value, 2> operands(getTagIndices());
p.printAffineMapOfSSAIds(getTagMapAttr(), operands);
p << "], ";
p.printOperand(getNumElements());
@@ -1068,7 +1058,7 @@
LogicalResult AffineDmaWaitOp::verify() {
if (!getOperand(0)->getType().isa<MemRefType>())
return emitOpError("expected DMA tag to be of memref type");
- for (auto *idx : getTagIndices()) {
+ for (auto idx : getTagIndices()) {
if (!idx->getType().isIndex())
return emitOpError("index to dma_wait must have 'index' type");
if (!isValidAffineIndexOperand(idx))
@@ -1368,7 +1358,7 @@
SmallVector<Attribute, 8> operandConstants;
auto boundOperands =
lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
- for (auto *operand : boundOperands) {
+ for (auto operand : boundOperands) {
Attribute operandCst;
matchPattern(operand, m_Constant(&operandCst));
operandConstants.push_back(operandCst);
@@ -1408,8 +1398,8 @@
/// Canonicalize the bounds of the given loop.
static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
- SmallVector<Value *, 4> lbOperands(forOp.getLowerBoundOperands());
- SmallVector<Value *, 4> ubOperands(forOp.getUpperBoundOperands());
+ SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
+ SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
auto lbMap = forOp.getLowerBoundMap();
auto ubMap = forOp.getUpperBoundMap();
@@ -1474,7 +1464,7 @@
assert(lbOperands.size() == map.getNumInputs());
assert(map.getNumResults() >= 1 && "bound map has at least one result");
- SmallVector<Value *, 4> newOperands(lbOperands.begin(), lbOperands.end());
+ SmallVector<Value, 4> newOperands(lbOperands.begin(), lbOperands.end());
auto ubOperands = getUpperBoundOperands();
newOperands.append(ubOperands.begin(), ubOperands.end());
@@ -1487,7 +1477,7 @@
assert(ubOperands.size() == map.getNumInputs());
assert(map.getNumResults() >= 1 && "bound map has at least one result");
- SmallVector<Value *, 4> newOperands(getLowerBoundOperands());
+ SmallVector<Value, 4> newOperands(getLowerBoundOperands());
newOperands.append(ubOperands.begin(), ubOperands.end());
getOperation()->setOperands(newOperands);
@@ -1553,7 +1543,7 @@
unsigned numOperands = lbMap.getNumInputs();
for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
- // Compare Value *'s.
+ // Compare Value 's.
if (getOperand(i) != getOperand(numOperands + i))
return false;
}
@@ -1562,7 +1552,7 @@
Region &AffineForOp::getLoopBody() { return region(); }
-bool AffineForOp::isDefinedOutsideOfLoop(Value *value) {
+bool AffineForOp::isDefinedOutsideOfLoop(Value value) {
return !region().isAncestor(value->getParentRegion());
}
@@ -1573,14 +1563,14 @@
}
/// Returns if the provided value is the induction variable of a AffineForOp.
-bool mlir::isForInductionVar(Value *val) {
+bool mlir::isForInductionVar(Value val) {
return getForInductionVarOwner(val) != AffineForOp();
}
/// Returns the loop parent of an induction variable. If the provided value is
/// not an induction variable, then return nullptr.
-AffineForOp mlir::getForInductionVarOwner(Value *val) {
- auto *ivArg = dyn_cast<BlockArgument>(val);
+AffineForOp mlir::getForInductionVarOwner(Value val) {
+ auto ivArg = val.dyn_cast<BlockArgument>();
if (!ivArg || !ivArg->getOwner())
return AffineForOp();
auto *containingInst = ivArg->getOwner()->getParent()->getParentOp();
@@ -1590,7 +1580,7 @@
/// Extracts the induction variables from a list of AffineForOps and returns
/// them.
void mlir::extractForInductionVars(ArrayRef<AffineForOp> forInsts,
- SmallVectorImpl<Value *> *ivs) {
+ SmallVectorImpl<Value> *ivs) {
ivs->reserve(forInsts.size());
for (auto forInst : forInsts)
ivs->push_back(forInst.getInductionVar());
@@ -1729,7 +1719,7 @@
LogicalResult AffineIfOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
auto set = getIntegerSet();
- SmallVector<Value *, 4> operands(getOperands());
+ SmallVector<Value, 4> operands(getOperands());
canonicalizeSetAndOperands(&set, &operands);
// Any canonicalization change always leads to either a reduction in the
@@ -1758,8 +1748,8 @@
result.types.push_back(memrefType.getElementType());
}
-void AffineLoadOp::build(Builder *builder, OperationState &result,
- Value *memref, AffineMap map, ValueRange mapOperands) {
+void AffineLoadOp::build(Builder *builder, OperationState &result, Value memref,
+ AffineMap map, ValueRange mapOperands) {
assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
result.addOperands(memref);
result.addOperands(mapOperands);
@@ -1768,8 +1758,8 @@
result.types.push_back(memrefType.getElementType());
}
-void AffineLoadOp::build(Builder *builder, OperationState &result,
- Value *memref, ValueRange indices) {
+void AffineLoadOp::build(Builder *builder, OperationState &result, Value memref,
+ ValueRange indices) {
auto memrefType = memref->getType().cast<MemRefType>();
auto rank = memrefType.getRank();
// Create identity map for memrefs with at least one dimension or () -> ()
@@ -1825,7 +1815,7 @@
"expects the number of subscripts to be equal to memref rank");
}
- for (auto *idx : getMapOperands()) {
+ for (auto idx : getMapOperands()) {
if (!idx->getType().isIndex())
return emitOpError("index to load must have 'index' type");
if (!isValidAffineIndexOperand(idx))
@@ -1851,7 +1841,7 @@
//===----------------------------------------------------------------------===//
void AffineStoreOp::build(Builder *builder, OperationState &result,
- Value *valueToStore, Value *memref, AffineMap map,
+ Value valueToStore, Value memref, AffineMap map,
ValueRange mapOperands) {
assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
result.addOperands(valueToStore);
@@ -1862,7 +1852,7 @@
// Use identity map.
void AffineStoreOp::build(Builder *builder, OperationState &result,
- Value *valueToStore, Value *memref,
+ Value valueToStore, Value memref,
ValueRange indices) {
auto memrefType = memref->getType().cast<MemRefType>();
auto rank = memrefType.getRank();
@@ -1923,7 +1913,7 @@
"expects the number of subscripts to be equal to memref rank");
}
- for (auto *idx : getMapOperands()) {
+ for (auto idx : getMapOperands()) {
if (!idx->getType().isIndex())
return emitOpError("index to store must have 'index' type");
if (!isValidAffineIndexOperand(idx))
@@ -2072,7 +2062,7 @@
p << AffinePrefetchOp::getOperationName() << " " << *op.memref() << '[';
AffineMapAttr mapAttr = op.getAttrOfType<AffineMapAttr>(op.getMapAttrName());
if (mapAttr) {
- SmallVector<Value *, 2> operands(op.getMapOperands());
+ SmallVector<Value, 2> operands(op.getMapOperands());
p.printAffineMapOfSSAIds(mapAttr, operands);
}
p << ']' << ", " << (op.isWrite() ? "write" : "read") << ", "
@@ -2099,7 +2089,7 @@
return op.emitOpError("too few operands");
}
- for (auto *idx : op.getMapOperands()) {
+ for (auto idx : op.getMapOperands()) {
if (!isValidAffineIndexOperand(idx))
return op.emitOpError("index must be a dimension or symbol identifier");
}
diff --git a/third_party/mlir/lib/Dialect/AffineOps/DialectRegistration.cpp b/third_party/mlir/lib/Dialect/AffineOps/DialectRegistration.cpp
index 9197e3c..775e25e 100644
--- a/third_party/mlir/lib/Dialect/AffineOps/DialectRegistration.cpp
+++ b/third_party/mlir/lib/Dialect/AffineOps/DialectRegistration.cpp
@@ -1,19 +1,10 @@
//===- DialectRegistration.cpp - Register Affine Op dialect ---------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/AffineOps/AffineOps.h"
using namespace mlir;
diff --git a/third_party/mlir/lib/Dialect/FxpMathOps/IR/DialectRegistration.cpp b/third_party/mlir/lib/Dialect/FxpMathOps/IR/DialectRegistration.cpp
index aa6782e..57d5ae8 100644
--- a/third_party/mlir/lib/Dialect/FxpMathOps/IR/DialectRegistration.cpp
+++ b/third_party/mlir/lib/Dialect/FxpMathOps/IR/DialectRegistration.cpp
@@ -1,19 +1,10 @@
//===- DialectRegistration.cpp - Register FxpMathOps dialect --------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/FxpMathOps/FxpMathOps.h"
diff --git a/third_party/mlir/lib/Dialect/FxpMathOps/IR/FxpMathOps.cpp b/third_party/mlir/lib/Dialect/FxpMathOps/IR/FxpMathOps.cpp
index 18c07b0..30e7dc0 100644
--- a/third_party/mlir/lib/Dialect/FxpMathOps/IR/FxpMathOps.cpp
+++ b/third_party/mlir/lib/Dialect/FxpMathOps/IR/FxpMathOps.cpp
@@ -1,19 +1,10 @@
//===- FxpMathOps.cpp - Op implementation for FxpMathOps ------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/FxpMathOps/FxpMathOps.h"
#include "mlir/Dialect/QuantOps/QuantTypes.h"
diff --git a/third_party/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/third_party/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
index 3982a6a..df6015d 100644
--- a/third_party/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
+++ b/third_party/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp
@@ -1,19 +1,10 @@
//===- LowerUniformRealMath.cpp ------------------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "UniformKernelUtils.h"
@@ -46,9 +37,9 @@
// Dequantize
//===----------------------------------------------------------------------===//
-static Value *emitUniformPerLayerDequantize(Location loc, Value *input,
- UniformQuantizedType elementType,
- PatternRewriter &rewriter) {
+static Value emitUniformPerLayerDequantize(Location loc, Value input,
+ UniformQuantizedType elementType,
+ PatternRewriter &rewriter) {
// Pre-conditions.
if (!elementType.isSigned()) {
// TODO: Support unsigned storage type.
@@ -71,7 +62,7 @@
// Apply zero-point offset.
if (elementType.getZeroPoint() != 0) {
- Value *negZeroPointConst = rewriter.create<ConstantOp>(
+ Value negZeroPointConst = rewriter.create<ConstantOp>(
loc, broadcastScalarConstIntValue(intermediateType,
-elementType.getZeroPoint()));
input = rewriter.create<AddIOp>(loc, input, negZeroPointConst);
@@ -81,14 +72,14 @@
input = rewriter.create<ConvertISToFOp>(loc, realType, input);
// Mul by scale.
- Value *scaleConst = rewriter.create<ConstantOp>(
+ Value scaleConst = rewriter.create<ConstantOp>(
loc, broadcastScalarConstFloatValue(realType,
APFloat(elementType.getScale())));
return rewriter.create<MulFOp>(loc, input, scaleConst);
}
-static Value *
-emitUniformPerAxisDequantize(Location loc, Value *input,
+static Value
+emitUniformPerAxisDequantize(Location loc, Value input,
UniformQuantizedPerAxisType elementType,
PatternRewriter &rewriter) {
// TODO: Support per-axis dequantize.
@@ -97,8 +88,8 @@
return nullptr;
}
-static Value *emitDequantize(Location loc, Value *input,
- PatternRewriter &rewriter) {
+static Value emitDequantize(Location loc, Value input,
+ PatternRewriter &rewriter) {
Type inputType = input->getType();
QuantizedType qElementType =
QuantizedType::getQuantizedElementType(inputType);
@@ -133,7 +124,7 @@
return matchFailure();
}
- Value *dequantizedValue = emitDequantize(op.getLoc(), op.arg(), rewriter);
+ Value dequantizedValue = emitDequantize(op.getLoc(), op.arg(), rewriter);
if (!dequantizedValue) {
return matchFailure();
}
@@ -170,14 +161,14 @@
castElementType(info.resultStorageType, intermediateElementType);
// Cast operands to storage type.
- Value *lhsValue = rewriter
- .create<StorageCastOp>(info.op->getLoc(),
- info.lhsStorageType, info.lhs)
- .getResult();
- Value *rhsValue = rewriter
- .create<StorageCastOp>(info.op->getLoc(),
- info.rhsStorageType, info.rhs)
- .getResult();
+ Value lhsValue = rewriter
+ .create<StorageCastOp>(info.op->getLoc(),
+ info.lhsStorageType, info.lhs)
+ .getResult();
+ Value rhsValue = rewriter
+ .create<StorageCastOp>(info.op->getLoc(),
+ info.rhsStorageType, info.rhs)
+ .getResult();
// Cast to the intermediate sized type.
lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
@@ -186,7 +177,7 @@
rhsValue);
// Add.
- Value *resultValue =
+ Value resultValue =
rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, rhsValue);
// Zero point offset adjustment.
@@ -194,7 +185,7 @@
// zpOffset = -zp
int zpOffset = -1 * info.resultType.getZeroPoint();
if (zpOffset != 0) {
- Value *zpOffsetConst = rewriter.create<ConstantOp>(
+ Value zpOffsetConst = rewriter.create<ConstantOp>(
info.op->getLoc(),
broadcastScalarConstIntValue(intermediateType, zpOffset));
resultValue =
@@ -246,14 +237,14 @@
castElementType(info.resultStorageType, intermediateElementType);
// Cast operands to storage type.
- Value *lhsValue = rewriter
- .create<StorageCastOp>(info.op->getLoc(),
- info.lhsStorageType, info.lhs)
- .getResult();
- Value *rhsValue = rewriter
- .create<StorageCastOp>(info.op->getLoc(),
- info.rhsStorageType, info.rhs)
- .getResult();
+ Value lhsValue = rewriter
+ .create<StorageCastOp>(info.op->getLoc(),
+ info.lhsStorageType, info.lhs)
+ .getResult();
+ Value rhsValue = rewriter
+ .create<StorageCastOp>(info.op->getLoc(),
+ info.rhsStorageType, info.rhs)
+ .getResult();
// Cast to the intermediate sized type.
lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
@@ -263,7 +254,7 @@
// Apply argument zeroPoints.
if (info.lhsType.getZeroPoint() != 0) {
- Value *zpOffsetConst = rewriter.create<ConstantOp>(
+ Value zpOffsetConst = rewriter.create<ConstantOp>(
info.op->getLoc(), broadcastScalarConstIntValue(
intermediateType, -info.lhsType.getZeroPoint()));
lhsValue =
@@ -271,7 +262,7 @@
}
if (info.rhsType.getZeroPoint() != 0) {
- Value *zpOffsetConst = rewriter.create<ConstantOp>(
+ Value zpOffsetConst = rewriter.create<ConstantOp>(
info.op->getLoc(), broadcastScalarConstIntValue(
intermediateType, -info.rhsType.getZeroPoint()));
rhsValue =
@@ -279,7 +270,7 @@
}
// Mul.
- Value *resultValue =
+ Value resultValue =
rewriter.create<MulIOp>(info.op->getLoc(), lhsValue, rhsValue);
// Scale output.
@@ -293,7 +284,7 @@
// Zero point offset adjustment.
if (info.resultType.getZeroPoint() != 0) {
- Value *zpOffsetConst = rewriter.create<ConstantOp>(
+ Value zpOffsetConst = rewriter.create<ConstantOp>(
info.op->getLoc(),
broadcastScalarConstIntValue(intermediateType,
info.resultType.getZeroPoint()));
diff --git a/third_party/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h b/third_party/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h
index 955e2ec..8cea97c 100644
--- a/third_party/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h
+++ b/third_party/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h
@@ -1,19 +1,10 @@
//===- UniformKernelUtils.h - Utilities for lowering uniform math - C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_
#define MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_
@@ -59,7 +50,7 @@
/// Helper class for operating on binary operations where all operands
/// and the result are a UniformQuantizedType.
struct UniformBinaryOpInfo {
- UniformBinaryOpInfo(Operation *op, Value *lhs, Value *rhs,
+ UniformBinaryOpInfo(Operation *op, Value lhs, Value rhs,
Optional<APFloat> clampMin, Optional<APFloat> clampMax)
: op(op), lhs(lhs), rhs(rhs), clampMin(clampMin), clampMax(clampMax),
lhsType(getUniformElementType(lhs->getType())),
@@ -128,8 +119,8 @@
}
Operation *op;
- Value *lhs;
- Value *rhs;
+ Value lhs;
+ Value rhs;
Optional<APFloat> clampMin;
Optional<APFloat> clampMax;
diff --git a/third_party/mlir/lib/Dialect/GPU/IR/DialectRegistration.cpp b/third_party/mlir/lib/Dialect/GPU/IR/DialectRegistration.cpp
index af50d02..511c69e 100644
--- a/third_party/mlir/lib/Dialect/GPU/IR/DialectRegistration.cpp
+++ b/third_party/mlir/lib/Dialect/GPU/IR/DialectRegistration.cpp
@@ -1,19 +1,10 @@
//===- DialectRegistration.cpp - MLIR GPU dialect registration ------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/GPU/GPUDialect.h"
diff --git a/third_party/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/third_party/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 7324b96..bda8032 100644
--- a/third_party/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/third_party/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1,19 +1,10 @@
//===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements the GPU kernel-related dialect and its operations.
//
@@ -145,7 +136,7 @@
if (!allReduce.body().empty()) {
if (allReduce.body().front().getNumArguments() != 2)
return allReduce.emitError("expected two region arguments");
- for (auto *argument : allReduce.body().front().getArguments()) {
+ for (auto argument : allReduce.body().front().getArguments()) {
if (argument->getType() != allReduce.getType())
return allReduce.emitError("incorrect region argument type");
}
@@ -165,6 +156,47 @@
return success();
}
+static LogicalResult verifyShuffleOp(gpu::ShuffleOp shuffleOp) {
+ auto type = shuffleOp.value()->getType();
+ if (shuffleOp.result()->getType() != type) {
+ return shuffleOp.emitOpError()
+ << "requires the same type for value operand and result";
+ }
+ if (!type.isIntOrFloat() || type.getIntOrFloatBitWidth() != 32) {
+ return shuffleOp.emitOpError()
+ << "requires value operand type to be f32 or i32";
+ }
+ return success();
+}
+
+static void printShuffleOp(OpAsmPrinter &p, ShuffleOp op) {
+ p << ShuffleOp::getOperationName() << ' ';
+ p.printOperands(op.getOperands());
+ p << ' ' << op.mode() << " : ";
+ p.printType(op.value()->getType());
+}
+
+static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &state) {
+ SmallVector<OpAsmParser::OperandType, 3> operandInfo;
+ if (parser.parseOperandList(operandInfo, 3))
+ return failure();
+
+ StringRef mode;
+ if (parser.parseKeyword(&mode))
+ return failure();
+ state.addAttribute("mode", parser.getBuilder().getStringAttr(mode));
+
+ Type valueType;
+ Type int32Type = parser.getBuilder().getIntegerType(32);
+ Type int1Type = parser.getBuilder().getI1Type();
+ if (parser.parseColonType(valueType) ||
+ parser.resolveOperands(operandInfo, {valueType, int32Type, int32Type},
+ parser.getCurrentLocation(), state.operands) ||
+ parser.addTypesToList({valueType, int1Type}, state.types))
+ return failure();
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// LaunchOp
//===----------------------------------------------------------------------===//
@@ -172,15 +204,14 @@
static SmallVector<Type, 4> getValueTypes(ValueRange values) {
SmallVector<Type, 4> types;
types.reserve(values.size());
- for (Value *v : values)
+ for (Value v : values)
types.push_back(v->getType());
return types;
}
-void LaunchOp::build(Builder *builder, OperationState &result, Value *gridSizeX,
- Value *gridSizeY, Value *gridSizeZ, Value *blockSizeX,
- Value *blockSizeY, Value *blockSizeZ,
- ValueRange operands) {
+void LaunchOp::build(Builder *builder, OperationState &result, Value gridSizeX,
+ Value gridSizeY, Value gridSizeZ, Value blockSizeX,
+ Value blockSizeY, Value blockSizeZ, ValueRange operands) {
// Add grid and block sizes as op operands, followed by the data operands.
result.addOperands(
{gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ});
@@ -440,7 +471,8 @@
PatternMatchResult matchAndRewrite(LaunchOp launchOp,
PatternRewriter &rewriter) const override {
- auto origInsertionPoint = rewriter.saveInsertionPoint();
+ rewriter.startRootUpdate(launchOp);
+ PatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&launchOp.body().front());
// Traverse operands passed to kernel and check if some of them are known
@@ -448,31 +480,29 @@
// and use it instead of passing the value from the parent region. Perform
// the traversal in the inverse order to simplify index arithmetics when
// dropping arguments.
- SmallVector<Value *, 8> operands(launchOp.getKernelOperandValues().begin(),
- launchOp.getKernelOperandValues().end());
- SmallVector<Value *, 8> kernelArgs(launchOp.getKernelArguments().begin(),
- launchOp.getKernelArguments().end());
+ auto operands = launchOp.getKernelOperandValues();
+ auto kernelArgs = launchOp.getKernelArguments();
bool found = false;
for (unsigned i = operands.size(); i > 0; --i) {
unsigned index = i - 1;
- Value *operand = operands[index];
- if (!isa_and_nonnull<ConstantOp>(operand->getDefiningOp())) {
+ Value operand = operands[index];
+ if (!isa_and_nonnull<ConstantOp>(operand->getDefiningOp()))
continue;
- }
found = true;
- Value *internalConstant =
+ Value internalConstant =
rewriter.clone(*operand->getDefiningOp())->getResult(0);
- Value *kernelArg = kernelArgs[index];
+ Value kernelArg = *std::next(kernelArgs.begin(), index);
kernelArg->replaceAllUsesWith(internalConstant);
launchOp.eraseKernelArgument(index);
}
- rewriter.restoreInsertionPoint(origInsertionPoint);
- if (!found)
+ if (!found) {
+ rewriter.cancelRootUpdate(launchOp);
return matchFailure();
+ }
- rewriter.updatedRootInPlace(launchOp);
+ rewriter.finalizeRootUpdate(launchOp);
return matchSuccess();
}
};
@@ -488,10 +518,9 @@
//===----------------------------------------------------------------------===//
void LaunchFuncOp::build(Builder *builder, OperationState &result,
- GPUFuncOp kernelFunc, Value *gridSizeX,
- Value *gridSizeY, Value *gridSizeZ, Value *blockSizeX,
- Value *blockSizeY, Value *blockSizeZ,
- ValueRange kernelOperands) {
+ GPUFuncOp kernelFunc, Value gridSizeX, Value gridSizeY,
+ Value gridSizeZ, Value blockSizeX, Value blockSizeY,
+ Value blockSizeZ, ValueRange kernelOperands) {
// Add grid and block sizes as op operands, followed by the data operands.
result.addOperands(
{gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ});
@@ -524,7 +553,7 @@
.getRootReference();
}
-Value *LaunchFuncOp::getKernelOperand(unsigned i) {
+Value LaunchFuncOp::getKernelOperand(unsigned i) {
return getOperation()->getOperand(i + kNumConfigOperands);
}
@@ -687,13 +716,13 @@
}
static void printAttributions(OpAsmPrinter &p, StringRef keyword,
- ArrayRef<BlockArgument *> values) {
+ ArrayRef<BlockArgument> values) {
if (values.empty())
return;
p << ' ' << keyword << '(';
interleaveComma(values, p,
- [&p](BlockArgument *v) { p << *v << " : " << v->getType(); });
+ [&p](BlockArgument v) { p << *v << " : " << v->getType(); });
p << ')';
}
@@ -740,9 +769,9 @@
}
static LogicalResult verifyAttributions(Operation *op,
- ArrayRef<BlockArgument *> attributions,
+ ArrayRef<BlockArgument> attributions,
unsigned memorySpace) {
- for (Value *v : attributions) {
+ for (Value v : attributions) {
auto type = v->getType().dyn_cast<MemRefType>();
if (!type)
return op->emitOpError() << "expected memref type in attribution";
diff --git a/third_party/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/third_party/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index 0a6a591..2d00ac0 100644
--- a/third_party/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/third_party/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -1,19 +1,10 @@
//===- KernelOutlining.cpp - Implementation of GPU kernel outlining -------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements the GPU dialect kernel outlining pass.
//
@@ -31,10 +22,10 @@
template <typename OpTy>
static void createForAllDimensions(OpBuilder &builder, Location loc,
- SmallVectorImpl<Value *> &values) {
+ SmallVectorImpl<Value> &values) {
for (StringRef dim : {"x", "y", "z"}) {
- Value *v = builder.create<OpTy>(loc, builder.getIndexType(),
- builder.getStringAttr(dim));
+ Value v = builder.create<OpTy>(loc, builder.getIndexType(),
+ builder.getStringAttr(dim));
values.push_back(v);
}
}
@@ -46,7 +37,7 @@
OpBuilder builder(loc->getContext());
Block &firstBlock = body.front();
builder.setInsertionPointToStart(&firstBlock);
- SmallVector<Value *, 12> indexOps;
+ SmallVector<Value, 12> indexOps;
createForAllDimensions<gpu::BlockIdOp>(builder, loc, indexOps);
createForAllDimensions<gpu::ThreadIdOp>(builder, loc, indexOps);
createForAllDimensions<gpu::GridDimOp>(builder, loc, indexOps);
@@ -69,7 +60,7 @@
gpu::LaunchFuncOp launch) {
OpBuilder kernelBuilder(kernelFunc.getBody());
auto &firstBlock = kernelFunc.getBody().front();
- SmallVector<Value *, 8> newLaunchArgs;
+ SmallVector<Value, 8> newLaunchArgs;
BlockAndValueMapping map;
for (int i = 0, e = launch.getNumKernelOperands(); i < e; ++i) {
map.map(launch.getKernelOperand(i), kernelFunc.getArgument(i));
@@ -82,7 +73,7 @@
}
// Only inline operations that do not create new arguments.
if (!llvm::all_of(operandOp->getOperands(),
- [map](Value *value) { return map.contains(value); })) {
+ [map](Value value) { return map.contains(value); })) {
continue;
}
auto clone = kernelBuilder.clone(*operandOp, map);
diff --git a/third_party/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/third_party/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index abbc4e0..71b7064 100644
--- a/third_party/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/third_party/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1,19 +1,10 @@
//===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines the types and operation details for the LLVM IR dialect in
// MLIR, and the LLVM IR dialect. It also registers the dialect.
@@ -415,7 +406,7 @@
// Expects vector to be of wrapped LLVM vector type and position to be of
// wrapped LLVM i32 type.
void LLVM::ExtractElementOp::build(Builder *b, OperationState &result,
- Value *vector, Value *position,
+ Value vector, Value position,
ArrayRef<NamedAttribute> attrs) {
auto wrappedVectorType = vector->getType().cast<LLVM::LLVMType>();
auto llvmType = wrappedVectorType.getVectorElementType();
@@ -681,7 +672,7 @@
// attribute-dict?
static ParseResult parseBrOp(OpAsmParser &parser, OperationState &result) {
Block *dest;
- SmallVector<Value *, 4> operands;
+ SmallVector<Value, 4> operands;
if (parser.parseSuccessorAndUseList(dest, operands) ||
parser.parseOptionalAttrDict(result.attributes))
return failure();
@@ -708,8 +699,8 @@
static ParseResult parseCondBrOp(OpAsmParser &parser, OperationState &result) {
Block *trueDest;
Block *falseDest;
- SmallVector<Value *, 4> trueOperands;
- SmallVector<Value *, 4> falseOperands;
+ SmallVector<Value, 4> trueOperands;
+ SmallVector<Value, 4> falseOperands;
OpAsmParser::OperandType condition;
Builder &builder = parser.getBuilder();
@@ -1066,8 +1057,8 @@
//===----------------------------------------------------------------------===//
// Expects vector to be of wrapped LLVM vector type and position to be of
// wrapped LLVM i32 type.
-void LLVM::ShuffleVectorOp::build(Builder *b, OperationState &result, Value *v1,
- Value *v2, ArrayAttr mask,
+void LLVM::ShuffleVectorOp::build(Builder *b, OperationState &result, Value v1,
+ Value v2, ArrayAttr mask,
ArrayRef<NamedAttribute> attrs) {
auto wrappedContainerType1 = v1->getType().cast<LLVM::LLVMType>();
auto vType = LLVMType::getVectorTy(
@@ -1115,9 +1106,23 @@
}
//===----------------------------------------------------------------------===//
-// Builder, printer and verifier for LLVM::LLVMFuncOp.
+// Implementations for LLVM::LLVMFuncOp.
//===----------------------------------------------------------------------===//
+// Add the entry block to the function.
+Block *LLVMFuncOp::addEntryBlock() {
+ assert(empty() && "function already has an entry block");
+ assert(!isVarArg() && "unimplemented: non-external variadic functions");
+
+ auto *entry = new Block;
+ push_back(entry);
+
+ LLVMType type = getType();
+ for (unsigned i = 0, e = type.getFunctionNumParams(); i < e; ++i)
+ entry->addArgument(type.getFunctionParamType(i));
+ return entry;
+}
+
void LLVMFuncOp::build(Builder *builder, OperationState &result, StringRef name,
LLVMType type, LLVM::Linkage linkage,
ArrayRef<NamedAttribute> attrs,
@@ -1650,10 +1655,10 @@
// Utility functions.
//===----------------------------------------------------------------------===//
-Value *mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
- StringRef name, StringRef value,
- LLVM::Linkage linkage,
- LLVM::LLVMDialect *llvmDialect) {
+Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
+ StringRef name, StringRef value,
+ LLVM::Linkage linkage,
+ LLVM::LLVMDialect *llvmDialect) {
assert(builder.getInsertionBlock() &&
builder.getInsertionBlock()->getParentOp() &&
"expected builder to point to a block constrained in an op");
@@ -1670,13 +1675,13 @@
builder.getStringAttr(value));
// Get the pointer to the first character in the global string.
- Value *globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
- Value *cst0 = builder.create<LLVM::ConstantOp>(
+ Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
+ Value cst0 = builder.create<LLVM::ConstantOp>(
loc, LLVM::LLVMType::getInt64Ty(llvmDialect),
builder.getIntegerAttr(builder.getIndexType(), 0));
- return builder.create<LLVM::GEPOp>(
- loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr,
- ArrayRef<Value *>({cst0, cst0}));
+ return builder.create<LLVM::GEPOp>(loc,
+ LLVM::LLVMType::getInt8PtrTy(llvmDialect),
+ globalPtr, ArrayRef<Value>({cst0, cst0}));
}
bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
diff --git a/third_party/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/third_party/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index e4708fb..3a8e84e 100644
--- a/third_party/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/third_party/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1,19 +1,10 @@
//===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines the types and operation details for the NVVM IR dialect in
// MLIR, and the LLVM IR dialect. It also registers the dialect.
diff --git a/third_party/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/third_party/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
index 30c55b5..c11572c 100644
--- a/third_party/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
+++ b/third_party/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
@@ -1,19 +1,10 @@
//===- ROCDLDialect.cpp - ROCDL IR Ops and Dialect registration -----------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines the types and operation details for the ROCDL IR dialect in
// MLIR, and the LLVM IR dialect. It also registers the dialect.
diff --git a/third_party/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/third_party/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
index d7e4d08..e8667f0 100644
--- a/third_party/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
+++ b/third_party/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
@@ -1,19 +1,10 @@
//===- DependenceAnalysis.cpp - Dependence analysis on SSA views ----------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements view-based alias and dependence analyses.
//
@@ -49,8 +40,8 @@
llvm_unreachable("Unexpected DependenceType");
}
-Value *Aliases::find(Value *v) {
- if (isa<BlockArgument>(v))
+Value Aliases::find(Value v) {
+ if (v.isa<BlockArgument>())
return v;
auto it = aliases.find(v);
@@ -60,7 +51,7 @@
}
while (true) {
- if (isa<BlockArgument>(v))
+ if (v.isa<BlockArgument>())
return v;
if (auto alloc = dyn_cast_or_null<AllocOp>(v->getDefiningOp())) {
if (isStrided(alloc.getType()))
@@ -147,9 +138,9 @@
}
void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
- for (auto *srcView : src.getOutputs()) { // W
+ for (auto srcView : src.getOutputs()) { // W
// RAW graph
- for (auto *dstView : dst.getInputs()) { // R
+ for (auto dstView : dst.getInputs()) { // R
if (aliases.alias(srcView, dstView)) { // if alias, fill RAW
addDependenceElem(DependenceType::RAW,
LinalgOpView{src.getOperation(), srcView},
@@ -157,7 +148,7 @@
}
}
// WAW graph
- for (auto *dstView : dst.getOutputs()) { // W
+ for (auto dstView : dst.getOutputs()) { // W
if (aliases.alias(srcView, dstView)) { // if alias, fill WAW
addDependenceElem(DependenceType::WAW,
LinalgOpView{src.getOperation(), srcView},
@@ -165,9 +156,9 @@
}
}
}
- for (auto *srcView : src.getInputs()) { // R
+ for (auto srcView : src.getInputs()) { // R
// RAR graph
- for (auto *dstView : dst.getInputs()) { // R
+ for (auto dstView : dst.getInputs()) { // R
if (aliases.alias(srcView, dstView)) { // if alias, fill RAR
addDependenceElem(DependenceType::RAR,
LinalgOpView{src.getOperation(), srcView},
@@ -175,7 +166,7 @@
}
}
// WAR graph
- for (auto *dstView : dst.getOutputs()) { // W
+ for (auto dstView : dst.getOutputs()) { // W
if (aliases.alias(srcView, dstView)) { // if alias, fill WAR
addDependenceElem(DependenceType::WAR,
LinalgOpView{src.getOperation(), srcView},
@@ -194,14 +185,14 @@
}
SmallVector<Operation *, 8> LinalgDependenceGraph::findCoveringWrites(
- LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value *view) const {
+ LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view) const {
return findOperationsWithCoveringDependences(
srcLinalgOp, dstLinalgOp, view,
{DependenceType::WAW, DependenceType::WAR});
}
SmallVector<Operation *, 8> LinalgDependenceGraph::findCoveringReads(
- LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value *view) const {
+ LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view) const {
return findOperationsWithCoveringDependences(
srcLinalgOp, dstLinalgOp, view,
{DependenceType::RAR, DependenceType::RAW});
@@ -209,7 +200,7 @@
SmallVector<Operation *, 8>
LinalgDependenceGraph::findOperationsWithCoveringDependences(
- LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value *view,
+ LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view,
ArrayRef<DependenceType> types) const {
auto *src = srcLinalgOp.getOperation();
auto *dst = dstLinalgOp.getOperation();
diff --git a/third_party/mlir/lib/Dialect/Linalg/CMakeLists.txt b/third_party/mlir/lib/Dialect/Linalg/CMakeLists.txt
index 9d2b0cd..2ca5da3 100644
--- a/third_party/mlir/lib/Dialect/Linalg/CMakeLists.txt
+++ b/third_party/mlir/lib/Dialect/Linalg/CMakeLists.txt
@@ -23,7 +23,7 @@
MLIRAnalysis
MLIREDSC
MLIRLinalgOpsIncGen
- MLIRLinalgLibraryOpsIncGen
+ MLIRLinalgStructuredOpsIncGen
MLIRLinalgTransformPatternsIncGen
MLIRStandardOps
MLIRStandardToLLVM
diff --git a/third_party/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/third_party/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
index ba96186..37c63b7 100644
--- a/third_party/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
+++ b/third_party/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
@@ -1,19 +1,10 @@
//===- Builders.cpp - MLIR Declarative Linalg Builders --------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/EDSC/Builders.h"
#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
@@ -44,8 +35,8 @@
Operation *mlir::edsc::makeLinalgGenericOp(
ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
ArrayRef<StructuredIndexed> outputs,
- function_ref<void(ArrayRef<BlockArgument *>)> regionBuilder,
- ArrayRef<Value *> otherValues, ArrayRef<Attribute> otherAttributes) {
+ function_ref<void(ArrayRef<BlockArgument>)> regionBuilder,
+ ArrayRef<Value> otherValues, ArrayRef<Attribute> otherAttributes) {
auto &builder = edsc::ScopedContext::getBuilder();
auto *ctx = builder.getContext();
unsigned nInputs = inputs.size();
@@ -66,7 +57,7 @@
AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, out.getExprs()));
unsigned nViews = nInputs + nOutputs;
- SmallVector<Value *, 4> values;
+ SmallVector<Value, 4> values;
values.reserve(nViews);
values.append(inputs.begin(), inputs.end());
values.append(outputs.begin(), outputs.end());
@@ -109,7 +100,7 @@
return op;
}
-void mlir::edsc::ops::macRegionBuilder(ArrayRef<BlockArgument *> args) {
+void mlir::edsc::ops::macRegionBuilder(ArrayRef<BlockArgument> args) {
using edsc::op::operator+;
using edsc::op::operator*;
assert(args.size() == 3 && "expected 3 block arguments");
@@ -122,7 +113,7 @@
StructuredIndexed O) {
SmallVector<edsc::IterType, 4> iterTypes(O.getExprs().size(),
edsc::IterType::Parallel);
- auto fun = [&unaryOp](ArrayRef<BlockArgument *> args) {
+ auto fun = [&unaryOp](ArrayRef<BlockArgument> args) {
assert(args.size() == 2 && "expected 2 block arguments");
ValueHandle a(args[0]);
linalg_yield(unaryOp(a));
@@ -134,8 +125,7 @@
StructuredIndexed O) {
;
using edsc::intrinsics::tanh;
- UnaryPointwiseOpBuilder unOp(
- [](ValueHandle a) -> Value * { return tanh(a); });
+ UnaryPointwiseOpBuilder unOp([](ValueHandle a) -> Value { return tanh(a); });
return linalg_pointwise(unOp, I, O);
}
@@ -146,7 +136,7 @@
StructuredIndexed O) {
SmallVector<edsc::IterType, 4> iterTypes(O.getExprs().size(),
edsc::IterType::Parallel);
- auto fun = [&binaryOp](ArrayRef<BlockArgument *> args) {
+ auto fun = [&binaryOp](ArrayRef<BlockArgument> args) {
assert(args.size() == 3 && "expected 3 block arguments");
ValueHandle a(args[0]), b(args[1]);
linalg_yield(binaryOp(a, b));
@@ -159,14 +149,14 @@
StructuredIndexed O) {
using edsc::op::operator+;
BinaryPointwiseOpBuilder binOp(
- [](ValueHandle a, ValueHandle b) -> Value * { return a + b; });
+ [](ValueHandle a, ValueHandle b) -> Value { return a + b; });
return linalg_pointwise(binOp, I1, I2, O);
}
Operation *mlir::edsc::ops::linalg_pointwise_max(StructuredIndexed I1,
StructuredIndexed I2,
StructuredIndexed O) {
- BinaryPointwiseOpBuilder binOp([](ValueHandle a, ValueHandle b) -> Value * {
+ BinaryPointwiseOpBuilder binOp([](ValueHandle a, ValueHandle b) -> Value {
using edsc::intrinsics::select;
using edsc::op::operator>;
return select(a > b, a, b).getValue();
diff --git a/third_party/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/third_party/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 3154516..0f9f8f8 100644
--- a/third_party/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/third_party/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1,19 +1,10 @@
//===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements a the Linalg operations.
//
@@ -318,7 +309,7 @@
// SliceOp
//===----------------------------------------------------------------------===//
void mlir::linalg::SliceOp::build(Builder *b, OperationState &result,
- Value *base, ValueRange indexings) {
+ Value base, ValueRange indexings) {
result.addOperands(base);
result.addOperands(indexings);
@@ -394,7 +385,7 @@
// TransposeOp
//===----------------------------------------------------------------------===//
void mlir::linalg::TransposeOp::build(Builder *b, OperationState &result,
- Value *view, AffineMapAttr permutation,
+ Value view, AffineMapAttr permutation,
ArrayRef<NamedAttribute> attrs) {
auto permutationMap = permutation.getValue();
assert(permutationMap);
@@ -507,10 +498,10 @@
/////// Operations corresponding to library calls defined with Tablegen ////////
// For such operations correspond to library calls (i.e. defined in
-// LinalgLibraryOps.td), we define an overloaded `print` function and a
+// LinalgStructuredOps.td), we define an overloaded `print` function and a
// parse`className` function.
-// A LinalgLibraryOp prints as:
+// A LinalgStructuredOp prints as:
//
// ```mlir
// concrete_op_name (ssa-inputs, ssa-outputs) : view-types
@@ -526,15 +517,15 @@
// ```
//
// Where %0, %1 and %2 are ssa-values of type MemRefType with strides.
-static void printLinalgLibraryOp(OpAsmPrinter &p, Operation *op) {
+static void printLinalgStructuredOp(OpAsmPrinter &p, Operation *op) {
assert(op->getAbstractOperation() && "unregistered operation");
p << op->getName().getStringRef() << "(" << op->getOperands() << ")";
p.printOptionalAttrDict(op->getAttrs());
p << " : " << op->getOperandTypes();
}
-static ParseResult parseLinalgLibraryOp(OpAsmParser &parser,
- OperationState &result) {
+static ParseResult parseLinalgStructuredOp(OpAsmParser &parser,
+ OperationState &result) {
SmallVector<OpAsmParser::OperandType, 3> ops;
SmallVector<Type, 3> types;
return failure(
@@ -621,13 +612,13 @@
namespace mlir {
namespace linalg {
-#include "mlir/Dialect/Linalg/IR/LinalgLibraryOpInterfaces.cpp.inc"
+#include "mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterfaces.cpp.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
#define GET_OP_CLASSES
-#include "mlir/Dialect/Linalg/IR/LinalgLibraryOps.cpp.inc"
+#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
} // namespace linalg
} // namespace mlir
diff --git a/third_party/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp b/third_party/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
index 9fbb83b..32b1620 100644
--- a/third_party/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
+++ b/third_party/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
@@ -1,19 +1,10 @@
//===- Dialect.cpp - Implementation of the linalg dialect and types -------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements the Linalg dialect types and dialect.
//
@@ -42,7 +33,7 @@
>();
addOperations<
#define GET_OP_LIST
-#include "mlir/Dialect/Linalg/IR/LinalgLibraryOps.cpp.inc"
+#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
>();
}
Type mlir::linalg::LinalgDialect::parseType(DialectAsmParser &parser) const {
diff --git a/third_party/mlir/lib/Dialect/Linalg/LinalgRegistration.cpp b/third_party/mlir/lib/Dialect/Linalg/LinalgRegistration.cpp
index df21ffa..768b18b 100644
--- a/third_party/mlir/lib/Dialect/Linalg/LinalgRegistration.cpp
+++ b/third_party/mlir/lib/Dialect/Linalg/LinalgRegistration.cpp
@@ -1,19 +1,10 @@
//===- LinalgRegistration.cpp - Register the linalg dialect statically ----===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
diff --git a/third_party/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/third_party/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 453daba..9df7bce 100644
--- a/third_party/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/third_party/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -1,19 +1,10 @@
//===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements the linalg dialect Fusion pass.
//
@@ -77,16 +68,16 @@
static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
ArrayRef<SubViewOp::Range> loopRanges) {
auto maps = loopToOperandRangesMaps(op);
- SmallVector<Value *, 8> clonedViews;
+ SmallVector<Value, 8> clonedViews;
clonedViews.reserve(op.getNumInputsAndOutputs());
// Iterate over the inputs and outputs in order.
// Extract the subranges from the linearized ranges.
- SmallVector<Value *, 8> ios(op.getInputsAndOutputs());
+ SmallVector<Value, 8> ios(op.getInputsAndOutputs());
for (auto en : llvm::enumerate(ios)) {
unsigned idx = en.index();
auto map = maps[idx];
LLVM_DEBUG(dbgs() << "map: " << map << "\n");
- Value *view = en.value();
+ Value view = en.value();
SmallVector<SubViewOp::Range, 4> viewRanges(map.getNumResults());
for (auto en2 : llvm::enumerate(map.getResults())) {
unsigned d = en2.index();
@@ -99,7 +90,7 @@
}
// Construct a new subview for the tile.
unsigned rank = viewRanges.size();
- SmallVector<Value *, 4> offsets, sizes, strides;
+ SmallVector<Value, 4> offsets, sizes, strides;
offsets.reserve(rank);
sizes.reserve(rank);
strides.reserve(rank);
@@ -117,7 +108,7 @@
}
struct ViewDimension {
- Value *view;
+ Value view;
unsigned dimension;
};
@@ -130,14 +121,14 @@
auto maps = loopToOperandRangesMaps(op);
// Iterate over the inputs and outputs in order.
// Extract the subranges from the linearized ranges.
- SmallVector<Value *, 8> ios(op.getInputsAndOutputs());
+ SmallVector<Value, 8> ios(op.getInputsAndOutputs());
for (auto en : llvm::enumerate(ios)) {
unsigned idx = en.index();
auto map = maps[idx];
LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n");
LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n");
- Value *view = en.value();
- SmallVector<Value *, 8> viewRanges(map.getNumResults(), nullptr);
+ Value view = en.value();
+ SmallVector<Value, 8> viewRanges(map.getNumResults(), nullptr);
for (auto en2 : llvm::enumerate(map.getResults())) {
if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth
@@ -151,7 +142,7 @@
llvm_unreachable("Expect to be able to extract a view defining loop range");
}
-static LinalgOp fuse(Value *producedView, LinalgOp producer, LinalgOp consumer,
+static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer,
unsigned consumerIdx, unsigned producerIdx,
OperationFolder *folder) {
auto subView = dyn_cast_or_null<SubViewOp>(
@@ -205,8 +196,7 @@
// Encode structural fusion safety preconditions.
// Some of these will be lifted in the future with better analysis.
-static bool isStructurallyFusableProducer(LinalgOp producer,
- Value *consumedView,
+static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
LinalgOp consumer) {
if (producer.getNumOutputs() != 1) {
LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)");
@@ -226,7 +216,7 @@
bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
LinalgOp consumer,
- Value *consumedView,
+ Value consumedView,
LinalgOp producer) {
// Make some simple structural checks that alleviate the need for more
// complex analyses.
@@ -245,7 +235,7 @@
}
bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
- LinalgOp consumer, Value *consumedView,
+ LinalgOp consumer, Value consumedView,
LinalgOp producer) {
if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer))
return false;
@@ -272,13 +262,13 @@
auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
// Check that the dependence is indeed on the input `consumerIdx` view.
- auto *consumedView = dependence.indexingView;
+ auto consumedView = dependence.indexingView;
if (consumer.getInput(consumerIdx) != consumedView)
continue;
// Consumer consumes this view, `isStructurallyFusableProducer` also checks
// whether it is a strict subview of the producer view.
- auto *producedView = dependence.dependentOpView.view;
+ auto producedView = dependence.dependentOpView.view;
auto producerIdx = producer.getIndexOfOutput(producedView).getValue();
// `consumerIdx` and `producerIdx` exist by construction.
LLVM_DEBUG(dbgs() << "\nRAW producer: " << *producer.getOperation()
diff --git a/third_party/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/third_party/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
index 96a8a21..d7cc4a8 100644
--- a/third_party/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
+++ b/third_party/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
@@ -1,19 +1,10 @@
//===- LowerToLoops.cpp - conversion from Linalg library ops to loops------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/AffineOps/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
@@ -49,7 +40,7 @@
static SmallVector<ValueHandle, 8>
makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map,
- ArrayRef<Value *> vals) {
+ ArrayRef<Value> vals) {
assert(map.getNumSymbols() == 0);
assert(map.getNumInputs() == vals.size());
SmallVector<ValueHandle, 8> res;
@@ -57,35 +48,34 @@
auto dims = map.getNumDims();
for (auto e : map.getResults()) {
auto exprMap = AffineMap::get(dims, 0, e);
- SmallVector<Value *, 4> operands(vals.begin(), vals.end());
+ SmallVector<Value, 4> operands(vals.begin(), vals.end());
canonicalizeMapAndOperands(&exprMap, &operands);
res.push_back(affine_apply(exprMap, operands));
}
return res;
}
-static SmallVector<Value *, 4> permuteIvs(ArrayRef<Value *> ivs,
- Optional<AffineMap> permutation) {
+static SmallVector<Value, 4> permuteIvs(ArrayRef<Value> ivs,
+ Optional<AffineMap> permutation) {
return permutation ? applyMapToValues(ScopedContext::getBuilder(),
ScopedContext::getLocation(),
permutation.getValue(), ivs)
- : SmallVector<Value *, 4>(ivs.begin(), ivs.end());
+ : SmallVector<Value, 4>(ivs.begin(), ivs.end());
}
// Creates a number of ranges equal to the number of results in `map`.
// The returned ranges correspond to the loop ranges, in the proper order, for
// which new loops will be created.
-static SmallVector<Value *, 4> emitLoopRanges(OpBuilder &b, Location loc,
- AffineMap map,
- ArrayRef<Value *> allViewSizes);
-SmallVector<Value *, 4> emitLoopRanges(OpBuilder &b, Location loc,
- AffineMap map,
- ArrayRef<Value *> allViewSizes) {
+static SmallVector<Value, 4> emitLoopRanges(OpBuilder &b, Location loc,
+ AffineMap map,
+ ArrayRef<Value> allViewSizes);
+SmallVector<Value, 4> emitLoopRanges(OpBuilder &b, Location loc, AffineMap map,
+ ArrayRef<Value> allViewSizes) {
// Apply `map` to get view sizes in loop order.
auto sizes = applyMapToValues(b, loc, map, allViewSizes);
// Create a new range with the applied tile sizes.
ScopedContext scope(b, loc);
- SmallVector<Value *, 4> res;
+ SmallVector<Value, 4> res;
for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx) {
res.push_back(range(constant_index(0), sizes[idx], constant_index(1)));
}
@@ -98,8 +88,7 @@
template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, CopyOp> {
public:
- static void emitScalarImplementation(ArrayRef<Value *> allIvs,
- CopyOp copyOp) {
+ static void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) {
auto nPar = copyOp.getNumParallelLoops();
assert(nPar == allIvs.size());
auto inputIvs =
@@ -121,8 +110,7 @@
template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, FillOp> {
public:
- static void emitScalarImplementation(ArrayRef<Value *> allIvs,
- FillOp fillOp) {
+ static void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
auto nPar = fillOp.getNumParallelLoops();
assert(nPar == allIvs.size());
auto ivs =
@@ -138,7 +126,7 @@
template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, DotOp> {
public:
- static void emitScalarImplementation(ArrayRef<Value *> allIvs, DotOp dotOp) {
+ static void emitScalarImplementation(ArrayRef<Value> allIvs, DotOp dotOp) {
assert(allIvs.size() == 1);
IndexHandle r_i(allIvs[0]);
IndexedValueType A(dotOp.getInput(0)), B(dotOp.getInput(1)),
@@ -151,7 +139,7 @@
template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, MatvecOp> {
public:
- static void emitScalarImplementation(ArrayRef<Value *> allIvs,
+ static void emitScalarImplementation(ArrayRef<Value> allIvs,
MatvecOp matvecOp) {
assert(allIvs.size() == 2);
IndexHandle i(allIvs[0]), r_j(allIvs[1]);
@@ -165,7 +153,7 @@
template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, MatmulOp> {
public:
- static void emitScalarImplementation(ArrayRef<Value *> allIvs,
+ static void emitScalarImplementation(ArrayRef<Value> allIvs,
MatmulOp matmulOp) {
assert(allIvs.size() == 3);
IndexHandle i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]);
@@ -179,8 +167,7 @@
template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, ConvOp> {
public:
- static void emitScalarImplementation(ArrayRef<Value *> allIvs,
- ConvOp convOp) {
+ static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) {
auto b = ScopedContext::getBuilder();
auto loc = ScopedContext::getLocation();
auto maps = loopToOperandRangesMaps(convOp);
@@ -229,14 +216,14 @@
template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, GenericOp> {
public:
- static void emitScalarImplementation(ArrayRef<Value *> allIvs,
+ static void emitScalarImplementation(ArrayRef<Value> allIvs,
GenericOp genericOp) {
auto b = ScopedContext::getBuilder();
auto loc = ScopedContext::getLocation();
using edsc::intrinsics::detail::ValueHandleArray;
unsigned nInputs = genericOp.getNumInputs();
unsigned nOutputs = genericOp.getNumOutputs();
- SmallVector<Value *, 4> indexedValues(nInputs + nOutputs);
+ SmallVector<Value, 4> indexedValues(nInputs + nOutputs);
// 1.a. Emit std_load from input views.
for (unsigned i = 0; i < nInputs; ++i) {
@@ -324,7 +311,7 @@
template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
public:
- static void emitScalarImplementation(ArrayRef<Value *> allIvs,
+ static void emitScalarImplementation(ArrayRef<Value> allIvs,
IndexedGenericOp indexedGenericOp) {
auto b = ScopedContext::getBuilder();
auto loc = ScopedContext::getLocation();
@@ -332,7 +319,7 @@
unsigned nInputs = indexedGenericOp.getNumInputs();
unsigned nOutputs = indexedGenericOp.getNumOutputs();
unsigned nLoops = allIvs.size();
- SmallVector<Value *, 4> indexedValues(nLoops + nInputs + nOutputs);
+ SmallVector<Value, 4> indexedValues(nLoops + nInputs + nOutputs);
for (unsigned i = 0; i < nLoops; ++i) {
indexedValues[i] = allIvs[i];
@@ -488,7 +475,7 @@
void FillRewritePatterns(OwningRewritePatternList &patterns, MLIRContext *ctx) {
RewritePatternList<LoopType, IndexedValueType,
#define GET_OP_LIST
-#include "mlir/Dialect/Linalg/IR/LinalgLibraryOps.cpp.inc"
+#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
>::build(patterns, ctx);
}
diff --git a/third_party/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/third_party/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
index f436492..eb23a8c 100644
--- a/third_party/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
+++ b/third_party/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
@@ -1,19 +1,10 @@
//===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements logic for transforming Linalg operations.
//
@@ -99,7 +90,7 @@
}
bool mlir::linalg::detail::isProducedByOpOfTypeImpl(
- Operation *consumerOp, Value *consumedView,
+ Operation *consumerOp, Value consumedView,
function_ref<bool(Operation *)> isaOpType) {
LinalgOp consumer = dyn_cast<LinalgOp>(consumerOp);
if (!consumer)
@@ -175,7 +166,7 @@
return failure();
// TODO(ntv): non-identity layout.
- auto isStaticMemRefWithIdentityLayout = [](Value *v) {
+ auto isStaticMemRefWithIdentityLayout = [](Value v) {
auto m = v->getType().dyn_cast<MemRefType>();
if (!m || !m.hasStaticShape() || !m.getAffineMaps().empty())
return false;
@@ -235,7 +226,7 @@
LogicalResult mlir::linalg::linalgOpPromoteSubviews(PatternRewriter &rewriter,
Operation *op) {
LinalgOp linOp = dyn_cast<LinalgOp>(op);
- SetVector<Value *> subViews;
+ SetVector<Value> subViews;
for (auto it : linOp.getInputsAndOutputs())
if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp()))
subViews.insert(sv);
diff --git a/third_party/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/third_party/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index c7fbebc..b8b2795 100644
--- a/third_party/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/third_party/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -1,19 +1,10 @@
//===- Promotion.cpp - Implementation of linalg Promotion -----------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements the linalg dialect Promotion pass.
//
@@ -55,14 +46,14 @@
llvm::cl::desc("Test generation of dynamic promoted buffers"),
llvm::cl::cat(clOptionsCategory), llvm::cl::init(false));
-static Value *allocBuffer(Type elementType, Value *size, bool dynamicBuffers) {
+static Value allocBuffer(Type elementType, Value size, bool dynamicBuffers) {
auto *ctx = size->getContext();
auto width = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
if (!dynamicBuffers)
if (auto cst = dyn_cast_or_null<ConstantIndexOp>(size->getDefiningOp()))
return alloc(
MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx)));
- Value *mul = muli(constant_index(width), size);
+ Value mul = muli(constant_index(width), size);
return alloc(MemRefType::get(-1, IntegerType::get(8, ctx)), mul);
}
@@ -92,20 +83,20 @@
auto viewType = subView.getType();
auto rank = viewType.getRank();
- Value *allocSize = one;
- SmallVector<Value *, 8> fullRanges, partialRanges;
+ Value allocSize = one;
+ SmallVector<Value, 8> fullRanges, partialRanges;
fullRanges.reserve(rank);
partialRanges.reserve(rank);
for (auto en : llvm::enumerate(subView.getRanges())) {
auto rank = en.index();
auto rangeValue = en.value();
- Value *d = rangeValue.size;
+ Value d = rangeValue.size;
allocSize = muli(folder, allocSize, d).getValue();
fullRanges.push_back(d);
partialRanges.push_back(range(folder, zero, dim(subView, rank), one));
}
SmallVector<int64_t, 4> dynSizes(fullRanges.size(), -1);
- auto *buffer =
+ auto buffer =
allocBuffer(viewType.getElementType(), allocSize, dynamicBuffers);
auto fullLocalView = view(
MemRefType::get(dynSizes, viewType.getElementType()), buffer, fullRanges);
@@ -115,7 +106,7 @@
SmallVector<PromotionInfo, 8>
mlir::linalg::promoteSubViews(OpBuilder &b, Location loc,
- ArrayRef<Value *> subViews, bool dynamicBuffers,
+ ArrayRef<Value> subViews, bool dynamicBuffers,
OperationFolder *folder) {
if (subViews.empty())
return {};
@@ -123,8 +114,8 @@
ScopedContext scope(b, loc);
SmallVector<PromotionInfo, 8> res;
res.reserve(subViews.size());
- DenseMap<Value *, PromotionInfo> promotionInfoMap;
- for (auto *v : subViews) {
+ DenseMap<Value, PromotionInfo> promotionInfoMap;
+ for (auto v : subViews) {
SubViewOp subView = cast<SubViewOp>(v->getDefiningOp());
auto viewType = subView.getType();
// TODO(ntv): support more cases than just float.
@@ -136,7 +127,7 @@
res.push_back(promotionInfo);
}
- for (auto *v : subViews) {
+ for (auto v : subViews) {
SubViewOp subView = cast<SubViewOp>(v->getDefiningOp());
auto info = promotionInfoMap.find(v);
if (info == promotionInfoMap.end())
@@ -144,14 +135,14 @@
// TODO(ntv): value to fill with should be related to the operation.
// For now, just use APFloat(0.0f).
auto t = subView.getType().getElementType().cast<FloatType>();
- Value *fillVal = constant_float(folder, APFloat(0.0f), t);
+ Value fillVal = constant_float(folder, APFloat(0.0f), t);
// TODO(ntv): fill is only necessary if `promotionInfo` has a full local
// view that is different from the partial local view and we are on the
// boundary.
fill(info->second.fullLocalView, fillVal);
}
- for (auto *v : subViews) {
+ for (auto v : subViews) {
auto info = promotionInfoMap.find(v);
if (info == promotionInfoMap.end())
continue;
@@ -161,19 +152,19 @@
}
LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op,
- SetVector<Value *> subViews,
+ SetVector<Value> subViews,
bool dynamicBuffers,
OperationFolder *folder) {
// 1. Promote the specified views and use them in the new op.
ScopedContext scope(b, op.getLoc());
auto promotedBufferAndViews = promoteSubViews(
b, op.getLoc(), subViews.getArrayRef(), dynamicBuffers, folder);
- SmallVector<Value *, 8> opViews;
+ SmallVector<Value, 8> opViews;
opViews.reserve(op.getNumInputsAndOutputs());
- SmallVector<std::pair<Value *, Value *>, 8> writebackViews;
+ SmallVector<std::pair<Value, Value>, 8> writebackViews;
writebackViews.reserve(subViews.size());
unsigned promotedIdx = 0;
- for (auto *view : op.getInputsAndOutputs()) {
+ for (auto view : op.getInputsAndOutputs()) {
if (subViews.count(view) != 0) {
opViews.push_back(promotedBufferAndViews[promotedIdx].fullLocalView);
writebackViews.emplace_back(std::make_pair(
@@ -214,7 +205,7 @@
f.walk([dynamicBuffers, &folder, &toErase](LinalgOp op) {
// TODO(ntv) some heuristic here to decide what to promote. Atm it is all or
// nothing.
- SetVector<Value *> subViews;
+ SetVector<Value> subViews;
OpBuilder b(op);
for (auto it : op.getInputsAndOutputs())
if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp()))
diff --git a/third_party/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/third_party/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 4d8a24c..964f540 100644
--- a/third_party/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/third_party/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -1,19 +1,10 @@
//===- Tiling.cpp - Implementation of linalg Tiling -----------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements the linalg dialect Tiling pass.
//
@@ -53,7 +44,7 @@
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated,
llvm::cl::cat(clOptionsCategory));
-static bool isZero(Value *v) {
+static bool isZero(Value v) {
return isa_and_nonnull<ConstantIndexOp>(v->getDefiningOp()) &&
cast<ConstantIndexOp>(v->getDefiningOp()).getValue() == 0;
}
@@ -71,12 +62,12 @@
// indices of newly created loops.
static std::tuple<SmallVector<SubViewOp::Range, 4>, LoopIndexToRangeIndexMap>
makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
- ArrayRef<Value *> allViewSizes,
- ArrayRef<Value *> allTileSizes, OperationFolder *folder) {
+ ArrayRef<Value> allViewSizes, ArrayRef<Value> allTileSizes,
+ OperationFolder *folder) {
assert(allTileSizes.size() == map.getNumResults());
// Apply `map` to get view sizes in loop order.
auto viewSizes = applyMapToValues(b, loc, map, allViewSizes, folder);
- SmallVector<Value *, 4> tileSizes(allTileSizes.begin(), allTileSizes.end());
+ SmallVector<Value, 4> tileSizes(allTileSizes.begin(), allTileSizes.end());
// Traverse the tile sizes, which are in loop order, erase zeros everywhere.
LoopIndexToRangeIndexMap loopIndexToRangeIndex;
@@ -110,8 +101,7 @@
// `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
//
struct TileCheck : public AffineExprVisitor<TileCheck> {
- TileCheck(ArrayRef<Value *> tileSizes)
- : isTiled(false), tileSizes(tileSizes) {}
+ TileCheck(ArrayRef<Value> tileSizes) : isTiled(false), tileSizes(tileSizes) {}
void visitDimExpr(AffineDimExpr expr) {
isTiled |= !isZero(tileSizes[expr.getPosition()]);
@@ -124,7 +114,7 @@
"nonpositive multiplying coefficient");
}
bool isTiled;
- ArrayRef<Value *> tileSizes;
+ ArrayRef<Value> tileSizes;
};
} // namespace
@@ -206,11 +196,11 @@
auto rangeIndex = loopIndexToRangeIndex.find(i);
if (rangeIndex == loopIndexToRangeIndex.end())
continue;
- Value *oldIndex = block.getArgument(i);
+ Value oldIndex = block.getArgument(i);
// Offset the index argument `i` by the value of the corresponding induction
// variable and replace all uses of the previous value.
- Value *newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex,
- pivs[rangeIndex->second]->getValue());
+ Value newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex,
+ pivs[rangeIndex->second]->getValue());
for (auto &use : oldIndex->getUses()) {
if (use.getOwner() == newIndex->getDefiningOp())
continue;
@@ -219,7 +209,7 @@
}
}
-static bool isTiled(AffineExpr expr, ArrayRef<Value *> tileSizes) {
+static bool isTiled(AffineExpr expr, ArrayRef<Value> tileSizes) {
if (!expr)
return false;
TileCheck t(tileSizes);
@@ -229,7 +219,7 @@
// Checks whether the view with index `viewIndex` within `linalgOp` varies with
// respect to a non-zero `tileSize`.
-static bool isTiled(AffineMap map, ArrayRef<Value *> tileSizes) {
+static bool isTiled(AffineMap map, ArrayRef<Value> tileSizes) {
if (!map)
return false;
for (unsigned r = 0; r < map.getNumResults(); ++r)
@@ -238,13 +228,13 @@
return false;
}
-static SmallVector<Value *, 4>
+static SmallVector<Value, 4>
makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp,
- ArrayRef<Value *> ivs, ArrayRef<Value *> tileSizes,
- ArrayRef<Value *> viewSizes, OperationFolder *folder) {
+ ArrayRef<Value> ivs, ArrayRef<Value> tileSizes,
+ ArrayRef<Value> viewSizes, OperationFolder *folder) {
assert(ivs.size() == static_cast<size_t>(llvm::count_if(
llvm::make_range(tileSizes.begin(), tileSizes.end()),
- [](Value *v) { return !isZero(v); })) &&
+ [](Value v) { return !isZero(v); })) &&
"expected as many ivs as non-zero sizes");
using edsc::intrinsics::select;
@@ -253,21 +243,21 @@
// Construct (potentially temporary) mins and maxes on which to apply maps
// that define tile subviews.
- SmallVector<Value *, 8> lbs, subViewSizes;
+ SmallVector<Value, 8> lbs, subViewSizes;
for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
bool isTiled = !isZero(tileSizes[idx]);
- lbs.push_back(isTiled ? ivs[idxIvs++] : (Value *)constant_index(folder, 0));
+ lbs.push_back(isTiled ? ivs[idxIvs++] : (Value)constant_index(folder, 0));
subViewSizes.push_back(isTiled ? tileSizes[idx] : viewSizes[idx]);
}
auto *op = linalgOp.getOperation();
- SmallVector<Value *, 4> res;
+ SmallVector<Value, 4> res;
res.reserve(op->getNumOperands());
auto viewIteratorBegin = linalgOp.getInputsAndOutputs().begin();
for (unsigned viewIndex = 0; viewIndex < linalgOp.getNumInputsAndOutputs();
++viewIndex) {
- Value *view = *(viewIteratorBegin + viewIndex);
+ Value view = *(viewIteratorBegin + viewIndex);
unsigned rank = view->getType().cast<MemRefType>().getRank();
auto map = loopToOperandRangesMaps(linalgOp)[viewIndex];
// If the view is not tiled, we can use it as is.
@@ -277,7 +267,7 @@
}
// Construct a new subview for the tile.
- SmallVector<Value *, 4> offsets, sizes, strides;
+ SmallVector<Value, 4> offsets, sizes, strides;
offsets.reserve(rank);
sizes.reserve(rank);
strides.reserve(rank);
@@ -292,9 +282,9 @@
// Tiling creates a new slice at the proper index, the slice step is 1
// (i.e. the slice view does not subsample, stepping occurs in the loop).
auto m = map.getSubMap({r});
- auto *offset = applyMapToValues(b, loc, m, lbs, folder).front();
+ auto offset = applyMapToValues(b, loc, m, lbs, folder).front();
offsets.push_back(offset);
- auto *size = applyMapToValues(b, loc, m, subViewSizes, folder).front();
+ auto size = applyMapToValues(b, loc, m, subViewSizes, folder).front();
sizes.push_back(size);
strides.push_back(constant_index(folder, 1));
}
@@ -308,16 +298,17 @@
// This is a special type of folding that we only apply when `folder` is
// defined.
if (folder)
- for (auto *v : llvm::concat<Value *>(lbs, subViewSizes))
+ for (auto v : llvm::concat<Value>(lbs, subViewSizes))
if (v->use_empty())
v->getDefiningOp()->erase();
return res;
}
-Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp(
- OpBuilder &b, LinalgOp op, ArrayRef<Value *> tileSizes,
- ArrayRef<unsigned> permutation, OperationFolder *folder) {
+Optional<TiledLinalgOp>
+mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes,
+ ArrayRef<unsigned> permutation,
+ OperationFolder *folder) {
// 1. Enforce the convention that "tiling by zero" skips tiling a particular
// dimension. This convention is significantly simpler to handle instead of
// adjusting affine maps to account for missing dimensions.
@@ -360,7 +351,7 @@
LoopNestRangeBuilder(pivs, loopRanges)([&] {
auto b = ScopedContext::getBuilder();
auto loc = ScopedContext::getLocation();
- SmallVector<Value *, 4> ivValues(ivs.begin(), ivs.end());
+ SmallVector<Value, 4> ivValues(ivs.begin(), ivs.end());
// If we have to apply a permutation to the tiled loop nest, we have to
// reorder the induction variables This permutation is the right one
@@ -411,7 +402,7 @@
ScopedContext scope(b, op.getLoc());
// Materialize concrete tile size values to pass the generic tiling function.
- SmallVector<Value *, 8> tileSizeValues;
+ SmallVector<Value, 8> tileSizeValues;
tileSizeValues.reserve(tileSizes.size());
for (auto ts : tileSizes)
tileSizeValues.push_back(constant_index(folder, ts));
diff --git a/third_party/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/third_party/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index eb501f9..560a023 100644
--- a/third_party/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/third_party/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -1,19 +1,10 @@
//===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements utilities for the Linalg dialect.
//
@@ -92,7 +83,7 @@
}
mlir::edsc::LoopNestRangeBuilder::LoopNestRangeBuilder(
- ArrayRef<ValueHandle *> ivs, ArrayRef<Value *> ranges)
+ ArrayRef<ValueHandle *> ivs, ArrayRef<Value> ranges)
: LoopNestRangeBuilder(
ivs, SmallVector<ValueHandle, 4>(ranges.begin(), ranges.end())) {}
@@ -106,26 +97,26 @@
return ValueHandle::null();
}
-static Value *emitOrFoldComposedAffineApply(OpBuilder &b, Location loc,
- AffineMap map,
- ArrayRef<Value *> operandsRef,
- OperationFolder *folder) {
- SmallVector<Value *, 4> operands(operandsRef.begin(), operandsRef.end());
+static Value emitOrFoldComposedAffineApply(OpBuilder &b, Location loc,
+ AffineMap map,
+ ArrayRef<Value> operandsRef,
+ OperationFolder *folder) {
+ SmallVector<Value, 4> operands(operandsRef.begin(), operandsRef.end());
fullyComposeAffineMapAndOperands(&map, &operands);
canonicalizeMapAndOperands(&map, &operands);
return folder ? folder->create<AffineApplyOp>(b, loc, map, operands)
: b.create<AffineApplyOp>(loc, map, operands);
}
-SmallVector<Value *, 4>
-mlir::linalg::applyMapToValues(OpBuilder &b, Location loc, AffineMap map,
- ArrayRef<Value *> values,
- OperationFolder *folder) {
- SmallVector<Value *, 4> res;
+SmallVector<Value, 4> mlir::linalg::applyMapToValues(OpBuilder &b, Location loc,
+ AffineMap map,
+ ArrayRef<Value> values,
+ OperationFolder *folder) {
+ SmallVector<Value, 4> res;
res.reserve(map.getNumResults());
unsigned numDims = map.getNumDims();
// For each `expr` in `map`, applies the `expr` to the values extracted from
- // ranges. If the resulting application can be folded into a Value*, the
+ // ranges. If the resulting application can be folded into a Value, the
// folding occurs eagerly. Otherwise, an affine.apply operation is emitted.
for (auto expr : map.getResults()) {
AffineMap map = AffineMap::get(numDims, 0, expr);
@@ -137,12 +128,12 @@
/// Returns all the operands of `linalgOp` that are not views.
/// Asserts that these operands are value types to allow transformations like
/// tiling to just use the values when cloning `linalgOp`.
-SmallVector<Value *, 4>
+SmallVector<Value, 4>
mlir::linalg::getAssumedNonViewOperands(LinalgOp linalgOp) {
auto *op = linalgOp.getOperation();
unsigned numViews = linalgOp.getNumInputsAndOutputs();
unsigned nOperands = op->getNumOperands() - numViews;
- SmallVector<Value *, 4> res;
+ SmallVector<Value, 4> res;
res.reserve(nOperands);
for (unsigned i = 0; i < nOperands; ++i) {
res.push_back(op->getOperand(numViews + i));
diff --git a/third_party/mlir/lib/Dialect/LoopOps/DialectRegistration.cpp b/third_party/mlir/lib/Dialect/LoopOps/DialectRegistration.cpp
index 5724402..6564e78 100644
--- a/third_party/mlir/lib/Dialect/LoopOps/DialectRegistration.cpp
+++ b/third_party/mlir/lib/Dialect/LoopOps/DialectRegistration.cpp
@@ -1,19 +1,10 @@
//===- DialectRegistration.cpp - Register loop dialect --------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LoopOps/LoopOps.h"
using namespace mlir;
diff --git a/third_party/mlir/lib/Dialect/LoopOps/LoopOps.cpp b/third_party/mlir/lib/Dialect/LoopOps/LoopOps.cpp
index fc8832e..acbab01 100644
--- a/third_party/mlir/lib/Dialect/LoopOps/LoopOps.cpp
+++ b/third_party/mlir/lib/Dialect/LoopOps/LoopOps.cpp
@@ -1,19 +1,10 @@
//===- Ops.cpp - Loop MLIR Operations -------------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LoopOps/LoopOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
@@ -69,8 +60,8 @@
// ForOp
//===----------------------------------------------------------------------===//
-void ForOp::build(Builder *builder, OperationState &result, Value *lb,
- Value *ub, Value *step) {
+void ForOp::build(Builder *builder, OperationState &result, Value lb, Value ub,
+ Value step) {
result.addOperands({lb, ub, step});
Region *bodyRegion = result.addRegion();
ForOp::ensureTerminator(*bodyRegion, *builder, result.location);
@@ -134,7 +125,7 @@
Region &ForOp::getLoopBody() { return region(); }
-bool ForOp::isDefinedOutsideOfLoop(Value *value) {
+bool ForOp::isDefinedOutsideOfLoop(Value value) {
return !region().isAncestor(value->getParentRegion());
}
@@ -144,8 +135,8 @@
return success();
}
-ForOp mlir::loop::getForInductionVarOwner(Value *val) {
- auto *ivArg = dyn_cast<BlockArgument>(val);
+ForOp mlir::loop::getForInductionVarOwner(Value val) {
+ auto ivArg = val.dyn_cast<BlockArgument>();
if (!ivArg)
return ForOp();
assert(ivArg->getOwner() && "unlinked block argument");
@@ -157,7 +148,7 @@
// IfOp
//===----------------------------------------------------------------------===//
-void IfOp::build(Builder *builder, OperationState &result, Value *cond,
+void IfOp::build(Builder *builder, OperationState &result, Value cond,
bool withElseRegion) {
result.addOperands(cond);
Region *thenRegion = result.addRegion();
diff --git a/third_party/mlir/lib/Dialect/QuantOps/IR/DialectRegistration.cpp b/third_party/mlir/lib/Dialect/QuantOps/IR/DialectRegistration.cpp
index b071248..1738d6d 100644
--- a/third_party/mlir/lib/Dialect/QuantOps/IR/DialectRegistration.cpp
+++ b/third_party/mlir/lib/Dialect/QuantOps/IR/DialectRegistration.cpp
@@ -1,19 +1,10 @@
//===- DialectRegistration.cpp - Register Quantization dialect ------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/QuantOps/QuantOps.h"
diff --git a/third_party/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp b/third_party/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp
index 51f1994..faeff24 100644
--- a/third_party/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp
+++ b/third_party/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp
@@ -1,19 +1,10 @@
//===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/QuantOps/QuantOps.h"
#include "TypeDetail.h"
diff --git a/third_party/mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp b/third_party/mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp
index bc8290c..2e33963 100644
--- a/third_party/mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp
+++ b/third_party/mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp
@@ -1,19 +1,10 @@
//===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/QuantOps/QuantTypes.h"
#include "TypeDetail.h"
diff --git a/third_party/mlir/lib/Dialect/QuantOps/IR/TypeDetail.h b/third_party/mlir/lib/Dialect/QuantOps/IR/TypeDetail.h
index 13a88da..801a0de 100644
--- a/third_party/mlir/lib/Dialect/QuantOps/IR/TypeDetail.h
+++ b/third_party/mlir/lib/Dialect/QuantOps/IR/TypeDetail.h
@@ -1,19 +1,10 @@
//===- TypeDetail.h - QuantOps Type detail ----------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef TYPE_DETAIL_H_
#define TYPE_DETAIL_H_
diff --git a/third_party/mlir/lib/Dialect/QuantOps/IR/TypeParser.cpp b/third_party/mlir/lib/Dialect/QuantOps/IR/TypeParser.cpp
index 2bdde1f..2689a2d 100644
--- a/third_party/mlir/lib/Dialect/QuantOps/IR/TypeParser.cpp
+++ b/third_party/mlir/lib/Dialect/QuantOps/IR/TypeParser.cpp
@@ -1,19 +1,10 @@
//===- TypeParser.h - Quantization Type Parser ------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/QuantOps/QuantOps.h"
#include "mlir/Dialect/QuantOps/QuantTypes.h"
diff --git a/third_party/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp b/third_party/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
index 61636dc..08a5ec5 100644
--- a/third_party/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
+++ b/third_party/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp
@@ -1,19 +1,10 @@
//===- ConvertConst.cpp - Quantizes constant ops --------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/QuantOps/Passes.h"
#include "mlir/Dialect/QuantOps/QuantOps.h"
diff --git a/third_party/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp b/third_party/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp
index 83fa923..2a4c14f 100644
--- a/third_party/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp
+++ b/third_party/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp
@@ -1,19 +1,10 @@
//===- ConvertSimQuant.cpp - Converts simulated quant ops------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h"
#include "mlir/Dialect/QuantOps/Passes.h"
diff --git a/third_party/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp b/third_party/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp
index f4256cf..cbd4315 100644
--- a/third_party/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp
+++ b/third_party/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp
@@ -1,19 +1,10 @@
//===- FakeQuantSupport.cpp - Support utilities for FakeQuant ops ---------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h"
#include "mlir/Dialect/QuantOps/QuantTypes.h"
diff --git a/third_party/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp b/third_party/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp
index 56e2cba..094fefe 100644
--- a/third_party/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp
+++ b/third_party/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp
@@ -1,19 +1,10 @@
//===- QuantizeUtils.cpp - Support utilities for quantization -------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/QuantOps/QuantizeUtils.h"
#include "mlir/Dialect/QuantOps/UniformSupport.h"
diff --git a/third_party/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp b/third_party/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp
index 34e767d..df00233 100644
--- a/third_party/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp
+++ b/third_party/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp
@@ -1,19 +1,10 @@
//===- UniformSupport.cpp - Support utilities for uniform quant -----------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/QuantOps/UniformSupport.h"
#include "mlir/IR/StandardTypes.h"
diff --git a/third_party/mlir/lib/Dialect/SDBM/SDBM.cpp b/third_party/mlir/lib/Dialect/SDBM/SDBM.cpp
index 510e13e..03ffe3f 100644
--- a/third_party/mlir/lib/Dialect/SDBM/SDBM.cpp
+++ b/third_party/mlir/lib/Dialect/SDBM/SDBM.cpp
@@ -1,19 +1,10 @@
//===- SDBM.cpp - MLIR SDBM implementation --------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// A striped difference-bound matrix (SDBM) is a set in Z^N (or R^N) defined
// as {(x_1, ... x_n) | f(x_1, ... x_n) >= 0} where f is an SDBM expression.
diff --git a/third_party/mlir/lib/Dialect/SDBM/SDBMDialect.cpp b/third_party/mlir/lib/Dialect/SDBM/SDBMDialect.cpp
index d3d895f..fab9463 100644
--- a/third_party/mlir/lib/Dialect/SDBM/SDBMDialect.cpp
+++ b/third_party/mlir/lib/Dialect/SDBM/SDBMDialect.cpp
@@ -1,19 +1,10 @@
//===- SDBMDialect.cpp - Dialect for striped difference-bound matrices ----===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SDBM/SDBMDialect.h"
diff --git a/third_party/mlir/lib/Dialect/SDBM/SDBMExpr.cpp b/third_party/mlir/lib/Dialect/SDBM/SDBMExpr.cpp
index 44cdd18..68e3e1c 100644
--- a/third_party/mlir/lib/Dialect/SDBM/SDBMExpr.cpp
+++ b/third_party/mlir/lib/Dialect/SDBM/SDBMExpr.cpp
@@ -1,19 +1,10 @@
//===- SDBMExpr.cpp - MLIR SDBM Expression implementation -----------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// A striped difference-bound matrix (SDBM) expression is a constant expression,
// an identifier, a binary expression with constant RHS and +, stripe operators
diff --git a/third_party/mlir/lib/Dialect/SDBM/SDBMExprDetail.h b/third_party/mlir/lib/Dialect/SDBM/SDBMExprDetail.h
index 0441200..fb80b45 100644
--- a/third_party/mlir/lib/Dialect/SDBM/SDBMExprDetail.h
+++ b/third_party/mlir/lib/Dialect/SDBM/SDBMExprDetail.h
@@ -1,19 +1,10 @@
//===- SDBMExprDetail.h - MLIR SDBM Expression storage details --*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This holds implementation details of SDBMExpr, in particular underlying
// storage types.
diff --git a/third_party/mlir/lib/Dialect/SPIRV/DialectRegistration.cpp b/third_party/mlir/lib/Dialect/SPIRV/DialectRegistration.cpp
index 63e9e81..431b40e 100644
--- a/third_party/mlir/lib/Dialect/SPIRV/DialectRegistration.cpp
+++ b/third_party/mlir/lib/Dialect/SPIRV/DialectRegistration.cpp
@@ -1,19 +1,10 @@
//===- DialectRegistration.cpp - MLIR SPIR-V dialect registration ---------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
diff --git a/third_party/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp b/third_party/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
index 5db478d..a12d04e 100644
--- a/third_party/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
+++ b/third_party/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
@@ -1,19 +1,10 @@
//===-- LayoutUtils.cpp - Decorate composite type with layout information -===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements Utilities used to get alignment and layout information
// for types in SPIR-V dialect.
diff --git a/third_party/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/third_party/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
index def8ee8..144252b 100644
--- a/third_party/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
+++ b/third_party/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
@@ -94,7 +94,7 @@
/// Handle the given inlined terminator by replacing it with a new operation
/// as necessary.
void handleTerminator(Operation *op,
- ArrayRef<Value *> valuesToRepl) const final {
+ ArrayRef<Value> valuesToRepl) const final {
// Only spv.ReturnValue needs to be handled here.
auto retValOp = dyn_cast<spirv::ReturnValueOp>(op);
if (!retValOp)
diff --git a/third_party/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/third_party/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
index 284fe91..0d2348c 100644
--- a/third_party/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
+++ b/third_party/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
@@ -1,19 +1,10 @@
//===- SPIRVLowering.cpp - Standard to SPIR-V dialect conversion--===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements utilities used to lower to SPIR-V dialect.
//
@@ -229,9 +220,9 @@
/// Gets the global variable associated with a builtin and add
/// it if it doesn't exist.
-Value *mlir::spirv::getBuiltinVariableValue(Operation *op,
- spirv::BuiltIn builtin,
- OpBuilder &builder) {
+Value mlir::spirv::getBuiltinVariableValue(Operation *op,
+ spirv::BuiltIn builtin,
+ OpBuilder &builder) {
auto moduleOp = op->getParentOfType<spirv::ModuleOp>();
if (!moduleOp) {
op->emitError("expected operation to be within a SPIR-V module");
@@ -239,7 +230,7 @@
}
spirv::GlobalVariableOp varOp =
getOrInsertBuiltinVariable(moduleOp, op->getLoc(), builtin, builder);
- Value *ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp);
+ Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp);
return builder.create<spirv::LoadOp>(op->getLoc(), ptr,
/*memory_access =*/nullptr,
/*alignment =*/nullptr);
diff --git a/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 0df4525..f42c077 100644
--- a/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -1,19 +1,10 @@
//===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines the operations in the SPIR-V dialect.
//
@@ -273,8 +264,8 @@
}
template <typename LoadStoreOpTy>
-static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value *ptr,
- Value *val) {
+static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
+ Value val) {
// ODS already checks ptr is spirv::PointerType. Just check that the pointee
// type of the pointer and the type of the value are the same
//
@@ -664,8 +655,8 @@
}
static void printShiftOp(Operation *op, OpAsmPrinter &printer) {
- Value *base = op->getOperand(0);
- Value *shift = op->getOperand(1);
+ Value base = op->getOperand(0);
+ Value shift = op->getOperand(1);
printer << op->getName() << ' ' << *base << ", " << *shift << " : "
<< base->getType() << ", " << shift->getType();
}
@@ -742,7 +733,7 @@
}
void spirv::AccessChainOp::build(Builder *builder, OperationState &state,
- Value *basePtr, ValueRange indices) {
+ Value basePtr, ValueRange indices) {
auto type = getElementPtrType(basePtr->getType(), indices, state.location);
assert(type && "Unable to deduce return type based on basePtr and indices");
build(builder, state, type, basePtr, indices);
@@ -782,8 +773,8 @@
}
static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
- SmallVector<Value *, 4> indices(accessChainOp.indices().begin(),
- accessChainOp.indices().end());
+ SmallVector<Value, 4> indices(accessChainOp.indices().begin(),
+ accessChainOp.indices().end());
auto resultType = getElementPtrType(accessChainOp.base_ptr()->getType(),
indices, accessChainOp.getLoc());
if (!resultType) {
@@ -824,7 +815,7 @@
}
// Combine indices.
- SmallVector<Value *, 4> indices(parentAccessChainOp.indices());
+ SmallVector<Value, 4> indices(parentAccessChainOp.indices());
indices.append(accessChainOp.indices().begin(),
accessChainOp.indices().end());
@@ -1060,7 +1051,7 @@
static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &state) {
Block *dest;
- SmallVector<Value *, 4> destOperands;
+ SmallVector<Value, 4> destOperands;
if (parser.parseSuccessorAndUseList(dest, destOperands))
return failure();
state.addSuccessor(dest, destOperands);
@@ -1089,7 +1080,7 @@
auto &builder = parser.getBuilder();
OpAsmParser::OperandType condInfo;
Block *dest;
- SmallVector<Value *, 4> destOperands;
+ SmallVector<Value, 4> destOperands;
// Parse the condition.
Type boolTy = builder.getI1Type();
@@ -1214,7 +1205,7 @@
static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) {
auto cType = compositeConstructOp.getType().cast<spirv::CompositeType>();
- SmallVector<Value *, 4> constituents(compositeConstructOp.constituents());
+ SmallVector<Value, 4> constituents(compositeConstructOp.constituents());
if (constituents.size() != cType.getNumElements()) {
return compositeConstructOp.emitError(
"has incorrect number of operands: expected ")
@@ -1239,7 +1230,7 @@
//===----------------------------------------------------------------------===//
void spirv::CompositeExtractOp::build(Builder *builder, OperationState &state,
- Value *composite,
+ Value composite,
ArrayRef<int32_t> indices) {
auto indexAttr = builder->getI32ArrayAttr(indices);
auto elementType =
@@ -1963,7 +1954,7 @@
//===----------------------------------------------------------------------===//
void spirv::LoadOp::build(Builder *builder, OperationState &state,
- Value *basePtr, IntegerAttr memory_access,
+ Value basePtr, IntegerAttr memory_access,
IntegerAttr alignment) {
auto ptrType = basePtr->getType().cast<spirv::PointerType>();
build(builder, state, ptrType.getPointeeType(), basePtr, memory_access,
@@ -2496,8 +2487,8 @@
// spv.Select
//===----------------------------------------------------------------------===//
-void spirv::SelectOp::build(Builder *builder, OperationState &state,
- Value *cond, Value *trueValue, Value *falseValue) {
+void spirv::SelectOp::build(Builder *builder, OperationState &state, Value cond,
+ Value trueValue, Value falseValue) {
build(builder, state, trueValue->getType(), cond, trueValue, falseValue);
}
@@ -2698,9 +2689,9 @@
return matchFailure();
}
- auto *trueValue = getSrcValue(trueBlock);
- auto *falseValue = getSrcValue(falseBlock);
- auto *ptrValue = getDstPtr(trueBlock);
+ auto trueValue = getSrcValue(trueBlock);
+ auto falseValue = getSrcValue(falseBlock);
+ auto ptrValue = getDstPtr(trueBlock);
auto storeOpAttributes =
cast<spirv::StoreOp>(trueBlock->front()).getOperation()->getAttrs();
@@ -2747,13 +2738,13 @@
}
// Returns a soruce value for the given block.
- Value *getSrcValue(Block *block) const {
+ Value getSrcValue(Block *block) const {
auto storeOp = cast<spirv::StoreOp>(block->front());
return storeOp.value();
}
// Returns a destination value for the given block.
- Value *getDstPtr(Block *block) const {
+ Value getDstPtr(Block *block) const {
auto storeOp = cast<spirv::StoreOp>(block->front());
return storeOp.ptr();
}
diff --git a/third_party/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/third_party/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
index 15621aa..18e027a 100644
--- a/third_party/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
+++ b/third_party/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
@@ -1,19 +1,10 @@
//===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines the types in the SPIR-V dialect.
//
diff --git a/third_party/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/third_party/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
index df9cb47..17ddc48 100644
--- a/third_party/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
+++ b/third_party/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
@@ -1,19 +1,10 @@
//===- Deserializer.cpp - MLIR SPIR-V Deserialization ---------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines the SPIR-V binary to MLIR SPIR-V module deserialization.
//
@@ -327,7 +318,7 @@
/// This method materializes normal constants and inserts "casting" ops
/// (`spv._address_of` and `spv._reference_of`) to turn an symbol into a SSA
/// value for handling uses of module scope constants/variables in functions.
- Value *getValue(uint32_t id);
+ Value getValue(uint32_t id);
/// Slices the first instruction out of `binary` and returns its opcode and
/// operands via `opcode` and `operands` respectively. Returns failure if
@@ -446,7 +437,7 @@
DenseMap<Block *, BlockPhiInfo> blockPhiInfo;
// Result <id> to value mapping.
- DenseMap<uint32_t, Value *> valueMap;
+ DenseMap<uint32_t, Value> valueMap;
// Mapping from result <id> to undef value of a type.
DenseMap<uint32_t, Type> undefMap;
@@ -1520,7 +1511,7 @@
"false label, and optionally two branch weights");
}
- auto *condition = getValue(operands[0]);
+ auto condition = getValue(operands[0]);
auto *trueBlock = getOrCreateBlock(operands[1]);
auto *falseBlock = getOrCreateBlock(operands[2]);
@@ -1531,8 +1522,8 @@
opBuilder.create<spirv::BranchConditionalOp>(
unknownLoc, condition, trueBlock,
- /*trueArguments=*/ArrayRef<Value *>(), falseBlock,
- /*falseArguments=*/ArrayRef<Value *>(), weights);
+ /*trueArguments=*/ArrayRef<Value>(), falseBlock,
+ /*falseArguments=*/ArrayRef<Value>(), weights);
return success();
}
@@ -1626,7 +1617,7 @@
// Create a block argument for this OpPhi instruction.
Type blockArgType = getType(operands[0]);
- BlockArgument *blockArg = curBlock->addArgument(blockArgType);
+ BlockArgument blockArg = curBlock->addArgument(blockArgType);
valueMap[operands[1]] = blockArg;
LLVM_DEBUG(llvm::dbgs() << "[phi] created block argument " << blockArg
<< " id = " << operands[1] << " of type "
@@ -1783,8 +1774,8 @@
LLVM_DEBUG(llvm::dbgs() << "[cf] cloned block " << newBlock
<< " from block " << block << "\n");
if (!isFnEntryBlock(block)) {
- for (BlockArgument *blockArg : block->getArguments()) {
- auto *newArg = newBlock->addArgument(blockArg->getType());
+ for (BlockArgument blockArg : block->getArguments()) {
+ auto newArg = newBlock->addArgument(blockArg->getType());
mapper.map(blockArg, newArg);
LLVM_DEBUG(llvm::dbgs() << "[cf] remapped block argument " << blockArg
<< " to " << newArg << '\n');
@@ -1801,10 +1792,10 @@
// Go through all ops and remap the operands.
auto remapOperands = [&](Operation *op) {
for (auto &operand : op->getOpOperands())
- if (auto *mappedOp = mapper.lookupOrNull(operand.get()))
+ if (auto mappedOp = mapper.lookupOrNull(operand.get()))
operand.set(mappedOp);
for (auto &succOp : op->getBlockOperands())
- if (auto *mappedOp = mapper.lookupOrNull(succOp.get()))
+ if (auto mappedOp = mapper.lookupOrNull(succOp.get()))
succOp.set(mappedOp);
};
for (auto &block : body) {
@@ -1824,13 +1815,13 @@
// we place the selection/loop op inside the old merge block, we need to
// make sure the old merge block has the same block argument list.
assert(mergeBlock->args_empty() && "OpPhi in loop merge block unsupported");
- for (BlockArgument *blockArg : headerBlock->getArguments()) {
+ for (BlockArgument blockArg : headerBlock->getArguments()) {
mergeBlock->addArgument(blockArg->getType());
}
// If the loop header block has block arguments, make sure the spv.branch op
// matches.
- SmallVector<Value *, 4> blockArgs;
+ SmallVector<Value, 4> blockArgs;
if (!headerBlock->args_empty())
blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
@@ -1838,7 +1829,7 @@
// loop header block.
builder.setInsertionPointToEnd(&body.front());
builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock),
- ArrayRef<Value *>(blockArgs));
+ ArrayRef<Value>(blockArgs));
}
// All the blocks cloned into the SelectionOp/LoopOp's region can now be
@@ -1924,10 +1915,10 @@
auto *op = block->getTerminator();
opBuilder.setInsertionPoint(op);
- SmallVector<Value *, 4> blockArgs;
+ SmallVector<Value, 4> blockArgs;
blockArgs.reserve(phiInfo.size());
for (uint32_t valueId : phiInfo) {
- if (Value *value = getValue(valueId)) {
+ if (Value value = getValue(valueId)) {
blockArgs.push_back(value);
LLVM_DEBUG(llvm::dbgs() << "[phi] block argument " << value
<< " id = " << valueId << '\n');
@@ -1996,7 +1987,7 @@
// Instruction
//===----------------------------------------------------------------------===//
-Value *Deserializer::getValue(uint32_t id) {
+Value Deserializer::getValue(uint32_t id) {
if (auto constInfo = getConstant(id)) {
// Materialize a `spv.constant` op at every use site.
return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second,
@@ -2192,7 +2183,7 @@
}
}
valueID = words[wordIndex++];
- SmallVector<Value *, 4> operands;
+ SmallVector<Value, 4> operands;
SmallVector<NamedAttribute, 4> attributes;
if (wordIndex < words.size()) {
auto arg = getValue(words[wordIndex]);
@@ -2366,9 +2357,9 @@
auto functionName = getFunctionSymbol(functionID);
- SmallVector<Value *, 4> arguments;
+ SmallVector<Value, 4> arguments;
for (auto operand : llvm::drop_begin(operands, 3)) {
- auto *value = getValue(operand);
+ auto value = getValue(operand);
if (!value) {
return emitError(unknownLoc, "unknown <id> ")
<< operand << " used by OpFunctionCall";
diff --git a/third_party/mlir/lib/Dialect/SPIRV/Serialization/SPIRVBinaryUtils.cpp b/third_party/mlir/lib/Dialect/SPIRV/Serialization/SPIRVBinaryUtils.cpp
index ba383b2..13405c9 100644
--- a/third_party/mlir/lib/Dialect/SPIRV/Serialization/SPIRVBinaryUtils.cpp
+++ b/third_party/mlir/lib/Dialect/SPIRV/Serialization/SPIRVBinaryUtils.cpp
@@ -1,19 +1,10 @@
//===- SPIRVBinaryUtils.cpp - MLIR SPIR-V Binary Module Utilities ---------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines common utilities for SPIR-V binary module.
//
diff --git a/third_party/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/third_party/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
index 4baac53..0cdcc25 100644
--- a/third_party/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
+++ b/third_party/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
@@ -1,19 +1,10 @@
//===- Serializer.cpp - MLIR SPIR-V Serialization -------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines the MLIR SPIR-V module to SPIR-V binary serialization.
//
@@ -323,7 +314,7 @@
uint32_t opcode,
ArrayRef<uint32_t> operands);
- uint32_t getValueID(Value *val) const { return valueIDMap.lookup(val); }
+ uint32_t getValueID(Value val) const { return valueIDMap.lookup(val); }
LogicalResult processAddressOfOp(spirv::AddressOfOp addressOfOp);
@@ -414,7 +405,7 @@
DenseMap<Type, uint32_t> undefValIDMap;
/// Map from results of normal operations to their <id>s.
- DenseMap<Value *, uint32_t> valueIDMap;
+ DenseMap<Value, uint32_t> valueIDMap;
/// Map from extended instruction set name to <id>s.
llvm::StringMap<uint32_t> extendedInstSetIDMap;
@@ -457,7 +448,7 @@
/// placed inside `functions`) here. And then after emitting all blocks, we
/// replace the dummy <id> 0 with the real result <id> by overwriting
/// `functions[offset]`.
- DenseMap<Value *, SmallVector<size_t, 1>> deferredPhiValues;
+ DenseMap<Value, SmallVector<size_t, 1>> deferredPhiValues;
};
} // namespace
@@ -513,12 +504,12 @@
void Serializer::printValueIDMap(raw_ostream &os) {
os << "\n= Value <id> Map =\n\n";
for (auto valueIDPair : valueIDMap) {
- Value *val = valueIDPair.first;
+ Value val = valueIDPair.first;
os << " " << val << " "
<< "id = " << valueIDPair.second << ' ';
if (auto *op = val->getDefiningOp()) {
os << "from op '" << op->getName() << "'";
- } else if (auto *arg = dyn_cast<BlockArgument>(val)) {
+ } else if (auto arg = val.dyn_cast<BlockArgument>()) {
Block *block = arg->getOwner();
os << "from argument of block " << block << ' ';
os << " in op '" << block->getParentOp()->getName() << "'";
@@ -752,7 +743,7 @@
// There might be OpPhi instructions who have value references needing to fix.
for (auto deferredValue : deferredPhiValues) {
- Value *value = deferredValue.first;
+ Value value = deferredValue.first;
uint32_t id = getValueID(value);
LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value
<< " to id = " << id << '\n');
@@ -1402,7 +1393,7 @@
// Then create OpPhi instruction for each of the block argument.
for (auto argIndex : llvm::seq<unsigned>(0, block->getNumArguments())) {
- BlockArgument *arg = block->getArgument(argIndex);
+ BlockArgument arg = block->getArgument(argIndex);
// Get the type <id> and result <id> for this OpPhi instruction.
uint32_t phiTypeID = 0;
@@ -1418,7 +1409,7 @@
phiArgs.push_back(phiID);
for (auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
- Value *value = *(predecessors[predIndex].second + argIndex);
+ Value value = *(predecessors[predIndex].second + argIndex);
uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId
<< ") value " << value << ' ');
@@ -1784,7 +1775,7 @@
auto funcCallID = getNextID();
SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID};
- for (auto *value : op.arguments()) {
+ for (auto value : op.arguments()) {
auto valueID = getValueID(value);
assert(valueID && "cannot find a value for spv.FunctionCall");
operands.push_back(valueID);
diff --git a/third_party/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp b/third_party/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp
index e9b4f23..750710f 100644
--- a/third_party/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp
+++ b/third_party/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp
@@ -1,19 +1,10 @@
//===- TranslateRegistration.cpp - hooks to mlir-translate ----------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements a translation from SPIR-V binary module to MLIR SPIR-V
// ModuleOp.
diff --git a/third_party/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp b/third_party/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp
index be486f8..07621d6 100644
--- a/third_party/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp
+++ b/third_party/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp
@@ -1,19 +1,10 @@
//===- DecorateSPIRVCompositeTypeLayoutPass.cpp - Decorate composite type -===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements a pass to decorate the composite types used by
// composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and
diff --git a/third_party/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/third_party/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index d48b31f..d7194da 100644
--- a/third_party/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/third_party/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -1,19 +1,10 @@
//===- LowerABIAttributesPass.cpp - Decorate composite type ---------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements a pass to lower attributes that specify the shader ABI
// for the functions in the generated SPIR-V module.
@@ -140,7 +131,7 @@
public:
using SPIRVOpLowering<FuncOp>::SPIRVOpLowering;
PatternMatchResult
- matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
+ matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@@ -153,7 +144,7 @@
} // namespace
PatternMatchResult
-FuncOpLowering::matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
+FuncOpLowering::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (!funcOp.getAttrOfType<spirv::EntryPointABIAttr>(
spirv::getEntryPointABIAttrName())) {
@@ -183,7 +174,7 @@
OpBuilder::InsertionGuard funcInsertionGuard(rewriter);
rewriter.setInsertionPointToStart(&funcOp.front());
// Insert spirv::AddressOf and spirv::AccessChain operations.
- Value *replacement =
+ Value replacement =
rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);
// Check if the arg is a scalar or vector type. In that case, the value
// needs to be loaded into registers.
@@ -206,13 +197,11 @@
}
// Creates a new function with the update signature.
- auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
- rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
- newFuncOp.end());
- newFuncOp.setType(rewriter.getFunctionType(
- signatureConverter.getConvertedTypes(), llvm::None));
- rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
- rewriter.eraseOp(funcOp.getOperation());
+ rewriter.updateRootInPlace(funcOp, [&] {
+ funcOp.setType(rewriter.getFunctionType(
+ signatureConverter.getConvertedTypes(), llvm::None));
+ rewriter.applySignatureConversion(&funcOp.getBody(), signatureConverter);
+ });
return matchSuccess();
}
diff --git a/third_party/mlir/lib/Dialect/StandardOps/DialectRegistration.cpp b/third_party/mlir/lib/Dialect/StandardOps/DialectRegistration.cpp
index 6b5578f..6848060 100644
--- a/third_party/mlir/lib/Dialect/StandardOps/DialectRegistration.cpp
+++ b/third_party/mlir/lib/Dialect/StandardOps/DialectRegistration.cpp
@@ -1,19 +1,10 @@
//===- DialectRegistration.cpp - Register standard Op dialect -------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/StandardOps/Ops.h"
using namespace mlir;
diff --git a/third_party/mlir/lib/Dialect/StandardOps/Ops.cpp b/third_party/mlir/lib/Dialect/StandardOps/Ops.cpp
index d0fd185..831c78a 100644
--- a/third_party/mlir/lib/Dialect/StandardOps/Ops.cpp
+++ b/third_party/mlir/lib/Dialect/StandardOps/Ops.cpp
@@ -1,19 +1,10 @@
//===- Ops.cpp - Standard MLIR Operations ---------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/StandardOps/Ops.h"
@@ -81,7 +72,7 @@
/// Handle the given inlined terminator by replacing it with a new operation
/// as necessary.
void handleTerminator(Operation *op,
- ArrayRef<Value *> valuesToRepl) const final {
+ ArrayRef<Value> valuesToRepl) const final {
// Only "std.return" needs to be handled here.
auto returnOp = cast<ReturnOp>(op);
@@ -184,7 +175,7 @@
// dimension operands parsed.
// Returns 'false' on success and 'true' on error.
ParseResult mlir::parseDimAndSymbolList(OpAsmParser &parser,
- SmallVectorImpl<Value *> &operands,
+ SmallVectorImpl<Value> &operands,
unsigned &numDims) {
SmallVector<OpAsmParser::OperandType, 8> opInfos;
if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
@@ -325,7 +316,7 @@
PatternRewriter &rewriter) const override {
// Check to see if any dimensions operands are constants. If so, we can
// substitute and drop them.
- if (llvm::none_of(alloc.getOperands(), [](Value *operand) {
+ if (llvm::none_of(alloc.getOperands(), [](Value operand) {
return matchPattern(operand, m_ConstantIndex());
}))
return matchFailure();
@@ -336,8 +327,8 @@
// and keep track of the resultant memref type to build.
SmallVector<int64_t, 4> newShapeConstants;
newShapeConstants.reserve(memrefType.getRank());
- SmallVector<Value *, 4> newOperands;
- SmallVector<Value *, 4> droppedOperands;
+ SmallVector<Value, 4> newOperands;
+ SmallVector<Value, 4> droppedOperands;
unsigned dynamicDimPos = 0;
for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
@@ -429,7 +420,7 @@
static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &result) {
Block *dest;
- SmallVector<Value *, 4> destOperands;
+ SmallVector<Value, 4> destOperands;
if (parser.parseSuccessorAndUseList(dest, destOperands))
return failure();
result.addSuccessor(dest, destOperands);
@@ -623,7 +614,7 @@
//===----------------------------------------------------------------------===//
static void buildCmpIOp(Builder *build, OperationState &result,
- CmpIPredicate predicate, Value *lhs, Value *rhs) {
+ CmpIPredicate predicate, Value lhs, Value rhs) {
result.addOperands({lhs, rhs});
result.types.push_back(getI1SameShape(build, lhs->getType()));
result.addAttribute(
@@ -777,7 +768,7 @@
}
static void buildCmpFOp(Builder *build, OperationState &result,
- CmpFPredicate predicate, Value *lhs, Value *rhs) {
+ CmpFPredicate predicate, Value lhs, Value rhs) {
result.addOperands({lhs, rhs});
result.types.push_back(getI1SameShape(build, lhs->getType()));
result.addAttribute(
@@ -946,7 +937,7 @@
static ParseResult parseCondBranchOp(OpAsmParser &parser,
OperationState &result) {
- SmallVector<Value *, 4> destOperands;
+ SmallVector<Value, 4> destOperands;
Block *dest;
OpAsmParser::OperandType condInfo;
@@ -1088,7 +1079,7 @@
}
void ConstantOp::getAsmResultNames(
- function_ref<void(Value *, StringRef)> setNameFn) {
+ function_ref<void(Value, StringRef)> setNameFn) {
Type type = getType();
if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
IntegerType intTy = type.dyn_cast<IntegerType>();
@@ -1183,7 +1174,7 @@
PatternMatchResult matchAndRewrite(DeallocOp dealloc,
PatternRewriter &rewriter) const override {
// Check that the memref operand's defining operation is an AllocOp.
- Value *memref = dealloc.memref();
+ Value memref = dealloc.memref();
if (!isa_and_nonnull<AllocOp>(memref->getDefiningOp()))
return matchFailure();
@@ -1320,10 +1311,10 @@
}
//===----------------------------------------------------------------------===//
-// DivISOp
+// SignedDivIOp
//===----------------------------------------------------------------------===//
-OpFoldResult DivISOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult SignedDivIOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "binary operation takes two operands");
// Don't fold if it would overflow or if it requires a division by zero.
@@ -1339,10 +1330,10 @@
}
//===----------------------------------------------------------------------===//
-// DivIUOp
+// UnsignedDivIOp
//===----------------------------------------------------------------------===//
-OpFoldResult DivIUOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult UnsignedDivIOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "binary operation takes two operands");
// Don't fold if it would require a division by zero.
@@ -1362,11 +1353,10 @@
// ---------------------------------------------------------------------------
void DmaStartOp::build(Builder *builder, OperationState &result,
- Value *srcMemRef, ValueRange srcIndices,
- Value *destMemRef, ValueRange destIndices,
- Value *numElements, Value *tagMemRef,
- ValueRange tagIndices, Value *stride,
- Value *elementsPerStride) {
+ Value srcMemRef, ValueRange srcIndices, Value destMemRef,
+ ValueRange destIndices, Value numElements,
+ Value tagMemRef, ValueRange tagIndices, Value stride,
+ Value elementsPerStride) {
result.addOperands(srcMemRef);
result.addOperands(srcIndices);
result.addOperands(destMemRef);
@@ -1506,9 +1496,8 @@
// DmaWaitOp
// ---------------------------------------------------------------------------
-void DmaWaitOp::build(Builder *builder, OperationState &result,
- Value *tagMemRef, ValueRange tagIndices,
- Value *numElements) {
+void DmaWaitOp::build(Builder *builder, OperationState &result, Value tagMemRef,
+ ValueRange tagIndices, Value numElements) {
result.addOperands(tagMemRef);
result.addOperands(tagIndices);
result.addOperands(numElements);
@@ -1885,11 +1874,11 @@
}
//===----------------------------------------------------------------------===//
-// RemISOp
+// SignedRemIOp
//===----------------------------------------------------------------------===//
-OpFoldResult RemISOp::fold(ArrayRef<Attribute> operands) {
- assert(operands.size() == 2 && "remis takes two operands");
+OpFoldResult SignedRemIOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.size() == 2 && "remi_signed takes two operands");
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
if (!rhs)
@@ -1911,11 +1900,11 @@
}
//===----------------------------------------------------------------------===//
-// RemIUOp
+// UnsignedRemIOp
//===----------------------------------------------------------------------===//
-OpFoldResult RemIUOp::fold(ArrayRef<Attribute> operands) {
- assert(operands.size() == 2 && "remiu takes two operands");
+OpFoldResult UnsignedRemIOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.size() == 2 && "remi_unsigned takes two operands");
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
if (!rhs)
@@ -2025,7 +2014,7 @@
}
OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
- auto *condition = getCondition();
+ auto condition = getCondition();
// select true, %0, %1 => %0
if (matchPattern(condition, m_One()))
@@ -2357,7 +2346,7 @@
static void print(OpAsmPrinter &p, ViewOp op) {
p << op.getOperationName() << ' ' << *op.getOperand(0) << '[';
- auto *dynamicOffset = op.getDynamicOffset();
+ auto dynamicOffset = op.getDynamicOffset();
if (dynamicOffset != nullptr)
p.printOperand(dynamicOffset);
p << "][" << op.getDynamicSizes() << ']';
@@ -2365,7 +2354,7 @@
p << " : " << op.getOperand(0)->getType() << " to " << op.getType();
}
-Value *ViewOp::getDynamicOffset() {
+Value ViewOp::getDynamicOffset() {
int64_t offset;
SmallVector<int64_t, 4> strides;
auto result =
@@ -2440,7 +2429,7 @@
PatternMatchResult matchAndRewrite(ViewOp viewOp,
PatternRewriter &rewriter) const override {
// Return if none of the operands are constants.
- if (llvm::none_of(viewOp.getOperands(), [](Value *operand) {
+ if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
return matchPattern(operand, m_ConstantIndex());
}))
return matchFailure();
@@ -2457,11 +2446,11 @@
if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
return matchFailure();
- SmallVector<Value *, 4> newOperands;
- SmallVector<Value *, 4> droppedOperands;
+ SmallVector<Value, 4> newOperands;
+ SmallVector<Value, 4> droppedOperands;
// Fold dynamic offset operand if it is produced by a constant.
- auto *dynamicOffset = viewOp.getDynamicOffset();
+ auto dynamicOffset = viewOp.getDynamicOffset();
int64_t newOffset = oldOffset;
unsigned dynamicOffsetOperandCount = 0;
if (dynamicOffset != nullptr) {
@@ -2576,7 +2565,7 @@
memRefType.getMemorySpace());
}
-void mlir::SubViewOp::build(Builder *b, OperationState &result, Value *source,
+void mlir::SubViewOp::build(Builder *b, OperationState &result, Value source,
ValueRange offsets, ValueRange sizes,
ValueRange strides, Type resultType,
ArrayRef<NamedAttribute> attrs) {
@@ -2590,7 +2579,7 @@
}
void mlir::SubViewOp::build(Builder *b, OperationState &result, Type resultType,
- Value *source) {
+ Value source) {
build(b, result, source, /*offsets=*/{}, /*sizes=*/{}, /*strides=*/{},
resultType);
}
@@ -2826,7 +2815,7 @@
// Follow all or nothing approach for shapes for now. If all the operands
// for sizes are constants then fold it into the type of the result memref.
if (subViewType.hasStaticShape() ||
- llvm::any_of(subViewOp.sizes(), [](Value *operand) {
+ llvm::any_of(subViewOp.sizes(), [](Value operand) {
return !matchPattern(operand, m_ConstantIndex());
})) {
return matchFailure();
@@ -2842,7 +2831,7 @@
subViewType.getMemorySpace());
auto newSubViewOp = rewriter.create<SubViewOp>(
subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(),
- ArrayRef<Value *>(), subViewOp.strides(), newMemRefType);
+ ArrayRef<Value>(), subViewOp.strides(), newMemRefType);
// Insert a memref_cast for compatibility of the uses of the op.
rewriter.replaceOpWithNewOp<MemRefCastOp>(
subViewOp.sizes(), subViewOp, newSubViewOp, subViewOp.getType());
@@ -2871,7 +2860,7 @@
failed(getStridesAndOffset(subViewType, resultStrides, resultOffset)) ||
llvm::is_contained(baseStrides,
MemRefType::getDynamicStrideOrOffset()) ||
- llvm::any_of(subViewOp.strides(), [](Value *stride) {
+ llvm::any_of(subViewOp.strides(), [](Value stride) {
return !matchPattern(stride, m_ConstantIndex());
})) {
return matchFailure();
@@ -2892,7 +2881,7 @@
layoutMap, subViewType.getMemorySpace());
auto newSubViewOp = rewriter.create<SubViewOp>(
subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(),
- subViewOp.sizes(), ArrayRef<Value *>(), newMemRefType);
+ subViewOp.sizes(), ArrayRef<Value>(), newMemRefType);
// Insert a memref_cast for compatibility of the uses of the op.
rewriter.replaceOpWithNewOp<MemRefCastOp>(
subViewOp.strides(), subViewOp, newSubViewOp, subViewOp.getType());
@@ -2922,7 +2911,7 @@
llvm::is_contained(baseStrides,
MemRefType::getDynamicStrideOrOffset()) ||
baseOffset == MemRefType::getDynamicStrideOrOffset() ||
- llvm::any_of(subViewOp.offsets(), [](Value *stride) {
+ llvm::any_of(subViewOp.offsets(), [](Value stride) {
return !matchPattern(stride, m_ConstantIndex());
})) {
return matchFailure();
@@ -2943,7 +2932,7 @@
MemRefType::get(subViewType.getShape(), subViewType.getElementType(),
layoutMap, subViewType.getMemorySpace());
auto newSubViewOp = rewriter.create<SubViewOp>(
- subViewOp.getLoc(), subViewOp.source(), ArrayRef<Value *>(),
+ subViewOp.getLoc(), subViewOp.source(), ArrayRef<Value>(),
subViewOp.sizes(), subViewOp.strides(), newMemRefType);
// Insert a memref_cast for compatibility of the uses of the op.
rewriter.replaceOpWithNewOp<MemRefCastOp>(
diff --git a/third_party/mlir/lib/Dialect/Traits.cpp b/third_party/mlir/lib/Dialect/Traits.cpp
index 0ac07c2..3aea206 100644
--- a/third_party/mlir/lib/Dialect/Traits.cpp
+++ b/third_party/mlir/lib/Dialect/Traits.cpp
@@ -1,19 +1,10 @@
//===- Traits.cpp - Common op traits shared by dialects -------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/StandardTypes.h"
diff --git a/third_party/mlir/lib/Dialect/VectorOps/DialectRegistration.cpp b/third_party/mlir/lib/Dialect/VectorOps/DialectRegistration.cpp
index 0caa1cf..edd6abb 100644
--- a/third_party/mlir/lib/Dialect/VectorOps/DialectRegistration.cpp
+++ b/third_party/mlir/lib/Dialect/VectorOps/DialectRegistration.cpp
@@ -1,19 +1,10 @@
//===- DialectRegistration.cpp - Register super vectorization dialect -----===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/VectorOps/VectorOps.h"
using namespace mlir;
diff --git a/third_party/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/third_party/mlir/lib/Dialect/VectorOps/VectorOps.cpp
index 4ed0902..a3904ef 100644
--- a/third_party/mlir/lib/Dialect/VectorOps/VectorOps.cpp
+++ b/third_party/mlir/lib/Dialect/VectorOps/VectorOps.cpp
@@ -1,19 +1,10 @@
//===- VectorOps.cpp - MLIR Super Vectorizer Operations -------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements convenience types for working with super-vectorization
// operations, in particular super-vector loads and stores.
@@ -58,12 +49,21 @@
return builder.create<ConstantOp>(loc, type, value);
}
+IntegerType vector::getVectorSubscriptType(Builder &builder) {
+ return builder.getIntegerType(64);
+}
+
+ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
+ ArrayRef<int64_t> values) {
+ return builder.getI64ArrayAttr(values);
+}
+
//===----------------------------------------------------------------------===//
// ContractionOp
//===----------------------------------------------------------------------===//
void vector::ContractionOp::build(Builder *builder, OperationState &result,
- Value *lhs, Value *rhs, Value *acc,
+ Value lhs, Value rhs, Value acc,
ArrayAttr indexingMaps,
ArrayAttr iteratorTypes) {
result.addOperands({lhs, rhs, acc});
@@ -395,9 +395,9 @@
}
void vector::ExtractOp::build(Builder *builder, OperationState &result,
- Value *source, ArrayRef<int32_t> position) {
+ Value source, ArrayRef<int64_t> position) {
result.addOperands(source);
- auto positionAttr = builder->getI32ArrayAttr(position);
+ auto positionAttr = getVectorSubscriptAttr(*builder, position);
result.addTypes(inferExtractOpResultType(source->getType().cast<VectorType>(),
positionAttr));
result.addAttribute(getPositionAttrName(), positionAttr);
@@ -462,12 +462,12 @@
//===----------------------------------------------------------------------===//
void ExtractSlicesOp::build(Builder *builder, OperationState &result,
- TupleType tupleType, Value *vector,
+ TupleType tupleType, Value vector,
ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides) {
result.addOperands(vector);
- auto sizesAttr = builder->getI64ArrayAttr(sizes);
- auto stridesAttr = builder->getI64ArrayAttr(strides);
+ auto sizesAttr = getVectorSubscriptAttr(*builder, sizes);
+ auto stridesAttr = getVectorSubscriptAttr(*builder, strides);
result.addTypes(tupleType);
result.addAttribute(getSizesAttrName(), sizesAttr);
result.addAttribute(getStridesAttrName(), stridesAttr);
@@ -638,10 +638,10 @@
// ShuffleOp
//===----------------------------------------------------------------------===//
-void ShuffleOp::build(Builder *builder, OperationState &result, Value *v1,
- Value *v2, ArrayRef<int32_t> mask) {
+void ShuffleOp::build(Builder *builder, OperationState &result, Value v1,
+ Value v2, ArrayRef<int64_t> mask) {
result.addOperands({v1, v2});
- auto maskAttr = builder->getI32ArrayAttr(mask);
+ auto maskAttr = getVectorSubscriptAttr(*builder, mask);
result.addTypes(v1->getType());
result.addAttribute(getMaskAttrName(), maskAttr);
}
@@ -762,10 +762,10 @@
// InsertOp
//===----------------------------------------------------------------------===//
-void InsertOp::build(Builder *builder, OperationState &result, Value *source,
- Value *dest, ArrayRef<int32_t> position) {
+void InsertOp::build(Builder *builder, OperationState &result, Value source,
+ Value dest, ArrayRef<int64_t> position) {
result.addOperands({source, dest});
- auto positionAttr = builder->getI32ArrayAttr(position);
+ auto positionAttr = getVectorSubscriptAttr(*builder, position);
result.addTypes(dest->getType());
result.addAttribute(getPositionAttrName(), positionAttr);
}
@@ -884,12 +884,12 @@
//===----------------------------------------------------------------------===//
void InsertStridedSliceOp::build(Builder *builder, OperationState &result,
- Value *source, Value *dest,
+ Value source, Value dest,
ArrayRef<int64_t> offsets,
ArrayRef<int64_t> strides) {
result.addOperands({source, dest});
- auto offsetsAttr = builder->getI64ArrayAttr(offsets);
- auto stridesAttr = builder->getI64ArrayAttr(strides);
+ auto offsetsAttr = getVectorSubscriptAttr(*builder, offsets);
+ auto stridesAttr = getVectorSubscriptAttr(*builder, strides);
result.addTypes(dest->getType());
result.addAttribute(getOffsetsAttrName(), offsetsAttr);
result.addAttribute(getStridesAttrName(), stridesAttr);
@@ -1099,6 +1099,123 @@
}
//===----------------------------------------------------------------------===//
+// ReshapeOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter &p, ReshapeOp op) {
+ p << op.getOperationName() << " " << *op.vector() << ", [" << op.input_shape()
+ << "], [" << op.output_shape() << "], " << op.fixed_vector_sizes();
+ SmallVector<StringRef, 2> elidedAttrs = {
+ ReshapeOp::getOperandSegmentSizeAttr(),
+ ReshapeOp::getFixedVectorSizesAttrName()};
+ p.printOptionalAttrDict(op.getAttrs(), elidedAttrs);
+ p << " : " << op.getInputVectorType() << " to " << op.getOutputVectorType();
+}
+
+// TODO(b/146516564) Consider passing number of inner vector dimensions that
+// are fixed, instead of their values in 'fixesVectorSizes' array attr.
+//
+// operation ::= ssa-id `=` `vector.reshape` ssa-use, `[` ssa-use-list `]`,
+// `[` ssa-use-list `]`, `[` array-attribute `]`
+// `:` vector-type 'to' vector-type
+//
+static ParseResult parseReshapeOp(OpAsmParser &parser, OperationState &result) {
+ OpAsmParser::OperandType inputInfo;
+ SmallVector<OpAsmParser::OperandType, 4> inputShapeInfo;
+ SmallVector<OpAsmParser::OperandType, 4> outputShapeInfo;
+ ArrayAttr fixedVectorSizesAttr;
+ StringRef attrName = ReshapeOp::getFixedVectorSizesAttrName();
+ auto indexType = parser.getBuilder().getIndexType();
+ if (parser.parseOperand(inputInfo) || parser.parseComma() ||
+ parser.parseOperandList(inputShapeInfo, OpAsmParser::Delimiter::Square) ||
+ parser.parseComma() ||
+ parser.parseOperandList(outputShapeInfo,
+ OpAsmParser::Delimiter::Square) ||
+ parser.parseComma()) {
+ return failure();
+ }
+
+ auto builder = parser.getBuilder();
+ result.addAttribute(
+ ReshapeOp::getOperandSegmentSizeAttr(),
+ builder.getI32VectorAttr({1, static_cast<int32_t>(inputShapeInfo.size()),
+ static_cast<int32_t>(outputShapeInfo.size())}));
+ Type inputType;
+ Type outputType;
+ return failure(
+ parser.parseAttribute(fixedVectorSizesAttr, attrName,
+ result.attributes) ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonType(inputType) ||
+ parser.resolveOperand(inputInfo, inputType, result.operands) ||
+ parser.resolveOperands(inputShapeInfo, indexType, result.operands) ||
+ parser.resolveOperands(outputShapeInfo, indexType, result.operands) ||
+ parser.parseKeywordType("to", outputType) ||
+ parser.addTypeToList(outputType, result.types));
+}
+
+static LogicalResult verify(ReshapeOp op) {
+ // Verify that rank(numInputs/outputs) + numFixedVec dim matches vec rank.
+ auto inputVectorType = op.getInputVectorType();
+ auto outputVectorType = op.getOutputVectorType();
+ int64_t inputShapeRank = op.getNumInputShapeSizes();
+ int64_t outputShapeRank = op.getNumOutputShapeSizes();
+ SmallVector<int64_t, 4> fixedVectorSizes;
+ op.getFixedVectorSizes(fixedVectorSizes);
+ int64_t numFixedVectorSizes = fixedVectorSizes.size();
+
+ if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes)
+ return op.emitError("invalid input shape for vector type ")
+ << inputVectorType;
+
+ if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes)
+ return op.emitError("invalid output shape for vector type ")
+ << outputVectorType;
+
+ // Verify that the 'fixedVectorSizes' match a input/output vector shape
+ // suffix.
+ unsigned inputVectorRank = inputVectorType.getRank();
+ for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
+ unsigned index = inputVectorRank - numFixedVectorSizes - i;
+ if (fixedVectorSizes[i] != inputVectorType.getShape()[index])
+ return op.emitError("fixed vector size must match input vector for dim ")
+ << i;
+ }
+
+ unsigned outputVectorRank = outputVectorType.getRank();
+ for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
+ unsigned index = outputVectorRank - numFixedVectorSizes - i;
+ if (fixedVectorSizes[i] != outputVectorType.getShape()[index])
+ return op.emitError("fixed vector size must match output vector for dim ")
+ << i;
+ }
+
+ // If all shape operands are produced by constant ops, verify that product
+ // of dimensions for input/output shape match.
+ auto isDefByConstant = [](Value operand) {
+ return isa_and_nonnull<ConstantIndexOp>(operand->getDefiningOp());
+ };
+ if (llvm::all_of(op.input_shape(), isDefByConstant) &&
+ llvm::all_of(op.output_shape(), isDefByConstant)) {
+ int64_t numInputElements = 1;
+ for (auto operand : op.input_shape())
+ numInputElements *=
+ cast<ConstantIndexOp>(operand->getDefiningOp()).getValue();
+ int64_t numOutputElements = 1;
+ for (auto operand : op.output_shape())
+ numOutputElements *=
+ cast<ConstantIndexOp>(operand->getDefiningOp()).getValue();
+ if (numInputElements != numOutputElements)
+ return op.emitError("product of input and output shape sizes must match");
+ }
+ return success();
+}
+
+void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) {
+ populateFromInt64AttrArray(fixed_vector_sizes(), results);
+}
+
+//===----------------------------------------------------------------------===//
// StridedSliceOp
//===----------------------------------------------------------------------===//
@@ -1121,12 +1238,12 @@
}
void StridedSliceOp::build(Builder *builder, OperationState &result,
- Value *source, ArrayRef<int64_t> offsets,
+ Value source, ArrayRef<int64_t> offsets,
ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides) {
result.addOperands(source);
- auto offsetsAttr = builder->getI64ArrayAttr(offsets);
- auto sizesAttr = builder->getI64ArrayAttr(sizes);
- auto stridesAttr = builder->getI64ArrayAttr(strides);
+ auto offsetsAttr = getVectorSubscriptAttr(*builder, offsets);
+ auto sizesAttr = getVectorSubscriptAttr(*builder, sizes);
+ auto stridesAttr = getVectorSubscriptAttr(*builder, strides);
result.addTypes(
inferStridedSliceOpResultType(source->getType().cast<VectorType>(),
offsetsAttr, sizesAttr, stridesAttr));
@@ -1249,7 +1366,7 @@
// Replace 'stridedSliceOp' with ConstantMaskOp with sliced mask region.
rewriter.replaceOpWithNewOp<ConstantMaskOp>(
stridedSliceOp, stridedSliceOp.getResult()->getType(),
- rewriter.getI64ArrayAttr(sliceMaskDimSizes));
+ vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes));
return matchSuccess();
}
};
@@ -1294,6 +1411,59 @@
return success();
}
+static LogicalResult verifyTransferOp(Operation *op, MemRefType memrefType,
+ VectorType vectorType,
+ AffineMap permutationMap) {
+ auto memrefElementType = memrefType.getElementType();
+ if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
+ // Memref has vector element type.
+
+ // Check that 'memrefVectorElementType' and vector element types match.
+ if (memrefVectorElementType.getElementType() != vectorType.getElementType())
+ return op->emitOpError(
+ "requires memref and vector types of the same elemental type");
+
+ // Check that memref vector type is a suffix of 'vectorType.
+ unsigned memrefVecEltRank = memrefVectorElementType.getRank();
+ unsigned resultVecRank = vectorType.getRank();
+ if (memrefVecEltRank > resultVecRank)
+ return op->emitOpError(
+ "requires memref vector element and vector result ranks to match.");
+ // TODO(b/146516564) Move this to isSuffix in VectorOps/Utils.h.
+ unsigned rankOffset = resultVecRank - memrefVecEltRank;
+ auto memrefVecEltShape = memrefVectorElementType.getShape();
+ auto resultVecShape = vectorType.getShape();
+ for (unsigned i = 0; i < memrefVecEltRank; ++i)
+ if (memrefVecEltShape[i] != resultVecShape[rankOffset + i])
+ return op->emitOpError(
+ "requires memref vector element shape to match suffix of "
+ "vector result shape.");
+ // Check that permutation map results match 'rankOffset' of vector type.
+ if (permutationMap.getNumResults() != rankOffset)
+ return op->emitOpError("requires a permutation_map with result dims of "
+ "the same rank as the vector type");
+ } else {
+ // Memref has scalar element type.
+
+ // Check that memref and vector element types match.
+ if (memrefType.getElementType() != vectorType.getElementType())
+ return op->emitOpError(
+ "requires memref and vector types of the same elemental type");
+
+ // Check that permutation map results match rank of vector type.
+ if (permutationMap.getNumResults() != vectorType.getRank())
+ return op->emitOpError("requires a permutation_map with result dims of "
+ "the same rank as the vector type");
+ }
+
+ if (permutationMap.getNumSymbols() != 0)
+ return op->emitOpError("requires permutation_map without symbols");
+ if (permutationMap.getNumInputs() != memrefType.getRank())
+ return op->emitOpError("requires a permutation_map with input dims of the "
+ "same rank as the memref type");
+ return success();
+}
+
static void print(OpAsmPrinter &p, TransferReadOp op) {
p << op.getOperationName() << " " << op.memref() << "[" << op.indices()
<< "], " << op.padding() << " ";
@@ -1333,26 +1503,35 @@
// Consistency of elemental types in memref and vector.
MemRefType memrefType = op.getMemRefType();
VectorType vectorType = op.getVectorType();
- if (memrefType.getElementType() != vectorType.getElementType())
- return op.emitOpError(
- "requires memref and vector types of the same elemental type");
- auto elementalType = op.padding()->getType();
- if (!VectorType::isValidElementType(elementalType))
- return op.emitOpError("requires valid padding vector elemental type");
- if (elementalType != vectorType.getElementType())
- return op.emitOpError(
- "requires formal padding and vector of the same elemental type");
- if (llvm::size(op.indices()) != memrefType.getRank())
- return op.emitOpError("requires ") << memrefType.getRank() << " indices";
+ auto paddingType = op.padding()->getType();
auto permutationMap = op.permutation_map();
- if (permutationMap.getNumSymbols() != 0)
- return op.emitOpError("requires permutation_map without symbols");
- if (permutationMap.getNumInputs() != memrefType.getRank())
- return op.emitOpError("requires a permutation_map with input dims of the "
- "same rank as the memref type");
- if (permutationMap.getNumResults() != vectorType.getRank())
- return op.emitOpError("requires a permutation_map with result dims of the "
- "same rank as the vector type");
+ auto memrefElementType = memrefType.getElementType();
+
+ if (static_cast<int64_t>(op.indices().size()) != memrefType.getRank())
+ return op.emitOpError("requires ") << memrefType.getRank() << " indices";
+
+ if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType,
+ permutationMap)))
+ return failure();
+
+ if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
+ // Memref has vector element type.
+ // Check that 'memrefVectorElementType' and 'paddingType' types match.
+ if (memrefVectorElementType != paddingType)
+ return op.emitOpError(
+ "requires memref element type and padding type to match.");
+
+ } else {
+ // Check that 'paddingType' is valid to store in a vector type.
+ if (!VectorType::isValidElementType(paddingType))
+ return op.emitOpError("requires valid padding vector elemental type");
+
+ // Check that padding type and vector element types match.
+ if (paddingType != vectorType.getElementType())
+ return op.emitOpError(
+ "requires formal padding and vector of the same elemental type");
+ }
+
return verifyPermutationMap(permutationMap,
[&op](Twine t) { return op.emitOpError(t); });
}
@@ -1393,24 +1572,15 @@
// Consistency of elemental types in memref and vector.
MemRefType memrefType = op.getMemRefType();
VectorType vectorType = op.getVectorType();
- if (memrefType.getElementType() != vectorType.getElementType())
- return op.emitOpError(
- "requires memref and vector types of the same elemental type");
+ auto permutationMap = op.permutation_map();
+
if (llvm::size(op.indices()) != memrefType.getRank())
return op.emitOpError("requires ") << memrefType.getRank() << " indices";
- // Consistency of AffineMap attribute.
- auto permutationMap = op.permutation_map();
- if (permutationMap.getNumSymbols() != 0)
- return op.emitOpError("requires a symbol-less permutation_map");
- if (permutationMap.getNumInputs() != memrefType.getRank())
- return op.emitOpError("requires a permutation_map with input dims of the "
- "same rank as the memref type: ")
- << permutationMap.getNumInputs() << " vs " << memrefType;
- if (permutationMap.getNumResults() != vectorType.getRank())
- return op.emitOpError("requires a permutation_map with result dims of the "
- "same rank as the vector type.")
- << permutationMap.getNumResults() << " vs " << vectorType;
+ if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType,
+ permutationMap)))
+ return failure();
+
return verifyPermutationMap(permutationMap,
[&op](Twine t) { return op.emitOpError(t); });
}
@@ -1423,8 +1593,7 @@
return MemRefType::get({}, VectorType::get(t.getShape(), t.getElementType()));
}
-void TypeCastOp::build(Builder *builder, OperationState &result,
- Value *source) {
+void TypeCastOp::build(Builder *builder, OperationState &result, Value source) {
result.addOperands(source);
result.addTypes(
inferVectorTypeCastResultType(source->getType().cast<MemRefType>()));
@@ -1614,21 +1783,21 @@
PatternMatchResult matchAndRewrite(CreateMaskOp createMaskOp,
PatternRewriter &rewriter) const override {
// Return if any of 'createMaskOp' operands are not defined by a constant.
- auto is_not_def_by_constant = [](Value *operand) {
+ auto is_not_def_by_constant = [](Value operand) {
return !isa_and_nonnull<ConstantIndexOp>(operand->getDefiningOp());
};
if (llvm::any_of(createMaskOp.operands(), is_not_def_by_constant))
return matchFailure();
// Gather constant mask dimension sizes.
SmallVector<int64_t, 4> maskDimSizes;
- for (auto *operand : createMaskOp.operands()) {
+ for (auto operand : createMaskOp.operands()) {
auto defOp = operand->getDefiningOp();
maskDimSizes.push_back(cast<ConstantIndexOp>(defOp).getValue());
}
// Replace 'createMaskOp' with ConstantMaskOp.
rewriter.replaceOpWithNewOp<ConstantMaskOp>(
createMaskOp, createMaskOp.getResult()->getType(),
- rewriter.getI64ArrayAttr(maskDimSizes));
+ vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
return matchSuccess();
}
};
diff --git a/third_party/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/third_party/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
index 64cacb2..28b803f 100644
--- a/third_party/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
+++ b/third_party/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
@@ -1,19 +1,10 @@
//===- VectorToLoops.cpp - Conversion within the Vector dialect -----------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements target-independent rewrites as 1->N patterns.
//
@@ -106,17 +97,17 @@
// `resultTypes`.
static Operation *cloneOpWithOperandsAndTypes(PatternRewriter &builder,
Location loc, Operation *op,
- ArrayRef<Value *> operands,
+ ArrayRef<Value> operands,
ArrayRef<Type> resultTypes) {
OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
op->getAttrs());
return builder.createOperation(res);
}
-static Value *makeSplatZero(Location loc, PatternRewriter &rewriter,
- VectorType vt) {
+static Value makeSplatZero(Location loc, PatternRewriter &rewriter,
+ VectorType vt) {
auto t = vt.getElementType();
- Value *f = nullptr;
+ Value f = nullptr;
if (t.isBF16() || t.isF16())
f = rewriter.create<ConstantOp>(loc, t, rewriter.getF64FloatAttr(0.0f));
else if (t.isF32())
@@ -190,12 +181,12 @@
SmallVector<int64_t, 4> unrollFactors;
SmallVector<int64_t, 8> basis;
int64_t numInstances;
- Value *slicesTuple;
+ Value slicesTuple;
};
// Populates 'state' with unrolled shape, unroll factors, basis and
// num unrolled instances for 'vectorType'.
-static void initUnrolledVectorState(VectorType vectorType, Value *initValue,
+static void initUnrolledVectorState(VectorType vectorType, Value initValue,
const DenseMap<int64_t, int64_t> &indexMap,
ArrayRef<int64_t> targetShape,
UnrolledVectorState &state,
@@ -239,11 +230,10 @@
// Returns an unrolled vector at 'vectorOffsets' within the vector
// represented by 'state'. The vector is created from a slice of 'initValue'
// if not present in 'cache'.
-static Value *getOrCreateUnrolledVectorSlice(
+static Value getOrCreateUnrolledVectorSlice(
Location loc, UnrolledVectorState &state, ArrayRef<int64_t> vectorOffsets,
ArrayRef<int64_t> offsets, DenseMap<int64_t, int64_t> &indexMap,
- Value *initValue, SmallVectorImpl<Value *> &cache,
- PatternRewriter &builder) {
+ Value initValue, SmallVectorImpl<Value> &cache, PatternRewriter &builder) {
// Compute slice offsets.
SmallVector<int64_t, 4> sliceOffsets(state.unrolledShape.size());
getMappedElements(indexMap, offsets, sliceOffsets);
@@ -253,7 +243,7 @@
int64_t sliceLinearIndex =
getUnrolledVectorLinearIndex(state, vectorOffsets, indexMap);
assert(sliceLinearIndex < static_cast<int64_t>(cache.size()));
- auto *valueSlice = cache[sliceLinearIndex];
+ auto valueSlice = cache[sliceLinearIndex];
if (valueSlice == nullptr) {
// Return tuple element at 'sliceLinearIndex'.
auto tupleIndex = builder.getI64IntegerAttr(sliceLinearIndex);
@@ -330,12 +320,12 @@
// TODO(andydavis) Generalize this to support structured ops beyond
// vector ContractionOp, and merge it with 'unrollSingleResultOpMatchingType'
-static Value *unrollSingleResultStructuredOp(Operation *op,
- ArrayRef<int64_t> iterationBounds,
- std::vector<VectorState> &vectors,
- unsigned resultIndex,
- ArrayRef<int64_t> targetShape,
- PatternRewriter &builder) {
+static Value unrollSingleResultStructuredOp(Operation *op,
+ ArrayRef<int64_t> iterationBounds,
+ std::vector<VectorState> &vectors,
+ unsigned resultIndex,
+ ArrayRef<int64_t> targetShape,
+ PatternRewriter &builder) {
auto shapedType = op->getResult(0)->getType().dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape())
assert(false && "Expected a statically shaped result type");
@@ -351,7 +341,7 @@
SmallVector<UnrolledVectorState, 3> unrolledVectorState(numVectors);
for (unsigned i = 0; i < numVectors; ++i) {
int64_t operandIndex = vectors[i].operandIndex;
- auto *operand = operandIndex >= 0 ? op->getOperand(operandIndex) : nullptr;
+ auto operand = operandIndex >= 0 ? op->getOperand(operandIndex) : nullptr;
initUnrolledVectorState(vectors[i].type, operand, vectors[i].indexMap,
targetShape, unrolledVectorState[i], builder);
}
@@ -364,7 +354,7 @@
shapedType.getElementType());
// Initialize caches for intermediate vector results.
- std::vector<SmallVector<Value *, 4>> caches(numVectors);
+ std::vector<SmallVector<Value, 4>> caches(numVectors);
for (unsigned i = 0; i < numVectors; ++i)
caches[i].resize(unrolledVectorState[i].numInstances);
@@ -376,13 +366,13 @@
auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; },
vectorOffsets, targetShape);
// Get cached slice (or create slice) for each operand at 'offsets'.
- SmallVector<Value *, 3> operands;
+ SmallVector<Value, 3> operands;
operands.resize(op->getNumOperands());
for (unsigned i = 0; i < numVectors; ++i) {
int64_t operandIndex = vectors[i].operandIndex;
if (operandIndex < 0)
continue; // Output
- auto *operand = op->getOperand(operandIndex);
+ auto operand = op->getOperand(operandIndex);
operands[operandIndex] = getOrCreateUnrolledVectorSlice(
op->getLoc(), unrolledVectorState[i], vectorOffsets, offsets,
vectors[i].indexMap, operand, caches[i], builder);
@@ -402,21 +392,21 @@
// Create TupleOp of unrolled result vectors.
SmallVector<Type, 4> vectorTupleTypes(resultValueState.numInstances);
- SmallVector<Value *, 4> vectorTupleValues(resultValueState.numInstances);
+ SmallVector<Value, 4> vectorTupleValues(resultValueState.numInstances);
for (unsigned i = 0; i < resultValueState.numInstances; ++i) {
vectorTupleTypes[i] = caches[resultIndex][i]->getType().cast<VectorType>();
vectorTupleValues[i] = caches[resultIndex][i];
}
TupleType tupleType = builder.getTupleType(vectorTupleTypes);
- Value *tupleOp = builder.create<vector::TupleOp>(op->getLoc(), tupleType,
- vectorTupleValues);
+ Value tupleOp = builder.create<vector::TupleOp>(op->getLoc(), tupleType,
+ vectorTupleValues);
// Create InsertSlicesOp(Tuple(result_vectors)).
auto resultVectorType = op->getResult(0)->getType().cast<VectorType>();
SmallVector<int64_t, 4> sizes(resultValueState.unrolledShape);
SmallVector<int64_t, 4> strides(resultValueState.unrollFactors.size(), 1);
- Value *insertSlicesOp = builder.create<vector::InsertSlicesOp>(
+ Value insertSlicesOp = builder.create<vector::InsertSlicesOp>(
op->getLoc(), resultVectorType, tupleOp, builder.getI64ArrayAttr(sizes),
builder.getI64ArrayAttr(strides));
return insertSlicesOp;
@@ -487,7 +477,7 @@
}
// Entry point for unrolling declarative pattern rewrites.
-Value *mlir::vector::unrollSingleResultOpMatchingType(
+Value mlir::vector::unrollSingleResultOpMatchingType(
PatternRewriter &builder, Operation *op, ArrayRef<int64_t> targetShape) {
assert(op->getNumResults() == 1 && "Expected single result operation");
@@ -516,8 +506,8 @@
static void
generateTransferOpSlices(VectorType vectorType, TupleType tupleType,
ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides,
- ArrayRef<Value *> indices, PatternRewriter &rewriter,
- function_ref<void(unsigned, ArrayRef<Value *>)> fn) {
+ ArrayRef<Value> indices, PatternRewriter &rewriter,
+ function_ref<void(unsigned, ArrayRef<Value>)> fn) {
// Compute strides w.r.t. to slice counts in each dimension.
auto maybeDimSliceCounts = shapeRatio(vectorType.getShape(), sizes);
assert(maybeDimSliceCounts.hasValue());
@@ -534,13 +524,13 @@
auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; },
vectorOffsets, sizes);
// Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
- SmallVector<Value *, 4> sliceIndices(numSliceIndices);
+ SmallVector<Value, 4> sliceIndices(numSliceIndices);
for (auto it : llvm::enumerate(indices)) {
auto expr = getAffineDimExpr(0, ctx) +
getAffineConstantExpr(offsets[it.index()], ctx);
auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
sliceIndices[it.index()] = rewriter.create<AffineApplyOp>(
- it.value()->getLoc(), map, ArrayRef<Value *>(it.value()));
+ it.value()->getLoc(), map, ArrayRef<Value>(it.value()));
}
// Call 'fn' to generate slice 'i' at 'sliceIndices'.
fn(i, sliceIndices);
@@ -559,7 +549,7 @@
if (!xferReadOp.permutation_map().isIdentity())
return matchFailure();
// Return unless the unique 'xferReadOp' user is an ExtractSlicesOp.
- Value *xferReadResult = xferReadOp.getResult();
+ Value xferReadResult = xferReadOp.getResult();
auto extractSlicesOp =
dyn_cast<vector::ExtractSlicesOp>(*xferReadResult->getUsers().begin());
if (!xferReadResult->hasOneUse() || !extractSlicesOp)
@@ -576,10 +566,10 @@
Location loc = xferReadOp.getLoc();
int64_t numSlices = resultTupleType.size();
- SmallVector<Value *, 4> vectorTupleValues(numSlices);
- SmallVector<Value *, 4> indices(xferReadOp.indices().begin(),
- xferReadOp.indices().end());
- auto createSlice = [&](unsigned index, ArrayRef<Value *> sliceIndices) {
+ SmallVector<Value, 4> vectorTupleValues(numSlices);
+ SmallVector<Value, 4> indices(xferReadOp.indices().begin(),
+ xferReadOp.indices().end());
+ auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
// Get VectorType for slice 'i'.
auto sliceVectorType = resultTupleType.getType(index);
// Create split TransferReadOp for 'sliceUser'.
@@ -591,8 +581,8 @@
indices, rewriter, createSlice);
// Create tuple of splice xfer read operations.
- Value *tupleOp = rewriter.create<vector::TupleOp>(loc, resultTupleType,
- vectorTupleValues);
+ Value tupleOp = rewriter.create<vector::TupleOp>(loc, resultTupleType,
+ vectorTupleValues);
// Replace 'xferReadOp' with result 'insertSlicesResult'.
rewriter.replaceOpWithNewOp<vector::InsertSlicesOp>(
xferReadOp, sourceVectorType, tupleOp, extractSlicesOp.sizes(),
@@ -632,9 +622,9 @@
insertSlicesOp.getStrides(strides);
Location loc = xferWriteOp.getLoc();
- SmallVector<Value *, 4> indices(xferWriteOp.indices().begin(),
- xferWriteOp.indices().end());
- auto createSlice = [&](unsigned index, ArrayRef<Value *> sliceIndices) {
+ SmallVector<Value, 4> indices(xferWriteOp.indices().begin(),
+ xferWriteOp.indices().end());
+ auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
// Create split TransferWriteOp for source vector 'tupleOp.operand[i]'.
rewriter.create<vector::TransferWriteOp>(
loc, tupleOp.getOperand(index), xferWriteOp.memref(), sliceIndices,
@@ -676,7 +666,7 @@
return matchFailure();
// Forward Value from 'tupleOp' at 'tupleGetOp.index'.
- Value *tupleValue = tupleOp.getOperand(tupleGetOp.getIndex());
+ Value tupleValue = tupleOp.getOperand(tupleGetOp.getIndex());
rewriter.replaceOp(tupleGetOp, tupleValue);
return matchSuccess();
}
diff --git a/third_party/mlir/lib/EDSC/Builders.cpp b/third_party/mlir/lib/EDSC/Builders.cpp
index 2956066..7d51cde 100644
--- a/third_party/mlir/lib/EDSC/Builders.cpp
+++ b/third_party/mlir/lib/EDSC/Builders.cpp
@@ -1,19 +1,10 @@
//===- Builders.cpp - MLIR Declarative Builder Classes --------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/EDSC/Builders.h"
#include "mlir/Dialect/StandardOps/Ops.h"
@@ -90,7 +81,7 @@
ValueHandle
mlir::edsc::ValueHandle::createComposedAffineApply(AffineMap map,
- ArrayRef<Value *> operands) {
+ ArrayRef<Value> operands) {
Operation *op =
makeComposedAffineApply(ScopedContext::getBuilder(),
ScopedContext::getLocation(), map, operands)
@@ -118,7 +109,7 @@
ArrayRef<Type> resultTypes,
ArrayRef<NamedAttribute> attributes) {
OperationState state(ScopedContext::getLocation(), name);
- SmallVector<Value *, 4> ops(operands.begin(), operands.end());
+ SmallVector<Value, 4> ops(operands.begin(), operands.end());
state.addOperands(ops);
state.addTypes(resultTypes);
for (const auto &attr : attributes) {
@@ -169,8 +160,8 @@
if (auto staticFor = emitStaticFor(lbHandles, ubHandles, step)) {
*iv = staticFor.getValue();
} else {
- SmallVector<Value *, 4> lbs(lbHandles.begin(), lbHandles.end());
- SmallVector<Value *, 4> ubs(ubHandles.begin(), ubHandles.end());
+ SmallVector<Value, 4> lbs(lbHandles.begin(), lbHandles.end());
+ SmallVector<Value, 4> ubs(ubHandles.begin(), ubHandles.end());
*iv = ValueHandle::create<AffineForOp>(
lbs, ScopedContext::getBuilder().getMultiDimIdentityMap(lbs.size()),
ubs, ScopedContext::getBuilder().getMultiDimIdentityMap(ubs.size()),
@@ -309,11 +300,11 @@
return ValueHandle::create<Op>(lhs.getValue(), rhs.getValue());
}
-static std::pair<AffineExpr, Value *>
-categorizeValueByAffineType(MLIRContext *context, Value *val, unsigned &numDims,
+static std::pair<AffineExpr, Value>
+categorizeValueByAffineType(MLIRContext *context, Value val, unsigned &numDims,
unsigned &numSymbols) {
AffineExpr d;
- Value *resultVal = nullptr;
+ Value resultVal = nullptr;
if (auto constant = dyn_cast_or_null<ConstantIndexOp>(val->getDefiningOp())) {
d = getAffineConstantExpr(constant.getValue(), context);
} else if (isValidSymbol(val) && !isValidDim(val)) {
@@ -332,12 +323,12 @@
MLIRContext *context = ScopedContext::getContext();
unsigned numDims = 0, numSymbols = 0;
AffineExpr d0, d1;
- Value *v0, *v1;
+ Value v0, v1;
std::tie(d0, v0) =
categorizeValueByAffineType(context, lhs.getValue(), numDims, numSymbols);
std::tie(d1, v1) =
categorizeValueByAffineType(context, rhs.getValue(), numDims, numSymbols);
- SmallVector<Value *, 2> operands;
+ SmallVector<Value, 2> operands;
if (v0) {
operands.push_back(v0);
}
@@ -390,14 +381,14 @@
}
ValueHandle mlir::edsc::op::operator/(ValueHandle lhs, ValueHandle rhs) {
- return createBinaryHandle<DivISOp, DivFOp>(
+ return createBinaryHandle<SignedDivIOp, DivFOp>(
lhs, rhs, [](AffineExpr d0, AffineExpr d1) -> AffineExpr {
llvm_unreachable("only exprs of non-index type support operator/");
});
}
ValueHandle mlir::edsc::op::operator%(ValueHandle lhs, ValueHandle rhs) {
- return createBinaryHandle<RemISOp, RemFOp>(
+ return createBinaryHandle<SignedRemIOp, RemFOp>(
lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 % d1; });
}
diff --git a/third_party/mlir/lib/EDSC/CoreAPIs.cpp b/third_party/mlir/lib/EDSC/CoreAPIs.cpp
index 46199c2..6f7c172 100644
--- a/third_party/mlir/lib/EDSC/CoreAPIs.cpp
+++ b/third_party/mlir/lib/EDSC/CoreAPIs.cpp
@@ -1,19 +1,10 @@
//===- Types.cpp - Implementations of MLIR Core C APIs --------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir-c/Core.h"
diff --git a/third_party/mlir/lib/EDSC/Helpers.cpp b/third_party/mlir/lib/EDSC/Helpers.cpp
index eeb2866..008948b 100644
--- a/third_party/mlir/lib/EDSC/Helpers.cpp
+++ b/third_party/mlir/lib/EDSC/Helpers.cpp
@@ -1,19 +1,10 @@
//===- Helpers.cpp - MLIR Declarative Helper Functionality ----------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/EDSC/Helpers.h"
#include "mlir/Dialect/StandardOps/Ops.h"
@@ -22,7 +13,7 @@
using namespace mlir;
using namespace mlir::edsc;
-static SmallVector<ValueHandle, 8> getMemRefSizes(Value *memRef) {
+static SmallVector<ValueHandle, 8> getMemRefSizes(Value memRef) {
MemRefType memRefType = memRef->getType().cast<MemRefType>();
assert(isStrided(memRefType) && "Expected strided MemRef type");
@@ -39,7 +30,7 @@
return res;
}
-mlir::edsc::MemRefView::MemRefView(Value *v) : base(v) {
+mlir::edsc::MemRefView::MemRefView(Value v) : base(v) {
assert(v->getType().isa<MemRefType>() && "MemRefType expected");
auto memrefSizeValues = getMemRefSizes(v);
@@ -50,7 +41,7 @@
}
}
-mlir::edsc::VectorView::VectorView(Value *v) : base(v) {
+mlir::edsc::VectorView::VectorView(Value v) : base(v) {
auto vectorType = v->getType().cast<VectorType>();
for (auto s : vectorType.getShape()) {
diff --git a/third_party/mlir/lib/EDSC/Intrinsics.cpp b/third_party/mlir/lib/EDSC/Intrinsics.cpp
index 1b19f9a..d339ec0 100644
--- a/third_party/mlir/lib/EDSC/Intrinsics.cpp
+++ b/third_party/mlir/lib/EDSC/Intrinsics.cpp
@@ -1,19 +1,10 @@
//===- Intrinsics.cpp - MLIR Operations for Declarative Builders ----------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/EDSC/Intrinsics.h"
#include "mlir/EDSC/Builders.h"
@@ -29,7 +20,7 @@
(void)o;
assert(o && "Expected already captured ValueHandle");
}
- SmallVector<Value *, 4> ops(operands.begin(), operands.end());
+ SmallVector<Value, 4> ops(operands.begin(), operands.end());
return OperationHandle::create<BranchOp>(bh.getBlock(), ops);
}
static void enforceEmptyCapturesMatchOperands(ArrayRef<ValueHandle *> captures,
@@ -52,7 +43,7 @@
assert(!*bh && "Unexpected already captured BlockHandle");
enforceEmptyCapturesMatchOperands(captures, operands);
BlockBuilder(bh, captures)(/* no body */);
- SmallVector<Value *, 4> ops(operands.begin(), operands.end());
+ SmallVector<Value, 4> ops(operands.begin(), operands.end());
return OperationHandle::create<BranchOp>(bh->getBlock(), ops);
}
@@ -61,8 +52,8 @@
ArrayRef<ValueHandle> trueOperands,
BlockHandle falseBranch,
ArrayRef<ValueHandle> falseOperands) {
- SmallVector<Value *, 4> trueOps(trueOperands.begin(), trueOperands.end());
- SmallVector<Value *, 4> falseOps(falseOperands.begin(), falseOperands.end());
+ SmallVector<Value, 4> trueOps(trueOperands.begin(), trueOperands.end());
+ SmallVector<Value, 4> falseOps(falseOperands.begin(), falseOperands.end());
return OperationHandle::create<CondBranchOp>(
cond, trueBranch.getBlock(), trueOps, falseBranch.getBlock(), falseOps);
}
@@ -78,8 +69,8 @@
enforceEmptyCapturesMatchOperands(falseCaptures, falseOperands);
BlockBuilder(trueBranch, trueCaptures)(/* no body */);
BlockBuilder(falseBranch, falseCaptures)(/* no body */);
- SmallVector<Value *, 4> trueOps(trueOperands.begin(), trueOperands.end());
- SmallVector<Value *, 4> falseOps(falseOperands.begin(), falseOperands.end());
+ SmallVector<Value, 4> trueOps(trueOperands.begin(), trueOperands.end());
+ SmallVector<Value, 4> falseOps(falseOperands.begin(), falseOperands.end());
return OperationHandle::create<CondBranchOp>(
cond, trueBranch->getBlock(), trueOps, falseBranch->getBlock(), falseOps);
}
diff --git a/third_party/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/third_party/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
index 5098ba8..1537018 100644
--- a/third_party/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
+++ b/third_party/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
@@ -1,19 +1,10 @@
//===- ExecutionEngine.cpp - MLIR Execution engine and utils --------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements the execution engine for MLIR modules based on LLVM Orc
// JIT engine.
diff --git a/third_party/mlir/lib/ExecutionEngine/OptUtils.cpp b/third_party/mlir/lib/ExecutionEngine/OptUtils.cpp
index dc3bd20..ec2ae5f 100644
--- a/third_party/mlir/lib/ExecutionEngine/OptUtils.cpp
+++ b/third_party/mlir/lib/ExecutionEngine/OptUtils.cpp
@@ -1,19 +1,10 @@
//===- OptUtils.cpp - MLIR Execution Engine optimization pass utilities ---===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements the utility functions to trigger LLVM optimizations from
// MLIR Execution Engine.
diff --git a/third_party/mlir/lib/IR/AffineExpr.cpp b/third_party/mlir/lib/IR/AffineExpr.cpp
index 009c1a1..dd8ce00 100644
--- a/third_party/mlir/lib/IR/AffineExpr.cpp
+++ b/third_party/mlir/lib/IR/AffineExpr.cpp
@@ -1,19 +1,10 @@
//===- AffineExpr.cpp - MLIR Affine Expr Classes --------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/IR/AffineExpr.h"
#include "AffineExprDetail.h"
diff --git a/third_party/mlir/lib/IR/AffineExprDetail.h b/third_party/mlir/lib/IR/AffineExprDetail.h
index 214fee6..8824ddd 100644
--- a/third_party/mlir/lib/IR/AffineExprDetail.h
+++ b/third_party/mlir/lib/IR/AffineExprDetail.h
@@ -1,19 +1,10 @@
//===- AffineExprDetail.h - MLIR Affine Expr storage details ----*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This holds implementation details of AffineExpr. Ideally it would not be
// exposed and would be kept local to AffineExpr.cpp however, MLIRContext.cpp
diff --git a/third_party/mlir/lib/IR/AffineMap.cpp b/third_party/mlir/lib/IR/AffineMap.cpp
index 6cfef36..50624af 100644
--- a/third_party/mlir/lib/IR/AffineMap.cpp
+++ b/third_party/mlir/lib/IR/AffineMap.cpp
@@ -1,19 +1,10 @@
//===- AffineMap.cpp - MLIR Affine Map Classes ----------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/IR/AffineMap.h"
#include "AffineMapDetail.h"
diff --git a/third_party/mlir/lib/IR/AffineMapDetail.h b/third_party/mlir/lib/IR/AffineMapDetail.h
index a247783..f00c4ba 100644
--- a/third_party/mlir/lib/IR/AffineMapDetail.h
+++ b/third_party/mlir/lib/IR/AffineMapDetail.h
@@ -1,19 +1,10 @@
//===- AffineMapDetail.h - MLIR Affine Map details Class --------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This holds implementation details of AffineMap.
//
diff --git a/third_party/mlir/lib/IR/AsmPrinter.cpp b/third_party/mlir/lib/IR/AsmPrinter.cpp
index e1903d5..881a636 100644
--- a/third_party/mlir/lib/IR/AsmPrinter.cpp
+++ b/third_party/mlir/lib/IR/AsmPrinter.cpp
@@ -1,19 +1,10 @@
//===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements the MLIR AsmPrinter class, which is used to implement
// the various print() methods on the core IR objects.
@@ -319,7 +310,7 @@
visitType(type);
for (auto ®ion : op->getRegions())
for (auto &block : region)
- for (auto *arg : block.getArguments())
+ for (auto arg : block.getArguments())
visitType(arg->getType());
// Visit each of the attributes.
@@ -1437,7 +1428,7 @@
void printAttribute(Attribute attr) override {
ModulePrinter::printAttribute(attr);
}
- void printOperand(Value *value) override { printValueID(value); }
+ void printOperand(Value value) override { printValueID(value); }
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<StringRef> elidedAttrs = {}) override {
@@ -1519,7 +1510,7 @@
void numberValuesInRegion(Region ®ion);
void numberValuesInBlock(Block &block);
void numberValuesInOp(Operation &op);
- void printValueID(Value *value, bool printResultNo = true) const {
+ void printValueID(Value value, bool printResultNo = true) const {
printValueIDImpl(value, printResultNo, os);
}
@@ -1528,13 +1519,13 @@
/// 'lookupValue' and the result of 'result' within that group in
/// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group
/// has more than 1 result.
- void getResultIDAndNumber(OpResult *result, Value *&lookupValue,
+ void getResultIDAndNumber(OpResult result, Value &lookupValue,
int &lookupResultNo) const;
- void printValueIDImpl(Value *value, bool printResultNo,
+ void printValueIDImpl(Value value, bool printResultNo,
raw_ostream &stream) const;
/// Set a special value name for the given value.
- void setValueName(Value *value, StringRef name);
+ void setValueName(Value value, StringRef name);
/// Uniques the given value name within the printer. If the given name
/// conflicts, it is automatically renamed.
@@ -1542,8 +1533,8 @@
/// This is the value ID for each SSA value. If this returns ~0, then the
/// valueID has an entry in valueNames.
- DenseMap<Value *, unsigned> valueIDs;
- DenseMap<Value *, StringRef> valueNames;
+ DenseMap<Value, unsigned> valueIDs;
+ DenseMap<Value, StringRef> valueNames;
/// This is a map of operations that contain multiple named result groups,
/// i.e. there may be multiple names for the results of the operation. The key
@@ -1619,13 +1610,28 @@
}
void OperationPrinter::numberValuesInBlock(Block &block) {
+ auto setArgNameFn = [&](Value arg, StringRef name) {
+ assert(!valueIDs.count(arg) && "arg numbered multiple times");
+ assert(arg.cast<BlockArgument>()->getOwner() == &block &&
+ "arg not defined in 'block'");
+ setValueName(arg, name);
+ };
+
bool isEntryBlock = block.isEntryBlock();
+ if (isEntryBlock && state) {
+ if (auto *op = block.getParentOp()) {
+ if (auto dialectAsmInterface = state->getOpAsmInterface(op->getDialect()))
+ dialectAsmInterface->getAsmBlockArgumentNames(&block, setArgNameFn);
+ }
+ }
// Number the block arguments. We give entry block arguments a special name
// 'arg'.
SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "");
llvm::raw_svector_ostream specialName(specialNameBuffer);
- for (auto *arg : block.getArguments()) {
+ for (auto arg : block.getArguments()) {
+ if (valueIDs.count(arg))
+ continue;
if (isEntryBlock) {
specialNameBuffer.resize(strlen("arg"));
specialName << nextArgumentID++;
@@ -1642,17 +1648,17 @@
unsigned numResults = op.getNumResults();
if (numResults == 0)
return;
- Value *resultBegin = op.getResult(0);
+ Value resultBegin = op.getResult(0);
// Function used to set the special result names for the operation.
SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
- auto setResultNameFn = [&](Value *result, StringRef name) {
+ auto setResultNameFn = [&](Value result, StringRef name) {
assert(!valueIDs.count(result) && "result numbered multiple times");
assert(result->getDefiningOp() == &op && "result not defined by 'op'");
setValueName(result, name);
// Record the result number for groups not anchored at 0.
- if (int resultNo = cast<OpResult>(result)->getResultNumber())
+ if (int resultNo = result.cast<OpResult>()->getResultNumber())
resultGroups.push_back(resultNo);
};
@@ -1675,7 +1681,7 @@
}
/// Set a special value name for the given value.
-void OperationPrinter::setValueName(Value *value, StringRef name) {
+void OperationPrinter::setValueName(Value value, StringRef name) {
// If the name is empty, the value uses the default numbering.
if (name.empty()) {
valueIDs[value] = nextValueID++;
@@ -1722,7 +1728,7 @@
// Print the argument list if non-empty.
if (!block->args_empty()) {
os << '(';
- interleaveComma(block->getArguments(), [&](BlockArgument *arg) {
+ interleaveComma(block->getArguments(), [&](BlockArgument arg) {
printValueID(arg);
os << ": ";
printType(arg->getType());
@@ -1773,8 +1779,7 @@
printTrailingLocation(op->getLoc());
}
-void OperationPrinter::getResultIDAndNumber(OpResult *result,
- Value *&lookupValue,
+void OperationPrinter::getResultIDAndNumber(OpResult result, Value &lookupValue,
int &lookupResultNo) const {
Operation *owner = result->getOwner();
if (owner->getNumResults() == 1)
@@ -1812,7 +1817,7 @@
lookupValue = owner->getResult(groupResultNo);
}
-void OperationPrinter::printValueIDImpl(Value *value, bool printResultNo,
+void OperationPrinter::printValueIDImpl(Value value, bool printResultNo,
raw_ostream &stream) const {
if (!value) {
stream << "<<NULL>>";
@@ -1825,7 +1830,7 @@
// If this is a reference to the result of a multi-result operation or
// operation, print out the # identifier and make sure to map our lookup
// to the first result of the operation.
- if (OpResult *result = dyn_cast<OpResult>(value))
+ if (OpResult result = value.dyn_cast<OpResult>())
getResultIDAndNumber(result, lookupValue, resultNo);
auto it = valueIDs.find(lookupValue);
@@ -1860,11 +1865,11 @@
SmallVector<char, 16> nameStr;
for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) {
- auto *nameToUse = namesToUse[i];
+ auto nameToUse = namesToUse[i];
if (nameToUse == nullptr)
continue;
- auto *nameToReplace = region.front().getArgument(i);
+ auto nameToReplace = region.front().getArgument(i);
nameStr.clear();
llvm::raw_svector_ostream nameStream(nameStr);
@@ -1936,10 +1941,10 @@
for (unsigned i = 0; i < numSuccessors; ++i)
totalNumSuccessorOperands += op->getNumSuccessorOperands(i);
unsigned numProperOperands = op->getNumOperands() - totalNumSuccessorOperands;
- SmallVector<Value *, 8> properOperands(
+ SmallVector<Value, 8> properOperands(
op->operand_begin(), std::next(op->operand_begin(), numProperOperands));
- interleaveComma(properOperands, [&](Value *value) { printValueID(value); });
+ interleaveComma(properOperands, [&](Value value) { printValueID(value); });
os << ')';
@@ -1982,10 +1987,10 @@
os << '(';
interleaveComma(succOperands,
- [this](Value *operand) { printValueID(operand); });
+ [this](Value operand) { printValueID(operand); });
os << " : ";
interleaveComma(succOperands,
- [this](Value *operand) { printType(operand->getType()); });
+ [this](Value operand) { printType(operand->getType()); });
os << ')';
}
@@ -2057,7 +2062,7 @@
if (auto *op = getDefiningOp())
return op->print(os);
// TODO: Improve this.
- assert(isa<BlockArgument>(*this));
+ assert(isa<BlockArgument>());
os << "<block argument>\n";
}
diff --git a/third_party/mlir/lib/IR/AttributeDetail.h b/third_party/mlir/lib/IR/AttributeDetail.h
index da4aa69..c78d49c 100644
--- a/third_party/mlir/lib/IR/AttributeDetail.h
+++ b/third_party/mlir/lib/IR/AttributeDetail.h
@@ -1,19 +1,10 @@
//===- AttributeDetail.h - MLIR Affine Map details Class --------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This holds implementation details of Attribute.
//
diff --git a/third_party/mlir/lib/IR/Attributes.cpp b/third_party/mlir/lib/IR/Attributes.cpp
index bb35a63..3a9c91f 100644
--- a/third_party/mlir/lib/IR/Attributes.cpp
+++ b/third_party/mlir/lib/IR/Attributes.cpp
@@ -1,19 +1,10 @@
//===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/IR/Attributes.h"
#include "AttributeDetail.h"
diff --git a/third_party/mlir/lib/IR/Block.cpp b/third_party/mlir/lib/IR/Block.cpp
index 4dac32a..b0ada99 100644
--- a/third_party/mlir/lib/IR/Block.cpp
+++ b/third_party/mlir/lib/IR/Block.cpp
@@ -1,19 +1,10 @@
//===- Block.cpp - MLIR Block Class ---------------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
@@ -25,10 +16,10 @@
//===----------------------------------------------------------------------===//
/// Returns the number of this argument.
-unsigned BlockArgument::getArgNumber() {
+unsigned BlockArgument::getArgNumber() const {
// Arguments are not stored in place, so we have to find it within the list.
auto argList = getOwner()->getArguments();
- return std::distance(argList.begin(), llvm::find(argList, this));
+ return std::distance(argList.begin(), llvm::find(argList, *this));
}
//===----------------------------------------------------------------------===//
@@ -38,7 +29,8 @@
Block::~Block() {
assert(!verifyOpOrder() && "Expected valid operation ordering.");
clear();
- llvm::DeleteContainerPointers(arguments);
+ for (BlockArgument arg : arguments)
+ arg.destroy();
}
Region *Block::getParent() const { return parentValidOpOrderPair.getPointer(); }
@@ -98,7 +90,7 @@
}
void Block::dropAllDefinedValueUses() {
- for (auto *arg : getArguments())
+ for (auto arg : getArguments())
arg->dropAllUses();
for (auto &op : *this)
op.dropAllDefinedValueUses();
@@ -151,8 +143,8 @@
// Argument list management.
//===----------------------------------------------------------------------===//
-BlockArgument *Block::addArgument(Type type) {
- auto *arg = new BlockArgument(type, this);
+BlockArgument Block::addArgument(Type type) {
+ BlockArgument arg = BlockArgument::create(type, this);
arguments.push_back(arg);
return arg;
}
@@ -172,7 +164,7 @@
assert(index < arguments.size());
// Delete the argument.
- delete arguments[index];
+ arguments[index].destroy();
arguments.erase(arguments.begin() + index);
// If we aren't updating predecessors, there is nothing left to do.
@@ -275,3 +267,8 @@
if ((count = term->getNumSuccessors()))
base = term->getBlockOperands().data();
}
+
+SuccessorRange::SuccessorRange(Operation *term) : SuccessorRange(nullptr, 0) {
+ if ((count = term->getNumSuccessors()))
+ base = term->getBlockOperands().data();
+}
diff --git a/third_party/mlir/lib/IR/Builders.cpp b/third_party/mlir/lib/IR/Builders.cpp
index 691b2ad..5567f87 100644
--- a/third_party/mlir/lib/IR/Builders.cpp
+++ b/third_party/mlir/lib/IR/Builders.cpp
@@ -1,19 +1,10 @@
//===- Builders.cpp - Helpers for constructing MLIR Classes ---------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/IR/Builders.h"
#include "mlir/IR/AffineExpr.h"
@@ -343,7 +334,7 @@
/// 'results'. Returns success if the operation was folded, failure otherwise.
/// Note: This function does not erase the operation on a successful fold.
LogicalResult OpBuilder::tryFold(Operation *op,
- SmallVectorImpl<Value *> &results) {
+ SmallVectorImpl<Value> &results) {
results.reserve(op->getNumResults());
auto cleanupFailure = [&] {
results.assign(op->result_begin(), op->result_end());
@@ -374,7 +365,7 @@
Dialect *dialect = op->getDialect();
for (auto &it : llvm::enumerate(foldResults)) {
// Normal values get pushed back directly.
- if (auto *value = it.value().dyn_cast<Value *>()) {
+ if (auto value = it.value().dyn_cast<Value>()) {
results.push_back(value);
continue;
}
diff --git a/third_party/mlir/lib/IR/Diagnostics.cpp b/third_party/mlir/lib/IR/Diagnostics.cpp
index 59e16a4..6ec92f0 100644
--- a/third_party/mlir/lib/IR/Diagnostics.cpp
+++ b/third_party/mlir/lib/IR/Diagnostics.cpp
@@ -1,19 +1,10 @@
//===- Diagnostics.cpp - MLIR Diagnostics ---------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Attributes.h"
diff --git a/third_party/mlir/lib/IR/Dialect.cpp b/third_party/mlir/lib/IR/Dialect.cpp
index c6266b0..b2485a3 100644
--- a/third_party/mlir/lib/IR/Dialect.cpp
+++ b/third_party/mlir/lib/IR/Dialect.cpp
@@ -1,19 +1,10 @@
//===- Dialect.cpp - Dialect implementation -------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Diagnostics.h"
diff --git a/third_party/mlir/lib/IR/Function.cpp b/third_party/mlir/lib/IR/Function.cpp
index b51c77f..72b5ac4 100644
--- a/third_party/mlir/lib/IR/Function.cpp
+++ b/third_party/mlir/lib/IR/Function.cpp
@@ -1,19 +1,10 @@
//===- Function.cpp - MLIR Function Classes -------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/IR/Function.h"
#include "mlir/IR/BlockAndValueMapping.h"
diff --git a/third_party/mlir/lib/IR/FunctionImplementation.cpp b/third_party/mlir/lib/IR/FunctionImplementation.cpp
index 9cec216..79863bc 100644
--- a/third_party/mlir/lib/IR/FunctionImplementation.cpp
+++ b/third_party/mlir/lib/IR/FunctionImplementation.cpp
@@ -1,19 +1,10 @@
//===- FunctionImplementation.cpp - Utilities for function-like ops -------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/Builders.h"
diff --git a/third_party/mlir/lib/IR/IntegerSet.cpp b/third_party/mlir/lib/IR/IntegerSet.cpp
index ce50fa7..835b4c3 100644
--- a/third_party/mlir/lib/IR/IntegerSet.cpp
+++ b/third_party/mlir/lib/IR/IntegerSet.cpp
@@ -1,19 +1,10 @@
//===- IntegerSet.cpp - MLIR Integer Set class ----------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/IR/IntegerSet.h"
#include "IntegerSetDetail.h"
diff --git a/third_party/mlir/lib/IR/IntegerSetDetail.h b/third_party/mlir/lib/IR/IntegerSetDetail.h
index b3eda52..54ffd47 100644
--- a/third_party/mlir/lib/IR/IntegerSetDetail.h
+++ b/third_party/mlir/lib/IR/IntegerSetDetail.h
@@ -1,19 +1,10 @@
//===- IntegerSetDetail.h - MLIR IntegerSet storage details -----*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This holds implementation details of IntegerSet.
//
diff --git a/third_party/mlir/lib/IR/Location.cpp b/third_party/mlir/lib/IR/Location.cpp
index 1ea75d5..e23a736 100644
--- a/third_party/mlir/lib/IR/Location.cpp
+++ b/third_party/mlir/lib/IR/Location.cpp
@@ -1,19 +1,10 @@
//===- Location.cpp - MLIR Location Classes -------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/IR/Location.h"
#include "LocationDetail.h"
diff --git a/third_party/mlir/lib/IR/LocationDetail.h b/third_party/mlir/lib/IR/LocationDetail.h
index 6ccaa17..a47a211 100644
--- a/third_party/mlir/lib/IR/LocationDetail.h
+++ b/third_party/mlir/lib/IR/LocationDetail.h
@@ -1,19 +1,10 @@
//===- LocationDetail.h - MLIR Location storage details ---------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This holds implementation details of the location attributes.
//
diff --git a/third_party/mlir/lib/IR/MLIRContext.cpp b/third_party/mlir/lib/IR/MLIRContext.cpp
index d3feca1..42d77ae 100644
--- a/third_party/mlir/lib/IR/MLIRContext.cpp
+++ b/third_party/mlir/lib/IR/MLIRContext.cpp
@@ -1,19 +1,10 @@
//===- MLIRContext.cpp - MLIR Type Classes --------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/IR/MLIRContext.h"
#include "AffineExprDetail.h"
diff --git a/third_party/mlir/lib/IR/Module.cpp b/third_party/mlir/lib/IR/Module.cpp
index c52a55b..c5af227 100644
--- a/third_party/mlir/lib/IR/Module.cpp
+++ b/third_party/mlir/lib/IR/Module.cpp
@@ -1,19 +1,10 @@
//===- Module.cpp - MLIR Module Operation ---------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/IR/Module.h"
#include "mlir/IR/Builders.h"
diff --git a/third_party/mlir/lib/IR/Operation.cpp b/third_party/mlir/lib/IR/Operation.cpp
index 9df1079..c7baba8 100644
--- a/third_party/mlir/lib/IR/Operation.cpp
+++ b/third_party/mlir/lib/IR/Operation.cpp
@@ -1,19 +1,10 @@
//===- Operation.cpp - Operation support code -----------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/IR/Operation.h"
#include "mlir/IR/BlockAndValueMapping.h"
@@ -77,23 +68,29 @@
//===----------------------------------------------------------------------===//
/// Return the result number of this result.
-unsigned OpResult::getResultNumber() {
- // Results are always stored consecutively, so use pointer subtraction to
- // figure out what number this is.
- return this - &getOwner()->getOpResults()[0];
+unsigned OpResult::getResultNumber() const {
+ // Results are not stored in place, so we have to find it within the list.
+ auto resList = getOwner()->getOpResults();
+ return std::distance(resList.begin(), llvm::find(resList, *this));
}
//===----------------------------------------------------------------------===//
// OpOperand
//===----------------------------------------------------------------------===//
-// TODO: This namespace is only required because of a bug in GCC<7.0.
-namespace mlir {
+OpOperand::OpOperand(Operation *owner, Value value)
+ : IROperand(owner, value.impl) {}
+
+/// Return the current value being used by this operand.
+Value OpOperand::get() { return (detail::ValueImpl *)IROperand::get(); }
+
+/// Set the current value being used by this operand.
+void OpOperand::set(Value newValue) { IROperand::set(newValue.impl); }
+
/// Return which operand this is in the operand list.
-template <> unsigned OpOperand::getOperandNumber() {
+unsigned OpOperand::getOperandNumber() {
return this - &getOwner()->getOpOperands()[0];
}
-} // end namespace mlir
//===----------------------------------------------------------------------===//
// BlockOperand
@@ -114,7 +111,7 @@
/// Create a new Operation with the specific fields.
Operation *Operation::create(Location location, OperationName name,
ArrayRef<Type> resultTypes,
- ArrayRef<Value *> operands,
+ ArrayRef<Value> operands,
ArrayRef<NamedAttribute> attributes,
ArrayRef<Block *> successors, unsigned numRegions,
bool resizableOperandList) {
@@ -134,7 +131,7 @@
/// Create a new Operation with the specific fields.
Operation *Operation::create(Location location, OperationName name,
ArrayRef<Type> resultTypes,
- ArrayRef<Value *> operands,
+ ArrayRef<Value> operands,
NamedAttributeList attributes,
ArrayRef<Block *> successors, RegionRange regions,
bool resizableOperandList) {
@@ -151,7 +148,7 @@
/// unnecessarily uniquing a list of attributes.
Operation *Operation::create(Location location, OperationName name,
ArrayRef<Type> resultTypes,
- ArrayRef<Value *> operands,
+ ArrayRef<Value> operands,
NamedAttributeList attributes,
ArrayRef<Block *> successors, unsigned numRegions,
bool resizableOperandList) {
@@ -188,7 +185,7 @@
auto instResults = op->getOpResults();
for (unsigned i = 0, e = resultTypes.size(); i != e; ++i)
- new (&instResults[i]) OpResult(resultTypes[i], op);
+ new (&instResults[i]) OpResult(OpResult::create(resultTypes[i], op));
auto opOperands = op->getOpOperands();
@@ -265,7 +262,7 @@
getOperandStorage().~OperandStorage();
for (auto &result : getOpResults())
- result.~OpResult();
+ result.destroy();
// Explicitly run the destructors for the successors.
for (auto &successor : getBlockOperands())
@@ -314,7 +311,7 @@
}
/// Replace any uses of 'from' with 'to' within this operation.
-void Operation::replaceUsesOfWith(Value *from, Value *to) {
+void Operation::replaceUsesOfWith(Value from, Value to) {
if (from == to)
return;
for (auto &operand : getOpOperands())
@@ -585,7 +582,7 @@
/// Return true if there are no users of any results of this operation.
bool Operation::use_empty() {
- for (auto *result : getResults())
+ for (auto result : getResults())
if (!result->use_empty())
return false;
return true;
@@ -672,14 +669,14 @@
/// Operands are remapped using `mapper` (if present), and `mapper` is updated
/// to contain the results.
Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper) {
- SmallVector<Value *, 8> operands;
+ SmallVector<Value, 8> operands;
SmallVector<Block *, 2> successors;
operands.reserve(getNumOperands() + getNumSuccessors());
if (getNumSuccessors() == 0) {
// Non-branching operations can just add all the operands.
- for (auto *opValue : getOperands())
+ for (auto opValue : getOperands())
operands.push_back(mapper.lookupOrDefault(opValue));
} else {
// We add the operands separated by nullptr's for each successor.
@@ -699,7 +696,7 @@
operands.push_back(nullptr);
// Remap the successors operands.
- for (auto *operand : getSuccessorOperands(succ))
+ for (auto operand : getSuccessorOperands(succ))
operands.push_back(mapper.lookupOrDefault(operand));
}
}
@@ -1092,8 +1089,8 @@
// These functions are out-of-line implementations of the methods in BinaryOp,
// which avoids them being template instantiated/duplicated.
-void impl::buildBinaryOp(Builder *builder, OperationState &result, Value *lhs,
- Value *rhs) {
+void impl::buildBinaryOp(Builder *builder, OperationState &result, Value lhs,
+ Value rhs) {
assert(lhs->getType() == rhs->getType());
result.addOperands({lhs, rhs});
result.types.push_back(lhs->getType());
@@ -1133,7 +1130,7 @@
// CastOp implementation
//===----------------------------------------------------------------------===//
-void impl::buildCastOp(Builder *builder, OperationState &result, Value *source,
+void impl::buildCastOp(Builder *builder, OperationState &result, Value source,
Type destType) {
result.addOperands(source);
result.addTypes(destType);
@@ -1157,7 +1154,7 @@
<< op->getResult(0)->getType();
}
-Value *impl::foldCastOp(Operation *op) {
+Value impl::foldCastOp(Operation *op) {
// Identity cast
if (op->getOperand(0)->getType() == op->getResult(0)->getType())
return op->getOperand(0);
diff --git a/third_party/mlir/lib/IR/OperationSupport.cpp b/third_party/mlir/lib/IR/OperationSupport.cpp
index 256a261..5dfd3b0 100644
--- a/third_party/mlir/lib/IR/OperationSupport.cpp
+++ b/third_party/mlir/lib/IR/OperationSupport.cpp
@@ -1,19 +1,10 @@
//===- OperationSupport.cpp -----------------------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file contains out-of-line implementations of the support types that
// Operation and related classes build on top of.
@@ -164,7 +155,7 @@
//===----------------------------------------------------------------------===//
// ValueRange
-ValueRange::ValueRange(ArrayRef<Value *> values)
+ValueRange::ValueRange(ArrayRef<Value> values)
: ValueRange(values.data(), values.size()) {}
ValueRange::ValueRange(OperandRange values)
: ValueRange(values.begin().getBase(), values.size()) {}
@@ -178,16 +169,16 @@
return operand + index;
if (OpResult *result = owner.dyn_cast<OpResult *>())
return result + index;
- return owner.get<Value *const *>() + index;
+ return owner.get<const Value *>() + index;
}
/// See `detail::indexed_accessor_range_base` for details.
-Value *ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
+Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
// Operands access the held value via 'get'.
if (OpOperand *operand = owner.dyn_cast<OpOperand *>())
return operand[index].get();
// An OpResult is a value, so we can return it directly.
if (OpResult *result = owner.dyn_cast<OpResult *>())
- return &result[index];
+ return result[index];
// Otherwise, this is a raw value array so just index directly.
- return owner.get<Value *const *>()[index];
+ return owner.get<const Value *>()[index];
}
diff --git a/third_party/mlir/lib/IR/PatternMatch.cpp b/third_party/mlir/lib/IR/PatternMatch.cpp
index 3887a03..50e6eee 100644
--- a/third_party/mlir/lib/IR/PatternMatch.cpp
+++ b/third_party/mlir/lib/IR/PatternMatch.cpp
@@ -1,19 +1,10 @@
//===- PatternMatch.cpp - Base classes for pattern match ------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/BlockAndValueMapping.h"
@@ -179,23 +170,6 @@
cloneRegionBefore(region, *before->getParent(), before->getIterator());
}
-/// This method is used as the final notification hook for patterns that end
-/// up modifying the pattern root in place, by changing its operands. This is
-/// a minor efficiency win (it avoids creating a new operation and removing
-/// the old one) but also often allows simpler code in the client.
-///
-/// The opsToRemoveIfDead list is an optional list of nodes that the rewriter
-/// should remove if they are dead at this point.
-///
-void PatternRewriter::updatedRootInPlace(Operation *op,
- ValueRange valuesToRemoveIfDead) {
- // Notify the rewriter subclass that we're about to replace this root.
- notifyRootUpdated(op);
-
- // TODO: Process the valuesToRemoveIfDead list, removing things and calling
- // the notifyOperationRemoved hook in the process.
-}
-
//===----------------------------------------------------------------------===//
// PatternMatcher implementation
//===----------------------------------------------------------------------===//
diff --git a/third_party/mlir/lib/IR/Region.cpp b/third_party/mlir/lib/IR/Region.cpp
index 6cec021..1e8abc8 100644
--- a/third_party/mlir/lib/IR/Region.cpp
+++ b/third_party/mlir/lib/IR/Region.cpp
@@ -1,19 +1,10 @@
//===- Region.cpp - MLIR Region Class -------------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/IR/Region.h"
#include "mlir/IR/BlockAndValueMapping.h"
@@ -91,7 +82,7 @@
// Clone the block arguments. The user might be deleting arguments to the
// block by specifying them in the mapper. If so, we don't add the
// argument to the cloned block.
- for (auto *arg : block.getArguments())
+ for (auto arg : block.getArguments())
if (!mapper.contains(arg))
mapper.map(arg, newBlock->addArgument(arg->getType()));
@@ -106,7 +97,7 @@
// operands of each of the operations.
auto remapOperands = [&](Operation *op) {
for (auto &operand : op->getOpOperands())
- if (auto *mappedOp = mapper.lookupOrNull(operand.get()))
+ if (auto mappedOp = mapper.lookupOrNull(operand.get()))
operand.set(mappedOp);
for (auto &succOp : op->getBlockOperands())
if (auto *mappedOp = mapper.lookupOrNull(succOp.get()))
@@ -143,7 +134,7 @@
while (!pendingRegions.empty()) {
for (Block &block : *pendingRegions.pop_back_val()) {
for (Operation &op : block) {
- for (Value *operand : op.getOperands()) {
+ for (Value operand : op.getOperands()) {
// operand should be non-null here if the IR is well-formed. But
// we don't assert here as this function is called from the verifier
// and so could be called on invalid IR.
diff --git a/third_party/mlir/lib/IR/StandardTypes.cpp b/third_party/mlir/lib/IR/StandardTypes.cpp
index 7c494e2..441b59e 100644
--- a/third_party/mlir/lib/IR/StandardTypes.cpp
+++ b/third_party/mlir/lib/IR/StandardTypes.cpp
@@ -1,19 +1,10 @@
//===- StandardTypes.cpp - MLIR Standard Type Classes ---------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/IR/StandardTypes.h"
#include "TypeDetail.h"
diff --git a/third_party/mlir/lib/IR/SymbolTable.cpp b/third_party/mlir/lib/IR/SymbolTable.cpp
index bd8cb59..83e5802 100644
--- a/third_party/mlir/lib/IR/SymbolTable.cpp
+++ b/third_party/mlir/lib/IR/SymbolTable.cpp
@@ -1,19 +1,10 @@
//===- SymbolTable.cpp - MLIR Symbol Table Class --------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/IR/SymbolTable.h"
#include "llvm/ADT/SmallString.h"
diff --git a/third_party/mlir/lib/IR/TypeDetail.h b/third_party/mlir/lib/IR/TypeDetail.h
index 5bcb0b6..b3e0edd 100644
--- a/third_party/mlir/lib/IR/TypeDetail.h
+++ b/third_party/mlir/lib/IR/TypeDetail.h
@@ -1,19 +1,10 @@
//===- TypeDetail.h - MLIR Type storage details -----------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This holds implementation details of Type.
//
diff --git a/third_party/mlir/lib/IR/TypeUtilities.cpp b/third_party/mlir/lib/IR/TypeUtilities.cpp
index 54b1bf6..0bf1627 100644
--- a/third_party/mlir/lib/IR/TypeUtilities.cpp
+++ b/third_party/mlir/lib/IR/TypeUtilities.cpp
@@ -1,19 +1,10 @@
//===- TypeUtilities.cpp - Helper function for type queries ---------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines generic type utilities.
//
@@ -33,14 +24,10 @@
return type;
}
-Type mlir::getElementTypeOrSelf(Value *val) {
+Type mlir::getElementTypeOrSelf(Value val) {
return getElementTypeOrSelf(val->getType());
}
-Type mlir::getElementTypeOrSelf(Value &val) {
- return getElementTypeOrSelf(val.getType());
-}
-
Type mlir::getElementTypeOrSelf(Attribute attr) {
return getElementTypeOrSelf(attr.getType());
}
@@ -101,18 +88,18 @@
OperandElementTypeIterator::OperandElementTypeIterator(
Operation::operand_iterator it)
- : llvm::mapped_iterator<Operation::operand_iterator, Type (*)(Value *)>(
+ : llvm::mapped_iterator<Operation::operand_iterator, Type (*)(Value)>(
it, &unwrap) {}
-Type OperandElementTypeIterator::unwrap(Value *value) {
+Type OperandElementTypeIterator::unwrap(Value value) {
return value->getType().cast<ShapedType>().getElementType();
}
ResultElementTypeIterator::ResultElementTypeIterator(
Operation::result_iterator it)
- : llvm::mapped_iterator<Operation::result_iterator, Type (*)(Value *)>(
+ : llvm::mapped_iterator<Operation::result_iterator, Type (*)(Value)>(
it, &unwrap) {}
-Type ResultElementTypeIterator::unwrap(Value *value) {
+Type ResultElementTypeIterator::unwrap(Value value) {
return value->getType().cast<ShapedType>().getElementType();
}
diff --git a/third_party/mlir/lib/IR/Types.cpp b/third_party/mlir/lib/IR/Types.cpp
index 23c80c9..923d6e1 100644
--- a/third_party/mlir/lib/IR/Types.cpp
+++ b/third_party/mlir/lib/IR/Types.cpp
@@ -1,19 +1,10 @@
//===- Types.cpp - MLIR Type Classes --------------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/IR/Types.h"
#include "TypeDetail.h"
diff --git a/third_party/mlir/lib/IR/Value.cpp b/third_party/mlir/lib/IR/Value.cpp
index 4c2ea5a..ffb9601 100644
--- a/third_party/mlir/lib/IR/Value.cpp
+++ b/third_party/mlir/lib/IR/Value.cpp
@@ -1,19 +1,10 @@
//===- Value.cpp - MLIR Value Classes -------------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/IR/Value.h"
#include "mlir/IR/Block.h"
@@ -22,8 +13,8 @@
/// If this value is the result of an Operation, return the operation that
/// defines it.
-Operation *Value::getDefiningOp() {
- if (auto *result = dyn_cast<OpResult>(this))
+Operation *Value::getDefiningOp() const {
+ if (auto result = dyn_cast<OpResult>())
return result->getOwner();
return nullptr;
}
@@ -38,7 +29,7 @@
Region *Value::getParentRegion() {
if (auto *op = getDefiningOp())
return op->getParentRegion();
- return cast<BlockArgument>(this)->getOwner()->getParent();
+ return cast<BlockArgument>()->getOwner()->getParent();
}
//===----------------------------------------------------------------------===//
diff --git a/third_party/mlir/lib/IR/Visitors.cpp b/third_party/mlir/lib/IR/Visitors.cpp
index ea2a6d6..404e74a 100644
--- a/third_party/mlir/lib/IR/Visitors.cpp
+++ b/third_party/mlir/lib/IR/Visitors.cpp
@@ -1,19 +1,10 @@
//===- Visitors.cpp - MLIR Visitor Utilties -------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/IR/Visitors.h"
#include "mlir/IR/Operation.h"
diff --git a/third_party/mlir/lib/Parser/Lexer.cpp b/third_party/mlir/lib/Parser/Lexer.cpp
index 29104c8..7d8337a9 100644
--- a/third_party/mlir/lib/Parser/Lexer.cpp
+++ b/third_party/mlir/lib/Parser/Lexer.cpp
@@ -1,19 +1,10 @@
//===- Lexer.cpp - MLIR Lexer Implementation ------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements the lexer for the MLIR textual form.
//
diff --git a/third_party/mlir/lib/Parser/Lexer.h b/third_party/mlir/lib/Parser/Lexer.h
index a7a2ac4..a760dca 100644
--- a/third_party/mlir/lib/Parser/Lexer.h
+++ b/third_party/mlir/lib/Parser/Lexer.h
@@ -1,19 +1,10 @@
//===- Lexer.h - MLIR Lexer Interface ---------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file declares the MLIR Lexer class.
//
diff --git a/third_party/mlir/lib/Parser/Parser.cpp b/third_party/mlir/lib/Parser/Parser.cpp
index 498a64d..0198a45 100644
--- a/third_party/mlir/lib/Parser/Parser.cpp
+++ b/third_party/mlir/lib/Parser/Parser.cpp
@@ -1,19 +1,10 @@
//===- Parser.cpp - MLIR Parser Implementation ----------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements the parser for the MLIR textual form.
//
@@ -3093,7 +3084,7 @@
ParseResult popSSANameScope();
/// Register a definition of a value with the symbol table.
- ParseResult addDefinition(SSAUseInfo useInfo, Value *value);
+ ParseResult addDefinition(SSAUseInfo useInfo, Value value);
/// Parse an optional list of SSA uses into 'results'.
ParseResult parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results);
@@ -3103,12 +3094,12 @@
/// Given a reference to an SSA value and its type, return a reference. This
/// returns null on failure.
- Value *resolveSSAUse(SSAUseInfo useInfo, Type type);
+ Value resolveSSAUse(SSAUseInfo useInfo, Type type);
ParseResult parseSSADefOrUseAndType(
const std::function<ParseResult(SSAUseInfo, Type)> &action);
- ParseResult parseOptionalSSAUseAndTypeList(SmallVectorImpl<Value *> &results);
+ ParseResult parseOptionalSSAUseAndTypeList(SmallVectorImpl<Value> &results);
/// Return the location of the value identified by its name and number if it
/// has been already reference.
@@ -3130,12 +3121,11 @@
/// Parse a single operation successor and its operand list.
ParseResult parseSuccessorAndUseList(Block *&dest,
- SmallVectorImpl<Value *> &operands);
+ SmallVectorImpl<Value> &operands);
/// Parse a comma-separated list of operation successors in brackets.
- ParseResult
- parseSuccessors(SmallVectorImpl<Block *> &destinations,
- SmallVectorImpl<SmallVector<Value *, 4>> &operands);
+ ParseResult parseSuccessors(SmallVectorImpl<Block *> &destinations,
+ SmallVectorImpl<SmallVector<Value, 4>> &operands);
/// Parse an operation instance that is in the generic form.
Operation *parseGenericOperation();
@@ -3173,9 +3163,8 @@
ParseResult parseBlockBody(Block *block);
/// Parse a (possibly empty) list of block arguments.
- ParseResult
- parseOptionalBlockArgList(SmallVectorImpl<BlockArgument *> &results,
- Block *owner);
+ ParseResult parseOptionalBlockArgList(SmallVectorImpl<BlockArgument> &results,
+ Block *owner);
/// Get the block with the specified name, creating it if it doesn't
/// already exist. The location specified is the point of use, which allows
@@ -3204,14 +3193,14 @@
void recordDefinition(StringRef def);
/// Get the value entry for the given SSA name.
- SmallVectorImpl<std::pair<Value *, SMLoc>> &getSSAValueEntry(StringRef name);
+ SmallVectorImpl<std::pair<Value, SMLoc>> &getSSAValueEntry(StringRef name);
/// Create a forward reference placeholder value with the given location and
/// result type.
- Value *createForwardRefPlaceholder(SMLoc loc, Type type);
+ Value createForwardRefPlaceholder(SMLoc loc, Type type);
/// Return true if this is a forward reference.
- bool isForwardRefPlaceholder(Value *value) {
+ bool isForwardRefPlaceholder(Value value) {
return forwardRefPlaceholders.count(value);
}
@@ -3236,7 +3225,7 @@
/// This keeps track of all of the SSA values we are tracking for each name
/// scope, indexed by their name. This has one entry per result number.
- llvm::StringMap<SmallVector<std::pair<Value *, SMLoc>, 1>> values;
+ llvm::StringMap<SmallVector<std::pair<Value, SMLoc>, 1>> values;
/// This keeps track of all of the values defined by a specific name scope.
SmallVector<llvm::StringSet<>, 2> definitionsPerScope;
@@ -3253,7 +3242,7 @@
/// These are all of the placeholders we've made along with the location of
/// their first reference, to allow checking for use of undefined values.
- DenseMap<Value *, SMLoc> forwardRefPlaceholders;
+ DenseMap<Value, SMLoc> forwardRefPlaceholders;
/// The builder used when creating parsed operation instances.
OpBuilder opBuilder;
@@ -3278,7 +3267,7 @@
// Check for any forward references that are left. If we find any, error
// out.
if (!forwardRefPlaceholders.empty()) {
- SmallVector<std::pair<const char *, Value *>, 4> errors;
+ SmallVector<std::pair<const char *, Value>, 4> errors;
// Iteration over the map isn't deterministic, so sort by source location.
for (auto entry : forwardRefPlaceholders)
errors.push_back({entry.second.getPointer(), entry.first});
@@ -3342,7 +3331,7 @@
}
/// Register a definition of a value with the symbol table.
-ParseResult OperationParser::addDefinition(SSAUseInfo useInfo, Value *value) {
+ParseResult OperationParser::addDefinition(SSAUseInfo useInfo, Value value) {
auto &entries = getSSAValueEntry(useInfo.name);
// Make sure there is a slot for this value.
@@ -3351,7 +3340,7 @@
// If we already have an entry for this, check to see if it was a definition
// or a forward reference.
- if (auto *existing = entries[useInfo.number].first) {
+ if (auto existing = entries[useInfo.number].first) {
if (!isForwardRefPlaceholder(existing)) {
return emitError(useInfo.loc)
.append("redefinition of SSA value '", useInfo.name, "'")
@@ -3416,12 +3405,12 @@
/// Given an unbound reference to an SSA value and its type, return the value
/// it specifies. This returns null on failure.
-Value *OperationParser::resolveSSAUse(SSAUseInfo useInfo, Type type) {
+Value OperationParser::resolveSSAUse(SSAUseInfo useInfo, Type type) {
auto &entries = getSSAValueEntry(useInfo.name);
// If we have already seen a value of this name, return it.
if (useInfo.number < entries.size() && entries[useInfo.number].first) {
- auto *result = entries[useInfo.number].first;
+ auto result = entries[useInfo.number].first;
// Check that the type matches the other uses.
if (result->getType() == type)
return result;
@@ -3447,7 +3436,7 @@
// Otherwise, this is a forward reference. Create a placeholder and remember
// that we did so.
- auto *result = createForwardRefPlaceholder(useInfo.loc, type);
+ auto result = createForwardRefPlaceholder(useInfo.loc, type);
entries[useInfo.number].first = result;
entries[useInfo.number].second = useInfo.loc;
return result;
@@ -3477,7 +3466,7 @@
/// ::= ssa-use-list ':' type-list-no-parens
///
ParseResult OperationParser::parseOptionalSSAUseAndTypeList(
- SmallVectorImpl<Value *> &results) {
+ SmallVectorImpl<Value> &results) {
SmallVector<SSAUseInfo, 4> valueIDs;
if (parseOptionalSSAUseList(valueIDs))
return failure();
@@ -3497,7 +3486,7 @@
results.reserve(valueIDs.size());
for (unsigned i = 0, e = valueIDs.size(); i != e; ++i) {
- if (auto *value = resolveSSAUse(valueIDs[i], types[i]))
+ if (auto value = resolveSSAUse(valueIDs[i], types[i]))
results.push_back(value);
else
return failure();
@@ -3512,13 +3501,13 @@
}
/// Get the value entry for the given SSA name.
-SmallVectorImpl<std::pair<Value *, SMLoc>> &
+SmallVectorImpl<std::pair<Value, SMLoc>> &
OperationParser::getSSAValueEntry(StringRef name) {
return isolatedNameScopes.back().values[name];
}
/// Create and remember a new placeholder for a forward reference.
-Value *OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) {
+Value OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) {
// Forward references are always created as operations, because we just need
// something with a def/use chain.
//
@@ -3632,7 +3621,7 @@
///
ParseResult
OperationParser::parseSuccessorAndUseList(Block *&dest,
- SmallVectorImpl<Value *> &operands) {
+ SmallVectorImpl<Value> &operands) {
// Verify branch is identifier and get the matching block.
if (!getToken().is(Token::caret_identifier))
return emitError("expected block name");
@@ -3655,13 +3644,13 @@
///
ParseResult OperationParser::parseSuccessors(
SmallVectorImpl<Block *> &destinations,
- SmallVectorImpl<SmallVector<Value *, 4>> &operands) {
+ SmallVectorImpl<SmallVector<Value, 4>> &operands) {
if (parseToken(Token::l_square, "expected '['"))
return failure();
auto parseElt = [this, &destinations, &operands]() {
Block *dest;
- SmallVector<Value *, 4> destOperands;
+ SmallVector<Value, 4> destOperands;
auto res = parseSuccessorAndUseList(dest, destOperands);
destinations.push_back(dest);
operands.push_back(destOperands);
@@ -3718,7 +3707,7 @@
// Parse the successor list but don't add successors to the result yet to
// avoid messing up with the argument order.
SmallVector<Block *, 2> successors;
- SmallVector<SmallVector<Value *, 4>, 2> successorOperands;
+ SmallVector<SmallVector<Value, 4>, 2> successorOperands;
if (getToken().is(Token::l_square)) {
// Check if the operation is a known terminator.
const AbstractOperation *abstractOp = result.name.getAbstractOperation();
@@ -3779,7 +3768,7 @@
// Add the successors, and their operands after the proper operands.
for (const auto &succ : llvm::zip(successors, successorOperands)) {
Block *successor = std::get<0>(succ);
- const SmallVector<Value *, 4> &operands = std::get<1>(succ);
+ const SmallVector<Value, 4> &operands = std::get<1>(succ);
result.addSuccessor(successor, operands);
}
@@ -4129,10 +4118,10 @@
/// Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperand(const OperandType &operand, Type type,
- SmallVectorImpl<Value *> &result) override {
+ SmallVectorImpl<Value> &result) override {
OperationParser::SSAUseInfo operandInfo = {operand.name, operand.number,
operand.location};
- if (auto *value = parser.resolveSSAUse(operandInfo, type)) {
+ if (auto value = parser.resolveSSAUse(operandInfo, type)) {
result.push_back(value);
return success();
}
@@ -4242,7 +4231,7 @@
/// Parse a single operation successor and its operand list.
ParseResult
parseSuccessorAndUseList(Block *&dest,
- SmallVectorImpl<Value *> &operands) override {
+ SmallVectorImpl<Value> &operands) override {
return parser.parseSuccessorAndUseList(dest, operands);
}
@@ -4470,7 +4459,7 @@
// If an argument list is present, parse it.
if (consumeIf(Token::l_paren)) {
- SmallVector<BlockArgument *, 8> bbArgs;
+ SmallVector<BlockArgument, 8> bbArgs;
if (parseOptionalBlockArgList(bbArgs, block) ||
parseToken(Token::r_paren, "expected ')' to end argument list"))
return failure();
@@ -4534,7 +4523,7 @@
/// ssa-id-and-type-list ::= ssa-id-and-type (`,` ssa-id-and-type)*
///
ParseResult OperationParser::parseOptionalBlockArgList(
- SmallVectorImpl<BlockArgument *> &results, Block *owner) {
+ SmallVectorImpl<BlockArgument> &results, Block *owner) {
if (getToken().is(Token::r_brace))
return success();
@@ -4555,7 +4544,7 @@
return emitError("too many arguments specified in argument list");
// Finally, make sure the existing argument has the correct type.
- auto *arg = owner->getArgument(nextArgument++);
+ auto arg = owner->getArgument(nextArgument++);
if (arg->getType() != type)
return emitError("argument and block argument type mismatch");
return addDefinition(useInfo, arg);
diff --git a/third_party/mlir/lib/Parser/Token.cpp b/third_party/mlir/lib/Parser/Token.cpp
index c01d603..84de4c3 100644
--- a/third_party/mlir/lib/Parser/Token.cpp
+++ b/third_party/mlir/lib/Parser/Token.cpp
@@ -1,19 +1,10 @@
//===- Token.cpp - MLIR Token Implementation ------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements the Token class for the MLIR textual form.
//
diff --git a/third_party/mlir/lib/Parser/Token.h b/third_party/mlir/lib/Parser/Token.h
index 333c4d2..7487736 100644
--- a/third_party/mlir/lib/Parser/Token.h
+++ b/third_party/mlir/lib/Parser/Token.h
@@ -1,19 +1,10 @@
//===- Token.h - MLIR Token Interface ---------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_LIB_PARSER_TOKEN_H
#define MLIR_LIB_PARSER_TOKEN_H
diff --git a/third_party/mlir/lib/Parser/TokenKinds.def b/third_party/mlir/lib/Parser/TokenKinds.def
index 19cd343..fc9f7821 100644
--- a/third_party/mlir/lib/Parser/TokenKinds.def
+++ b/third_party/mlir/lib/Parser/TokenKinds.def
@@ -1,19 +1,10 @@
//===- TokenKinds.def - MLIR Token Description ------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file is intended to be #include'd multiple times to extract information
// about tokens for various clients in the lexer.
diff --git a/third_party/mlir/lib/Pass/IRPrinting.cpp b/third_party/mlir/lib/Pass/IRPrinting.cpp
index 8e17215..75aadbd 100644
--- a/third_party/mlir/lib/Pass/IRPrinting.cpp
+++ b/third_party/mlir/lib/Pass/IRPrinting.cpp
@@ -1,19 +1,10 @@
//===- IRPrinting.cpp -----------------------------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/IR/Module.h"
@@ -48,14 +39,14 @@
for (Region ®ion : op->getRegions()) {
for (Block &block : region) {
addDataToHash(hasher, &block);
- for (BlockArgument *arg : block.getArguments())
+ for (BlockArgument arg : block.getArguments())
addDataToHash(hasher, arg);
}
}
// - Location
addDataToHash(hasher, op->getLoc().getAsOpaquePointer());
// - Operands
- for (Value *operand : op->getOperands())
+ for (Value operand : op->getOperands())
addDataToHash(hasher, operand);
// - Successors
for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i)
diff --git a/third_party/mlir/lib/Pass/Pass.cpp b/third_party/mlir/lib/Pass/Pass.cpp
index f893c7b..8877cc5 100644
--- a/third_party/mlir/lib/Pass/Pass.cpp
+++ b/third_party/mlir/lib/Pass/Pass.cpp
@@ -1,19 +1,10 @@
//===- Pass.cpp - Pass infrastructure implementation ----------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements common pass infrastructure.
//
@@ -45,6 +36,17 @@
/// single .o file.
void Pass::anchor() {}
+/// Attempt to initialize the options of this pass from the given string.
+LogicalResult Pass::initializeOptions(StringRef options) {
+ return passOptions.parseFromString(options);
+}
+
+/// Copy the option values from 'other', which is another instance of this
+/// pass.
+void Pass::copyOptionValuesFrom(const Pass *other) {
+ passOptions.copyOptionValuesFrom(other->passOptions);
+}
+
/// Prints out the pass in the textual representation of pipelines. If this is
/// an adaptor pass, print with the op_name(sub_pass,...) format.
void Pass::printAsTextualPipeline(raw_ostream &os) {
@@ -55,11 +57,14 @@
pm.printAsTextualPipeline(os);
os << ")";
});
- } else if (const PassInfo *info = lookupPassInfo()) {
- os << info->getPassArgument();
- } else {
- os << getName();
+ return;
}
+ // Otherwise, print the pass argument followed by its options.
+ if (const PassInfo *info = lookupPassInfo())
+ os << info->getPassArgument();
+ else
+ os << getName();
+ passOptions.print(os);
}
/// Forwarding function to execute this pass.
diff --git a/third_party/mlir/lib/Pass/PassDetail.h b/third_party/mlir/lib/Pass/PassDetail.h
index d0a2ea6..9a52535 100644
--- a/third_party/mlir/lib/Pass/PassDetail.h
+++ b/third_party/mlir/lib/Pass/PassDetail.h
@@ -1,19 +1,10 @@
//===- PassDetail.h - MLIR Pass details -------------------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef MLIR_PASS_PASSDETAIL_H_
#define MLIR_PASS_PASSDETAIL_H_
diff --git a/third_party/mlir/lib/Pass/PassManagerOptions.cpp b/third_party/mlir/lib/Pass/PassManagerOptions.cpp
index c29e0d0..8748706 100644
--- a/third_party/mlir/lib/Pass/PassManagerOptions.cpp
+++ b/third_party/mlir/lib/Pass/PassManagerOptions.cpp
@@ -1,19 +1,10 @@
//===- PassManagerOptions.cpp - PassManager Command Line Options ----------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
diff --git a/third_party/mlir/lib/Pass/PassRegistry.cpp b/third_party/mlir/lib/Pass/PassRegistry.cpp
index 1a321d6..1c5193d 100644
--- a/third_party/mlir/lib/Pass/PassRegistry.cpp
+++ b/third_party/mlir/lib/Pass/PassRegistry.cpp
@@ -1,19 +1,10 @@
//===- PassRegistry.cpp - Pass Registration Utilities ---------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Pass/Pass.h"
@@ -33,10 +24,15 @@
static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
passPipelineRegistry;
-// Helper to avoid exposing OpPassManager.
-void mlir::detail::addPassToPassManager(OpPassManager &pm,
- std::unique_ptr<Pass> pass) {
- pm.addPass(std::move(pass));
+/// Utility to create a default registry function from a pass instance.
+static PassRegistryFunction
+buildDefaultRegistryFn(const PassAllocatorFunction &allocator) {
+ return [=](OpPassManager &pm, StringRef options) {
+ std::unique_ptr<Pass> pass = allocator();
+ LogicalResult result = pass->initializeOptions(options);
+ pm.addPass(std::move(pass));
+ return result;
+ };
}
//===----------------------------------------------------------------------===//
@@ -55,9 +51,13 @@
// PassInfo
//===----------------------------------------------------------------------===//
+PassInfo::PassInfo(StringRef arg, StringRef description, const PassID *passID,
+ const PassAllocatorFunction &allocator)
+ : PassRegistryEntry(arg, description, buildDefaultRegistryFn(allocator)) {}
+
void mlir::registerPass(StringRef arg, StringRef description,
const PassID *passID,
- const PassRegistryFunction &function) {
+ const PassAllocatorFunction &function) {
PassInfo passInfo(arg, description, passID, function);
bool inserted = passRegistry->try_emplace(passID, passInfo).second;
assert(inserted && "Pass registered multiple times");
@@ -76,7 +76,19 @@
// PassOptions
//===----------------------------------------------------------------------===//
-LogicalResult PassOptionsBase::parseFromString(StringRef options) {
+/// Out of line virtual function to provide home for the class.
+void detail::PassOptions::OptionBase::anchor() {}
+
+/// Copy the option values from 'other'.
+void detail::PassOptions::copyOptionValuesFrom(const PassOptions &other) {
+ assert(options.size() == other.options.size());
+ if (options.empty())
+ return;
+ for (auto optionsIt : llvm::zip(options, other.options))
+ std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt));
+}
+
+LogicalResult detail::PassOptions::parseFromString(StringRef options) {
// TODO(parkers): Handle escaping strings.
// NOTE: `options` is modified in place to always refer to the unprocessed
// part of the string.
@@ -108,7 +120,6 @@
auto it = OptionsMap.find(key);
if (it == OptionsMap.end()) {
llvm::errs() << "<Pass-Options-Parser>: no such option " << key << "\n";
-
return failure();
}
if (llvm::cl::ProvidePositionalOption(it->second, value, 0))
@@ -118,6 +129,28 @@
return success();
}
+/// Print the options held by this struct in a form that can be parsed via
+/// 'parseFromString'.
+void detail::PassOptions::print(raw_ostream &os) {
+ // If there are no options, there is nothing left to do.
+ if (OptionsMap.empty())
+ return;
+
+ // Sort the options to make the ordering deterministic.
+ SmallVector<OptionBase *, 4> orderedOptions(options.begin(), options.end());
+ llvm::array_pod_sort(orderedOptions.begin(), orderedOptions.end(),
+ [](OptionBase *const *lhs, OptionBase *const *rhs) {
+ return (*lhs)->getArgStr().compare(
+ (*rhs)->getArgStr());
+ });
+
+ // Interleave the options with ' '.
+ os << '{';
+ interleave(
+ orderedOptions, os, [&](OptionBase *option) { option->print(os); }, " ");
+ os << '}';
+}
+
//===----------------------------------------------------------------------===//
// TextualPassPipeline Parser
//===----------------------------------------------------------------------===//
diff --git a/third_party/mlir/lib/Pass/PassStatistics.cpp b/third_party/mlir/lib/Pass/PassStatistics.cpp
index 5306974..0ab656c 100644
--- a/third_party/mlir/lib/Pass/PassStatistics.cpp
+++ b/third_party/mlir/lib/Pass/PassStatistics.cpp
@@ -1,19 +1,10 @@
//===- PassStatistics.cpp -------------------------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Pass/PassManager.h"
diff --git a/third_party/mlir/lib/Pass/PassTiming.cpp b/third_party/mlir/lib/Pass/PassTiming.cpp
index 113b65a..93e640e 100644
--- a/third_party/mlir/lib/Pass/PassTiming.cpp
+++ b/third_party/mlir/lib/Pass/PassTiming.cpp
@@ -1,19 +1,10 @@
//===- PassTiming.cpp -----------------------------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Pass/PassManager.h"
diff --git a/third_party/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp b/third_party/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp
index 94e3642..ba9c078 100644
--- a/third_party/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp
+++ b/third_party/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp
@@ -1,19 +1,10 @@
//===- FxpMathConfig.cpp - Reference fixed point config -------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines a TargetConfiguration for reference fixed-point math
// quantization scheme based on the FxpMathOps (plus a small category of
diff --git a/third_party/mlir/lib/Quantizer/Support/Configuration.cpp b/third_party/mlir/lib/Quantizer/Support/Configuration.cpp
index 78a7451..f64cc85 100644
--- a/third_party/mlir/lib/Quantizer/Support/Configuration.cpp
+++ b/third_party/mlir/lib/Quantizer/Support/Configuration.cpp
@@ -1,19 +1,10 @@
//===- Configuration.cpp - Configuration object base classes --------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Quantizer/Support/Configuration.h"
diff --git a/third_party/mlir/lib/Quantizer/Support/ConstraintAnalysisGraph.cpp b/third_party/mlir/lib/Quantizer/Support/ConstraintAnalysisGraph.cpp
index d38c762..3c194bb 100644
--- a/third_party/mlir/lib/Quantizer/Support/ConstraintAnalysisGraph.cpp
+++ b/third_party/mlir/lib/Quantizer/Support/ConstraintAnalysisGraph.cpp
@@ -1,19 +1,10 @@
//===- ConstraintAnalysisGraph.cpp - Graphs type for constraints ----------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
@@ -102,7 +93,7 @@
std::vector<std::pair<CAGAnchorNode *, CAGAnchorNode *>> impliedPairs;
for (auto &resultAnchorPair : resultAnchors) {
CAGResultAnchor *resultAnchor = resultAnchorPair.second;
- Value *resultValue = resultAnchor->getValue();
+ Value resultValue = resultAnchor->getValue();
for (auto &use : resultValue->getUses()) {
Operation *operandOp = use.getOwner();
unsigned operandIdx = use.getOperandNumber();
diff --git a/third_party/mlir/lib/Quantizer/Support/Metadata.cpp b/third_party/mlir/lib/Quantizer/Support/Metadata.cpp
index 89478c4..b7badfd 100644
--- a/third_party/mlir/lib/Quantizer/Support/Metadata.cpp
+++ b/third_party/mlir/lib/Quantizer/Support/Metadata.cpp
@@ -1,19 +1,10 @@
//===- Metadata.cpp - Top level types and metadata ------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Quantizer/Support/Metadata.h"
diff --git a/third_party/mlir/lib/Quantizer/Support/Statistics.cpp b/third_party/mlir/lib/Quantizer/Support/Statistics.cpp
index 6753898..3c8b041 100644
--- a/third_party/mlir/lib/Quantizer/Support/Statistics.cpp
+++ b/third_party/mlir/lib/Quantizer/Support/Statistics.cpp
@@ -1,19 +1,10 @@
//===- Statistics.cpp - Collects statistics over tensors ------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Quantizer/Support/Statistics.h"
diff --git a/third_party/mlir/lib/Quantizer/Support/TypeUtils.cpp b/third_party/mlir/lib/Quantizer/Support/TypeUtils.cpp
index fab4e56..a1f52c5 100644
--- a/third_party/mlir/lib/Quantizer/Support/TypeUtils.cpp
+++ b/third_party/mlir/lib/Quantizer/Support/TypeUtils.cpp
@@ -1,19 +1,10 @@
//===- TypeUtils.cpp - Helper function for manipulating types -------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Quantizer/Support/TypeUtils.h"
diff --git a/third_party/mlir/lib/Quantizer/Support/UniformConstraints.cpp b/third_party/mlir/lib/Quantizer/Support/UniformConstraints.cpp
index 1a800da..b202135 100644
--- a/third_party/mlir/lib/Quantizer/Support/UniformConstraints.cpp
+++ b/third_party/mlir/lib/Quantizer/Support/UniformConstraints.cpp
@@ -1,19 +1,10 @@
//===- UniformConstraints.cpp - Constraints for uniform quant -------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Quantizer/Support/UniformConstraints.h"
diff --git a/third_party/mlir/lib/Quantizer/Support/UniformSolvers.cpp b/third_party/mlir/lib/Quantizer/Support/UniformSolvers.cpp
index 77d69be..2f6bb20 100644
--- a/third_party/mlir/lib/Quantizer/Support/UniformSolvers.cpp
+++ b/third_party/mlir/lib/Quantizer/Support/UniformSolvers.cpp
@@ -1,19 +1,10 @@
//===- UniformSolvers.cpp - Uniform type solver algorithms ----------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Quantizer/Support/UniformSolvers.h"
#include "mlir/Support/LLVM.h"
diff --git a/third_party/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp b/third_party/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp
index a32bb2c..a27f09b 100644
--- a/third_party/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp
+++ b/third_party/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp
@@ -1,19 +1,10 @@
//===- AddDefaultStatsTestPass.cpp - Testing pass to add default stats ----===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines a testing pass to add default statistics nodes to every
// quantization eligible op. Useful for unit testing.
@@ -74,7 +65,7 @@
auto func = getFunction();
// Insert stats for each argument.
- for (auto *arg : func.getArguments()) {
+ for (auto arg : func.getArguments()) {
if (!config.isHandledType(arg->getType()))
continue;
OpBuilder b(func.getBody());
diff --git a/third_party/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp b/third_party/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp
index 511df0a..5ecb668 100644
--- a/third_party/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp
+++ b/third_party/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp
@@ -1,19 +1,10 @@
//===- InferQuantizedTypesPass.cpp - Infers quantized types ---------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines the primary pass for instantiating a CAG, running it to
// convergence on a module to determine eligible quantized type transforms, and
@@ -181,17 +172,17 @@
void InferQuantizedTypesPass::transformOperandType(CAGOperandAnchor *anchor,
Type newType) {
- Value *inputValue = anchor->getValue();
+ Value inputValue = anchor->getValue();
Operation *op = anchor->getOp();
OpBuilder b(op->getBlock(), Block::iterator(op));
- SmallVector<Value *, 1> removeValuesIfDead;
+ SmallVector<Value, 1> removeValuesIfDead;
// Because we've already run the result transforms at this phase, it is
// very likely that inputValue points to a dcast op whose input matches
// our type. We detect that situation and route around just to save some
// bulk in the IR.
- Value *newTypedInputValue = inputValue;
+ Value newTypedInputValue = inputValue;
auto inputDcastOp =
dyn_cast_or_null<DequantizeCastOp>(inputValue->getDefiningOp());
if (inputDcastOp && inputDcastOp.arg()->getType() == newType) {
@@ -228,7 +219,7 @@
break;
}
- for (Value *removeValueIfDead : removeValuesIfDead) {
+ for (Value removeValueIfDead : removeValuesIfDead) {
if (removeValueIfDead->use_empty()) {
removeValueIfDead->getDefiningOp()->erase();
}
@@ -237,12 +228,12 @@
void InferQuantizedTypesPass::transformResultType(CAGResultAnchor *anchor,
Type newType) {
- Value *origResultValue = anchor->getValue();
+ Value origResultValue = anchor->getValue();
Operation *op = origResultValue->getDefiningOp();
OpBuilder b(op->getBlock(), ++Block::iterator(op));
- Value *replacedResultValue = nullptr;
- Value *newResultValue = nullptr;
+ Value replacedResultValue = nullptr;
+ Value newResultValue = nullptr;
switch (anchor->getTypeTransformRule()) {
case CAGAnchorNode::TypeTransformRule::Direct:
origResultValue->setType(newType);
diff --git a/third_party/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp b/third_party/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp
index 0266520..da5bd12 100644
--- a/third_party/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp
+++ b/third_party/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp
@@ -1,19 +1,10 @@
//===- RemoveInstrumentationPass.cpp - Removes instrumentation ------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines a pass to remove any instrumentation ops. It is often one
// of the final steps when performing quantization and is run after any
diff --git a/third_party/mlir/lib/Support/FileUtilities.cpp b/third_party/mlir/lib/Support/FileUtilities.cpp
index 6f0dc93..a56ae57 100644
--- a/third_party/mlir/lib/Support/FileUtilities.cpp
+++ b/third_party/mlir/lib/Support/FileUtilities.cpp
@@ -1,19 +1,10 @@
//===- FileUtilities.cpp - utilities for working with files ---------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Definitions of common utilities for working with files.
//
diff --git a/third_party/mlir/lib/Support/JitRunner.cpp b/third_party/mlir/lib/Support/JitRunner.cpp
index dcd2343..b327d3d 100644
--- a/third_party/mlir/lib/Support/JitRunner.cpp
+++ b/third_party/mlir/lib/Support/JitRunner.cpp
@@ -1,19 +1,10 @@
//===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This is a library that provides a shared implementation for command line
// utilities that execute an MLIR file on the CPU by translating MLIR to LLVM
diff --git a/third_party/mlir/lib/Support/MlirOptMain.cpp b/third_party/mlir/lib/Support/MlirOptMain.cpp
index c256e97..4a76801 100644
--- a/third_party/mlir/lib/Support/MlirOptMain.cpp
+++ b/third_party/mlir/lib/Support/MlirOptMain.cpp
@@ -1,19 +1,10 @@
//===- MlirOptMain.cpp - MLIR Optimizer Driver ----------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This is a utility that runs an optimization pass and prints the result back
// out. It is designed to support unit testing.
diff --git a/third_party/mlir/lib/Support/StorageUniquer.cpp b/third_party/mlir/lib/Support/StorageUniquer.cpp
index cae4dce..d6f6bac 100644
--- a/third_party/mlir/lib/Support/StorageUniquer.cpp
+++ b/third_party/mlir/lib/Support/StorageUniquer.cpp
@@ -1,19 +1,10 @@
//===- StorageUniquer.cpp - Common Storage Class Uniquer ------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Support/StorageUniquer.h"
diff --git a/third_party/mlir/lib/Support/ToolUtilities.cpp b/third_party/mlir/lib/Support/ToolUtilities.cpp
index 60d0eee..cd2df78 100644
--- a/third_party/mlir/lib/Support/ToolUtilities.cpp
+++ b/third_party/mlir/lib/Support/ToolUtilities.cpp
@@ -1,19 +1,10 @@
//===- ToolUtilities.cpp - MLIR Tool Utilities ----------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines common utilities for implementing MLIR tools.
//
diff --git a/third_party/mlir/lib/Support/TranslateClParser.cpp b/third_party/mlir/lib/Support/TranslateClParser.cpp
index 115c0c0..1f538cb 100644
--- a/third_party/mlir/lib/Support/TranslateClParser.cpp
+++ b/third_party/mlir/lib/Support/TranslateClParser.cpp
@@ -1,19 +1,10 @@
//===- TranslateClParser.h - Translations command line parser -------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file contains custom command line parser for translations.
//
diff --git a/third_party/mlir/lib/TableGen/Argument.cpp b/third_party/mlir/lib/TableGen/Argument.cpp
index 17dba05..080e717 100644
--- a/third_party/mlir/lib/TableGen/Argument.cpp
+++ b/third_party/mlir/lib/TableGen/Argument.cpp
@@ -1,19 +1,10 @@
//===- Argument.cpp - Argument definitions --------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/TableGen/Argument.h"
#include "llvm/TableGen/Record.h"
diff --git a/third_party/mlir/lib/TableGen/Attribute.cpp b/third_party/mlir/lib/TableGen/Attribute.cpp
index ec946a8..92f5b1f 100644
--- a/third_party/mlir/lib/TableGen/Attribute.cpp
+++ b/third_party/mlir/lib/TableGen/Attribute.cpp
@@ -1,19 +1,10 @@
//===- Attribute.cpp - Attribute wrapper class ----------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Attribute wrapper to simplify using TableGen Record defining a MLIR
// Attribute.
diff --git a/third_party/mlir/lib/TableGen/Constraint.cpp b/third_party/mlir/lib/TableGen/Constraint.cpp
index ef3fa52..022c5ad 100644
--- a/third_party/mlir/lib/TableGen/Constraint.cpp
+++ b/third_party/mlir/lib/TableGen/Constraint.cpp
@@ -1,19 +1,10 @@
//===- Constraint.cpp - Constraint class ----------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Constraint wrapper to simplify using TableGen Record for constraints.
//
diff --git a/third_party/mlir/lib/TableGen/Dialect.cpp b/third_party/mlir/lib/TableGen/Dialect.cpp
index ace4ce3..d9e8e2f 100644
--- a/third_party/mlir/lib/TableGen/Dialect.cpp
+++ b/third_party/mlir/lib/TableGen/Dialect.cpp
@@ -1,19 +1,10 @@
//===- Dialect.cpp - Dialect wrapper class --------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Dialect wrapper to simplify using TableGen Record defining a MLIR dialect.
//
diff --git a/third_party/mlir/lib/TableGen/Format.cpp b/third_party/mlir/lib/TableGen/Format.cpp
index 967d51a..07742ab 100644
--- a/third_party/mlir/lib/TableGen/Format.cpp
+++ b/third_party/mlir/lib/TableGen/Format.cpp
@@ -1,19 +1,10 @@
//===- Format.cpp - Utilities for String Format ---------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines utilities for formatting strings. They are specially
// tailored to the needs of TableGen'ing op definitions and rewrite rules,
diff --git a/third_party/mlir/lib/TableGen/OpInterfaces.cpp b/third_party/mlir/lib/TableGen/OpInterfaces.cpp
index 1687f3a..b1e56ef 100644
--- a/third_party/mlir/lib/TableGen/OpInterfaces.cpp
+++ b/third_party/mlir/lib/TableGen/OpInterfaces.cpp
@@ -1,19 +1,10 @@
//===- OpInterfaces.cpp - OpInterfaces class ------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// OpInterfaces wrapper to simplify using TableGen OpInterfaces.
//
diff --git a/third_party/mlir/lib/TableGen/OpTrait.cpp b/third_party/mlir/lib/TableGen/OpTrait.cpp
index 0e436a8..86e34cd 100644
--- a/third_party/mlir/lib/TableGen/OpTrait.cpp
+++ b/third_party/mlir/lib/TableGen/OpTrait.cpp
@@ -1,19 +1,10 @@
//===- OpTrait.cpp - OpTrait class ----------------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// OpTrait wrapper to simplify using TableGen Record defining a MLIR OpTrait.
//
diff --git a/third_party/mlir/lib/TableGen/Operator.cpp b/third_party/mlir/lib/TableGen/Operator.cpp
index 4529208..d61eec4 100644
--- a/third_party/mlir/lib/TableGen/Operator.cpp
+++ b/third_party/mlir/lib/TableGen/Operator.cpp
@@ -1,19 +1,10 @@
//===- Operator.cpp - Operator class --------------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Operator wrapper to simplify using TableGen Record defining a MLIR Op.
//
@@ -23,6 +14,7 @@
#include "mlir/TableGen/OpTrait.h"
#include "mlir/TableGen/Predicate.h"
#include "mlir/TableGen/Type.h"
+#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
@@ -293,12 +285,18 @@
results.push_back({name, TypeConstraint(resultDef)});
}
- auto traitListInit = def.getValueAsListInit("traits");
- if (!traitListInit)
- return;
- traits.reserve(traitListInit->size());
- for (auto traitInit : *traitListInit)
- traits.push_back(OpTrait::create(traitInit));
+ // Create list of traits, skipping over duplicates: appending to lists in
+ // tablegen is easy, making them unique less so, so dedupe here.
+ if (auto traitList = def.getValueAsListInit("traits")) {
+ // This is uniquing based on pointers of the trait.
+ SmallPtrSet<const llvm::Init *, 32> traitSet;
+ traits.reserve(traitSet.size());
+ for (auto traitInit : *traitList) {
+ // Keep traits in the same order while skipping over duplicates.
+ if (traitSet.insert(traitInit).second)
+ traits.push_back(OpTrait::create(traitInit));
+ }
+ }
// Handle regions
auto *regionsDag = def.getValueAsDag("regions");
diff --git a/third_party/mlir/lib/TableGen/Pattern.cpp b/third_party/mlir/lib/TableGen/Pattern.cpp
index 098dba3..ada2af8 100644
--- a/third_party/mlir/lib/TableGen/Pattern.cpp
+++ b/third_party/mlir/lib/TableGen/Pattern.cpp
@@ -1,19 +1,10 @@
//===- Pattern.cpp - Pattern wrapper class --------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Pattern wrapper class to simplify using TableGen Record defining a MLIR
// Pattern.
@@ -224,7 +215,7 @@
return formatv("Operation::operand_range {0}(op0->getOperands());\n", name);
}
case Kind::Value: {
- return formatv("ArrayRef<Value *> {0};\n", name);
+ return formatv("ArrayRef<Value> {0};\n", name);
}
case Kind::Result: {
// Use the op itself for captured results.
diff --git a/third_party/mlir/lib/TableGen/Predicate.cpp b/third_party/mlir/lib/TableGen/Predicate.cpp
index f8f23e0..c52e15d 100644
--- a/third_party/mlir/lib/TableGen/Predicate.cpp
+++ b/third_party/mlir/lib/TableGen/Predicate.cpp
@@ -1,19 +1,10 @@
//===- Predicate.cpp - Predicate class ------------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Wrapper around predicates defined in TableGen.
//
diff --git a/third_party/mlir/lib/TableGen/Type.cpp b/third_party/mlir/lib/TableGen/Type.cpp
index a558be4..9a309bd 100644
--- a/third_party/mlir/lib/TableGen/Type.cpp
+++ b/third_party/mlir/lib/TableGen/Type.cpp
@@ -1,19 +1,10 @@
//===- Type.cpp - Type class ----------------------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Type wrapper to simplify using TableGen Record defining a MLIR Type.
//
diff --git a/third_party/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/third_party/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
index 6cf975b..4466fb5 100644
--- a/third_party/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
+++ b/third_party/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
@@ -1,19 +1,10 @@
//===- ConvertFromLLVMIR.cpp - MLIR to LLVM IR conversion -----------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements a translation between LLVM IR and the MLIR LLVM dialect.
//
@@ -76,7 +67,7 @@
/// `value` is an SSA-use. Return the remapped version of `value` or a
/// placeholder that will be remapped later if this is an instruction that
/// has not yet been visited.
- Value *processValue(llvm::Value *value);
+ Value processValue(llvm::Value *value);
/// Create the most accurate Location possible using a llvm::DebugLoc and
/// possibly an llvm::Instruction to narrow the Location if debug information
/// is unavailable.
@@ -85,14 +76,14 @@
/// `br` branches to `target`. Return the block arguments to attach to the
/// generated branch op. These should be in the same order as the PHIs in
/// `target`.
- SmallVector<Value *, 4> processBranchArgs(llvm::BranchInst *br,
- llvm::BasicBlock *target);
+ SmallVector<Value, 4> processBranchArgs(llvm::BranchInst *br,
+ llvm::BasicBlock *target);
/// Return `value` as an attribute to attach to a GlobalOp.
Attribute getConstantAsAttr(llvm::Constant *value);
/// Return `c` as an MLIR Value. This could either be a ConstantOp, or
/// an expanded sequence of ops in the current function's entry block (for
/// ConstantExprs or ConstantGEPs).
- Value *processConstant(llvm::Constant *c);
+ Value processConstant(llvm::Constant *c);
/// The current builder, pointing at where the next Instruction should be
/// generated.
@@ -120,7 +111,7 @@
/// Remapped blocks, for the current function.
DenseMap<llvm::BasicBlock *, Block *> blocks;
/// Remapped values. These are function-local.
- DenseMap<llvm::Value *, Value *> instMap;
+ DenseMap<llvm::Value *, Value> instMap;
/// Instructions that had not been defined when first encountered as a use.
/// Maps to the dummy Operation that was created in processValue().
DenseMap<llvm::Value *, Operation *> unknownInstMap;
@@ -263,13 +254,13 @@
Region &r = op.getInitializerRegion();
currentEntryBlock = b.createBlock(&r);
b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin());
- Value *v = processConstant(GV->getInitializer());
- b.create<ReturnOp>(op.getLoc(), ArrayRef<Value *>({v}));
+ Value v = processConstant(GV->getInitializer());
+ b.create<ReturnOp>(op.getLoc(), ArrayRef<Value>({v}));
}
return globals[GV] = op;
}
-Value *Importer::processConstant(llvm::Constant *c) {
+Value Importer::processConstant(llvm::Constant *c) {
if (Attribute attr = getConstantAsAttr(c)) {
// These constants can be represented as attributes.
OpBuilder b(currentEntryBlock, currentEntryBlock->begin());
@@ -298,7 +289,7 @@
return nullptr;
}
-Value *Importer::processValue(llvm::Value *value) {
+Value Importer::processValue(llvm::Value *value) {
auto it = instMap.find(value);
if (it != instMap.end())
return it->second;
@@ -407,9 +398,9 @@
// `br` branches to `target`. Return the branch arguments to `br`, in the
// same order of the PHIs in `target`.
-SmallVector<Value *, 4> Importer::processBranchArgs(llvm::BranchInst *br,
- llvm::BasicBlock *target) {
- SmallVector<Value *, 4> v;
+SmallVector<Value, 4> Importer::processBranchArgs(llvm::BranchInst *br,
+ llvm::BasicBlock *target) {
+ SmallVector<Value, 4> v;
for (auto inst = target->begin(); isa<llvm::PHINode>(inst); ++inst) {
auto *PN = cast<llvm::PHINode>(&*inst);
v.push_back(processValue(PN->getIncomingValueForBlock(br->getParent())));
@@ -421,7 +412,7 @@
// FIXME: Support uses of SubtargetData. Currently inbounds GEPs, fast-math
// flags and call / operand attributes are not supported.
Location loc = processDebugLoc(inst->getDebugLoc(), inst);
- Value *&v = instMap[inst];
+ Value &v = instMap[inst];
assert(!v && "processInstruction must be called only once per instruction!");
switch (inst->getOpcode()) {
default:
@@ -462,7 +453,7 @@
case llvm::Instruction::AddrSpaceCast:
case llvm::Instruction::BitCast: {
OperationState state(loc, opcMap.lookup(inst->getOpcode()));
- SmallVector<Value *, 4> ops;
+ SmallVector<Value, 4> ops;
ops.reserve(inst->getNumOperands());
for (auto *op : inst->operand_values())
ops.push_back(processValue(op));
@@ -484,7 +475,7 @@
auto *brInst = cast<llvm::BranchInst>(inst);
OperationState state(loc,
brInst->isConditional() ? "llvm.cond_br" : "llvm.br");
- SmallVector<Value *, 4> ops;
+ SmallVector<Value, 4> ops;
if (brInst->isConditional())
ops.push_back(processValue(brInst->getCondition()));
state.addOperands(ops);
@@ -500,7 +491,7 @@
}
case llvm::Instruction::Call: {
llvm::CallInst *ci = cast<llvm::CallInst>(inst);
- SmallVector<Value *, 4> ops;
+ SmallVector<Value, 4> ops;
ops.reserve(inst->getNumOperands());
for (auto &op : ci->arg_operands())
ops.push_back(processValue(op.get()));
@@ -523,7 +514,7 @@
case llvm::Instruction::GetElementPtr: {
// FIXME: Support inbounds GEPs.
llvm::GetElementPtrInst *gep = cast<llvm::GetElementPtrInst>(inst);
- SmallVector<Value *, 4> ops;
+ SmallVector<Value, 4> ops;
for (auto *op : gep->operand_values())
ops.push_back(processValue(op));
v = b.create<GEPOp>(loc, processType(inst->getType()), ops,
@@ -565,8 +556,8 @@
// any unknown uses we encountered are remapped.
for (auto &llvmAndUnknown : unknownInstMap) {
assert(instMap.count(llvmAndUnknown.first));
- Value *newValue = instMap[llvmAndUnknown.first];
- Value *oldValue = llvmAndUnknown.second->getResult(0);
+ Value newValue = instMap[llvmAndUnknown.first];
+ Value oldValue = llvmAndUnknown.second->getResult(0);
oldValue->replaceAllUsesWith(newValue);
llvmAndUnknown.second->erase();
}
diff --git a/third_party/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/third_party/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
index e69dce7..4cc5997 100644
--- a/third_party/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
+++ b/third_party/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
@@ -1,19 +1,10 @@
//===- ConvertToLLVMIR.cpp - MLIR to LLVM IR conversion -------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements a translation between the MLIR LLVM dialect and LLVM IR.
//
diff --git a/third_party/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp b/third_party/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp
index 8baed98..a599217 100644
--- a/third_party/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp
+++ b/third_party/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp
@@ -1,19 +1,10 @@
//===- ConvertToNVVMIR.cpp - MLIR to LLVM IR conversion -------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements a translation between the MLIR LLVM + NVVM dialects and
// LLVM IR with NVVM intrinsics and metadata.
diff --git a/third_party/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp b/third_party/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp
index f119b13..881d165 100644
--- a/third_party/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp
+++ b/third_party/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp
@@ -1,19 +1,10 @@
//===- ConvertToROCDLIR.cpp - MLIR to LLVM IR conversion ------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements a translation between the MLIR LLVM + ROCDL dialects and
// LLVM IR with ROCDL intrinsics and metadata.
diff --git a/third_party/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/third_party/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 6206a88..e3c0768 100644
--- a/third_party/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/third_party/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1,19 +1,10 @@
//===- ModuleTranslation.cpp - MLIR to LLVM conversion --------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements the translation between an MLIR LLVM dialect module and
// the corresponding LLVMIR module. It only handles core LLVM IR operations.
@@ -248,7 +239,7 @@
auto predecessors = bb.getPredecessors();
unsigned numPredecessors =
std::distance(predecessors.begin(), predecessors.end());
- for (auto *arg : bb.getArguments()) {
+ for (auto arg : bb.getArguments()) {
auto wrappedType = arg->getType().dyn_cast<LLVM::LLVMType>();
if (!wrappedType)
return emitError(bb.front().getLoc(),
@@ -342,8 +333,8 @@
/// Get the SSA value passed to the current block from the terminator operation
/// of its predecessor.
-static Value *getPHISourceValue(Block *current, Block *pred,
- unsigned numArguments, unsigned index) {
+static Value getPHISourceValue(Block *current, Block *pred,
+ unsigned numArguments, unsigned index) {
auto &terminator = *pred->getTerminator();
if (isa<LLVM::BrOp>(terminator)) {
return terminator.getOperand(index);
@@ -420,7 +411,7 @@
unsigned int argIdx = 0;
for (const auto &kvp : llvm::zip(func.getArguments(), llvmFunc->args())) {
llvm::Argument &llvmArg = std::get<1>(kvp);
- BlockArgument *mlirArg = std::get<0>(kvp);
+ BlockArgument mlirArg = std::get<0>(kvp);
if (auto attr = func.getArgAttrOfType<BoolAttr>(argIdx, "llvm.noalias")) {
// NB: Attribute already verified to be boolean, so check if we can indeed
@@ -492,6 +483,16 @@
return success();
}
+/// A helper to look up remapped operands in the value remapping table.`
+SmallVector<llvm::Value *, 8>
+ModuleTranslation::lookupValues(ValueRange values) {
+ SmallVector<llvm::Value *, 8> remapped;
+ remapped.reserve(values.size());
+ for (Value v : values)
+ remapped.push_back(valueMapping.lookup(v));
+ return remapped;
+}
+
std::unique_ptr<llvm::Module>
ModuleTranslation::prepareLLVMModule(Operation *m) {
auto *dialect = m->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
diff --git a/third_party/mlir/lib/Transforms/AffineDataCopyGeneration.cpp b/third_party/mlir/lib/Transforms/AffineDataCopyGeneration.cpp
index 7fb356f..902f5c3 100644
--- a/third_party/mlir/lib/Transforms/AffineDataCopyGeneration.cpp
+++ b/third_party/mlir/lib/Transforms/AffineDataCopyGeneration.cpp
@@ -1,19 +1,10 @@
//===- AffineDataCopyGeneration.cpp - Explicit memref copying pass ------*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements a pass to automatically promote accessed memref regions
// to buffers in a faster memory space that is explicitly managed, with the
@@ -130,7 +121,7 @@
bool skipNonUnitStrideLoops;
// Constant zero index to avoid too many duplicates.
- Value *zeroIndex = nullptr;
+ Value zeroIndex = nullptr;
};
} // end anonymous namespace
diff --git a/third_party/mlir/lib/Transforms/AffineLoopInvariantCodeMotion.cpp b/third_party/mlir/lib/Transforms/AffineLoopInvariantCodeMotion.cpp
index f384f6d..24ec2d7 100644
--- a/third_party/mlir/lib/Transforms/AffineLoopInvariantCodeMotion.cpp
+++ b/third_party/mlir/lib/Transforms/AffineLoopInvariantCodeMotion.cpp
@@ -1,19 +1,10 @@
//===- AffineLoopInvariantCodeMotion.cpp - Code to perform loop fusion-----===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements loop invariant code motion.
//
@@ -58,15 +49,15 @@
} // end anonymous namespace
static bool
-checkInvarianceOfNestedIfOps(Operation *op, Value *indVar,
+checkInvarianceOfNestedIfOps(Operation *op, Value indVar,
SmallPtrSetImpl<Operation *> &definedOps,
SmallPtrSetImpl<Operation *> &opsToHoist);
-static bool isOpLoopInvariant(Operation &op, Value *indVar,
+static bool isOpLoopInvariant(Operation &op, Value indVar,
SmallPtrSetImpl<Operation *> &definedOps,
SmallPtrSetImpl<Operation *> &opsToHoist);
static bool
-areAllOpsInTheBlockListInvariant(Region &blockList, Value *indVar,
+areAllOpsInTheBlockListInvariant(Region &blockList, Value indVar,
SmallPtrSetImpl<Operation *> &definedOps,
SmallPtrSetImpl<Operation *> &opsToHoist);
@@ -79,7 +70,7 @@
}
// Returns true if the individual op is loop invariant.
-bool isOpLoopInvariant(Operation &op, Value *indVar,
+bool isOpLoopInvariant(Operation &op, Value indVar,
SmallPtrSetImpl<Operation *> &definedOps,
SmallPtrSetImpl<Operation *> &opsToHoist) {
LLVM_DEBUG(llvm::dbgs() << "iterating on op: " << op;);
@@ -97,9 +88,9 @@
return false;
} else if (!isa<ConstantOp>(op)) {
if (isMemRefDereferencingOp(op)) {
- Value *memref = isa<AffineLoadOp>(op)
- ? cast<AffineLoadOp>(op).getMemRef()
- : cast<AffineStoreOp>(op).getMemRef();
+ Value memref = isa<AffineLoadOp>(op)
+ ? cast<AffineLoadOp>(op).getMemRef()
+ : cast<AffineStoreOp>(op).getMemRef();
for (auto *user : memref->getUsers()) {
// If this memref has a user that is a DMA, give up because these
// operations write to this memref.
@@ -163,7 +154,7 @@
// Checks if all ops in a region (i.e. list of blocks) are loop invariant.
bool areAllOpsInTheBlockListInvariant(
- Region &blockList, Value *indVar, SmallPtrSetImpl<Operation *> &definedOps,
+ Region &blockList, Value indVar, SmallPtrSetImpl<Operation *> &definedOps,
SmallPtrSetImpl<Operation *> &opsToHoist) {
for (auto &b : blockList) {
@@ -178,7 +169,7 @@
}
// Returns true if the affine.if op can be hoisted.
-bool checkInvarianceOfNestedIfOps(Operation *op, Value *indVar,
+bool checkInvarianceOfNestedIfOps(Operation *op, Value indVar,
SmallPtrSetImpl<Operation *> &definedOps,
SmallPtrSetImpl<Operation *> &opsToHoist) {
assert(isa<AffineIfOp>(op));
@@ -199,7 +190,7 @@
void LoopInvariantCodeMotion::runOnAffineForOp(AffineForOp forOp) {
auto *loopBody = forOp.getBody();
- auto *indVar = forOp.getInductionVar();
+ auto indVar = forOp.getInductionVar();
SmallPtrSet<Operation *, 8> definedOps;
// This is the place where hoisted instructions would reside.
diff --git a/third_party/mlir/lib/Transforms/CSE.cpp b/third_party/mlir/lib/Transforms/CSE.cpp
index 18f9fce..714fb1d 100644
--- a/third_party/mlir/lib/Transforms/CSE.cpp
+++ b/third_party/mlir/lib/Transforms/CSE.cpp
@@ -1,19 +1,10 @@
//===- CSE.cpp - Common Sub-expression Elimination ------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This transformation pass performs a simple common sub-expression elimination
// algorithm on operations within a function.
diff --git a/third_party/mlir/lib/Transforms/Canonicalizer.cpp b/third_party/mlir/lib/Transforms/Canonicalizer.cpp
index 7dcdeb6..5b3a1eb 100644
--- a/third_party/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/third_party/mlir/lib/Transforms/Canonicalizer.cpp
@@ -1,19 +1,10 @@
//===- Canonicalizer.cpp - Canonicalize MLIR operations -------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This transformation pass converts operations into their canonical forms by
// folding constants, applying operation identity transformations etc.
diff --git a/third_party/mlir/lib/Transforms/DialectConversion.cpp b/third_party/mlir/lib/Transforms/DialectConversion.cpp
index 37c918f..5f7fb7a 100644
--- a/third_party/mlir/lib/Transforms/DialectConversion.cpp
+++ b/third_party/mlir/lib/Transforms/DialectConversion.cpp
@@ -1,19 +1,10 @@
//===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/IR/Block.h"
@@ -86,13 +77,13 @@
struct ConversionValueMapping {
/// Lookup a mapped value within the map. If a mapping for the provided value
/// does not exist then return the provided value.
- Value *lookupOrDefault(Value *from) const;
+ Value lookupOrDefault(Value from) const;
/// Map a value to the one provided.
- void map(Value *oldVal, Value *newVal) { mapping.map(oldVal, newVal); }
+ void map(Value oldVal, Value newVal) { mapping.map(oldVal, newVal); }
/// Drop the last mapping for the given value.
- void erase(Value *value) { mapping.erase(value); }
+ void erase(Value value) { mapping.erase(value); }
private:
/// Current value mappings.
@@ -102,10 +93,10 @@
/// Lookup a mapped value within the map. If a mapping for the provided value
/// does not exist then return the provided value.
-Value *ConversionValueMapping::lookupOrDefault(Value *from) const {
+Value ConversionValueMapping::lookupOrDefault(Value from) const {
// If this value had a valid mapping, unmap that value as well in the case
// that it was also replaced.
- while (auto *mappedValue = mapping.lookupOrNull(from))
+ while (auto mappedValue = mapping.lookupOrNull(from))
from = mappedValue;
return from;
}
@@ -127,7 +118,7 @@
/// been converted.
struct ConvertedArgInfo {
ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
- Value *castValue = nullptr)
+ Value castValue = nullptr)
: newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
/// The start index of in the new argument list that contains arguments that
@@ -139,7 +130,7 @@
/// The cast value that was created to cast from the new arguments to the
/// old. This only used if 'newArgSize' > 1.
- Value *castValue;
+ Value castValue;
};
/// This structure contains information pertaining to a block that has had its
@@ -235,7 +226,7 @@
// Drop all uses of the original arguments and delete the original block.
Block *origBlock = it->second.origBlock;
- for (BlockArgument *arg : origBlock->getArguments())
+ for (BlockArgument arg : origBlock->getArguments())
arg->dropAllUses();
conversionInfo.erase(it);
}
@@ -270,7 +261,7 @@
// Process the remapping for each of the original arguments.
for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
Optional<ConvertedArgInfo> &argInfo = blockInfo.argInfo[i];
- BlockArgument *origArg = origBlock->getArgument(i);
+ BlockArgument origArg = origBlock->getArgument(i);
// Handle the case of a 1->0 value mapping.
if (!argInfo) {
@@ -305,7 +296,7 @@
}
// Otherwise this is a 1->N value mapping.
- Value *castValue = argInfo->castValue;
+ Value castValue = argInfo->castValue;
assert(argInfo->newArgSize > 1 && castValue && "expected 1->N mapping");
// If the argument is still used, replace it with the generated cast.
@@ -344,8 +335,8 @@
Block *newBlock = block->splitBlock(block->begin());
block->replaceAllUsesWith(newBlock);
- SmallVector<Value *, 4> newArgRange(newBlock->addArguments(convertedTypes));
- ArrayRef<Value *> newArgs(newArgRange);
+ SmallVector<Value, 4> newArgRange(newBlock->addArguments(convertedTypes));
+ ArrayRef<Value> newArgs(newArgRange);
// Remap each of the original arguments as determined by the signature
// conversion.
@@ -358,7 +349,7 @@
auto inputMap = signatureConversion.getInputMapping(i);
if (!inputMap)
continue;
- BlockArgument *origArg = block->getArgument(i);
+ BlockArgument origArg = block->getArgument(i);
// If inputMap->replacementValue is not nullptr, then the argument is
// dropped and a replacement value is provided to be the remappedValue.
@@ -415,14 +406,16 @@
/// This class contains a snapshot of the current conversion rewriter state.
/// This is useful when saving and undoing a set of rewrites.
struct RewriterState {
- RewriterState(unsigned numCreatedOperations, unsigned numReplacements,
- unsigned numBlockActions, unsigned numIgnoredOperations)
- : numCreatedOperations(numCreatedOperations),
- numReplacements(numReplacements), numBlockActions(numBlockActions),
- numIgnoredOperations(numIgnoredOperations) {}
+ RewriterState(unsigned numCreatedOps, unsigned numReplacements,
+ unsigned numBlockActions, unsigned numIgnoredOperations,
+ unsigned numRootUpdates)
+ : numCreatedOps(numCreatedOps), numReplacements(numReplacements),
+ numBlockActions(numBlockActions),
+ numIgnoredOperations(numIgnoredOperations),
+ numRootUpdates(numRootUpdates) {}
/// The current number of created operations.
- unsigned numCreatedOperations;
+ unsigned numCreatedOps;
/// The current number of replacements queued.
unsigned numReplacements;
@@ -432,6 +425,41 @@
/// The current number of ignored operations.
unsigned numIgnoredOperations;
+
+ /// The current number of operations that were updated in place.
+ unsigned numRootUpdates;
+};
+
+/// The state of an operation that was updated by a pattern in-place. This
+/// contains all of the necessary information to reconstruct an operation that
+/// was updated in place.
+class OperationTransactionState {
+public:
+ OperationTransactionState() = default;
+ OperationTransactionState(Operation *op)
+ : op(op), loc(op->getLoc()), attrs(op->getAttrList()),
+ operands(op->operand_begin(), op->operand_end()),
+ successors(op->successor_begin(), op->successor_end()) {}
+
+ /// Discard the transaction state and reset the state of the original
+ /// operation.
+ void resetOperation() const {
+ op->setLoc(loc);
+ op->setAttrs(attrs);
+ op->setOperands(operands);
+ for (auto it : llvm::enumerate(successors))
+ op->setSuccessor(it.value(), it.index());
+ }
+
+ /// Return the original operation of this state.
+ Operation *getOperation() const { return op; }
+
+private:
+ Operation *op;
+ LocationAttr loc;
+ NamedAttributeList attrs;
+ SmallVector<Value, 8> operands;
+ SmallVector<Block *, 2> successors;
};
} // end anonymous namespace
@@ -445,7 +473,7 @@
: op(op), newValues(newValues.begin(), newValues.end()) {}
Operation *op;
- SmallVector<Value *, 2> newValues;
+ SmallVector<Value, 2> newValues;
};
/// The kind of the block action performed during the rewrite. Actions can be
@@ -542,7 +570,7 @@
/// Remap the given operands to those with potentially different types.
void remapValues(Operation::operand_range operands,
- SmallVectorImpl<Value *> &remapped);
+ SmallVectorImpl<Value> &remapped);
/// Returns true if the given operation is ignored, and does not need to be
/// converted.
@@ -576,27 +604,43 @@
/// the others. This simplifies the amount of memory needed as we can query if
/// the parent operation was ignored.
llvm::SetVector<Operation *> ignoredOps;
+
+ /// A transaction state for each of operations that were updated in-place.
+ SmallVector<OperationTransactionState, 4> rootUpdates;
+
+#ifndef NDEBUG
+ /// A set of operations that have pending updates. This tracking isn't
+ /// strictly necessary, and is thus only active during debug builds for extra
+ /// verification.
+ SmallPtrSet<Operation *, 1> pendingRootUpdates;
+#endif
};
} // end namespace detail
} // end namespace mlir
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
return RewriterState(createdOps.size(), replacements.size(),
- blockActions.size(), ignoredOps.size());
+ blockActions.size(), ignoredOps.size(),
+ rootUpdates.size());
}
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
+ // Reset any operations that were updated in place.
+ for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i)
+ rootUpdates[i].resetOperation();
+ rootUpdates.resize(state.numRootUpdates);
+
// Undo any block actions.
undoBlockActions(state.numBlockActions);
// Reset any replaced operations and undo any saved mappings.
for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
- for (auto *result : repl.op->getResults())
+ for (auto result : repl.op->getResults())
mapping.erase(result);
replacements.resize(state.numReplacements);
// Pop all of the newly created operations.
- while (createdOps.size() != state.numCreatedOperations) {
+ while (createdOps.size() != state.numCreatedOps) {
createdOps.back()->erase();
createdOps.pop_back();
}
@@ -649,6 +693,10 @@
}
void ConversionPatternRewriterImpl::discardRewrites() {
+ // Reset any operations that were updated in place.
+ for (auto &state : rootUpdates)
+ state.resetOperation();
+
undoBlockActions();
// Remove any newly created ops.
@@ -660,7 +708,7 @@
// Apply all of the rewrites replacements requested during conversion.
for (auto &repl : replacements) {
for (unsigned i = 0, e = repl.newValues.size(); i != e; ++i) {
- if (auto *newValue = repl.newValues[i])
+ if (auto newValue = repl.newValues[i])
repl.op->getResult(i)->replaceAllUsesWith(
mapping.lookupOrDefault(newValue));
}
@@ -715,7 +763,7 @@
// Create mappings for each of the new result values.
for (unsigned i = 0, e = newValues.size(); i < e; ++i)
- if (auto *repl = newValues[i])
+ if (auto repl = newValues[i])
mapping.map(op->getResult(i), repl);
// Record the requested operation replacement.
@@ -755,9 +803,9 @@
}
void ConversionPatternRewriterImpl::remapValues(
- Operation::operand_range operands, SmallVectorImpl<Value *> &remapped) {
+ Operation::operand_range operands, SmallVectorImpl<Value> &remapped) {
remapped.reserve(llvm::size(operands));
- for (Value *operand : operands)
+ for (Value operand : operands)
remapped.push_back(mapping.lookupOrDefault(operand));
}
@@ -803,7 +851,7 @@
void ConversionPatternRewriter::eraseOp(Operation *op) {
LLVM_DEBUG(llvm::dbgs() << "** Erasing operation : " << op->getName()
<< "\n");
- SmallVector<Value *, 1> nullRepls(op->getNumResults(), nullptr);
+ SmallVector<Value, 1> nullRepls(op->getNumResults(), nullptr);
impl->replaceOp(op, nullRepls, /*valuesToRemoveIfDead=*/llvm::None);
}
@@ -813,8 +861,8 @@
return impl->applySignatureConversion(region, conversion);
}
-void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument *from,
- Value *to) {
+void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
+ Value to) {
for (auto &u : from->getUses()) {
if (u.getOwner() == to->getDefiningOp())
continue;
@@ -825,7 +873,7 @@
/// Return the converted value that replaces 'key'. Return 'key' if there is
/// no such a converted value.
-Value *ConversionPatternRewriter::getRemappedValue(Value *key) {
+Value ConversionPatternRewriter::getRemappedValue(Value key) {
return impl->mapping.lookupOrDefault(key);
}
@@ -876,11 +924,34 @@
}
/// PatternRewriter hook for updating the root operation in-place.
-void ConversionPatternRewriter::notifyRootUpdated(Operation *op) {
- // The rewriter caches changes to the IR to allow for operating in-place and
- // backtracking. The rewriter is currently not capable of backtracking
- // in-place modifications.
- llvm_unreachable("in-place operation updates are not supported");
+void ConversionPatternRewriter::startRootUpdate(Operation *op) {
+#ifndef NDEBUG
+ impl->pendingRootUpdates.insert(op);
+#endif
+ impl->rootUpdates.emplace_back(op);
+}
+
+/// PatternRewriter hook for updating the root operation in-place.
+void ConversionPatternRewriter::finalizeRootUpdate(Operation *op) {
+ // There is nothing to do here, we only need to track the operation at the
+ // start of the update.
+#ifndef NDEBUG
+ assert(impl->pendingRootUpdates.erase(op) &&
+ "operation did not have a pending in-place update");
+#endif
+}
+
+/// PatternRewriter hook for updating the root operation in-place.
+void ConversionPatternRewriter::cancelRootUpdate(Operation *op) {
+#ifndef NDEBUG
+ assert(impl->pendingRootUpdates.erase(op) &&
+ "operation did not have a pending in-place update");
+#endif
+ // Erase the last update for this operation.
+ auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; };
+ auto &rootUpdates = impl->rootUpdates;
+ auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp);
+ rootUpdates.erase(rootUpdates.begin() + (rootUpdates.rend() - it));
}
/// Return a reference to the internal implementation.
@@ -896,7 +967,7 @@
PatternMatchResult
ConversionPattern::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
- SmallVector<Value *, 4> operands;
+ SmallVector<Value, 4> operands;
auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
dialectRewriter.getImpl().remapValues(op->getOperands(), operands);
@@ -908,7 +979,7 @@
SmallVector<Block *, 2> destinations;
destinations.reserve(op->getNumSuccessors());
- SmallVector<ArrayRef<Value *>, 2> operandsPerDestination;
+ SmallVector<ArrayRef<Value>, 2> operandsPerDestination;
unsigned firstSuccessorOperand = op->getSuccessorOperandIndex(0);
for (unsigned i = 0, seen = 0, e = op->getNumSuccessors(); i < e; ++i) {
destinations.push_back(op->getSuccessor(i));
@@ -1059,7 +1130,7 @@
RewriterState curState = rewriterImpl.getCurrentState();
// Try to fold the operation.
- SmallVector<Value *, 2> replacementValues;
+ SmallVector<Value, 2> replacementValues;
rewriter.setInsertionPoint(op);
if (failed(rewriter.tryFold(op, replacementValues)))
return failure();
@@ -1068,8 +1139,7 @@
rewriter.replaceOp(op, replacementValues);
// Recursively legalize any new constant operations.
- for (unsigned i = curState.numCreatedOperations,
- e = rewriterImpl.createdOps.size();
+ for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
i != e; ++i) {
Operation *cstOp = rewriterImpl.createdOps[i];
if (failed(legalize(cstOp, rewriter))) {
@@ -1111,7 +1181,12 @@
// Try to rewrite with the given pattern.
rewriter.setInsertionPoint(op);
- if (!pattern->matchAndRewrite(op, rewriter)) {
+ auto matchedPattern = pattern->matchAndRewrite(op, rewriter);
+#ifndef NDEBUG
+ assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
+#endif
+
+ if (!matchedPattern) {
LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Pattern failed to match.\n");
return cleanupFailure();
}
@@ -1148,12 +1223,32 @@
else
rewriterImpl.ignoredOps.insert(replacedOp);
}
- assert(replacedRoot && "expected pattern to replace the root operation");
+
+ // Check that the root was either updated or replace.
+ auto updatedRootInPlace = [&] {
+ return llvm::any_of(
+ llvm::drop_begin(rewriterImpl.rootUpdates, curState.numRootUpdates),
+ [op](auto &state) { return state.getOperation() == op; });
+ };
(void)replacedRoot;
+ (void)updatedRootInPlace;
+ assert((replacedRoot || updatedRootInPlace()) &&
+ "expected pattern to replace the root operation");
+
+ // Recursively legalize each of the operations updated in place.
+ for (unsigned i = curState.numRootUpdates,
+ e = rewriterImpl.rootUpdates.size();
+ i != e; ++i) {
+ auto &state = rewriterImpl.rootUpdates[i];
+ if (failed(legalize(state.getOperation(), rewriter))) {
+ LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Operation updated in-place '"
+ << op->getName() << "' was illegal.\n");
+ return cleanupFailure();
+ }
+ }
// Recursively legalize each of the new operations.
- for (unsigned i = curState.numCreatedOperations,
- e = rewriterImpl.createdOps.size();
+ for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
i != e; ++i) {
Operation *op = rewriterImpl.createdOps[i];
if (failed(legalize(op, rewriter))) {
@@ -1459,7 +1554,7 @@
/// Remap an input of the original signature to another `replacementValue`
/// value. This would make the signature converter drop this argument.
void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
- Value *replacementValue) {
+ Value replacementValue) {
assert(!remappedInputs[origInputNo] && "input has already been remapped");
remappedInputs[origInputNo] =
InputMapping{origInputNo, /*size=*/0, replacementValue};
@@ -1528,7 +1623,7 @@
/// Hook for derived classes to implement combined matching and rewriting.
PatternMatchResult
- matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
+ matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
FunctionType type = funcOp.getType();
@@ -1543,16 +1638,12 @@
if (failed(converter.convertTypes(type.getResults(), convertedResults)))
return matchFailure();
- // Create a new function with an updated signature.
- auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
- rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
- newFuncOp.end());
- newFuncOp.setType(FunctionType::get(result.getConvertedTypes(),
- convertedResults, funcOp.getContext()));
-
- // Tell the rewriter to convert the region signature.
- rewriter.applySignatureConversion(&newFuncOp.getBody(), result);
- rewriter.eraseOp(funcOp);
+ // Update the function signature in-place.
+ rewriter.updateRootInPlace(funcOp, [&] {
+ funcOp.setType(FunctionType::get(result.getConvertedTypes(),
+ convertedResults, funcOp.getContext()));
+ rewriter.applySignatureConversion(&funcOp.getBody(), result);
+ });
return matchSuccess();
}
diff --git a/third_party/mlir/lib/Transforms/Inliner.cpp b/third_party/mlir/lib/Transforms/Inliner.cpp
index b158948..b2cee7d 100644
--- a/third_party/mlir/lib/Transforms/Inliner.cpp
+++ b/third_party/mlir/lib/Transforms/Inliner.cpp
@@ -1,19 +1,10 @@
//===- Inliner.cpp - Pass to inline function calls ------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements a basic inlining algorithm that operates bottom up over
// the Strongly Connect Components(SCCs) of the CallGraph. This enables a more
diff --git a/third_party/mlir/lib/Transforms/LoopCoalescing.cpp b/third_party/mlir/lib/Transforms/LoopCoalescing.cpp
index c1eec56..2aee688 100644
--- a/third_party/mlir/lib/Transforms/LoopCoalescing.cpp
+++ b/third_party/mlir/lib/Transforms/LoopCoalescing.cpp
@@ -1,19 +1,10 @@
//===- LoopCoalescing.cpp - Pass transforming loop nests into single loops-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LoopOps/LoopOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
diff --git a/third_party/mlir/lib/Transforms/LoopFusion.cpp b/third_party/mlir/lib/Transforms/LoopFusion.cpp
index 5694c99..fcfc1d7 100644
--- a/third_party/mlir/lib/Transforms/LoopFusion.cpp
+++ b/third_party/mlir/lib/Transforms/LoopFusion.cpp
@@ -1,19 +1,10 @@
//===- LoopFusion.cpp - Code to perform loop fusion -----------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements loop fusion.
//
@@ -172,7 +163,7 @@
Node(unsigned id, Operation *op) : id(id), op(op) {}
// Returns the load op count for 'memref'.
- unsigned getLoadOpCount(Value *memref) {
+ unsigned getLoadOpCount(Value memref) {
unsigned loadOpCount = 0;
for (auto *loadOpInst : loads) {
if (memref == cast<AffineLoadOp>(loadOpInst).getMemRef())
@@ -182,7 +173,7 @@
}
// Returns the store op count for 'memref'.
- unsigned getStoreOpCount(Value *memref) {
+ unsigned getStoreOpCount(Value memref) {
unsigned storeOpCount = 0;
for (auto *storeOpInst : stores) {
if (memref == cast<AffineStoreOp>(storeOpInst).getMemRef())
@@ -192,7 +183,7 @@
}
// Returns all store ops in 'storeOps' which access 'memref'.
- void getStoreOpsForMemref(Value *memref,
+ void getStoreOpsForMemref(Value memref,
SmallVectorImpl<Operation *> *storeOps) {
for (auto *storeOpInst : stores) {
if (memref == cast<AffineStoreOp>(storeOpInst).getMemRef())
@@ -201,7 +192,7 @@
}
// Returns all load ops in 'loadOps' which access 'memref'.
- void getLoadOpsForMemref(Value *memref,
+ void getLoadOpsForMemref(Value memref,
SmallVectorImpl<Operation *> *loadOps) {
for (auto *loadOpInst : loads) {
if (memref == cast<AffineLoadOp>(loadOpInst).getMemRef())
@@ -211,13 +202,13 @@
// Returns all memrefs in 'loadAndStoreMemrefSet' for which this node
// has at least one load and store operation.
- void getLoadAndStoreMemrefSet(DenseSet<Value *> *loadAndStoreMemrefSet) {
- llvm::SmallDenseSet<Value *, 2> loadMemrefs;
+ void getLoadAndStoreMemrefSet(DenseSet<Value> *loadAndStoreMemrefSet) {
+ llvm::SmallDenseSet<Value, 2> loadMemrefs;
for (auto *loadOpInst : loads) {
loadMemrefs.insert(cast<AffineLoadOp>(loadOpInst).getMemRef());
}
for (auto *storeOpInst : stores) {
- auto *memref = cast<AffineStoreOp>(storeOpInst).getMemRef();
+ auto memref = cast<AffineStoreOp>(storeOpInst).getMemRef();
if (loadMemrefs.count(memref) > 0)
loadAndStoreMemrefSet->insert(memref);
}
@@ -239,7 +230,7 @@
// defines an SSA value and another graph node which uses the SSA value
// (e.g. a constant operation defining a value which is used inside a loop
// nest).
- Value *value;
+ Value value;
};
// Map from node id to Node.
@@ -250,7 +241,7 @@
DenseMap<unsigned, SmallVector<Edge, 2>> outEdges;
// Map from memref to a count on the dependence edges associated with that
// memref.
- DenseMap<Value *, unsigned> memrefEdgeCount;
+ DenseMap<Value, unsigned> memrefEdgeCount;
// The next unique identifier to use for newly created graph nodes.
unsigned nextNodeId = 0;
@@ -309,7 +300,7 @@
bool writesToLiveInOrEscapingMemrefs(unsigned id) {
Node *node = getNode(id);
for (auto *storeOpInst : node->stores) {
- auto *memref = cast<AffineStoreOp>(storeOpInst).getMemRef();
+ auto memref = cast<AffineStoreOp>(storeOpInst).getMemRef();
auto *op = memref->getDefiningOp();
// Return true if 'memref' is a block argument.
if (!op)
@@ -338,7 +329,7 @@
const auto &nodeOutEdges = outEdgeIt->second;
for (auto *op : node->stores) {
auto storeOp = cast<AffineStoreOp>(op);
- auto *memref = storeOp.getMemRef();
+ auto memref = storeOp.getMemRef();
// Skip this store if there are no dependences on its memref. This means
// that store either:
// *) writes to a memref that is only read within the same loop nest
@@ -381,7 +372,7 @@
// Returns true iff there is an edge from node 'srcId' to node 'dstId' which
// is for 'value' if non-null, or for any value otherwise. Returns false
// otherwise.
- bool hasEdge(unsigned srcId, unsigned dstId, Value *value = nullptr) {
+ bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr) {
if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
return false;
}
@@ -395,7 +386,7 @@
}
// Adds an edge from node 'srcId' to node 'dstId' for 'value'.
- void addEdge(unsigned srcId, unsigned dstId, Value *value) {
+ void addEdge(unsigned srcId, unsigned dstId, Value value) {
if (!hasEdge(srcId, dstId, value)) {
outEdges[srcId].push_back({dstId, value});
inEdges[dstId].push_back({srcId, value});
@@ -405,7 +396,7 @@
}
// Removes an edge from node 'srcId' to node 'dstId' for 'value'.
- void removeEdge(unsigned srcId, unsigned dstId, Value *value) {
+ void removeEdge(unsigned srcId, unsigned dstId, Value value) {
assert(inEdges.count(dstId) > 0);
assert(outEdges.count(srcId) > 0);
if (value->getType().isa<MemRefType>()) {
@@ -459,7 +450,7 @@
// Returns the input edge count for node 'id' and 'memref' from src nodes
// which access 'memref' with a store operation.
- unsigned getIncomingMemRefAccesses(unsigned id, Value *memref) {
+ unsigned getIncomingMemRefAccesses(unsigned id, Value memref) {
unsigned inEdgeCount = 0;
if (inEdges.count(id) > 0)
for (auto &inEdge : inEdges[id])
@@ -474,7 +465,7 @@
// Returns the output edge count for node 'id' and 'memref' (if non-null),
// otherwise returns the total output edge count from node 'id'.
- unsigned getOutEdgeCount(unsigned id, Value *memref = nullptr) {
+ unsigned getOutEdgeCount(unsigned id, Value memref = nullptr) {
unsigned outEdgeCount = 0;
if (outEdges.count(id) > 0)
for (auto &outEdge : outEdges[id])
@@ -548,7 +539,7 @@
// Updates edge mappings from node 'srcId' to node 'dstId' after 'oldMemRef'
// has been replaced in node at 'dstId' by a private memref depending
// on the value of 'createPrivateMemRef'.
- void updateEdges(unsigned srcId, unsigned dstId, Value *oldMemRef,
+ void updateEdges(unsigned srcId, unsigned dstId, Value oldMemRef,
bool createPrivateMemRef) {
// For each edge in 'inEdges[srcId]': add new edge remaping to 'dstId'.
if (inEdges.count(srcId) > 0) {
@@ -681,7 +672,7 @@
// TODO(andydavis) Add support for taking a Block arg to construct the
// dependence graph at a different depth.
bool MemRefDependenceGraph::init(FuncOp f) {
- DenseMap<Value *, SetVector<unsigned>> memrefAccesses;
+ DenseMap<Value, SetVector<unsigned>> memrefAccesses;
// TODO: support multi-block functions.
if (f.getBlocks().size() != 1)
@@ -701,12 +692,12 @@
Node node(nextNodeId++, &op);
for (auto *opInst : collector.loadOpInsts) {
node.loads.push_back(opInst);
- auto *memref = cast<AffineLoadOp>(opInst).getMemRef();
+ auto memref = cast<AffineLoadOp>(opInst).getMemRef();
memrefAccesses[memref].insert(node.id);
}
for (auto *opInst : collector.storeOpInsts) {
node.stores.push_back(opInst);
- auto *memref = cast<AffineStoreOp>(opInst).getMemRef();
+ auto memref = cast<AffineStoreOp>(opInst).getMemRef();
memrefAccesses[memref].insert(node.id);
}
forToNodeMap[&op] = node.id;
@@ -715,14 +706,14 @@
// Create graph node for top-level load op.
Node node(nextNodeId++, &op);
node.loads.push_back(&op);
- auto *memref = cast<AffineLoadOp>(op).getMemRef();
+ auto memref = cast<AffineLoadOp>(op).getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
} else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) {
// Create graph node for top-level store op.
Node node(nextNodeId++, &op);
node.stores.push_back(&op);
- auto *memref = cast<AffineStoreOp>(op).getMemRef();
+ auto memref = cast<AffineStoreOp>(op).getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
} else if (op.getNumRegions() != 0) {
@@ -743,7 +734,7 @@
if (!node.loads.empty() || !node.stores.empty())
continue;
auto *opInst = node.op;
- for (auto *value : opInst->getResults()) {
+ for (auto value : opInst->getResults()) {
for (auto *user : value->getUsers()) {
SmallVector<AffineForOp, 4> loops;
getLoopIVs(*user, &loops);
@@ -777,7 +768,7 @@
// Removes load operations from 'srcLoads' which operate on 'memref', and
// adds them to 'dstLoads'.
-static void moveLoadsAccessingMemrefTo(Value *memref,
+static void moveLoadsAccessingMemrefTo(Value memref,
SmallVectorImpl<Operation *> *srcLoads,
SmallVectorImpl<Operation *> *dstLoads) {
dstLoads->clear();
@@ -893,10 +884,10 @@
// MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
// TODO(bondhugula): consider refactoring the common code from generateDma and
// this one.
-static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
- unsigned dstLoopDepth,
- Optional<unsigned> fastMemorySpace,
- uint64_t localBufSizeThreshold) {
+static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
+ unsigned dstLoopDepth,
+ Optional<unsigned> fastMemorySpace,
+ uint64_t localBufSizeThreshold) {
auto *forInst = forOp.getOperation();
// Create builder to insert alloc op just before 'forOp'.
@@ -904,7 +895,7 @@
// Builder to create constants at the top level.
OpBuilder top(forInst->getParentOfType<FuncOp>().getBody());
// Create new memref type based on slice bounds.
- auto *oldMemRef = cast<AffineStoreOp>(srcStoreOpInst).getMemRef();
+ auto oldMemRef = cast<AffineStoreOp>(srcStoreOpInst).getMemRef();
auto oldMemRefType = oldMemRef->getType().cast<MemRefType>();
unsigned rank = oldMemRefType.getRank();
@@ -928,7 +919,7 @@
// 'outerIVs' holds the values that this memory region is symbolic/parametric
// on; this would correspond to loop IVs surrounding the level at which the
// slice is being materialized.
- SmallVector<Value *, 8> outerIVs;
+ SmallVector<Value, 8> outerIVs;
cst->getIdValues(rank, cst->getNumIds(), &outerIVs);
// Build 'rank' AffineExprs from MemRefRegion 'lbs'
@@ -960,7 +951,7 @@
auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(),
{}, newMemSpace);
// Gather alloc operands for the dynamic dimensions of the memref.
- SmallVector<Value *, 4> allocOperands;
+ SmallVector<Value, 4> allocOperands;
unsigned dynamicDimCount = 0;
for (auto dimSize : oldMemRefType.getShape()) {
if (dimSize == -1)
@@ -973,7 +964,7 @@
// consumer loop nests to reduce their live range. Currently they are added
// at the beginning of the function, because loop nests can be reordered
// during the fusion pass.
- Value *newMemRef =
+ Value newMemRef =
top.create<AllocOp>(forOp.getLoc(), newMemRefType, allocOperands);
// Build an AffineMap to remap access functions based on lower bound offsets.
@@ -1016,7 +1007,7 @@
MemRefDependenceGraph *mdg) {
assert(srcLiveOutStoreOp && "Expected a valid store op");
auto *dstNode = mdg->getNode(dstId);
- Value *memref = srcLiveOutStoreOp.getMemRef();
+ Value memref = srcLiveOutStoreOp.getMemRef();
// Return false if 'srcNode' has more than one output edge on 'memref'.
if (mdg->getOutEdgeCount(srcId, memref) > 1)
return false;
@@ -1495,10 +1486,10 @@
SmallVector<Operation *, 4> loads = dstNode->loads;
SmallVector<Operation *, 4> dstLoadOpInsts;
- DenseSet<Value *> visitedMemrefs;
+ DenseSet<Value> visitedMemrefs;
while (!loads.empty()) {
// Get memref of load on top of the stack.
- auto *memref = cast<AffineLoadOp>(loads.back()).getMemRef();
+ auto memref = cast<AffineLoadOp>(loads.back()).getMemRef();
if (visitedMemrefs.count(memref) > 0)
continue;
visitedMemrefs.insert(memref);
@@ -1653,7 +1644,7 @@
}
// TODO(andydavis) Use union of memref write regions to compute
// private memref footprint.
- auto *newMemRef = createPrivateMemRef(
+ auto newMemRef = createPrivateMemRef(
dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
fastMemorySpace, localBufSizeThreshold);
visitedMemrefs.insert(newMemRef);
@@ -1671,7 +1662,7 @@
// Add new load ops to current Node load op list 'loads' to
// continue fusing based on new operands.
for (auto *loadOpInst : dstLoopCollector.loadOpInsts) {
- auto *loadMemRef = cast<AffineLoadOp>(loadOpInst).getMemRef();
+ auto loadMemRef = cast<AffineLoadOp>(loadOpInst).getMemRef();
if (visitedMemrefs.count(loadMemRef) == 0)
loads.push_back(loadOpInst);
}
@@ -1737,10 +1728,10 @@
// Attempt to fuse 'dstNode' with sibling nodes in the graph.
void fuseWithSiblingNodes(Node *dstNode) {
DenseSet<unsigned> visitedSibNodeIds;
- std::pair<unsigned, Value *> idAndMemref;
+ std::pair<unsigned, Value> idAndMemref;
while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) {
unsigned sibId = idAndMemref.first;
- Value *memref = idAndMemref.second;
+ Value memref = idAndMemref.second;
// TODO(andydavis) Check that 'sibStoreOpInst' post-dominates all other
// stores to the same memref in 'sibNode' loop nest.
auto *sibNode = mdg->getNode(sibId);
@@ -1804,10 +1795,10 @@
// 'idAndMemrefToFuse' on success. Returns false otherwise.
bool findSiblingNodeToFuse(Node *dstNode,
DenseSet<unsigned> *visitedSibNodeIds,
- std::pair<unsigned, Value *> *idAndMemrefToFuse) {
+ std::pair<unsigned, Value> *idAndMemrefToFuse) {
// Returns true if 'sibNode' can be fused with 'dstNode' for input reuse
// on 'memref'.
- auto canFuseWithSibNode = [&](Node *sibNode, Value *memref) {
+ auto canFuseWithSibNode = [&](Node *sibNode, Value memref) {
// Skip if 'outEdge' is not a read-after-write dependence.
// TODO(andydavis) Remove restrict to single load op restriction.
if (sibNode->getLoadOpCount(memref) != 1)
@@ -1819,15 +1810,15 @@
return false;
// Skip sib node if it loads to (and stores from) the same memref on
// which it also has an input dependence edge.
- DenseSet<Value *> loadAndStoreMemrefSet;
+ DenseSet<Value> loadAndStoreMemrefSet;
sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet);
- if (llvm::any_of(loadAndStoreMemrefSet, [=](Value *memref) {
+ if (llvm::any_of(loadAndStoreMemrefSet, [=](Value memref) {
return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0;
}))
return false;
// Check that all stores are to the same memref.
- DenseSet<Value *> storeMemrefs;
+ DenseSet<Value> storeMemrefs;
for (auto *storeOpInst : sibNode->stores) {
storeMemrefs.insert(cast<AffineStoreOp>(storeOpInst).getMemRef());
}
@@ -1856,7 +1847,7 @@
if (visitedSibNodeIds->count(sibNode->id) > 0)
continue;
// Skip 'use' if it does not load from the same memref as 'dstNode'.
- auto *memref = loadOp.getMemRef();
+ auto memref = loadOp.getMemRef();
if (dstNode->getLoadOpCount(memref) == 0)
continue;
// Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
@@ -1950,7 +1941,7 @@
for (auto &pair : mdg->memrefEdgeCount) {
if (pair.second > 0)
continue;
- auto *memref = pair.first;
+ auto memref = pair.first;
// Skip if there exist other uses (return operation or function calls).
if (!memref->use_empty())
continue;
diff --git a/third_party/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/third_party/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
index 4932494..fb3d0c0 100644
--- a/third_party/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
+++ b/third_party/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
@@ -1,19 +1,10 @@
//===- LoopInvariantCodeMotion.cpp - Code to perform loop fusion-----------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements loop invariant code motion.
//
@@ -50,7 +41,7 @@
// - the op has no side-effects. If sideEffecting is Never, sideeffects of this
// op and its nested ops are ignored.
static bool canBeHoisted(Operation *op,
- function_ref<bool(Value *)> definedOutside,
+ function_ref<bool(Value)> definedOutside,
SideEffecting sideEffecting,
SideEffectsInterface &interface) {
// Check that dependencies are defined outside of loop.
@@ -92,7 +83,7 @@
SmallVector<Operation *, 8> opsToMove;
// Helper to check whether an operation is loop invariant wrt. SSA properties.
- auto isDefinedOutsideOfBody = [&](Value *value) {
+ auto isDefinedOutsideOfBody = [&](Value value) {
auto definingOp = value->getDefiningOp();
return (definingOp && !!willBeMovedSet.count(definingOp)) ||
looplike.isDefinedOutsideOfLoop(value);
diff --git a/third_party/mlir/lib/Transforms/LoopTiling.cpp b/third_party/mlir/lib/Transforms/LoopTiling.cpp
index 1065478..d3dc817 100644
--- a/third_party/mlir/lib/Transforms/LoopTiling.cpp
+++ b/third_party/mlir/lib/Transforms/LoopTiling.cpp
@@ -1,19 +1,10 @@
//===- LoopTiling.cpp --- Loop tiling pass ------------------------------*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements a pass to tile loop nests.
//
@@ -120,8 +111,8 @@
for (unsigned i = 0; i < width; i++) {
auto lbOperands = origLoops[i].getLowerBoundOperands();
auto ubOperands = origLoops[i].getUpperBoundOperands();
- SmallVector<Value *, 4> newLbOperands(lbOperands);
- SmallVector<Value *, 4> newUbOperands(ubOperands);
+ SmallVector<Value, 4> newLbOperands(lbOperands);
+ SmallVector<Value, 4> newUbOperands(ubOperands);
newLoops[i].setLowerBound(newLbOperands, origLoops[i].getLowerBoundMap());
newLoops[i].setUpperBound(newUbOperands, origLoops[i].getUpperBoundMap());
newLoops[i].setStep(tileSizes[i]);
@@ -147,7 +138,7 @@
// with 'i' (tile-space loop) appended to it. The new upper bound map is
// the original one with an additional expression i + tileSize appended.
auto ub = origLoops[i].getUpperBound();
- SmallVector<Value *, 4> ubOperands;
+ SmallVector<Value, 4> ubOperands;
ubOperands.reserve(ub.getNumOperands() + 1);
auto origUbMap = ub.getMap();
// Add dim operands from original upper bound.
@@ -235,9 +226,9 @@
// Move the loop body of the original nest to the new one.
moveLoopBody(origLoops[origLoops.size() - 1], innermostPointLoop);
- SmallVector<Value *, 8> origLoopIVs;
+ SmallVector<Value, 8> origLoopIVs;
extractForInductionVars(band, &origLoopIVs);
- SmallVector<Optional<Value *>, 6> ids(origLoopIVs.begin(), origLoopIVs.end());
+ SmallVector<Optional<Value>, 6> ids(origLoopIVs.begin(), origLoopIVs.end());
FlatAffineConstraints cst;
getIndexSet(band, &cst);
diff --git a/third_party/mlir/lib/Transforms/LoopUnroll.cpp b/third_party/mlir/lib/Transforms/LoopUnroll.cpp
index 40f48ad..e94c6c8 100644
--- a/third_party/mlir/lib/Transforms/LoopUnroll.cpp
+++ b/third_party/mlir/lib/Transforms/LoopUnroll.cpp
@@ -1,19 +1,10 @@
//===- LoopUnroll.cpp - Code to perform loop unrolling --------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements loop unrolling.
//
diff --git a/third_party/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/third_party/mlir/lib/Transforms/LoopUnrollAndJam.cpp
index 230869a..6c74d54 100644
--- a/third_party/mlir/lib/Transforms/LoopUnrollAndJam.cpp
+++ b/third_party/mlir/lib/Transforms/LoopUnrollAndJam.cpp
@@ -1,19 +1,10 @@
//===- LoopUnrollAndJam.cpp - Code to perform loop unroll and jam ---------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements loop unroll and jam. Unroll and jam is a transformation
// that improves locality, in particular, register reuse, while also improving
@@ -191,7 +182,7 @@
// Adjust the lower bound of the cleanup loop; its upper bound is the same
// as the original loop's upper bound.
AffineMap cleanupMap;
- SmallVector<Value *, 4> cleanupOperands;
+ SmallVector<Value, 4> cleanupOperands;
getCleanupLoopLowerBound(forOp, unrollJamFactor, &cleanupMap,
&cleanupOperands, builder);
cleanupAffineForOp.setLowerBound(cleanupOperands, cleanupMap);
@@ -208,7 +199,7 @@
int64_t step = forOp.getStep();
forOp.setStep(step * unrollJamFactor);
- auto *forOpIV = forOp.getInductionVar();
+ auto forOpIV = forOp.getInductionVar();
// Unroll and jam (appends unrollJamFactor - 1 additional copies).
for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
// Operand map persists across all sub-blocks.
diff --git a/third_party/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/third_party/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
index c531ca5..e2514e1 100644
--- a/third_party/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
+++ b/third_party/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
@@ -1,19 +1,10 @@
//===- MemRefDataFlowOpt.cpp - MemRef DataFlow Optimization pass ------ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements a pass to forward memref stores to loads, thereby
// potentially getting rid of intermediate memref's entirely.
@@ -76,7 +67,7 @@
void forwardStoreToLoad(AffineLoadOp loadOp);
// A list of memref's that are potentially dead / could be eliminated.
- SmallPtrSet<Value *, 4> memrefsToErase;
+ SmallPtrSet<Value, 4> memrefsToErase;
// Load op's whose results were replaced by those forwarded from stores.
SmallVector<Operation *, 8> loadOpsToErase;
@@ -180,7 +171,7 @@
return;
// Perform the actual store to load forwarding.
- Value *storeVal = cast<AffineStoreOp>(lastWriteStoreOp).getValueToStore();
+ Value storeVal = cast<AffineStoreOp>(lastWriteStoreOp).getValueToStore();
loadOp.replaceAllUsesWith(storeVal);
// Record the memref for a later sweep to optimize away.
memrefsToErase.insert(loadOp.getMemRef());
@@ -213,7 +204,7 @@
// Check if the store fwd'ed memrefs are now left with only stores and can
// thus be completely deleted. Note: the canonicalize pass should be able
// to do this as well, but we'll do it here since we collected these anyway.
- for (auto *memref : memrefsToErase) {
+ for (auto memref : memrefsToErase) {
// If the memref hasn't been alloc'ed in this function, skip.
Operation *defInst = memref->getDefiningOp();
if (!defInst || !isa<AllocOp>(defInst))
diff --git a/third_party/mlir/lib/Transforms/PipelineDataTransfer.cpp b/third_party/mlir/lib/Transforms/PipelineDataTransfer.cpp
index fdf0135..dce0273 100644
--- a/third_party/mlir/lib/Transforms/PipelineDataTransfer.cpp
+++ b/third_party/mlir/lib/Transforms/PipelineDataTransfer.cpp
@@ -1,19 +1,10 @@
//===- PipelineDataTransfer.cpp --- Pass for pipelining data movement ---*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements a pass to pipeline data transfers.
//
@@ -70,7 +61,7 @@
/// Replaces all uses of the old memref by the new one while indexing the newly
/// added dimension by the loop IV of the specified 'affine.for' operation
/// modulo 2. Returns false if such a replacement cannot be performed.
-static bool doubleBuffer(Value *oldMemRef, AffineForOp forOp) {
+static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) {
auto *forBody = forOp.getBody();
OpBuilder bInner(forBody, forBody->begin());
@@ -94,7 +85,7 @@
auto *forInst = forOp.getOperation();
OpBuilder bOuter(forInst);
// Put together alloc operands for any dynamic dimensions of the memref.
- SmallVector<Value *, 4> allocOperands;
+ SmallVector<Value, 4> allocOperands;
unsigned dynamicDimCount = 0;
for (auto dimSize : oldMemRefType.getShape()) {
if (dimSize == -1)
@@ -103,7 +94,7 @@
}
// Create and place the alloc right before the 'affine.for' operation.
- Value *newMemRef =
+ Value newMemRef =
bOuter.create<AllocOp>(forInst->getLoc(), newMemRefType, allocOperands);
// Create 'iv mod 2' value to index the leading dimension.
@@ -212,7 +203,7 @@
continue;
// We only double buffer if the buffer is not live out of loop.
- auto *memref = dmaStartOp.getOperand(dmaStartOp.getFasterMemPos());
+ auto memref = dmaStartOp.getOperand(dmaStartOp.getFasterMemPos());
bool escapingUses = false;
for (auto *user : memref->getUsers()) {
// We can double buffer regardless of dealloc's outside the loop.
@@ -270,7 +261,7 @@
// dimension.
for (auto &pair : startWaitPairs) {
auto *dmaStartInst = pair.first;
- Value *oldMemRef = dmaStartInst->getOperand(
+ Value oldMemRef = dmaStartInst->getOperand(
cast<AffineDmaStartOp>(dmaStartInst).getFasterMemPos());
if (!doubleBuffer(oldMemRef, forOp)) {
// Normally, double buffering should not fail because we already checked
@@ -301,7 +292,7 @@
// Double the buffers for tag memrefs.
for (auto &pair : startWaitPairs) {
auto *dmaFinishInst = pair.second;
- Value *oldTagMemRef =
+ Value oldTagMemRef =
dmaFinishInst->getOperand(getTagMemRefPos(*dmaFinishInst));
if (!doubleBuffer(oldTagMemRef, forOp)) {
LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
@@ -342,7 +333,7 @@
// If a slice wasn't created, the reachable affine.apply op's from its
// operands are the ones that go with it.
SmallVector<Operation *, 4> affineApplyInsts;
- SmallVector<Value *, 4> operands(dmaStartInst->getOperands());
+ SmallVector<Value, 4> operands(dmaStartInst->getOperands());
getReachableAffineApplyOps(operands, affineApplyInsts);
for (auto *op : affineApplyInsts) {
instShiftMap[op] = 0;
diff --git a/third_party/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/third_party/mlir/lib/Transforms/SimplifyAffineStructures.cpp
index 9512ff7..217e06b 100644
--- a/third_party/mlir/lib/Transforms/SimplifyAffineStructures.cpp
+++ b/third_party/mlir/lib/Transforms/SimplifyAffineStructures.cpp
@@ -1,19 +1,10 @@
//===- SimplifyAffineStructures.cpp ---------------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements a pass to simplify affine structures.
//
diff --git a/third_party/mlir/lib/Transforms/StripDebugInfo.cpp b/third_party/mlir/lib/Transforms/StripDebugInfo.cpp
index 772df3d..cdfc7fd 100644
--- a/third_party/mlir/lib/Transforms/StripDebugInfo.cpp
+++ b/third_party/mlir/lib/Transforms/StripDebugInfo.cpp
@@ -1,19 +1,10 @@
//===- StripDebugInfo.cpp - Pass to strip debug information ---------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/IR/Function.h"
#include "mlir/IR/Operation.h"
diff --git a/third_party/mlir/lib/Transforms/Utils/FoldUtils.cpp b/third_party/mlir/lib/Transforms/Utils/FoldUtils.cpp
index d4b7caa..719c6fa 100644
--- a/third_party/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/third_party/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -1,19 +1,10 @@
//===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines various operation fold utilities. These utilities are
// intended to be used by passes to unify and simply their logic.
@@ -90,7 +81,7 @@
return failure();
// Try to fold the operation.
- SmallVector<Value *, 8> results;
+ SmallVector<Value, 8> results;
if (failed(tryToFold(op, results, processGeneratedConstants)))
return failure();
@@ -138,7 +129,7 @@
/// Tries to perform folding on the given `op`. If successful, populates
/// `results` with the results of the folding.
LogicalResult OperationFolder::tryToFold(
- Operation *op, SmallVectorImpl<Value *> &results,
+ Operation *op, SmallVectorImpl<Value> &results,
function_ref<void(Operation *)> processGeneratedConstants) {
SmallVector<Attribute, 8> operandConstants;
SmallVector<OpFoldResult, 8> foldResults;
@@ -181,13 +172,13 @@
assert(!foldResults[i].isNull() && "expected valid OpFoldResult");
// Check if the result was an SSA value.
- if (auto *repl = foldResults[i].dyn_cast<Value *>()) {
+ if (auto repl = foldResults[i].dyn_cast<Value>()) {
results.emplace_back(repl);
continue;
}
// Check to see if there is a canonicalized version of this constant.
- auto *res = op->getResult(i);
+ auto res = op->getResult(i);
Attribute attrRepl = foldResults[i].get<Attribute>();
if (auto *constOp =
tryGetOrCreateConstant(uniquedConstants, dialect, builder, attrRepl,
diff --git a/third_party/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/third_party/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index e2ca3f8..1eb9c57 100644
--- a/third_party/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/third_party/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -1,19 +1,10 @@
//===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements mlir::applyPatternsGreedily.
//
@@ -107,7 +98,7 @@
// simplifications to its users - make sure to add them to the worklist
// before the root is changed.
void notifyRootReplaced(Operation *op) override {
- for (auto *result : op->getResults())
+ for (auto result : op->getResults())
for (auto *user : result->getUsers())
addToWorklist(user);
}
@@ -118,7 +109,7 @@
// operation is modified or removed, as it may trigger further
// simplifications.
template <typename Operands> void addToWorklist(Operands &&operands) {
- for (Value *operand : operands) {
+ for (Value operand : operands) {
// If the use count of this operand is now < 2, we re-add the defining
// operation to the worklist.
// TODO(riverriddle) This is based on the fact that zero use operations
@@ -160,7 +151,7 @@
region.walk(collectOps);
// These are scratch vectors used in the folding loop below.
- SmallVector<Value *, 8> originalOperands, resultValues;
+ SmallVector<Value, 8> originalOperands, resultValues;
changed = false;
while (!worklist.empty()) {
@@ -189,7 +180,7 @@
// Add all the users of the result to the worklist so we make sure
// to revisit them.
- for (auto *result : op->getResults())
+ for (auto result : op->getResults())
for (auto *operand : result->getUsers())
addToWorklist(operand);
diff --git a/third_party/mlir/lib/Transforms/Utils/InliningUtils.cpp b/third_party/mlir/lib/Transforms/Utils/InliningUtils.cpp
index e8466aa..1ac286c 100644
--- a/third_party/mlir/lib/Transforms/Utils/InliningUtils.cpp
+++ b/third_party/mlir/lib/Transforms/Utils/InliningUtils.cpp
@@ -1,19 +1,10 @@
//===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements miscellaneous inlining utilities.
//
@@ -55,7 +46,7 @@
BlockAndValueMapping &mapper) {
auto remapOperands = [&](Operation *op) {
for (auto &operand : op->getOpOperands())
- if (auto *mappedOp = mapper.lookupOrNull(operand.get()))
+ if (auto mappedOp = mapper.lookupOrNull(operand.get()))
operand.set(mappedOp);
};
for (auto &block : inlinedBlocks)
@@ -98,7 +89,7 @@
/// Handle the given inlined terminator by replacing it with a new operation
/// as necessary.
void InlinerInterface::handleTerminator(Operation *op,
- ArrayRef<Value *> valuesToRepl) const {
+ ArrayRef<Value> valuesToRepl) const {
auto *handler = getInterfaceFor(op);
assert(handler && "expected valid dialect handler");
handler->handleTerminator(op, valuesToRepl);
@@ -137,7 +128,7 @@
LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
Operation *inlinePoint,
BlockAndValueMapping &mapper,
- ArrayRef<Value *> resultsToReplace,
+ ArrayRef<Value> resultsToReplace,
Optional<Location> inlineLoc,
bool shouldCloneInlinedRegion) {
// We expect the region to have at least one block.
@@ -147,7 +138,7 @@
// Check that all of the region arguments have been mapped.
auto *srcEntryBlock = &src->front();
if (llvm::any_of(srcEntryBlock->getArguments(),
- [&](BlockArgument *arg) { return !mapper.contains(arg); }))
+ [&](BlockArgument arg) { return !mapper.contains(arg); }))
return failure();
// The insertion point must be within a block.
@@ -207,7 +198,7 @@
} else {
// Otherwise, there were multiple blocks inlined. Add arguments to the post
// insertion block to represent the results to replace.
- for (Value *resultToRepl : resultsToReplace) {
+ for (Value resultToRepl : resultsToReplace) {
resultToRepl->replaceAllUsesWith(
postInsertBlock->addArgument(resultToRepl->getType()));
}
@@ -229,8 +220,8 @@
/// in-favor of the region arguments when inlining.
LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
Operation *inlinePoint,
- ArrayRef<Value *> inlinedOperands,
- ArrayRef<Value *> resultsToReplace,
+ ArrayRef<Value> inlinedOperands,
+ ArrayRef<Value> resultsToReplace,
Optional<Location> inlineLoc,
bool shouldCloneInlinedRegion) {
// We expect the region to have at least one block.
@@ -246,7 +237,7 @@
for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) {
// Verify that the types of the provided values match the function argument
// types.
- BlockArgument *regionArg = entryBlock->getArgument(i);
+ BlockArgument regionArg = entryBlock->getArgument(i);
if (inlinedOperands[i]->getType() != regionArg->getType())
return failure();
mapper.map(regionArg, inlinedOperands[i]);
@@ -259,10 +250,10 @@
/// Utility function used to generate a cast operation from the given interface,
/// or return nullptr if a cast could not be generated.
-static Value *materializeConversion(const DialectInlinerInterface *interface,
- SmallVectorImpl<Operation *> &castOps,
- OpBuilder &castBuilder, Value *arg,
- Type type, Location conversionLoc) {
+static Value materializeConversion(const DialectInlinerInterface *interface,
+ SmallVectorImpl<Operation *> &castOps,
+ OpBuilder &castBuilder, Value arg, Type type,
+ Location conversionLoc) {
if (!interface)
return nullptr;
@@ -297,8 +288,8 @@
// Make sure that the number of arguments and results matchup between the call
// and the region.
- SmallVector<Value *, 8> callOperands(call.getArgOperands());
- SmallVector<Value *, 8> callResults(call.getOperation()->getResults());
+ SmallVector<Value, 8> callOperands(call.getArgOperands());
+ SmallVector<Value, 8> callResults(call.getOperation()->getResults());
if (callOperands.size() != entryBlock->getNumArguments() ||
callResults.size() != callableResultTypes.size())
return failure();
@@ -325,8 +316,8 @@
// Map the provided call operands to the arguments of the region.
BlockAndValueMapping mapper;
for (unsigned i = 0, e = callOperands.size(); i != e; ++i) {
- BlockArgument *regionArg = entryBlock->getArgument(i);
- Value *operand = callOperands[i];
+ BlockArgument regionArg = entryBlock->getArgument(i);
+ Value operand = callOperands[i];
// If the call operand doesn't match the expected region argument, try to
// generate a cast.
@@ -342,13 +333,13 @@
// Ensure that the resultant values of the call, match the callable.
castBuilder.setInsertionPointAfter(call);
for (unsigned i = 0, e = callResults.size(); i != e; ++i) {
- Value *callResult = callResults[i];
+ Value callResult = callResults[i];
if (callResult->getType() == callableResultTypes[i])
continue;
// Generate a conversion that will produce the original type, so that the IR
// is still valid after the original call gets replaced.
- Value *castResult =
+ Value castResult =
materializeConversion(callInterface, castOps, castBuilder, callResult,
callResult->getType(), castLoc);
if (!castResult)
diff --git a/third_party/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/third_party/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
index fd80339..b0d9fdf 100644
--- a/third_party/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
+++ b/third_party/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp
@@ -1,19 +1,10 @@
//===- LoopFusionUtils.cpp ---- Utilities for loop fusion ----------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements loop fusion transformation utility functions.
//
@@ -45,7 +36,7 @@
// Gathers all load and store memref accesses in 'opA' into 'values', where
// 'values[memref] == true' for each store operation.
static void getLoadAndStoreMemRefAccesses(Operation *opA,
- DenseMap<Value *, bool> &values) {
+ DenseMap<Value, bool> &values) {
opA->walk([&](Operation *op) {
if (auto loadOp = dyn_cast<AffineLoadOp>(op)) {
if (values.count(loadOp.getMemRef()) == 0)
@@ -60,7 +51,7 @@
// accessed 'values' and at least one of the access is a store operation.
// Returns false otherwise.
static bool isDependentLoadOrStoreOp(Operation *op,
- DenseMap<Value *, bool> &values) {
+ DenseMap<Value, bool> &values) {
if (auto loadOp = dyn_cast<AffineLoadOp>(op)) {
return values.count(loadOp.getMemRef()) > 0 &&
values[loadOp.getMemRef()] == true;
@@ -75,7 +66,7 @@
static Operation *getFirstDependentOpInRange(Operation *opA, Operation *opB) {
// Record memref values from all loads/store in loop nest rooted at 'opA'.
// Map from memref value to bool which is true if store, false otherwise.
- DenseMap<Value *, bool> values;
+ DenseMap<Value, bool> values;
getLoadAndStoreMemRefAccesses(opA, values);
// For each 'opX' in block in range ('opA', 'opB'), check if there is a data
@@ -101,7 +92,7 @@
static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) {
// Record memref values from all loads/store in loop nest rooted at 'opB'.
// Map from memref value to bool which is true if store, false otherwise.
- DenseMap<Value *, bool> values;
+ DenseMap<Value, bool> values;
getLoadAndStoreMemRefAccesses(opB, values);
// For each 'opX' in block in range ('opA', 'opB') in reverse order,
@@ -121,8 +112,8 @@
}
return WalkResult::advance();
}
- for (auto *value : op->getResults()) {
- for (auto *user : value->getUsers()) {
+ for (auto value : op->getResults()) {
+ for (auto user : value->getUsers()) {
SmallVector<AffineForOp, 4> loops;
// Check if any loop in loop nest surrounding 'user' is 'opB'.
getLoopIVs(*user, &loops);
@@ -443,7 +434,7 @@
// Subtract from operation count the loads/store we expect load/store
// forwarding to remove.
unsigned storeCount = 0;
- llvm::SmallDenseSet<Value *, 4> storeMemrefs;
+ llvm::SmallDenseSet<Value, 4> storeMemrefs;
srcForOp.walk([&](Operation *op) {
if (auto storeOp = dyn_cast<AffineStoreOp>(op)) {
storeMemrefs.insert(storeOp.getMemRef());
@@ -455,7 +446,7 @@
computeCostMap[insertPointParent] = -storeCount;
// Subtract out any load users of 'storeMemrefs' nested below
// 'insertPointParent'.
- for (auto *value : storeMemrefs) {
+ for (auto value : storeMemrefs) {
for (auto *user : value->getUsers()) {
if (auto loadOp = dyn_cast<AffineLoadOp>(user)) {
SmallVector<AffineForOp, 4> loops;
diff --git a/third_party/mlir/lib/Transforms/Utils/LoopUtils.cpp b/third_party/mlir/lib/Transforms/Utils/LoopUtils.cpp
index 419df8d..0fece54 100644
--- a/third_party/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/third_party/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -1,19 +1,10 @@
//===- LoopUtils.cpp ---- Misc utilities for loop transformation ----------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements miscellaneous loop transformation routines.
//
@@ -52,7 +43,7 @@
/// expression.
void mlir::getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor,
AffineMap *map,
- SmallVectorImpl<Value *> *operands,
+ SmallVectorImpl<Value> *operands,
OpBuilder &b) {
auto lbMap = forOp.getLowerBoundMap();
@@ -63,7 +54,7 @@
}
AffineMap tripCountMap;
- SmallVector<Value *, 4> tripCountOperands;
+ SmallVector<Value, 4> tripCountOperands;
buildTripCountMapAndOperands(forOp, &tripCountMap, &tripCountOperands);
// Sometimes the trip count cannot be expressed as an affine expression.
@@ -82,7 +73,7 @@
// lb + tr1 - tr1 % ufactor, lb + tr2 - tr2 % ufactor; the results of all
// these affine.apply's make up the cleanup loop lower bound.
SmallVector<AffineExpr, 4> bumpExprs(tripCountMap.getNumResults());
- SmallVector<Value *, 4> bumpValues(tripCountMap.getNumResults());
+ SmallVector<Value, 4> bumpValues(tripCountMap.getNumResults());
for (unsigned i = 0, e = tripCountMap.getNumResults(); i < e; i++) {
auto tripCountExpr = tripCountMap.getResult(i);
bumpExprs[i] = (tripCountExpr - tripCountExpr % unrollFactor) * step;
@@ -105,7 +96,7 @@
*map = simplifyAffineMap(*map);
canonicalizeMapAndOperands(map, operands);
// Remove any affine.apply's that became dead from the simplification above.
- for (auto *v : bumpValues) {
+ for (auto v : bumpValues) {
if (v->use_empty()) {
v->getDefiningOp()->erase();
}
@@ -127,7 +118,7 @@
return failure();
// Replaces all IV uses to its single iteration value.
- auto *iv = forOp.getInductionVar();
+ auto iv = forOp.getInductionVar();
Operation *op = forOp.getOperation();
if (!iv->use_empty()) {
if (forOp.hasConstantLowerBound()) {
@@ -137,7 +128,7 @@
iv->replaceAllUsesWith(constOp);
} else {
AffineBound lb = forOp.getLowerBound();
- SmallVector<Value *, 4> lbOperands(lb.operand_begin(), lb.operand_end());
+ SmallVector<Value, 4> lbOperands(lb.operand_begin(), lb.operand_end());
OpBuilder builder(op->getBlock(), Block::iterator(op));
if (lb.getMap() == builder.getDimIdentityMap()) {
// No need of generating an affine.apply.
@@ -178,8 +169,8 @@
const std::vector<std::pair<uint64_t, ArrayRef<Operation *>>>
&instGroupQueue,
unsigned offset, AffineForOp srcForInst, OpBuilder b) {
- SmallVector<Value *, 4> lbOperands(srcForInst.getLowerBoundOperands());
- SmallVector<Value *, 4> ubOperands(srcForInst.getUpperBoundOperands());
+ SmallVector<Value, 4> lbOperands(srcForInst.getLowerBoundOperands());
+ SmallVector<Value, 4> ubOperands(srcForInst.getUpperBoundOperands());
assert(lbMap.getNumInputs() == lbOperands.size());
assert(ubMap.getNumInputs() == ubOperands.size());
@@ -187,8 +178,8 @@
auto loopChunk =
b.create<AffineForOp>(srcForInst.getLoc(), lbOperands, lbMap, ubOperands,
ubMap, srcForInst.getStep());
- auto *loopChunkIV = loopChunk.getInductionVar();
- auto *srcIV = srcForInst.getInductionVar();
+ auto loopChunkIV = loopChunk.getInductionVar();
+ auto srcIV = srcForInst.getInductionVar();
BlockAndValueMapping operandMap;
@@ -449,7 +440,7 @@
OpBuilder builder(op->getBlock(), ++Block::iterator(op));
auto cleanupForInst = cast<AffineForOp>(builder.clone(*op));
AffineMap cleanupMap;
- SmallVector<Value *, 4> cleanupOperands;
+ SmallVector<Value, 4> cleanupOperands;
getCleanupLoopLowerBound(forOp, unrollFactor, &cleanupMap, &cleanupOperands,
builder);
assert(cleanupMap &&
@@ -477,7 +468,7 @@
Block::iterator srcBlockEnd = std::prev(forOp.getBody()->end(), 2);
// Unroll the contents of 'forOp' (append unrollFactor-1 additional copies).
- auto *forOpIV = forOp.getInductionVar();
+ auto forOpIV = forOp.getInductionVar();
for (unsigned i = 1; i < unrollFactor; i++) {
BlockAndValueMapping operandMap;
@@ -669,8 +660,8 @@
// ...
// }
// ```
-static void augmentMapAndBounds(OpBuilder &b, Value *iv, AffineMap *map,
- SmallVector<Value *, 4> *operands,
+static void augmentMapAndBounds(OpBuilder &b, Value iv, AffineMap *map,
+ SmallVector<Value, 4> *operands,
int64_t offset = 0) {
auto bounds = llvm::to_vector<4>(map->getResults());
bounds.push_back(b.getAffineDimExpr(map->getNumDims()) + offset);
@@ -699,16 +690,16 @@
// Lower-bound map creation.
auto lbMap = forOp.getLowerBoundMap();
- SmallVector<Value *, 4> lbOperands(forOp.getLowerBoundOperands());
+ SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
augmentMapAndBounds(b, forOp.getInductionVar(), &lbMap, &lbOperands);
// Upper-bound map creation.
auto ubMap = forOp.getUpperBoundMap();
- SmallVector<Value *, 4> ubOperands(forOp.getUpperBoundOperands());
+ SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
augmentMapAndBounds(b, forOp.getInductionVar(), &ubMap, &ubOperands,
/*offset=*/scaledStep);
- auto *iv = forOp.getInductionVar();
+ auto iv = forOp.getInductionVar();
SmallVector<AffineForOp, 8> innerLoops;
for (auto t : targets) {
// Insert newForOp before the terminator of `t`.
@@ -729,10 +720,10 @@
return innerLoops;
}
-static Loops stripmineSink(loop::ForOp forOp, Value *factor,
+static Loops stripmineSink(loop::ForOp forOp, Value factor,
ArrayRef<loop::ForOp> targets) {
- auto *originalStep = forOp.step();
- auto *iv = forOp.getInductionVar();
+ auto originalStep = forOp.step();
+ auto iv = forOp.getInductionVar();
OpBuilder b(forOp);
forOp.setStep(b.create<MulIOp>(forOp.getLoc(), originalStep, factor));
@@ -745,10 +736,10 @@
// Insert newForOp before the terminator of `t`.
OpBuilder b(t.getBodyBuilder());
- Value *stepped = b.create<AddIOp>(t.getLoc(), iv, forOp.step());
- Value *less = b.create<CmpIOp>(t.getLoc(), CmpIPredicate::slt,
- forOp.upperBound(), stepped);
- Value *ub =
+ Value stepped = b.create<AddIOp>(t.getLoc(), iv, forOp.step());
+ Value less = b.create<CmpIOp>(t.getLoc(), CmpIPredicate::slt,
+ forOp.upperBound(), stepped);
+ Value ub =
b.create<SelectOp>(t.getLoc(), less, forOp.upperBound(), stepped);
// Splice [begin, begin + nOps - 1) into `newForOp` and replace uses.
@@ -799,7 +790,7 @@
}
SmallVector<Loops, 8> mlir::tile(ArrayRef<loop::ForOp> forOps,
- ArrayRef<Value *> sizes,
+ ArrayRef<Value> sizes,
ArrayRef<loop::ForOp> targets) {
return tileImpl(forOps, sizes, targets);
}
@@ -821,13 +812,12 @@
return tileImpl(forOps, sizes, target);
}
-Loops mlir::tile(ArrayRef<loop::ForOp> forOps, ArrayRef<Value *> sizes,
+Loops mlir::tile(ArrayRef<loop::ForOp> forOps, ArrayRef<Value> sizes,
loop::ForOp target) {
return tileImpl(forOps, sizes, target);
}
-Loops mlir::tilePerfectlyNested(loop::ForOp rootForOp,
- ArrayRef<Value *> sizes) {
+Loops mlir::tilePerfectlyNested(loop::ForOp rootForOp, ArrayRef<Value> sizes) {
// Collect perfectly nested loops. If more size values provided than nested
// loops available, truncate `sizes`.
SmallVector<loop::ForOp, 4> forOps;
@@ -842,29 +832,29 @@
// Build the IR that performs ceil division of a positive value by a constant:
// ceildiv(a, B) = divis(a + (B-1), B)
// where divis is rounding-to-zero division.
-static Value *ceilDivPositive(OpBuilder &builder, Location loc, Value *dividend,
- int64_t divisor) {
+static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
+ int64_t divisor) {
assert(divisor > 0 && "expected positive divisor");
assert(dividend->getType().isIndex() && "expected index-typed value");
- Value *divisorMinusOneCst = builder.create<ConstantIndexOp>(loc, divisor - 1);
- Value *divisorCst = builder.create<ConstantIndexOp>(loc, divisor);
- Value *sum = builder.create<AddIOp>(loc, dividend, divisorMinusOneCst);
- return builder.create<DivISOp>(loc, sum, divisorCst);
+ Value divisorMinusOneCst = builder.create<ConstantIndexOp>(loc, divisor - 1);
+ Value divisorCst = builder.create<ConstantIndexOp>(loc, divisor);
+ Value sum = builder.create<AddIOp>(loc, dividend, divisorMinusOneCst);
+ return builder.create<SignedDivIOp>(loc, sum, divisorCst);
}
// Build the IR that performs ceil division of a positive value by another
// positive value:
// ceildiv(a, b) = divis(a + (b - 1), b)
// where divis is rounding-to-zero division.
-static Value *ceilDivPositive(OpBuilder &builder, Location loc, Value *dividend,
- Value *divisor) {
+static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
+ Value divisor) {
assert(dividend->getType().isIndex() && "expected index-typed value");
- Value *cstOne = builder.create<ConstantIndexOp>(loc, 1);
- Value *divisorMinusOne = builder.create<SubIOp>(loc, divisor, cstOne);
- Value *sum = builder.create<AddIOp>(loc, dividend, divisorMinusOne);
- return builder.create<DivISOp>(loc, sum, divisor);
+ Value cstOne = builder.create<ConstantIndexOp>(loc, 1);
+ Value divisorMinusOne = builder.create<SubIOp>(loc, divisor, cstOne);
+ Value sum = builder.create<AddIOp>(loc, dividend, divisorMinusOne);
+ return builder.create<SignedDivIOp>(loc, sum, divisor);
}
// Hoist the ops within `outer` that appear before `inner`.
@@ -945,7 +935,7 @@
// iterations. Given that the loop current executes
// numIterations = ceildiv((upperBound - lowerBound), step)
// iterations, we need to tile with size ceildiv(numIterations, size[i]).
- SmallVector<Value *, 4> tileSizes;
+ SmallVector<Value, 4> tileSizes;
tileSizes.reserve(sizes.size());
for (unsigned i = 0, e = sizes.size(); i < e; ++i) {
assert(sizes[i] > 0 && "expected strictly positive size for strip-mining");
@@ -953,10 +943,10 @@
auto forOp = forOps[i];
OpBuilder builder(forOp);
auto loc = forOp.getLoc();
- Value *diff =
+ Value diff =
builder.create<SubIOp>(loc, forOp.upperBound(), forOp.lowerBound());
- Value *numIterations = ceilDivPositive(builder, loc, diff, forOp.step());
- Value *iterationsPerBlock =
+ Value numIterations = ceilDivPositive(builder, loc, diff, forOp.step());
+ Value iterationsPerBlock =
ceilDivPositive(builder, loc, numIterations, sizes[i]);
tileSizes.push_back(iterationsPerBlock);
}
@@ -976,7 +966,7 @@
// Replaces all uses of `orig` with `replacement` except if the user is listed
// in `exceptions`.
static void
-replaceAllUsesExcept(Value *orig, Value *replacement,
+replaceAllUsesExcept(Value orig, Value replacement,
const SmallPtrSetImpl<Operation *> &exceptions) {
for (auto &use : llvm::make_early_inc_range(orig->getUses())) {
if (exceptions.count(use.getOwner()) == 0)
@@ -1018,30 +1008,30 @@
// of the loop to go from 0 to the number of iterations, if necessary.
// TODO(zinenko): introduce support for negative steps or emit dynamic asserts
// on step positivity, whatever gets implemented first.
- Value *diff =
+ Value diff =
builder.create<SubIOp>(loc, loop.upperBound(), loop.lowerBound());
- Value *numIterations = ceilDivPositive(builder, loc, diff, loop.step());
+ Value numIterations = ceilDivPositive(builder, loc, diff, loop.step());
loop.setUpperBound(numIterations);
- Value *lb = loop.lowerBound();
+ Value lb = loop.lowerBound();
if (!isZeroBased) {
- Value *cst0 = builder.create<ConstantIndexOp>(loc, 0);
+ Value cst0 = builder.create<ConstantIndexOp>(loc, 0);
loop.setLowerBound(cst0);
}
- Value *step = loop.step();
+ Value step = loop.step();
if (!isStepOne) {
- Value *cst1 = builder.create<ConstantIndexOp>(loc, 1);
+ Value cst1 = builder.create<ConstantIndexOp>(loc, 1);
loop.setStep(cst1);
}
// Insert code computing the value of the original loop induction variable
// from the "normalized" one.
builder.setInsertionPointToStart(inner.getBody());
- Value *scaled =
+ Value scaled =
isStepOne ? loop.getInductionVar()
: builder.create<MulIOp>(loc, loop.getInductionVar(), step);
- Value *shifted =
+ Value shifted =
isZeroBased ? scaled : builder.create<AddIOp>(loc, scaled, lb);
SmallPtrSet<Operation *, 2> preserve{scaled->getDefiningOp(),
@@ -1065,7 +1055,7 @@
// of the number of iterations of all loops.
OpBuilder builder(outermost);
Location loc = outermost.getLoc();
- Value *upperBound = outermost.upperBound();
+ Value upperBound = outermost.upperBound();
for (auto loop : loops.drop_front())
upperBound = builder.create<MulIOp>(loc, upperBound, loop.upperBound());
outermost.setUpperBound(upperBound);
@@ -1080,16 +1070,16 @@
// iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
// Compute these iteratively from the innermost loop by creating a "running
// quotient" of division by the range.
- Value *previous = outermost.getInductionVar();
+ Value previous = outermost.getInductionVar();
for (unsigned i = 0, e = loops.size(); i < e; ++i) {
unsigned idx = loops.size() - i - 1;
if (i != 0)
- previous =
- builder.create<DivISOp>(loc, previous, loops[idx + 1].upperBound());
+ previous = builder.create<SignedDivIOp>(loc, previous,
+ loops[idx + 1].upperBound());
- Value *iv = (i == e - 1) ? previous
- : builder.create<RemISOp>(loc, previous,
- loops[idx].upperBound());
+ Value iv = (i == e - 1) ? previous
+ : builder.create<SignedRemIOp>(
+ loc, previous, loops[idx].upperBound());
replaceAllUsesInRegionWith(loops[idx].getInductionVar(), iv,
loops.back().region());
}
@@ -1104,25 +1094,24 @@
second.erase();
}
-void mlir::mapLoopToProcessorIds(loop::ForOp forOp,
- ArrayRef<Value *> processorId,
- ArrayRef<Value *> numProcessors) {
+void mlir::mapLoopToProcessorIds(loop::ForOp forOp, ArrayRef<Value> processorId,
+ ArrayRef<Value> numProcessors) {
assert(processorId.size() == numProcessors.size());
if (processorId.empty())
return;
OpBuilder b(forOp);
Location loc(forOp.getLoc());
- Value *mul = processorId.front();
+ Value mul = processorId.front();
for (unsigned i = 1, e = processorId.size(); i < e; ++i)
mul = b.create<AddIOp>(loc, b.create<MulIOp>(loc, mul, numProcessors[i]),
processorId[i]);
- Value *lb = b.create<AddIOp>(loc, forOp.lowerBound(),
- b.create<MulIOp>(loc, forOp.step(), mul));
+ Value lb = b.create<AddIOp>(loc, forOp.lowerBound(),
+ b.create<MulIOp>(loc, forOp.step(), mul));
forOp.setLowerBound(lb);
- Value *step = forOp.step();
- for (auto *numProcs : numProcessors)
+ Value step = forOp.step();
+ for (auto numProcs : numProcessors)
step = b.create<MulIOp>(loc, step, numProcs);
forOp.setStep(step);
}
@@ -1139,7 +1128,7 @@
Block::iterator *copyInPlacementStart,
Block::iterator *copyOutPlacementStart) {
const auto *cst = region.getConstraints();
- SmallVector<Value *, 4> symbols;
+ SmallVector<Value, 4> symbols;
cst->getIdValues(cst->getNumDimIds(), cst->getNumDimAndSymbolIds(), &symbols);
SmallVector<AffineForOp, 4> enclosingFors;
@@ -1202,10 +1191,10 @@
/// returns the outermost AffineForOp of the copy loop nest. `memIndicesStart'
/// holds the lower coordinates of the region in the original memref to copy
/// in/out. If `copyOut' is true, generates a copy-out; otherwise a copy-in.
-static AffineForOp generatePointWiseCopy(Location loc, Value *memref,
- Value *fastMemRef,
+static AffineForOp generatePointWiseCopy(Location loc, Value memref,
+ Value fastMemRef,
AffineMap memAffineMap,
- ArrayRef<Value *> memIndicesStart,
+ ArrayRef<Value> memIndicesStart,
ArrayRef<int64_t> fastBufferShape,
bool isCopyOut, OpBuilder b) {
assert(!memIndicesStart.empty() && "only 1-d or more memrefs");
@@ -1215,7 +1204,7 @@
// for y = ...
// fast_buf[x][y] = buf[mem_x + x][mem_y + y]
- SmallVector<Value *, 4> fastBufIndices, memIndices;
+ SmallVector<Value, 4> fastBufIndices, memIndices;
AffineForOp copyNestRoot;
for (unsigned d = 0, e = fastBufferShape.size(); d < e; ++d) {
auto forOp = b.create<AffineForOp>(loc, 0, fastBufferShape[d]);
@@ -1224,7 +1213,7 @@
b = forOp.getBodyBuilder();
fastBufIndices.push_back(forOp.getInductionVar());
- Value *memBase =
+ Value memBase =
(memAffineMap == b.getMultiDimIdentityMap(memAffineMap.getNumDims()))
? memIndicesStart[d]
: b.create<AffineApplyOp>(
@@ -1277,7 +1266,7 @@
const MemRefRegion ®ion, Block *block, Block::iterator begin,
Block::iterator end, Block *copyPlacementBlock,
Block::iterator copyInPlacementStart, Block::iterator copyOutPlacementStart,
- AffineCopyOptions copyOptions, DenseMap<Value *, Value *> &fastBufferMap,
+ AffineCopyOptions copyOptions, DenseMap<Value, Value> &fastBufferMap,
DenseSet<Operation *> ©Nests, uint64_t *sizeInBytes,
Block::iterator *nBegin, Block::iterator *nEnd) {
*nBegin = begin;
@@ -1285,7 +1274,7 @@
FuncOp f = begin->getParentOfType<FuncOp>();
OpBuilder topBuilder(f.getBody());
- Value *zeroIndex = topBuilder.create<ConstantIndexOp>(f.getLoc(), 0);
+ Value zeroIndex = topBuilder.create<ConstantIndexOp>(f.getLoc(), 0);
if (begin == end)
return success();
@@ -1305,7 +1294,7 @@
OpBuilder top(func.getBody());
auto loc = region.loc;
- auto *memref = region.memref;
+ auto memref = region.memref;
auto memRefType = memref->getType().cast<MemRefType>();
auto layoutMaps = memRefType.getAffineMaps();
@@ -1317,9 +1306,9 @@
// Indices to use for the copying.
// Indices for the original memref being copied from/to.
- SmallVector<Value *, 4> memIndices;
+ SmallVector<Value, 4> memIndices;
// Indices for the faster buffer being copied into/from.
- SmallVector<Value *, 4> bufIndices;
+ SmallVector<Value, 4> bufIndices;
unsigned rank = memRefType.getRank();
SmallVector<int64_t, 4> fastBufferShape;
@@ -1345,7 +1334,7 @@
// 'regionSymbols' hold values that this memory region is symbolic/parametric
// on; these typically include loop IVs surrounding the level at which the
// copy generation is being done or other valid symbols in MLIR.
- SmallVector<Value *, 8> regionSymbols;
+ SmallVector<Value, 8> regionSymbols;
cst->getIdValues(rank, cst->getNumIds(), ®ionSymbols);
// Construct the index expressions for the fast memory buffer. The index
@@ -1393,7 +1382,7 @@
}
// The faster memory space buffer.
- Value *fastMemRef;
+ Value fastMemRef;
// Check if a buffer was already created.
bool existingBuf = fastBufferMap.count(memref) > 0;
@@ -1433,8 +1422,8 @@
return failure();
}
- Value *stride = nullptr;
- Value *numEltPerStride = nullptr;
+ Value stride = nullptr;
+ Value numEltPerStride = nullptr;
if (!strideInfos.empty()) {
stride = top.create<ConstantIndexOp>(loc, strideInfos[0].stride);
numEltPerStride =
@@ -1473,7 +1462,7 @@
copyOptions.tagMemorySpace);
auto tagMemRef = prologue.create<AllocOp>(loc, tagMemRefType);
- SmallVector<Value *, 4> tagIndices({zeroIndex});
+ SmallVector<Value, 4> tagIndices({zeroIndex});
auto tagAffineMap = b.getMultiDimIdentityMap(tagIndices.size());
fullyComposeAffineMapAndOperands(&tagAffineMap, &tagIndices);
if (!region.isWrite()) {
@@ -1582,7 +1571,7 @@
SmallVector<AffineForOp, 4> ivs;
getLoopIVs(*opInst, &ivs);
ivs.resize(numParamLoopIVs);
- SmallVector<Value *, 4> symbols;
+ SmallVector<Value, 4> symbols;
extractForInductionVars(ivs, &symbols);
regionCst->reset(rank, numParamLoopIVs, 0);
regionCst->setIdValues(rank, rank + numParamLoopIVs, symbols);
@@ -1629,12 +1618,12 @@
// List of memory regions to copy for. We need a map vector to have a
// guaranteed iteration order to write test cases. CHECK-DAG doesn't help here
// since the alloc's for example are identical except for the SSA id.
- SmallMapVector<Value *, std::unique_ptr<MemRefRegion>, 4> readRegions;
- SmallMapVector<Value *, std::unique_ptr<MemRefRegion>, 4> writeRegions;
+ SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4> readRegions;
+ SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4> writeRegions;
// Map from original memref's to the fast buffers that their accesses are
// replaced with.
- DenseMap<Value *, Value *> fastBufferMap;
+ DenseMap<Value, Value> fastBufferMap;
// To check for errors when walking the block.
bool error = false;
@@ -1684,7 +1673,7 @@
// Attempts to update; returns true if 'region' exists in targetRegions.
auto updateRegion =
- [&](const SmallMapVector<Value *, std::unique_ptr<MemRefRegion>, 4>
+ [&](const SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4>
&targetRegions) {
auto it = targetRegions.find(region->memref);
if (it == targetRegions.end())
@@ -1736,7 +1725,7 @@
uint64_t totalCopyBuffersSizeInBytes = 0;
bool ret = true;
auto processRegions =
- [&](const SmallMapVector<Value *, std::unique_ptr<MemRefRegion>, 4>
+ [&](const SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4>
®ions) {
for (const auto ®ionEntry : regions) {
// For each region, hoist copy in/out past all hoistable
diff --git a/third_party/mlir/lib/Transforms/Utils/RegionUtils.cpp b/third_party/mlir/lib/Transforms/Utils/RegionUtils.cpp
index b91b189..ca26074 100644
--- a/third_party/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/third_party/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -1,19 +1,10 @@
//===- RegionUtils.cpp - Region-related transformation utilities ----------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Transforms/RegionUtils.h"
#include "mlir/IR/Block.h"
@@ -27,9 +18,9 @@
using namespace mlir;
-void mlir::replaceAllUsesInRegionWith(Value *orig, Value *replacement,
+void mlir::replaceAllUsesInRegionWith(Value orig, Value replacement,
Region ®ion) {
- for (IROperand &use : llvm::make_early_inc_range(orig->getUses())) {
+ for (auto &use : llvm::make_early_inc_range(orig->getUses())) {
if (region.isAncestor(use.getOwner()->getParentRegion()))
use.set(replacement);
}
@@ -63,14 +54,14 @@
}
void mlir::getUsedValuesDefinedAbove(Region ®ion, Region &limit,
- llvm::SetVector<Value *> &values) {
+ llvm::SetVector<Value> &values) {
visitUsedValuesDefinedAbove(region, limit, [&](OpOperand *operand) {
values.insert(operand->get());
});
}
void mlir::getUsedValuesDefinedAbove(MutableArrayRef<Region> regions,
- llvm::SetVector<Value *> &values) {
+ llvm::SetVector<Value> &values) {
for (Region ®ion : regions)
getUsedValuesDefinedAbove(region, region, values);
}
@@ -146,8 +137,8 @@
class LiveMap {
public:
/// Value methods.
- bool wasProvenLive(Value *value) { return liveValues.count(value); }
- void setProvedLive(Value *value) {
+ bool wasProvenLive(Value value) { return liveValues.count(value); }
+ void setProvedLive(Value value) {
changed |= liveValues.insert(value).second;
}
@@ -161,7 +152,7 @@
private:
bool changed = false;
- DenseSet<Value *> liveValues;
+ DenseSet<Value> liveValues;
DenseSet<Operation *> liveOps;
};
} // namespace
@@ -188,7 +179,7 @@
return false;
}
-static void processValue(Value *value, LiveMap &liveMap) {
+static void processValue(Value value, LiveMap &liveMap) {
bool provedLive = llvm::any_of(value->getUses(), [&](OpOperand &use) {
if (isUseSpeciallyKnownDead(use, liveMap))
return false;
@@ -222,9 +213,9 @@
liveMap.setProvedLive(op);
return;
}
- for (Value *value : op->getResults())
+ for (Value value : op->getResults())
processValue(value, liveMap);
- bool provedLive = llvm::any_of(op->getResults(), [&](Value *value) {
+ bool provedLive = llvm::any_of(op->getResults(), [&](Value value) {
return liveMap.wasProvenLive(value);
});
if (provedLive)
@@ -240,7 +231,7 @@
// faster convergence to a fixed point (we try to visit uses before defs).
for (Operation &op : llvm::reverse(block->getOperations()))
propagateLiveness(&op, liveMap);
- for (Value *value : block->getArguments())
+ for (Value value : block->getArguments())
processValue(value, liveMap);
}
}
@@ -259,7 +250,7 @@
// Iterating args in reverse is needed for correctness, to avoid
// shifting later args when earlier args are erased.
unsigned arg = argE - argI - 1;
- Value *value = terminator->getSuccessor(succ)->getArgument(arg);
+ Value value = terminator->getSuccessor(succ)->getArgument(arg);
if (!liveMap.wasProvenLive(value)) {
terminator->eraseSuccessorOperand(succ, arg);
}
diff --git a/third_party/mlir/lib/Transforms/Utils/Utils.cpp b/third_party/mlir/lib/Transforms/Utils/Utils.cpp
index 57a9253..a662918 100644
--- a/third_party/mlir/lib/Transforms/Utils/Utils.cpp
+++ b/third_party/mlir/lib/Transforms/Utils/Utils.cpp
@@ -1,19 +1,10 @@
//===- Utils.cpp ---- Misc utilities for code and data transformation -----===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements miscellaneous transformation routines for non-loop IR
// structures.
@@ -47,7 +38,7 @@
}
/// Return the AffineMapAttr associated with memory 'op' on 'memref'.
-static NamedAttribute getAffineMapAttrForMemRef(Operation *op, Value *memref) {
+static NamedAttribute getAffineMapAttrForMemRef(Operation *op, Value memref) {
return TypeSwitch<Operation *, NamedAttribute>(op)
.Case<AffineDmaStartOp, AffineLoadOp, AffinePrefetchOp, AffineStoreOp,
AffineDmaWaitOp>(
@@ -55,12 +46,12 @@
}
// Perform the replacement in `op`.
-LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
+LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
Operation *op,
- ArrayRef<Value *> extraIndices,
+ ArrayRef<Value> extraIndices,
AffineMap indexRemap,
- ArrayRef<Value *> extraOperands,
- ArrayRef<Value *> symbolOperands) {
+ ArrayRef<Value> extraOperands,
+ ArrayRef<Value> symbolOperands) {
unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
(void)newMemRefRank; // unused in opt mode
unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank();
@@ -106,13 +97,13 @@
NamedAttribute oldMapAttrPair = getAffineMapAttrForMemRef(op, oldMemRef);
AffineMap oldMap = oldMapAttrPair.second.cast<AffineMapAttr>().getValue();
unsigned oldMapNumInputs = oldMap.getNumInputs();
- SmallVector<Value *, 4> oldMapOperands(
+ SmallVector<Value, 4> oldMapOperands(
op->operand_begin() + memRefOperandPos + 1,
op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
// Apply 'oldMemRefOperands = oldMap(oldMapOperands)'.
- SmallVector<Value *, 4> oldMemRefOperands;
- SmallVector<Value *, 4> affineApplyOps;
+ SmallVector<Value, 4> oldMemRefOperands;
+ SmallVector<Value, 4> affineApplyOps;
oldMemRefOperands.reserve(oldMemRefRank);
if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) {
for (auto resultExpr : oldMap.getResults()) {
@@ -130,14 +121,14 @@
// Construct new indices as a remap of the old ones if a remapping has been
// provided. The indices of a memref come right after it, i.e.,
// at position memRefOperandPos + 1.
- SmallVector<Value *, 4> remapOperands;
+ SmallVector<Value, 4> remapOperands;
remapOperands.reserve(extraOperands.size() + oldMemRefRank +
symbolOperands.size());
remapOperands.append(extraOperands.begin(), extraOperands.end());
remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
remapOperands.append(symbolOperands.begin(), symbolOperands.end());
- SmallVector<Value *, 4> remapOutputs;
+ SmallVector<Value, 4> remapOutputs;
remapOutputs.reserve(oldMemRefRank);
if (indexRemap &&
@@ -156,11 +147,11 @@
remapOutputs.append(remapOperands.begin(), remapOperands.end());
}
- SmallVector<Value *, 4> newMapOperands;
+ SmallVector<Value, 4> newMapOperands;
newMapOperands.reserve(newMemRefRank);
// Prepend 'extraIndices' in 'newMapOperands'.
- for (auto *extraIndex : extraIndices) {
+ for (auto extraIndex : extraIndices) {
assert(extraIndex->getDefiningOp()->getNumResults() == 1 &&
"single result op's expected to generate these indices");
assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
@@ -179,7 +170,7 @@
newMap = simplifyAffineMap(newMap);
canonicalizeMapAndOperands(&newMap, &newMapOperands);
// Remove any affine.apply's that became dead as a result of composition.
- for (auto *value : affineApplyOps)
+ for (auto value : affineApplyOps)
if (value->use_empty())
value->getDefiningOp()->erase();
@@ -203,7 +194,7 @@
// Result types don't change. Both memref's are of the same elemental type.
state.types.reserve(op->getNumResults());
- for (auto *result : op->getResults())
+ for (auto result : op->getResults())
state.types.push_back(result->getType());
// Add attribute for 'newMap', other Attributes do not change.
@@ -224,11 +215,11 @@
return success();
}
-LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
- ArrayRef<Value *> extraIndices,
+LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
+ ArrayRef<Value> extraIndices,
AffineMap indexRemap,
- ArrayRef<Value *> extraOperands,
- ArrayRef<Value *> symbolOperands,
+ ArrayRef<Value> extraOperands,
+ ArrayRef<Value> symbolOperands,
Operation *domInstFilter,
Operation *postDomInstFilter) {
unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
@@ -331,9 +322,9 @@
void mlir::createAffineComputationSlice(
Operation *opInst, SmallVectorImpl<AffineApplyOp> *sliceOps) {
// Collect all operands that are results of affine apply ops.
- SmallVector<Value *, 4> subOperands;
+ SmallVector<Value, 4> subOperands;
subOperands.reserve(opInst->getNumOperands());
- for (auto *operand : opInst->getOperands())
+ for (auto operand : opInst->getOperands())
if (isa_and_nonnull<AffineApplyOp>(operand->getDefiningOp()))
subOperands.push_back(operand);
@@ -348,7 +339,7 @@
// which case there would be nothing to do.
bool localized = true;
for (auto *op : affineApplyOps) {
- for (auto *result : op->getResults()) {
+ for (auto result : op->getResults()) {
for (auto *user : result->getUsers()) {
if (user != opInst) {
localized = false;
@@ -361,7 +352,7 @@
return;
OpBuilder builder(opInst);
- SmallVector<Value *, 4> composedOpOperands(subOperands);
+ SmallVector<Value, 4> composedOpOperands(subOperands);
auto composedMap = builder.getMultiDimIdentityMap(composedOpOperands.size());
fullyComposeAffineMapAndOperands(&composedMap, &composedOpOperands);
@@ -378,7 +369,7 @@
// affine apply op above instead of existing ones (subOperands). So, they
// differ from opInst's operands only for those operands in 'subOperands', for
// which they will be replaced by the corresponding one from 'sliceOps'.
- SmallVector<Value *, 4> newOperands(opInst->getOperands());
+ SmallVector<Value, 4> newOperands(opInst->getOperands());
for (unsigned i = 0, e = newOperands.size(); i < e; i++) {
// Replace the subOperands from among the new operands.
unsigned j, f;
@@ -451,8 +442,8 @@
newShape[d] = ubConst.getValue() + 1;
}
- auto *oldMemRef = allocOp.getResult();
- SmallVector<Value *, 4> symbolOperands(allocOp.getSymbolicOperands());
+ auto oldMemRef = allocOp.getResult();
+ SmallVector<Value, 4> symbolOperands(allocOp.getSymbolicOperands());
auto newMemRefType = MemRefType::get(newShape, memrefType.getElementType(),
b.getMultiDimIdentityMap(newRank));
diff --git a/third_party/mlir/lib/Transforms/Vectorize.cpp b/third_party/mlir/lib/Transforms/Vectorize.cpp
index e3212d5..6b2b3e1 100644
--- a/third_party/mlir/lib/Transforms/Vectorize.cpp
+++ b/third_party/mlir/lib/Transforms/Vectorize.cpp
@@ -1,19 +1,10 @@
//===- Vectorize.cpp - Vectorize Pass Impl --------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements vectorization of loops, operations and data types to
// a target-independent, n-D super-vector abstraction.
@@ -705,7 +696,7 @@
// Map of old scalar Operation to new vectorized Operation.
DenseMap<Operation *, Operation *> vectorizationMap;
// Map of old scalar Value to new vectorized Value.
- DenseMap<Value *, Value *> replacementMap;
+ DenseMap<Value, Value> replacementMap;
// The strategy drives which loop to vectorize by which amount.
const VectorizationStrategy *strategy;
// Use-def roots. These represent the starting points for the worklist in the
@@ -728,7 +719,7 @@
OperationFolder *folder;
private:
- void registerReplacement(Value *key, Value *value);
+ void registerReplacement(Value key, Value value);
};
} // end namespace
@@ -768,7 +759,7 @@
}
}
-void VectorizationState::registerReplacement(Value *key, Value *value) {
+void VectorizationState::registerReplacement(Value key, Value value) {
assert(replacementMap.count(key) == 0 && "replacement already registered");
replacementMap.insert(std::make_pair(key, value));
}
@@ -776,7 +767,7 @@
// Apply 'map' with 'mapOperands' returning resulting values in 'results'.
static void computeMemoryOpIndices(Operation *op, AffineMap map,
ValueRange mapOperands,
- SmallVectorImpl<Value *> &results) {
+ SmallVectorImpl<Value> &results) {
OpBuilder builder(op);
for (auto resultExpr : map.getResults()) {
auto singleResMap =
@@ -803,7 +794,7 @@
/// Such special cases force us to delay the vectorization of the stores until
/// the last step. Here we merely register the store operation.
template <typename LoadOrStoreOpPointer>
-static LogicalResult vectorizeRootOrTerminal(Value *iv,
+static LogicalResult vectorizeRootOrTerminal(Value iv,
LoadOrStoreOpPointer memoryOp,
VectorizationState *state) {
auto memRefType = memoryOp.getMemRef()->getType().template cast<MemRefType>();
@@ -823,7 +814,7 @@
if (auto load = dyn_cast<AffineLoadOp>(opInst)) {
OpBuilder b(opInst);
ValueRange mapOperands = load.getMapOperands();
- SmallVector<Value *, 8> indices;
+ SmallVector<Value, 8> indices;
indices.reserve(load.getMemRefType().getRank());
if (load.getAffineMap() !=
b.getMultiDimIdentityMap(load.getMemRefType().getRank())) {
@@ -838,8 +829,7 @@
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
LLVM_DEBUG(permutationMap.print(dbgs()));
auto transfer = b.create<vector::TransferReadOp>(
- opInst->getLoc(), vectorType, memoryOp.getMemRef(),
- map(makePtrDynCaster<Value>(), indices),
+ opInst->getLoc(), vectorType, memoryOp.getMemRef(), indices,
AffineMapAttr::get(permutationMap),
// TODO(b/144455320) add a proper padding value, not just 0.0 : f32
state->folder->create<ConstantFloatOp>(b, opInst->getLoc(),
@@ -951,7 +941,7 @@
/// element type.
/// If `type` is not a valid vector type or if the scalar constant is not a
/// valid vector element type, returns nullptr.
-static Value *vectorizeConstant(Operation *op, ConstantOp constant, Type type) {
+static Value vectorizeConstant(Operation *op, ConstantOp constant, Type type) {
if (!type || !type.isa<VectorType>() ||
!VectorType::isValidElementType(constant.getType())) {
return nullptr;
@@ -989,8 +979,8 @@
/// vectorization is possible with the above logic. Returns nullptr otherwise.
///
/// TODO(ntv): handle more complex cases.
-static Value *vectorizeOperand(Value *operand, Operation *op,
- VectorizationState *state) {
+static Value vectorizeOperand(Value operand, Operation *op,
+ VectorizationState *state) {
LLVM_DEBUG(dbgs() << "\n[early-vect]vectorize operand: ");
LLVM_DEBUG(operand->print(dbgs()));
// 1. If this value has already been vectorized this round, we are done.
@@ -1004,7 +994,7 @@
// been vectorized. This would be invalid IR.
auto it = state->replacementMap.find(operand);
if (it != state->replacementMap.end()) {
- auto *res = it->second;
+ auto res = it->second;
LLVM_DEBUG(dbgs() << "-> delayed replacement by: ");
LLVM_DEBUG(res->print(dbgs()));
return res;
@@ -1047,12 +1037,12 @@
if (auto store = dyn_cast<AffineStoreOp>(opInst)) {
OpBuilder b(opInst);
- auto *memRef = store.getMemRef();
- auto *value = store.getValueToStore();
- auto *vectorValue = vectorizeOperand(value, opInst, state);
+ auto memRef = store.getMemRef();
+ auto value = store.getValueToStore();
+ auto vectorValue = vectorizeOperand(value, opInst, state);
ValueRange mapOperands = store.getMapOperands();
- SmallVector<Value *, 8> indices;
+ SmallVector<Value, 8> indices;
indices.reserve(store.getMemRefType().getRank());
if (store.getAffineMap() !=
b.getMultiDimIdentityMap(store.getMemRefType().getRank())) {
@@ -1081,16 +1071,16 @@
return nullptr;
SmallVector<Type, 8> vectorTypes;
- for (auto *v : opInst->getResults()) {
+ for (auto v : opInst->getResults()) {
vectorTypes.push_back(
VectorType::get(state->strategy->vectorSizes, v->getType()));
}
- SmallVector<Value *, 8> vectorOperands;
- for (auto *v : opInst->getOperands()) {
+ SmallVector<Value, 8> vectorOperands;
+ for (auto v : opInst->getOperands()) {
vectorOperands.push_back(vectorizeOperand(v, opInst, state));
}
// Check whether a single operand is null. If so, vectorization failed.
- bool success = llvm::all_of(vectorOperands, [](Value *op) { return op; });
+ bool success = llvm::all_of(vectorOperands, [](Value op) { return op; });
if (!success) {
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ an operand failed vectorize");
return nullptr;
diff --git a/third_party/mlir/lib/Transforms/ViewOpGraph.cpp b/third_party/mlir/lib/Transforms/ViewOpGraph.cpp
index 591562d..508c547 100644
--- a/third_party/mlir/lib/Transforms/ViewOpGraph.cpp
+++ b/third_party/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -1,19 +1,10 @@
//===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Transforms/ViewOpGraph.h"
#include "mlir/IR/Block.h"
diff --git a/third_party/mlir/lib/Transforms/ViewRegionGraph.cpp b/third_party/mlir/lib/Transforms/ViewRegionGraph.cpp
index db55415..7711108 100644
--- a/third_party/mlir/lib/Transforms/ViewRegionGraph.cpp
+++ b/third_party/mlir/lib/Transforms/ViewRegionGraph.cpp
@@ -1,19 +1,10 @@
//===- ViewRegionGraph.cpp - View/write graphviz graphs -------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Transforms/ViewRegionGraph.h"
#include "mlir/IR/RegionGraphTraits.h"
diff --git a/third_party/mlir/lib/Translation/Translation.cpp b/third_party/mlir/lib/Translation/Translation.cpp
index 8b5f987..80c1e48 100644
--- a/third_party/mlir/lib/Translation/Translation.cpp
+++ b/third_party/mlir/lib/Translation/Translation.cpp
@@ -1,19 +1,10 @@
//===- Translation.cpp - Translation registry -----------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Definitions of the translation registry.
//
diff --git a/third_party/mlir/test/APITest.h b/third_party/mlir/test/APITest.h
index 9475bae..08d64a0 100644
--- a/third_party/mlir/test/APITest.h
+++ b/third_party/mlir/test/APITest.h
@@ -1,19 +1,10 @@
//===- Test.h - Simple macros for API unit tests ----------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file define simple macros for declaring test functions and running them.
// The actual checking must be performed on the outputs with FileCheck.
diff --git a/third_party/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td b/third_party/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td
index d231392..d07f606 100644
--- a/third_party/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td
+++ b/third_party/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td
@@ -1,19 +1,10 @@
//===- TestLinalgTransformPatterns.td - Test patterns --*- tablegen ----*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This is the pattern definition file for declarative Linalg transformations
// tests.
diff --git a/third_party/mlir/test/lib/DeclarativeTransforms/TestVectorTransformPatterns.td b/third_party/mlir/test/lib/DeclarativeTransforms/TestVectorTransformPatterns.td
index 228a8a0..29875cc 100644
--- a/third_party/mlir/test/lib/DeclarativeTransforms/TestVectorTransformPatterns.td
+++ b/third_party/mlir/test/lib/DeclarativeTransforms/TestVectorTransformPatterns.td
@@ -1,19 +1,10 @@
//===- TestVectorTransformPatterns.td - Test patterns ---*- tablegen ----*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This is the pattern definition file for declarative Vector transformations
// tests.
diff --git a/third_party/mlir/test/lib/IR/TestFunc.cpp b/third_party/mlir/test/lib/IR/TestFunc.cpp
index 880d078..3e13159 100644
--- a/third_party/mlir/test/lib/IR/TestFunc.cpp
+++ b/third_party/mlir/test/lib/IR/TestFunc.cpp
@@ -1,19 +1,10 @@
//===- TestFunctionLike.cpp - Pass to test helpers on FunctionLike --------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/IR/Function.h"
#include "mlir/Pass/Pass.h"
diff --git a/third_party/mlir/test/lib/IR/TestMatchers.cpp b/third_party/mlir/test/lib/IR/TestMatchers.cpp
index 5985a88..b62daa8 100644
--- a/third_party/mlir/test/lib/IR/TestMatchers.cpp
+++ b/third_party/mlir/test/lib/IR/TestMatchers.cpp
@@ -1,19 +1,10 @@
//===- TestMatchers.cpp - Pass to test matchers ---------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/Function.h"
diff --git a/third_party/mlir/test/lib/IR/TestSymbolUses.cpp b/third_party/mlir/test/lib/IR/TestSymbolUses.cpp
index 8ef4bb4..c8fb1d8 100644
--- a/third_party/mlir/test/lib/IR/TestSymbolUses.cpp
+++ b/third_party/mlir/test/lib/IR/TestSymbolUses.cpp
@@ -1,19 +1,10 @@
//===- TestSymbolUses.cpp - Pass to test symbol uselists ------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "TestDialect.h"
#include "mlir/IR/Function.h"
diff --git a/third_party/mlir/test/lib/Pass/TestPassManager.cpp b/third_party/mlir/test/lib/Pass/TestPassManager.cpp
index d1e1a6d..cc926e1 100644
--- a/third_party/mlir/test/lib/Pass/TestPassManager.cpp
+++ b/third_party/mlir/test/lib/Pass/TestPassManager.cpp
@@ -1,19 +1,10 @@
//===- TestPassManager.cpp - Test pass manager functionality --------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/IR/Function.h"
#include "mlir/Pass/Pass.h"
@@ -30,43 +21,34 @@
};
class TestOptionsPass : public FunctionPass<TestOptionsPass> {
public:
- struct Options : public PassOptions<Options> {
- List<int> listOption{*this, "list", llvm::cl::MiscFlags::CommaSeparated,
- llvm::cl::desc("Example list option")};
- List<std::string> stringListOption{
+ struct Options : public PassPipelineOptions<Options> {
+ ListOption<int> listOption{*this, "list",
+ llvm::cl::MiscFlags::CommaSeparated,
+ llvm::cl::desc("Example list option")};
+ ListOption<std::string> stringListOption{
*this, "string-list", llvm::cl::MiscFlags::CommaSeparated,
llvm::cl::desc("Example string list option")};
Option<std::string> stringOption{*this, "string",
llvm::cl::desc("Example string option")};
};
+ TestOptionsPass() = default;
+ TestOptionsPass(const TestOptionsPass &) {}
TestOptionsPass(const Options &options) {
- listOption.assign(options.listOption.begin(), options.listOption.end());
- stringOption = options.stringOption;
- stringListOption.assign(options.stringListOption.begin(),
- options.stringListOption.end());
- }
-
- void printAsTextualPipeline(raw_ostream &os) final {
- os << "test-options-pass{";
- if (!listOption.empty()) {
- os << "list=";
- // Not interleaveComma to avoid spaces between the elements.
- interleave(listOption, os, ",");
- }
- if (!stringListOption.empty()) {
- os << " string-list=";
- interleave(stringListOption, os, ",");
- }
- if (!stringOption.empty())
- os << " string=" << stringOption;
- os << "}";
+ listOption->assign(options.listOption.begin(), options.listOption.end());
+ stringOption.setValue(options.stringOption);
+ stringListOption->assign(options.stringListOption.begin(),
+ options.stringListOption.end());
}
void runOnFunction() final {}
- SmallVector<int64_t, 4> listOption;
- SmallVector<std::string, 4> stringListOption;
- std::string stringOption;
+ ListOption<int> listOption{*this, "list", llvm::cl::MiscFlags::CommaSeparated,
+ llvm::cl::desc("Example list option")};
+ ListOption<std::string> stringListOption{
+ *this, "string-list", llvm::cl::MiscFlags::CommaSeparated,
+ llvm::cl::desc("Example string list option")};
+ Option<std::string> stringOption{*this, "string",
+ llvm::cl::desc("Example string option")};
};
/// A test pass that always aborts to enable testing the crash recovery
@@ -106,7 +88,7 @@
(void)parsePassPipeline("test-pm-nested-pipeline", pm);
}
-static PassRegistration<TestOptionsPass, TestOptionsPass::Options>
+static PassRegistration<TestOptionsPass>
reg("test-options-pass", "Test options parsing capabilities");
static PassRegistration<TestModulePass>
diff --git a/third_party/mlir/test/lib/TestDialect/TestDialect.cpp b/third_party/mlir/test/lib/TestDialect/TestDialect.cpp
index 059cfb3..21cf69e 100644
--- a/third_party/mlir/test/lib/TestDialect/TestDialect.cpp
+++ b/third_party/mlir/test/lib/TestDialect/TestDialect.cpp
@@ -1,19 +1,10 @@
//===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "TestDialect.h"
#include "mlir/IR/Function.h"
@@ -41,6 +32,20 @@
if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
setNameFn(asmOp, "result");
}
+
+ void getAsmBlockArgumentNames(Block *block,
+ OpAsmSetValueNameFn setNameFn) const final {
+ auto op = block->getParentOp();
+ auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
+ if (!arrayAttr)
+ return;
+ auto args = block->getArguments();
+ auto e = std::min(arrayAttr.size(), args.size());
+ for (unsigned i = 0; i < e; ++i) {
+ if (auto strAttr = arrayAttr.getValue()[i].dyn_cast<StringAttr>())
+ setNameFn(args[i], strAttr.getValue());
+ }
+ }
};
struct TestOpFolderDialectInterface : public OpFolderDialectInterface {
@@ -86,7 +91,7 @@
/// Handle the given inlined terminator by replacing it with a new operation
/// as necessary.
void handleTerminator(Operation *op,
- ArrayRef<Value *> valuesToRepl) const final {
+ ArrayRef<Value> valuesToRepl) const final {
// Only handle "test.return" here.
auto returnOp = dyn_cast<TestReturnOp>(op);
if (!returnOp)
@@ -103,7 +108,7 @@
/// operation that takes 'input' as the only operand, and produces a single
/// result of 'resultType'. If a conversion can not be generated, nullptr
/// should be returned.
- Operation *materializeCallConversion(OpBuilder &builder, Value *input,
+ Operation *materializeCallConversion(OpBuilder &builder, Value input,
Type resultType,
Location conversionLoc) const final {
// Only allow conversion for i16/i32 types.
@@ -217,7 +222,7 @@
// Create a return terminator in the inner region, pass as operand to the
// terminator the returned values from the wrapped operation.
- SmallVector<Value *, 8> return_operands(wrapped_op->getResults());
+ SmallVector<Value, 8> return_operands(wrapped_op->getResults());
OpBuilder builder(parser.getBuilder().getContext());
builder.setInsertionPointToEnd(&block);
builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
@@ -283,7 +288,7 @@
LogicalResult TestOpWithVariadicResultsAndFolder::fold(
ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
- for (Value *input : this->operands()) {
+ for (Value input : this->operands()) {
results.push_back(input);
}
return success();
diff --git a/third_party/mlir/test/lib/TestDialect/TestDialect.h b/third_party/mlir/test/lib/TestDialect/TestDialect.h
index 783b8a1..20db0f3 100644
--- a/third_party/mlir/test/lib/TestDialect/TestDialect.h
+++ b/third_party/mlir/test/lib/TestDialect/TestDialect.h
@@ -1,19 +1,10 @@
//===- TestDialect.h - MLIR Dialect for testing -----------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines a fake 'test' dialect that can be used for testing things
// that do not have a respective counterpart in the main source directories.
diff --git a/third_party/mlir/test/lib/TestDialect/TestOps.td b/third_party/mlir/test/lib/TestDialect/TestOps.td
index e33d9c2..dacb796 100644
--- a/third_party/mlir/test/lib/TestDialect/TestOps.td
+++ b/third_party/mlir/test/lib/TestDialect/TestOps.td
@@ -1,19 +1,10 @@
//===-- TestOps.td - Test dialect operation definitions ----*- tablegen -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#ifndef TEST_OPS
#define TEST_OPS
@@ -644,7 +635,7 @@
let builders = [
OpBuilder<
- "Builder *builder, OperationState &state, Value *operand",
+ "Builder *builder, OperationState &state, Value operand",
[{
state.types.assign({builder->getIntegerType(32)});
state.addOperands({operand});
diff --git a/third_party/mlir/test/lib/TestDialect/TestPatterns.cpp b/third_party/mlir/test/lib/TestDialect/TestPatterns.cpp
index 94eb792..929c4a9 100644
--- a/third_party/mlir/test/lib/TestDialect/TestPatterns.cpp
+++ b/third_party/mlir/test/lib/TestDialect/TestPatterns.cpp
@@ -1,19 +1,10 @@
//===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "TestDialect.h"
#include "mlir/IR/PatternMatch.h"
@@ -22,11 +13,11 @@
using namespace mlir;
// Native function for testing NativeCodeCall
-static Value *chooseOperand(Value *input1, Value *input2, BoolAttr choice) {
+static Value chooseOperand(Value input1, Value input2, BoolAttr choice) {
return choice.getValue() ? input1 : input2;
}
-static void createOpI(PatternRewriter &rewriter, Value *input) {
+static void createOpI(PatternRewriter &rewriter, Value input) {
rewriter.create<OpI>(rewriter.getUnknownLoc(), input);
}
@@ -73,7 +64,7 @@
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final {
if (auto retTypeFn = dyn_cast<InferTypeOpInterface>(op)) {
- SmallVector<Value *, 4> values(op->getOperands());
+ SmallVector<Value, 4> values(op->getOperands());
SmallVector<Type, 2> inferedReturnTypes;
if (failed(retTypeFn.inferReturnTypes(op->getLoc(), values,
op->getAttrs(), op->getRegions(),
@@ -132,7 +123,7 @@
: ConversionPattern("test.region", 1, ctx) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// Inline this region into the parent region.
auto &parentRegion = *op->getParentRegion();
@@ -165,7 +156,7 @@
// Add an explicitly illegal operation to ensure the conversion fails.
rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
- rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value *>());
+ rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>());
// Drop this operation.
rewriter.eraseOp(op);
@@ -182,7 +173,7 @@
: ConversionPattern("test.drop_region_op", 1, ctx), converter(converter) {
}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Region ®ion = op->getRegion(0);
Block *entry = ®ion.front();
@@ -208,7 +199,7 @@
TestPassthroughInvalidOp(MLIRContext *ctx)
: ConversionPattern("test.invalid", 1, ctx) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands,
llvm::None);
@@ -220,7 +211,7 @@
TestSplitReturnType(MLIRContext *ctx)
: ConversionPattern("test.return", 1, ctx) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// Check for a return of F32.
if (op->getNumOperands() != 1 || !op->getOperand(0)->getType().isF32())
@@ -245,7 +236,7 @@
TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
: ConversionPattern("test.type_producer", 1, ctx) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// If the type is I32, change the type to F32.
if (!(*op->result_type_begin()).isInteger(32))
@@ -258,7 +249,7 @@
TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
: ConversionPattern("test.type_producer", 1, ctx) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// If the type is F32, change the type to F64.
if (!(*op->result_type_begin()).isF32())
@@ -271,7 +262,7 @@
TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
: ConversionPattern("test.type_producer", 10, ctx) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// Always convert to B16, even though it is not a legal type. This tests
// that values are unmapped correctly.
@@ -283,7 +274,7 @@
TestUpdateConsumerType(MLIRContext *ctx)
: ConversionPattern("test.type_consumer", 1, ctx) {}
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// Verify that the incoming operand has been successfully remapped to F64.
if (!operands[0]->getType().isF64())
@@ -344,7 +335,7 @@
/// Override the hook to materialize a conversion. This is necessary because
/// we generate 1->N type mappings.
Operation *materializeConversion(PatternRewriter &rewriter, Type resultType,
- ArrayRef<Value *> inputs,
+ ArrayRef<Value> inputs,
Location loc) override {
return rewriter.create<TestCastOp>(loc, resultType, inputs);
}
@@ -467,13 +458,13 @@
using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
PatternMatchResult
- matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value *> operands,
+ matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto origOps = op.getOperands();
assert(std::distance(origOps.begin(), origOps.end()) == 1 &&
"One operand expected");
- Value *origOp = *origOps.begin();
- SmallVector<Value *, 2> remappedOperands;
+ Value origOp = *origOps.begin();
+ SmallVector<Value, 2> remappedOperands;
// Replicate the remapped original operand twice. Note that we don't used
// the remapped 'operand' since the goal is testing 'getRemappedValue'.
remappedOperands.push_back(rewriter.getRemappedValue(origOp));
diff --git a/third_party/mlir/test/lib/Transforms/TestCallGraph.cpp b/third_party/mlir/test/lib/Transforms/TestCallGraph.cpp
index debf5e7..6378d95 100644
--- a/third_party/mlir/test/lib/Transforms/TestCallGraph.cpp
+++ b/third_party/mlir/test/lib/Transforms/TestCallGraph.cpp
@@ -1,19 +1,10 @@
//===- TestCallGraph.cpp - Test callgraph construction and iteration ------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file contains test passes for constructing and iterating over a
// callgraph.
diff --git a/third_party/mlir/test/lib/Transforms/TestConstantFold.cpp b/third_party/mlir/test/lib/Transforms/TestConstantFold.cpp
index 5a0e9ed..f660bcc 100644
--- a/third_party/mlir/test/lib/Transforms/TestConstantFold.cpp
+++ b/third_party/mlir/test/lib/Transforms/TestConstantFold.cpp
@@ -1,19 +1,10 @@
//===- TestConstantFold.cpp - Pass to test constant folding ---------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/AffineOps/AffineOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
diff --git a/third_party/mlir/test/lib/Transforms/TestInlining.cpp b/third_party/mlir/test/lib/Transforms/TestInlining.cpp
index 0571dc6..3637828 100644
--- a/third_party/mlir/test/lib/Transforms/TestInlining.cpp
+++ b/third_party/mlir/test/lib/Transforms/TestInlining.cpp
@@ -1,19 +1,10 @@
//===- TestInlining.cpp - Pass to inline calls in the test dialect --------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// TODO(riverriddle) This pass is only necessary because the main inlining pass
// has no abstracted away the call+callee relationship. When the inlining
diff --git a/third_party/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/third_party/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index 37030ca..6ea995d 100644
--- a/third_party/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/third_party/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -1,19 +1,10 @@
//===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements logic for testing Linalg transformations.
//
diff --git a/third_party/mlir/test/lib/Transforms/TestLiveness.cpp b/third_party/mlir/test/lib/Transforms/TestLiveness.cpp
index d970602..2372574 100644
--- a/third_party/mlir/test/lib/Transforms/TestLiveness.cpp
+++ b/third_party/mlir/test/lib/Transforms/TestLiveness.cpp
@@ -1,20 +1,11 @@
//===- TestLiveness.cpp - Test liveness construction and information
//-------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file contains test passes for constructing and resolving liveness
// information.
diff --git a/third_party/mlir/test/lib/Transforms/TestLoopFusion.cpp b/third_party/mlir/test/lib/Transforms/TestLoopFusion.cpp
index 7dc722f..23e5035 100644
--- a/third_party/mlir/test/lib/Transforms/TestLoopFusion.cpp
+++ b/third_party/mlir/test/lib/Transforms/TestLoopFusion.cpp
@@ -1,19 +1,10 @@
//===- TestLoopFusion.cpp - Test loop fusion ------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements a pass to test various loop fusion utility functions.
//
diff --git a/third_party/mlir/test/lib/Transforms/TestLoopMapping.cpp b/third_party/mlir/test/lib/Transforms/TestLoopMapping.cpp
index c25fea9..86e5713 100644
--- a/third_party/mlir/test/lib/Transforms/TestLoopMapping.cpp
+++ b/third_party/mlir/test/lib/Transforms/TestLoopMapping.cpp
@@ -1,19 +1,10 @@
//===- TestLoopMapping.cpp --- Parametric loop mapping pass ---------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements a pass to parametrically map loop.for loops to virtual
// processing element dimensions.
@@ -41,7 +32,7 @@
// SSA values for the transformation are created out of thin air by
// unregistered "new_processor_id_and_range" operations. This is enough to
// emulate mapping conditions.
- SmallVector<Value *, 8> processorIds, numProcessors;
+ SmallVector<Value, 8> processorIds, numProcessors;
func.walk([&processorIds, &numProcessors](Operation *op) {
if (op->getName().getStringRef() != "new_processor_id_and_range")
return;
diff --git a/third_party/mlir/test/lib/Transforms/TestLoopParametricTiling.cpp b/third_party/mlir/test/lib/Transforms/TestLoopParametricTiling.cpp
index 9a8e191..e793ee5 100644
--- a/third_party/mlir/test/lib/Transforms/TestLoopParametricTiling.cpp
+++ b/third_party/mlir/test/lib/Transforms/TestLoopParametricTiling.cpp
@@ -1,19 +1,10 @@
//===- TestLoopParametricTiling.cpp --- Parametric loop tiling pass -------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements a pass to parametrically tile nests of standard loops.
//
@@ -34,18 +25,10 @@
class SimpleParametricLoopTilingPass
: public FunctionPass<SimpleParametricLoopTilingPass> {
public:
- struct Options : public PassOptions<Options> {
- List<int> clOuterLoopSizes{
- *this, "test-outer-loop-sizes", llvm::cl::MiscFlags::CommaSeparated,
- llvm::cl::desc(
- "fixed number of iterations that the outer loops should have")};
- };
-
- explicit SimpleParametricLoopTilingPass(ArrayRef<int64_t> outerLoopSizes)
- : sizes(outerLoopSizes.begin(), outerLoopSizes.end()) {}
- explicit SimpleParametricLoopTilingPass(const Options &options) {
- sizes.assign(options.clOuterLoopSizes.begin(),
- options.clOuterLoopSizes.end());
+ SimpleParametricLoopTilingPass() = default;
+ SimpleParametricLoopTilingPass(const SimpleParametricLoopTilingPass &) {}
+ explicit SimpleParametricLoopTilingPass(ArrayRef<int64_t> outerLoopSizes) {
+ sizes = outerLoopSizes;
}
void runOnFunction() override {
@@ -58,7 +41,10 @@
});
}
- SmallVector<int64_t, 4> sizes;
+ ListOption<int64_t> sizes{
+ *this, "test-outer-loop-sizes", llvm::cl::MiscFlags::CommaSeparated,
+ llvm::cl::desc(
+ "fixed number of iterations that the outer loops should have")};
};
} // end namespace
@@ -67,8 +53,7 @@
return std::make_unique<SimpleParametricLoopTilingPass>(outerLoopSizes);
}
-static PassRegistration<SimpleParametricLoopTilingPass,
- SimpleParametricLoopTilingPass::Options>
+static PassRegistration<SimpleParametricLoopTilingPass>
reg("test-extract-fixed-outer-loops",
"test application of parametric tiling to the outer loops so that the "
"ranges of outer loops become static");
diff --git a/third_party/mlir/test/lib/Transforms/TestMemRefStrideCalculation.cpp b/third_party/mlir/test/lib/Transforms/TestMemRefStrideCalculation.cpp
index 40788b2..d5e0b7d 100644
--- a/third_party/mlir/test/lib/Transforms/TestMemRefStrideCalculation.cpp
+++ b/third_party/mlir/test/lib/Transforms/TestMemRefStrideCalculation.cpp
@@ -1,19 +1,10 @@
//===- TestMemRefStrideCalculation.cpp - Pass to test strides computation--===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/StandardTypes.h"
diff --git a/third_party/mlir/test/lib/Transforms/TestOpaqueLoc.cpp b/third_party/mlir/test/lib/Transforms/TestOpaqueLoc.cpp
index 0db5332..9a261c0 100644
--- a/third_party/mlir/test/lib/Transforms/TestOpaqueLoc.cpp
+++ b/third_party/mlir/test/lib/Transforms/TestOpaqueLoc.cpp
@@ -1,19 +1,10 @@
//===- TestOpaqueLoc.cpp - Pass to test opaque locations ------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/Builders.h"
diff --git a/third_party/mlir/test/lib/Transforms/TestVectorToLoopsConversion.cpp b/third_party/mlir/test/lib/Transforms/TestVectorToLoopsConversion.cpp
index e5f5f74..a31f8e4 100644
--- a/third_party/mlir/test/lib/Transforms/TestVectorToLoopsConversion.cpp
+++ b/third_party/mlir/test/lib/Transforms/TestVectorToLoopsConversion.cpp
@@ -1,19 +1,10 @@
//===- TestVectorToLoopsConversion.cpp - Test VectorTransfers lowering ----===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include <type_traits>
diff --git a/third_party/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/third_party/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index 1d51306..664d49a 100644
--- a/third_party/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/third_party/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -1,19 +1,10 @@
//===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
#include <type_traits>
diff --git a/third_party/mlir/test/lib/Transforms/TestVectorizationUtils.cpp b/third_party/mlir/test/lib/Transforms/TestVectorizationUtils.cpp
index 7efc74f..6f4d948 100644
--- a/third_party/mlir/test/lib/Transforms/TestVectorizationUtils.cpp
+++ b/third_party/mlir/test/lib/Transforms/TestVectorizationUtils.cpp
@@ -1,19 +1,10 @@
//===- VectorizerTestPass.cpp - VectorizerTestPass Pass Impl --------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file implements a simple testing pass for vectorization functionality.
//
@@ -245,7 +236,7 @@
for (auto m : matches) {
auto app = cast<AffineApplyOp>(m.getMatchedOperation());
OpBuilder b(m.getMatchedOperation());
- SmallVector<Value *, 8> operands(app.getOperands());
+ SmallVector<Value, 8> operands(app.getOperands());
makeComposedAffineApply(b, app.getLoc(), app.getAffineMap(), operands);
}
}
diff --git a/third_party/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp b/third_party/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp
index f7023c4c..144f73d 100644
--- a/third_party/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp
+++ b/third_party/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp
@@ -1,19 +1,10 @@
//===- mlir-cpu-runner.cpp - MLIR CPU Execution Driver---------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Main entry point to a command line utility that executes an MLIR file on the
// CPU by translating MLIR to LLVM IR before JIT-compiling and executing the
diff --git a/third_party/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp b/third_party/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
index 0698095..9f1591b 100644
--- a/third_party/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
+++ b/third_party/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
@@ -1,19 +1,10 @@
//===- cuda-runtime-wrappers.cpp - MLIR CUDA runner wrapper library -------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Implements C wrappers around the CUDA library for easy linking in ORC jit.
// Also adds some debugging helpers that are helpful when writing MLIR code to
diff --git a/third_party/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp b/third_party/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp
index c1ca4eb..d6160d6 100644
--- a/third_party/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp
+++ b/third_party/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp
@@ -1,19 +1,10 @@
//===- mlir-cpu-runner.cpp - MLIR CPU Execution Driver---------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This is a command line utility that executes an MLIR file on the GPU by
// translating MLIR to NVVM/LVVM IR before JIT-compiling and executing the
diff --git a/third_party/mlir/tools/mlir-opt/mlir-opt.cpp b/third_party/mlir/tools/mlir-opt/mlir-opt.cpp
index d01f66d..b0dd1b5 100644
--- a/third_party/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/third_party/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -1,19 +1,10 @@
//===- mlir-opt.cpp - MLIR Optimizer Driver -------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// Main entry function for mlir-opt for when built as standalone binary.
//
diff --git a/third_party/mlir/tools/mlir-tblgen/DocGenUtilities.h b/third_party/mlir/tools/mlir-tblgen/DocGenUtilities.h
index b761774..1b3c854 100644
--- a/third_party/mlir/tools/mlir-tblgen/DocGenUtilities.h
+++ b/third_party/mlir/tools/mlir-tblgen/DocGenUtilities.h
@@ -1,19 +1,10 @@
//===- DocGenUtilities.h - MLIR doc gen utilities ---------------*- C++ -*-===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines common utilities for generating documents from tablegen
// structures.
diff --git a/third_party/mlir/tools/mlir-tblgen/EnumsGen.cpp b/third_party/mlir/tools/mlir-tblgen/EnumsGen.cpp
index e278fdd..610a380 100644
--- a/third_party/mlir/tools/mlir-tblgen/EnumsGen.cpp
+++ b/third_party/mlir/tools/mlir-tblgen/EnumsGen.cpp
@@ -1,19 +1,10 @@
//===- EnumsGen.cpp - MLIR enum utility generator -------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// EnumsGen generates common utility functions for enums.
//
diff --git a/third_party/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/third_party/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
index f4b1279..30f720e 100644
--- a/third_party/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
+++ b/third_party/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
@@ -1,19 +1,10 @@
//===- LLVMIRConversionGen.cpp - MLIR LLVM IR builder generator -----------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file uses tablegen definitions of the LLVM IR Dialect operations to
// generate the code building the LLVM IR from it.
diff --git a/third_party/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/third_party/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index dd56458..f5b3e01 100644
--- a/third_party/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/third_party/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1,19 +1,10 @@
//===- OpDefinitionsGen.cpp - MLIR op definitions generator ---------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// OpDefinitionsGen uses the description of operations to generate C++
// definitions for ops.
@@ -713,11 +704,12 @@
// Generates the named operand getter methods for the given Operator `op` and
// puts them in `opClass`. Uses `rangeType` as the return type of getters that
-// return a range of operands (individual operands are `Value *` and each
-// element in the range must also be `Value *`); use `rangeBeginCall` to get an
-// iterator to the beginning of the operand range; use `rangeSizeCall` to obtain
-// the number of operands. `getOperandCallPattern` contains the code necessary
-// to obtain a single operand whose position will be substituted instead of
+// return a range of operands (individual operands are `Value ` and each
+// element in the range must also be `Value `); use `rangeBeginCall` to get
+// an iterator to the beginning of the operand range; use `rangeSizeCall` to
+// obtain the number of operands. `getOperandCallPattern` contains the code
+// necessary to obtain a single operand whose position will be substituted
+// instead of
// "{0}" marker in the pattern. Note that the pattern should work for any kind
// of ops, in particular for one-operand ops that may not have the
// `getOperand(unsigned)` method.
@@ -790,7 +782,7 @@
auto &m = opClass.newMethod(rangeType, operand.name);
m.body() << " return getODSOperands(" << i << ");";
} else {
- auto &m = opClass.newMethod("Value *", operand.name);
+ auto &m = opClass.newMethod("Value ", operand.name);
m.body() << " return *getODSOperands(" << i << ").begin();";
}
}
@@ -868,7 +860,7 @@
auto &m = opClass.newMethod("Operation::result_range", result.name);
m.body() << " return getODSResults(" << i << ");";
} else {
- auto &m = opClass.newMethod("Value *", result.name);
+ auto &m = opClass.newMethod("Value ", result.name);
m.body() << " return *getODSResults(" << i << ").begin();";
}
}
@@ -1246,7 +1238,7 @@
auto argument = op.getArg(i);
if (argument.is<tblgen::NamedTypeConstraint *>()) {
const auto &operand = op.getOperand(numOperands);
- paramList.append(operand.isVariadic() ? ", ValueRange " : ", Value *");
+ paramList.append(operand.isVariadic() ? ", ValueRange " : ", Value ");
paramList.append(getArgumentName(op, numOperands));
++numOperands;
} else {
@@ -1535,7 +1527,7 @@
continue;
// Emit a loop to check all the dynamic values in the pack.
- body << formatv(" for (Value *v : getODS{0}{1}s({2})) {{\n",
+ body << formatv(" for (Value v : getODS{0}{1}s({2})) {{\n",
// Capitalize the first letter to match the function name
valueKind.substr(0, 1).upper(), valueKind.substr(1),
staticValue.index());
@@ -1690,7 +1682,7 @@
namespace {
// Helper class to emit Op operand adaptors to an output stream. Operand
-// adaptors are wrappers around ArrayRef<Value *> that provide named operand
+// adaptors are wrappers around ArrayRef<Value> that provide named operand
// getters identical to those defined in the Op.
class OpOperandAdaptorEmitter {
public:
@@ -1706,12 +1698,12 @@
OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
: adapterClass(op.getCppClassName().str() + "OperandAdaptor") {
- adapterClass.newField("ArrayRef<Value *>", "tblgen_operands");
- auto &constructor = adapterClass.newConstructor("ArrayRef<Value *> values");
+ adapterClass.newField("ArrayRef<Value>", "tblgen_operands");
+ auto &constructor = adapterClass.newConstructor("ArrayRef<Value> values");
constructor.body() << " tblgen_operands = values;\n";
generateNamedOperandGetters(op, adapterClass,
- /*rangeType=*/"ArrayRef<Value *>",
+ /*rangeType=*/"ArrayRef<Value>",
/*rangeBeginCall=*/"tblgen_operands.begin()",
/*rangeSizeCall=*/"tblgen_operands.size()",
/*getOperandCallPattern=*/"tblgen_operands[{0}]");
diff --git a/third_party/mlir/tools/mlir-tblgen/OpDocGen.cpp b/third_party/mlir/tools/mlir-tblgen/OpDocGen.cpp
index 8b048d9..87a2723 100644
--- a/third_party/mlir/tools/mlir-tblgen/OpDocGen.cpp
+++ b/third_party/mlir/tools/mlir-tblgen/OpDocGen.cpp
@@ -1,19 +1,10 @@
//===- OpDocGen.cpp - MLIR operation documentation generator --------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// OpDocGen uses the description of operations to generate documentation for the
// operations.
diff --git a/third_party/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/third_party/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index a48bd25..a96736c 100644
--- a/third_party/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/third_party/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -1,19 +1,10 @@
//===- OpInterfacesGen.cpp - MLIR op interface utility generator ----------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// OpInterfacesGen generates definitions for operation interfaces.
//
diff --git a/third_party/mlir/tools/mlir-tblgen/ReferenceImplGen.cpp b/third_party/mlir/tools/mlir-tblgen/ReferenceImplGen.cpp
index 9181d0e..90b60e5 100644
--- a/third_party/mlir/tools/mlir-tblgen/ReferenceImplGen.cpp
+++ b/third_party/mlir/tools/mlir-tblgen/ReferenceImplGen.cpp
@@ -1,19 +1,10 @@
//===- ReferenceImplGen.cpp - MLIR reference implementation generator -----===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// ReferenceImplGen uses the description of operations to generate reference
// implementations for the ops.
diff --git a/third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp b/third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp
index b2376e8..824ddae 100644
--- a/third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -1,19 +1,10 @@
//===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
//
@@ -576,14 +567,14 @@
os.indent(4) << "rewriter.eraseOp(op0);\n";
} else {
// Process replacement result patterns.
- os.indent(4) << "SmallVector<Value *, 4> tblgen_repl_values;\n";
+ os.indent(4) << "SmallVector<Value, 4> tblgen_repl_values;\n";
for (int i = replStartIndex; i < numResultPatterns; ++i) {
DagNode resultTree = pattern.getResultPattern(i);
auto val = handleResultPattern(resultTree, offsets[i], 0);
os.indent(4) << "\n";
// Resolve each symbol for all range use so that we can loop over them.
os << symbolInfoMap.getAllRangeUse(
- val, " for (auto *v : {0}) {{ tblgen_repl_values.push_back(v); }",
+ val, " for (auto v : {0}) {{ tblgen_repl_values.push_back(v); }",
"\n");
}
os.indent(4) << "\n";
@@ -819,7 +810,7 @@
int numResults = resultOp.getNumResults();
if (numResults != 0) {
for (int i = 0; i < numResults; ++i)
- os.indent(6) << formatv("for (auto *v : castedOp0.getODSResults({0})) {{"
+ os.indent(6) << formatv("for (auto v : castedOp0.getODSResults({0})) {{"
"tblgen_types.push_back(v->getType()); }\n",
resultIndex + i);
}
@@ -835,8 +826,8 @@
Operator &resultOp = node.getDialectOp(opMap);
// Now prepare operands used for building this op:
- // * If the operand is non-variadic, we create a `Value*` local variable.
- // * If the operand is variadic, we create a `SmallVector<Value*>` local
+ // * If the operand is non-variadic, we create a `Value` local variable.
+ // * If the operand is variadic, we create a `SmallVector<Value>` local
// variable.
int valueIndex = 0; // An index for uniquing local variable names.
@@ -851,7 +842,7 @@
std::string varName;
if (operand->isVariadic()) {
varName = formatv("tblgen_values_{0}", valueIndex++);
- os.indent(6) << formatv("SmallVector<Value *, 4> {0};\n", varName);
+ os.indent(6) << formatv("SmallVector<Value, 4> {0};\n", varName);
std::string range;
if (node.isNestedDagArg(argIndex)) {
range = childNodeNames[argIndex];
@@ -861,11 +852,11 @@
// Resolve the symbol for all range use so that we have a uniform way of
// capturing the values.
range = symbolInfoMap.getValueAndRangeUse(range);
- os.indent(6) << formatv("for (auto *v : {0}) {1}.push_back(v);\n", range,
+ os.indent(6) << formatv("for (auto v : {0}) {1}.push_back(v);\n", range,
varName);
} else {
varName = formatv("tblgen_value_{0}", valueIndex++);
- os.indent(6) << formatv("Value *{0} = ", varName);
+ os.indent(6) << formatv("Value {0} = ", varName);
if (node.isNestedDagArg(argIndex)) {
os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
} else {
@@ -934,7 +925,7 @@
Operator &resultOp = node.getDialectOp(opMap);
os.indent(6) << formatv(
- "SmallVector<Value *, 4> tblgen_values; (void)tblgen_values;\n");
+ "SmallVector<Value, 4> tblgen_values; (void)tblgen_values;\n");
os.indent(6) << formatv(
"SmallVector<NamedAttribute, 4> tblgen_attrs; (void)tblgen_attrs;\n");
@@ -975,7 +966,7 @@
// capturing the values.
range = symbolInfoMap.getValueAndRangeUse(range);
os.indent(6) << formatv(
- "for (auto *v : {0}) tblgen_values.push_back(v);\n", range);
+ "for (auto v : {0}) tblgen_values.push_back(v);\n", range);
} else {
os.indent(6) << formatv("tblgen_values.push_back(", varName);
if (node.isNestedDagArg(argIndex)) {
diff --git a/third_party/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/third_party/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index f1712ef..d65b216 100644
--- a/third_party/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/third_party/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -1,19 +1,10 @@
//===- SPIRVSerializationGen.cpp - SPIR-V serialization utility generator -===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// SPIRVSerializationGen generates common utility functions for SPIR-V
// serialization.
@@ -470,7 +461,7 @@
emitResultDeserialization(op, record->getLoc(), " ", words, wordIndex,
resultTypes, valueID, os);
- os << formatv(" SmallVector<Value *, 4> {0};\n", operands);
+ os << formatv(" SmallVector<Value, 4> {0};\n", operands);
os << formatv(" SmallVector<NamedAttribute, 4> {0};\n", attributes);
// Operand deserialization
emitOperandDeserialization(op, record->getLoc(), " ", words, wordIndex,
diff --git a/third_party/mlir/tools/mlir-tblgen/StructsGen.cpp b/third_party/mlir/tools/mlir-tblgen/StructsGen.cpp
index d884495..576085e 100644
--- a/third_party/mlir/tools/mlir-tblgen/StructsGen.cpp
+++ b/third_party/mlir/tools/mlir-tblgen/StructsGen.cpp
@@ -1,19 +1,10 @@
//===- StructsGen.cpp - MLIR struct utility generator ---------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// StructsGen generates common utility functions for grouping attributes into a
// set of structured data.
diff --git a/third_party/mlir/tools/mlir-tblgen/mlir-tblgen.cpp b/third_party/mlir/tools/mlir-tblgen/mlir-tblgen.cpp
index 993a05d..3c9778b 100644
--- a/third_party/mlir/tools/mlir-tblgen/mlir-tblgen.cpp
+++ b/third_party/mlir/tools/mlir-tblgen/mlir-tblgen.cpp
@@ -1,19 +1,10 @@
//===- mlir-tblgen.cpp - Top-Level TableGen implementation for MLIR -------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file contains the main function for MLIR's TableGen.
//
diff --git a/third_party/mlir/tools/mlir-translate/mlir-translate.cpp b/third_party/mlir/tools/mlir-translate/mlir-translate.cpp
index b5622e3..3b15c5f 100644
--- a/third_party/mlir/tools/mlir-translate/mlir-translate.cpp
+++ b/third_party/mlir/tools/mlir-translate/mlir-translate.cpp
@@ -1,19 +1,10 @@
//===- mlir-translate.cpp - MLIR Translate Driver -------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This is a command line utility that translates a file from/to MLIR using one
// of the registered translations.
diff --git a/third_party/mlir/utils/generate-test-checks.py b/third_party/mlir/utils/generate-test-checks.py
index 3bb4ffe..6dc40c7 100755
--- a/third_party/mlir/utils/generate-test-checks.py
+++ b/third_party/mlir/utils/generate-test-checks.py
@@ -17,19 +17,9 @@
about what constitutes a good test!
"""
-# Copyright 2019 The MLIR Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import argparse
import os # Used to advertise this file's name ("autogenerated_note").
diff --git a/third_party/mlir/utils/spirv/define_enum.sh b/third_party/mlir/utils/spirv/define_enum.sh
index 9da898f..87b88c9 100755
--- a/third_party/mlir/utils/spirv/define_enum.sh
+++ b/third_party/mlir/utils/spirv/define_enum.sh
@@ -1,18 +1,8 @@
#!/bin/bash
-# Copyright 2019 The MLIR Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Script for defining a new enum attr using SPIR-V spec from the Internet.
#
diff --git a/third_party/mlir/utils/spirv/define_inst.sh b/third_party/mlir/utils/spirv/define_inst.sh
index f11078a..db58813 100755
--- a/third_party/mlir/utils/spirv/define_inst.sh
+++ b/third_party/mlir/utils/spirv/define_inst.sh
@@ -1,47 +1,35 @@
#!/bin/bash
-# Copyright 2019 The MLIR Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Script for defining a new op using SPIR-V spec from the Internet.
#
# Run as:
-# ./define_inst.sh <filename> <inst_category> (<opname>)*
+# ./define_inst.sh <filename> <baseclass> (<opname>)*
# <filename> is required, which is the file name of MLIR SPIR-V op definitions
# spec.
-# <inst_category> is required. It can be one of
-# (Op|ArithmeticOp|LogicalOp|ControlFlowOp|StructureOp). Based on the
-# inst_category the file SPIRV<inst_category>s.td is updated with the
-# instruction definition. If <opname> is missing, this script updates existing
-# ones in SPIRV<inst_category>s.td
+# <baseclass> is required. It will be the direct base class the newly defined
+# op will drive from.
+# If <opname> is missing, this script updates existing ones in <filename>.
# For example:
-# ./define_inst.sh SPIRVArithmeticOps.td ArithmeticOp OpIAdd
+# ./define_inst.sh SPIRVArithmeticOps.td ArithmeticBianryOp OpIAdd
# ./define_inst.sh SPIRVLogicalOps.td LogicalOp OpFOrdEqual
set -e
file_name=$1
-inst_category=$2
+baseclass=$2
-case $inst_category in
- Op | ArithmeticOp | LogicalOp | CastOp | ControlFlowOp | StructureOp | AtomicUpdateOp | AtomicUpdateWithValueOp)
+case $baseclass in
+ Op | ArithmeticBinaryOp | ArithmeticUnaryOp | LogicalBinaryOp | LogicalUnaryOp | CastOp | ControlFlowOp | StructureOp | AtomicUpdateOp | AtomicUpdateWithValueOp)
;;
*)
- echo "Usage : " $0 "<filename> <inst_category> (<opname>)*"
+ echo "Usage : " $0 "<filename> <baseclass> (<opname>)*"
echo "<filename> is the file name of MLIR SPIR-V op definitions spec"
- echo "<inst_category> must be one of " \
- "(Op|ArithmeticOp|LogicalOp|CastOp|ControlFlowOp|StructureOp|AtomicUpdateOp)"
+ echo "<baseclass> must be one of " \
+ "(Op|ArithmeticBinaryOp|ArithmeticUnaryOp|LogicalBinaryOp|LogicalUnaryOp|CastOp|ControlFlowOp|StructureOp|AtomicUpdateOp)"
exit 1;
;;
esac
@@ -55,7 +43,7 @@
python3 ${current_dir}/gen_spirv_dialect.py \
--op-td-path \
${current_dir}/../../include/mlir/Dialect/SPIRV/${file_name} \
- --inst-category $inst_category --new-inst "$@"
+ --inst-category $baseclass --new-inst "$@"
${current_dir}/define_opcodes.sh "$@"
diff --git a/third_party/mlir/utils/spirv/define_opcodes.sh b/third_party/mlir/utils/spirv/define_opcodes.sh
index 05c3657..7b9aeab 100755
--- a/third_party/mlir/utils/spirv/define_opcodes.sh
+++ b/third_party/mlir/utils/spirv/define_opcodes.sh
@@ -1,18 +1,8 @@
#!/bin/bash
-# Copyright 2019 The MLIR Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Script for defining map for opname to opcode using SPIR-V spec from the
# Internet
diff --git a/third_party/mlir/utils/spirv/gen_spirv_dialect.py b/third_party/mlir/utils/spirv/gen_spirv_dialect.py
index be7116c..2433cf4 100755
--- a/third_party/mlir/utils/spirv/gen_spirv_dialect.py
+++ b/third_party/mlir/utils/spirv/gen_spirv_dialect.py
@@ -1,19 +1,9 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
-# Copyright 2019 The MLIR Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Script for updating SPIR-V dialect by scraping information from SPIR-V
# HTML and JSON specs from the Internet.
diff --git a/third_party/repo.bzl b/third_party/repo.bzl
index 8cb10cd..cb3e06a 100644
--- a/third_party/repo.bzl
+++ b/third_party/repo.bzl
@@ -118,6 +118,10 @@
for internal_src, external_dest in ctx.attr.system_link_files.items():
ctx.symlink(Label(internal_src), ctx.path(external_dest))
+ if ctx.attr.additional_build_files:
+ for internal_src, external_dest in ctx.attr.additional_build_files.items():
+ ctx.symlink(Label(internal_src), ctx.path(external_dest))
+
tf_http_archive = repository_rule(
implementation = _tf_http_archive,
attrs = {
@@ -130,6 +134,7 @@
"build_file": attr.label(),
"system_build_file": attr.label(),
"system_link_files": attr.string_dict(),
+ "additional_build_files": attr.string_dict(),
},
environ = [
"TF_SYSTEM_LIBS",