Merge pull request #39399 from marload:patch-0511-01
PiperOrigin-RevId: 311208506
Change-Id: Ic4d22e1b01554cfe7eaab759ecbc0c2cadc3c002
diff --git a/.bazelrc b/.bazelrc
index cf15d09..224238d 100644
--- a/.bazelrc
+++ b/.bazelrc
@@ -163,6 +163,8 @@
build:dbg --config=opt -c dbg
# for now, disable arm_neon. see: https://github.com/tensorflow/tensorflow/issues/33360
build:dbg --cxxopt -DTF_LITE_DISABLE_X86_NEON
+# AWS SDK must be compiled in release mode. see: https://github.com/tensorflow/tensorflow/issues/37498
+build:dbg --copt -DDEBUG_BUILD
build:tensorrt --action_env TF_NEED_TENSORRT=1
diff --git a/configure.py b/configure.py
index ca2ff59..945c303 100644
--- a/configure.py
+++ b/configure.py
@@ -217,7 +217,8 @@
elif not os.path.exists(python_bin_path):
print('Invalid python path: {} cannot be found.'.format(python_bin_path))
else:
- print('{} is not executable. Is it the python binary?'.format(python_bin_path))
+ print('{} is not executable. Is it the python binary?'.format(
+ python_bin_path))
environ_cp['PYTHON_BIN_PATH'] = ''
# Convert python path to Windows style before checking lib and version
@@ -320,7 +321,8 @@
Raise the error to avoid infinitely looping.
"""
if not question:
- question = 'Do you wish to build TensorFlow with {} support?'.format(query_item)
+ question = 'Do you wish to build TensorFlow with {} support?'.format(
+ query_item)
if not yes_reply:
yes_reply = '{} support will be enabled for TensorFlow.'.format(query_item)
if not no_reply:
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index bf3af3c..ab4316d 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -530,6 +530,13 @@
# TODO(b/154762408) Remove this package group once it's no longer needed.
package_group(name = "composite_tensor_whitelist")
+# Packages that use private types symbols, until they are exported.
+# TODO(b/154650521) Remove.
+package_group(
+ name = "types_whitelist",
+ packages = ["//learning/deepmind/tensorflow/replicator/..."],
+)
+
filegroup(
name = "intel_binary_blob",
data = if_mkl_ml(
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 73c2f78..5c01ccb 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -924,7 +924,7 @@
context->GetDevicePlacementPolicy());
}
-TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
+TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
tensorflow::Tensor tensor;
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
if (!status->status.ok()) return nullptr;
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index 070b3a9..5afe304 100644
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -137,7 +137,7 @@
// placed in memory of different devices or remote address spaces.
typedef struct TFE_TensorHandle TFE_TensorHandle;
-TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t,
+TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t,
TF_Status* status);
// Indicates that the caller will not be using `h` any more.
TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h);
diff --git a/tensorflow/c/eager/c_api_unified_experimental_test.cc b/tensorflow/c/eager/c_api_unified_experimental_test.cc
index 8c9aa97..bd99189 100644
--- a/tensorflow/c/eager/c_api_unified_experimental_test.cc
+++ b/tensorflow/c/eager/c_api_unified_experimental_test.cc
@@ -29,7 +29,7 @@
namespace tensorflow {
namespace {
-TEST(UnifedCAPI, TestBasicEager) {
+TEST(UnifiedCAPI, TestBasicEager) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
@@ -81,7 +81,7 @@
TF_DeleteExecutionContext(ctx);
}
-TEST(UnifedCAPI, TestBasicGraph) {
+TEST(UnifiedCAPI, TestBasicGraph) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
@@ -185,7 +185,7 @@
TF_DeleteExecutionContext(eager_execution_ctx);
}
-TEST(UnifedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
+TEST(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
@@ -201,7 +201,7 @@
TF_DeleteExecutionContext(ctx);
}
-TEST(UnifedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
+TEST(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
@@ -222,7 +222,7 @@
TF_DeleteExecutionContext(graph_ctx);
}
-TEST(UnifedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
+TEST(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
@@ -243,7 +243,7 @@
TF_DeleteExecutionContext(graph_ctx);
}
-TEST(UnifedCAPI, TestExecutingEagerOpInGraphModeRaises) {
+TEST(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) {
// Build an Eager context.
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
@@ -289,7 +289,7 @@
TF_DeleteExecutionContext(graph_ctx);
}
-TEST(UnifedCAPI, TestExecutingGraphOpInEagerModeRaises) {
+TEST(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
diff --git a/tensorflow/c/eager/parallel_device/BUILD b/tensorflow/c/eager/parallel_device/BUILD
index f4dbcc6..3b2640e 100644
--- a/tensorflow/c/eager/parallel_device/BUILD
+++ b/tensorflow/c/eager/parallel_device/BUILD
@@ -27,6 +27,7 @@
name = "parallel_device",
srcs = [":sources"],
hdrs = [":headers"],
+ visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/c:c_api",
"//tensorflow/c/eager:c_api",
@@ -43,6 +44,7 @@
srcs = ["parallel_device_test.cc"],
deps = [
":parallel_device",
+ ":parallel_device_ops",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_experimental",
"//tensorflow/c/eager:c_api",
@@ -52,3 +54,19 @@
"//tensorflow/core:test_main",
],
)
+
+# Note: ParallelDevice-specific ops are experimental and not currently linked in
+# to TensorFlow by default, just used in a few tests.
+filegroup(
+ name = "parallel_device_ops_srcs",
+ srcs = ["parallel_device_ops.cc"],
+ visibility = ["//tensorflow/python/distribute/parallel_device:__pkg__"],
+)
+
+cc_library(
+ name = "parallel_device_ops",
+ srcs = [":parallel_device_ops_srcs"],
+ visibility = ["//tensorflow:internal"],
+ deps = ["//tensorflow/core:framework"],
+ alwayslink = 1,
+)
diff --git a/tensorflow/c/eager/parallel_device/parallel_device.cc b/tensorflow/c/eager/parallel_device/parallel_device.cc
index e684680..27c2699 100644
--- a/tensorflow/c/eager/parallel_device/parallel_device.cc
+++ b/tensorflow/c/eager/parallel_device/parallel_device.cc
@@ -92,6 +92,10 @@
TFE_TensorHandle* tensor,
TF_Status* status) const;
+ // A parallel tensor with scalar integers numbering component devices.
+ std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
+ TF_Status* status) const;
+
// Takes a description of a single operation being executed on the
// ParallelDevice, and in turn runs one operation per component device with
// its corresponding inputs from the input ParallelTensors (or
@@ -208,6 +212,46 @@
status);
}
+std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
+ TFE_Context* context, TF_Status* status) const {
+ // TODO(allenl): We could cache DeviceIDs (keyed by context).
+ std::vector<TensorHandlePtr> components;
+ components.reserve(underlying_devices_.size());
+ for (int device_index = 0; device_index < underlying_devices_.size();
+ ++device_index) {
+ int64_t* device_id = new int64_t;
+ *device_id = device_index;
+ std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
+ TF_NewTensor(
+ TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
+ sizeof(int64_t),
+ [](void* data, size_t, void* arg) {
+ delete reinterpret_cast<int64_t*>(data);
+ },
+ nullptr),
+ TF_DeleteTensor);
+ // TODO(allenl): Here and when executing regular operations, we could hold
+ // on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
+ // device names repeatedly.
+ OpPtr const_op(TFE_NewOp(context, "Const", status));
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
+ status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
+ TFE_TensorHandle* device_handle;
+ int num_outputs = 1;
+ TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ components.emplace_back(device_handle);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ }
+ return ParallelTensor::FromTensorHandles(*this, std::move(components),
+ status);
+}
+
absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
@@ -282,6 +326,13 @@
}
result.emplace(std::move(outputs));
return result;
+ } else if (operation_name == std::string("DeviceID")) {
+ std::vector<MaybeParallelTensorOwned> result_content;
+ result_content.reserve(1);
+ result_content.push_back(DeviceIDs(context, status));
+ if (TF_GetCode(status) != TF_OK) return result;
+ result.emplace(std::move(result_content));
+ return result;
}
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
maybe_parallel_results(
diff --git a/tensorflow/c/eager/parallel_device/parallel_device_ops.cc b/tensorflow/c/eager/parallel_device/parallel_device_ops.cc
new file mode 100644
index 0000000..1decffc
--- /dev/null
+++ b/tensorflow/c/eager/parallel_device/parallel_device_ops.cc
@@ -0,0 +1,26 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+
+// TODO(allenl): Figure out if we need this op, and if so whether we should move
+// it to core TF. Right now the eager C API does some checking of op
+// registrations before calling into custom devices, but we may be able to avoid
+// that.
+REGISTER_OP("DeviceID")
+ .Output("device_id: int64")
+ .SetIsStateful()
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape);
diff --git a/tensorflow/c/eager/parallel_device/parallel_device_test.cc b/tensorflow/c/eager/parallel_device/parallel_device_test.cc
index 9b0613b..fdc1404 100644
--- a/tensorflow/c/eager/parallel_device/parallel_device_test.cc
+++ b/tensorflow/c/eager/parallel_device/parallel_device_test.cc
@@ -278,14 +278,15 @@
}
// Assert that `handle` is equal to `expected_value`.
-void AssertScalarFloatEq(TFE_TensorHandle* handle, float expected_value) {
+template <typename value_type>
+void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero(
TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
- ASSERT_EQ(expected_value,
- *static_cast<float*>(TF_TensorData(value_zero.get())));
+ EXPECT_EQ(expected_value,
+ *static_cast<value_type*>(TF_TensorData(value_zero.get())));
}
template <std::size_t num_devices>
@@ -343,8 +344,8 @@
ExtractPerDeviceValues(context, read.get(), &components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
- AssertScalarFloatEq(components[0].get(), 20.);
- AssertScalarFloatEq(components[1].get(), 20.);
+ ExpectScalarEq<float>(components[0].get(), 20.);
+ ExpectScalarEq<float>(components[1].get(), 20.);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
@@ -373,8 +374,8 @@
ExtractPerDeviceValues(context, read.get(), &components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
- AssertScalarFloatEq(components[0].get(), 23.);
- AssertScalarFloatEq(components[1].get(), 18.);
+ ExpectScalarEq<float>(components[0].get(), 23.);
+ ExpectScalarEq<float>(components[1].get(), 18.);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
@@ -383,6 +384,32 @@
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
+ // Compute the device ID twice and verify the result
+ for (int i = 0; i < 2; ++i) {
+ std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
+ TFE_NewOp(context, "DeviceID", status.get()), TFE_DeleteOp);
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+ TFE_OpSetDevice(op.get(), device_name, status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ TFE_TensorHandle* result_handle;
+ int num_retvals = 1;
+ TFE_Execute(op.get(), &result_handle, &num_retvals, status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+ std::array<TensorHandlePtr, 2> components;
+ ExtractPerDeviceValues(context, result_handle, &components, status.get());
+ TFE_DeleteTensorHandle(result_handle);
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ ExpectScalarEq<int64_t>(components[0].get(), 0);
+ ExpectScalarEq<int64_t>(components[1].get(), 1);
+ std::string first_device =
+ TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
+ ASSERT_EQ(underlying_devices[0], first_device);
+ std::string second_device =
+ TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
+ ASSERT_EQ(underlying_devices[1], second_device);
+ }
}
TEST(PARALLEL_DEVICE, TestBasicCPU) {
@@ -498,8 +525,8 @@
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// The value of the original tensor is replicated on each device.
- AssertScalarFloatEq(components[0].get(), 3.);
- AssertScalarFloatEq(components[1].get(), 3.);
+ ExpectScalarEq<float>(components[0].get(), 3.);
+ ExpectScalarEq<float>(components[1].get(), 3.);
// Verify that the mirrors are placed on the component devices.
std::string first_device =
@@ -630,7 +657,7 @@
&second_components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
- AssertScalarFloatEq(second_components[1].get(), 9.);
+ ExpectScalarEq<float>(second_components[1].get(), 9.);
// Verify that the mirrors are placed on the component devices.
std::string first_device = TFE_TensorHandleBackingDeviceName(
@@ -644,8 +671,8 @@
std::array<TensorHandlePtr, 2> first_components;
ExtractPerDeviceValues(context.get(), second_components[0].get(),
&first_components, status.get());
- AssertScalarFloatEq(first_components[0].get(), 3.);
- AssertScalarFloatEq(first_components[1].get(), 6.);
+ ExpectScalarEq<float>(first_components[0].get(), 3.);
+ ExpectScalarEq<float>(first_components[1].get(), 6.);
first_device = TFE_TensorHandleBackingDeviceName(first_components[0].get(),
status.get());
@@ -806,8 +833,8 @@
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
- AssertScalarFloatEq(result_components[0].get(), 3.);
- AssertScalarFloatEq(result_components[1].get(), 3.);
+ ExpectScalarEq<float>(result_components[0].get(), 3.);
+ ExpectScalarEq<float>(result_components[1].get(), 3.);
}
void RegisterCollectiveMulFunction(TFE_Context* context,
@@ -909,8 +936,8 @@
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
- AssertScalarFloatEq(result_components[0].get(), 7. * 9.);
- AssertScalarFloatEq(result_components[1].get(), 7. * 9.);
+ ExpectScalarEq<float>(result_components[0].get(), 7. * 9.);
+ ExpectScalarEq<float>(result_components[1].get(), 7. * 9.);
std::string first_device = TFE_TensorHandleBackingDeviceName(
result_components[0].get(), status.get());
diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD
index e8cb40f..e1fad8e 100644
--- a/tensorflow/cc/BUILD
+++ b/tensorflow/cc/BUILD
@@ -178,7 +178,7 @@
name = "ops",
srcs = ["framework/ops.cc"],
hdrs = ["framework/ops.h"],
- android_deps = ["//tensorflow/core:android_tensorflow_lib"],
+ android_deps = ["//tensorflow/core:portable_tensorflow_lib"],
deps = [
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
@@ -197,7 +197,7 @@
"framework/scope_internal.h",
],
hdrs = ["framework/scope.h"],
- android_deps = ["//tensorflow/core:android_tensorflow_lib"],
+ android_deps = ["//tensorflow/core:portable_tensorflow_lib"],
common_deps = [
":ops",
],
@@ -237,7 +237,7 @@
name = "client_session",
srcs = ["client/client_session.cc"],
hdrs = ["client/client_session.h"],
- android_deps = ["//tensorflow/core:android_tensorflow_lib"],
+ android_deps = ["//tensorflow/core:portable_tensorflow_lib"],
common_deps = [
":ops",
":scope",
@@ -275,7 +275,7 @@
srcs = ["ops/const_op.cc"],
hdrs = ["ops/const_op.h"],
android_deps = [
- "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core:portable_tensorflow_lib",
],
common_deps = [
":ops",
@@ -304,7 +304,7 @@
srcs = ["ops/while_loop.cc"],
hdrs = ["ops/while_loop.h"],
android_deps = [
- "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core:portable_tensorflow_lib",
],
common_deps = [
":cc_ops",
diff --git a/tensorflow/cc/experimental/base/public/BUILD b/tensorflow/cc/experimental/base/public/BUILD
index 93acf1b..045d4e6 100644
--- a/tensorflow/cc/experimental/base/public/BUILD
+++ b/tensorflow/cc/experimental/base/public/BUILD
@@ -62,3 +62,17 @@
"//tensorflow/c:tf_tensor",
],
)
+
+cc_library(
+ name = "tensorhandle",
+ hdrs = [
+ "tensorhandle.h",
+ ],
+ deps = [
+ ":runtime",
+ ":status",
+ ":tensor",
+ "//tensorflow/c/eager:c_api",
+ "//tensorflow/c/eager:c_api_experimental",
+ ],
+)
diff --git a/tensorflow/cc/experimental/base/public/runtime.h b/tensorflow/cc/experimental/base/public/runtime.h
index 47fd886..711a38c 100644
--- a/tensorflow/cc/experimental/base/public/runtime.h
+++ b/tensorflow/cc/experimental/base/public/runtime.h
@@ -21,6 +21,7 @@
#include "tensorflow/c/eager/c_api_experimental.h"
namespace tensorflow {
+namespace experimental {
namespace cc {
// Runtime represents an opaque instance of a Tensorflow runtime, with its own
@@ -40,6 +41,7 @@
private:
friend class RuntimeBuilder;
friend class SavedModelAPI;
+ friend class TensorHandle;
// Wraps a TFE_Context. Takes ownership of ctx.
explicit Runtime(TFE_Context* ctx) : ctx_(ctx) {}
@@ -63,6 +65,7 @@
};
} // namespace cc
+} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_
diff --git a/tensorflow/cc/experimental/base/public/runtime_builder.h b/tensorflow/cc/experimental/base/public/runtime_builder.h
index ed3c93a..737e06c 100644
--- a/tensorflow/cc/experimental/base/public/runtime_builder.h
+++ b/tensorflow/cc/experimental/base/public/runtime_builder.h
@@ -24,6 +24,7 @@
#include "tensorflow/cc/experimental/base/public/status.h"
namespace tensorflow {
+namespace experimental {
namespace cc {
// RuntimeBuilder is a builder used to construct a tensorflow::cc::Runtime.
@@ -79,6 +80,7 @@
}
} // namespace cc
+} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_
diff --git a/tensorflow/cc/experimental/base/public/status.h b/tensorflow/cc/experimental/base/public/status.h
index f91f2ca..98c8cf6 100644
--- a/tensorflow/cc/experimental/base/public/status.h
+++ b/tensorflow/cc/experimental/base/public/status.h
@@ -22,6 +22,7 @@
#include "tensorflow/c/tf_status.h"
namespace tensorflow {
+namespace experimental {
namespace cc {
// Status is a wrapper around an error code and an optional error message.
@@ -57,6 +58,7 @@
friend class RuntimeBuilder;
friend class Runtime;
friend class SavedModelAPI;
+ friend class TensorHandle;
// Wraps a TF_Status*, and takes ownership of it.
explicit Status(TF_Status* status) : status_(status) {}
@@ -88,6 +90,7 @@
}
} // namespace cc
+} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_
diff --git a/tensorflow/cc/experimental/base/public/tensor.h b/tensorflow/cc/experimental/base/public/tensor.h
index 26b0e5d..fc44726 100644
--- a/tensorflow/cc/experimental/base/public/tensor.h
+++ b/tensorflow/cc/experimental/base/public/tensor.h
@@ -28,6 +28,7 @@
#include "tensorflow/cc/experimental/base/public/status.h"
namespace tensorflow {
+namespace experimental {
namespace cc {
// Tensor represents an n-dimensional array of values.
@@ -168,6 +169,7 @@
}
} // namespace cc
+} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_
diff --git a/tensorflow/cc/experimental/base/public/tensorhandle.h b/tensorflow/cc/experimental/base/public/tensorhandle.h
new file mode 100644
index 0000000..99453ee
--- /dev/null
+++ b/tensorflow/cc/experimental/base/public/tensorhandle.h
@@ -0,0 +1,98 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
+#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/c/eager/c_api.h"
+#include "tensorflow/c/eager/c_api_experimental.h"
+#include "tensorflow/cc/experimental/base/public/runtime.h"
+#include "tensorflow/cc/experimental/base/public/status.h"
+#include "tensorflow/cc/experimental/base/public/tensor.h"
+
+namespace tensorflow {
+namespace experimental {
+namespace cc {
+
+// An opaque representation of a tensor computed/managed by the Tensorflow
+// runtime (tensorflow:cc::Runtime). Unlike a tensor, a Tensorhandle may refer
+// to tensors placed in memory of different devices or remote address spaces.
+// Note that tensorflow::cc::Runtime MUST outlive all TensorHandles created
+// from it.
+class TensorHandle {
+ public:
+ // Unwraps a Tensor from the given TensorHandle. If an error occurred,
+ // status->ok() will be false, and the returned Tensor must not be used.
+ Tensor Resolve(Status* status);
+
+ // Constructs a TensorHandle from a Tensor. If an error occurred,
+ // status->ok() will be false, and the returned TensorHandle must not be used.
+ static TensorHandle FromTensor(const Tensor& tensor, const Runtime& runtime,
+ Status* status);
+
+ // TensorHandle is movable, and not copyable
+ TensorHandle(TensorHandle&&) = default;
+ TensorHandle& operator=(TensorHandle&&) = default;
+
+ private:
+ // Wraps a TFE_TensorHandle. Takes ownership of handle.
+ explicit TensorHandle(TFE_TensorHandle* handle) : handle_(handle) {}
+
+ // TensorHandle is not copyable
+ TensorHandle(const TensorHandle&) = delete;
+ TensorHandle& operator=(const TensorHandle&) = delete;
+
+ // Returns the underlying TFE_TensorHandle that this object wraps.
+ // This object retains ownership of the pointer.
+ TFE_TensorHandle* GetTFETensorHandle() const { return handle_.get(); }
+
+ // Deletes the currently wrapped TFE_TensorHandle, and swaps it with handle,
+ // and takes ownership of handle.
+ void Reset(TFE_TensorHandle* handle) { handle_.reset(handle); }
+
+ struct TFETensorHandleDeleter {
+ void operator()(TFE_TensorHandle* p) const { TFE_DeleteTensorHandle(p); }
+ };
+ std::unique_ptr<TFE_TensorHandle, TFETensorHandleDeleter> handle_;
+};
+
+inline Tensor TensorHandle::Resolve(Status* status) {
+ TF_Tensor* tensor =
+ TFE_TensorHandleResolve(handle_.get(), status->GetTFStatus());
+ if (!status->ok()) {
+ return Tensor(nullptr);
+ }
+ return Tensor(tensor);
+}
+
+inline TensorHandle TensorHandle::FromTensor(const Tensor& tensor,
+ const Runtime& runtime,
+ Status* status) {
+ TFE_TensorHandle* tensor_handle = TFE_NewTensorHandleFromTensor(
+ runtime.GetTFEContext(), tensor.GetTFTensor(), status->GetTFStatus());
+ if (!status->ok()) {
+ return TensorHandle(nullptr);
+ }
+ return TensorHandle(tensor_handle);
+}
+
+} // namespace cc
+} // namespace experimental
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
diff --git a/tensorflow/cc/experimental/base/tests/BUILD b/tensorflow/cc/experimental/base/tests/BUILD
index a2b634a..f449d61 100644
--- a/tensorflow/cc/experimental/base/tests/BUILD
+++ b/tensorflow/cc/experimental/base/tests/BUILD
@@ -5,12 +5,22 @@
licenses = ["notice"], # Apache 2.0
)
+cc_library(
+ name = "tensor_types_test_util",
+ testonly = True,
+ hdrs = ["tensor_types_test_util.h"],
+ deps = [
+ "//tensorflow/c:tf_datatype",
+ ],
+)
+
tf_cc_test(
name = "tensor_test",
srcs = [
"tensor_test.cc",
],
deps = [
+ ":tensor_types_test_util",
"//tensorflow/c:tf_datatype",
"//tensorflow/cc/experimental/base/public:status",
"//tensorflow/cc/experimental/base/public:tensor",
@@ -19,3 +29,22 @@
"//tensorflow/core:test_main",
],
)
+
+tf_cc_test(
+ name = "tensorhandle_test",
+ srcs = [
+ "tensorhandle_test.cc",
+ ],
+ deps = [
+ ":tensor_types_test_util",
+ "//tensorflow/c:tf_datatype",
+ "//tensorflow/cc/experimental/base/public:runtime",
+ "//tensorflow/cc/experimental/base/public:runtime_builder",
+ "//tensorflow/cc/experimental/base/public:status",
+ "//tensorflow/cc/experimental/base/public:tensor",
+ "//tensorflow/cc/experimental/base/public:tensorhandle",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
diff --git a/tensorflow/cc/experimental/base/tests/tensor_test.cc b/tensorflow/cc/experimental/base/tests/tensor_test.cc
index 86a50ba..33f9ab6 100644
--- a/tensorflow/cc/experimental/base/tests/tensor_test.cc
+++ b/tensorflow/cc/experimental/base/tests/tensor_test.cc
@@ -16,69 +16,22 @@
#include "tensorflow/cc/experimental/base/public/tensor.h"
#include <stddef.h>
-
-#include <cstdint>
+#include <stdint.h>
#include "tensorflow/c/tf_datatype.h"
+#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/test.h"
-namespace tensorflow {
namespace {
-// Each of the following struct types have two members: a kDType that
-// corresponds to a TF_Datatype enum value, and a typedef "type"
-// of its corresponding C++ type. These types allow us to write Dtype-agnostic
-// tests via GoogleTest's TypedTests:
-// https://github.com/google/googletest/blob/e589a337170554c48bc658cc857cf15080c9eacc/googletest/docs/advanced.md#typed-tests
-struct FloatType {
- using type = float;
- static constexpr TF_DataType kDType = TF_FLOAT;
-};
+using tensorflow::experimental::cc::Status;
+using tensorflow::experimental::cc::Tensor;
-struct DoubleType {
- using type = double;
- static constexpr TF_DataType kDType = TF_DOUBLE;
-};
-
-struct Int32Type {
- using type = int32_t;
- static constexpr TF_DataType kDType = TF_INT32;
-};
-
-struct UINT8Type {
- using type = uint8_t;
- static constexpr TF_DataType kDType = TF_UINT8;
-};
-
-struct INT8Type {
- using type = int8_t;
- static constexpr TF_DataType kDType = TF_INT8;
-};
-
-struct INT64Type {
- using type = int64_t;
- static constexpr TF_DataType kDType = TF_INT64;
-};
-
-struct UINT16Type {
- using type = uint16_t;
- static constexpr TF_DataType kDType = TF_UINT16;
-};
-
-struct UINT32Type {
- using type = uint32_t;
- static constexpr TF_DataType kDType = TF_UINT32;
-};
-
-struct UINT64Type {
- using type = uint64_t;
- static constexpr TF_DataType kDType = TF_UINT64;
-};
-
-using SimpleTypes =
- ::testing::Types<FloatType, DoubleType, Int32Type, UINT8Type, INT8Type,
- INT64Type, UINT16Type, UINT32Type, UINT64Type>;
+using SimpleTypes = ::testing::Types<
+ tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type,
+ tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type,
+ tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>;
template <typename T>
class ConstructScalarTensorTest : public ::testing::Test {};
@@ -88,14 +41,13 @@
// and verifies the expected dimensions, dtype, value, number of bytes, and
// number of elements.
TYPED_TEST(ConstructScalarTensorTest, ValidTensorAttributesAfterConstruction) {
- cc::Status status;
+ Status status;
TF_DataType dtype = TypeParam::kDType;
typename TypeParam::type value = 42;
- cc::Tensor tensor =
- cc::Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
- /*data=*/&value,
- /*len=*/sizeof(value),
- /*deleter=*/[](void*, size_t) {}, &status);
+ Tensor tensor = Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
+ /*data=*/&value,
+ /*len=*/sizeof(value),
+ /*deleter=*/[](void*, size_t) {}, &status);
ASSERT_TRUE(status.ok()) << status.message();
EXPECT_EQ(tensor.dims(), 0);
@@ -113,7 +65,7 @@
// and verifies the expected dimensions, dtype, value, number of bytes, and
// number of elements.
TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) {
- cc::Status status;
+ Status status;
TF_DataType dtype = TypeParam::kDType;
// This is our 1D tensor of varying dtype.
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
@@ -121,7 +73,7 @@
std::vector<int64_t> shape;
shape.push_back(value.size());
- cc::Tensor tensor = cc::Tensor::FromBuffer(
+ Tensor tensor = Tensor::FromBuffer(
/*dtype=*/dtype, /*shape=*/shape,
/*data=*/value.data(),
/*len=*/value.size() * sizeof(typename TypeParam::type),
@@ -130,7 +82,7 @@
EXPECT_EQ(tensor.dims(), 1);
EXPECT_EQ(tensor.dtype(), dtype);
- gtl::ArraySlice<typename TypeParam::type> tensor_view(
+ tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
EXPECT_EQ(tensor_view[0], 42);
EXPECT_EQ(tensor_view[1], 100);
@@ -152,14 +104,14 @@
// and verifies the expected dimensions, dtype, value, number of bytes, and
// number of elements.
TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) {
- cc::Status status;
+ Status status;
TF_DataType dtype = TypeParam::kDType;
// This is our 1D tensor of varying dtype.
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
// Shape is Rank 2 vector with shape 2 x 3.
std::vector<int64_t> shape({2, 3});
- cc::Tensor tensor = cc::Tensor::FromBuffer(
+ Tensor tensor = Tensor::FromBuffer(
/*dtype=*/dtype, /*shape=*/shape,
/*data=*/value.data(),
/*len=*/value.size() * sizeof(typename TypeParam::type),
@@ -169,7 +121,7 @@
EXPECT_EQ(tensor.dims(), 2);
EXPECT_EQ(tensor.dtype(), dtype);
- gtl::ArraySlice<typename TypeParam::type> tensor_view(
+ tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
EXPECT_EQ(tensor_view[0], 42);
EXPECT_EQ(tensor_view[1], 100);
@@ -185,22 +137,22 @@
TEST(CPPTensorAPI, ConstructTensorFromBuffer) {
bool done = false;
- cc::Status status;
+ Status status;
std::vector<int32_t> data_vector({12, 14, 20, 18, 39, 42, 100});
{
// data_vector is a rank 1 tensor.
std::vector<int64_t> shape;
shape.push_back(data_vector.size());
- cc::Tensor::DeleterCallback callback = [&done](void* data, size_t len) {
+ Tensor::DeleterCallback callback = [&done](void* data, size_t len) {
done = true;
};
- cc::Tensor tensor =
- cc::Tensor::FromBuffer(/*dtype=*/TF_INT32, /*shape=*/shape,
- /*data=*/data_vector.data(),
- /*len=*/data_vector.size() * sizeof(int32_t),
- /*deleter=*/callback, &status);
+ Tensor tensor =
+ Tensor::FromBuffer(/*dtype=*/TF_INT32, /*shape=*/shape,
+ /*data=*/data_vector.data(),
+ /*len=*/data_vector.size() * sizeof(int32_t),
+ /*deleter=*/callback, &status);
ASSERT_TRUE(status.ok()) << status.message();
}
// At this point, tensor has been destroyed, and the deleter callback should
@@ -209,4 +161,3 @@
}
} // namespace
-} // namespace tensorflow
diff --git a/tensorflow/cc/experimental/base/tests/tensor_types_test_util.h b/tensorflow/cc/experimental/base/tests/tensor_types_test_util.h
new file mode 100644
index 0000000..af9cad7
--- /dev/null
+++ b/tensorflow/cc/experimental/base/tests/tensor_types_test_util.h
@@ -0,0 +1,76 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
+#define TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
+
+#include <stdint.h>
+
+#include "tensorflow/c/tf_datatype.h"
+
+namespace tensorflow {
+
+// Each of the following struct types have two members: a kDType that
+// corresponds to a TF_Datatype enum value, and a typedef "type"
+// of its corresponding C++ type. These types allow us to write Dtype-agnostic
+// tests via GoogleTest's TypedTests:
+// https://github.com/google/googletest/blob/e589a337170554c48bc658cc857cf15080c9eacc/googletest/docs/advanced.md#typed-tests
+struct FloatType {
+ using type = float;
+ static constexpr TF_DataType kDType = TF_FLOAT;
+};
+
+struct DoubleType {
+ using type = double;
+ static constexpr TF_DataType kDType = TF_DOUBLE;
+};
+
+struct Int32Type {
+ using type = int32_t;
+ static constexpr TF_DataType kDType = TF_INT32;
+};
+
+struct UINT8Type {
+ using type = uint8_t;
+ static constexpr TF_DataType kDType = TF_UINT8;
+};
+
+struct INT8Type {
+ using type = int8_t;
+ static constexpr TF_DataType kDType = TF_INT8;
+};
+
+struct INT64Type {
+ using type = int64_t;
+ static constexpr TF_DataType kDType = TF_INT64;
+};
+
+struct UINT16Type {
+ using type = uint16_t;
+ static constexpr TF_DataType kDType = TF_UINT16;
+};
+
+struct UINT32Type {
+ using type = uint32_t;
+ static constexpr TF_DataType kDType = TF_UINT32;
+};
+
+struct UINT64Type {
+ using type = uint64_t;
+ static constexpr TF_DataType kDType = TF_UINT64;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
diff --git a/tensorflow/cc/experimental/base/tests/tensorhandle_test.cc b/tensorflow/cc/experimental/base/tests/tensorhandle_test.cc
new file mode 100644
index 0000000..cfeaba4
--- /dev/null
+++ b/tensorflow/cc/experimental/base/tests/tensorhandle_test.cc
@@ -0,0 +1,184 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/experimental/base/public/tensorhandle.h"
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <memory>
+
+#include "tensorflow/c/tf_datatype.h"
+#include "tensorflow/cc/experimental/base/public/runtime.h"
+#include "tensorflow/cc/experimental/base/public/runtime_builder.h"
+#include "tensorflow/cc/experimental/base/public/tensor.h"
+#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+using tensorflow::experimental::cc::Runtime;
+using tensorflow::experimental::cc::RuntimeBuilder;
+using tensorflow::experimental::cc::Status;
+using tensorflow::experimental::cc::Tensor;
+using tensorflow::experimental::cc::TensorHandle;
+
+using SimpleTypes = ::testing::Types<
+ tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type,
+ tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type,
+ tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>;
+
+template <typename T>
+class ConstructScalarTensorHandleTest : public ::testing::Test {};
+TYPED_TEST_SUITE(ConstructScalarTensorHandleTest, SimpleTypes);
+
+// This test constructs a scalar tensor for each of the types in "SimpleTypes",
+// then wraps it in a TensorHandle. We then unwrap it back into a Tensor, and
+// verify the expected dims, dtype, value, num bytes, and num elements.
+TYPED_TEST(ConstructScalarTensorHandleTest,
+ ValidTensorAttributesAfterConstruction) {
+ Status status;
+ RuntimeBuilder runtime_builder;
+ std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
+ ASSERT_TRUE(status.ok()) << status.message();
+
+ TF_DataType dtype = TypeParam::kDType;
+ typename TypeParam::type value = 42;
+ Tensor original_tensor =
+ Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
+ /*data=*/&value,
+ /*len=*/sizeof(value),
+ /*deleter=*/[](void*, size_t) {}, &status);
+ ASSERT_TRUE(status.ok()) << status.message();
+
+ TensorHandle handle =
+ TensorHandle::FromTensor(original_tensor, *runtime, &status);
+ ASSERT_TRUE(status.ok()) << status.message();
+
+ Tensor tensor = handle.Resolve(&status);
+ ASSERT_TRUE(status.ok()) << status.message();
+
+ EXPECT_EQ(tensor.dims(), 0);
+ EXPECT_EQ(tensor.dtype(), dtype);
+ EXPECT_EQ(*reinterpret_cast<typename TypeParam::type*>(tensor.data()), 42);
+ EXPECT_EQ(tensor.num_bytes(), sizeof(typename TypeParam::type));
+ EXPECT_EQ(tensor.num_elements(), 1);
+}
+
+template <typename T>
+class Construct1DTensorHandleTest : public ::testing::Test {};
+TYPED_TEST_SUITE(Construct1DTensorHandleTest, SimpleTypes);
+
+// This test constructs a 1D tensor for each of the types in "SimpleTypes",
+// and verifies the expected dimensions, dtype, value, number of bytes, and
+// number of elements.
+TYPED_TEST(Construct1DTensorHandleTest,
+ ValidTensorAttributesAfterConstruction) {
+ Status status;
+ RuntimeBuilder runtime_builder;
+ std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
+ ASSERT_TRUE(status.ok()) << status.message();
+
+ TF_DataType dtype = TypeParam::kDType;
+ // This is our 1D tensor of varying dtype.
+ std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
+ // Shape is Rank 1 vector.
+ std::vector<int64_t> shape;
+ shape.push_back(value.size());
+
+ Tensor original_tensor = Tensor::FromBuffer(
+ /*dtype=*/dtype, /*shape=*/shape,
+ /*data=*/value.data(),
+ /*len=*/value.size() * sizeof(typename TypeParam::type),
+ /*deleter=*/[](void*, size_t) {}, &status);
+ ASSERT_TRUE(status.ok()) << status.message();
+
+ TensorHandle handle =
+ TensorHandle::FromTensor(original_tensor, *runtime, &status);
+ ASSERT_TRUE(status.ok()) << status.message();
+
+ Tensor tensor = handle.Resolve(&status);
+ ASSERT_TRUE(status.ok()) << status.message();
+
+ EXPECT_EQ(tensor.dims(), 1);
+ EXPECT_EQ(tensor.dtype(), dtype);
+ tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
+ reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
+ EXPECT_EQ(tensor_view[0], 42);
+ EXPECT_EQ(tensor_view[1], 100);
+ EXPECT_EQ(tensor_view[2], 0);
+ EXPECT_EQ(tensor_view[3], 1);
+ EXPECT_EQ(tensor_view[4], 4);
+ EXPECT_EQ(tensor_view[5], 29);
+
+ EXPECT_EQ(tensor.num_bytes(),
+ value.size() * sizeof(typename TypeParam::type));
+ EXPECT_EQ(tensor.num_elements(), value.size());
+}
+
+template <typename T>
+class Construct2DTensorHandleTest : public ::testing::Test {};
+TYPED_TEST_SUITE(Construct2DTensorHandleTest, SimpleTypes);
+
+// This test constructs a 2D tensor for each of the types in "SimpleTypes",
+// and verifies the expected dimensions, dtype, value, number of bytes, and
+// number of elements.
+TYPED_TEST(Construct2DTensorHandleTest,
+ ValidTensorAttributesAfterConstruction) {
+ Status status;
+ RuntimeBuilder runtime_builder;
+ std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
+ ASSERT_TRUE(status.ok()) << status.message();
+
+ TF_DataType dtype = TypeParam::kDType;
+ // This is our 1D tensor of varying dtype.
+ std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
+ // Shape is Rank 2 vector with shape 2 x 3.
+ std::vector<int64_t> shape({2, 3});
+
+ Tensor original_tensor = Tensor::FromBuffer(
+ /*dtype=*/dtype, /*shape=*/shape,
+ /*data=*/value.data(),
+ /*len=*/value.size() * sizeof(typename TypeParam::type),
+ /*deleter=*/[](void*, size_t) {}, &status);
+ ASSERT_TRUE(status.ok()) << status.message();
+
+ TensorHandle handle =
+ TensorHandle::FromTensor(original_tensor, *runtime, &status);
+ ASSERT_TRUE(status.ok()) << status.message();
+
+ Tensor tensor = handle.Resolve(&status);
+ ASSERT_TRUE(status.ok()) << status.message();
+
+ EXPECT_EQ(tensor.dims(), 2);
+ EXPECT_EQ(tensor.dtype(), dtype);
+ tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
+ reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
+ EXPECT_EQ(tensor_view[0], 42);
+ EXPECT_EQ(tensor_view[1], 100);
+ EXPECT_EQ(tensor_view[2], 0);
+ EXPECT_EQ(tensor_view[3], 1);
+ EXPECT_EQ(tensor_view[4], 4);
+ EXPECT_EQ(tensor_view[5], 29);
+
+ EXPECT_EQ(tensor.num_bytes(),
+ value.size() * sizeof(typename TypeParam::type));
+ EXPECT_EQ(tensor.num_elements(), value.size());
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/cc/saved_model/experimental/public/concrete_function.h b/tensorflow/cc/saved_model/experimental/public/concrete_function.h
index f57ba05..1adaf70 100644
--- a/tensorflow/cc/saved_model/experimental/public/concrete_function.h
+++ b/tensorflow/cc/saved_model/experimental/public/concrete_function.h
@@ -24,6 +24,7 @@
#include "tensorflow/cc/saved_model/experimental/public/function_metadata.h"
namespace tensorflow {
+namespace experimental {
namespace cc {
// ConcreteFunction is an executable "function" loaded from a SavedModelAPI.
@@ -54,6 +55,7 @@
}
} // namespace cc
+} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_
diff --git a/tensorflow/cc/saved_model/experimental/public/concrete_function_list.h b/tensorflow/cc/saved_model/experimental/public/concrete_function_list.h
index bab9527..88cb779 100644
--- a/tensorflow/cc/saved_model/experimental/public/concrete_function_list.h
+++ b/tensorflow/cc/saved_model/experimental/public/concrete_function_list.h
@@ -22,6 +22,7 @@
#include "tensorflow/cc/saved_model/experimental/public/concrete_function.h"
namespace tensorflow {
+namespace experimental {
namespace cc {
// ConcreteFunctionList helps convert an opaque pointer to an array of
@@ -56,6 +57,7 @@
}
} // namespace cc
+} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
diff --git a/tensorflow/cc/saved_model/experimental/public/function_metadata.h b/tensorflow/cc/saved_model/experimental/public/function_metadata.h
index c3dcc45..11e1a86 100644
--- a/tensorflow/cc/saved_model/experimental/public/function_metadata.h
+++ b/tensorflow/cc/saved_model/experimental/public/function_metadata.h
@@ -21,6 +21,7 @@
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
namespace tensorflow {
+namespace experimental {
namespace cc {
// FunctionMetadata stores additional function information, including
@@ -40,6 +41,7 @@
};
} // namespace cc
+} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_
diff --git a/tensorflow/cc/saved_model/experimental/public/saved_model_api.h b/tensorflow/cc/saved_model/experimental/public/saved_model_api.h
index 814479d..04018bf 100644
--- a/tensorflow/cc/saved_model/experimental/public/saved_model_api.h
+++ b/tensorflow/cc/saved_model/experimental/public/saved_model_api.h
@@ -28,6 +28,7 @@
#include "tensorflow/cc/saved_model/experimental/public/concrete_function_list.h"
namespace tensorflow {
+namespace experimental {
namespace cc {
// SavedModelAPI offers a way to load Tensorflow Saved Models
@@ -155,6 +156,7 @@
}
} // namespace cc
+} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_
diff --git a/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc b/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc
index 155c586..7f7f6b0 100644
--- a/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc
+++ b/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc
@@ -26,10 +26,14 @@
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/platform/test.h"
-namespace tensorflow {
namespace {
+using tensorflow::experimental::cc::Runtime;
+using tensorflow::experimental::cc::RuntimeBuilder;
+using tensorflow::experimental::cc::SavedModelAPI;
+using tensorflow::experimental::cc::Status;
+
constexpr char kTestData[] = "cc/saved_model/testdata";
std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
@@ -43,21 +47,21 @@
class CPPSavedModelAPITest : public ::testing::TestWithParam<bool> {};
TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) {
- cc::Status status;
- cc::RuntimeBuilder builder;
+ Status status;
+ RuntimeBuilder builder;
bool use_tfrt = GetParam();
if (use_tfrt) {
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
}
builder.SetUseTFRT(use_tfrt);
- std::unique_ptr<cc::Runtime> runtime = builder.Build(&status);
+ std::unique_ptr<Runtime> runtime = builder.Build(&status);
ASSERT_TRUE(status.ok()) << status.message();
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
std::unordered_set<std::string> tags = {"serve"};
- std::unique_ptr<cc::SavedModelAPI> model =
- cc::SavedModelAPI::Load(model_dir, *runtime, &status, &tags);
+ std::unique_ptr<SavedModelAPI> model =
+ SavedModelAPI::Load(model_dir, *runtime, &status, &tags);
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
// That unblocks writing other tests that require a TF_SavedModel*,
@@ -67,20 +71,20 @@
}
TEST_P(CPPSavedModelAPITest, LoadsSavedModel) {
- cc::Status status;
- cc::RuntimeBuilder builder;
+ Status status;
+ RuntimeBuilder builder;
bool use_tfrt = GetParam();
if (use_tfrt) {
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
}
builder.SetUseTFRT(use_tfrt);
- std::unique_ptr<cc::Runtime> runtime = builder.Build(&status);
+ std::unique_ptr<Runtime> runtime = builder.Build(&status);
ASSERT_TRUE(status.ok()) << status.message();
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
- std::unique_ptr<cc::SavedModelAPI> model =
- cc::SavedModelAPI::Load(model_dir, *runtime, &status);
+ std::unique_ptr<SavedModelAPI> model =
+ SavedModelAPI::Load(model_dir, *runtime, &status);
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
// That unblocks writing other tests that require a TF_SavedModel*,
@@ -94,4 +98,3 @@
} // namespace
-} // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD
index d907d28..f99b280 100644
--- a/tensorflow/compiler/mlir/lite/BUILD
+++ b/tensorflow/compiler/mlir/lite/BUILD
@@ -695,9 +695,9 @@
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
- "@llvm-project//mlir:LoopOpsTransforms",
"@llvm-project//mlir:MlirTranslateMain",
"@llvm-project//mlir:QuantOps",
+ "@llvm-project//mlir:SCFTransforms",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Translation",
diff --git a/tensorflow/compiler/mlir/runlit.cfg.py b/tensorflow/compiler/mlir/runlit.cfg.py
index 6d3131a..f1271d0 100644
--- a/tensorflow/compiler/mlir/runlit.cfg.py
+++ b/tensorflow/compiler/mlir/runlit.cfg.py
@@ -70,9 +70,9 @@
]
tool_names = [
'mlir-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate',
- 'flatbuffer_to_string', 'flatbuffer_translate', 'tf-mlir-translate',
- 'mlir-tflite-runner', 'tfcompile', 'json_to_flatbuffer', 'xla-gpu-opt',
- 'xla-opt'
+ 'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate',
+ 'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile',
+ 'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt'
]
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
llvm_config.add_tool_substitutions(tools, tool_dirs)
diff --git a/tensorflow/compiler/mlir/runlit.site.cfg.py b/tensorflow/compiler/mlir/runlit.site.cfg.py
index 661e620..3e7596c 100644
--- a/tensorflow/compiler/mlir/runlit.site.cfg.py
+++ b/tensorflow/compiler/mlir/runlit.site.cfg.py
@@ -44,6 +44,7 @@
'tensorflow/compiler/mlir',
'tensorflow/compiler/mlir/lite',
'tensorflow/compiler/mlir/tensorflow',
+ 'tensorflow/compiler/mlir/tfjs',
'tensorflow/compiler/mlir/xla',
'tensorflow/compiler/aot',
'tensorflow/compiler/xla/service/mlir_gpu',
diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD
index 9099f2b..0edf0f3 100644
--- a/tensorflow/compiler/mlir/tensorflow/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/BUILD
@@ -556,7 +556,7 @@
deps = [
":tensorflow",
"@llvm-project//mlir:IR",
- "@llvm-project//mlir:LoopOpsTransforms",
+ "@llvm-project//mlir:SCFTransforms",
],
alwayslink = 1,
)
@@ -823,6 +823,7 @@
":mangling_util",
":tensorflow_attributes",
":tensorflow_types",
+ "//tensorflow/compiler/xla:util",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index a33b339..9a29fa4 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -192,6 +192,44 @@
let verifier = [{ return Verify(*this); }];
}
+def TF_AllToAllOp : TF_Op<"AllToAll", [NoSideEffect]> {
+ let summary = "An Op to exchange data across TPU replicas.";
+
+ let description = [{
+On each replica, the input is split into `split_count` blocks along
+`split_dimension` and send to the other replicas given group_assignment. After
+receiving `split_count` - 1 blocks from other replicas, we concatenate the
+blocks along `concat_dimension` as the output.
+
+For example, suppose there are 2 TPU replicas:
+replica 0 receives input: `[[A, B]]`
+replica 1 receives input: `[[C, D]]`
+
+group_assignment=`[[0, 1]]`
+concat_dimension=0
+split_dimension=1
+split_count=2
+
+replica 0's output: `[[A], [C]]`
+replica 1's output: `[[B], [D]]`
+ }];
+
+ let arguments = (ins
+ TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
+ I32Tensor:$group_assignment,
+
+ I64Attr:$concat_dimension,
+ I64Attr:$split_dimension,
+ I64Attr:$split_count
+ );
+
+ let results = (outs
+ TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
+ );
+
+ TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+}
+
def TF_AngleOp : TF_Op<"Angle", [NoSideEffect, SameOperandsAndResultShape]> {
let summary = "Returns the argument of a complex number.";
@@ -1408,6 +1446,30 @@
let hasCanonicalizer = 1;
}
+def TF_ConjugateTransposeOp : TF_Op<"ConjugateTranspose", [NoSideEffect]> {
+ let summary = [{
+Shuffle dimensions of x according to a permutation and conjugate the result.
+ }];
+
+ let description = [{
+The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy:
+ `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]`
+ `y[i,j,k,...,s,t,u] == conj(x[perm[i], perm[j], perm[k],...,perm[s], perm[t], perm[u]])`
+ }];
+
+ let arguments = (ins
+ TF_Tensor:$x,
+ TF_I32OrI64Tensor:$perm
+ );
+
+ let results = (outs
+ TF_Tensor:$y
+ );
+
+ TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+ TF_DerivedOperandTypeAttr Tperm = TF_DerivedOperandTypeAttr<1>;
+}
+
def TF_Conv2DOp : TF_Op<"Conv2D", [NoSideEffect, TF_LayoutSensitiveInterface]> {
let summary = [{
Computes a 2-D convolution given 4-D `input` and `filter` tensors.
@@ -3563,6 +3625,31 @@
TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>;
}
+def TF_InplaceUpdateOp : TF_Op<"InplaceUpdate", [NoSideEffect]> {
+ let summary = [{
+ Create a copy of `x` with the updated specified rows 'i' with values 'v'.
+
+ }];
+
+ let description = [{
+ Creates a copy of tensor 'x' and updates the columns specified in tensor 'i'
+ with the values 'v'. Originally this function was mutative however for
+ compilation we make this operation create / operate on a copy.
+ }];
+
+ let arguments = (ins
+ TF_Tensor:$x,
+ I32Tensor:$i,
+ TF_Tensor:$v
+ );
+
+ let results = (outs
+ TF_Tensor:$y
+ );
+
+ TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+}
+
def TF_InvOp : TF_Op<"Inv", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes the reciprocal of x element-wise.";
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
index edba135..85baff5 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
@@ -986,18 +986,15 @@
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
- for (NamedAttribute named_attr : attributes) {
- if (named_attr.first.strref() != "value") continue;
- auto value = named_attr.second;
- if (auto elem_attr = value.dyn_cast<ElementsAttr>()) {
- inferredReturnTypes.assign({elem_attr.getType()});
- return success();
- }
- return emitOptionalError(location,
- "attribute 'value' failed to satisfy constraint: "
- "constant vector/tensor");
+ auto value = attributes.get("value");
+ if (!value) return emitOptionalError(location, "missing attribute 'value'");
+ if (auto elem_attr = value.dyn_cast<ElementsAttr>()) {
+ inferredReturnTypes.assign({elem_attr.getType()});
+ return success();
}
- return emitOptionalError(location, "missing attribute 'value'");
+ return emitOptionalError(location,
+ "attribute 'value' failed to satisfy constraint: "
+ "constant vector/tensor");
}
//===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir
index 77ca08c..eb67bdc 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir
@@ -1,13 +1,17 @@
// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-tpu-extract-head-tail-outside-compilation | FileCheck %s --dump-input-on-failure
-// Tests extraction of a single outside compiled cluster with no input or output dependecies.
+// Tests extraction of a outside compiled ops at head of TPU computation.
-// CHECK-LABEL: func @nodep_single_head_outside_compilation
-func @nodep_single_head_outside_compilation() -> () {
- // CHECK: "tf.A"
- // CHECK-NEXT: "tf_device.launch"
- "tf_device.launch"() ( {
- "tf.A"() {_xla_outside_compilation = "cluster1"} : () -> ()
+func @single_head_outside_compilation(%arg0 : tensor<i32>) -> () {
+ // CHECK: tf_device.launch
+ // CHECK: "tf.A"
+ // CHECK-NEXT: tf_device.return
+ //
+ // CHECK: "tf_device.cluster"
+ // CHECK: "tf.C"
+ // CHECK-NEXT: tf_device.return
+ "tf_device.cluster"() ( {
+ "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> ()
"tf.B"() : () -> ()
"tf.C"() : () -> ()
tf_device.return
@@ -15,15 +19,62 @@
return
}
-// CHECK-LABEL: func @nodep_multiple_head_outside_compilation
-func @nodep_multiple_head_outside_compilation() -> () {
- // CHECK: "tf.A"
- // CHECK-NEXT: "tf.B"
- // CHECK-NEXT: "tf_device.launch"
- "tf_device.launch"() ( {
- "tf.A"() {_xla_outside_compilation = "cluster1"} : () -> ()
- "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
- "tf.C"() : () -> ()
+// CHECK-LABEL: func @multiple_head_outside_compilation
+func @multiple_head_outside_compilation(%arg0 : tensor<i32>) -> () {
+ // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
+ // CHECK: %[[A_OUT:.*]] = "tf.A"
+ // CHECK: %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]])
+ // CHECK: "tf.C"
+ // CHECK-NEXT: tf_device.return %[[B_OUT]]
+ //
+ // CHECK: "tf_device.cluster"
+ // CHECK: "tf.D"(%[[LAUNCH_OUT]])
+ // CHECK-NEXT: tf_device.return
+ "tf_device.cluster"() ( {
+ %0 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
+ %1 = "tf.B"(%0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
+ "tf.C"(%1, %arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> ()
+ "tf.D"(%1) : (tensor<i32>) -> ()
+ tf_device.return
+ }) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
+ return
+}
+
+// CHECK-LABEL: func @test_do_not_outside_compiled_ops_in_middle
+func @test_do_not_outside_compiled_ops_in_middle(%arg0 : tensor<i32>) -> () {
+ // CHECK-NOT: tf_device.launch
+ // CHECK: "tf_device.cluster"
+ // CHECK-NEXT: "tf.A"
+ // CHECK-NEXT: "tf.B"
+ // CHECK-NEXT: "tf.C"
+ // CHECK-NEXT: tf_device.return
+ "tf_device.cluster"() ( {
+ %0 = "tf.A"(%arg0) {} : (tensor<i32>) -> (tensor<i32>)
+ %1 = "tf.B"(%0) {_xla_outside_compilation = "cluster1"}: (tensor<i32>) -> (tensor<i32>)
+ "tf.C"(%1) : (tensor<i32>) -> ()
+ tf_device.return
+ }) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
+ return
+}
+
+// CHECK-LABEL: func @test_ops_with_tpu_operands_not_extracted
+func @test_ops_with_tpu_operands_not_extracted(%arg0 : tensor<i32>) -> () {
+ // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
+ // CHECK: %[[A_OUT:.*]] = "tf.A"
+ // CHECK: %[[D_OUT:.*]] = "tf.D"(%[[A_OUT]])
+ // CHECK-NEXT: tf_device.return %[[D_OUT]]
+ //
+ // CHECK: "tf_device.cluster"
+ // CHECK: "tf.B"
+ // CHECK: "tf.C"
+ // CHECK: "tf.E"
+ // CHECK-NEXT: tf_device.return
+ "tf_device.cluster"() ( {
+ %0 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
+ %1 = "tf.B"() {} : () -> (tensor<i32>)
+ %2 = "tf.C"(%arg0, %1) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> (tensor<i32>)
+ %3 = "tf.D"(%0) {_xla_outside_compilation = "cluster1"}: (tensor<i32>) -> (tensor<i32>)
+ %4 = "tf.E"(%3) {} : (tensor<i32>) -> (tensor<i32>)
tf_device.return
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
return
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir
index af0119d..b8a48bb 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir
@@ -1222,6 +1222,41 @@
// -----
+// Tests simple case of `tf_device.cluster_func` on TPU with replication and parallel_execute.
+
+module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} {
+ // CHECK-LABEL: func @replicated_parallel_tpu_cluster_func
+ func @replicated_parallel_tpu_cluster_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
+ // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
+ %0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
+ // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
+ %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
+ // CHECK: "tf._TPUCompileMlir"
+ // CHECK: "tf.TPUCompileSucceededAssert"
+ // CHECK: "tf_device.parallel_execute"
+ // CHECK: "tf.TPUExecute"
+ %3 = "tf_device.parallel_execute"() ( {
+ "tf.D"() : () -> ()
+ tf_device.return
+ }, {
+ %4 = "tf_device.cluster_func"(%ri_0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<?xi32>) -> tensor<?xi32>
+
+ tf_device.return %4 : tensor<?xi32>
+ }) : () -> (tensor<?xi32>)
+ tf_device.return %3 : tensor<?xi32>
+ }
+ %2 = "tf.C"(%1#1) : (tensor<?xi32>) -> tensor<?xi32>
+ return %2 : tensor<?xi32>
+ }
+
+ func @tpu0_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
+ %0 = "tf.B"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
+ return %0 : tensor<?xi32>
+ }
+}
+
+// -----
+
// Tests devices are set properly for non replicated model parallelism.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"]} {
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
index c1d99c2..0b1ff2b 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h
@@ -258,7 +258,7 @@
// Creates a pass that extracts outside compilation (CPU ops inside TPU cluster)
// at head/tail of TPU cluster to run before/after TPU computation.
-std::unique_ptr<OperationPass<FuncOp>>
+std::unique_ptr<OperationPass<ModuleOp>>
CreateTPUExtractHeadTailOutsideCompilationPass();
// Creates a pass that extract outside compilation (CPU ops inside TPU cluster)
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
index efe82c4..5a2cae3 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
@@ -66,8 +66,7 @@
namespace mlir {
namespace TF {
namespace {
-Optional<llvm::SmallVector<mlir::Type, 4>> InferShapeForFunctionReturnType(
- FuncOp func) {
+Optional<SmallVector<Type, 4>> InferShapeForFunctionReturnType(FuncOp func) {
// Find any return ops.
SmallVector<ReturnOp, 4> return_ops;
for (Block& block : func) {
@@ -137,9 +136,9 @@
cast_op = b.create<TF::CastOp>(op->getLoc(), old_type, result,
/*truncate=*/b.getBoolAttr(false));
}
- return mlir::Value(cast_op);
+ return Value(cast_op);
};
- for (OpOperand& use : llvm::make_early_inc_range(result.getUses())) {
+ for (OpOperand& use : make_early_inc_range(result.getUses())) {
if (use.getOwner()->getDialect() != tf_dialect &&
!IsSupportedNonTFOp(use.getOwner()))
use.set(get_cast_op());
@@ -162,7 +161,7 @@
bool InferShapeForPassThroughOps(OperandRange pass_through_operands,
Operation* op, Dialect* tf_dialect) {
bool changed = false;
- for (auto entry : llvm::zip(pass_through_operands, op->getResults())) {
+ for (auto entry : zip(pass_through_operands, op->getResults())) {
Type operand_type = std::get<0>(entry).getType();
Value result = std::get<1>(entry);
if (result.getType() == operand_type) continue;
@@ -204,7 +203,7 @@
tf_dialect);
}
// TODO(b/155227679): Use OpInterface instead of hard-coding for TensorCastOp.
- if (auto tensor_cast = dyn_cast<mlir::TensorCastOp>(op)) {
+ if (auto tensor_cast = dyn_cast<TensorCastOp>(op)) {
return InferShapeForPassThroughOps(
tensor_cast.getOperation()->getOperands(), op, tf_dialect);
}
@@ -254,7 +253,7 @@
// match the i-th operand type). Returns true if anything is changed.
bool PassThroughOperandTypes(OperandRange operands, ResultRange results) {
bool changed = false;
- for (auto entry : llvm::zip(operands, results)) {
+ for (auto entry : zip(operands, results)) {
Type operand_type = std::get<0>(entry).getType();
Type result_type = std::get<1>(entry).getType();
if (operand_type == result_type) continue;
@@ -291,14 +290,13 @@
CallInterfaceCallable callable = call_op.getCallableForCallee();
SymbolRefAttr sym = callable.dyn_cast<SymbolRefAttr>();
if (!sym) return false;
- FuncOp func =
- dyn_cast<mlir::FuncOp>(SymbolTable::lookupNearestSymbolFrom(op, sym));
+ FuncOp func = dyn_cast<FuncOp>(SymbolTable::lookupNearestSymbolFrom(op, sym));
if (!func) return false;
bool changed = false;
// Map each of the results of the call to the returned type of the
// function.
- for (auto result : llvm::zip(op->getResults(), func.getType().getResults())) {
+ for (auto result : zip(op->getResults(), func.getType().getResults())) {
if (std::get<0>(result).getType() == std::get<1>(result)) continue;
// Skip already statically shaped results.
if (!CanBeRefined(std::get<0>(result).getType())) continue;
@@ -335,7 +333,7 @@
// Map each of the results of the call to the returned type of the
// function.
bool changed = false;
- for (auto result : llvm::zip(op->getResults(), inferred)) {
+ for (auto result : zip(op->getResults(), inferred)) {
if (std::get<0>(result).getType() == std::get<1>(result)) continue;
// Inserts a cast back to the original type if any user is not in the
@@ -356,7 +354,7 @@
// so for tf.Const -> tensor<10x20xf32>, [0,2,18] would point to a unique output
// scalar value).
struct ValuePort {
- llvm::PointerUnion<Operation*, BlockArgument> producer;
+ PointerUnion<Operation*, BlockArgument> producer;
SmallVector<unsigned int, 2> port;
bool operator==(const ValuePort& other) const {
@@ -374,39 +372,38 @@
port = {0};
}
}
- ValuePort(llvm::PointerUnion<Operation*, BlockArgument> producer,
+ ValuePort(PointerUnion<Operation*, BlockArgument> producer,
SmallVector<unsigned int, 2> port)
: producer(producer), port(port) {}
- llvm::raw_ostream& print(llvm::raw_ostream& os) const {
+ raw_ostream& print(raw_ostream& os) const {
if (auto* op = producer.dyn_cast<Operation*>())
os << "op " << op->getName();
if (auto ba = producer.dyn_cast<BlockArgument>())
os << "block_arg " << ba.getArgNumber();
- os << llvm::formatv(" [{0}]", llvm::make_range(port.begin(), port.end()));
+ os << formatv(" [{0}]", llvm::make_range(port.begin(), port.end()));
return os;
}
};
struct ValuePortHasher {
std::size_t operator()(const ValuePort& other) const {
- return llvm::hash_combine(
- llvm::hash_value(other.producer.getOpaqueValue()),
- llvm::hash_value(ArrayRef<unsigned int>(other.port)));
+ return hash_combine(llvm::hash_value(other.producer.getOpaqueValue()),
+ hash_value(ArrayRef<unsigned int>(other.port)));
}
};
using ValuePortResultMap =
std::unordered_map<ValuePort, Attribute, ValuePortHasher>;
-using ComputedQueryFn = llvm::function_ref<bool(ValuePort)>;
-using ValueQueryFn = llvm::function_ref<Attribute(const ValuePort&)>;
-using ValuePortInputs = llvm::SmallVectorImpl<ValuePort>;
+using ComputedQueryFn = function_ref<bool(ValuePort)>;
+using ValueQueryFn = function_ref<Attribute(const ValuePort&)>;
+using ValuePortInputs = SmallVectorImpl<ValuePort>;
-// TODO(jpienaar): InputsRequiredForOutput and ComputeOutputComponent are
+// TODO(jpienaar): ComputeInputsRequiredForOutput and ComputeOutputComponent are
// intended to be switched to op interfaces once more refined.
-LogicalResult InputsRequiredForOutput(ValuePort value_port,
- ComputedQueryFn has_been_computed,
- ValuePortInputs* inputs) {
+LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port,
+ ComputedQueryFn has_been_computed,
+ ValuePortInputs* inputs) {
auto op = value_port.producer.dyn_cast<Operation*>();
auto& port = value_port.port;
if (!op) return failure();
@@ -460,26 +457,94 @@
return nullptr;
}
-ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic) {
+// Context used during ShapeInference. This class contains common information
+// that is required by the individual shape inference helper functions (e.g.,
+// TF Graph version, constant values computed, etc.)
+class ShapeInference {
+ public:
+ ShapeInference(int64_t graph_version, MLIRContext* context);
+
+ LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port,
+ ValuePortInputs* inputs) {
+ return ::mlir::TF::ComputeInputsRequiredForOutput(
+ value_port,
+ [this](const ValuePort& port) {
+ return results_.find(port) != results_.end();
+ },
+ inputs);
+ }
+
+ Attribute ComputeOutputComponent(const ValuePort& value_port) {
+ return ::mlir::TF::ComputeOutputComponent(
+ value_port, [this](const ValuePort& port) { return results_[port]; });
+ }
+
+ // Returns ShapeHandle if the op result could be computed as shape.
+ ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic);
+
+ void RecordValue(const ValuePort& value_port, Attribute value) {
+ results_[value_port] = value;
+ }
+
+ // Performs shape inference on the provided op and return true if the type of
+ // at least one result has been changed.
+ // A tf.Cast() is inserted for any uses that isn't in the TensorFlow dialect.
+ // `graph_version` indicates the current GraphDef compatibility versions
+ // (the versions field in graph.proto).
+ bool InferShapeForSingleOperation(Operation* op);
+
+ // Infers shape on the provided region, including nested ones, iterate until
+ // fix point with a limit of max_iteration. Returns success if fix point is
+ // reached before max_iteration.
+ LogicalResult InferShapeUntilFixPoint(Region* region,
+ int64_t max_iteration = 10);
+
+ // Updates input types and refine shapes inside body of functions that are
+ // attached to ControlFlow ops (If/While). These functions include Then/Else
+ // branches of IfOp and Cond/Body functions of WhileOp. These functions share
+ // following common properties:
+ // 1) They are never reused, ie. having a single use in module.
+ // 2) Their input types match those of their parent ops (excluding inputs
+ // like predicate).
+ // Returns a boolean indicating whether any change has been applied.
+ LogicalResult RefineShapeForControlFlowFunc(FuncOp func,
+ ArrayRef<Type> input_types,
+ int64_t max_iteration);
+
+ // Propagate the shapes to the functions named.
+ LogicalResult PropagateShapeToFunctions(
+ ModuleOp module, Operation::operand_type_range input_types,
+ ArrayRef<StringRef> func_names, int64_t max_iteration);
+
+ // Shape propagation for call/control flow ops.
+ LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op,
+ int64_t max_iteration);
+
+ private:
+ // Mapping between ValuePort (which corresponds to an OpResult or smaller,
+ // e.g., first element of OpResult produded) to an Attribute if the ValuePort
+ // corresponds to a constant value.
+ ValuePortResultMap results_;
+ int64_t graph_version_;
+ MLIRContext* context_;
+ Dialect* tf_dialect_;
+};
+
+ShapeInference::ShapeInference(int64_t graph_version, MLIRContext* context)
+ : graph_version_(graph_version) {
+ context_ = context;
+ tf_dialect_ = context->getRegisteredDialect<TensorFlowDialect>();
+}
+
+ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result,
+ InferenceContext* ic) {
LLVM_DEBUG(result.print(llvm::dbgs() << "\nEvaluate partially "));
auto rt = result.getType().dyn_cast<RankedTensorType>();
if (!rt || !rt.hasStaticShape() || rt.getRank() != 1) return {};
int dim_size = rt.getDimSize(0);
// Worklist to direct partial evaluation.
- llvm::SmallVector<ValuePort, 4> worklist;
- // The ValuePort evaluated results.
- // TODO(jpienaar): This could be cached across invocations (e.g., part of some
- // inference context).
- ValuePortResultMap evaluated;
- // Returns whether a ValuePort has been previously computed.
- auto has_been_computed = [&evaluated](const ValuePort& port) {
- return evaluated.find(port) != evaluated.end();
- };
- // Returns previously computed ValuePort value.
- auto values = [&evaluated](const ValuePort& port) -> Attribute {
- return evaluated[port];
- };
+ SmallVector<ValuePort, 4> worklist;
// Simple evaluator that attempts to partially evaluate the input value even
// if unable to evaluate the complete output. Below follows a simple stack
@@ -498,7 +563,7 @@
LLVM_DEBUG(front.print(llvm::errs() << "\nWorklist front "));
SmallVector<ValuePort, 4> inputs;
- auto res = InputsRequiredForOutput(front, has_been_computed, &inputs);
+ auto res = ComputeInputsRequiredForOutput(front, &inputs);
if (failed(res)) {
// Abort if unable to find which required inputs need to be computed.
worklist.clear();
@@ -513,16 +578,16 @@
continue;
}
- auto ret = ComputeOutputComponent(front, values);
+ auto ret = ComputeOutputComponent(front);
if (!ret) continue;
- evaluated[front] = ret;
+ RecordValue(front, ret);
LLVM_DEBUG(ret.print(llvm::dbgs() << "\ncomputed result = "));
// If worklist is empty, then this is the root query op.
if (worklist.empty()) {
LLVM_DEBUG(llvm::dbgs() << "[root node]\n");
- if (auto dea = ret.dyn_cast<mlir::DenseIntElementsAttr>()) {
+ if (auto dea = ret.dyn_cast<DenseIntElementsAttr>()) {
if (dea.getNumElements() != 1) {
LLVM_DEBUG(llvm::errs() << "Unexpected number of elements\n");
return {};
@@ -536,9 +601,8 @@
return ic->MakeShape(dims);
}
-bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
- int64_t graph_version) {
- assert(tf_dialect == op->getDialect());
+bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
+ assert(tf_dialect_ == op->getDialect());
// The shape function of these ops sometimes does not propagate subtypes
// (handle shapes) for resource and variant types. We use a simple passthrough
// to make sure they are preserved in the output.
@@ -550,7 +614,7 @@
// If no result for this op needs shape inference, we have a fast-path return.
// But if the type is a resource/variant, we do not skip it because we might
// not have the handle shapes.
- if (llvm::none_of(op->getResultTypes(), CanBeRefined)) {
+ if (none_of(op->getResultTypes(), CanBeRefined)) {
LLVM_DEBUG(llvm::dbgs() << "Skipping inference for statically shaped op '"
<< op->getName() << "'.\n");
return false;
@@ -565,8 +629,8 @@
// This is necessary to avoid reprocessing the tf.Cast that are inserted at
// the end of this function.
if (isa<CastOp>(op) &&
- llvm::all_of(op->getResult(0).getUsers(), [&](Operation* user) {
- return user->getDialect() != tf_dialect;
+ all_of(op->getResult(0).getUsers(), [&](Operation* user) {
+ return user->getDialect() != tf_dialect_;
})) {
LLVM_DEBUG(llvm::dbgs() << "Skipping inference for tf.Cast with no TF "
"dialect operation users '"
@@ -646,7 +710,7 @@
// Perform the shape inference using an InferenceContext with the input
// shapes. This object is abstracting the information that the ShapeInference
// function operates on.
- InferenceContext c(graph_version, *node_def, op_reg_data->op_def,
+ InferenceContext c(graph_version_, *node_def, op_reg_data->op_def,
input_shapes, input_tensors,
/*input_tensors_as_shapes=*/{}, handle_shapes_and_types);
auto status = c.Run(op_reg_data->shape_inference_fn);
@@ -659,7 +723,7 @@
// Determine if, during shape computation, the shape functions attempted to
// query an input operand as shape where the input was not known/constant.
bool requires_inputs =
- llvm::any_of(llvm::seq<int>(0, c.num_inputs()), [&](int input) {
+ any_of(llvm::seq<int>(0, c.num_inputs()), [&](int input) {
return c.requested_input_tensor_as_partial_shape(input) &&
!input_tensors[input];
});
@@ -723,7 +787,7 @@
new_element_type.isa<TF::VariantType>()) {
auto handle_shapes_types = c.output_handle_shapes_and_types(output);
if (handle_shapes_types) {
- llvm::SmallVector<mlir::TensorType, 1> subtypes;
+ SmallVector<TensorType, 1> subtypes;
OpBuilder b(op);
for (const auto& shape_n_type : *handle_shapes_types) {
Type element_type;
@@ -743,7 +807,7 @@
if (result.getType() == new_type) continue;
// Inserts a cast back to the original type if any user is not in the TF
// dialect.
- AddCastBackForUnsupportedNonTFUses(op, result, tf_dialect,
+ AddCastBackForUnsupportedNonTFUses(op, result, tf_dialect_,
result.getType());
// Finally we inferred the shape and replace the type for this result.
result.setType(new_type);
@@ -755,23 +819,13 @@
return changed;
}
-// Updates input types and refine shapes inside body of functions that are
-// attached to ControlFlow ops (If/While). These functions include Then/Else
-// branches of IfOp and Cond/Body functions of WhileOp. These functions share
-// following common properties:
-// 1) They are never reused, ie. having a single use in module.
-// 2) Their input types match those of their parent ops (excluding inputs like
-// predicate).
-// Returns a boolean indicating whether any change has been applied.
-LogicalResult RefineShapeForControlFlowFunc(FuncOp func,
- llvm::ArrayRef<Type> input_types,
- int64_t graph_version,
- int64_t max_iteration) {
+LogicalResult ShapeInference::RefineShapeForControlFlowFunc(
+ FuncOp func, ArrayRef<Type> input_types, int64_t max_iteration) {
ModuleOp module = func.getParentOfType<ModuleOp>();
auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion());
int num_uses = std::distance(func_uses->begin(), func_uses->end());
if (num_uses != 1) {
- func.emitWarning(llvm::formatv(
+ func.emitWarning(formatv(
"expected control flow function {0} to have exactly 1 use, found {1}.",
func.getName(), num_uses));
return failure();
@@ -785,8 +839,7 @@
arg_and_idx.value().setType(input_types[arg_and_idx.index()]);
}
- auto res =
- InferShapeUntilFixPoint(&func.getBody(), graph_version, max_iteration);
+ auto res = InferShapeUntilFixPoint(&func.getBody(), max_iteration);
if (failed(res)) return res;
auto new_return_types = InferShapeForFunctionReturnType(func);
@@ -798,20 +851,18 @@
return success();
}
-LogicalResult PropagateShapeToFunctions(
+LogicalResult ShapeInference::PropagateShapeToFunctions(
ModuleOp module, Operation::operand_type_range input_types,
- llvm::ArrayRef<StringRef> func_names, int64_t graph_version,
- int64_t max_iteration) {
- bool success = true;
+ ArrayRef<StringRef> func_names, int64_t max_iteration) {
+ bool all_succeeded = true;
auto types = llvm::to_vector<4>(input_types);
for (auto func_name : func_names) {
FuncOp func = module.lookupSymbol<FuncOp>(func_name);
- if (failed(RefineShapeForControlFlowFunc(func, types, graph_version,
- max_iteration))) {
- success = false;
- }
+ all_succeeded =
+ succeeded(RefineShapeForControlFlowFunc(func, types, max_iteration)) &&
+ all_succeeded;
}
- return mlir::success(success);
+ return success(all_succeeded);
}
// If the callee has only one use, propagates any constant operand of call_op to
@@ -831,7 +882,7 @@
// the constant inside the function.
for (auto arg : func.getArguments()) {
auto operand = op->getOperand(arg.getArgNumber()).getDefiningOp();
- if (llvm::isa_and_nonnull<TF::ConstOp>(operand)) {
+ if (isa_and_nonnull<TF::ConstOp>(operand)) {
arg.replaceAllUsesWith(builder.clone(*operand)->getResult(0));
}
}
@@ -850,33 +901,31 @@
for (auto retval :
llvm::enumerate(func.front().getTerminator()->getOperands())) {
auto retval_op = retval.value().getDefiningOp();
- if (llvm::isa_and_nonnull<TF::ConstOp>(retval_op)) {
+ if (isa_and_nonnull<TF::ConstOp>(retval_op)) {
op->getResult(retval.index())
.replaceAllUsesWith(builder.clone(*retval_op)->getResult(0));
}
}
}
-LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op,
- int64_t graph_version,
- int64_t max_iteration) {
+LogicalResult ShapeInference::PropagateShapeIntoAttachedFunctions(
+ Operation* op, int64_t max_iteration) {
ModuleOp module = op->getParentOfType<ModuleOp>();
if (auto if_op = dyn_cast<TF::IfOp>(op)) {
return PropagateShapeToFunctions(
- module, llvm::drop_begin(if_op.getOperandTypes(), 1),
- {if_op.then_branch(), if_op.else_branch()}, graph_version,
- max_iteration);
+ module, drop_begin(if_op.getOperandTypes(), 1),
+ {if_op.then_branch(), if_op.else_branch()}, max_iteration);
} else if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
return PropagateShapeToFunctions(module, while_op.getOperandTypes(),
{while_op.cond(), while_op.body()},
- graph_version, max_iteration);
+ max_iteration);
} else if (auto call_op = dyn_cast<CallOpInterface>(op)) {
CallInterfaceCallable callable = call_op.getCallableForCallee();
if (SymbolRefAttr sym = callable.dyn_cast<SymbolRefAttr>()) {
PropagateConstantToCallee(call_op, sym, module);
if (failed(PropagateShapeToFunctions(
module, call_op.getArgOperands().getTypes(),
- {sym.getRootReference()}, graph_version, max_iteration))) {
+ {sym.getRootReference()}, max_iteration))) {
return failure();
}
PropagateConstantFromCallee(call_op, sym, module);
@@ -889,13 +938,10 @@
return success();
}
-LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version,
- int64_t max_iteration) {
- MLIRContext* ctx = region->getContext();
- Dialect* tf_dialect = ctx->getRegisteredDialect<TensorFlowDialect>();
-
- // An operation folder that is used to attempt folding before inference.
- OperationFolder folder(ctx);
+LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region,
+ int64_t max_iteration) {
+ // An operation folder that is used to attempt folding before inference._
+ OperationFolder folder(context_);
bool changed = true;
// TODO(aminim): we could have a more efficient traversal by guiding the
@@ -908,14 +954,14 @@
<< "Shape inference, iteration " << iteration << "\n");
region->walk([&](Operation* op) {
if (auto infer_ti = dyn_cast<InferTypeOpInterface>(op)) {
- changed |= RefineWithInferTypeOpInterface(infer_ti, tf_dialect);
+ changed |= RefineWithInferTypeOpInterface(infer_ti, tf_dialect_);
// TODO(jpienaar): Debug why we can't just return here. We end up with
// additional constant due to the propagation of constant into attached
// function if we return already.
}
- if (op->getDialect() != tf_dialect) {
- changed |= InferShapeForNonTFDialectOperation(op, tf_dialect);
+ if (op->getDialect() != tf_dialect_) {
+ changed |= InferShapeForNonTFDialectOperation(op, tf_dialect_);
return;
}
@@ -924,13 +970,12 @@
// Best-effort shape inference in attached functions. Do not return
// failure even if it doesn't get to fixed point.
- if (failed(PropagateShapeIntoAttachedFunctions(op, graph_version,
- max_iteration))) {
+ if (failed(PropagateShapeIntoAttachedFunctions(op, max_iteration))) {
op->emitWarning() << "unable to refine shape of attached function "
"arguments and bodies";
}
- changed |= InferShapeForSingleOperation(op, tf_dialect, graph_version);
+ changed |= InferShapeForSingleOperation(op);
});
}
@@ -945,31 +990,43 @@
LogicalResult InferShapeForFunction(FuncOp func,
ArrayRef<ArrayRef<int64_t>> arg_shapes,
int64_t graph_version) {
- mlir::FunctionType func_type = func.getType();
+ ShapeInference context(graph_version, func.getContext());
+ if (arg_shapes.empty()) {
+ if (failed(context.InferShapeUntilFixPoint(&func.getBody())))
+ return failure();
+ // TODO(b/156276510): Verify that it is always fine to refine a function's
+ // return type, as long as we do not change the argument shapes.
+ if (auto return_types = InferShapeForFunctionReturnType(func)) {
+ func.setType(FunctionType::get(func.getType().getInputs(),
+ return_types.getValue(),
+ func.getContext()));
+ }
+
+ return success();
+ }
+ FunctionType func_type = func.getType();
bool needs_refinement = false;
- llvm::SmallVector<mlir::Type, 4> new_arg_types;
+ SmallVector<Type, 4> new_arg_types;
new_arg_types.reserve(func_type.getNumInputs());
// Update argument types in-place using the provided arg_shapes.
for (size_t i = 0; i < func_type.getNumInputs(); ++i) {
ArrayRef<int64_t> shape = arg_shapes[i];
- mlir::Type element_type;
- if (auto input_ty =
- func_type.getInput(i).dyn_cast<mlir::RankedTensorType>()) {
+ Type element_type;
+ if (auto input_ty = func_type.getInput(i).dyn_cast<RankedTensorType>()) {
if (!input_ty || input_ty.getShape().size() != shape.size()) {
return failure();
}
element_type = input_ty.getElementType();
} else {
- auto unranked_input_ty =
- func_type.getInput(i).dyn_cast<mlir::TensorType>();
+ auto unranked_input_ty = func_type.getInput(i).dyn_cast<TensorType>();
if (!unranked_input_ty) {
return failure();
}
element_type = unranked_input_ty.getElementType();
}
- auto new_arg_type = mlir::RankedTensorType::get(shape, element_type);
+ auto new_arg_type = RankedTensorType::get(shape, element_type);
if (new_arg_type != func_type.getInput(i)) {
// If the new type is more detailed, trigger shape inference.
func.getArgument(i).setType(new_arg_type);
@@ -982,28 +1039,17 @@
return success();
}
- mlir::LogicalResult result =
- mlir::TF::InferShapeUntilFixPoint(&func.getBody(), graph_version);
+ LogicalResult result = context.InferShapeUntilFixPoint(&func.getBody());
if (failed(result)) {
return failure();
}
auto return_types = InferShapeForFunctionReturnType(func);
- func.setType(mlir::FunctionType::get(new_arg_types,
- return_types.hasValue()
- ? return_types.getValue()
- : func.getType().getResults(),
- func.getContext()));
-
- return success();
-}
-
-LogicalResult InferShapeForFunctionType(FuncOp func) {
- if (auto return_types = InferShapeForFunctionReturnType(func)) {
- func.setType(mlir::FunctionType::get(func.getType().getInputs(),
- return_types.getValue(),
- func.getContext()));
- }
+ func.setType(FunctionType::get(new_arg_types,
+ return_types.hasValue()
+ ? return_types.getValue()
+ : func.getType().getResults(),
+ func.getContext()));
return success();
}
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h
index 0524ec6..e36d8d5 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h
@@ -27,30 +27,13 @@
namespace TF {
-// Performs shape inference on the provided op and return true if the type of
-// at least one result has been changed.
-// A tf.Cast() is inserted for any uses that isn't in the TensorFlow dialect.
-// `graph_version` indicates the current GraphDef compatibility versions
-// (the versions field in graph.proto).
-bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
- int64_t graph_version);
-
-// Infers shape on the provided region, including nested ones, iterate until fix
-// point with a limit of max_iteration. Returns success if fix point is reached
-// before max_iteration.
-LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version,
- int64_t max_iteration = 10);
-
// Given a list of refined shapes matching the function arguments of func, runs
// shape inference over the function to propagate this updated information.
+// If arg_shapes are empty, then argument shapes will be left unchanged.
LogicalResult InferShapeForFunction(FuncOp func,
ArrayRef<ArrayRef<int64_t>> arg_shapes,
int64_t graph_version);
-// Refines the return type of the given function by folding tf.Cast that
-// precedes the return instruction.
-LogicalResult InferShapeForFunctionType(FuncOp func);
-
} // namespace TF
} // namespace mlir
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc
index 48e4e77..acdfc0e 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc
@@ -58,10 +58,8 @@
}
int64_t producer = producer_or.ValueOrDie();
for (auto func : module.getOps<FuncOp>()) {
- InferShapeUntilFixPoint(&func.getBody(), producer);
- // TODO(yuanzx): Verify that it is always fine to refine a function's
- // return type, as long as we do not change the argument shapes.
- InferShapeForFunctionType(func);
+ if (failed(InferShapeForFunction(func, /*arg_shapes=*/{}, producer)))
+ return signalPassFailure();
}
}
};
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc
index 141feeb..b9e2144 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc
@@ -14,11 +14,23 @@
==============================================================================*/
#include <memory>
+#include <type_traits>
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/IR/Attributes.h" // from @llvm-project
+#include "mlir/IR/Block.h" // from @llvm-project
+#include "mlir/IR/Builders.h" // from @llvm-project
+#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
+#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
+#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
namespace mlir {
namespace TFTPU {
@@ -30,30 +42,182 @@
constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
-struct TPUExtractHeadTailOutsideCompilation
- : public PassWrapper<TPUExtractHeadTailOutsideCompilation, FunctionPass> {
- void runOnFunction() override;
-};
+bool HasOutsideCompilationAttribute(Operation* op) {
+ return op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr) != nullptr;
+}
-void TPUExtractHeadTailOutsideCompilation::runOnFunction() {
- getFunction().walk([&](tf_device::LaunchOp launch) {
- Block& launch_block = launch.GetBody();
- for (auto& op : llvm::make_early_inc_range(launch_block.getOperations())) {
- // TODO(b/155115766): Handle outputs that should be inputs to TPU
- // LaunchOp.
- if (auto attr =
- op.getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
- op.moveBefore(launch);
- } else {
+// Returns whether all operands of `op` are from values inside the
+// `input_value_set`.
+bool OpContainsOperandsFromSet(Operation* op,
+ const llvm::SetVector<Value>& input_value_set) {
+ for (auto operand : op->getOperands())
+ if (input_value_set.count(operand) == 0) return false;
+
+ return true;
+}
+
+void RecordOutsideCompiledOpsAndUsages(
+ Operation* op, llvm::SmallSetVector<Operation*, 4>* outside_compiled_ops,
+ llvm::SetVector<Value>* outside_compiled_op_usages) {
+ if (HasOutsideCompilationAttribute(op) &&
+ OpContainsOperandsFromSet(op, *outside_compiled_op_usages)) {
+ outside_compiled_ops->insert(op);
+ outside_compiled_op_usages->insert(op->getResults().begin(),
+ op->getResults().end());
+ }
+}
+
+// Traverses the MLIR graph and returns a set of ops that
+// are connected to inputs of TPU computation and outside compiled.
+void ExtractOutsideCompiledOpsConnectedToHead(
+ Value input_value, llvm::SetVector<Value>* values_used_in_host_cluster,
+ llvm::SmallSetVector<Operation*, 4>* outside_compiled_ops) {
+ llvm::SmallSetVector<Operation*, 4> parent_outside_compiled_ops_at_head;
+ for (auto& usage : input_value.getUses()) {
+ auto head_operation = usage.getOwner();
+ RecordOutsideCompiledOpsAndUsages(head_operation,
+ &parent_outside_compiled_ops_at_head,
+ values_used_in_host_cluster);
+ }
+
+ // Traverse the graph and find all outside compiled ops connected from
+ // the `input_value`.
+ while (!parent_outside_compiled_ops_at_head.empty()) {
+ llvm::SmallSetVector<Operation*, 4> connected_outside_compiled_ops;
+ for (auto head_outside_compiled_op : parent_outside_compiled_ops_at_head) {
+ auto op_results = head_outside_compiled_op->getOpResults();
+ for (auto op_result : op_results) {
+ for (auto& use : op_result.getUses()) {
+ auto connected_op = use.getOwner();
+ RecordOutsideCompiledOpsAndUsages(connected_op,
+ &connected_outside_compiled_ops,
+ values_used_in_host_cluster);
+ }
+ }
+ }
+
+ outside_compiled_ops->insert(parent_outside_compiled_ops_at_head.begin(),
+ parent_outside_compiled_ops_at_head.end());
+ std::swap(parent_outside_compiled_ops_at_head,
+ connected_outside_compiled_ops);
+ }
+}
+
+// TODO(hongjunchoi): Also handle ops without inputs that are outside
+// compiled.
+//
+// Returns set of ops that are outside compiled and are directly connected
+// to inputs to the TPU computation.
+llvm::SmallSetVector<Operation*, 4> IdentifyOutsideCompiledOpsAtHead(
+ tf_device::ClusterOp tpu_cluster) {
+ llvm::SmallSetVector<Operation*, 4> outside_compiled_at_head_ops;
+ llvm::SetVector<Value> values_used_in_cluster;
+ auto& cluster_region = tpu_cluster.body();
+ getUsedValuesDefinedAbove(cluster_region, cluster_region,
+ values_used_in_cluster);
+
+ auto input_value_list = llvm::to_vector<8>(values_used_in_cluster);
+ for (auto input_value : input_value_list)
+ ExtractOutsideCompiledOpsConnectedToHead(
+ input_value, &values_used_in_cluster, &outside_compiled_at_head_ops);
+ return outside_compiled_at_head_ops;
+}
+
+// Returns output values of extracted outside compiled cluster at head that
+// are used by the TPU computation.
+llvm::SmallVector<Value, 8> GetHeadExtractedClusterOutputs(
+ const llvm::SmallSetVector<Operation*, 4>& head_outside_compiled_ops) {
+ llvm::SmallVector<Value, 8> outputs;
+ outputs.reserve(head_outside_compiled_ops.size());
+
+ for (auto op : head_outside_compiled_ops) {
+ for (Operation* user : op->getUsers()) {
+ if (!head_outside_compiled_ops.count(user)) {
+ outputs.append(op->result_begin(), op->result_end());
break;
}
}
+ }
+
+ return outputs;
+}
+
+// Creates new tf_device.launch op with outside compiled ops extracted
+// from the head of TPU computation.
+llvm::Optional<tf_device::LaunchOp> IsolateHeadExtractedOpsToLaunchOp(
+ OpBuilder* builder, tf_device::ClusterOp cluster,
+ const llvm::SmallSetVector<Operation*, 4>& head_outside_compiled_ops) {
+ if (head_outside_compiled_ops.empty())
+ return llvm::Optional<tf_device::LaunchOp>();
+
+ // Create tf_device.launch op to separate all extracted outside compiled ops
+ // before the tf_device.cluster.
+ auto output_values =
+ GetHeadExtractedClusterOutputs(head_outside_compiled_ops);
+
+ llvm::SmallVector<Type, 8> output_return_types;
+ output_return_types.reserve(output_values.size());
+ for (auto output : output_values)
+ output_return_types.emplace_back(output.getType());
+
+ builder->setInsertionPoint(cluster);
+ auto host_launch_op = builder->create<tf_device::LaunchOp>(
+ cluster.getLoc(), builder->getStringAttr(""), output_return_types);
+
+ // Replace all usages of outside compiled ops that are used in TPU
+ // computation with the results of the above created launch op.
+ for (auto output_and_index : llvm::enumerate(output_values)) {
+ auto output_index = output_and_index.index();
+ auto output = output_and_index.value();
+ for (auto& use : output.getUses()) {
+ if (!head_outside_compiled_ops.count(use.getOwner()))
+ use.set(host_launch_op.getResult(output_index));
+ }
+ }
+
+ // Create terminator op for the newly created launch op.
+ host_launch_op.body().push_back(new Block());
+ builder->setInsertionPointToEnd(&host_launch_op.GetBody());
+ auto terminator = builder->create<tf_device::ReturnOp>(
+ host_launch_op.getLoc(), output_values);
+
+ // Move all outside compile ops from cluster op to launch op.
+ for (auto outside_compiled_op : head_outside_compiled_ops)
+ outside_compiled_op->moveBefore(terminator);
+
+ return host_launch_op;
+}
+
+struct TPUExtractHeadTailOutsideCompilation
+ : public PassWrapper<TPUExtractHeadTailOutsideCompilation,
+ OperationPass<ModuleOp>> {
+ void runOnOperation() override;
+};
+
+void TPUExtractHeadTailOutsideCompilation::runOnOperation() {
+ // Get runtime devices information from the closest parent module.
+ auto module = getOperation();
+ mlir::TF::RuntimeDevices devices;
+ if (failed(tensorflow::GetDevicesFromOp(module, &devices)))
+ return signalPassFailure();
+
+ OpBuilder builder(&getContext());
+ module.walk([&](tf_device::ClusterOp cluster) {
+ auto head_outside_compiled_ops = IdentifyOutsideCompiledOpsAtHead(cluster);
+ IsolateHeadExtractedOpsToLaunchOp(&builder, cluster,
+ head_outside_compiled_ops);
+
+ // TODO(b/156030523): Update device attribute of newly created host launch
+ // op as well as enclosing Replicate op (if TPU computation is replicated)
+ // with host device names.
+
+ // TODO(b/155115766): Implement tail outside compiled op extraction.
});
}
} // anonymous namespace
-std::unique_ptr<OperationPass<FuncOp>>
+std::unique_ptr<OperationPass<ModuleOp>>
CreateTPUExtractHeadTailOutsideCompilationPass() {
return std::make_unique<TPUExtractHeadTailOutsideCompilation>();
}
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc
index 98ff0de..f5e9da9 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc
@@ -92,7 +92,7 @@
//
// Would become following ops (unimportant attributes, types are omitted):
// %1 = "tf.Shape"(%0)
-// %2:2 = "tf.MLIRCompileToTPU"(%1) {module = "<Serialized @tpu_func>"}
+// %2:2 = "tf._TPUCompileMlir"(%1) {module = "<Serialized @tpu_func>"}
// "tf.TPUCompileSucceededAssert"(%2#0)
// %3 = "tf.TPUExecute"(%0, %2#1)
// %4 = "tf.SomeOp"(%3)
@@ -448,19 +448,20 @@
// core, and all replica devices per core are grouped together.
void AssignDevicesToReplicate(
tf_device::ReplicateOp replicate,
- llvm::ArrayRef<llvm::SmallVector<std::string, 8>> execution_devices,
+ llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost, 8>>
+ tpu_devices,
OpBuilder* builder) {
if (!replicate) return;
- const int num_replicas = execution_devices.size();
- const int num_cores_per_replica = execution_devices.front().size();
+ const int num_replicas = tpu_devices.size();
+ const int num_cores_per_replica = tpu_devices.front().size();
llvm::SmallVector<NamedAttribute, 8> device_attrs;
for (int core = 0; core < num_cores_per_replica; ++core) {
llvm::SmallVector<StringRef, 8> devices_by_core;
devices_by_core.reserve(num_replicas);
for (int replica = 0; replica < num_replicas; ++replica)
- devices_by_core.push_back(execution_devices[replica][core]);
+ devices_by_core.push_back(tpu_devices[replica][core].device);
device_attrs.push_back(
builder->getNamedAttr(tensorflow::GetDeviceAliasForLogicalCore(core),
@@ -492,11 +493,12 @@
// Creates a tf_device.parallel_execute op that wraps TPUExecute op to
// represent execution of TPU program in multiple logical cores.
LogicalResult BuildParallelExecuteOp(
- llvm::ArrayRef<llvm::SmallVector<std::string, 8>> execution_devices,
+ llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost, 8>>
+ tpu_devices,
llvm::ArrayRef<xla::OpSharding> output_sharding_config,
Operation* compile_op, tf_device::ClusterFuncOp cluster_func,
OpBuilder* builder, tf_device::ParallelExecuteOp* parallel_execute_op) {
- const int num_cores_per_replica = execution_devices.front().size();
+ const int num_cores_per_replica = tpu_devices.front().size();
// parallel_execute op returns concatenated list of return values of
// all its regions.
//
@@ -528,7 +530,7 @@
num_cores_per_replica, cluster_func, builder, &input_list);
if (failed(result)) return failure();
- const bool replicated = execution_devices.size() != 1;
+ const bool replicated = tpu_devices.size() != 1;
// For each logical core, create a region with TPUExecute op.
assert(input_list.size() == num_cores_per_replica);
for (int core = 0; core < num_cores_per_replica; ++core) {
@@ -553,7 +555,7 @@
// op.
std::string device = replicated
? tensorflow::GetDeviceAliasForLogicalCore(core)
- : execution_devices.front()[core];
+ : tpu_devices.front()[core].device;
auto region_launch_op =
WrapOpInLaunch(builder, region.getParent()->getLoc(), execute, device);
@@ -566,13 +568,14 @@
}
tf_device::LaunchOp AssignDevicesToReplicatedExecute(
- llvm::ArrayRef<llvm::SmallVector<std::string, 8>> execution_devices,
+ llvm::ArrayRef<llvm::SmallVector<tensorflow::TPUDeviceAndHost, 8>>
+ tpu_devices,
Operation* execute_op, OpBuilder* builder) {
- const bool replicated = execution_devices.size() != 1;
+ const bool replicated = tpu_devices.size() != 1;
// If computation is replicated, use aliased device. Otherwise there is only
// one execution device and the device is assigned to the execute op.
std::string device = replicated ? tensorflow::GetDeviceAliasForLogicalCore(0)
- : execution_devices.front().front();
+ : tpu_devices.front().front().device;
return WrapOpInLaunch(builder, execute_op->getLoc(), execute_op, device);
}
@@ -687,6 +690,16 @@
// Create compile op.
auto& tpu_device_assignment = status_or_tpu_device_assignment.ValueOrDie();
builder->setInsertionPoint(cluster_func);
+
+ // Create the TPUCompileMlir and TPUCompileSucceededAssert outside of
+ // parallel_execute region if it exists.
+ if (llvm::isa<tf_device::ParallelExecuteOp>(cluster_func.getParentOp())) {
+ // Currently, outside compilation and model parallelism are not supported
+ // together.
+ assert(num_cores_per_replica == 1);
+ builder->setInsertionPoint(cluster_func.getParentOp());
+ }
+
Operation* compile_op = BuildCompileOp(
cluster_func, num_replicas, num_cores_per_replica,
tpu_device_assignment.compilation_device,
@@ -704,7 +717,7 @@
BuildTPUCompileSucceededAssertOp(
compile_op, tpu_device_assignment.compilation_device, builder);
- AssignDevicesToReplicate(replicate, tpu_device_assignment.execution_devices,
+ AssignDevicesToReplicate(replicate, tpu_device_assignment.tpu_devices,
builder);
llvm::SmallVector<xla::OpSharding, 4> output_shardings;
@@ -712,12 +725,13 @@
num_cores_per_replica, cluster_func, &output_shardings);
if (failed(result)) return failure();
+ builder->setInsertionPoint(cluster_func);
if (num_cores_per_replica > 1) {
// For model parallelism, tf_device.parallel_execute is used to express
// concurrent device execution across multiple logical devices.
tf_device::ParallelExecuteOp execute_op;
- result = BuildParallelExecuteOp(tpu_device_assignment.execution_devices,
+ result = BuildParallelExecuteOp(tpu_device_assignment.tpu_devices,
output_shardings, compile_op, cluster_func,
builder, &execute_op);
if (failed(result)) return failure();
@@ -740,7 +754,7 @@
if (failed(result)) return failure();
tf_device::LaunchOp launch_op = AssignDevicesToReplicatedExecute(
- tpu_device_assignment.execution_devices, execute_op, builder);
+ tpu_device_assignment.tpu_devices, execute_op, builder);
cluster_func.replaceAllUsesWith(launch_op);
}
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
index 2374687..e8ca691 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
@@ -293,6 +293,12 @@
tf2xla.addPass(mlir::xla_hlo::createLegalizeTfWithTf2XlaPass(device_type));
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
+ // Run shape inference pass to propagate shapes through tensor_cast operations
+ // from static to dynamic shapes. This could be generated if the shape
+ // inference was originally missing in a TF op but the corresponding HLO op
+ // had static shape after lowering.
+ tf2xla.addPass(mlir::TF::CreateTFShapeInferencePass());
+
// Run LegalizeTFPass again because the previous legalization passes can
// expose more graph pruning and canonicalization opportunities that are
// necessary for the second LegalizeTFPass(allow_partial_conversion=false)
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc
index fcfef56..b28f26b 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc
@@ -31,12 +31,14 @@
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
+#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/stream_executor/lib/statusor.h"
@@ -131,13 +133,21 @@
case DTYPE: \
return ConvertFlatTensor<CTYPE>(input_tensor, type);
- // TODO(fengliuai): customize the conversions for more types.
+ // TODO(fengliuai): customize the conversions for quantized and string types.
switch (input_dtype) {
CONVERT_FLAT(DT_BOOL, bool)
CONVERT_FLAT(DT_FLOAT, float)
CONVERT_FLAT(DT_DOUBLE, double)
+ CONVERT_FLAT(DT_INT8, int8)
+ CONVERT_FLAT(DT_INT16, int16)
CONVERT_FLAT(DT_INT32, int32)
CONVERT_FLAT(DT_INT64, int64)
+ CONVERT_FLAT(DT_UINT8, uint8)
+ CONVERT_FLAT(DT_UINT16, uint16)
+ CONVERT_FLAT(DT_UINT32, uint32)
+ CONVERT_FLAT(DT_UINT64, uint64)
+ CONVERT_FLAT(DT_COMPLEX64, std::complex<float>)
+ CONVERT_FLAT(DT_COMPLEX128, std::complex<double>)
// BFLOAT16 is a special case that it needs to be cast to double type to
// match its storage type.
@@ -207,12 +217,20 @@
// Converts an MLIR dense string elements attribute to a TensorFlow tensor
// proto.
-Status ConvertStringElementsAttr(const DenseStringElementsAttr attr,
- TensorProto* output_tensor) {
- for (const auto& val : attr.getRawStringData()) {
- output_tensor->add_string_val(val.data(), val.size());
+void ConvertStringElementsAttr(
+ const DenseStringElementsAttr attr,
+ protobuf::RepeatedPtrField<std::string>* output) {
+ for (const auto& val : attr.getRawStringData())
+ output->Add({val.data(), val.size()});
+}
+
+template <typename T>
+void ConvertComplexElementsAttr(const mlir::DenseElementsAttr attr,
+ protobuf::RepeatedField<T>* output) {
+ for (const auto& val : attr.getValues<std::complex<T>>()) {
+ output->Add(val.real());
+ output->Add(val.imag());
}
- return Status::OK();
}
// Converts an MLIR opaque elements attribute to a TensorFlow tensor proto.
@@ -226,139 +244,80 @@
return InvalidArgument("Unexpected elements attribute type from MLIR.");
}
-// Converts an MLIR elements attribute to a TensorFlow tensor proto
-// with the double_val field updated.
-Status ConvertDoubleElementsAttr(const ElementsAttr attr,
- TensorProto* output_tensor) {
- if (auto elts = attr.dyn_cast<DenseFPElementsAttr>()) {
- if (elts.isSplat()) {
- output_tensor->add_double_val(elts.getSplatValue<double>());
- } else {
- for (auto value : elts.getValues<double>())
- output_tensor->add_double_val(value);
- }
- return Status::OK();
- }
- return ConvertOpaqueElementsAttr(attr, output_tensor);
-}
-
-// Converts an MLIR elements attribute to a TensorFlow tensor proto
-// with the float_val field updated.
-Status ConvertFloatElementsAttr(const ElementsAttr attr,
- TensorProto* output_tensor) {
- if (auto elts = attr.dyn_cast<DenseFPElementsAttr>()) {
- if (elts.isSplat()) {
- output_tensor->add_float_val(elts.getSplatValue<float>());
- } else {
- for (auto value : elts.getValues<float>())
- output_tensor->add_float_val(value);
- }
- return Status::OK();
- }
- return ConvertOpaqueElementsAttr(attr, output_tensor);
-}
-
-// Converts an MLIR elements attribute to a TensorFlow tensor proto
-// with the half_val field updated.
-Status ConvertHalfElementsAttr(const ElementsAttr attr,
- TensorProto* output_tensor) {
- if (auto elts = attr.dyn_cast<DenseFPElementsAttr>()) {
- if (elts.isSplat()) {
- output_tensor->add_half_val(
- (*elts.begin()).bitcastToAPInt().getSExtValue());
- } else {
- for (const auto& value : elts.getFloatValues())
- output_tensor->add_half_val(value.bitcastToAPInt().getSExtValue());
- }
- return Status::OK();
- }
- return ConvertOpaqueElementsAttr(attr, output_tensor);
-}
-
-// Converts an MLIR elements attribute to a TensorFlow tensor proto
-// with the int_val field updated.
-Status ConvertIntElementsAttr(const mlir::ElementsAttr attr,
- TensorProto* output_tensor) {
- if (auto elts = attr.dyn_cast<DenseIntElementsAttr>()) {
- if (elts.isSplat()) {
- output_tensor->add_int_val((*elts.begin()).getSExtValue());
- } else {
- for (const auto& val : elts)
- output_tensor->add_int_val(val.getSExtValue());
- }
- return Status::OK();
- }
- return ConvertOpaqueElementsAttr(attr, output_tensor);
-}
-
-Status ConvertBfloat16ElementsAttr(const mlir::ElementsAttr attr,
- TensorProto* output_tensor) {
- auto elts = attr.dyn_cast<DenseFPElementsAttr>();
- if (!elts) {
- return ConvertOpaqueElementsAttr(attr, output_tensor);
- }
-
- // Bfloat16 is internally represented as `double` in MLIR.
- if (elts.isSplat()) {
- double v = elts.getSplatValue<double>();
- bfloat16 bf16_val = static_cast<bfloat16>(v);
- output_tensor->add_half_val(absl::bit_cast<int16>(bf16_val));
+// Converts an MLIR elements attribute and adds it to specified repeated field.
+template <typename T>
+void ConvertElementsAttr(const mlir::DenseElementsAttr attr,
+ protobuf::RepeatedField<T>* output) {
+ if (attr.isSplat()) {
+ output->Add(attr.getSplatValue<T>());
} else {
- for (auto v : elts.getValues<double>()) {
+ for (auto value : attr.getValues<T>()) output->Add(value);
+ }
+}
+
+// Converts an MLIR elements attribute containing half values and adds it to
+// specified repeated field.
+void ConvertHalfElementsAttr(const DenseFPElementsAttr attr,
+ protobuf::RepeatedField<int>* output_tensor) {
+ if (attr.isSplat()) {
+ output_tensor->Add((*attr.begin()).bitcastToAPInt().getSExtValue());
+ } else {
+ for (const llvm::APFloat value : attr.getFloatValues())
+ output_tensor->Add(value.bitcastToAPInt().getSExtValue());
+ }
+}
+
+// Converts an MLIR elements attribute containing int values and adds it to
+// specified repeated field.
+void ConvertIntElementsAttr(const mlir::DenseIntElementsAttr attr,
+ protobuf::RepeatedField<int>* output) {
+ if (attr.isSplat()) {
+ output->Add((*attr.begin()).getSExtValue());
+ } else {
+ for (const llvm::APInt val : attr) output->Add(val.getSExtValue());
+ }
+}
+
+void ConvertBfloat16ElementsAttr(const mlir::DenseFPElementsAttr attr,
+ protobuf::RepeatedField<int>* output) {
+ // Bfloat16 is internally represented as `double` in MLIR.
+ if (attr.isSplat()) {
+ double v = attr.getSplatValue<double>();
+ bfloat16 bf16_val = static_cast<bfloat16>(v);
+ output->Add(absl::bit_cast<int16>(bf16_val));
+ } else {
+ for (auto v : attr.getValues<double>()) {
bfloat16 bf16_val = static_cast<bfloat16>(v);
- output_tensor->add_half_val(absl::bit_cast<int16>(bf16_val));
+ output->Add(absl::bit_cast<int16>(bf16_val));
}
}
-
- return Status::OK();
}
-// Converts an MLIR elements attribute to a TensorFlow tensor proto
-// with the int64_val field updated.
-Status ConvertInt64ElementsAttr(const mlir::ElementsAttr attr,
- TensorProto* output_tensor) {
- if (auto elts = attr.dyn_cast<DenseIntElementsAttr>()) {
- if (elts.isSplat()) {
- output_tensor->add_int64_val((*elts.begin()).getSExtValue());
- } else {
- for (const auto& val : elts)
- output_tensor->add_int64_val(val.getSExtValue());
- }
- return Status::OK();
- }
- return ConvertOpaqueElementsAttr(attr, output_tensor);
-}
-
-// Converts an MLIR elements attribute to a TensorFlow tensor proto
-// with bool_val field updated.
-Status ConvertBoolElementsAttr(const mlir::ElementsAttr attr,
- TensorProto* output_tensor) {
- if (auto elts = attr.dyn_cast<DenseIntElementsAttr>()) {
- for (const auto& val : elts) {
- output_tensor->add_bool_val(val.getBoolValue());
- }
- return Status::OK();
- }
- return ConvertOpaqueElementsAttr(attr, output_tensor);
-}
-
-Status ConvertToTensorProto(const ElementsAttr attr,
- TensorProto* output_tensor) {
+Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) {
auto type = attr.getType();
auto shape = type.getShape();
DataType output_dtype;
TF_RETURN_IF_ERROR(ConvertToDataType(type, &output_dtype));
- output_tensor->set_dtype(output_dtype);
- ConvertToTensorShapeProto(shape, output_tensor->mutable_tensor_shape());
+ output->set_dtype(output_dtype);
+ ConvertToTensorShapeProto(shape, output->mutable_tensor_shape());
+
+ if (attr.isa<OpaqueElementsAttr>())
+ return ConvertOpaqueElementsAttr(attr.cast<OpaqueElementsAttr>(), output);
+
+ auto dense_attr = attr.dyn_cast<mlir::DenseElementsAttr>();
+ if (!dense_attr) return errors::InvalidArgument("Unsupported elements attr");
switch (output_dtype) {
case DT_FLOAT:
- return ConvertFloatElementsAttr(attr, output_tensor);
+ ConvertElementsAttr<float>(dense_attr, output->mutable_float_val());
+ break;
case DT_HALF:
- // Handles both DenseFPElementsAttr and OpaqueElementsAttr.
- return ConvertHalfElementsAttr(attr, output_tensor);
+ ConvertHalfElementsAttr(dense_attr.cast<DenseFPElementsAttr>(),
+ output->mutable_half_val());
+ break;
case DT_DOUBLE:
- return ConvertDoubleElementsAttr(attr, output_tensor);
+ ConvertElementsAttr(dense_attr, output->mutable_double_val());
+ break;
case DT_QUINT8:
case DT_UINT8:
case DT_INT8:
@@ -366,20 +325,40 @@
case DT_UINT16:
case DT_INT16:
case DT_INT32:
- return ConvertIntElementsAttr(attr, output_tensor);
+ ConvertIntElementsAttr(dense_attr.cast<DenseIntElementsAttr>(),
+ output->mutable_int_val());
+ break;
+ case DT_UINT32:
+ ConvertElementsAttr(dense_attr, output->mutable_uint32_val());
+ break;
+ case DT_UINT64:
+ ConvertElementsAttr(dense_attr, output->mutable_uint64_val());
+ break;
case DT_INT64:
- return ConvertInt64ElementsAttr(attr, output_tensor);
+ ConvertElementsAttr(dense_attr, output->mutable_int64_val());
+ break;
case DT_BOOL:
- return ConvertBoolElementsAttr(attr, output_tensor);
+ ConvertElementsAttr(dense_attr, output->mutable_bool_val());
+ break;
case DT_BFLOAT16:
- return ConvertBfloat16ElementsAttr(attr, output_tensor);
+ ConvertBfloat16ElementsAttr(dense_attr.cast<DenseFPElementsAttr>(),
+ output->mutable_half_val());
+ break;
case DT_STRING:
- return ConvertStringElementsAttr(attr.cast<DenseStringElementsAttr>(),
- output_tensor);
+ ConvertStringElementsAttr(dense_attr.cast<DenseStringElementsAttr>(),
+ output->mutable_string_val());
+ break;
+ case DT_COMPLEX64:
+ ConvertComplexElementsAttr(dense_attr, output->mutable_scomplex_val());
+ break;
+ case DT_COMPLEX128:
+ ConvertComplexElementsAttr(dense_attr, output->mutable_dcomplex_val());
+ break;
default:
- return ConvertOpaqueElementsAttr(attr.cast<OpaqueElementsAttr>(),
- output_tensor);
+ return errors::Unimplemented(absl::StrCat("Unimplemented data type ",
+ DataTypeString(output_dtype)));
}
+ return Status::OK();
}
Status ConvertToTensor(const mlir::ElementsAttr attr, Tensor* output_tensor) {
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc
index d711c19..bf96e3d 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc
@@ -16,6 +16,7 @@
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
#include <cstring>
+#include <initializer_list>
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
@@ -99,48 +100,74 @@
EXPECT_EQ(string_values[3], mlir::StringRef("four"));
}
-TEST(ConvertTypeToTensorTypeTest, Convert16BitFloats) {
+class ConvertTensorTest : public ::testing::Test {
+ protected:
+ template <typename T>
+ void VerifyConversion(std::initializer_list<T> values, DataType dtype,
+ mlir::Type expected_ty) {
+ mlir::Builder b(expected_ty.getContext());
+ Tensor tensor(dtype, TensorShape({static_cast<int64>(values.size())}));
+ tensor.flat<T>().setValues(values);
+
+ auto value_or = ConvertTensor(tensor, &b);
+ TF_ASSERT_OK(value_or.status());
+ auto attr = value_or.ValueOrDie();
+
+ EXPECT_EQ(attr.getType().getElementType(), expected_ty);
+
+ Tensor out;
+ TF_ASSERT_OK(ConvertToTensor(attr, &out));
+
+ test::ExpectTensorEqual<T>(tensor, out);
+ }
+};
+
+TEST_F(ConvertTensorTest, Simple) {
RegisterDialects();
+
mlir::MLIRContext context;
- mlir::Builder b(&context);
+ ASSERT_NO_FATAL_FAILURE(VerifyConversion<Eigen::half>(
+ {Eigen::half(1.0)}, DT_HALF, mlir::FloatType::getF16(&context)));
+ ASSERT_NO_FATAL_FAILURE(
+ VerifyConversion<bfloat16>({bfloat16(1.0), bfloat16(-1.0)}, DT_BFLOAT16,
+ mlir::FloatType::getBF16(&context)));
+ ASSERT_NO_FATAL_FAILURE(VerifyConversion<float>(
+ {1.0, -1.0}, DT_FLOAT, mlir::FloatType::getF32(&context)));
+ ASSERT_NO_FATAL_FAILURE(VerifyConversion<double>(
+ {1.0, -1.0}, DT_DOUBLE, mlir::FloatType::getF64(&context)));
- {
- // Create the sample tensor to convert.
- Tensor tensor(DT_HALF, TensorShape({1}));
- auto Tt = tensor.flat<Eigen::half>();
- Tt.setValues({Eigen::half(1.0)});
+ ASSERT_NO_FATAL_FAILURE(VerifyConversion<int8>(
+ {1, -1}, DT_INT8, mlir::IntegerType::get(8, &context)));
+ ASSERT_NO_FATAL_FAILURE(VerifyConversion<int16>(
+ {1, -1}, DT_INT16, mlir::IntegerType::get(16, &context)));
+ ASSERT_NO_FATAL_FAILURE(VerifyConversion<int32>(
+ {1, -1}, DT_INT32, mlir::IntegerType::get(32, &context)));
+ ASSERT_NO_FATAL_FAILURE(VerifyConversion<int64>(
+ {1, -1}, DT_INT64, mlir::IntegerType::get(64, &context)));
- auto value_or = ConvertTensor(tensor, &b);
- TF_EXPECT_OK(value_or.status());
- auto attr = value_or.ValueOrDie();
+ ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint8>(
+ {1, 2}, DT_UINT8,
+ mlir::IntegerType::get(
+ 8, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
+ ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint16>(
+ {1, 2}, DT_UINT16,
+ mlir::IntegerType::get(
+ 16, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
+ ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint32>(
+ {1, 2}, DT_UINT32,
+ mlir::IntegerType::get(
+ 32, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
+ ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint64>(
+ {1, 2}, DT_UINT64,
+ mlir::IntegerType::get(
+ 64, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
- EXPECT_TRUE(attr.isa<mlir::DenseFPElementsAttr>());
- EXPECT_TRUE(attr.getType().getElementType().isF16());
-
- Tensor out;
- TF_ASSERT_OK(ConvertToTensor(attr, &out));
-
- test::ExpectTensorEqual<Eigen::half>(tensor, out);
- }
-
- {
- // Create the sample tensor to convert.
- Tensor tensor(DT_BFLOAT16, TensorShape({2}));
- auto Tt = tensor.flat<bfloat16>();
- Tt.setValues({bfloat16(1.0), bfloat16(-1.0)});
-
- auto value_or = ConvertTensor(tensor, &b);
- TF_EXPECT_OK(value_or.status());
- auto attr = value_or.ValueOrDie();
-
- EXPECT_TRUE(attr.isa<mlir::DenseFPElementsAttr>());
- EXPECT_TRUE(attr.getType().getElementType().isBF16());
-
- Tensor out;
- TF_ASSERT_OK(ConvertToTensor(attr, &out));
-
- test::ExpectTensorEqual<bfloat16>(tensor, out);
- }
+ ASSERT_NO_FATAL_FAILURE(VerifyConversion<std::complex<float>>(
+ {{0.0, 1.0}, {1.0, 0.0}}, DT_COMPLEX64,
+ mlir::ComplexType::get(mlir::FloatType::getF32(&context))));
+ ASSERT_NO_FATAL_FAILURE(VerifyConversion<std::complex<double>>(
+ {{0.0, 1.0}, {1.0, 0.0}}, DT_COMPLEX128,
+ mlir::ComplexType::get(mlir::FloatType::getF64(&context))));
}
} // namespace
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc
index cc79525..4877cbc 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc
@@ -59,6 +59,18 @@
namespace tensorflow {
namespace {
+// static TensorFlow op prefix set.
+std::set<std::string>* GlobalOpPrefixes() {
+ static std::set<std::string>* global_op_prefixes = [] {
+ std::set<std::string>* result = new std::set<std::string>;
+ result->insert("tf.");
+ result->insert("_tf.");
+ result->insert("tf_executor.");
+ return result;
+ }();
+ return global_op_prefixes;
+}
+
// Converts a location to the debug information for the node def.
Status ConvertLocation(mlir::Location inst_loc,
NodeDef::ExperimentalDebugInfo* debug_info) {
@@ -268,8 +280,10 @@
// - ".sink" or ".Sink": only the NextIteration operation has this suffix. We
// don't need to consider ".source"/".Source" because the nodes with this
// suffix are skipped by the caller and will not be added to the graph.
- if (!op_name.consume_front("_tf.") && !op_name.consume_front("tf.") &&
- !op_name.consume_front("tf_executor.")) {
+ auto prefixes = GlobalOpPrefixes();
+ if (std::none_of(prefixes->begin(), prefixes->end(), [&](std::string prefix) {
+ return op_name.consume_front(prefix);
+ })) {
return errors::FailedPrecondition("op node '", op_name.str(),
"' was not a TF op!");
}
@@ -506,4 +520,9 @@
inst->getName().getStringRef().compare("_tf.LegacyCall") == 0;
}
+Status AddTensorFlowOpPrefix(std::string prefix) {
+ GlobalOpPrefixes()->insert(prefix);
+ return Status::OK();
+}
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h
index 32ed528..58fe39f 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h
+++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h
@@ -34,10 +34,17 @@
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/stream_executor/lib/statusor.h"
+namespace mlir {
+class ShapedType;
+} // namespace mlir
+
namespace tensorflow {
using stream_executor::port::StatusOr;
+// Add custom op prefix for TensorFlow dialects.
+Status AddTensorFlowOpPrefix(std::string);
+
// Maps an MLIR op name in the TensorFlow dialect or the TensorFlow control
// dialect back into a TensorFlow valid op name.
StatusOr<llvm::StringRef> GetTensorFlowOpName(llvm::StringRef);
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc
index ddbcc91..06c10c2 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc
@@ -164,12 +164,19 @@
return DeviceNameUtils::ParsedNameToString(system_device);
}
+// Finds the host CPU device for a given TPU device.
+std::string GetCPUHostDeviceForTPUDevice(Device tpu_device) {
+ tpu_device.type = DEVICE_CPU;
+ tpu_device.id = 0;
+ return DeviceNameUtils::ParsedNameToString(tpu_device);
+}
+
// Determines execution devices when topology and device assignment are not
// defined. This is a special case where a single core computation is replicated
// to every core in the mesh. TPU devices are simply added to
// `execution_devices` of one replica. `num_replicas` must be 1 or the total
// number of TPU devices available, and `num_cores_per_replica` must be 1.
-StatusOr<ExecutionDevices> GetFullMeshTPUExecutionDeviceAssignment(
+StatusOr<TPUDevicesAndHosts> GetFullMeshTPUExecutionDeviceAssignment(
int num_replicas, int num_cores_per_replica,
llvm::ArrayRef<llvm::SmallVector<Device, 8>> tpu_devices) {
const int num_tasks = tpu_devices.size();
@@ -185,17 +192,18 @@
"'num_cores_per_replica' must be equal to 1, got ",
num_cores_per_replica);
- ExecutionDevices execution_devices;
- execution_devices.reserve(num_replicas);
+ TPUDevicesAndHosts devices_and_hosts;
+ devices_and_hosts.reserve(num_replicas);
for (int i = 0; i < num_replicas; ++i) {
const int task = i / num_tpus_per_task;
const int device = i % num_tpus_per_task;
- execution_devices.push_back(
- {tensorflow::DeviceNameUtils::ParsedNameToString(
- tpu_devices[task][device])});
+ const auto& tpu_device = tpu_devices[task][device];
+ devices_and_hosts.push_back({TPUDeviceAndHost(
+ /*device=*/tensorflow::DeviceNameUtils::ParsedNameToString(tpu_device),
+ /*host=*/GetCPUHostDeviceForTPUDevice(tpu_device))});
}
- return execution_devices;
+ return devices_and_hosts;
}
// Helper struct for keeping track of task and device for an associated TPU
@@ -326,7 +334,7 @@
// - number of device coordinates (in tuple 3) match number 'num_replicas' *
// 'num_cores_per_replica'
// - a TPU device associated with each device coordinate
-StatusOr<std::pair<ExecutionDevices, xla::DeviceAssignmentProto>>
+StatusOr<std::pair<TPUDevicesAndHosts, xla::DeviceAssignmentProto>>
GetGeneralTPUExecutionDeviceAssignment(
int num_replicas, int num_cores_per_replica,
llvm::ArrayRef<llvm::SmallVector<Device, 8>> tpu_devices,
@@ -361,9 +369,9 @@
std::vector<bool> used_device_ids(
location_to_id(bound_x - 1, bound_y - 1, bound_z - 1, bound_core - 1),
false);
- ExecutionDevices execution_devices(
- num_replicas,
- llvm::SmallVector<std::string, 8>(num_cores_per_replica, ""));
+ TPUDevicesAndHosts devices_and_hosts(
+ num_replicas, llvm::SmallVector<TPUDeviceAndHost, 8>(
+ num_cores_per_replica, TPUDeviceAndHost()));
xla::DeviceAssignment device_assignment(num_replicas, num_cores_per_replica);
int pos = 0;
for (int replica = 0; replica < num_replicas; ++replica) {
@@ -393,16 +401,18 @@
used_device_ids[device_id] = true;
device_assignment(replica, logical_core) = device_id;
- execution_devices[replica][logical_core] =
- DeviceNameUtils::ParsedNameToString(tpu_devices[task][device]);
+ auto& device_and_host = devices_and_hosts[replica][logical_core];
+ const auto& tpu_device = tpu_devices[task][device];
+ device_and_host.device = DeviceNameUtils::ParsedNameToString(tpu_device);
+ device_and_host.host = GetCPUHostDeviceForTPUDevice(tpu_device);
}
}
xla::DeviceAssignmentProto device_assignment_proto;
TF_RETURN_IF_ERROR(device_assignment.Serialize(&device_assignment_proto));
- return std::pair<ExecutionDevices, xla::DeviceAssignmentProto>(
- std::move(execution_devices), std::move(device_assignment_proto));
+ return std::pair<TPUDevicesAndHosts, xla::DeviceAssignmentProto>(
+ std::move(devices_and_hosts), std::move(device_assignment_proto));
}
} // anonymous namespace
@@ -447,27 +457,4 @@
return llvm::formatv("{0}_{1}", kTPUReplicatedCore, core_index).str();
}
-StatusOr<std::string> GetCPUHostForTPUDevice(llvm::StringRef tpu_device) {
- Device device;
- if (!DeviceNameUtils::ParseFullName(tpu_device.str(), &device))
- return errors::InvalidArgument("'", tpu_device.str(),
- "' is not a valid device");
-
- device.type = DEVICE_CPU;
- device.id = 0;
- return DeviceNameUtils::ParsedNameToString(device);
-}
-
-StatusOr<llvm::SmallVector<std::string, 8>> GetCPUHostsForTPUDevices(
- llvm::ArrayRef<std::string> tpu_devices) {
- llvm::SmallVector<std::string, 8> cpu_devices;
- cpu_devices.reserve(tpu_devices.size());
- for (const auto& tpu_device : tpu_devices) {
- TF_ASSIGN_OR_RETURN(cpu_devices.emplace_back(),
- GetCPUHostForTPUDevice(tpu_device));
- }
-
- return cpu_devices;
-}
-
} // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h
index 47ce7f1..5fdb6b8 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h
+++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h
@@ -30,29 +30,40 @@
namespace tensorflow {
using stream_executor::port::StatusOr;
-// TPU devices to be used for execution (e.g. devices for TPUExecute ops). They
-// are ordered by `num_replicas` followed by `num_cores_per_replica`.
-using ExecutionDevices =
- llvm::SmallVector<llvm::SmallVector<std::string, 8>, 8>;
+// A TPU device for execution alongside its associated host CPU device.
+struct TPUDeviceAndHost {
+ TPUDeviceAndHost() {}
+ TPUDeviceAndHost(llvm::StringRef device, llvm::StringRef host)
+ : device(device), host(host) {}
-// TPU compilation device, execution devices, and optionally execution device
-// IDs. Execution device IDs are populated if `topology` and `device_assignment`
-// are provided.
+ std::string device;
+ std::string host;
+};
+
+// TPU devices to be used for execution (e.g. devices for TPUExecute ops) and
+// their associated host CPU devices (for outside compilation). They are ordered
+// by `num_replicas` followed by `num_cores_per_replica`.
+using TPUDevicesAndHosts =
+ llvm::SmallVector<llvm::SmallVector<TPUDeviceAndHost, 8>, 8>;
+
+// TPU compilation device, execution and associated host devices, and optionally
+// execution device IDs. Execution device IDs are populated if `topology` and
+// `device_assignment` are provided.
struct TPUDeviceAssignment {
TPUDeviceAssignment(llvm::StringRef compilation_device,
- ExecutionDevices&& execution_devices)
+ TPUDevicesAndHosts&& tpu_devices)
: compilation_device(compilation_device),
- execution_devices(std::move(execution_devices)) {}
+ tpu_devices(std::move(tpu_devices)) {}
TPUDeviceAssignment(llvm::StringRef compilation_device,
- ExecutionDevices&& execution_devices,
+ TPUDevicesAndHosts&& tpu_devices,
xla::DeviceAssignmentProto&& xla_device_assignment)
: compilation_device(compilation_device),
- execution_devices(std::move(execution_devices)),
+ tpu_devices(std::move(tpu_devices)),
xla_device_assignment(std::move(xla_device_assignment)) {}
std::string compilation_device;
- ExecutionDevices execution_devices;
+ TPUDevicesAndHosts tpu_devices;
llvm::Optional<xla::DeviceAssignmentProto> xla_device_assignment;
};
@@ -216,17 +227,6 @@
// logical core.
std::string GetDeviceAliasForLogicalCore(int core_index);
-// Finds associated CPU host device for given TPU device. This assumes a
-// matching CPU host device exists based on TPU device name. An error will be
-// returned if the TPU device name is invalid.
-StatusOr<std::string> GetCPUHostForTPUDevice(llvm::StringRef tpu_device);
-
-// Finds associated CPU host devices for given TPU devices. This assumes a
-// matching CPU host device exist based on each TPU device name. An error will
-// be returned if a TPU device name is invalid.
-StatusOr<llvm::SmallVector<std::string, 8>> GetCPUHostsForTPUDevices(
- llvm::ArrayRef<std::string> tpu_devices);
-
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TPU_REWRITE_DEVICE_UTIL_H_
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc
index 57e123a..7ac5635a 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc
@@ -323,30 +323,46 @@
TF_ASSERT_OK(status_or.status());
- auto& tpu_device_assignment = status_or.ValueOrDie();
+ const auto& tpu_device_assignment = status_or.ValueOrDie();
EXPECT_EQ(tpu_device_assignment.compilation_device,
"/job:worker/replica:0/task:0/device:CPU:0");
- auto& execution_devices = tpu_device_assignment.execution_devices;
- ASSERT_EQ(execution_devices.size(), 8);
- for (const auto& replica_execution_device : execution_devices)
- ASSERT_EQ(replica_execution_device.size(), 1);
+ const auto& tpu_devices = tpu_device_assignment.tpu_devices;
+ ASSERT_EQ(tpu_devices.size(), 8);
+ for (const auto& replica_tpu_devices : tpu_devices)
+ ASSERT_EQ(replica_tpu_devices.size(), 1);
- EXPECT_EQ(execution_devices[0][0],
+ EXPECT_EQ(tpu_devices[0][0].device,
"/job:worker/replica:0/task:0/device:TPU:0");
- EXPECT_EQ(execution_devices[1][0],
+ EXPECT_EQ(tpu_devices[0][0].host,
+ "/job:worker/replica:0/task:0/device:CPU:0");
+ EXPECT_EQ(tpu_devices[1][0].device,
"/job:worker/replica:0/task:0/device:TPU:1");
- EXPECT_EQ(execution_devices[2][0],
+ EXPECT_EQ(tpu_devices[1][0].host,
+ "/job:worker/replica:0/task:0/device:CPU:0");
+ EXPECT_EQ(tpu_devices[2][0].device,
"/job:worker/replica:0/task:0/device:TPU:2");
- EXPECT_EQ(execution_devices[3][0],
+ EXPECT_EQ(tpu_devices[2][0].host,
+ "/job:worker/replica:0/task:0/device:CPU:0");
+ EXPECT_EQ(tpu_devices[3][0].device,
"/job:worker/replica:0/task:0/device:TPU:3");
- EXPECT_EQ(execution_devices[4][0],
+ EXPECT_EQ(tpu_devices[3][0].host,
+ "/job:worker/replica:0/task:0/device:CPU:0");
+ EXPECT_EQ(tpu_devices[4][0].device,
"/job:worker/replica:0/task:1/device:TPU:0");
- EXPECT_EQ(execution_devices[5][0],
+ EXPECT_EQ(tpu_devices[4][0].host,
+ "/job:worker/replica:0/task:1/device:CPU:0");
+ EXPECT_EQ(tpu_devices[5][0].device,
"/job:worker/replica:0/task:1/device:TPU:1");
- EXPECT_EQ(execution_devices[6][0],
+ EXPECT_EQ(tpu_devices[5][0].host,
+ "/job:worker/replica:0/task:1/device:CPU:0");
+ EXPECT_EQ(tpu_devices[6][0].device,
"/job:worker/replica:0/task:1/device:TPU:2");
- EXPECT_EQ(execution_devices[7][0],
+ EXPECT_EQ(tpu_devices[6][0].host,
+ "/job:worker/replica:0/task:1/device:CPU:0");
+ EXPECT_EQ(tpu_devices[7][0].device,
"/job:worker/replica:0/task:1/device:TPU:3");
+ EXPECT_EQ(tpu_devices[7][0].host,
+ "/job:worker/replica:0/task:1/device:CPU:0");
EXPECT_FALSE(tpu_device_assignment.xla_device_assignment.hasValue());
}
@@ -410,30 +426,46 @@
TF_ASSERT_OK(status_or.status());
- auto& tpu_device_assignment = status_or.ValueOrDie();
+ const auto& tpu_device_assignment = status_or.ValueOrDie();
EXPECT_EQ(tpu_device_assignment.compilation_device,
"/job:worker/replica:0/task:0/device:CPU:0");
- auto& execution_devices = tpu_device_assignment.execution_devices;
- ASSERT_EQ(execution_devices.size(), 4);
- for (const auto& replica_execution_device : execution_devices)
- ASSERT_EQ(replica_execution_device.size(), 2);
+ const auto& tpu_devices = tpu_device_assignment.tpu_devices;
+ ASSERT_EQ(tpu_devices.size(), 4);
+ for (const auto& replica_tpu_devices : tpu_devices)
+ ASSERT_EQ(replica_tpu_devices.size(), 2);
- EXPECT_EQ(execution_devices[0][0],
+ EXPECT_EQ(tpu_devices[0][0].device,
"/job:worker/replica:0/task:0/device:TPU:0");
- EXPECT_EQ(execution_devices[0][1],
+ EXPECT_EQ(tpu_devices[0][0].host,
+ "/job:worker/replica:0/task:0/device:CPU:0");
+ EXPECT_EQ(tpu_devices[0][1].device,
"/job:worker/replica:0/task:1/device:TPU:3");
- EXPECT_EQ(execution_devices[1][0],
+ EXPECT_EQ(tpu_devices[0][1].host,
+ "/job:worker/replica:0/task:1/device:CPU:0");
+ EXPECT_EQ(tpu_devices[1][0].device,
"/job:worker/replica:0/task:0/device:TPU:1");
- EXPECT_EQ(execution_devices[1][1],
+ EXPECT_EQ(tpu_devices[1][0].host,
+ "/job:worker/replica:0/task:0/device:CPU:0");
+ EXPECT_EQ(tpu_devices[1][1].device,
"/job:worker/replica:0/task:1/device:TPU:2");
- EXPECT_EQ(execution_devices[2][0],
+ EXPECT_EQ(tpu_devices[1][1].host,
+ "/job:worker/replica:0/task:1/device:CPU:0");
+ EXPECT_EQ(tpu_devices[2][0].device,
"/job:worker/replica:0/task:0/device:TPU:3");
- EXPECT_EQ(execution_devices[2][1],
+ EXPECT_EQ(tpu_devices[2][0].host,
+ "/job:worker/replica:0/task:0/device:CPU:0");
+ EXPECT_EQ(tpu_devices[2][1].device,
"/job:worker/replica:0/task:1/device:TPU:0");
- EXPECT_EQ(execution_devices[3][0],
+ EXPECT_EQ(tpu_devices[2][1].host,
+ "/job:worker/replica:0/task:1/device:CPU:0");
+ EXPECT_EQ(tpu_devices[3][0].device,
"/job:worker/replica:0/task:0/device:TPU:2");
- EXPECT_EQ(execution_devices[3][1],
+ EXPECT_EQ(tpu_devices[3][0].host,
+ "/job:worker/replica:0/task:0/device:CPU:0");
+ EXPECT_EQ(tpu_devices[3][1].device,
"/job:worker/replica:0/task:1/device:TPU:1");
+ EXPECT_EQ(tpu_devices[3][1].host,
+ "/job:worker/replica:0/task:1/device:CPU:0");
auto& xla_device_assignment = tpu_device_assignment.xla_device_assignment;
ASSERT_TRUE(xla_device_assignment.hasValue());
@@ -511,23 +543,35 @@
EXPECT_EQ(tpu_device_assignment.compilation_device,
"/job:worker/replica:0/task:0/device:CPU:0");
- auto& execution_devices = tpu_device_assignment.execution_devices;
- ASSERT_EQ(execution_devices.size(), 2);
- for (const auto& replica_execution_device : execution_devices)
- ASSERT_EQ(replica_execution_device.size(), 3);
+ auto& tpu_devices = tpu_device_assignment.tpu_devices;
+ ASSERT_EQ(tpu_devices.size(), 2);
+ for (const auto& replica_tpu_devices : tpu_devices)
+ ASSERT_EQ(replica_tpu_devices.size(), 3);
- EXPECT_EQ(execution_devices[0][0],
+ EXPECT_EQ(tpu_devices[0][0].device,
"/job:worker/replica:0/task:1/device:TPU:1");
- EXPECT_EQ(execution_devices[0][1],
+ EXPECT_EQ(tpu_devices[0][0].host,
+ "/job:worker/replica:0/task:1/device:CPU:0");
+ EXPECT_EQ(tpu_devices[0][1].device,
"/job:worker/replica:0/task:1/device:TPU:0");
- EXPECT_EQ(execution_devices[0][2],
+ EXPECT_EQ(tpu_devices[0][1].host,
+ "/job:worker/replica:0/task:1/device:CPU:0");
+ EXPECT_EQ(tpu_devices[0][2].device,
"/job:worker/replica:0/task:2/device:TPU:0");
- EXPECT_EQ(execution_devices[1][0],
+ EXPECT_EQ(tpu_devices[0][2].host,
+ "/job:worker/replica:0/task:2/device:CPU:0");
+ EXPECT_EQ(tpu_devices[1][0].device,
"/job:worker/replica:0/task:2/device:TPU:1");
- EXPECT_EQ(execution_devices[1][1],
+ EXPECT_EQ(tpu_devices[1][0].host,
+ "/job:worker/replica:0/task:2/device:CPU:0");
+ EXPECT_EQ(tpu_devices[1][1].device,
"/job:worker/replica:0/task:0/device:TPU:0");
- EXPECT_EQ(execution_devices[1][2],
+ EXPECT_EQ(tpu_devices[1][1].host,
+ "/job:worker/replica:0/task:0/device:CPU:0");
+ EXPECT_EQ(tpu_devices[1][2].device,
"/job:worker/replica:0/task:0/device:TPU:1");
+ EXPECT_EQ(tpu_devices[1][2].host,
+ "/job:worker/replica:0/task:0/device:CPU:0");
auto& xla_device_assignment = tpu_device_assignment.xla_device_assignment;
ASSERT_TRUE(xla_device_assignment.hasValue());
@@ -552,44 +596,5 @@
EXPECT_EQ(computation_device_2.replica_device_ids(1), 3);
}
-struct ParameterizedCPUHostForTPUDeviceTest
- : ::testing::TestWithParam<std::tuple<std::string, std::string>> {};
-
-TEST_P(ParameterizedCPUHostForTPUDeviceTest, CPUHostForTPUDevice) {
- auto status_or_device = GetCPUHostForTPUDevice(std::get<0>(GetParam()));
- TF_ASSERT_OK(status_or_device.status());
- EXPECT_EQ(status_or_device.ValueOrDie(), std::get<1>(GetParam()));
-}
-
-INSTANTIATE_TEST_SUITE_P(
- CPUHostForTPUDevice, ParameterizedCPUHostForTPUDeviceTest,
- ::testing::Values(
- std::make_tuple("/job:worker/replica:0/task:0/device:TPU:0",
- "/job:worker/replica:0/task:0/device:CPU:0"),
- std::make_tuple("/job:worker/replica:0/task:1/device:TPU:1",
- "/job:worker/replica:0/task:1/device:CPU:0")));
-
-TEST(TPURewriteDeviceUtilTest, CPUHostForTPUDeviceInvalidDevice) {
- auto status_or_device = GetCPUHostForTPUDevice("bad_device");
- ASSERT_FALSE(status_or_device.ok());
-}
-
-TEST(TPURewriteDeviceUtilTest, CPUHostsForTPUDevices) {
- auto status_or_devices =
- GetCPUHostsForTPUDevices({"/job:worker/replica:0/task:0/device:TPU:0",
- "/job:worker/replica:0/task:1/device:TPU:1"});
- TF_ASSERT_OK(status_or_devices.status());
- const auto& devices = status_or_devices.ValueOrDie();
- ASSERT_EQ(devices.size(), 2);
- EXPECT_EQ(devices[0], "/job:worker/replica:0/task:0/device:CPU:0");
- EXPECT_EQ(devices[1], "/job:worker/replica:0/task:1/device:CPU:0");
-}
-
-TEST(TPURewriteDeviceUtilTest, CPUHostsForTPUDevicesInvalidDevice) {
- auto status_or_devices = GetCPUHostsForTPUDevices(
- {"/job:worker/replica:0/task:0/device:TPU:0", "bad_device"});
- ASSERT_FALSE(status_or_devices.ok());
-}
-
} // anonymous namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tfjs/BUILD b/tensorflow/compiler/mlir/tfjs/BUILD
index 9b731d2..806a77e 100644
--- a/tensorflow/compiler/mlir/tfjs/BUILD
+++ b/tensorflow/compiler/mlir/tfjs/BUILD
@@ -1,4 +1,5 @@
load("//third_party/mlir:tblgen.bzl", "gentbl")
+load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
package(
default_visibility = ["//visibility:public"],
@@ -131,10 +132,106 @@
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
- "//tensorflow/compiler/mlir/tensorflow:translate_lib",
- "@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Transforms",
],
)
+
+cc_library(
+ name = "json_translate_lib",
+ srcs = [
+ "translate/json_translate.cc",
+ ],
+ hdrs = [
+ "translate/json_translate.h",
+ ],
+ deps = [
+ ":tensorflow_js",
+ ":tensorflow_js_dialect_registration",
+ "//tensorflow/compiler/mlir/tensorflow",
+ "//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
+ "//tensorflow/compiler/mlir/tensorflow:export_utils",
+ "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
+ "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:graph",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:Translation",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "tf_to_tfjs_json",
+ srcs = ["translate/tf_to_tfjs_json.cc"],
+ hdrs = [
+ "translate/tf_to_tfjs_json.h",
+ ],
+ deps = [
+ ":json_translate_lib",
+ ":tfjs_optimize",
+ "//tensorflow/compiler/mlir/tensorflow",
+ "//tensorflow/compiler/mlir/tensorflow:decode_constant_pass",
+ "//tensorflow/compiler/mlir/tensorflow:error_util",
+ "//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib",
+ "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
+ "//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
+ "//tensorflow/compiler/mlir/tensorflow:translate_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/stream_executor/lib",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ "@llvm-project//llvm:support",
+ "@llvm-project//mlir:AllPassesAndDialects",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Parser",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:Support",
+ ],
+ alwayslink = 1,
+)
+
+tf_cc_binary(
+ name = "json_translate",
+ deps = [
+ ":json_translate_lib",
+ "@llvm-project//mlir:MlirTranslateMain",
+ ],
+)
+
+filegroup(
+ name = "tf_tfjs_translate_main",
+ srcs = [
+ "translate/tf_tfjs_translate.cc",
+ ],
+)
+
+tf_cc_binary(
+ name = "tf_tfjs_translate",
+ srcs = [":tf_tfjs_translate_main"],
+ deps = [
+ ":json_translate_lib",
+ ":tensorflow_js_passes",
+ ":tf_to_tfjs_json",
+ ":tfjs_optimize",
+ "//tensorflow/compiler/mlir:init_mlir",
+ "//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/platform:errors",
+ "//tensorflow/stream_executor/lib",
+ "@com_google_absl//absl/strings",
+ "@llvm-project//llvm:support",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:Support",
+ ],
+)
diff --git a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h
index 318895d..545183a 100644
--- a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h
+++ b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h
@@ -28,6 +28,7 @@
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
+
namespace mlir {
namespace tfjs {
diff --git a/tensorflow/compiler/mlir/tfjs/tests/e2e/BUILD b/tensorflow/compiler/mlir/tfjs/tests/e2e/BUILD
new file mode 100644
index 0000000..5c8d37d
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfjs/tests/e2e/BUILD
@@ -0,0 +1,23 @@
+load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
+
+licenses(["notice"])
+
+glob_lit_tests(
+ data = [
+ ":test_utilities",
+ ],
+ driver = "@llvm-project//mlir:run_lit.sh",
+ test_file_exts = [
+ "pbtxt",
+ ],
+)
+
+# Bundle together all of the test utilities that are used by tests.
+filegroup(
+ name = "test_utilities",
+ testonly = True,
+ data = [
+ "//tensorflow/compiler/mlir/tfjs:tf_tfjs_translate",
+ "@llvm-project//llvm:FileCheck",
+ ],
+)
diff --git a/tensorflow/compiler/mlir/tfjs/tests/e2e/add.pbtxt b/tensorflow/compiler/mlir/tfjs/tests/e2e/add.pbtxt
new file mode 100644
index 0000000..f6a324f
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfjs/tests/e2e/add.pbtxt
@@ -0,0 +1,78 @@
+# RUN: tf_tfjs_translate %s -tf-input-arrays=input0,input1 -tf-input-data-types=DT_INT32,DT_INT32 -tf-input-shapes=10:10 -tf-output-arrays=Mul -o - | FileCheck %s --dump-input-on-failure
+# Add two tensor<4xi32> inputs and return the result
+
+node {
+ name: "Add"
+ op: "Add"
+ input: "input0"
+ input: "input1"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "input0"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "input1"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "Mul"
+ op: "Mul"
+ input: "Add"
+ input: "Add"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+versions {
+ producer: 27
+}
+
+# CHECK: "name": "input0"
+# CHECK-NEXT: "op": "Placeholder"
+# CHECK: "type": "DT_INT32"
+# CHECK: "name": "input1",
+# CHECK-NEXT: "op": "Placeholder"
+# CHECK: "type": "DT_INT32"
+# CHECK: "name": "Add"
+# CHECK-NEXT: "op": "AddV2"
+# CHECK-NEXT: "input":
+# CHECK-NEXT: "input0"
+# CHECK-NEXT: "input1"
+# CHECK: "type": "DT_INT32"
+# CHECK: "name": "Mul1"
+# CHECK-NEXT: "op": "Mul"
+# CHECK-NEXT: "input":
+# CHECK-NEXT: "Add"
+# CHECK-NEXT: "Add"
+# CHECK: "type": "DT_INT32"
+# CHECK: "name": "Mul"
+# CHECK-NEXT: "op": "_Retval"
+# CHECK-NEXT: "input":
+# CHECK-NEXT: "Mul1"
+# CHECK: "type": "DT_INT32"
+# CHECK: "library"
+# CHECK: "versions"
+# CHECK: "producer": 27
+
diff --git a/tensorflow/compiler/mlir/tfjs/tests/e2e/prelu.pbtxt b/tensorflow/compiler/mlir/tfjs/tests/e2e/prelu.pbtxt
new file mode 100644
index 0000000..810db71
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfjs/tests/e2e/prelu.pbtxt
@@ -0,0 +1,175 @@
+# RUN: tf_tfjs_translate %s -tf-input-arrays=input0 -tf-input-data-types=DT_FLOAT -tf-input-shapes=10 -tf-output-arrays=Add -tf-custom-opdefs="name: 'Prelu' input_arg: { name: 'x' type: DT_FLOAT } input_arg: { name: 'alpha' type: DT_FLOAT } output_arg: { name: 'c' type: DT_FLOAT }" -o - | FileCheck %s --dump-input-on-failure
+# Add two tensor<4xi32> inputs and return the result
+
+node {
+ name: "input0"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+ experimental_debug_info {
+ }
+}
+node {
+ name: "alpha"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 0.5
+ }
+ }
+ }
+ experimental_debug_info {
+ }
+}
+node {
+ name: "Relu"
+ op: "Relu"
+ input: "input0"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ experimental_debug_info {
+ }
+}
+node {
+ name: "Neg"
+ op: "Neg"
+ input: "input0"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ experimental_debug_info {
+ }
+}
+node {
+ name: "Relu1"
+ op: "Relu"
+ input: "Neg"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ experimental_debug_info {
+ }
+}
+node {
+ name: "Mul"
+ op: "Mul"
+ input: "alpha"
+ input: "Relu1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ experimental_debug_info {
+ }
+}
+node {
+ name: "Add"
+ op: "Add"
+ input: "Relu"
+ input: "Mul"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ experimental_debug_info {
+ }
+}
+node {
+ name: "main"
+ op: "_Retval"
+ input: "Add"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "index"
+ value {
+ i: 0
+ }
+ }
+}
+library {
+}
+versions {
+ producer: 344
+}
+
+# CHECK: "node":
+# CHECK: "name": "input0",
+# CHECK-NEXT: "op": "Placeholder",
+# CHECK-NEXT: "attr":
+# CHECK: "type": "DT_FLOAT"
+# CHECK: "name": "Add.Relu.Neg.Relu1.Mul",
+# CHECK-NEXT: "op": "Const",
+# CHECK-NEXT: "attr":
+# CHECK: "value":
+# CHECK: "tensor":
+# CHECK: "dtype": "DT_FLOAT",
+# CHECK: "tensorShape": {},
+# CHECK: "floatVal":
+# CHECK: -0.5
+# CHECK: "name": "Add.Relu.Neg.Relu1.Mul1",
+# CHECK-NEXT: "op": "Prelu",
+# CHECK-NEXT: "input":
+# CHECK: "input0",
+# CHECK: "Add.Relu.Neg.Relu1.Mul"
+# CHECK: "attr":
+# CHECK: "_output_shapes":
+# CHECK: "list":
+# CHECK: "shape":
+# CHECK: "dim":
+# CHECK: "size": "10"
+# CHECK: "experimentalDebugInfo": {}
+# CHECK: "name": "Add",
+# CHECK-NEXT: "op": "_Retval",
+# CHECK-NEXT: "input":
+# CHECK: "Add.Relu.Neg.Relu1.Mul1"
+# CHECK: "attr":
+# CHECK: "T":
+# CHECK: "type": "DT_FLOAT"
+# CHECK: "library": {},
+# CHECK: "versions":
+# CHECK: "producer": 344
+
diff --git a/tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.cc b/tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.cc
index 631bb1a..a445937 100644
--- a/tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.cc
+++ b/tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.cc
@@ -1,4 +1,4 @@
-/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -20,7 +20,6 @@
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
-#include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tfjs/transforms/passes.h"
@@ -47,6 +46,11 @@
// Canonicalize, CSE etc.
pm->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
pm->addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
+
+ // raise to executor dialect in order to use GraphDef converter
+ pm->addNestedPass<mlir::FuncOp>(
+ mlir::CreateFunctionalToExecutorDialectConversionPass());
+ pm->addNestedPass<mlir::FuncOp>(mlir::CreateBreakUpIslandsPass());
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tfjs/translate/json_translate.cc b/tensorflow/compiler/mlir/tfjs/translate/json_translate.cc
new file mode 100644
index 0000000..7f4b8ff
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfjs/translate/json_translate.cc
@@ -0,0 +1,105 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/mlir/tfjs/translate/json_translate.h"
+
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "mlir/IR/Attributes.h" // from @llvm-project
+#include "mlir/IR/Module.h" // from @llvm-project
+#include "mlir/Support/LogicalResult.h" // from @llvm-project
+#include "mlir/Translation.h" // from @llvm-project
+#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
+#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
+#include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
+
+using mlir::ModuleOp;
+using mlir::TranslateFromMLIRRegistration;
+using std::string;
+using tensorflow::Status;
+using xla::StatusOr;
+
+// Translates the given MLIR module in the TFJS dialect to TFJS JSON
+// format. Returns false on success.
+//
+bool tfjs::MlirToJSONTranslateFunction(ModuleOp module,
+ std::string* serialized_json) {
+ string json_output;
+ // Allow TF to treat TFJS ops as TF ops.
+ if (!tensorflow::AddTensorFlowOpPrefix("tfjs.").ok()) {
+ LOG(ERROR) << "Failed to add tfjs op prefix.";
+ return false;
+ }
+ tensorflow::GraphExportConfig confs;
+ confs.export_shapes = true;
+ confs.export_library = true;
+ tensorflow::FunctionLibraryDefinition flib_def(
+ tensorflow::OpRegistry::Global(), tensorflow::FunctionDefLibrary());
+ absl::flat_hash_set<tensorflow::Node*> control_ret_nodes;
+ auto graph = absl::make_unique<tensorflow::Graph>(flib_def);
+ auto status = tensorflow::ConvertMlirToGraph(module, confs, &graph, &flib_def,
+ &control_ret_nodes);
+ if (!status.ok()) {
+ LOG(ERROR) << "Graph export failed: " << status;
+ return false;
+ }
+ auto graphdef = absl::make_unique<tensorflow::GraphDef>();
+ graph->ToGraphDef(graphdef.get());
+
+ // Replace the _Arg nodes of the main function with Placeholder op.
+ auto nodes = graphdef->mutable_node();
+ for (const auto& node : llvm::enumerate(*nodes)) {
+ if (node.value().op() == "_Arg") {
+ nodes->Mutable(node.index())->set_op("Placeholder");
+ }
+ }
+
+ tensorflow::protobuf::util::JsonPrintOptions json_options;
+ json_options.add_whitespace = true;
+ auto jsonStatus = tensorflow::protobuf::util::MessageToJsonString(
+ *graphdef, &json_output, json_options);
+ if (!jsonStatus.ok()) {
+ LOG(ERROR) << "Proto2Json failed: " << status;
+ return false;
+ }
+ *serialized_json = std::move(json_output);
+ return true;
+}
+
+static mlir::LogicalResult MlirToJSONFileTranslateFunction(
+ ModuleOp module, llvm::raw_ostream& output) {
+ std::string serialized_json;
+ if (!tfjs::MlirToJSONTranslateFunction(module, &serialized_json))
+ return mlir::failure();
+
+ output << serialized_json;
+ return mlir::success();
+}
+
+static TranslateFromMLIRRegistration MLIRToJSONFileTranslate(
+ "mlir-to-tfjs-json", MlirToJSONFileTranslateFunction);
diff --git a/tensorflow/compiler/mlir/tfjs/translate/json_translate.h b/tensorflow/compiler/mlir/tfjs/translate/json_translate.h
new file mode 100644
index 0000000..0a931f7
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfjs/translate/json_translate.h
@@ -0,0 +1,31 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_JSON_TRANSLATE_H_
+#define TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_JSON_TRANSLATE_H_
+
+#include <string>
+
+#include "mlir/IR/Module.h" // from @llvm-project
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tfjs {
+
+// Translates the given MLIR `module` into a JSON string. Returns true if
+// translation fails, otherwise returns false.
+bool MlirToJSONTranslateFunction(mlir::ModuleOp module,
+ std::string* serialized_json);
+} // namespace tfjs
+
+#endif // TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_JSON_TRANSLATE_H_
diff --git a/tensorflow/compiler/mlir/tfjs/translate/tf_tfjs_translate.cc b/tensorflow/compiler/mlir/tfjs/translate/tf_tfjs_translate.cc
new file mode 100644
index 0000000..e735a3c
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfjs/translate/tf_tfjs_translate.cc
@@ -0,0 +1,173 @@
+
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <iostream>
+#include <string>
+
+#include "absl/strings/str_split.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/ToolOutputFile.h"
+#include "mlir/IR/Diagnostics.h" // from @llvm-project
+#include "mlir/IR/Function.h" // from @llvm-project
+#include "mlir/IR/MLIRContext.h" // from @llvm-project
+#include "mlir/IR/Module.h" // from @llvm-project
+#include "mlir/Pass/Pass.h" // from @llvm-project
+#include "mlir/Support/FileUtilities.h" // from @llvm-project
+#include "tensorflow/compiler/mlir/init_mlir.h"
+#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h"
+#include "tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.h"
+#include "tensorflow/compiler/mlir/tfjs/transforms/passes.h"
+#include "tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+
+using llvm::cl::opt;
+using mlir::MLIRContext;
+using stream_executor::port::StatusOr;
+
+// NOLINTNEXTLINE
+opt<std::string> input_file_name(llvm::cl::Positional,
+ llvm::cl::desc("<input file>"),
+ llvm::cl::init("-"));
+
+// NOLINTNEXTLINE
+opt<bool> import_saved_model_object_graph(
+ "savedmodel-objectgraph-to-mlir",
+ llvm::cl::desc("Import a saved model to its MLIR representation"),
+ llvm::cl::value_desc("dir"));
+
+// NOLINTNEXTLINE
+opt<bool> import_saved_model_signature_defs(
+ "savedmodel-signaturedefs-to-mlir",
+ llvm::cl::desc("Import a saved model V1 to its MLIR representation"),
+ llvm::cl::value_desc("dir"));
+
+// NOLINTNEXTLINE
+opt<std::string> saved_model_tags(
+ "tf-savedmodel-tags",
+ llvm::cl::desc("Tags used to indicate which MetaGraphDef to import, "
+ "separated by ','"),
+ llvm::cl::init("serve"));
+
+// NOLINTNEXTLINE
+opt<std::string> saved_model_exported_names(
+ "tf-savedmodel-exported-names",
+ llvm::cl::desc("Names to export from SavedModel, separated by ','. Empty "
+ "(the default) means export all."),
+ llvm::cl::init(""));
+
+// NOLINTNEXTLINE
+opt<std::string> output_file_name("o", llvm::cl::desc("<output file>"),
+ llvm::cl::value_desc("filename"),
+ llvm::cl::init("-"));
+// NOLINTNEXTLINE
+opt<bool> input_mlir(
+ "input-mlir",
+ llvm::cl::desc("Take input TensorFlow model in textual MLIR instead of "
+ "GraphDef format"),
+ llvm::cl::init(false), llvm::cl::Hidden);
+// NOLINTNEXTLINE
+opt<bool> output_mlir(
+ "output-mlir",
+ llvm::cl::desc("Output MLIR rather than JSON for the generated TFJS model"),
+ llvm::cl::init(false));
+
+// The following approach allows injecting opdefs in addition
+// to those that are already part of the global TF registry to be linked in
+// prior to importing the graph. The primary goal is for support of custom ops.
+// This is not intended to be a general solution for custom ops for the future
+// but mainly for supporting older models like mobilenet_ssd. More appropriate
+// mechanisms, such as op hints or using functions to represent composable ops
+// like https://github.com/tensorflow/community/pull/113 should be encouraged
+// going forward.
+// NOLINTNEXTLINE
+llvm::cl::list<std::string> custom_opdefs(
+ "tf-custom-opdefs", llvm::cl::desc("List of custom opdefs when importing "
+ "graphdef"));
+
+// Debugging flag to print function mapping in the JSON.
+// NOLINTNEXTLINE
+static opt<bool> print_function_result_mapping(
+ "print-function-result-mapping",
+ llvm::cl::desc(
+ "Print the mapping of function result to json output buffer"),
+ llvm::cl::init(false));
+
+enum TranslationStatus { kTrSuccess, kTrFailure };
+
+static int PrintFunctionResultMapping(const std::string& result) {
+ std::cout << result << std::endl;
+ return kTrSuccess;
+}
+
+int main(int argc, char** argv) {
+ tensorflow::InitMlir y(&argc, &argv);
+
+ llvm::cl::ParseCommandLineOptions(argc, argv,
+ "TF GraphDef to TFJS JSON converter\n");
+
+ MLIRContext context;
+ llvm::SourceMgr source_mgr;
+ mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context);
+
+ StatusOr<mlir::OwningModuleRef> module;
+
+ if (import_saved_model_object_graph || import_saved_model_signature_defs) {
+ if (input_mlir)
+ module = tensorflow::errors::InvalidArgument(
+ "Importing saved model should not have input_mlir set");
+ module = tensorflow::ImportSavedModel(
+ import_saved_model_object_graph, import_saved_model_signature_defs,
+ custom_opdefs, input_file_name, saved_model_tags,
+ saved_model_exported_names, &context);
+ } else {
+ module = tensorflow::LoadFromGraphdefOrMlirSource(
+ input_file_name, input_mlir, custom_opdefs, debug_info_file,
+ input_arrays, input_dtypes, input_shapes, output_arrays,
+ /*prune_unused_nodes=*/true, &source_mgr, &context);
+ }
+
+ // If errors occur, the library call in the above already logged the error
+ // message. So we can just return here.
+ if (!module.ok()) return kTrFailure;
+
+ mlir::PassManager pm(&context);
+
+ tensorflow::AddTFToTFJSConversionPasses(&pm);
+
+ std::string result;
+ auto status = tensorflow::ConvertTFOpsToTfjsJSON(module.ValueOrDie().get(),
+ output_mlir, &result, &pm);
+ if (!status.ok()) return kTrFailure;
+
+ std::string error_msg;
+ auto output = mlir::openOutputFile(output_file_name, &error_msg);
+ if (output == nullptr) {
+ llvm::errs() << error_msg << '\n';
+ return kTrFailure;
+ }
+ output->os() << result;
+ output->keep();
+
+ // Print out debugging info related to function mapping.
+ if (print_function_result_mapping) return PrintFunctionResultMapping(result);
+ return kTrSuccess;
+}
diff --git a/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.cc b/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.cc
new file mode 100644
index 0000000..7dc9ea0
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.cc
@@ -0,0 +1,152 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h"
+
+#include <memory>
+#include <string>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
+#include "mlir/IR/MLIRContext.h" // from @llvm-project
+#include "mlir/IR/Module.h" // from @llvm-project
+#include "mlir/Parser.h" // from @llvm-project
+#include "mlir/Pass/PassManager.h" // from @llvm-project
+#include "mlir/Support/FileUtilities.h" // from @llvm-project
+#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
+#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
+#include "tensorflow/compiler/mlir/tfjs/translate/json_translate.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/op_def_builder.h"
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/status.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+
+namespace tensorflow {
+
+using mlir::MLIRContext;
+using mlir::ModuleOp;
+using mlir::OwningModuleRef;
+using stream_executor::port::StatusOr;
+
+namespace {
+tensorflow::Status RegisterCustomOps(
+ const std::vector<std::string>& extra_tf_opdefs) {
+ for (const auto& tf_opdefs_string : extra_tf_opdefs) {
+ tensorflow::OpDef opdef;
+ if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string,
+ &opdef)) {
+ LOG(ERROR) << "OpDef parsing failed for: " << tf_opdefs_string;
+ return errors::InvalidArgument("fail to parse extra OpDef");
+ }
+ // Register extra opdefs.
+ tensorflow::OpRegistry::Global()->Register(
+ [opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status {
+ *op_reg_data = tensorflow::OpRegistrationData(opdef);
+ return Status::OK();
+ });
+ }
+ return Status::OK();
+}
+} // namespace
+
+StatusOr<OwningModuleRef> LoadFromGraphdefOrMlirSource(
+ const std::string& input_filename, bool input_mlir,
+ const std::vector<std::string>& extra_tf_opdefs,
+ absl::string_view debug_info_file, absl::string_view input_arrays,
+ absl::string_view input_dtypes, absl::string_view input_shapes,
+ absl::string_view output_arrays, bool prune_unused_nodes,
+ llvm::SourceMgr* source_mgr, MLIRContext* context) {
+ // Set up the input file.
+ std::string error_message;
+ auto file = mlir::openInputFile(input_filename, &error_message);
+ if (!file) {
+ llvm::errs() << error_message << "\n";
+ return errors::InvalidArgument("fail to open input file");
+ }
+
+ if (input_mlir) {
+ source_mgr->AddNewSourceBuffer(std::move(file), llvm::SMLoc());
+ return OwningModuleRef(mlir::parseSourceFile(*source_mgr, context));
+ }
+
+ TF_RETURN_IF_ERROR(RegisterCustomOps(extra_tf_opdefs));
+
+ return tensorflow::GraphdefToMlirTranslateFunction(
+ file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
+ input_shapes, output_arrays, /*control_output_arrays=*/"",
+ prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
+ /*graph_as_function=*/false, /*upgrade_legacy=*/true,
+ /*enable_shape_inference=*/true, context);
+}
+
+Status ConvertTFOpsToTfjsJSON(mlir::ModuleOp module, bool export_to_mlir,
+ std::string* result,
+ mlir::PassManager* pass_manager) {
+ mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(),
+ /*propagate=*/true);
+ if (failed(pass_manager->run(module))) {
+ return statusHandler.ConsumeStatus();
+ }
+
+ if (export_to_mlir) {
+ llvm::raw_string_ostream os(*result);
+ module.print(os);
+ return Status::OK();
+ }
+
+ return tfjs::MlirToJSONTranslateFunction(module, result)
+ ? Status::OK()
+ : statusHandler.ConsumeStatus();
+}
+
+StatusOr<mlir::OwningModuleRef> ImportSavedModel(
+ bool import_saved_model, bool import_saved_model_v1,
+ const std::vector<std::string>& extra_tf_opdefs,
+ const std::string& input_filename, const std::string& saved_model_tags,
+ const std::string& saved_model_exported_names, mlir::MLIRContext* context) {
+ std::unordered_set<std::string> tags = absl::StrSplit(saved_model_tags, ',');
+ std::vector<std::string> exported_names_in_vector =
+ absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
+ absl::Span<std::string> exported_names(exported_names_in_vector);
+ if (import_saved_model) {
+ auto module = tensorflow::SavedModelObjectGraphToMlirImport(
+ input_filename, tags, absl::Span<std::string>(exported_names), context);
+ if (!module)
+ return tensorflow::errors::InvalidArgument("fail to open input file");
+ TF_RETURN_IF_ERROR(RegisterCustomOps(extra_tf_opdefs));
+ return module;
+ } else if (import_saved_model_v1) {
+ auto module = tensorflow::SavedModelSignatureDefsToMlirImport(
+ input_filename, tags, exported_names, context);
+
+ if (!module)
+ return tensorflow::errors::InvalidArgument("fail to open input file");
+ TF_RETURN_IF_ERROR(RegisterCustomOps(extra_tf_opdefs));
+ return module;
+ } else {
+ return tensorflow::errors::InvalidArgument(
+ "Should be either saved model v1 or v2");
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h b/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h
new file mode 100644
index 0000000..d68f0e7
--- /dev/null
+++ b/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h
@@ -0,0 +1,63 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_TF_TO_TFJS_JSON_H_
+#define TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_TF_TO_TFJS_JSON_H_
+
+#include <string>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "llvm/Support/SourceMgr.h"
+#include "mlir/IR/MLIRContext.h" // from @llvm-project
+#include "mlir/IR/Module.h" // from @llvm-project
+#include "mlir/Pass/PassManager.h" // from @llvm-project
+#include "tensorflow/core/platform/status.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+
+namespace tensorflow {
+
+// Load a TF model from a GraphDef definition or a TF control flow dialect MLIR
+// source into a MLIR module. If `input_mlir` is true, load from a MLIR source
+// file; otherwise, load from a GraphDef.
+// Setting prune_unused_nodes to true, would prune unreachable nodes if
+// output_arrays is specified.
+stream_executor::port::StatusOr<mlir::OwningModuleRef>
+LoadFromGraphdefOrMlirSource(
+ const std::string& input_filename, bool input_mlir,
+ const std::vector<std::string>& extra_tf_opdefs,
+ absl::string_view debug_info_file, absl::string_view input_arrays,
+ absl::string_view input_dtypes, absl::string_view input_shapes,
+ absl::string_view output_arrays, bool prune_unused_nodes,
+ llvm::SourceMgr* source_mgr, mlir::MLIRContext* context);
+
+// Load Saved model (either v1 or v2) into MLIR.
+stream_executor::port::StatusOr<mlir::OwningModuleRef> ImportSavedModel(
+ bool import_saved_model, bool import_saved_model_v1,
+ const std::vector<std::string>& extra_tf_opdefs,
+ const std::string& input_filename, const std::string& saved_model_tags,
+ const std::string& saved_model_exported_names, mlir::MLIRContext* context);
+
+// Taking a MLIR module in TF executor dialect and a set of parameters,
+// applies a set of passes to convert the module to TFJS dialect and
+// serializes the result to JSON string.
+// If `export_to_mlir` is true, the result is exported in MLIR text format,
+// otherwise exported in JSON.
+Status ConvertTFOpsToTfjsJSON(mlir::ModuleOp module, bool export_to_mlir,
+ std::string* result,
+ mlir::PassManager* pass_manager);
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_TF_TO_TFJS_JSON_H_
diff --git a/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc b/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc
index 8a187cf..9257114 100644
--- a/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc
+++ b/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc
@@ -29,11 +29,25 @@
#include "tfrt/tensor/dense_host_tensor_view.h"
namespace tensorflow {
+namespace {
-void MapFunctionGlobalTensorCapturesFromTFSavedModelMLIR(
+llvm::StringRef ProcessIndexPath(mlir::ArrayAttr index_path) {
+ if (index_path.size() == 1 && index_path[0].isa<mlir::StringAttr>()) {
+ // TODO(chky): Support cases where index_path is not a single string.
+ return index_path[0].cast<mlir::StringAttr>().getValue();
+ }
+ return "";
+}
+
+} // namespace
+
+void MapFunctionSignaturesFromTFSavedModelMLIR(
mlir::ModuleOp module,
llvm::function_ref<void(
llvm::StringRef func_name,
+ llvm::ArrayRef<std::pair<llvm::StringRef, llvm::StringRef>>
+ input_names_and_devices,
+ llvm::ArrayRef<llvm::StringRef> output_names,
llvm::ArrayRef<mlir::tf_saved_model::GlobalTensorOp> global_tensors)>
map_fn) {
// Create global_tensors for each functions.
@@ -44,17 +58,38 @@
auto func_names = mlir::tf_saved_model::GetExportedNames(func);
if (func_names.empty()) return;
- // Here we walk through each arguments and find out the variables used by
- // this function.
+ // Here we walk through each arguments and find out the input/output names,
+ // and input devices, variables used by this function.
+ llvm::SmallVector<std::pair<llvm::StringRef, llvm::StringRef>, 4>
+ input_names_and_devices;
llvm::SmallVector<mlir::tf_saved_model::GlobalTensorOp, 4> global_tensors;
for (unsigned i = 0, e = func.getNumArguments(); i != e; ++i) {
+ if (auto input_index_path = func.getArgAttrOfType<mlir::ArrayAttr>(
+ i, "tf_saved_model.index_path")) {
+ std::pair<llvm::StringRef, llvm::StringRef> name_and_device;
+ name_and_device.first = ProcessIndexPath(input_index_path);
+ if (auto input_device =
+ func.getArgAttrOfType<mlir::StringAttr>(i, "tf.device")) {
+ name_and_device.second = input_device.getValue();
+ }
+ input_names_and_devices.push_back(name_and_device);
+ }
if (auto variable =
mlir::tf_saved_model::LookupBoundInput(func, i, symbol_table)) {
global_tensors.push_back(variable);
}
}
- for (auto func_name : func_names) map_fn(func_name, global_tensors);
+ llvm::SmallVector<llvm::StringRef, 4> output_names;
+ for (unsigned i = 0, e = func.getNumResults(); i != e; ++i) {
+ if (auto output_index_path = func.getResultAttrOfType<mlir::ArrayAttr>(
+ i, "tf_saved_model.index_path")) {
+ output_names.push_back(ProcessIndexPath(output_index_path));
+ }
+ }
+
+ for (auto func_name : func_names)
+ map_fn(func_name, input_names_and_devices, output_names, global_tensors);
});
}
diff --git a/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.h b/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.h
index de24ea2..06a6c5a 100644
--- a/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.h
+++ b/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.h
@@ -57,12 +57,15 @@
std::string force_data_format;
};
-// Map captured global tensors for each function.
-void MapFunctionGlobalTensorCapturesFromTFSavedModelMLIR(
+// Map signatures (eg. input/output names, variables) for each function.
+void MapFunctionSignaturesFromTFSavedModelMLIR(
mlir::ModuleOp module,
- llvm::function_ref<
- void(llvm::StringRef func_name,
- llvm::ArrayRef<mlir::tf_saved_model::GlobalTensorOp> captures)>
+ llvm::function_ref<void(
+ llvm::StringRef func_name,
+ llvm::ArrayRef<std::pair<llvm::StringRef, llvm::StringRef>>
+ input_names_and_devices,
+ llvm::ArrayRef<llvm::StringRef> output_names,
+ llvm::ArrayRef<mlir::tf_saved_model::GlobalTensorOp> global_tensors)>
map_fn);
// Compile MLIR in TF saved model dialect into BEF.
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/build_defs.bzl b/tensorflow/compiler/mlir/tools/kernel_gen/build_defs.bzl
deleted file mode 100644
index cec9968..0000000
--- a/tensorflow/compiler/mlir/tools/kernel_gen/build_defs.bzl
+++ /dev/null
@@ -1,96 +0,0 @@
-load("//third_party/gpus/cuda:build_defs.bzl", "cuda_gpu_select_list")
-
-def _lookup_file(filegroup, path):
- """Extracts file at (relative) path in filegroup."""
- for file in filegroup.files.to_list():
- if file.path.endswith(path):
- return file
- return None
-
-def _gen_kernel_image_hdr_impl(ctx):
- if not ctx.attr.gpu_archs:
- fail("No GPU architecture specified, use --config=cuda or similar.")
-
- name = ctx.attr.name
- tile_sizes = ctx.attr.tile_size.replace("x", ",")
- same_shape = []
- if ctx.attr.same_shape:
- same_shape.append("--same_shape=%s" % ctx.attr.same_shape)
-
- cubins = []
- images = []
- for arch in ctx.attr.gpu_archs:
- filename = "%s.%s.cubin" % (name, arch)
- cubin = ctx.actions.declare_file(filename)
- ctx.actions.run(
- outputs = [cubin],
- executable = ctx.executable._tool,
- arguments = same_shape + [
- "--tile_sizes=%s" % tile_sizes,
- "--arch=%s" % arch.split("_")[1],
- "--output=%s" % cubin.path,
- ctx.attr.op,
- ],
- mnemonic = "compile",
- )
- cubins.append(cubin)
- images.append("--image=profile=%s,file=%s" % (arch, cubin.path))
-
- # Generate fatbin file from all cubins.
- fatbin = ctx.actions.declare_file("%s.fatbin" % name)
- ctx.actions.run(
- outputs = [fatbin],
- inputs = cubins,
- executable = _lookup_file(ctx.attr._cuda_root, "bin/fatbinary"),
- arguments = [
- "--64",
- "--cmdline=--compile-only",
- "--link",
- "--compress-all",
- "--create=%s" % fatbin.path,
- ] + images,
- mnemonic = "fatbinary",
- )
-
- bin2c = _lookup_file(ctx.attr._cuda_root, "bin/bin2c")
- ctx.actions.run_shell(
- outputs = [ctx.outputs.out],
- inputs = [fatbin],
- tools = [bin2c],
- command = "%s --static --const --type=int --name=%s %s 1> %s" %
- (bin2c.path, ctx.attr.symbol, fatbin.path, ctx.outputs.out.path),
- mnemonic = "bin2c",
- )
-
-_gen_kernel_image_hdr = rule(
- implementation = _gen_kernel_image_hdr_impl,
- output_to_genfiles = True,
- attrs = {
- "op": attr.string(mandatory = True),
- "tile_size": attr.string(mandatory = True),
- "same_shape": attr.string(),
- "out": attr.output(mandatory = True),
- "symbol": attr.string(mandatory = True),
- "gpu_archs": attr.string_list(mandatory = True),
- "_cuda_root": attr.label(
- default = Label("//third_party/gpus/cuda:cuda_root"),
- ),
- "_tool": attr.label(
- executable = True,
- default = Label("//tensorflow/compiler/mlir/tools/kernel_gen:tf_to_cubin"),
- cfg = "host",
- ),
- },
-)
-
-def gen_kernel_image_hdr(name, op, tile_size, same_shape = None):
- """Generates a C header with fatbin data from a Tensorflow op."""
- _gen_kernel_image_hdr(
- name = name,
- op = op,
- tile_size = tile_size,
- same_shape = same_shape,
- out = "include/tfrt/gpu/ops/tf/%s.h" % name,
- symbol = "k%s" % name.replace("_", " ").title().replace(" ", ""),
- gpu_archs = cuda_gpu_select_list("sm_{}"),
- )
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc
index 46af4e4..b1c4b1b 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc
@@ -64,9 +64,9 @@
StatusOr<std::string> GetLibdeviceDir(
const xla::HloModuleConfig& hlo_module_config) {
- for (const string& cuda_root : tensorflow::CandidateCudaRoots(
+ for (const std::string& cuda_root : tensorflow::CandidateCudaRoots(
hlo_module_config.debug_options().xla_gpu_cuda_data_dir())) {
- string libdevice_dir =
+ std::string libdevice_dir =
tensorflow::io::JoinPath(cuda_root, "nvvm", "libdevice");
VLOG(2) << "Looking for libdevice at " << libdevice_dir;
if (tensorflow::Env::Default()->IsDirectory(libdevice_dir).ok()) {
@@ -136,7 +136,7 @@
: public mlir::PassWrapper<PropagateStaticKnowledge,
mlir::OperationPass<mlir::LLVM::LLVMFuncOp>> {
explicit PropagateStaticKnowledge(mlir::FunctionType type,
- llvm::ArrayRef<unsigned> same_shape_)
+ llvm::ArrayRef<uint32_t> same_shape_)
: func_type(type), same_shape(same_shape_) {}
void runOnOperation() override {
@@ -152,8 +152,8 @@
func.getLoc(), index_type, b.getIntegerAttr(b.getIndexType(), 1));
mlir::Value zero = b.create<mlir::LLVM::ConstantOp>(
func.getLoc(), index_type, b.getIntegerAttr(b.getIndexType(), 0));
- unsigned arg_pos = 0;
- std::vector<unsigned> positions;
+ uint32_t arg_pos = 0;
+ std::vector<uint32_t> positions;
for (mlir::Type arg_type : func_type.getInputs()) {
positions.push_back(arg_pos);
func.getArgument(arg_pos + 2).replaceAllUsesWith(zero);
@@ -165,13 +165,13 @@
// can use that here. Simply replace usages of the shape parameters within
// the function body to a single shape parameter.
if (!same_shape.empty()) {
- int first = same_shape.front();
- int first_offset = positions.at(first);
+ auto first = same_shape.front();
+ auto first_offset = positions.at(first);
mlir::ShapedType first_type =
func_type.getInput(first).cast<mlir::ShapedType>();
- unsigned rank = first_type.getRank();
- for (int same : same_shape.drop_front(1)) {
- unsigned same_offset = positions.at(same);
+ uint32_t rank = first_type.getRank();
+ for (auto same : same_shape.drop_front(1)) {
+ uint32_t same_offset = positions.at(same);
auto same_type = func_type.getInput(same).cast<mlir::ShapedType>();
if (same_type.getRank() != rank) {
func.emitOpError() << "same shape constraints on arguments with "
@@ -180,7 +180,7 @@
signalPassFailure();
}
- for (int i = 0; i < 2 * rank; ++i) {
+ for (uint32_t i = 0; i < 2 * rank; ++i) {
// Replace uses for second arg data with first arg.
auto same_arg = func.getArgument(same_offset + 3 + i);
auto first_arg = func.getArgument(first_offset + 3 + i);
@@ -191,11 +191,11 @@
}
mlir::FunctionType func_type;
- llvm::ArrayRef<unsigned> same_shape;
+ llvm::ArrayRef<uint32_t> same_shape;
};
Status PropagateStaticShapeKnowledgeToKernel(
- mlir::ModuleOp module, llvm::ArrayRef<unsigned> same_shape) {
+ mlir::ModuleOp module, llvm::ArrayRef<uint32_t> same_shape) {
// Grab the original signature from the single function.
auto func = *module.getBody()->op_begin<mlir::FuncOp>();
@@ -218,10 +218,10 @@
}
} // namespace
-StatusOr<std::vector<uint8>> tensorflow::kernel_gen::GenerateCubinForTfCode(
- llvm::StringRef tf_code, std::pair<int, int> compute_capability,
- llvm::ArrayRef<unsigned> tile_sizes, llvm::ArrayRef<unsigned> same_shape,
- llvm::ArrayRef<unsigned> unroll_factors) {
+StatusOr<std::vector<uint8_t>> tensorflow::kernel_gen::GenerateCubinForTfCode(
+ llvm::StringRef tf_code, std::pair<int32_t, int32_t> compute_capability,
+ llvm::ArrayRef<uint32_t> tile_sizes, llvm::ArrayRef<uint32_t> same_shape,
+ llvm::ArrayRef<uint32_t> unroll_factors) {
mlir::MLIRContext context;
context.allowUnregisteredDialects(); // TODO(b/152572127)
mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context);
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h
index c874633..47626ba 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h
@@ -30,11 +30,12 @@
namespace tensorflow {
namespace kernel_gen {
-xla::StatusOr<std::vector<uint8>> GenerateCubinForTfCode(
- llvm::StringRef tf_code, std::pair<int, int> compute_capability = {7, 5},
- llvm::ArrayRef<unsigned> tile_sizes = {16, 64},
- llvm::ArrayRef<unsigned> same_shape = {},
- llvm::ArrayRef<unsigned> unroll_factors = {});
+xla::StatusOr<std::vector<uint8_t>> GenerateCubinForTfCode(
+ llvm::StringRef tf_code,
+ std::pair<int32_t, int32_t> compute_capability = {7, 5},
+ llvm::ArrayRef<uint32_t> tile_sizes = {16, 64},
+ llvm::ArrayRef<uint32_t> same_shape = {},
+ llvm::ArrayRef<uint32_t> unroll_factors = {});
} // namespace kernel_gen
} // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_cubin.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_cubin.cc
index d39edd8..8edc567 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_cubin.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_cubin.cc
@@ -31,9 +31,9 @@
#include "tensorflow/core/util/command_line_flags.h"
namespace {
-bool ParseStringList(std::string string_list, std::vector<uint32>* result) {
+bool ParseStringList(std::string string_list, std::vector<uint32_t>* result) {
result->clear();
- uint32 item;
+ uint32_t item;
auto items = absl::StrSplit(string_list, ',');
for (const auto& item_str : items) {
if (!absl::SimpleAtoi(item_str, &item)) {
@@ -48,10 +48,10 @@
int main(int argc, char** argv) {
std::string output_file = "foo.bin";
- int32 architecture = 50;
- std::vector<uint32> tile_sizes;
- std::vector<uint32> unroll_factors;
- std::vector<uint32> same_shape;
+ int32_t architecture = 50;
+ std::vector<uint32_t> tile_sizes;
+ std::vector<uint32_t> unroll_factors;
+ std::vector<uint32_t> same_shape;
auto parse_tile_sizes = [&tile_sizes](std::string tile_sizes_str) {
if (!ParseStringList(tile_sizes_str, &tile_sizes)) {
@@ -91,8 +91,8 @@
return 1;
}
- std::pair<int32, int32> compute_capability(architecture / 10,
- architecture % 10);
+ std::pair<int32_t, int32_t> compute_capability(architecture / 10,
+ architecture % 10);
auto cubin = tensorflow::kernel_gen::GenerateCubinForTfCode(
argv[1], compute_capability, tile_sizes, same_shape, unroll_factors);
@@ -102,7 +102,7 @@
return 1;
}
- std::vector<uint8> cubin_data = cubin.ConsumeValueOrDie();
+ std::vector<uint8_t> cubin_data = cubin.ConsumeValueOrDie();
auto status = tensorflow::WriteStringToFile(
tensorflow::Env::Default(), output_file,
diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD
index e0e93c3..d9108e8 100644
--- a/tensorflow/compiler/mlir/xla/BUILD
+++ b/tensorflow/compiler/mlir/xla/BUILD
@@ -240,8 +240,8 @@
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps",
- "@llvm-project//mlir:LoopOps",
"@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Transforms",
],
@@ -278,8 +278,8 @@
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps",
- "@llvm-project//mlir:LoopOps",
"@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Transforms",
],
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
index 874f4a1..68eafb8 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
@@ -1170,9 +1170,22 @@
//===----------------------------------------------------------------------===//
OpFoldResult ReverseOp::fold(ArrayRef<Attribute> operands) {
+ auto input = operand();
+
// No dimensions to reverse.
- if (dimensions().getNumElements() == 0) return operand();
- return nullptr;
+ if (dimensions().getNumElements() == 0) return input;
+
+ llvm::SmallVector<APInt, 5> new_dims;
+ new_dims.reserve(dimensions().getNumElements());
+
+ auto shaped_type = input.getType().cast<ShapedType>();
+ for (auto dim : dimensions().getValues<APInt>()) {
+ if (shaped_type.getDimSize(dim.getLimitedValue()) != 1) {
+ return nullptr;
+ }
+ }
+
+ return input;
}
//===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
index 917a50f..0db9563 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
@@ -229,7 +229,7 @@
BASE_HLO_SqrtOp;
def HLO_TanhOp: HLO_UnaryElementwiseOp<"tanh",
- [ResultsAreFloatLike, NoSideEffect, SameOperandsAndResultType],
+ [NoSideEffect, SameOperandsAndResultType],
HLO_FpOrComplexTensor>, BASE_HLO_TanhOp;
//===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
index 99d1da7..cc334d8 100644
--- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
+++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc
@@ -56,6 +56,20 @@
return mlir::DenseIntElementsAttr::get(ty, mlir_values);
}
+static mlir::DenseIntElementsAttr ConvertPadding(
+ absl::Span<const std::pair<tensorflow::int64, tensorflow::int64>> padding,
+ mlir::Builder* builder) {
+ llvm::SmallVector<int64_t, 8> elements;
+ elements.reserve(padding.size() * 2);
+ for (const auto& vals : padding) {
+ elements.push_back(vals.first);
+ elements.push_back(vals.second);
+ }
+ auto ty = mlir::RankedTensorType::get(
+ {static_cast<int64_t>(padding.size()), 2}, builder->getIntegerType(64));
+ return mlir::DenseIntElementsAttr::get(ty, elements);
+}
+
MlirHloBuilder::~MlirHloBuilder() = default;
StatusOr<XlaOp> MlirHloBuilder::MakeXlaOp(mlir::Value val) {
@@ -79,6 +93,31 @@
});
}
+StatusOr<XlaOp> MlirHloBuilder::ConvGeneralDilatedInternal(
+ const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count, int64 batch_group_count,
+ const PrecisionConfig* precision_config) {
+ TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
+ shape, builder_));
+ mlir::ArrayAttr config_attr;
+ if (precision_config)
+ config_attr = ConvertPrecisionConfig(precision_config, &builder_);
+ auto op = builder_.create<mlir::xla_hlo::ConvOp>(
+ loc_, ty, GetValue(lhs), GetValue(rhs),
+ GetI64ElementsAttr(window_strides, &builder_),
+ ConvertPadding(padding, &builder_),
+ GetI64ElementsAttr(lhs_dilation, &builder_),
+ GetI64ElementsAttr(rhs_dilation, &builder_),
+ ConvertConvDimensionNumbers(dimension_numbers, &builder_),
+ builder_.getI64IntegerAttr(feature_group_count),
+ builder_.getI64IntegerAttr(batch_group_count), config_attr);
+ return MakeXlaOp(op);
+}
+
StatusOr<XlaOp> MlirHloBuilder::TransposeInternal(
const Shape& shape, XlaOp operand, absl::Span<const int64> permutation) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h
index dbcb685..5a84d60 100644
--- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h
+++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h
@@ -110,6 +110,16 @@
private:
XlaOp ConstantLiteral(const LiteralSlice& literal) override;
+ StatusOr<XlaOp> ConvGeneralDilatedInternal(
+ const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ absl::Span<const int64> lhs_dilation,
+ absl::Span<const int64> rhs_dilation,
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count, int64 batch_group_count,
+ const PrecisionConfig* precision_config) override;
+
StatusOr<XlaOp> TransposeInternal(
const Shape& shape, XlaOp operand,
absl::Span<const int64> permutation) override;
diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
index 21668b7..228a26b 100644
--- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
+++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
@@ -986,6 +986,21 @@
return LowerFunctionCall(&call_op, builder, &value_map);
}
+ if (auto op = dyn_cast<mlir::TensorCastOp>(inst)) {
+ Value operand = op.getOperand();
+ auto ty = operand.getType().dyn_cast<ShapedType>();
+ // If this was a cast from a static shaped tensors, then it is a noop for
+ // export to HLO and we can use the operand.
+ if (!ty || !ty.hasStaticShape()) {
+ inst->emitOpError()
+ << "requires static shaped operand for HLO translation";
+ return failure();
+ }
+
+ value_map[op.getResult()] = value_map[operand];
+ return success();
+ }
+
// TODO(jpienaar): This doesn't support layouts yet.
if (matchPattern(inst, m_Constant(&const_attr))) {
auto literal_or = CreateLiteralFromAttr(const_attr);
diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir
index 262533b..53296b2 100644
--- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir
@@ -1,4 +1,4 @@
-// RUN: xla-opt -hlo-legalize-to-lhlo %s -o - | FileCheck %s --dump-input-on-failure
+// RUN: xla-opt -hlo-legalize-to-lhlo -buffer-placement %s -o - | FileCheck %s --dump-input-on-failure
// CHECK-LABEL: func @attrs
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
@@ -13,33 +13,42 @@
// -----
+func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> {
+ return %arg0 : tensor<4xf32>
+}
+// CHECK: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]])
+// CHECK-NEXT: "xla_lhlo.copy"(%[[ARG0]], %[[RESULT]]) : ([[TYPE]], [[TYPE]]) -> ()
+// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
+
+// -----
+
// CHECK-LABEL: func @func_op_long
func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
- // CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
- // CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
- // CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
- // CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
- // CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
- // CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() {temp = true} : memref<4xf32>
%1 = xla_hlo.maximum %arg0, %arg1 : tensor<4xf32>
- // CHECK-NEXT: "xla_lhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
%2 = xla_hlo.add %arg0, %1 : tensor<4xf32>
- // CHECK-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
%3 = xla_hlo.minimum %arg0, %arg1 : tensor<4xf32>
- // CHECK-NEXT: "xla_lhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
%4 = xla_hlo.subtract %arg1, %3 : tensor<4xf32>
- // CHECK-NEXT: "xla_lhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
%5 = xla_hlo.multiply %2, %4 : tensor<4xf32>
- // CHECK-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
- // CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
- // CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
- // CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
- // CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
- // CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> ()
- // CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32>
return %5 : tensor<4xf32>
- // CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
}
+// CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
+// CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32>
+// CHECK-NEXT: "xla_lhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
+// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32>
+// CHECK-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
+// CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
+// CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32>
+// CHECK-NEXT: "xla_lhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
+// CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32>
+// CHECK-NEXT: "xla_lhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
+// CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
+// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32>
+// CHECK-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
+// CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
+// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
+// CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> ()
+// CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32>
+// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
// -----
@@ -47,20 +56,20 @@
func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
%summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}}, %[[RESULT:.*]]: {{.*}})
- // CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() {temp = true} : memref<2x2xf32>
- // CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() {temp = true} : memref<2x2xf32>
+ // CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32>
%tensor_summand_1 = tensor_load %summand_1 : memref<2x2xf32>
%tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32>
%sum = "xla_hlo.add"(%tensor_summand_1, %tensor_summand_2)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: "xla_lhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
+ // CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32>
%tensor_multiplier = tensor_load %multiplier : memref<2x2xf32>
%tensor_result = "xla_hlo.multiply"(%sum, %tensor_multiplier)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
+ // CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32>
// CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]])
tensor_store %tensor_result, %result : memref<2x2xf32>
- // CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32>
// CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<2x2xf32>
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
"xla_lhlo.terminator"() : () -> ()
diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir
index aa949a0..a856ee5 100644
--- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir
@@ -530,3 +530,15 @@
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f64):
// CHECK-NEXT: %[[RESULT:.*]] = fptrunc %[[OPERAND_IN]] : f64 to f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
+
+// -----
+
+// CHECK-LABEL: func @convert_f32_to_i32
+func @convert_f32_to_i32(%input: tensor<2x2xf32>) -> tensor<2x2xi32> {
+ %result = "xla_hlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi32>
+ return %result : tensor<2x2xi32>
+}
+// CHECK: linalg.generic
+// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32):
+// CHECK-NEXT: %[[RESULT:.*]] = fptosi %[[OPERAND_IN]] : f32 to i32
+// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/passthrough.mlir b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/passthrough.mlir
index cda1dc4..6a2b68a 100644
--- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/passthrough.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/passthrough.mlir
@@ -8,7 +8,9 @@
// CHECK-SAME: ) {
func @main(%value: tensor<2x2xf32>) -> tensor<2x2xf32> {
// The only expected instruction is a copy from the input into the output.
- // CHECK: %[[OUTPUT:.*]] = std.view %[[ARG1]][][] : memref<16xi8> to memref<2x2xf32>
+ // CHECK: %[[C0:.*]] = constant 0 : index
+ // CHECK: %[[C02:.*]] = constant 0 : index
+ // CHECK: %[[OUTPUT:.*]] = std.view %[[ARG1]][%[[C02]]][] : memref<16xi8> to memref<2x2xf32>
// CHECK: xla_lhlo.copy
// CHECK-SAME: %[[ARG0]], %[[OUTPUT]]
return %value : tensor<2x2xf32>
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
index aef9b17..a5353be 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
@@ -1596,6 +1596,44 @@
return %0, %1 : tensor<i32>, tensor<i32>
}
+
+//===----------------------------------------------------------------------===//
+// ReverseV2 op legalization.
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @reverse_func_32
+func @reverse_func_32(%arg0: tensor<5xi32>) -> tensor<5xi32> {
+ %axis = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> (tensor<1xi32>)
+
+ // CHECK: [[VAL:%.+]] = "xla_hlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>}
+ %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi32>) -> tensor<5xi32>
+
+ // CHECK: return [[VAL]] : tensor<5xi32>
+ return %reversed : tensor<5xi32>
+}
+
+// CHECK-LABEL: @reverse_func_64
+func @reverse_func_64(%arg0: tensor<5xi32>) -> tensor<5xi32> {
+ %axis = "tf.Const"() {value = dense<0> : tensor<1xi64>} : () -> (tensor<1xi64>)
+
+ // CHECK: [[VAL:%.+]] = "xla_hlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>}
+ %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi64>) -> tensor<5xi32>
+
+ // CHECK: return [[VAL]] : tensor<5xi32>
+ return %reversed : tensor<5xi32>
+}
+
+// CHECK-LABEL: @reverse_func_neg
+func @reverse_func_neg(%arg0: tensor<5x5xi32>) -> tensor<5x5xi32> {
+ %axis = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>)
+
+ // CHECK: [[VAL:%.+]] = "xla_hlo.reverse"(%arg0) {dimensions = dense<1> : tensor<1xi64>}
+ %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5x5xi32>, tensor<1xi32>) -> tensor<5x5xi32>
+
+ // CHECK: return [[VAL]] : tensor<5x5xi32>
+ return %reversed : tensor<5x5xi32>
+}
+
//===----------------------------------------------------------------------===//
// StatefulPartitionedCall op legalization.
//===----------------------------------------------------------------------===//
@@ -4074,6 +4112,41 @@
return %0 : tensor<4x16xf32>
}
+// CHECK-LABEL: inplace_update_one
+func @inplace_update_one(%arg0: tensor<8x4xf32>, %arg1: tensor<1x4xf32>, %arg2: tensor<1xi32>) -> tensor<8x4xf32> {
+ // CHECK-DAG: [[CST:%.+]] = xla_hlo.constant dense<0>
+ // CHECK-DAG: [[SLICE1:%.+]] = "xla_hlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+ // CHECK-DAG: [[SLICE2:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ // CHECK-DAG: [[RESHAPE1:%.+]] = "xla_hlo.reshape"([[SLICE1]])
+ // CHECK-DAG: [[UPDATE:%.+]] = "xla_hlo.dynamic-update-slice"(%arg0, [[SLICE2]], [[RESHAPE1]], [[CST]])
+ %0 = "tf.InplaceUpdate"(%arg0, %arg2, %arg1) : (tensor<8x4xf32>, tensor<1xi32>, tensor<1x4xf32>) -> tensor<8x4xf32>
+
+ // CHECK: return [[UPDATE]]
+ return %0 : tensor<8x4xf32>
+}
+
+// CHECK-LABEL: inplace_update_three
+func @inplace_update_three(%arg0: tensor<8x8x4xf32>, %arg1: tensor<3x8x4xf32>, %arg2: tensor<3xi32>) -> tensor<8x8x4xf32> {
+ // CHECK-DAG: [[CST:%.+]] = xla_hlo.constant dense<0>
+ // CHECK-DAG: [[SLICE1:%.+]] = "xla_hlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+ // CHECK-DAG: [[SLICE2:%.+]] = "xla_hlo.slice"(%arg2) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+ // CHECK-DAG: [[SLICE3:%.+]] = "xla_hlo.slice"(%arg2) {limit_indices = dense<3> : tensor<1xi64>, start_indices = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+ // CHECK-DAG: [[SLICE4:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[1, 8, 4]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
+ // CHECK-DAG: [[SLICE5:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[2, 8, 4]> : tensor<3xi64>, start_indices = dense<[1, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
+ // CHECK-DAG: [[SLICE6:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[3, 8, 4]> : tensor<3xi64>, start_indices = dense<[2, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
+ // CHECK-DAG: [[RESHAPE1:%.+]] = "xla_hlo.reshape"([[SLICE1]])
+ // CHECK-DAG: [[RESHAPE2:%.+]] = "xla_hlo.reshape"([[SLICE2]])
+ // CHECK-DAG: [[RESHAPE3:%.+]] = "xla_hlo.reshape"([[SLICE3]])
+ // CHECK-DAG: [[UPDATE1:%.+]] = "xla_hlo.dynamic-update-slice"(%arg0, [[SLICE4]], [[RESHAPE1]], [[CST]], [[CST]])
+ // CHECK-DAG: [[UPDATE2:%.+]] = "xla_hlo.dynamic-update-slice"([[UPDATE1]], [[SLICE5]], [[RESHAPE2]], [[CST]], [[CST]])
+ // CHECK-DAG: [[UPDATE3:%.+]] = "xla_hlo.dynamic-update-slice"([[UPDATE2]], [[SLICE6]], [[RESHAPE3]], [[CST]], [[CST]])
+ %0 = "tf.InplaceUpdate"(%arg0, %arg2, %arg1) : (tensor<8x8x4xf32>, tensor<3xi32>, tensor<3x8x4xf32>) -> tensor<8x8x4xf32>
+
+ // CHECK: return [[UPDATE3]] : tensor<8x8x4xf32>
+ return %0 : tensor<8x8x4xf32>
+}
+
+
// CHECK-LABEL: xla_dynamic_update_slice
func @xla_dynamic_update_slice(%arg0: tensor<4x16xf32>, %arg1: tensor<2x4xf32>, %arg2: tensor<2xi32>) -> tensor<4x16xf32> {
// CHECK: [[SLICE0:%.+]] = "xla_hlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32>
@@ -4097,6 +4170,21 @@
}
//===----------------------------------------------------------------------===//
+// AllToAll op legalizations.
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @alltoall_basic
+func @alltoall_basic(%input: tensor<10xf32>) -> tensor<10xf32> {
+ %group_assignment = "tf.Const" () {
+ value = dense<[[0, 2, 4, 6], [1, 3, 5, 7], [3, 5, 6, 8]]> : tensor<3x4xi32>
+ } : () -> tensor<3x4xi32>
+ %result = "tf.AllToAll"(%input, %group_assignment) {T = f32, concat_dimension = 1 : i64, split_count = 2 : i64, split_dimension = 0 : i64} : (tensor<10xf32>, tensor<3x4xi32>) -> tensor<10xf32>
+ // CHECK: xla_hlo.all_to_all
+ // CHECK-SAME: replica_groups = dense<{{\[}}[0, 2, 4, 6], [1, 3, 5, 7], [3, 5, 6, 8]]> : tensor<3x4xi64>
+ return %result : tensor<10xf32>
+}
+
+//===----------------------------------------------------------------------===//
// Cumsum op legalizations.
//===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir
index 3605e8e..bb8010b 100644
--- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir
@@ -411,6 +411,19 @@
// -----
+// CHECK-LABEL: func @convert_f32_to_i32
+func @convert_f32_to_i32(%input: memref<2x2xf32>, %result: memref<2x2xi32>) {
+ "xla_lhlo.convert"(%input, %result)
+ : (memref<2x2xf32>, memref<2x2xi32>) -> ()
+ return
+}
+// CHECK: linalg.generic
+// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]: i32):
+// CHECK-NEXT: %[[RESULT:.*]] = fptosi %[[OPERAND_IN]] : f32 to i32
+// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
+
+// -----
+
// CHECK-LABEL: func @cos
func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
diff --git a/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir b/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir
index 4050340..2340650 100644
--- a/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir
@@ -20,6 +20,17 @@
// -----
+// CHECK-LABEL: @addBroadcastEqual
+func @addBroadcastEqual(%arg0: tensor<4x1xf32>, %arg1: tensor<1x4xf32>) -> tensor<4x4xf32> {
+ // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x1xf32>) -> tensor<4x4xf32>
+ // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<4x4xf32>
+ // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %[[BROADCAST0]], %[[BROADCAST1]] : tensor<4x4xf32>
+ %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4x1xf32>, tensor<1x4xf32>) -> tensor<4x4xf32>
+ return %0 : tensor<4x4xf32>
+}
+
+// -----
+
// CHECK-LABEL: @addBroadcastMultidimension
func @addBroadcastMultidimension(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x4xf32> {
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xf32>) -> tensor<1x1x4xf32>
diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir
index ed06863..15fa915 100644
--- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir
@@ -1044,3 +1044,16 @@
// CHECK: ENTRY
// CHECK: %[[ARG0:.*]] = u8[4] parameter(0)
// ROOT %[[RESULT:.*]] = u8[4] not(u8[4] %[[ARG0]])
+
+// -----
+
+// CHECK: HloModule
+func @main(%arg0: tensor<4xi32>) -> (tensor<*xi32>) {
+ %0 = "xla_hlo.not"(%arg0) : (tensor<4xi32>) -> tensor<4xi32>
+ %1 = tensor_cast %0 : tensor<4xi32> to tensor<*xi32>
+ return %1 : tensor<*xi32>
+}
+
+// CHECK: ENTRY
+// CHECK: %[[ARG0:.*]] = s32[4] parameter(0)
+// ROOT %[[RESULT:.*]] = s32[4] not(s32[4] %[[ARG0]])
diff --git a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc
index aa29241..10f3576 100644
--- a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc
@@ -27,6 +27,7 @@
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
+#include "mlir/Transforms/BufferPlacement.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
@@ -39,16 +40,11 @@
namespace {
constexpr StringRef kTempBufferAttr = "temp";
-
-/// Returns DeallocOp to ensure that CopyOp is not inserted after dealloc.
-Operation* FindInsertionPointForCopy(Value value) {
- for (const auto& user : value.getUsers()) {
- if (auto dealloc = dyn_cast<DeallocOp>(user)) {
- return user;
- }
- }
- return nullptr;
-}
+template <typename T>
+using BaseOpConversion = BufferAssignmentOpConversionPattern<T>;
+using StdReturnOpConverter =
+ NonVoidToVoidReturnOpConverter<mlir::ReturnOp, xla_lhlo::TerminatorOp,
+ xla_lhlo::CopyOp>;
Value InsertDynamicAllocAndDealloc(Location loc, Value result,
Value shape_operand,
@@ -92,8 +88,9 @@
return alloc;
}
-Value InsertAllocAndDealloc(Location loc, Value result,
- ConversionPatternRewriter* rewriter) {
+Value InsertAlloc(Location loc, OpResult result,
+ BufferAssignmentPlacer* bufferAssignment,
+ ConversionPatternRewriter* rewriter) {
auto result_type = result.getType().dyn_cast<ShapedType>();
if (!result_type || !result_type.hasStaticShape()) {
result.getDefiningOp()->emitOpError()
@@ -101,31 +98,21 @@
}
auto memref_type =
MemRefType::get(result_type.getShape(), result_type.getElementType());
-
- Operation* op = result.getDefiningOp();
- auto block = op->getBlock();
-
- OpBuilder allocBuilder(op);
- allocBuilder.setInsertionPointToStart(block); // Inserting at the beginning
- auto alloc = allocBuilder.create<AllocOp>(loc, memref_type);
-
- alloc.setAttr(kTempBufferAttr, rewriter->getBoolAttr(true));
-
- allocBuilder.setInsertionPoint(block, std::prev(block->end()));
- allocBuilder.create<DeallocOp>(loc, alloc);
-
+ OpBuilder::InsertionGuard guard(*rewriter);
+ rewriter->restoreInsertionPoint(
+ bufferAssignment->computeAllocPosition(result));
+ auto alloc = rewriter->create<AllocOp>(loc, memref_type);
return alloc;
}
template <typename HloOpTy>
-class HloToLhloOpConverter : public ConversionPattern {
+class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
public:
- explicit HloToLhloOpConverter(MLIRContext* context)
- : ConversionPattern(HloOpTy::getOperationName(), 1, context) {}
-
+ using BaseOpConversion<HloOpTy>::BaseOpConversion;
LogicalResult matchAndRewrite(
- Operation* op, ArrayRef<Value> operands,
+ HloOpTy hloOp, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
+ Operation* op = hloOp.getOperation();
const auto& original_results = op->getResults();
SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
for (auto result : llvm::enumerate(original_results)) {
@@ -135,8 +122,8 @@
return failure();
}
if (resultType.hasStaticShape()) {
- buffer_args.push_back(
- InsertAllocAndDealloc(op->getLoc(), result.value(), &rewriter));
+ buffer_args.push_back(InsertAlloc(op->getLoc(), result.value(),
+ this->bufferAssignment, &rewriter));
} else {
SmallVector<Value, 1> results_shape;
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
@@ -156,9 +143,9 @@
};
struct HloToLhloDynamicBroadcastInDimOpConverter
- : public OpConversionPattern<xla_hlo::DynamicBroadcastInDimOp> {
+ : public BaseOpConversion<xla_hlo::DynamicBroadcastInDimOp> {
public:
- using OpConversionPattern::OpConversionPattern;
+ using BaseOpConversion<xla_hlo::DynamicBroadcastInDimOp>::BaseOpConversion;
LogicalResult matchAndRewrite(
xla_hlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
@@ -175,10 +162,9 @@
}
};
-struct HloToLhloReduceOpConverter
- : public OpConversionPattern<xla_hlo::ReduceOp> {
+struct HloToLhloReduceOpConverter : public BaseOpConversion<xla_hlo::ReduceOp> {
public:
- using OpConversionPattern::OpConversionPattern;
+ using BaseOpConversion<xla_hlo::ReduceOp>::BaseOpConversion;
LogicalResult matchAndRewrite(
xla_hlo::ReduceOp op, ArrayRef<Value> operands,
@@ -194,7 +180,8 @@
const auto& original_results = op.getResults();
SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
for (auto result : original_results) {
- buffer_args.push_back(InsertAllocAndDealloc(loc, result, &rewriter));
+ buffer_args.push_back(
+ InsertAlloc(loc, result, this->bufferAssignment, &rewriter));
}
auto new_op = rewriter.create<xla_lhlo::ReduceOp>(
loc, llvm::None, buffer_args, op.getAttrs());
@@ -230,12 +217,12 @@
}
};
-class HloToLhloTensorLoadOpConverter : public ConversionPattern {
+class HloToLhloTensorLoadOpConverter
+ : public BaseOpConversion<mlir::TensorLoadOp> {
public:
- explicit HloToLhloTensorLoadOpConverter(MLIRContext* context)
- : ConversionPattern(TensorLoadOp::getOperationName(), 1, context) {}
+ using BaseOpConversion<mlir::TensorLoadOp>::BaseOpConversion;
LogicalResult matchAndRewrite(
- Operation* op, ArrayRef<Value> operands,
+ mlir::TensorLoadOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
rewriter.replaceOp(op, operands);
return success();
@@ -243,13 +230,13 @@
};
// TODO(b/137624192): Rewrite into a copy and elide copy if possible.
-class HloToLhloTensorStoreOpConverter : public ConversionPattern {
+class HloToLhloTensorStoreOpConverter
+ : public BaseOpConversion<mlir::TensorStoreOp> {
public:
- explicit HloToLhloTensorStoreOpConverter(MLIRContext* context)
- : ConversionPattern(TensorStoreOp::getOperationName(), 1, context) {}
+ using BaseOpConversion<mlir::TensorStoreOp>::BaseOpConversion;
LogicalResult matchAndRewrite(
- Operation* op, ArrayRef<Value> operands,
+ mlir::TensorStoreOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
rewriter.replaceOpWithNewOp<xla_lhlo::CopyOp>(
op, llvm::None, operands.front(), operands.back());
@@ -291,7 +278,6 @@
// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
// "xla_lhlo.multiply"(%0, %arg0, %arg3) :
// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
-// dealloc %0 : memref<2x2xf32>
// "xla_lhlo.terminator"() : () -> ()
// }) : () -> ()
// return
@@ -313,14 +299,13 @@
// %arg1: memref<4xf32>,
// %arg2: memref<4xf32>) {
// %0 = alloc() : memref<4xf32>
-// %1 = alloc() : memref<4xf32>
+
// "xla_lhlo.maximum"(%arg0, %arg1, %0) :
// (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
+// %1 = alloc() : memref<4xf32>
// "xla_lhlo.add"(%arg0, %0, %1) :
// (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
// "xla_lhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
-// dealloc %0 : memref<4xf32>
-// dealloc %1 : memref<4xf32>
// "xla_lhlo.terminator"() : () -> ()
// }
@@ -346,101 +331,25 @@
});
auto module = getOperation();
- populateHLOToLHLOConversionPattern(module.getContext(), &patterns);
-
- // Do partial conversion so we can have unknown ops in tests.
- if (failed(applyPartialConversion(module, target, patterns, nullptr))) {
- signalPassFailure();
- }
+ BufferAssignmentTypeConverter converter;
+ module.walk([&](FuncOp func) {
+ BufferAssignmentPlacer bufferAssignment(func);
+ OwningRewritePatternList patterns;
+ populateHLOToLHLOConversionPattern(func.getContext(), &bufferAssignment,
+ &converter, &patterns);
+ return WalkResult(
+ applyPartialConversion(func, target, patterns, &converter));
+ });
}
};
-
-Type ConvertType(Type t) {
- if (auto tensorType = t.dyn_cast<RankedTensorType>()) {
- return MemRefType::get(tensorType.getShape(), tensorType.getElementType());
- }
- return t;
-}
-
} // namespace
-/// Transforms FuncOp arguments and results from tensors to buffers. Tensor
-/// results are converted to memrefs and appended to the argument list.
-class HloToLhloFuncOpConverter : public OpConversionPattern<FuncOp> {
- public:
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult matchAndRewrite(
- FuncOp funcOp, ArrayRef<Value> operands,
- ConversionPatternRewriter& rewriter) const final {
- if (funcOp.getBody().getBlocks().size() > 1) {
- funcOp.emitOpError() << "tensor to buffer conversion expects a single "
- "block in the region containing the operation";
- return failure();
- }
-
- auto funcType = funcOp.getType();
-
- TypeConverter::SignatureConversion conversion(funcType.getNumInputs());
- for (auto argType : llvm::enumerate(funcType.getInputs())) {
- conversion.addInputs(argType.index(), ConvertType(argType.value()));
- }
- for (auto resType : funcType.getResults()) {
- conversion.addInputs(ConvertType(resType));
- }
- rewriter.updateRootInPlace(funcOp, [&] {
- funcOp.setType(
- rewriter.getFunctionType(conversion.getConvertedTypes(), llvm::None));
- rewriter.applySignatureConversion(&funcOp.getBody(), conversion);
- });
- return success();
- }
-};
-
-/// Transforms ReturnOp to LhloTerminator. CopyOp is inserted to copy each
-/// result to the corresponding buffer argument.
-class StdToLhloReturnOpConverter : public OpConversionPattern<mlir::ReturnOp> {
- public:
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult matchAndRewrite(
- mlir::ReturnOp returnOp, ArrayRef<Value> operands,
- ConversionPatternRewriter& rewriter) const final {
- auto numReturnValues = returnOp.getNumOperands();
- auto funcOp = returnOp.getParentOfType<FuncOp>();
- auto numFuncArgs = funcOp.getNumArguments();
- auto loc = returnOp.getLoc();
-
- for (auto operand : llvm::enumerate(operands)) {
- auto returnArgNumber = numFuncArgs - numReturnValues + operand.index();
- auto dstBuffer = funcOp.getArgument(returnArgNumber);
- if (dstBuffer == operand.value()) {
- continue;
- }
-
- auto dealloc = FindInsertionPointForCopy(operand.value());
-
- if (dealloc == nullptr) {
- returnOp.emitOpError()
- << "Missing dealloc for operand " << operand.index();
- return failure();
- }
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPoint(dealloc);
- rewriter.create<xla_lhlo::CopyOp>(loc, llvm::None, operand.value(),
- funcOp.getArgument(returnArgNumber));
- }
- rewriter.replaceOpWithNewOp<xla_lhlo::TerminatorOp>(returnOp);
- return success();
- }
-};
-
-void populateHLOToLHLOConversionPattern(MLIRContext* context,
- OwningRewritePatternList* patterns) {
+void populateHLOToLHLOConversionPattern(
+ MLIRContext* context, BufferAssignmentPlacer* bufferAssignment,
+ TypeConverter* converter, OwningRewritePatternList* patterns) {
// clang-format off
patterns->insert<
HloToLhloDynamicBroadcastInDimOpConverter,
- HloToLhloFuncOpConverter,
HloToLhloOpConverter<xla_hlo::AbsOp>,
HloToLhloOpConverter<xla_hlo::AddOp>,
HloToLhloOpConverter<xla_hlo::AndOp>,
@@ -472,8 +381,9 @@
HloToLhloReduceOpConverter,
HloToLhloTensorLoadOpConverter,
HloToLhloTensorStoreOpConverter,
- StdToLhloReturnOpConverter
- >(context);
+ FunctionAndBlockSignatureConverter,
+ StdReturnOpConverter
+ >(context, bufferAssignment, converter);
// clang-format on
}
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
index de808bc..a0a5e47 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
@@ -2590,6 +2590,21 @@
}
};
+ElementsAttr ConvertAxisAttr(Value val, ElementsAttr attr, Builder *builder) {
+ auto int_attr = attr.cast<DenseIntElementsAttr>();
+ auto type = val.getType().cast<ShapedType>();
+
+ SmallVector<int64_t, 6> axis;
+ axis.reserve(int_attr.getNumElements());
+
+ int64_t rank = type.getRank();
+ for (auto val : int_attr.getValues<APInt>()) {
+ axis.push_back((val.getSExtValue() + rank) % rank);
+ }
+
+ return builder->getI64TensorAttr(axis);
+}
+
/// Converts the LinSpace tensorflow op to a xla_hlo.iota op with a scaling
/// and offset applied to generate the linspace values. The output tensor needs
/// to have a static shape. The implementation is defined in C++ because there
@@ -4182,6 +4197,68 @@
}
};
+// Converts a TF InplaceUpdate op to DynamicUpdateSlice HLO.
+class ConvertInplaceUpdateOp : public OpRewritePattern<TF::InplaceUpdateOp> {
+ public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(TF::InplaceUpdateOp op,
+ PatternRewriter &rewriter) const override {
+ auto input = op.x();
+ auto indices = op.i();
+ auto updates = op.v();
+
+ // Slice each row of `i` and `v` to perform a separate dynamic-update-slice
+ // on the contents of `x`.
+ auto input_type = input.getType().cast<ShapedType>();
+ auto updates_type = updates.getType().cast<ShapedType>();
+ auto indices_type = indices.getType().cast<ShapedType>();
+ if (!indices_type.hasStaticShape()) return failure();
+
+ if (indices_type.getRank() != 1) return failure();
+
+ SmallVector<Type, 4> unpacked_indices_type(
+ indices_type.getDimSize(0),
+ RankedTensorType::get({}, indices_type.getElementType()));
+ auto zero_attr = IntegerAttr::get(rewriter.getIntegerType(64), 0);
+ auto unpacked_indices = rewriter.create<TF::UnpackOp>(
+ op.getLoc(), unpacked_indices_type, indices, zero_attr);
+
+ SmallVector<int64_t, 4> split_updates_shape;
+ split_updates_shape.append(updates_type.getShape().begin(),
+ updates_type.getShape().end());
+ split_updates_shape.front() = 1;
+ SmallVector<Type, 4> split_updates_type;
+ split_updates_type.resize(
+ updates_type.getShape().front(),
+ RankedTensorType::get(split_updates_shape,
+ updates_type.getElementType()));
+
+ auto cst =
+ rewriter.create<xla_hlo::ConstOp>(op.getLoc(), zero_attr).getResult();
+ auto split_updates = rewriter.create<TF::SplitOp>(
+ op.getLoc(), split_updates_type, cst, updates);
+
+ SmallVector<Value, 6> input_indices;
+ input_indices.resize(input_type.getRank(), cst);
+
+ SmallVector<int64_t, 6> starts(updates_type.getRank(), 0);
+ SmallVector<int64_t, 6> strides(updates_type.getRank(), 1);
+ SmallVector<int64_t, 6> limits(updates_type.getShape().begin(),
+ updates_type.getShape().end());
+
+ for (auto pair :
+ llvm::zip(unpacked_indices.output(), split_updates.output())) {
+ input_indices.front() = std::get<0>(pair);
+ input = rewriter.create<xla_hlo::DynamicUpdateSliceOp>(
+ op.getLoc(), op.getType(), input, std::get<1>(pair), input_indices);
+ }
+
+ rewriter.replaceOp(op, input);
+ return success();
+ }
+};
+
// Converts a TF XlaDynamicUpdateSlice op to DynamicUpdateSlice HLO.
class ConvertXlaDynamicUpdateSliceOp
: public OpRewritePattern<TF::XlaDynamicUpdateSliceOp> {
@@ -4863,12 +4940,13 @@
ConvertConv3DBackpropInputOp, ConvertCumsumOp, ConvertDiagPartOp,
ConvertEinsumOp, ConvertFusedBatchNormGradOp,
ConvertFusedBatchNormGradV2Op, ConvertFusedBatchNormGradV3Op,
- ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp, ConvertLinSpaceOp,
- ConvertMaxOp, ConvertMinOp, ConvertAvgPoolOp, ConvertMaxPool2DOp,
- ConvertMaxPool3DOp, ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp,
- ConvertMeanOp, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp,
- ConvertProdOp, ConvertQrOp, ConvertRangeOp, ConvertSelectV2Op,
- ConvertSigmoidOp, ConvertSizeOp, ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
+ ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp,
+ ConvertInplaceUpdateOp, ConvertLinSpaceOp, ConvertMaxOp, ConvertMinOp,
+ ConvertAvgPoolOp, ConvertMaxPool2DOp, ConvertMaxPool3DOp,
+ ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp, ConvertMeanOp,
+ ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertQrOp,
+ ConvertRangeOp, ConvertSelectV2Op, ConvertSigmoidOp, ConvertSizeOp,
+ ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,
ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
index d53dbdc..2a27c1f 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
@@ -274,6 +274,13 @@
(CastElementsToI64Elements $group_assignment))>;
//===----------------------------------------------------------------------===//
+// All2All op patterns.
+//===----------------------------------------------------------------------===//
+
+def : Pat<(TF_AllToAllOp AnyRankedTensor:$input, (TF_ConstOp $group_assignment), I64Attr:$concat_dimension, $split_dimension, $split_count),
+ (HLO_AllToAllOp $input, $split_dimension, $concat_dimension, $split_count, (CastElementsToI64Elements $group_assignment))>;
+
+//===----------------------------------------------------------------------===//
// FFT op patterns.
//===----------------------------------------------------------------------===//
@@ -514,6 +521,16 @@
}
//===----------------------------------------------------------------------===//
+// Reverse op patterns.
+//===----------------------------------------------------------------------===//
+
+// Handles axis conversion for TF reverse.
+def ConvertAxisAttr : NativeCodeCall<"ConvertAxisAttr($0, $1, &$_builder)">;
+
+def : Pat<(TF_ReverseV2Op AnyRankedTensor:$values, (TF_ConstOp $axis)),
+ (HLO_ReverseOp $values, (ConvertAxisAttr $values, $axis))>;
+
+//===----------------------------------------------------------------------===//
// Ternary op patterns.
//===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
index be6ba16..86a2def 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
@@ -102,6 +102,7 @@
TypeID::get<TF::CastOp>(),
TypeID::get<TF::ClipByValueOp>(),
TypeID::get<TF::ComplexAbsOp>(),
+ TypeID::get<TF::ConjugateTransposeOp>(),
TypeID::get<TF::CoshOp>(),
TypeID::get<TF::CrossOp>(),
TypeID::get<TF::DataFormatDimMapOp>(),
@@ -135,9 +136,11 @@
TypeID::get<TF::MulOp>(),
TypeID::get<TF::NegOp>(),
TypeID::get<TF::NotEqualOp>(),
+ TypeID::get<TF::PadOp>(),
TypeID::get<TF::PlaceholderWithDefaultOp>(),
TypeID::get<TF::PowOp>(),
TypeID::get<TF::RealDivOp>(),
+ TypeID::get<TF::ReciprocalOp>(),
TypeID::get<TF::ReciprocalGradOp>(),
TypeID::get<TF::Relu6GradOp>(),
TypeID::get<TF::RightShiftOp>(),
@@ -162,6 +165,8 @@
TypeID::get<TF::TruncateModOp>(),
TypeID::get<TF::UnpackOp>(),
TypeID::get<TF::XdivyOp>(),
+ TypeID::get<TF::XlaBroadcastHelperOp>(),
+ TypeID::get<TF::XlaConvOp>(),
TypeID::get<TF::XlaDotOp>(),
TypeID::get<TF::XlaPadOp>(),
TypeID::get<TF::Xlog1pyOp>(),
diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc
index e6f3ac0..f0eb3cc 100644
--- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc
@@ -21,7 +21,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project
-#include "mlir/Dialect/LoopOps/LoopOps.h" // from @llvm-project
+#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
@@ -112,7 +112,7 @@
auto step = rewriter.create<mlir::ConstantOp>(
loc, rewriter.getIndexType(),
rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
- auto loop = rewriter.create<mlir::loop::ForOp>(loc, zero, upper, step);
+ auto loop = rewriter.create<mlir::scf::ForOp>(loc, zero, upper, step);
rewriter.setInsertionPointToStart(loop.getBody());
// Compute memrefs for the value to reduce. This makes it easier to just
@@ -173,8 +173,7 @@
OwningRewritePatternList patterns;
ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
- gpu::GPUDialect, loop::LoopOpsDialect,
- XlaLhloDialect>();
+ gpu::GPUDialect, scf::SCFDialect, XlaLhloDialect>();
target.addIllegalOp<ReduceOp>();
auto func = getFunction();
patterns.insert<LhloReduceToGPULaunchConverter>(func.getContext());
diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc
index 54b3acd..c5f5b39 100644
--- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc
@@ -18,7 +18,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project
-#include "mlir/Dialect/LoopOps/LoopOps.h" // from @llvm-project
+#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
@@ -64,12 +64,12 @@
// into a reduction operator of loop.reduce by doing buffer allocation for
// scalar arguments and the result of `loop.reduce` to make it compatible with
// LHLO ops.
-void ConvertToReductionOperator(Location loc, loop::ReduceOp reduce_op,
+void ConvertToReductionOperator(Location loc, scf::ReduceOp reduce_op,
Block* lhlo_block, OpBuilder* b) {
Block& loop_reduce_op_body = reduce_op.reductionOperator().front();
OpBuilder::InsertionGuard guard(*b);
b->setInsertionPointToStart(&loop_reduce_op_body);
- b->create<loop::ReduceReturnOp>(
+ b->create<scf::ReduceReturnOp>(
loc, ApplySingleResultLhloCode(loc, loop_reduce_op_body.getArguments(),
lhlo_block, b));
}
@@ -136,9 +136,9 @@
return mapped_ivs;
}
-// Returns loop::Parallel over a shaped value with static or dynamic shape.
-loop::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value,
- OpBuilder* b) {
+// Returns scf::Parallel over a shaped value with static or dynamic shape.
+scf::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value,
+ OpBuilder* b) {
Value zero = b->create<ConstantIndexOp>(loc, 0);
Value one = b->create<ConstantIndexOp>(loc, 1);
@@ -151,10 +151,10 @@
lower.push_back(zero);
step.push_back(one);
}
- return b->create<loop::ParallelOp>(loc, lower, upper, step);
+ return b->create<scf::ParallelOp>(loc, lower, upper, step);
}
-// Converts `xla_lhlo.ReduceOp` into two loop::ParallelOp and a loop::ReduceOp.
+// Converts `xla_lhlo.ReduceOp` into two scf::ParallelOp and a scf::ReduceOp.
// The outper `ParallelOp` refers to the parallel loops if there are
// any. The inner `ParalleOp` refers to the reduction loops and `ReduceOp`
// contains the reduction operator.
@@ -197,7 +197,7 @@
// TODO(b/137624192) Implement variadic reduce.
if (xla_reduce_op.out().size() != 1) return failure();
- loop::ReduceOp reduce_op =
+ scf::ReduceOp reduce_op =
CreateReduceOpInNestedParallelLoops(xla_reduce_op, &rewriter);
ConvertToReductionOperator(xla_reduce_op.getLoc(), reduce_op,
&xla_reduce_op.body().front(), &rewriter);
@@ -225,7 +225,7 @@
// } : f32
// loop.yield
// }
- loop::ReduceOp CreateReduceOpInNestedParallelLoops(
+ scf::ReduceOp CreateReduceOpInNestedParallelLoops(
xla_lhlo::ReduceOp xla_reduce_op,
ConversionPatternRewriter* rewriter) const {
auto loc = xla_reduce_op.getLoc();
@@ -254,13 +254,13 @@
SmallVector<Value, 1> init_value = {
rewriter->create<LoadOp>(loc, *xla_reduce_op.init_values().begin())};
// Outer ParallelOp is not needed if it is a reduction across all dims.
- loop::ParallelOp outer;
+ scf::ParallelOp outer;
if (!parallel_lower.empty()) {
- outer = rewriter->create<loop::ParallelOp>(loc, parallel_lower,
- parallel_upper, parallel_step);
+ outer = rewriter->create<scf::ParallelOp>(loc, parallel_lower,
+ parallel_upper, parallel_step);
rewriter->setInsertionPointToStart(outer.getBody());
}
- loop::ParallelOp inner = rewriter->create<loop::ParallelOp>(
+ scf::ParallelOp inner = rewriter->create<scf::ParallelOp>(
loc, reduce_lower, reduce_upper, reduce_step, init_value);
Value reduction_result = *inner.getResults().begin();
@@ -294,7 +294,7 @@
rewriter->setInsertionPointToStart(inner.getBody());
Value elem = rewriter->create<mlir::LoadOp>(
loc, *xla_reduce_op.operands().begin(), indices);
- return rewriter->create<loop::ReduceOp>(loc, elem);
+ return rewriter->create<scf::ReduceOp>(loc, elem);
}
};
@@ -314,8 +314,8 @@
// accumulator = reduction_operator(output[O], value)
// output[O] = accumulator
//
-// Converts `xla_lhlo.ReduceWindowOp` into two loop::ParallelOp and a
-// loop::ReduceOp.
+// Converts `xla_lhlo.ReduceWindowOp` into two scf::ParallelOp and a
+// scf::ReduceOp.
// The outper `ParallelOp` refers to the parallel loops that traverese output
// buffer. The inner `ParalleOp` refers to the reduction loops that traverse
// reduction windows and `ReduceOp` contains the reduction operator.
@@ -366,12 +366,12 @@
LogicalResult matchAndRewrite(
xla_lhlo::ReduceWindowOp xla_reduce_window_op, ArrayRef<Value> /*args*/,
ConversionPatternRewriter& rewriter) const final {
- loop::ParallelOp output_loop, window_loop;
+ scf::ParallelOp output_loop, window_loop;
std::tie(output_loop, window_loop) =
CreateParallelLoopsToTraverseOutputAndWindow(xla_reduce_window_op,
&rewriter);
- loop::ReduceOp reduce_op = CreateReduceOpInNestedParallelLoops(
+ scf::ReduceOp reduce_op = CreateReduceOpInNestedParallelLoops(
xla_reduce_window_op, output_loop, window_loop, &rewriter);
ConvertToReductionOperator(xla_reduce_window_op.getLoc(), reduce_op,
@@ -381,7 +381,7 @@
}
private:
- std::pair<loop::ParallelOp, loop::ParallelOp>
+ std::pair<scf::ParallelOp, scf::ParallelOp>
CreateParallelLoopsToTraverseOutputAndWindow(
xla_lhlo::ReduceWindowOp xla_reduce_window_op,
ConversionPatternRewriter* rewriter) const {
@@ -405,7 +405,7 @@
window_upper.push_back(
rewriter->create<ConstantIndexOp>(loc, window_dim.getSExtValue()));
}
- auto window_loop = rewriter->create<loop::ParallelOp>(
+ auto window_loop = rewriter->create<scf::ParallelOp>(
loc, window_lower, window_upper, window_step, init_value);
Value reduction_result = *window_loop.getResults().begin();
@@ -414,9 +414,9 @@
return std::make_pair(output_loop, window_loop);
}
- loop::ReduceOp CreateReduceOpInNestedParallelLoops(
+ scf::ReduceOp CreateReduceOpInNestedParallelLoops(
xla_lhlo::ReduceWindowOp xla_reduce_window_op,
- loop::ParallelOp output_loop, loop::ParallelOp window_loop,
+ scf::ParallelOp output_loop, scf::ParallelOp window_loop,
ConversionPatternRewriter* rewriter) const {
rewriter->setInsertionPointToStart(window_loop.getBody());
auto loc = xla_reduce_window_op.getLoc();
@@ -436,20 +436,20 @@
xla_reduce_window_op, output_loop.getInductionVars(),
window_loop.getInductionVars(), rewriter);
- auto elem_or_init = rewriter->create<loop::IfOp>(
+ auto elem_or_init = rewriter->create<scf::IfOp>(
loc, xla_operand_type.getElementType(), mapped_ivs.in_bounds,
/*withElseRegion=*/true);
OpBuilder then_builder = elem_or_init.getThenBodyBuilder();
Value elem = then_builder.create<mlir::LoadOp>(
loc, xla_reduce_window_op.operand(), mapped_ivs.ivs);
- then_builder.create<loop::YieldOp>(loc, elem);
+ then_builder.create<scf::YieldOp>(loc, elem);
OpBuilder else_builder = elem_or_init.getElseBodyBuilder();
- else_builder.create<loop::YieldOp>(loc, *window_loop.initVals().begin());
+ else_builder.create<scf::YieldOp>(loc, *window_loop.initVals().begin());
- return rewriter->create<loop::ReduceOp>(loc,
- *elem_or_init.results().begin());
+ return rewriter->create<scf::ReduceOp>(loc,
+ *elem_or_init.results().begin());
}
};
@@ -490,7 +490,7 @@
ConversionPatternRewriter& rewriter) const final {
auto loc = s_and_s_op.getLoc();
InitializeOutput(s_and_s_op, &rewriter);
- loop::ParallelOp loop_over_src =
+ scf::ParallelOp loop_over_src =
MakeLoopOverShape(loc, s_and_s_op.source(), &rewriter);
rewriter.setInsertionPointToStart(loop_over_src.getBody());
@@ -520,7 +520,7 @@
auto loc = s_and_s_op.getLoc();
Value init_value = b->create<LoadOp>(loc, s_and_s_op.init_value());
- loop::ParallelOp loop_over_output =
+ scf::ParallelOp loop_over_output =
MakeLoopOverShape(loc, s_and_s_op.out(), b);
OpBuilder::InsertionGuard guard(*b);
b->setInsertionPointToStart(loop_over_output.getBody());
@@ -531,10 +531,10 @@
struct WindowLoops {
SmallVector<Value, 2> selected_ivs;
SmallVector<Value, 2> window_ivs;
- loop::ForOp inner_loop;
+ scf::ForOp inner_loop;
};
WindowLoops InsertWindowLoops(xla_lhlo::SelectAndScatterOp s_and_s_op,
- loop::ParallelOp loop_over_src,
+ scf::ParallelOp loop_over_src,
OpBuilder* b) const {
auto loc = s_and_s_op.getLoc();
Value zero = b->create<ConstantIndexOp>(loc, 0);
@@ -558,12 +558,12 @@
s_and_s_op.window_dimensions()->getIntValues()) {
Value upper = b->create<ConstantIndexOp>(loc, window_dim.getSExtValue());
result.inner_loop =
- b->create<loop::ForOp>(loc, zero, upper, one, iter_args);
+ b->create<scf::ForOp>(loc, zero, upper, one, iter_args);
if (b->getInsertionBlock() == loop_over_src.getBody()) {
ip = b->saveInsertionPoint();
result.selected_ivs = result.inner_loop.getResults().take_front(rank);
} else {
- b->create<loop::YieldOp>(loc, result.inner_loop.getResults());
+ b->create<scf::YieldOp>(loc, result.inner_loop.getResults());
}
b->setInsertionPointToStart(result.inner_loop.getBody());
iter_args = ValueRange{result.inner_loop.getRegionIterArgs()};
@@ -599,7 +599,7 @@
};
SmallVector<Value, 2> SelectIvs(xla_lhlo::SelectAndScatterOp s_and_s_op,
- loop::ParallelOp loop_over_src,
+ scf::ParallelOp loop_over_src,
OpBuilder* b) const {
auto loc = s_and_s_op.getLoc();
@@ -614,7 +614,7 @@
IterArgs ivs_val_flag(window_loops.inner_loop.getRegionIterArgs());
- auto if_in_bounds = inner_loop_b.create<loop::IfOp>(
+ auto if_in_bounds = inner_loop_b.create<scf::IfOp>(
loc, window_loops.inner_loop.getResultTypes(), mapped_ivs.in_bounds,
/*withElseRegion=*/true);
@@ -623,16 +623,16 @@
OpBuilder in_bounds_then_b = if_in_bounds.getThenBodyBuilder();
auto select_or_init_results = SelectOrInitialize(
s_and_s_op, mapped_ivs.ivs, &ivs_val_flag, &in_bounds_then_b);
- in_bounds_then_b.create<loop::YieldOp>(loc, select_or_init_results);
+ in_bounds_then_b.create<scf::YieldOp>(loc, select_or_init_results);
}
// Case when we are in the pad.
{
OpBuilder in_bounds_else_b = if_in_bounds.getElseBodyBuilder();
- in_bounds_else_b.create<loop::YieldOp>(loc, ivs_val_flag.to_vector());
+ in_bounds_else_b.create<scf::YieldOp>(loc, ivs_val_flag.to_vector());
}
- inner_loop_b.create<loop::YieldOp>(loc, if_in_bounds.getResults());
+ inner_loop_b.create<scf::YieldOp>(loc, if_in_bounds.getResults());
return window_loops.selected_ivs;
}
@@ -647,8 +647,8 @@
Value operand_elem =
b->create<LoadOp>(loc, s_and_s_op.operand(), operand_ivs);
auto if_init =
- b->create<loop::IfOp>(loc, iter_arg_types, ivs_val_flag->is_init(),
- /*withElseRegion=*/true);
+ b->create<scf::IfOp>(loc, iter_arg_types, ivs_val_flag->is_init(),
+ /*withElseRegion=*/true);
// Init == true, i.e. iter args are already initialized with a selected
// element in boundaries of the operand. Select function has to be computed
// here.
@@ -660,32 +660,31 @@
ApplySingleResultLhloCode(loc, {operand_elem, ivs_val_flag->value()},
&lhlo_select, &if_init_then_b);
- auto if_pred =
- if_init_then_b.create<loop::IfOp>(loc, iter_arg_types, pred,
- /*withElseRegion=*/true);
+ auto if_pred = if_init_then_b.create<scf::IfOp>(loc, iter_arg_types, pred,
+ /*withElseRegion=*/true);
// Pred == true, therefore pack newly selected ivs, val and init flag back
// to iter_args and return.
{
OpBuilder if_pred_then_b = if_pred.getThenBodyBuilder();
- if_pred_then_b.create<loop::YieldOp>(
+ if_pred_then_b.create<scf::YieldOp>(
loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector());
}
// Pred == false, therefore return old iter_args.
{
OpBuilder if_pred_else_b = if_pred.getElseBodyBuilder();
- if_pred_else_b.create<loop::YieldOp>(loc, ivs_val_flag->to_vector());
+ if_pred_else_b.create<scf::YieldOp>(loc, ivs_val_flag->to_vector());
}
- if_init_then_b.create<loop::YieldOp>(loc, if_pred.getResults());
+ if_init_then_b.create<scf::YieldOp>(loc, if_pred.getResults());
}
// Init == false, i.e. only pad was visited before and this is the first
// element in the boundaries of the operand.
{
OpBuilder if_init_else_b = if_init.getElseBodyBuilder();
- if_init_else_b.create<loop::YieldOp>(
+ if_init_else_b.create<scf::YieldOp>(
loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector());
}
return if_init.getResults();
@@ -708,7 +707,7 @@
ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
- loop::LoopOpsDialect, XlaLhloDialect>();
+ scf::SCFDialect, XlaLhloDialect>();
target.addIllegalOp<xla_lhlo::ReduceOp, xla_lhlo::ReduceWindowOp,
xla_lhlo::SelectAndScatterOp>();
diff --git a/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h b/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h
index 982ec4f..c317dc3 100644
--- a/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h
+++ b/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h
@@ -281,11 +281,9 @@
// No conversion is needed for the same width integers
return args.front();
}
- // TODO(dfki-ehna): Add other primitive type conversions
- // if (mlir::FpToSiOp::areCastCompatible(sourceType, targetType)) {
- // return b.create<mlir::FpToSiOp>(loc, result_types,
- // args,mlir::None);
- // }
+ if (mlir::FPToSIOp::areCastCompatible(sourceType, targetType)) {
+ return b->create<mlir::FPToSIOp>(loc, result_types, args, mlir::None);
+ }
return nullptr;
}
diff --git a/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc b/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc
index a4ffa57..bf66640 100644
--- a/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc
@@ -50,12 +50,6 @@
template <typename SrcOp>
bool CreateStaticBroadcastsForBinaryOp(SrcOp op, PatternRewriter *rewriter,
Value *out_lhs, Value *out_rhs) {
- if (!op.broadcast_dimensions().hasValue()) {
- // Note: the op may still have an implicit broadcast on it, such as
- // for (tensor<1xf32>, tensor<4xf32>).
- return false;
- }
-
// Insert BroadcastInDimOps for the left-hand-side and right-hand-side args,
// replacing the original LHS and RHS args in the source op with the results
// of the broadcasts.
@@ -79,25 +73,7 @@
auto lhs_rank = lhs_ranked_type.getRank();
auto rhs_rank = rhs_ranked_type.getRank();
-
- // Set broadcast_dimensions to [0, ..., rank] for the higher rank arg.
- // Use the original op.broadcast_dimensions for the lower rank arg.
- auto higher_rank_broadcast_dims =
- GetI64ElementsAttrForSeq(0, std::max(lhs_rank, rhs_rank), rewriter);
- DenseIntElementsAttr lhs_broadcast_dims;
- DenseIntElementsAttr rhs_broadcast_dims;
- if (lhs_rank > rhs_rank) {
- lhs_broadcast_dims = higher_rank_broadcast_dims;
- rhs_broadcast_dims = op.broadcast_dimensions().getValue();
- } else if (lhs_rank < rhs_rank) {
- lhs_broadcast_dims = op.broadcast_dimensions().getValue();
- rhs_broadcast_dims = higher_rank_broadcast_dims;
- } else {
- // This shouldn't happen for legal ops. If the broadcast_dimensions
- // attribute is set, the ranks should be different.
- // TODO(scotttodd): Add a custom verification for ops and assert here.
- return false;
- }
+ ArrayRef<int64_t> op_shape = op_ranked_type.getShape();
// BroadcastInDimOp must have the same element type for operands and results,
// so preserve the original output shape and the original input element type.
@@ -105,16 +81,32 @@
// broadcast_in_dim (tensor<1x4xf32>) -> tensor<1x4xf32>
// broadcast_in_dim (tensor<4xf32>) -> tensor<1x4xf32>
// SrcOp (tensor<1x4xf32>, tensor<1x4xf32>) -> tensor<1x4xi1>
- ArrayRef<int64_t> op_shape = op_ranked_type.getShape();
- auto lhs_type =
- RankedTensorType::get(op_shape, lhs_ranked_type.getElementType());
- auto rhs_type =
- RankedTensorType::get(op_shape, rhs_ranked_type.getElementType());
+ if (lhs_ranked_type.getShape() != op_ranked_type.getShape()) {
+ auto type =
+ RankedTensorType::get(op_shape, lhs_ranked_type.getElementType());
+ DenseIntElementsAttr attr = GetI64ElementsAttrForSeq(0, lhs_rank, rewriter);
+ if (lhs_rank < rhs_rank) {
+ attr = op.broadcast_dimensions().getValue();
+ }
- *out_lhs = rewriter->createOrFold<BroadcastInDimOp>(op.getLoc(), lhs_type,
- lhs, lhs_broadcast_dims);
- *out_rhs = rewriter->createOrFold<BroadcastInDimOp>(op.getLoc(), rhs_type,
- rhs, rhs_broadcast_dims);
+ lhs =
+ rewriter->createOrFold<BroadcastInDimOp>(op.getLoc(), type, lhs, attr);
+ }
+
+ if (rhs_ranked_type.getShape() != op_ranked_type.getShape()) {
+ auto type =
+ RankedTensorType::get(op_shape, rhs_ranked_type.getElementType());
+ DenseIntElementsAttr attr = GetI64ElementsAttrForSeq(0, rhs_rank, rewriter);
+ if (rhs_rank < lhs_rank) {
+ attr = op.broadcast_dimensions().getValue();
+ }
+
+ rhs =
+ rewriter->createOrFold<BroadcastInDimOp>(op.getLoc(), type, rhs, attr);
+ }
+
+ *out_lhs = lhs;
+ *out_rhs = rhs;
return true;
}
@@ -359,9 +351,15 @@
void SetupMaterializeBroadcastsLegality(MLIRContext *context,
ConversionTarget *conversionTarget) {
-#define ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(OpType) \
- conversionTarget->addDynamicallyLegalOp<OpType>( \
- [](OpType op) { return !op.broadcast_dimensions().hasValue(); });
+#define ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(OpType) \
+ conversionTarget->addDynamicallyLegalOp<OpType>([](OpType op) { \
+ if (op.broadcast_dimensions().hasValue()) return false; \
+ auto l = op.lhs().getType().cast<ShapedType>(); \
+ auto r = op.rhs().getType().cast<ShapedType>(); \
+ if (!l.hasRank() || !r.hasRank()) return false; \
+ return l.getShape() == r.getShape(); \
+ });
+
// Binary elementwise ops.
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(AddOp);
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(Atan2Op);
diff --git a/tensorflow/compiler/mlir/xla/transforms/rewriters.h b/tensorflow/compiler/mlir/xla/transforms/rewriters.h
index ad81cda..9cde6f8 100644
--- a/tensorflow/compiler/mlir/xla/transforms/rewriters.h
+++ b/tensorflow/compiler/mlir/xla/transforms/rewriters.h
@@ -23,6 +23,7 @@
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
namespace mlir {
+class BufferAssignmentPlacer;
namespace xla_hlo {
// Collection of rewrite patterns for lowering a general dot product.
@@ -38,9 +39,9 @@
MLIRContext *ctx);
// Collection of rewrite patterns for lowering of HLO to LHLO dialect.
-void populateHLOToLHLOConversionPattern(MLIRContext *context,
- OwningRewritePatternList *patterns);
-
+void populateHLOToLHLOConversionPattern(
+ MLIRContext *context, BufferAssignmentPlacer *bufferAssignment,
+ TypeConverter *converter, OwningRewritePatternList *patterns);
// Collection of rewrite patterns for lowering of HLO to Linalg dialect.
void populateHLOToLinalgConversionPattern(MLIRContext *context,
OwningRewritePatternList *patterns);
diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.cc
index 436a3e7..a12bd9e 100644
--- a/tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.cc
@@ -251,17 +251,15 @@
// Create the view for this slice size, possible with an affine map to model
// the offset. The result is cached in the slices_ map.
- SmallVector<AffineMap, 1> offset_map;
- if (slice.offset()) {
- offset_map.push_back(AffineMap::get(
- /*dimCount=*/1, /*symbolCount=*/0,
- {getAffineDimExpr(0, builder_.getContext()) + slice.offset()},
- builder_.getContext()));
- }
- auto slice_type = MemRefType::get({slice.size()}, i8_type_, offset_map);
+ // The std.view result type does not carry the static offset: this is not
+ // useful information. Rather, the view op must have the static offset.
+ auto slice_type = MemRefType::get({slice.size()}, i8_type_, {});
- auto slice_view = builder_.create<ViewOp>(
- alloc_buffer.getLoc(), slice_type, alloc_buffer, /*operands=*/llvm::None);
+ Value byte_shift =
+ builder_.create<ConstantIndexOp>(alloc_buffer.getLoc(), slice.offset());
+ auto slice_view =
+ builder_.create<ViewOp>(alloc_buffer.getLoc(), slice_type, alloc_buffer,
+ byte_shift, /*sizes=*/ArrayRef<Value>{});
slices_.insert({slice_key, slice_view});
return slice_view;
}
@@ -277,9 +275,12 @@
Value slice_view = GetOrCreateView(out_slice);
TF_ASSIGN_OR_RETURN(Type out_type, ::xla::ConvertShapeToType<MemRefType>(
target_shape, builder_));
+ Value byte_shift =
+ builder_.create<ConstantIndexOp>(builder_.getUnknownLoc(), 0);
if (slice_view.getType() != out_type)
- slice_view = builder_.create<ViewOp>(builder_.getUnknownLoc(), out_type,
- slice_view, llvm::None);
+ slice_view =
+ builder_.create<ViewOp>(builder_.getUnknownLoc(), out_type, slice_view,
+ byte_shift, /*sizes=*/ArrayRef<Value>{});
return slice_view;
}
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 5cca1e6..ea4ba8d 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -562,6 +562,7 @@
name = "dynamic_slice_ops_test",
size = "small",
srcs = ["dynamic_slice_ops_test.py"],
+ enable_mlir_bridge = True,
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
@@ -1386,6 +1387,7 @@
size = "medium",
srcs = ["fused_batchnorm_test.py"],
python_version = "PY3",
+ shard_count = 5,
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index bd01319..00ed6d8 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -1096,8 +1096,6 @@
x,
expected=np.matmul(x, x.transpose([0, 1, 3, 2])))
- @test_util.disable_mlir_bridge(
- "TODO(b/155097273): Handle complex dtype constants")
def testExpandDims(self):
for dtype in self.numeric_types:
self._testBinary(
@@ -1195,8 +1193,6 @@
np.full([1, 1, 3, 5], 3., dtype=np.float32),
expected=np.full([4, 5, 1, 2, 5], 18., dtype=np.float32))
- @test_util.disable_mlir_bridge(
- "TODO(b/155097273): Handle complex dtype constants")
def testPad(self):
for dtype, pad_type in itertools.product(
self.numeric_types, [np.int32, np.int64]):
@@ -1337,8 +1333,6 @@
],
dtype=dtype))
- @test_util.disable_mlir_bridge(
- "TODO(b/155097273): Handle complex dtype constants")
def testReshape(self):
for dtype in self.numeric_types:
self._testBinary(
@@ -1471,8 +1465,6 @@
[1, 2]],
dtype=dtype))
- @test_util.disable_mlir_bridge(
- "TODO(b/155097273): Handle complex dtype constants")
def testTranspose(self):
for dtype in self.numeric_types:
self._testBinary(
@@ -1491,8 +1483,6 @@
np.array([1, 0], dtype=np.int32),
expected=np.array([[1, 3], [2, 4]], dtype=dtype))
- @test_util.disable_mlir_bridge(
- "TODO(b/155097273): Handle complex dtype constants")
def testConjugateTranspose(self):
for dtype in self.complex_types:
self._testBinary(
diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py
index ff35587..a1bb64e 100644
--- a/tensorflow/compiler/tests/ternary_ops_test.py
+++ b/tensorflow/compiler/tests/ternary_ops_test.py
@@ -48,7 +48,8 @@
{'start': 1, 'end': 2, 'num': 1},
{'start': 1, 'end': 4, 'num': 3},
{'start': 0, 'end': 41, 'num': 42})
- @test_util.disable_mlir_bridge('Requires dynamic shape handling')
+ @test_util.disable_mlir_bridge(
+ 'TODO(b/156174708): Dynamic result types not supported')
def testLinspace(self, start, end, num):
expected = np.linspace(start, end, num, dtype=np.float32)
result = self._testTernary(
@@ -182,7 +183,6 @@
np.array([8, 9], dtype=dtype),
expected=np.array([[7, 9], [8, 7], [8, 9]], dtype=dtype))
- @test_util.disable_mlir_bridge('TODO(b/155097273)')
def testSlice(self):
for dtype in self.numeric_types:
self._testTernary(
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index dc11c24..3e36f67 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -601,8 +601,7 @@
np.array([-1, -0.5, 0, 0.3], dtype=dtype),
expected=np.array([-1., -0.5, 0., 0.296875], dtype=dtype))
- @test_util.disable_mlir_bridge(
- "Complex types not supported in CreateDenseElementsAttrFromLiteral")
+ @test_util.disable_mlir_bridge("TODO(b/156135423): Fix ConvertSigmoidOp")
def testComplexOps(self):
for dtype in self.complex_types:
diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py
index b01c5ae..1f83701 100644
--- a/tensorflow/compiler/tests/xla_ops_test.py
+++ b/tensorflow/compiler/tests/xla_ops_test.py
@@ -51,7 +51,6 @@
equality_fn = self.assertAllClose
equality_fn(result, expected, rtol=1e-3)
- @test_util.disable_mlir_bridge('Not supported yet')
def testAdd(self):
for dtype in self.numeric_types:
self._assertOpOutputMatchesExpected(
@@ -109,7 +108,6 @@
xla_data_pb2.PrecisionConfig.HIGHEST)
@parameterized.parameters(*PRECISION_VALUES)
- @test_util.disable_mlir_bridge('Not supported yet')
def testConv(self, precision):
for dtype in set(self.float_types).intersection(
set([dtypes.bfloat16.as_numpy_dtype, np.float32])):
@@ -194,8 +192,6 @@
args=(np.array([1, 2, 3], dtype=dtype),),
expected=np.array([-1, -2, -3], dtype=dtype))
- @test_util.disable_mlir_bridge(
- 'Requires XlaPad op shape inference to have static result types')
def testPad(self):
for dtype in self.numeric_types:
diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc
index cd52e2f..404f9eb 100644
--- a/tensorflow/compiler/xla/client/executable_build_options.cc
+++ b/tensorflow/compiler/xla/client/executable_build_options.cc
@@ -70,6 +70,12 @@
return *this;
}
+ExecutableBuildOptions& ExecutableBuildOptions::set_use_spmd_partitioning(
+ bool use_spmd_partitioning) {
+ use_spmd_partitioning_ = use_spmd_partitioning;
+ return *this;
+}
+
ExecutableBuildOptions& ExecutableBuildOptions::set_device_assignment(
const DeviceAssignment& device_assignment) {
device_assignment_ = device_assignment;
diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h
index 360ad02..9a7fdd9 100644
--- a/tensorflow/compiler/xla/client/executable_build_options.h
+++ b/tensorflow/compiler/xla/client/executable_build_options.h
@@ -77,6 +77,11 @@
int num_partitions() const { return num_partitions_; }
ExecutableBuildOptions& set_num_partitions(int num_partitions);
+ // Indicates whether to use SPMD (true) or MPMD (false) partitioning when
+ // num_partitions > 1 and XLA is requested to partition the input program.
+ bool use_spmd_partitioning() const { return use_spmd_partitioning_; }
+ ExecutableBuildOptions& set_use_spmd_partitioning(bool use_spmd_partitioning);
+
// If set, this specifies a static device assignment for the computation.
// Otherwise, the computation will be compiled generically and can be run with
// any device assignment compatible with the computation's replica and
@@ -104,6 +109,7 @@
se::DeviceMemoryAllocator* device_allocator_ = nullptr;
int num_replicas_ = 1;
int num_partitions_ = 1;
+ bool use_spmd_partitioning_ = false;
absl::optional<DeviceAssignment> device_assignment_;
bool alias_passthrough_params_ = false;
};
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 0b146a4..bd70ce8 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -1301,7 +1301,6 @@
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
- HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
TF_RETURN_IF_ERROR(
@@ -1314,30 +1313,45 @@
window_dimensions[i] =
rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
}
- TF_ASSIGN_OR_RETURN(*instr.mutable_window(),
+
+ TF_ASSIGN_OR_RETURN(Window window,
ShapeInference::InferWindowFromDimensions(
window_dimensions, window_strides, padding,
lhs_dilation, rhs_dilation));
-
- TF_ASSIGN_OR_RETURN(
- Shape shape, ShapeInference::InferConvolveShape(
- *lhs_shape, *rhs_shape, feature_group_count,
- batch_group_count, instr.window(), dimension_numbers));
- *instr.mutable_shape() = shape.ToProto();
-
- *instr.mutable_convolution_dimension_numbers() = dimension_numbers;
- instr.set_feature_group_count(feature_group_count);
- instr.set_batch_group_count(batch_group_count);
-
- if (precision_config != nullptr) {
- *instr.mutable_precision_config() = *precision_config;
- }
-
- return AddInstruction(std::move(instr), HloOpcode::kConvolution,
- {lhs, rhs});
+ TF_ASSIGN_OR_RETURN(Shape shape,
+ ShapeInference::InferConvolveShape(
+ *lhs_shape, *rhs_shape, feature_group_count,
+ batch_group_count, window, dimension_numbers));
+ return ConvGeneralDilatedInternal(shape, lhs, rhs, window, window_strides,
+ padding, lhs_dilation, rhs_dilation,
+ dimension_numbers, feature_group_count,
+ batch_group_count, precision_config);
});
}
+StatusOr<XlaOp> XlaBuilder::ConvGeneralDilatedInternal(
+ const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count, int64 batch_group_count,
+ const PrecisionConfig* precision_config) {
+ HloInstructionProto instr;
+ *instr.mutable_shape() = shape.ToProto();
+
+ *instr.mutable_window() = window;
+ *instr.mutable_convolution_dimension_numbers() = dimension_numbers;
+ instr.set_feature_group_count(feature_group_count);
+ instr.set_batch_group_count(batch_group_count);
+
+ if (precision_config != nullptr) {
+ *instr.mutable_precision_config() = *precision_config;
+ }
+
+ return AddInstruction(std::move(instr), HloOpcode::kConvolution, {lhs, rhs});
+}
+
XlaOp XlaBuilder::Fft(XlaOp operand, const FftType fft_type,
const absl::Span<const int64> fft_length) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index bfb97d7..33fe62e 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -491,6 +491,16 @@
int64 batch_group_count = 1,
const PrecisionConfig* precision_config = nullptr);
+ virtual StatusOr<XlaOp> ConvGeneralDilatedInternal(
+ const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
+ absl::Span<const int64> window_strides,
+ absl::Span<const std::pair<int64, int64>> padding,
+ absl::Span<const int64> lhs_dilation,
+ absl::Span<const int64> rhs_dilation,
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count, int64 batch_group_count,
+ const PrecisionConfig* precision_config);
+
XlaOp Fft(XlaOp operand, FftType fft_type,
absl::Span<const int64> fft_length);
diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py
index ef0caff..6d4482a 100644
--- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py
+++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py
@@ -20,6 +20,9 @@
from absl import logging
+# Import xla_client to load shared C++ extensions (just CompileOptions at the
+# time of writing).
+from tensorflow.compiler.xla.python import xla_client # pylint: disable=unused-import
from tensorflow.compiler.xla.python.tpu_driver.client import tpu_client_extension as _tpu_client
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 499c4e2..3349528 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -3846,6 +3846,7 @@
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
"@llvm-project//llvm:core",
"@llvm-project//llvm:transform_utils",
],
diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc
index 8c76e91..ce9c8a4 100644
--- a/tensorflow/compiler/xla/service/compile_only_service.cc
+++ b/tensorflow/compiler/xla/service/compile_only_service.cc
@@ -91,6 +91,7 @@
TF_RETURN_IF_ERROR(options.static_device_assignment().Serialize(
execution_options.mutable_device_assignment()));
}
+ execution_options.set_use_spmd_partitioning(options.use_spmd_partitioning());
for (const AotXlaComputationInstance& instance : computations) {
TF_RET_CHECK(instance.computation.has_host_program_shape());
*execution_options.mutable_shape_with_output_layout() =
diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h
index cf64615..57b24e3 100644
--- a/tensorflow/compiler/xla/service/compiler.h
+++ b/tensorflow/compiler/xla/service/compiler.h
@@ -76,6 +76,7 @@
virtual int64 replica_count() const { return 0; }
virtual int64 num_cores() const { return 0; }
+ virtual bool use_spmd_partitioning() const { return false; }
// Optional allocator that may be used for allocating temp space on the device
// during compilation.
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
index e21ca01..05364a4 100644
--- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
@@ -109,24 +109,6 @@
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator) {
switch (hlo->opcode()) {
- case HloOpcode::kMap:
- return [this, hlo, &operand_to_generator](
- const IrArray::Index& index) -> StatusOr<llvm::Value*> {
- std::vector<llvm::Value*> operands;
- for (int i = 0; i < hlo->operand_count(); i++) {
- TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
- operand_to_generator.at(hlo->operand(i))(index));
- operands.push_back(operand_value);
- }
- return ir_emitter_->EmitElementalMap(*Cast<HloMapInstruction>(hlo),
- operands, llvm_ir::IrName(hlo));
- };
- case HloOpcode::kReduceWindow:
- return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
- return ir_emitter_->EmitElementalReduceWindow(
- Cast<HloReduceWindowInstruction>(hlo),
- operand_to_generator.at(hlo->operand(0)), index);
- };
case HloOpcode::kConvolution:
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
return ir_emitter_->EmitElementalConvolution(
@@ -134,22 +116,6 @@
operand_to_generator.at(hlo->operand(0)),
operand_to_generator.at(hlo->operand(1)), index);
};
- case HloOpcode::kReduce:
- return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
- auto reduce_instr = Cast<HloReduceInstruction>(hlo);
- std::vector<llvm_ir::ElementGenerator> input_generators;
- for (const HloInstruction* instr : reduce_instr->inputs()) {
- input_generators.push_back(operand_to_generator.at(instr));
- }
-
- std::vector<llvm_ir::ElementGenerator> initial_value_generators;
- for (const HloInstruction* instr : reduce_instr->init_values()) {
- initial_value_generators.push_back(operand_to_generator.at(instr));
- }
- return ir_emitter_->EmitElementalReduce(
- reduce_instr, std::move(input_generators),
- std::move(initial_value_generators), index);
- };
default:
return ElementalIrEmitter::MakeElementGenerator(hlo,
operand_to_generator);
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
index e3fba93..5c9f667 100644
--- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
@@ -44,6 +44,12 @@
StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
llvm::Value* value) override;
+ StatusOr<std::vector<llvm::Value*>> EmitThreadLocalCall(
+ const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
+ absl::string_view name) override {
+ return ir_emitter_->EmitThreadLocalCall(callee, parameters, name);
+ }
+
IrEmitter* ir_emitter_;
};
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index c19fa77..2b715bf 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -695,101 +695,6 @@
return Status::OK();
}
-llvm::Value* IrEmitter::EmitElementalMap(
- const HloMapInstruction& map_instr,
- absl::Span<llvm::Value* const> elemental_operands, absl::string_view name) {
- return EmitScalarReturningThreadLocalCall(*map_instr.to_apply(),
- elemental_operands, name);
-}
-
-StatusOr<llvm::Value*> IrEmitter::EmitElementalReduceWindow(
- const HloReduceWindowInstruction* reduce_window,
- const llvm_ir::ElementGenerator& input_generator,
- const llvm_ir::IrArray::Index& index) {
- const HloInstruction* operand = reduce_window->operand(0);
- const Window& window = reduce_window->window();
-
- // We fold inputs into the accumulator and initialize it to
- // the initial value on the reduce_window.
- PrimitiveType operand_element_type = operand->shape().element_type();
- llvm::Value* accumulator_address = llvm_ir::EmitAllocaAtFunctionEntry(
- llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
- "reduce_window_accumulator_address", &b_,
- MinimumAlignmentForPrimitiveType(operand_element_type));
- Store(Load(GetEmittedValueFor(reduce_window->operand(1))),
- accumulator_address);
-
- llvm_ir::ForLoopNest loops(IrName(reduce_window, "inner"), &b_);
- std::vector<int64> window_size;
- for (const auto& dim : window.dimensions()) {
- window_size.push_back(dim.size());
- }
- const llvm_ir::IrArray::Index window_index = loops.AddLoopsForShape(
- ShapeUtil::MakeShape(operand_element_type, window_size), "window");
- CHECK_EQ(window_index.size(), index.size());
-
- SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
-
- std::vector<llvm::Value*> input_multi_index(index.size());
- llvm::Value* in_bounds_condition = nullptr;
- for (size_t i = 0; i < index.size(); ++i) {
- llvm::Value* strided_index =
- NSWMul(index[i], b_.getInt64(window.dimensions(i).stride()));
- input_multi_index[i] = NSWSub(
- NSWAdd(strided_index,
- NSWMul(window_index[i],
- b_.getInt64(window.dimensions(i).window_dilation()))),
- b_.getInt64(window.dimensions(i).padding_low()));
-
- // We need to verify that we are not in the dilated base area.
- llvm::Value* dilation_condition =
- ICmpEQ(SRem(input_multi_index[i],
- b_.getInt64(window.dimensions(i).base_dilation())),
- b_.getInt64(0));
- if (in_bounds_condition == nullptr) {
- in_bounds_condition = dilation_condition;
- } else {
- in_bounds_condition = And(in_bounds_condition, dilation_condition);
- }
-
- // Apply base dilation to the index.
- input_multi_index[i] =
- SDiv(input_multi_index[i],
- b_.getInt64(window.dimensions(i).base_dilation()));
-
- // We need to check if 0 <= input_multi_index[i] < bound, as otherwise we
- // are in the padding so that we can skip the computation. That is
- // equivalent to input_multi_index[i] < bound as an *unsigned* comparison,
- // since a negative value will wrap to a large positive value.
- llvm::Value* index_condition =
- ICmpULT(input_multi_index[i],
- b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
- if (in_bounds_condition == nullptr) {
- in_bounds_condition = index_condition;
- } else {
- in_bounds_condition = And(in_bounds_condition, index_condition);
- }
- }
- CHECK(in_bounds_condition != nullptr);
-
- llvm_ir::LlvmIfData if_data =
- llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
- SetToFirstInsertPoint(if_data.true_block, &b_);
-
- // We are not in the padding, so carry out the computation.
- llvm_ir::IrArray::Index input_index(input_multi_index, operand->shape(),
- b_.getInt64Ty());
- TF_ASSIGN_OR_RETURN(llvm::Value* const input_value,
- input_generator(input_index));
- llvm::Value* result = EmitScalarReturningThreadLocalCall(
- *reduce_window->to_apply(), {Load(accumulator_address), input_value},
- "reducer_function");
- Store(result, accumulator_address);
-
- SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
- return Load(accumulator_address);
-}
-
Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
// Pseudo code for reduce window:
//
@@ -2099,108 +2004,6 @@
return true;
}
-StatusOr<llvm::Value*> IrEmitter::EmitElementalReduce(
- const HloReduceInstruction* reduce,
- std::vector<llvm_ir::ElementGenerator> input_generators,
- std::vector<llvm_ir::ElementGenerator> initial_value_generators,
- const llvm_ir::IrArray::Index& index) {
- const Shape& out_shape = reduce->shape();
- bool is_variadic = !out_shape.IsArray();
- int accumulators_count = 1;
- if (is_variadic) {
- CHECK(out_shape.IsTuple());
- accumulators_count = out_shape.tuple_shapes_size();
- }
-
- absl::Span<const int64> reduced_dimensions(reduce->dimensions());
-
- std::vector<llvm::Value*> accumulator_addrs;
- std::vector<llvm::Type*> accumulator_types;
- for (int i = 0; i < accumulators_count; i++) {
- const Shape& element_shape =
- is_variadic ? out_shape.tuple_shapes(i) : out_shape;
- PrimitiveType accumulator_type = element_shape.element_type();
- llvm::Type* accumulator_llvm_type =
- llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_);
- accumulator_types.push_back(accumulator_llvm_type);
-
- // Initialize an accumulator with init_value.
- llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry(
- accumulator_llvm_type, "accumulator_" + std::to_string(i), &b_,
- MinimumAlignmentForPrimitiveType(accumulator_type));
- TF_ASSIGN_OR_RETURN(
- llvm::Value* const init_value,
- initial_value_generators[i](llvm_ir::IrArray::Index(index.GetType())));
- Store(init_value, accumulator_addr);
- accumulator_addrs.push_back(accumulator_addr);
- }
-
- // The enclosing loops go over all the target elements. Now we have to compute
- // the actual target element. For this, we build a new loop nest to iterate
- // over all the reduction dimensions in the argument.
- // AddLoopsForShapeOnDimensions will return an Index where induction Value*s
- // are placed for each dimension in dimensions, and all the rest are nullptrs.
- llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &b_);
- const HloInstruction* arg = reduce->operand(0);
- std::vector<llvm::Value*> input_multi_index =
- loops.AddLoopsForShapeOnDimensions(arg->shape(), reduced_dimensions,
- "reduction_dim");
-
- SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_);
-
- // Build a full index for the input argument, using input_multi_index as the
- // base. In input_multi_index only the reduction dimensions are filled in. We
- // fill in the rest of the dimensions with induction Value*s taken from
- // 'index' which iterates over the target array. See the high-level
- // description in the XLA documentation for details.
- llvm_ir::IrArray::Index::const_iterator it = index.begin();
-
- for (auto& i : input_multi_index) {
- if (i == nullptr) {
- i = *it++;
- }
- }
- CHECK(index.end() == it);
- llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(),
- b_.getInt64Ty());
-
- std::vector<llvm::Value*> reduction_operands;
- for (llvm::Value* accum : accumulator_addrs) {
- llvm::Value* accum_value = Load(accum);
- reduction_operands.push_back(accum_value);
- }
-
- for (int i = 0; i < accumulators_count; i++) {
- TF_ASSIGN_OR_RETURN(llvm::Value* const input_element,
- input_generators[i](input_index));
- reduction_operands.push_back(input_element);
- }
-
- std::vector<llvm::Value*> results = EmitThreadLocalCall(
- *reduce->to_apply(), reduction_operands, "reduce_function");
-
- CHECK(results.size() == accumulators_count);
- for (int i = 0; i < accumulators_count; i++) {
- Store(results[i], accumulator_addrs[i]);
- }
- SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
-
- if (is_variadic) {
- // Emit a structure, as that what the LoopEmitter expects.
- llvm::Value* returned_structure = llvm::UndefValue::get(
- llvm::StructType::get(b_.getContext(), accumulator_types));
- for (int i = 0; i < accumulators_count; i++) {
- llvm::Value* accumulator_value = Load(accumulator_addrs[i]);
- returned_structure =
- b_.CreateInsertValue(returned_structure, accumulator_value, i);
- }
- return returned_structure;
- } else {
- CHECK_EQ(accumulator_addrs.size(), 1);
- return Load(accumulator_addrs[0]);
- }
-}
-
Status IrEmitter::HandleReduce(HloInstruction* reduce) {
auto arg = reduce->mutable_operand(0);
auto init_value = reduce->mutable_operand(1);
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index cc5aa3f..24524c6 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -58,6 +58,8 @@
// functions.
class IrEmitter : public DfsHloVisitorWithDefault,
public IrBuilderMixin<IrEmitter> {
+ friend class CpuElementalIrEmitter;
+
public:
using GeneratorForOperandIrArrays =
std::function<std::vector<llvm_ir::IrArray>()>;
@@ -113,28 +115,12 @@
// Emit an LLVM global variable for every constant buffer allocation.
Status EmitConstantGlobals();
- // Emit code to map one element according to `map_instr`.
- llvm::Value* EmitElementalMap(
- const HloMapInstruction& map_instr,
- absl::Span<llvm::Value* const> elemental_operands,
- absl::string_view name);
- // Emit code to emit the element at `index` for a reduce window instruction.
- StatusOr<llvm::Value*> EmitElementalReduceWindow(
- const HloReduceWindowInstruction* reduce_window,
- const llvm_ir::ElementGenerator& input_generator,
- const llvm_ir::IrArray::Index& index);
// Emit code to emit the element at `index` for a convolution instruction.
StatusOr<llvm::Value*> EmitElementalConvolution(
const HloConvolutionInstruction* convolution,
const llvm_ir::ElementGenerator& input_generator,
const llvm_ir::ElementGenerator& kernel_generator,
const llvm_ir::IrArray::Index& index);
- // Emit code to emit the element at `index` for a reduce instruction.
- StatusOr<llvm::Value*> EmitElementalReduce(
- const HloReduceInstruction* reduce,
- std::vector<llvm_ir::ElementGenerator> input_generators,
- std::vector<llvm_ir::ElementGenerator> initial_value_generator,
- const llvm_ir::IrArray::Index& index);
protected:
//
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 30300b8..8cb660d 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -2422,6 +2422,43 @@
-> StatusOr<llvm::Value*> {
return EmitElementalDot(hlo, operand_to_generator, dot_result_index);
};
+ case HloOpcode::kMap:
+ return [this, hlo, &operand_to_generator](
+ const IrArray::Index& index) -> StatusOr<llvm::Value*> {
+ std::vector<llvm::Value*> operands;
+ for (int i = 0; i < hlo->operand_count(); i++) {
+ TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
+ operand_to_generator.at(hlo->operand(i))(index));
+ operands.push_back(operand_value);
+ }
+ std::vector<llvm_ir::ElementGenerator> input_generators;
+ for (const HloInstruction* instr : hlo->operands()) {
+ input_generators.push_back(operand_to_generator.at(instr));
+ }
+ return EmitElementalMap(Cast<HloMapInstruction>(hlo), operands);
+ };
+ case HloOpcode::kReduceWindow:
+ return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
+ return EmitElementalReduceWindow(
+ Cast<HloReduceWindowInstruction>(hlo),
+ operand_to_generator.at(hlo->operand(0)),
+ operand_to_generator.at(hlo->operand(1)), index);
+ };
+ case HloOpcode::kReduce:
+ return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
+ auto reduce_instr = Cast<HloReduceInstruction>(hlo);
+ std::vector<llvm_ir::ElementGenerator> input_generators;
+ for (const HloInstruction* instr : reduce_instr->inputs()) {
+ input_generators.push_back(operand_to_generator.at(instr));
+ }
+
+ std::vector<llvm_ir::ElementGenerator> initial_value_generators;
+ for (const HloInstruction* instr : reduce_instr->init_values()) {
+ initial_value_generators.push_back(operand_to_generator.at(instr));
+ }
+ return EmitElementalReduce(reduce_instr, std::move(input_generators),
+ std::move(initial_value_generators), index);
+ };
default:
return [hlo](const IrArray::Index& index) {
return Unimplemented("Unhandled opcode for elemental IR emission: %s",
@@ -2451,4 +2488,215 @@
return complex;
}
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalMap(
+ const HloMapInstruction* map_instr,
+ absl::Span<llvm::Value* const> elemental_operands) {
+ TF_ASSIGN_OR_RETURN(
+ std::vector<llvm::Value*> values,
+ EmitThreadLocalCall(*map_instr->to_apply(), elemental_operands,
+ llvm_ir::IrName(map_instr)));
+ CHECK_EQ(values.size(), 1);
+ return values[0];
+}
+
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalReduceWindow(
+ const HloReduceWindowInstruction* reduce_window,
+ const llvm_ir::ElementGenerator& input_generator,
+ const llvm_ir::ElementGenerator& initial_value_generator,
+ const llvm_ir::IrArray::Index& index) {
+ // Pseudocode:
+ // for each index I in output
+ // value = init_value
+ // for each index W in window
+ // for each dimension i from 0 to rank - 1
+ // (input index I)[i] = O[i] * stride[i] + W[i] - pad_low[i]
+ // if I in bounds of input
+ // value = function(value, input[I])
+ // output[O] = value
+ const HloInstruction* operand = reduce_window->operand(0);
+ const Window& window = reduce_window->window();
+
+ PrimitiveType operand_element_type = operand->shape().element_type();
+ llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry(
+ llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
+ "reduce_window_accum_ptr", b_);
+ {
+ TF_ASSIGN_OR_RETURN(
+ llvm::Value* const init_value,
+ initial_value_generator(llvm_ir::IrArray::Index(index.GetType())));
+ Store(init_value, accum_ptr);
+ }
+
+ llvm::Type* index_type = index.GetType();
+ auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
+ return index.GetConstantWithIndexType(c);
+ };
+
+ llvm_ir::ForLoopNest loops(IrName(reduce_window), b_, index_type);
+ std::vector<int64> window_size;
+ for (const auto& dim : window.dimensions()) {
+ window_size.push_back(dim.size());
+ }
+ const IrArray::Index window_index = loops.AddLoopsForShape(
+ ShapeUtil::MakeShape(operand_element_type, window_size), "window");
+ CHECK_EQ(window_index.size(), index.size());
+
+ SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b_);
+
+ std::vector<llvm::Value*> input_multi_index(index.size());
+ llvm::Value* in_bounds = b_->getInt1(true);
+ for (size_t i = 0; i < index.size(); ++i) {
+ llvm::Value* stridden_index =
+ NSWMul(index[i], index_typed_const(window.dimensions(i).stride()));
+ input_multi_index[i] = NSWSub(
+ NSWAdd(
+ stridden_index,
+ NSWMul(window_index[i],
+ index_typed_const(window.dimensions(i).window_dilation()))),
+ index_typed_const(window.dimensions(i).padding_low()));
+
+ // We need to verify that we are not in the dilated base area.
+ llvm::Value* dilation_condition =
+ ICmpEQ(SRem(input_multi_index[i],
+ index_typed_const(window.dimensions(i).base_dilation())),
+ index_typed_const(0));
+ in_bounds = And(in_bounds, dilation_condition);
+
+ // Apply base dilation to the index.
+ input_multi_index[i] =
+ SDiv(input_multi_index[i],
+ index_typed_const(window.dimensions(i).base_dilation()));
+
+ // We must check whether 0 <= input_multi_index[i] < bound, as
+ // otherwise we are in the pad and so can skip the computation. This
+ // comparison is equivalent to the unsigned comparison
+ // input_multi_index[i] < bound, as a negative value wraps to a large
+ // positive value.
+ in_bounds = And(in_bounds,
+ ICmpULT(input_multi_index[i],
+ index_typed_const(operand->shape().dimensions(i))));
+ }
+
+ llvm_ir::LlvmIfData if_data =
+ llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_);
+ SetToFirstInsertPoint(if_data.true_block, b_);
+
+ // We are not in pad, so do the computation.
+ IrArray::Index input_index(input_multi_index, operand->shape(), index_type);
+ TF_ASSIGN_OR_RETURN(llvm::Value * input_value, input_generator(input_index));
+ TF_ASSIGN_OR_RETURN(
+ std::vector<llvm::Value*> accum_values,
+ EmitThreadLocalCall(*reduce_window->to_apply(),
+ {Load(accum_ptr), input_value}, "reducer_function"));
+ CHECK_EQ(accum_values.size(), 1);
+ Store(accum_values[0], accum_ptr);
+
+ SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_);
+ return Load(accum_ptr);
+}
+
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalReduce(
+ const HloReduceInstruction* reduce,
+ std::vector<llvm_ir::ElementGenerator> input_generators,
+ std::vector<llvm_ir::ElementGenerator> initial_value_generators,
+ const llvm_ir::IrArray::Index& index) {
+ const Shape& out_shape = reduce->shape();
+ bool is_variadic = !out_shape.IsArray();
+ int accumulators_count = 1;
+ if (is_variadic) {
+ CHECK(out_shape.IsTuple());
+ accumulators_count = out_shape.tuple_shapes_size();
+ }
+
+ absl::Span<const int64> reduced_dimensions(reduce->dimensions());
+
+ std::vector<llvm::Value*> accumulator_addrs;
+ std::vector<llvm::Type*> accumulator_types;
+ llvm::Type* index_type = index.GetType();
+ for (int i = 0; i < accumulators_count; i++) {
+ const Shape& element_shape =
+ is_variadic ? out_shape.tuple_shapes(i) : out_shape;
+ PrimitiveType accumulator_type = element_shape.element_type();
+ llvm::Type* accumulator_llvm_type =
+ llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_);
+ accumulator_types.push_back(accumulator_llvm_type);
+
+ // Initialize an accumulator with init_value.
+ llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry(
+ accumulator_llvm_type, "accumulator_" + std::to_string(i), b());
+ TF_ASSIGN_OR_RETURN(
+ llvm::Value* const init_value,
+ initial_value_generators[i](llvm_ir::IrArray::Index(index_type)));
+ Store(init_value, accumulator_addr);
+ accumulator_addrs.push_back(accumulator_addr);
+ }
+
+ // The enclosing loops go over all the target elements. Now we have to compute
+ // the actual target element. For this, we build a new loop nest to iterate
+ // over all the reduction dimensions in the argument.
+ // AddLoopsForShapeOnDimensions will return an Index where induction Value*s
+ // are placed for each dimension in dimensions, and all the rest are nullptrs.
+ llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), b(), index_type);
+ const HloInstruction* arg = reduce->operand(0);
+ std::vector<llvm::Value*> input_multi_index =
+ loops.AddLoopsForShapeOnDimensions(arg->shape(), reduced_dimensions,
+ "reduction_dim");
+
+ SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b());
+
+ // Build a full index for the input argument, using input_multi_index as the
+ // base. In input_multi_index only the reduction dimensions are filled in. We
+ // fill in the rest of the dimensions with induction Value*s taken from
+ // 'index' which iterates over the target array. See the high-level
+ // description in the XLA documentation for details.
+ auto it = index.begin();
+
+ for (auto& i : input_multi_index) {
+ if (i == nullptr) {
+ i = *it++;
+ }
+ }
+ CHECK(index.end() == it);
+ llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(),
+ index_type);
+
+ std::vector<llvm::Value*> reduction_operands;
+ for (llvm::Value* accum : accumulator_addrs) {
+ llvm::Value* accum_value = Load(accum);
+ reduction_operands.push_back(accum_value);
+ }
+
+ for (int i = 0; i < accumulators_count; i++) {
+ TF_ASSIGN_OR_RETURN(llvm::Value* const input_element,
+ input_generators[i](input_index));
+ reduction_operands.push_back(input_element);
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ std::vector<llvm::Value*> results,
+ EmitThreadLocalCall(*reduce->to_apply(), reduction_operands,
+ "reduce_function"));
+
+ CHECK(results.size() == accumulators_count);
+ for (int i = 0; i < accumulators_count; i++) {
+ Store(results[i], accumulator_addrs[i]);
+ }
+ SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b());
+
+ if (is_variadic) {
+ // Emit a structure, as that what the LoopEmitter expects.
+ llvm::Value* returned_structure = llvm::UndefValue::get(
+ llvm::StructType::get(b()->getContext(), accumulator_types));
+ for (int i = 0; i < accumulators_count; i++) {
+ llvm::Value* accumulator_value = Load(accumulator_addrs[i]);
+ returned_structure =
+ b()->CreateInsertValue(returned_structure, accumulator_value, i);
+ }
+ return returned_structure;
+ } else {
+ CHECK_EQ(accumulator_addrs.size(), 1);
+ return Load(accumulator_addrs[0]);
+ }
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
index 94e8f1d..06a9d7b 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
@@ -17,12 +17,17 @@
#define TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_
#include <unordered_map>
+#include <vector>
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -220,6 +225,26 @@
const HloToElementGeneratorMap& operand_to_generator,
const llvm_ir::IrArray::Index& dot_result_index);
+ virtual StatusOr<std::vector<llvm::Value*>> EmitThreadLocalCall(
+ const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
+ absl::string_view name) = 0;
+
+ StatusOr<llvm::Value*> EmitElementalMap(
+ const HloMapInstruction* map_instr,
+ absl::Span<llvm::Value* const> elemental_operands);
+
+ StatusOr<llvm::Value*> EmitElementalReduceWindow(
+ const HloReduceWindowInstruction* reduce_window,
+ const llvm_ir::ElementGenerator& input_generator,
+ const llvm_ir::ElementGenerator& initial_value_generator,
+ const llvm_ir::IrArray::Index& index);
+
+ StatusOr<llvm::Value*> EmitElementalReduce(
+ const HloReduceInstruction* reduce,
+ std::vector<llvm_ir::ElementGenerator> input_generators,
+ std::vector<llvm_ir::ElementGenerator> initial_value_generators,
+ const llvm_ir::IrArray::Index& index);
+
llvm::IRBuilder<>* const b_;
llvm::Module* module_;
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 61bc412..bff8734 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -17,15 +17,15 @@
"tf_cuda_library",
)
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
-load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
+load(
+ "@local_config_rocm//rocm:build_defs.bzl",
+ "if_rocm",
+ "if_rocm_is_configured",
+)
load(
"//tensorflow/core/platform/default:cuda_build_defs.bzl",
"if_cuda_is_configured",
)
-load(
- "@local_config_rocm//rocm:build_defs.bzl",
- "if_rocm_is_configured",
-)
load("//tensorflow:tensorflow.bzl", "if_nccl")
package(
@@ -86,6 +86,7 @@
name = "gpu_types",
hdrs = ["gpu_types.h"],
deps = [
+ "//tensorflow/compiler/xla:types",
"@com_google_absl//absl/types:variant",
],
)
@@ -405,6 +406,7 @@
deps = [
":buffer_allocations",
":gpu_executable_run_options",
+ ":gpu_types",
":hlo_execution_profiler",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla/service:hlo",
@@ -684,7 +686,7 @@
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_pass",
- "//tensorflow/core:autotuning_proto_cc",
+ "//tensorflow/core/protobuf:autotuning_proto_cc",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core/util/proto:proto_utils",
@@ -720,7 +722,7 @@
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:hlo_pass",
- "//tensorflow/core:autotuning_proto_cc",
+ "//tensorflow/core/protobuf:autotuning_proto_cc",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:stream_executor_no_cuda",
@@ -1674,7 +1676,7 @@
protodeps = [
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_proto",
- "//tensorflow/core:autotuning_proto",
+ "//tensorflow/core/protobuf:autotuning_proto",
],
)
@@ -1685,8 +1687,8 @@
deps = [
":gpu_autotuning_proto_cc",
"//tensorflow/compiler/xla:debug_options_flags",
- "//tensorflow/core:autotuning_proto_cc",
"//tensorflow/core:stream_executor_no_cuda",
+ "//tensorflow/core/protobuf:autotuning_proto_cc",
"@com_google_absl//absl/container:flat_hash_map",
],
)
diff --git a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc
index 974db02..485aff0 100644
--- a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc
@@ -104,11 +104,9 @@
return isa_version;
}
-StatusOr<std::pair<std::string, std::vector<uint8>>>
-AMDGPUCompiler::CompileTargetBinary(const HloModule* module,
- llvm::Module* llvm_module,
- GpuVersion gpu_version,
- se::StreamExecutor* stream_exec) {
+StatusOr<GpuTargetBinary> AMDGPUCompiler::CompileTargetBinary(
+ const HloModule* module, llvm::Module* llvm_module, GpuVersion gpu_version,
+ se::StreamExecutor* stream_exec) {
if (rocdl_dir_.empty()) {
// Compute rocdl_dir_ just once and cache it in this member.
rocdl_dir_ = GetROCDLDir(module->config());
@@ -129,7 +127,7 @@
user_post_optimization_hook_(*llvm_module);
}
- return std::pair<std::string, std::vector<uint8>>("", std::move(hsaco));
+ return GpuTargetBinary{"", std::move(hsaco)};
}
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h
index acc5e02..9033585 100644
--- a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h
+++ b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h
@@ -39,7 +39,7 @@
GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) override;
- StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
+ StatusOr<GpuTargetBinary> CompileTargetBinary(
const HloModule* hlo_module, llvm::Module* llvm_module,
GpuVersion gpu_version, se::StreamExecutor* stream_exec) override;
diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
index e31f459..5e7d89c 100644
--- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
@@ -50,7 +50,7 @@
}
}
-Status ConditionalThunk::Initialize(const GpuExecutable& executable,
+Status ConditionalThunk::Initialize(const GpuTargetBinary& target_binary,
se::StreamExecutor* executor) {
if (branch_index_is_bool_) {
TF_RET_CHECK(branch_thunks_.size() == 2);
@@ -58,7 +58,7 @@
TF_RET_CHECK(!branch_thunks_.empty());
}
for (auto& branch_thunk : branch_thunks_) {
- TF_RETURN_IF_ERROR(branch_thunk->Initialize(executable, executor));
+ TF_RETURN_IF_ERROR(branch_thunk->Initialize(target_binary, executor));
}
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h
index 404e213..ba69e1a 100644
--- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h
@@ -52,7 +52,7 @@
ConditionalThunk& operator=(const ConditionalThunk&) = delete;
void ComputeAnnotations() override;
- Status Initialize(const GpuExecutable& executable,
+ Status Initialize(const GpuTargetBinary& target_binary,
se::StreamExecutor* executor) override;
Status ExecuteOnStream(const ExecuteParams& params) override;
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
index c6df786..1be0b1b 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
@@ -305,168 +305,5 @@
return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block);
}
-llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
- const HloInstruction* hlo,
- const HloToElementGeneratorMap& operand_to_generator) {
- switch (hlo->opcode()) {
- case HloOpcode::kMap:
- return [=, &operand_to_generator](
- const IrArray::Index& index) -> StatusOr<llvm::Value*> {
- TF_RET_CHECK(!hlo->operands().empty())
- << "Zero operand map not implemented in GPU backend.";
- TF_RET_CHECK(hlo->to_apply()->num_parameters() > 0);
- std::vector<llvm::Value*> operand_elements;
- for (HloInstruction* operand : hlo->operands()) {
- TF_ASSIGN_OR_RETURN(llvm::Value * value,
- operand_to_generator.at(operand)(index));
- operand_elements.push_back(value);
- }
- return compute_nested_(*hlo->to_apply(), operand_elements);
- };
- case HloOpcode::kReduceWindow:
- // Pseudocode:
- // for each index I in output
- // value = init_value
- // for each index W in window
- // for each dimension i from 0 to rank - 1
- // (input index I)[i] = O[i] * stride[i] + W[i] - pad_low[i]
- // if I in bounds of input
- // value = function(value, input[I])
- // output[O] = value
- return [=, &operand_to_generator](
- const IrArray::Index& index) -> StatusOr<llvm::Value*> {
- const HloInstruction* operand = hlo->operand(0);
- const Window& window = hlo->window();
-
- PrimitiveType operand_element_type = operand->shape().element_type();
- llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry(
- llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
- "reduce_window_accum_ptr", b_);
- {
- TF_ASSIGN_OR_RETURN(llvm::Value * init_value,
- operand_to_generator.at(hlo->operand(1))(
- IrArray::Index(index.GetType())));
- Store(init_value, accum_ptr);
- }
-
- llvm::Type* index_type = index.GetType();
- auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
- return index.GetConstantWithIndexType(c);
- };
-
- llvm_ir::ForLoopNest loops(IrName(hlo), b_, index_type);
- std::vector<int64> window_size;
- for (const auto& dim : window.dimensions()) {
- window_size.push_back(dim.size());
- }
- const IrArray::Index window_index = loops.AddLoopsForShape(
- ShapeUtil::MakeShape(operand_element_type, window_size), "window");
- CHECK_EQ(window_index.size(), index.size());
-
- SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b_);
-
- std::vector<llvm::Value*> input_multi_index(index.size());
- llvm::Value* in_bounds = b_->getInt1(true);
- for (size_t i = 0; i < index.size(); ++i) {
- llvm::Value* stridden_index = NSWMul(
- index[i], index_typed_const(window.dimensions(i).stride()));
- input_multi_index[i] = NSWSub(
- NSWAdd(stridden_index,
- NSWMul(window_index[i],
- index_typed_const(
- window.dimensions(i).window_dilation()))),
- index_typed_const(window.dimensions(i).padding_low()));
-
- // We need to verify that we are not in the dilated base area.
- llvm::Value* dilation_condition = ICmpEQ(
- SRem(input_multi_index[i],
- index_typed_const(window.dimensions(i).base_dilation())),
- index_typed_const(0));
- in_bounds = And(in_bounds, dilation_condition);
-
- // Apply base dilation to the index.
- input_multi_index[i] =
- SDiv(input_multi_index[i],
- index_typed_const(window.dimensions(i).base_dilation()));
-
- // We must check whether 0 <= input_multi_index[i] < bound, as
- // otherwise we are in the pad and so can skip the computation. This
- // comparison is equivalent to the unsigned comparison
- // input_multi_index[i] < bound, as a negative value wraps to a large
- // positive value.
- in_bounds =
- And(in_bounds,
- ICmpULT(input_multi_index[i],
- index_typed_const(operand->shape().dimensions(i))));
- }
-
- llvm_ir::LlvmIfData if_data =
- llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_);
- SetToFirstInsertPoint(if_data.true_block, b_);
-
- // We are not in pad, so do the computation.
- IrArray::Index input_index(input_multi_index, operand->shape(),
- index_type);
- TF_ASSIGN_OR_RETURN(llvm::Value * input_value,
- operand_to_generator.at(operand)(input_index));
- TF_ASSIGN_OR_RETURN(
- llvm::Value * accum_value,
- compute_nested_(*hlo->to_apply(), {Load(accum_ptr), input_value}));
- Store(accum_value, accum_ptr);
-
- SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_);
- return Load(accum_ptr);
- };
- case HloOpcode::kReduce:
- // TODO(b/118332391): This should be supported.
- CHECK_EQ(hlo->operand_count(), 2) << "Did not expect variadic reduce";
- return [=, &operand_to_generator](
- const IrArray::Index& output_index) -> StatusOr<llvm::Value*> {
- const HloInstruction* operand = hlo->operand(0);
- llvm::Value* accum_ptr =
- b()->CreateAlloca(llvm_ir::PrimitiveTypeToIrType(
- hlo->shape().element_type(), module_));
- llvm::Type* index_type = output_index.GetType();
- TF_ASSIGN_OR_RETURN(llvm::Value * init_value,
- operand_to_generator.at(hlo->operand(1))(
- IrArray::Index(index_type)));
- b()->CreateStore(init_value, accum_ptr);
-
- llvm_ir::ForLoopNest loops(IrName(hlo), b_, index_type);
- std::vector<llvm::Value*> input_multi_index =
- loops.AddLoopsForShapeOnDimensions(
- operand->shape(), hlo->dimensions(), "reduction_dim");
- if (!ShapeUtil::IsScalar(hlo->shape())) {
- // Here only input_multi_index[hlo->dimensions()] are non-null, so we
- // must set the rest.
- size_t j = 0;
- for (auto& i : input_multi_index) {
- if (i == nullptr) {
- i = output_index[j++];
- }
- }
- CHECK_EQ(output_index.size(), j);
- }
- llvm_ir::IrArray::Index input_index(
- input_multi_index, hlo->operand(0)->shape(), index_type);
-
- SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b());
- TF_ASSIGN_OR_RETURN(
- llvm::Value * input_value,
- operand_to_generator.at(hlo->operand(0))(input_index));
- TF_ASSIGN_OR_RETURN(
- llvm::Value * accum_value,
- compute_nested_(*hlo->to_apply(),
- {b()->CreateLoad(accum_ptr), input_value}));
- b()->CreateStore(accum_value, accum_ptr);
- SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b());
- return b()->CreateLoad(accum_ptr);
- };
- default:
- return ElementalIrEmitter::MakeElementGenerator(hlo,
- operand_to_generator);
- }
-}
-
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
index c8a58a2..3c4e9f7 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
@@ -47,10 +47,6 @@
llvm::Module* module, llvm::IRBuilder<>* b,
NestedComputer compute_nested);
- llvm_ir::ElementGenerator MakeElementGenerator(
- const HloInstruction* hlo,
- const HloToElementGeneratorMap& operand_to_generator) override;
-
protected:
StatusOr<llvm::Value*> EmitFloatBinaryOp(const HloInstruction* op,
llvm::Value* lhs_value,
@@ -92,6 +88,17 @@
StatusOr<llvm::Value*> EmitComplexAbs(PrimitiveType prim_type,
llvm::Value* value) override;
+ StatusOr<std::vector<llvm::Value*>> EmitThreadLocalCall(
+ const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
+ absl::string_view) override {
+ // TODO(b/118332391): Supported variadic return values.
+ auto result = compute_nested_(callee, parameters);
+ if (!result.ok()) {
+ return result.status();
+ }
+ return std::vector<llvm::Value*>{result.ValueOrDie()};
+ }
+
llvm::Value* EmitThreadId() override;
private:
diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc
index 0a97f66..aacc9de 100644
--- a/tensorflow/compiler/xla/service/gpu/for_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc
@@ -39,9 +39,9 @@
body_thunk_sequence_->ComputeAnnotations();
}
-Status ForThunk::Initialize(const GpuExecutable& executable,
+Status ForThunk::Initialize(const GpuTargetBinary& target_binary,
se::StreamExecutor* executor) {
- TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable, executor));
+ TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(target_binary, executor));
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.h b/tensorflow/compiler/xla/service/gpu/for_thunk.h
index 57402f7..57657b6 100644
--- a/tensorflow/compiler/xla/service/gpu/for_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/for_thunk.h
@@ -38,7 +38,7 @@
ForThunk& operator=(const ForThunk&) = delete;
void ComputeAnnotations() override;
- Status Initialize(const GpuExecutable& executable,
+ Status Initialize(const GpuTargetBinary& target_binary,
se::StreamExecutor* executor) override;
Status ExecuteOnStream(const ExecuteParams& params) override;
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index 5f6dfd7..533ff52 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -565,8 +565,7 @@
GpuVersion gpu_version = GetGpuVersion(stream_exec);
- using BackendCompileResult = std::pair<std::string, std::vector<uint8>>;
- TF_ASSIGN_OR_RETURN(BackendCompileResult backend_result,
+ TF_ASSIGN_OR_RETURN(GpuTargetBinary backend_result,
CompileTargetBinary(module.get(), &llvm_module,
gpu_version, stream_exec));
@@ -578,6 +577,11 @@
thunk_schedule->ToString());
}
+ std::vector<Thunk*> thunks;
+ for (Thunk* thunk : thunk_schedule->TotalOrder()) {
+ thunks.push_back(thunk);
+ }
+
std::unique_ptr<HloProfileIndexMap> profile_index_map;
std::unique_ptr<HloProfilePrinterData> profile_printer;
@@ -597,14 +601,19 @@
}
auto* gpu_executable = new GpuExecutable(
- backend_result.first, backend_result.second, gpu_version,
- std::move(thunk_schedule), std::move(module),
- std::move(buffer_assignment), std::move(profile_printer),
- std::move(profile_index_map));
+ std::move(backend_result), gpu_version, std::move(thunk_schedule),
+ std::move(module), std::move(buffer_assignment),
+ std::move(profile_printer), std::move(profile_index_map));
if (embed_ir_in_executable) {
DCHECK_NE("", ir_module_string_before_opt);
gpu_executable->set_ir_module_string(ir_module_string_before_opt);
}
+
+ for (Thunk* thunk : thunks) {
+ TF_RETURN_IF_ERROR(
+ thunk->Initialize(gpu_executable->target_binary(), stream_exec));
+ }
+
return std::unique_ptr<Executable>(gpu_executable);
}
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h
index b52af53..deb5d78 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h
@@ -74,10 +74,9 @@
virtual GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) = 0;
- virtual StatusOr<std::pair<std::string, std::vector<uint8>>>
- CompileTargetBinary(const HloModule* hlo_module, llvm::Module* llvm_module,
- GpuVersion gpu_version,
- se::StreamExecutor* stream_exec) = 0;
+ virtual StatusOr<GpuTargetBinary> CompileTargetBinary(
+ const HloModule* hlo_module, llvm::Module* llvm_module,
+ GpuVersion gpu_version, se::StreamExecutor* stream_exec) = 0;
Status PrepareHloModuleForIrEmitting(HloModule* hlo_module);
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index 2df6b50..ebd3630 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -52,16 +52,15 @@
// Implementation note: HLO profiling is always enabled for GPU executables,
// since we can use timers around thunks.
GpuExecutable::GpuExecutable(
- const string& text, const std::vector<uint8>& binary,
- GpuVersion gpu_version, std::unique_ptr<const ThunkSchedule> thunk_schedule,
+ GpuTargetBinary target_binary, GpuVersion gpu_version,
+ std::unique_ptr<const ThunkSchedule> thunk_schedule,
std::shared_ptr<HloModule> hlo_module,
std::shared_ptr<const BufferAssignment> assignment,
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
: Executable(std::move(hlo_module), std::move(hlo_profile_printer_data),
std::move(hlo_profile_index_map)),
- text_(text),
- binary_(binary),
+ target_binary_(std::move(target_binary)),
gpu_version_(gpu_version),
thunk_schedule_(std::move(thunk_schedule)),
assignment_(std::move(assignment)) {
@@ -176,7 +175,6 @@
// module, we won't get any data, but that's probably an OK trade-off.
ScopedAnnotation annotation([&] { return thunk->profile_annotation(); });
- TF_RETURN_IF_ERROR(thunk->Initialize(*this, executor));
int32 stream_no =
thunk_schedule_->StreamNumberForHlo(*thunk->hlo_instruction());
se::Stream* stream =
@@ -469,7 +467,7 @@
int64 GpuExecutable::SizeOfGeneratedCodeInBytes() {
// Non-empty PTX but empty cubin: compilation must have failed, return
// "unknown".
- if (binary().empty() && !text_.empty()) {
+ if (binary().empty() && !text().empty()) {
return -1;
}
return binary().size();
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
index 045a36c..29441c6 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
@@ -52,8 +52,7 @@
// We need to share ownership of hlo_module and assignment with profiler to
// safely keep a reference to these objects during tracing period, thus they
// are passed as shared pointers.
- GpuExecutable(const string& text, const std::vector<uint8>& binary,
- GpuVersion gpu_version,
+ GpuExecutable(GpuTargetBinary target_binary, GpuVersion gpu_version,
std::unique_ptr<const ThunkSchedule> thunk_schedule,
std::shared_ptr<HloModule> hlo_module,
std::shared_ptr<const BufferAssignment> assignment,
@@ -73,12 +72,14 @@
// Returns the compiled code for the computation. The compiled code is PTX in
// Cuda and unused empty string in ROCm.
- const string& text() const { return text_; }
+ const string& text() const { return target_binary_.text; }
// Returns the binary stored in this GpuExecutable. The binary is cubin in
// Cuda, and HSA code object in ROCm. It may be empty, in which case
// compilation is left up to the GPU driver.
- const std::vector<uint8>& binary() const { return binary_; }
+ const std::vector<uint8>& binary() const { return target_binary_.binary; }
+
+ const GpuTargetBinary& target_binary() const { return target_binary_; }
// ExecuteAsyncOnStream will fail if the compute capability of the stream
// doesn't match the compute capability passed to this object's constructor.
@@ -131,14 +132,7 @@
// This string should be modified only before ExecuteOnStream.
string ir_module_string_;
- // The compiled code for the computation.
- const string text_;
-
- // The GPU machine code for the computation, targeting GPUs at
- // compute_capability_.
- //
- // May be empty, in which case we leave compilation up to the GPU driver.
- const std::vector<uint8> binary_;
+ const GpuTargetBinary target_binary_;
// The GPU version for compute compatibility check.
GpuVersion gpu_version_;
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_types.h b/tensorflow/compiler/xla/service/gpu/gpu_types.h
index 1c51040..5c8b809 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_types.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_types.h
@@ -16,7 +16,11 @@
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TYPES_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TYPES_H_
+#include <string>
+#include <vector>
+
#include "absl/types/variant.h"
+#include "tensorflow/compiler/xla/types.h"
namespace xla {
namespace gpu {
@@ -25,6 +29,19 @@
// it comprises a pair of integers denoting major and minor version.
// On ROCm platform, it comprises one integer for AMD GCN ISA version.
using GpuVersion = absl::variant<std::pair<int, int>, int>;
+
+// A struct to carry around compiled results by the GPU assembler.
+struct GpuTargetBinary {
+ GpuTargetBinary(const GpuTargetBinary& other) = delete;
+ GpuTargetBinary(GpuTargetBinary&& other) = default;
+
+ // The text format of the compiled result, e.g. PTX.
+ std::string text;
+
+ // The actual compiled binary.
+ std::vector<tensorflow::uint8> binary;
+};
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
index d976b5d..0b5010e 100644
--- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
@@ -18,7 +18,6 @@
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
-#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -42,7 +41,7 @@
kernel_name_(kernel_name),
unroll_factor_(unroll_factor) {}
-Status KernelThunk::Initialize(const GpuExecutable& executable,
+Status KernelThunk::Initialize(const GpuTargetBinary& target_binary,
se::StreamExecutor* executor) {
tensorflow::mutex_lock lock(mutex_);
@@ -55,8 +54,10 @@
if (kernel_cache_.end() == it) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<se::KernelBase> kernel,
- CreateKernel(kernel_name_, args_.size(), executable.text(),
- executable.binary(), executor));
+ CreateKernel(kernel_name_, args_.size(), target_binary.text,
+ target_binary.binary, executor));
+ CHECK(!target_binary.binary.empty());
+ CHECK(kernel);
kernel_cache_.emplace(executor, std::move(kernel));
}
diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h
index 8835188..97a1d08 100644
--- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h
@@ -35,8 +35,6 @@
namespace xla {
namespace gpu {
-class GpuExecutable;
-
// This class stores everything that StreamExecutor needs for launching a
// kernel. It implements the ExecuteOnStream interface for GpuExecutable to
// invoke the corresponding kernel.
@@ -58,7 +56,7 @@
int unroll_factor() const { return unroll_factor_; }
void SetLaunchDimensions(const LaunchDimensions& launch_dims);
- Status Initialize(const GpuExecutable& executable,
+ Status Initialize(const GpuTargetBinary& target_binary,
se::StreamExecutor* executor) override;
Status ExecuteOnStream(const ExecuteParams& params) override;
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index 0196267..cf6fe92 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -295,11 +295,9 @@
return std::make_pair(cc_major, cc_minor);
}
-StatusOr<std::pair<std::string, std::vector<uint8>>>
-NVPTXCompiler::CompileTargetBinary(const HloModule* module,
- llvm::Module* llvm_module,
- GpuVersion gpu_version,
- se::StreamExecutor* stream_exec) {
+StatusOr<GpuTargetBinary> NVPTXCompiler::CompileTargetBinary(
+ const HloModule* module, llvm::Module* llvm_module, GpuVersion gpu_version,
+ se::StreamExecutor* stream_exec) {
std::pair<int, int> compute_capability =
absl::get<std::pair<int, int>>(gpu_version);
@@ -340,8 +338,7 @@
stream_exec, ptx, compute_capability.first, compute_capability.second,
module->config());
- return std::pair<std::string, std::vector<uint8>>(std::move(ptx),
- std::move(cubin));
+ return GpuTargetBinary{std::move(ptx), std::move(cubin)};
}
std::vector<uint8> NVPTXCompiler::CompileGpuAsmOrGetCachedResult(
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h
index e69be94..ec550b5 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h
@@ -48,7 +48,7 @@
GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) override;
- StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
+ StatusOr<GpuTargetBinary> CompileTargetBinary(
const HloModule* hlo_module, llvm::Module* llvm_module,
GpuVersion gpu_version, se::StreamExecutor* stream_exec) override;
diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc
index 025ca60..bd26033 100644
--- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc
@@ -34,10 +34,10 @@
}
}
-Status SequentialThunk::Initialize(const GpuExecutable& executable,
+Status SequentialThunk::Initialize(const GpuTargetBinary& target_binary,
se::StreamExecutor* executor) {
for (auto& thunk : thunks_) {
- TF_RETURN_IF_ERROR(thunk->Initialize(executable, executor));
+ TF_RETURN_IF_ERROR(thunk->Initialize(target_binary, executor));
}
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h
index 3abb82c..b547566 100644
--- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h
@@ -40,7 +40,7 @@
const std::vector<std::unique_ptr<Thunk>>& thunks() const { return thunks_; }
void ComputeAnnotations() override;
- Status Initialize(const GpuExecutable& executable,
+ Status Initialize(const GpuTargetBinary& target_binary,
se::StreamExecutor* executor) override;
Status ExecuteOnStream(const ExecuteParams& params) override;
diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h
index e9be41b..7aff9ca 100644
--- a/tensorflow/compiler/xla/service/gpu/thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/thunk.h
@@ -22,6 +22,7 @@
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_types.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/core/lib/core/status.h"
@@ -30,8 +31,6 @@
namespace xla {
namespace gpu {
-class GpuExecutable;
-
// Thunk acts as the bridge between IrEmitter and GpuExecutable. It stores the
// metadata IrEmitter generates for GpuExecutable to invoke an HloInstruction.
//
@@ -97,7 +96,7 @@
// This may be called multiple times. Its main purpose is to give us a chance
// to do initialization outside of ExecuteOnStream() so that the
// time spent initializing doesn't count towards our execution profile.
- virtual Status Initialize(const GpuExecutable& /*executable*/,
+ virtual Status Initialize(const GpuTargetBinary& /*target_binary*/,
se::StreamExecutor* /*executor*/) {
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc
index 4134cd3..2650508 100644
--- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc
@@ -45,11 +45,11 @@
body_thunk_sequence_->ComputeAnnotations();
}
-Status WhileThunk::Initialize(const GpuExecutable& executable,
+Status WhileThunk::Initialize(const GpuTargetBinary& target_binary,
se::StreamExecutor* executor) {
TF_RETURN_IF_ERROR(
- condition_thunk_sequence_->Initialize(executable, executor));
- TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable, executor));
+ condition_thunk_sequence_->Initialize(target_binary, executor));
+ TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(target_binary, executor));
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.h b/tensorflow/compiler/xla/service/gpu/while_thunk.h
index 31db01b..77ee010 100644
--- a/tensorflow/compiler/xla/service/gpu/while_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/while_thunk.h
@@ -47,7 +47,7 @@
WhileThunk& operator=(const WhileThunk&) = delete;
void ComputeAnnotations() override;
- Status Initialize(const GpuExecutable& executable,
+ Status Initialize(const GpuTargetBinary& target_binary,
se::StreamExecutor* executor) override;
Status ExecuteOnStream(const ExecuteParams& params) override;
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index de65ed9..9722d5c 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -420,6 +420,8 @@
if (execution_options->num_partitions() > 0) {
module_config.set_num_partitions(execution_options->num_partitions());
}
+ module_config.set_use_spmd_partitioning(
+ execution_options->use_spmd_partitioning());
if (execution_options->has_device_assignment()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<DeviceAssignment> device_assignment,
DeviceAssignment::Deserialize(
diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h
index b31a9ae..833d0fe 100644
--- a/tensorflow/compiler/xla/service/hlo_module_config.h
+++ b/tensorflow/compiler/xla/service/hlo_module_config.h
@@ -128,6 +128,11 @@
}
int64 num_partitions() const { return num_partitions_; }
+ void set_use_spmd_partitioning(bool use_spmd_partitioning) {
+ use_spmd_partitioning_ = use_spmd_partitioning;
+ }
+ bool use_spmd_partitioning() const { return use_spmd_partitioning_; }
+
// Return a string which unambiguously represents all the fields of this data
// structure. Used for generating a cache key for storing the compiled
// executable.
@@ -199,6 +204,14 @@
std::vector<std::vector<int64>>* mutable_dot_config() { return &dot_config_; }
+ absl::Span<const std::vector<std::vector<int64>>> layout_config() const {
+ return layout_config_;
+ }
+
+ std::vector<std::vector<std::vector<int64>>>* mutable_layout_config() {
+ return &layout_config_;
+ }
+
private:
// If you add new members, be sure to update compilation_cache_key.
@@ -216,6 +229,10 @@
// The number of partitions (model parallelism) to compile this binary for.
int64 num_partitions_ = 1;
+ // Whether to use SPMD (true) or MPMD (false) when num_partitions_ > 0 and XLA
+ // needs to partition the module.
+ bool use_spmd_partitioning_ = false;
+
// The target maximum parallelism at which to partition HLOs for parallel
// execution on the CPU backend.
int64 intra_op_parallelism_threads_ = -1;
@@ -232,6 +249,9 @@
FusionConfigCollection fusion_config_collection_ =
FusionConfigCollection::kOff;
+ // TODO(b/155665133): Consolidate fusion, dot, and layout config into a proto
+ // similar to backend config.
+
// Custom fusion configuration, where fusion_config_[c][v] control if node v
// in computation c must be fused to all its consumers (true) or not (false).
std::vector<std::vector<bool>> fusion_config_;
@@ -240,6 +260,10 @@
// how to convert dot operation v (sorted topologically and by computation) to
// convolution.
std::vector<std::vector<int64>> dot_config_;
+
+ // Layout configuration, where layout_config_[v][i] controls the layout
+ // decision i of operation v.
+ std::vector<std::vector<std::vector<int64>>> layout_config_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 360c8e5..d15a365 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -662,9 +662,11 @@
shape_size_function_(bitcast->operand(0)->shape())) {
return InternalError(
"Bitcast cannot have different shape sizes of output (%d) and operand "
- "(%d)",
+ "(%d) (%s) (%s)",
shape_size_function_(bitcast->shape()),
- shape_size_function_(bitcast->operand(0)->shape()));
+ shape_size_function_(bitcast->operand(0)->shape()),
+ bitcast->shape().ToString(true),
+ bitcast->operand(0)->shape().ToString(true));
}
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 84654bf..13699f3 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -951,7 +951,8 @@
if (!Shape::Equal()
.IgnoreDynamicDimension()
.MinorToMajorOnlyInLayout()(instruction_subshape,
- buffer->shape())) {
+ buffer->shape()) &&
+ instruction->opcode() != HloOpcode::kBitcast) {
return InternalError(
"Layout of instruction %s at index {%s} does not match "
"source LogicalBuffer %s: %s vs %s",
@@ -1798,13 +1799,6 @@
// potential bugs in the layout assignment pass that may accidentally use the
// existing layout.
for (HloInstruction* instruction : computation->instructions()) {
- if (instruction->opcode() == HloOpcode::kBitcast) {
- // bitcasts are inherently layout sensitive and so a bitcast instruction
- // present in the IR before layout assignment is a bug.
- return InternalError(
- "Unexpected bitcast operation seen during layout assignment: %s.",
- instruction->ToString());
- }
// Some instructions carry mandatory layouts in their shape.
if (instruction->opcode() != HloOpcode::kInfeed &&
!IsLayoutConstrainedCustomCall(instruction) &&
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index 304a80c..6e57524 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -814,27 +814,6 @@
EXPECT_THAT(false_result->opcode(), HloOpcode::kCopy);
}
-TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) {
- auto builder = HloComputation::Builder(TestName());
- auto constant0 = builder.AddInstruction(
- HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
- {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
- builder.AddInstruction(
- HloInstruction::CreateBitcast(constant0->shape(), constant0));
- auto m = CreateNewVerifiedModule();
- m->AddEntryComputation(builder.Build());
-
- ComputationLayout computation_layout(
- m->entry_computation()->ComputeProgramShape());
- LayoutAssignment layout_assignment(&computation_layout);
- Status error_status = layout_assignment.Run(m.get()).status();
- EXPECT_FALSE(error_status.ok());
- EXPECT_THAT(
- error_status.error_message(),
- ::testing::HasSubstr(
- "Unexpected bitcast operation seen during layout assignment"));
-}
-
TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
// Pin non matching layouts to parameter and root.
const char* module_str = R"(
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc
index ef8ddfc..c80646e 100644
--- a/tensorflow/compiler/xla/service/local_service.cc
+++ b/tensorflow/compiler/xla/service/local_service.cc
@@ -112,6 +112,8 @@
}
execution_options.set_num_replicas(build_options.num_replicas());
execution_options.set_num_partitions(build_options.num_partitions());
+ execution_options.set_use_spmd_partitioning(
+ build_options.use_spmd_partitioning());
if (build_options.has_device_assignment()) {
TF_CHECK_OK(build_options.device_assignment().Serialize(
execution_options.mutable_device_assignment()));
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/BUILD
index cd679f7..a57e430 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/BUILD
+++ b/tensorflow/compiler/xla/service/mlir_gpu/BUILD
@@ -185,11 +185,11 @@
"@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:LinalgToLLVM",
"@llvm-project//mlir:LinalgTransforms",
- "@llvm-project//mlir:LoopOps",
- "@llvm-project//mlir:LoopOpsTransforms",
"@llvm-project//mlir:LoopsToGPUPass",
"@llvm-project//mlir:NVVMDialect",
"@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:SCFDialect",
+ "@llvm-project//mlir:SCFTransforms",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc
index 33d3690..847ad91 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc
@@ -31,9 +31,9 @@
#include "mlir/Dialect/LLVMIR/NVVMDialect.h" // from @llvm-project
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project
#include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project
-#include "mlir/Dialect/LoopOps/LoopOps.h" // from @llvm-project
-#include "mlir/Dialect/LoopOps/Passes.h" // from @llvm-project
-#include "mlir/Dialect/LoopOps/Transforms.h" // from @llvm-project
+#include "mlir/Dialect/SCF/Passes.h" // from @llvm-project
+#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
+#include "mlir/Dialect/SCF/Transforms.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
@@ -45,6 +45,7 @@
#include "mlir/IR/Region.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
+#include "mlir/Transforms/BufferPlacement.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "mlir/Transforms/LoopUtils.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
@@ -60,34 +61,6 @@
using ::mlir::xla_lhlo::FusionOp;
-// Following are some small transformations that are required to clean up code
-// after lowering from linalg to loops.
-
-// A simple pass that applies lowering of HLO to LHLO only within LHLO ops that
-// contain regions with HLO ops, e.g. FusionOp, ReduceOp, SelectAndScatterOp.
-// This is needed, as these ops are not closed from above and hence nested pass
-// managers can not be applied.
-struct NestedHloRegionsConverter
- : public mlir::PassWrapper<NestedHloRegionsConverter,
- ::mlir::FunctionPass> {
- void runOnFunction() override {
- auto& ctx = getContext();
- mlir::OwningRewritePatternList patterns;
- mlir::ConversionTarget target(ctx);
- target.addLegalDialect<::mlir::xla_lhlo::XlaLhloDialect>();
- ::mlir::xla_hlo::populateHLOToLHLOConversionPattern(&ctx, &patterns);
-
- getFunction().walk([&](mlir::Operation* op) {
- if (op->getNumRegions() == 0) {
- return;
- }
- if (failed(applyPartialConversion(op, target, patterns, nullptr))) {
- signalPassFailure();
- }
- });
- }
-};
-
// Replaces a FusionOp by the operations contained in its region.
struct FusionOpRemover
: public mlir::PassWrapper<FusionOpRemover, ::mlir::FunctionPass> {
@@ -132,7 +105,7 @@
// No store operation found. Continue search outside of the parallel
// loop if block is in a parallel loop.
if (auto parallelOp =
- llvm::dyn_cast<mlir::loop::ParallelOp>(block->getParentOp())) {
+ llvm::dyn_cast<mlir::scf::ParallelOp>(block->getParentOp())) {
return findStore(parallelOp.getOperation(), matches);
}
return {};
@@ -388,8 +361,8 @@
struct FuseInnerParallelLoops
: public mlir::PassWrapper<FuseInnerParallelLoops, mlir::FunctionPass> {
void runOnFunction() override {
- getFunction().walk([](mlir::loop::ParallelOp op) {
- mlir::loop::naivelyFuseParallelOps(op.region());
+ getFunction().walk([](mlir::scf::ParallelOp op) {
+ mlir::scf::naivelyFuseParallelOps(op.region());
});
}
};
@@ -401,7 +374,7 @@
void runOnOperation() override {
mlir::Operation* module = getOperation();
- module->walk([&](mlir::loop::ParallelOp op) {
+ module->walk([&](mlir::scf::ParallelOp op) {
unsigned num_loops = op.getNumLoops();
std::vector<unsigned> combinedLoops;
combinedLoops.reserve(num_loops);
@@ -436,8 +409,10 @@
tiling_for_unrolling.append(tile_sizes.begin(), tile_sizes.end());
}
- // First, lower bodies of LHLO operations that contain HLO ops.
- pm.addPass(absl::make_unique<NestedHloRegionsConverter>());
+ // Legalize from HLO to LHLO.
+ pm.addPass(::mlir::xla_hlo::createLegalizeToLhloPass());
+ // Moving `AllocOp`s and inserting missing `DeallocOp`s
+ pm.addPass(::mlir::createBufferPlacementPass());
// Next, we can strip the outer fusion operation.
pm.addPass(absl::make_unique<FusionOpRemover>());
// Remove unnecessary LHLO copies.
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc
index 35ac3b2..667cdef 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc
@@ -549,10 +549,11 @@
}
// TODO(b/137624192): Add profiling support.
+
return {absl::make_unique<GpuExecutable>(
- ptx, cubin, GetGpuVersion(stream_exec), std::move(thunk_schedule),
- emission_context.releaseHloModule(), std::move(buffer_assignment),
- nullptr, nullptr)};
+ xla::gpu::GpuTargetBinary{ptx, cubin}, GetGpuVersion(stream_exec),
+ std::move(thunk_schedule), emission_context.releaseHloModule(),
+ std::move(buffer_assignment), nullptr, nullptr)};
}
StatusOr<std::vector<std::unique_ptr<Executable>>> MlirCompilerImpl::Compile(
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index ab71c30..2ed5e70 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -313,6 +313,8 @@
if (execution_options->num_partitions() > 0) {
config->set_num_partitions(execution_options->num_partitions());
}
+ config->set_use_spmd_partitioning(
+ execution_options->use_spmd_partitioning());
config->set_seed(execution_options->seed());
config->set_launch_id(execution_options->launch_id());
config->set_debug_options(execution_options->debug_options());
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index c8a242c..1ad1f83 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -1909,7 +1909,7 @@
# This test is tagged "manual" because it requires multiple GPUs, and
# Forge only supports single-GPU tests. Guitar skips "manual" tests
# unless they're also tagged "guitar".
- "guitar",
+ # "guitar", # Re-enable after b/156405690 is fixed.
"manual",
"multi_gpu",
"no_oss",
diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
index 1947f51..16ed022 100644
--- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
+++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
@@ -55,16 +55,15 @@
GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) { return 0; }
- StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
+ StatusOr<GpuTargetBinary> CompileTargetBinary(
const HloModule* hlo_module, llvm::Module* llvm_module,
- GpuVersion gpu_version, se::StreamExecutor* stream_exec) {
+ GpuVersion gpu_version, se::StreamExecutor* stream_exec) override {
if (user_post_optimization_hook_) {
user_post_optimization_hook_(*llvm_module);
}
std::vector<uint8> compiled_results;
- return std::pair<std::string, std::vector<uint8>>(
- "", std::move(compiled_results));
+ return GpuTargetBinary{"", std::move(compiled_results)};
}
};
} // namespace gpu
diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto
index a015af6..f4b08f45 100644
--- a/tensorflow/compiler/xla/xla.proto
+++ b/tensorflow/compiler/xla/xla.proto
@@ -333,6 +333,10 @@
// Used to identify a set of programs that should be launch together.
int32 launch_id = 10;
+
+ // Indicates whether to use SPMD (true) or MPMD (false) partitioning when
+ // num_partitions > 1 and XLA is requested to partition the input program.
+ bool use_spmd_partitioning = 11;
}
message GetDeviceHandlesRequest {
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 28ef943..6b4874a 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -105,19 +105,17 @@
# For platform specific build config
load(
"//tensorflow/core/platform:build_config.bzl",
- "tf_additional_all_protos",
"tf_additional_lib_deps",
"tf_additional_test_deps",
"tf_jspb_proto_library",
"tf_kernel_tests_linkstatic",
"tf_lib_proto_parsing_deps",
"tf_portable_deps_no_runtime",
+ "tf_portable_proto_lib",
"tf_proto_library",
- "tf_proto_library_cc",
"tf_protos_all_impl",
"tf_protos_grappler_impl",
"tf_protos_profiler_impl",
- "tf_pyclif_proto_library",
)
load(
"//tensorflow/core/platform:rules_cc.bzl",
@@ -180,18 +178,18 @@
# filegroup; e.g. ones with individual proto_library targets.
# LINT.IfChange
COMMON_PROTO_SRCS = [
- "protobuf/bfc_memory_map.proto",
- "protobuf/config.proto",
- "protobuf/cluster.proto",
- "protobuf/debug.proto",
- "protobuf/device_filters.proto",
- "protobuf/device_properties.proto",
- "protobuf/graph_debug_info.proto",
- "protobuf/queue_runner.proto",
- "protobuf/rewriter_config.proto",
- "protobuf/tensor_bundle.proto",
- "protobuf/saver.proto",
- "protobuf/verifier_config.proto",
+ "//tensorflow/core/protobuf:bfc_memory_map.proto",
+ "//tensorflow/core/protobuf:config.proto",
+ "//tensorflow/core/protobuf:cluster.proto",
+ "//tensorflow/core/protobuf:debug.proto",
+ "//tensorflow/core/protobuf:device_filters.proto",
+ "//tensorflow/core/protobuf:device_properties.proto",
+ "//tensorflow/core/protobuf:graph_debug_info.proto",
+ "//tensorflow/core/protobuf:queue_runner.proto",
+ "//tensorflow/core/protobuf:rewriter_config.proto",
+ "//tensorflow/core/protobuf:tensor_bundle.proto",
+ "//tensorflow/core/protobuf:saver.proto",
+ "//tensorflow/core/protobuf:verifier_config.proto",
]
EXAMPLE_PROTO_SRCS = [
@@ -238,7 +236,7 @@
]
ERROR_CODES_PROTO_SRCS = [
- "protobuf/error_codes.proto",
+ "//tensorflow/core/protobuf:error_codes.proto",
"//tensorflow/core/lib/core:error_codes.proto",
]
# LINT.ThenChange(//tensorflow/core/portable_proto_config.asciipb)
@@ -251,11 +249,13 @@
cc_api_version = 2,
make_default_target_header_only = True,
protodeps = [
- ":core_protos",
- ":error_codes_proto_impl",
"//tensorflow/core/example:protos_all",
"//tensorflow/core/framework:protos_all",
"//tensorflow/core/lib/core:error_codes_proto",
+ "//tensorflow/core/profiler/protobuf:xplane_proto",
+ "//tensorflow/core/profiler:profiler_options_proto",
+ "//tensorflow/core/protobuf:error_codes_proto_impl",
+ "//tensorflow/core/protobuf:for_core_protos",
"//tensorflow/core/util:protos_all",
"//tensorflow/core/util:test_log_proto_impl",
],
@@ -1270,7 +1270,7 @@
"//tensorflow/core/platform:mobile_srcs_no_runtime",
"//tensorflow/core/public:mobile_srcs_no_runtime",
"//tensorflow/core/util:mobile_srcs_no_runtime",
- "//tensorflow/core/util/ctc:android_srcs",
+ "//tensorflow/core/util/ctc:mobile_srcs",
] + glob(
[
"client/**/*.cc",
@@ -1300,12 +1300,12 @@
"//tensorflow/core/common_runtime/eager:srcs",
"//tensorflow/core/framework:mobile_srcs_only_runtime",
"//tensorflow/core/graph:mobile_srcs_only_runtime",
- "//tensorflow/core/kernels:android_srcs",
+ "//tensorflow/core/kernels:mobile_srcs",
"//tensorflow/core/lib/io:mobile_srcs_only_runtime",
"//tensorflow/core/profiler:mobile_srcs",
"//tensorflow/core/public:mobile_srcs_only_runtime",
"//tensorflow/core/util/sparse:mobile_srcs_only_runtime",
- "//tensorflow/core/util/tensor_bundle:android_srcs",
+ "//tensorflow/core/util/tensor_bundle:mobile_srcs",
"//tensorflow/core/util:mobile_srcs_only_runtime",
# Sources for which we already have granular targets.
@@ -1378,10 +1378,9 @@
],
visibility = ["//visibility:public"],
deps = [
- ":protos_all_cc_impl",
"//tensorflow/core/util:stats_calculator_portable",
"//tensorflow/core:mobile_additional_lib_deps",
- ] + tf_portable_deps_no_runtime(),
+ ] + tf_portable_proto_lib() + tf_portable_deps_no_runtime(),
alwayslink = 1,
)
@@ -1603,20 +1602,13 @@
[
alias(
name = "protobuf_%s_pyclif%s" % (proto_name, target_suffix),
- actual = ":protobuf/%s_pyclif%s" % (proto_name, target_suffix),
+ actual = "//tensorflow/core/protobuf:%s_pyclif%s" % (proto_name, target_suffix),
visibility = ["//visibility:public"],
)
for target_suffix in [
"",
"_pb2",
]
- ] + [
- tf_pyclif_proto_library(
- name = "protobuf/%s_pyclif" % proto_name,
- proto_lib = ":protos_all",
- proto_srcfile = "protobuf/%s.proto" % proto_name,
- visibility = ["//visibility:public"],
- ),
]
for proto_name in [
"config",
@@ -1630,77 +1622,74 @@
# -----------------------------------------------------------------------------
# Internal targets
-tf_proto_library(
+alias(
name = "autotuning_proto",
- srcs = ["protobuf/autotuning.proto"],
- cc_api_version = 2,
- make_default_target_header_only = True,
+ actual = "//tensorflow/core/protobuf:autotuning_proto",
visibility = [
"//tensorflow:internal",
],
)
-tf_proto_library(
+alias(
+ name = "autotuning_proto_cc",
+ actual = "//tensorflow/core/protobuf:autotuning_proto_cc",
+ visibility = [
+ "//tensorflow:internal",
+ ],
+)
+
+alias(
name = "conv_autotuning_proto",
- srcs = ["protobuf/conv_autotuning.proto"],
- cc_api_version = 2,
- make_default_target_header_only = True,
- protodeps = [
- "//tensorflow/stream_executor:dnn_proto",
- ],
+ actual = "//tensorflow/core/protobuf:conv_autotuning_proto",
visibility = [
"//tensorflow:internal",
],
)
-tf_proto_library_cc(
- name = "worker_proto",
- srcs = ["protobuf/worker.proto"],
- cc_api_version = 2,
- protodeps = tf_additional_all_protos(),
- visibility = ["//visibility:public"],
-)
-
-tf_proto_library_cc(
- name = "worker_service_proto",
- srcs = ["protobuf/worker_service.proto"],
- has_services = 1,
- cc_api_version = 2,
- cc_stubby_versions = ["2"],
- protodeps = [":worker_proto"],
+alias(
+ name = "conv_autotuning_proto_cc",
+ actual = "//tensorflow/core/protobuf:conv_autotuning_proto_cc",
visibility = [
"//tensorflow:internal",
],
)
-tf_proto_library_cc(
- name = "master_proto",
- srcs = ["protobuf/master.proto"],
- cc_api_version = 2,
- protodeps = tf_additional_all_protos(),
- visibility = ["//tensorflow:internal"],
-)
-
-tf_proto_library_cc(
- name = "master_service_proto",
- srcs = ["protobuf/master_service.proto"],
- has_services = 1,
- cc_api_version = 2,
- cc_stubby_versions = ["2"],
- protodeps = [":master_proto"],
+alias(
+ name = "worker_proto_cc",
+ actual = "//tensorflow/core/protobuf:worker_proto_cc",
visibility = [
"//tensorflow:internal",
],
)
-tf_proto_library_cc(
- name = "eager_service_proto",
- srcs = ["protobuf/eager_service.proto"],
- has_services = 1,
- cc_api_version = 2,
- cc_grpc_version = 1,
- cc_stubby_versions = ["2"],
- protodeps = tf_additional_all_protos(),
+alias(
+ name = "worker_service_proto_cc",
+ actual = "//tensorflow/core/protobuf:worker_service_proto_cc",
+ visibility = [
+ "//tensorflow:internal",
+ ],
+)
+
+alias(
+ name = "master_proto_cc",
+ actual = "//tensorflow/core/protobuf:master_proto_cc",
+ visibility = [
+ "//learning/brain/frameworks/uptc:__subpackages__",
+ "//tensorflow:internal",
+ ],
+)
+
+alias(
+ name = "master_service_proto_cc",
+ actual = "//tensorflow/core/protobuf:master_service_proto_cc",
+ visibility = [
+ "//tensorflow:internal",
+ ],
+)
+
+alias(
+ name = "eager_service_proto_cc",
+ actual = "//tensorflow/core/protobuf:eager_service_proto_cc",
visibility = [
"//tensorflow:internal",
],
@@ -2112,49 +2101,14 @@
],
)
-tf_proto_library(
+alias(
name = "error_codes_proto_impl",
- srcs = ["protobuf/error_codes.proto"],
- cc_api_version = 2,
- make_default_target_header_only = True,
+ actual = "//tensorflow/core/protobuf:error_codes_proto_impl",
)
-tf_proto_library(
- name = "core_protos",
- srcs = COMMON_PROTO_SRCS + [
- # Protos which are not needed on mobile builds, but should be included
- # in protos_all.
- #
- # Note that some protos are in neither core_proto_srcs nor this
- # filegroup; e.g. ones with individual proto_library targets.
- "protobuf/control_flow.proto",
- # TODO(ebrevdo): Re-enable once CriticalSection is in core.
- # "protobuf/critical_section.proto",
- "protobuf/data/experimental/snapshot.proto",
- "protobuf/debug_event.proto",
- "protobuf/meta_graph.proto",
- "protobuf/named_tensor.proto",
- "protobuf/remote_tensor_handle.proto",
- "protobuf/saved_model.proto",
- "protobuf/saved_object_graph.proto",
- "protobuf/struct.proto",
- "protobuf/tensorflow_server.proto",
- "protobuf/trackable_object_graph.proto",
- "protobuf/transport_options.proto",
- ],
- cc_api_version = 2,
- make_default_target_header_only = True,
- protodeps = [
- ":error_codes_proto_impl",
- "//tensorflow/core/example:protos_all",
- "//tensorflow/core/framework:protos_all",
- "//tensorflow/core/lib/core:error_codes_proto",
- "//tensorflow/core/profiler/protobuf:xplane_proto",
- "//tensorflow/core/profiler:profiler_options_proto",
- "//tensorflow/core/util:protos_all",
- "//tensorflow/core/util:test_log_proto_impl",
- ],
- visibility = ["//visibility:private"],
+alias(
+ name = "error_codes_proto_impl_cc",
+ actual = "//tensorflow/core/protobuf:error_codes_proto_impl_cc",
)
alias(
@@ -2446,13 +2400,9 @@
visibility = ["//visibility:public"],
)
-tf_proto_library_cc(
- name = "replay_log_proto",
- srcs = ["protobuf/replay_log.proto"],
- cc_api_version = 2,
- protodeps = [
- ":master_proto",
- ] + tf_additional_all_protos(),
+alias(
+ name = "replay_log_proto_cc",
+ actual = "//tensorflow/core/protobuf:replay_log_proto_cc",
visibility = [
"//tensorflow:internal",
],
diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD
index 484bdee..016896b 100644
--- a/tensorflow/core/common_runtime/BUILD
+++ b/tensorflow/core/common_runtime/BUILD
@@ -243,6 +243,7 @@
"memory_types.h",
"mkl_cpu_allocator.h",
"mkl_layout_pass.h",
+ "mkl_tfconversion_pass.h",
"optimization_registry.h",
"partitioning_utils.h",
"placer.h",
@@ -1028,9 +1029,13 @@
cc_library(
name = "mkl_layout_pass",
srcs = ["mkl_layout_pass.cc"],
- hdrs = ["mkl_layout_pass.h"],
+ hdrs = [
+ "mkl_layout_pass.h",
+ "//tensorflow/core/graph:mkl_graph_util_header",
+ ],
copts = tf_copts(),
deps = [
+ ":function",
":optimization_registry",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
@@ -1043,9 +1048,13 @@
cc_library(
name = "mkl_tfconversion_pass",
srcs = ["mkl_tfconversion_pass.cc"],
- hdrs = ["mkl_tfconversion_pass.h"],
+ hdrs = [
+ "mkl_tfconversion_pass.h",
+ "//tensorflow/core/graph:mkl_graph_util_header",
+ ],
copts = tf_copts(),
deps = [
+ ":function",
":optimization_registry",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
@@ -2286,7 +2295,7 @@
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:scope",
"//tensorflow/core/kernels:cwise_op",
- ] + if_mkl([":mkl_array_ops_op_lib"]),
+ ] + if_mkl(["//tensorflow/core:mkl_array_ops_op_lib"]),
)
tf_cc_test(
diff --git a/tensorflow/core/common_runtime/eager/eager_executor.cc b/tensorflow/core/common_runtime/eager/eager_executor.cc
index 13630a0..7850978 100644
--- a/tensorflow/core/common_runtime/eager/eager_executor.cc
+++ b/tensorflow/core/common_runtime/eager/eager_executor.cc
@@ -98,7 +98,7 @@
Status EagerExecutor::SyncExecute(EagerNode* node) {
if (Async()) {
- return errors::Internal("Executor does not support sync execution");
+ return errors::Internal("Executor does not support async execution");
}
if (node->AsAsync() != nullptr) {
return errors::Internal("Executor does not support executing async nodes");
diff --git a/tensorflow/core/common_runtime/eager/execute_node_test.cc b/tensorflow/core/common_runtime/eager/execute_node_test.cc
index 970307d..99f0303 100644
--- a/tensorflow/core/common_runtime/eager/execute_node_test.cc
+++ b/tensorflow/core/common_runtime/eager/execute_node_test.cc
@@ -15,6 +15,8 @@
#include "tensorflow/core/common_runtime/eager/execute_node.h"
+#include <memory>
+
#include "tensorflow/core/common_runtime/composite_device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/context.h"
@@ -49,9 +51,12 @@
StaticDeviceMgr device_mgr(
DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:0"));
Device* device0 = device_mgr.ListDevices().at(0);
- StaticDeviceMgr remote_device_mgr(
+ auto remote_device_mgr = absl::make_unique<DynamicDeviceMgr>();
+ std::vector<std::unique_ptr<Device>> remote_devices;
+ remote_devices.emplace_back(
DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:1"));
- Device* device1 = remote_device_mgr.ListDevices().at(0);
+ TF_ASSERT_OK(remote_device_mgr->AddDevices(std::move(remote_devices)));
+ Device* device1 = remote_device_mgr->ListDevices().at(0);
Status s;
std::unique_ptr<CompositeDevice> composite_device =
@@ -65,6 +70,17 @@
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false, false,
&device_mgr, false, nullptr, nullptr, nullptr);
+ // Set a RemoteMgr to the EagerContext.
+ auto remote_mgr = absl::make_unique<eager::RemoteMgr>(
+ /*is_master=*/true, ctx);
+ TF_ASSERT_OK(ctx->InitializeRemoteMaster(
+ /*server=*/nullptr, /*worker_env=*/nullptr,
+ /*worker_session=*/nullptr, /*remote_eager_workers=*/nullptr,
+ std::move(remote_device_mgr), /*remote_contexts=*/{},
+ EagerContext::NewContextId(),
+ /*r=*/nullptr, &device_mgr, /*keep_alive_secs*/ 600,
+ /*cluster_flr=*/nullptr, std::move(remote_mgr)));
+
DataType dtype = DT_FLOAT;
Tensor t0(dtype, TensorShape({}));
// Create two local TensorHandles
diff --git a/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc b/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc
index 2d4ae33..f233980 100644
--- a/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc
+++ b/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc
@@ -17,7 +17,7 @@
#include <unordered_map>
#include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h"
-#include "tensorflow/core/common_runyime/mkl_layout_pass.h"
+#include "tensorflow/core/common_runtime/mkl_layout_pass.h"
#include "tensorflow/core/graph/mkl_graph_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/util/mkl_util.h"
diff --git a/tensorflow/core/data/service/master_impl.cc b/tensorflow/core/data/service/master_impl.cc
index 6e2c95c..336ab06 100644
--- a/tensorflow/core/data/service/master_impl.cc
+++ b/tensorflow/core/data/service/master_impl.cc
@@ -169,7 +169,11 @@
if (job != nullptr) {
TF_RETURN_IF_ERROR(ValidateMatchingJob(**job, requested_processing_mode,
request->dataset_id()));
- response->set_job_id((*job)->job_id());
+ int64 job_id = (*job)->job_id();
+ response->set_job_id(job_id);
+ VLOG(3) << "Found existing job for name=" << request->job_name()
+ << ", index=" << request->job_name_index()
+ << ". job_id: " << job_id;
return Status::OK();
}
int64 job_id;
@@ -177,6 +181,8 @@
request->job_name(), &job_id));
named_jobs_[key] = jobs_[job_id];
response->set_job_id(job_id);
+ VLOG(3) << "Created job " << job_id << " for dataset "
+ << request->dataset_id() << " and name " << request->job_name();
return Status::OK();
}
diff --git a/tensorflow/core/data/service/master_impl.h b/tensorflow/core/data/service/master_impl.h
index de25ea0..e8b70e8 100644
--- a/tensorflow/core/data/service/master_impl.h
+++ b/tensorflow/core/data/service/master_impl.h
@@ -75,7 +75,7 @@
}
std::string DebugString() {
- return absl::StrCat("id: ", worker_id_, "address: ", address_);
+ return absl::StrCat("id: ", worker_id_, " address: ", address_);
}
private:
diff --git a/tensorflow/core/data/service/worker_impl.cc b/tensorflow/core/data/service/worker_impl.cc
index 7395244..8d00825 100644
--- a/tensorflow/core/data/service/worker_impl.cc
+++ b/tensorflow/core/data/service/worker_impl.cc
@@ -84,6 +84,7 @@
Status DataServiceWorkerImpl::ProcessTaskInternal(const TaskDef& task_def)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ VLOG(3) << "Received request to process task " << task_def.task_id();
standalone::Dataset::Params params;
std::unique_ptr<standalone::Dataset> dataset;
TF_RETURN_IF_ERROR(standalone::Dataset::FromGraph(
@@ -100,6 +101,7 @@
task.id = task_def.task_id();
task.dataset = std::move(dataset);
task.iterator = std::move(iterator);
+ VLOG(3) << "Began processing for task " << task_def.task_id();
return Status::OK();
}
diff --git a/tensorflow/core/framework/shape_inference_testutil.h b/tensorflow/core/framework/shape_inference_testutil.h
index 40a6d53..361f7ed 100644
--- a/tensorflow/core/framework/shape_inference_testutil.h
+++ b/tensorflow/core/framework/shape_inference_testutil.h
@@ -16,6 +16,7 @@
#define TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_TESTUTIL_H_
#include <vector>
+
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/core/status.h"
@@ -90,7 +91,7 @@
::tensorflow::shape_inference::ShapeInferenceTestutil::InferShapes( \
op, i, "e") \
.error_message(); \
- const std::string& substring = error_substring; \
+ const std::string substring = error_substring; \
EXPECT_NE("", error_message); \
EXPECT_TRUE(absl::StrContains(error_message, substring)) \
<< "Expected to see '" << substring << "' in '" << error_message \
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index f47265f..cd0d44e 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -114,12 +114,12 @@
}
uint64 DeadlineMicroSeconds(const RewriterConfig& cfg) {
- const uint64 kFiveMinutesInUsec = 5 * 60 * 1000 * 1000;
+ const uint64 kTwentyMinutesInUsec = 20 * 60 * 1000 * 1000;
if (cfg.meta_optimizer_timeout_ms() < 0) {
return 0;
} else {
return cfg.meta_optimizer_timeout_ms() == 0
- ? Env::Default()->NowMicros() + kFiveMinutesInUsec
+ ? Env::Default()->NowMicros() + kTwentyMinutesInUsec
: Env::Default()->NowMicros() +
cfg.meta_optimizer_timeout_ms() * 1000;
}
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index e47c681..7cfb6fc 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -7096,7 +7096,7 @@
build_test(
name = "android_tensorflow_kernels_build_test",
- targets = [":android_tensorflow_kernels"],
+ targets = [":portable_tensorflow_kernels"],
)
cc_library(
@@ -7109,7 +7109,7 @@
"//tensorflow/core:android_gif_internal",
"//tensorflow/core:android_jpeg_internal",
"//tensorflow/core:android_png_internal",
- "//tensorflow/core:android_tensorflow_lib_lite",
+ "//tensorflow/core:portable_tensorflow_lib_lite",
],
alwayslink = 1,
)
@@ -7126,7 +7126,7 @@
linkopts = ["-ldl"],
visibility = ["//visibility:public"],
deps = [
- "//tensorflow/core:android_tensorflow_lib_lite",
+ "//tensorflow/core:portable_tensorflow_lib_lite",
],
alwayslink = 1,
)
diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD
index d61c574..4ddfd99 100644
--- a/tensorflow/core/kernels/data/experimental/BUILD
+++ b/tensorflow/core/kernels/data/experimental/BUILD
@@ -138,6 +138,7 @@
"//tensorflow/core/kernels/data:dataset_utils",
"//tensorflow/core/kernels/data:name_utils",
"//tensorflow/core/kernels/data:serialization_utils",
+ "//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
diff --git a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc
index 8c33668..697f4d9 100644
--- a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc
@@ -37,6 +37,7 @@
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/snappy.h"
+#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
namespace tensorflow {
@@ -178,7 +179,10 @@
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params, int64 iterator_index)
- : DatasetIterator<Dataset>(params), iterator_index_(iterator_index) {}
+ : DatasetIterator<Dataset>(params),
+ iterator_index_(iterator_index),
+ max_outstanding_requests_(params.dataset->max_outstanding_requests_) {
+ }
~Iterator() override {
mutex_lock l(mu_);
@@ -390,21 +394,23 @@
// TODO(aaudibert): add backoff and max retries.
int64 deadline_micros =
Env::Default()->NowMicros() + kRetryTimeoutMicros;
- Status s = FetchElement(task_thread, deadline_micros);
+ Status s = GetElement(task_thread, deadline_micros);
if (!s.ok()) {
- LOG(WARNING) << "Failed to fetch element from worker at "
+ LOG(WARNING) << "Failed to get element from worker at "
<< task_thread->address << ": " << s;
}
}
}
- // Fetches an element from a task and adds the element to `results_`.
+ // Gets an element from a task and adds the element to `results_`.
//
// If the task reaches end_of_sequence or is cancelled (e.g. due to a
- // worker dying), FetchElement returns Status::OK() without adding to
+ // worker dying), GetElement returns Status::OK() without adding to
// `results_`.
- Status FetchElement(TaskThread* task_thread, int64 deadline_micros) {
- VLOG(3) << "Fetching an element for task id " << task_thread->task_id;
+ Status GetElement(TaskThread* task_thread, int64 deadline_micros) {
+ VLOG(3) << "Getting an element for task id " << task_thread->task_id;
+ tensorflow::profiler::TraceMe activity(
+ "GetElement", tensorflow::profiler::TraceMeLevel::kInfo);
CompressedElement compressed;
bool end_of_sequence;
for (int num_retries = 0;; ++num_retries) {
@@ -453,7 +459,7 @@
}
results_.push(std::move(element));
cv_.notify_all();
- VLOG(3) << "Fetched an element for task id " << task_thread->task_id;
+ VLOG(3) << "Got an element for task id " << task_thread->task_id;
return Status::OK();
}
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
index 5dae096..7b8f697 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
@@ -621,6 +621,11 @@
return false;
}
if (!deterministic_) {
+ // Iterate through in-flight results and returns the first one that is
+ // found to be available and not end-of-input. If the first result (in
+ // order) is end-of-input, we know that all earlier iterations have
+ // already been completed, so it is safe to return that result for the
+ // caller to process end of iteration.
for (auto it = invocation_results_.begin();
it != invocation_results_.end(); ++it) {
if ((*it)->notification.HasBeenNotified() &&
diff --git a/tensorflow/core/kernels/debug_ops.h b/tensorflow/core/kernels/debug_ops.h
index 0035677..42364e4 100644
--- a/tensorflow/core/kernels/debug_ops.h
+++ b/tensorflow/core/kernels/debug_ops.h
@@ -435,9 +435,9 @@
for (const string& dump_root : dump_roots_) {
tfdbg::DebugEventsWriter* debug_events_writer =
tfdbg::DebugEventsWriter::GetDebugEventsWriter(dump_root);
- debug_events_writer->WriteGraphExecutionTrace(
- tfdbg_context_id_, device_name_, op_name_, output_slot_,
- tensor_debug_mode_, tensor);
+ OP_REQUIRES_OK(context, debug_events_writer->WriteGraphExecutionTrace(
+ tfdbg_context_id_, device_name_, op_name_,
+ output_slot_, tensor_debug_mode_, tensor));
}
context->set_output(0, tensor);
}
diff --git a/tensorflow/core/kernels/dequantize_op.cc b/tensorflow/core/kernels/dequantize_op.cc
index 0f5a701..3b38daf 100644
--- a/tensorflow/core/kernels/dequantize_op.cc
+++ b/tensorflow/core/kernels/dequantize_op.cc
@@ -61,7 +61,9 @@
" is '" +
DataTypeString(ctx->output_type(0)) + "'"));
+ need_cast_ = true;
if (ctx->output_type(0) == DT_FLOAT) {
+ need_cast_ = false;
OP_REQUIRES(ctx,
(mode_string == "MIN_COMBINED" ||
mode_string == "MIN_FIRST" || mode_string == "SCALED"),
@@ -98,8 +100,9 @@
}
Tensor* output = nullptr;
- Tensor float_output = tensorflow::Tensor(DT_FLOAT, input.shape());
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
+ Tensor float_output =
+ need_cast_ ? tensorflow::Tensor(DT_FLOAT, input.shape()) : *output;
if (num_slices == 1) {
const float min_range = input_min_tensor.flat<float>()(0);
const float max_range = input_max_tensor.flat<float>()(0);
@@ -128,10 +131,12 @@
max_ranges(i), output_tensor.template chip<1>(i));
}
}
- S* out_ptr = output->flat<S>().data();
- float* in_ptr = float_output.flat<float>().data();
- for (int64 i = 0; i < float_output.NumElements(); ++i) {
- out_ptr[i] = static_cast<S>(in_ptr[i]);
+ if (need_cast_) {
+ S* out_ptr = output->flat<S>().data();
+ float* in_ptr = float_output.flat<float>().data();
+ for (int64 i = 0; i < float_output.NumElements(); ++i) {
+ out_ptr[i] = static_cast<S>(in_ptr[i]);
+ }
}
}
@@ -219,6 +224,7 @@
int mode_;
int axis_;
bool narrow_range_;
+ bool need_cast_;
};
REGISTER_KERNEL_BUILDER(Name("Dequantize")
diff --git a/tensorflow/core/lib/core/BUILD b/tensorflow/core/lib/core/BUILD
index 80ad494..491e4c5 100644
--- a/tensorflow/core/lib/core/BUILD
+++ b/tensorflow/core/lib/core/BUILD
@@ -138,10 +138,13 @@
cc_api_version = 2,
make_default_target_header_only = True,
protodeps = [
- "//tensorflow/core:error_codes_proto_impl",
+ "//tensorflow/core/protobuf:error_codes_proto_impl",
],
- visibility = ["//tensorflow/core:__subpackages__"],
- exports = ["//tensorflow/core:error_codes_proto_impl"],
+ visibility = [
+ "//tensorflow/core:__subpackages__",
+ "//tensorflow/core/protobuf:__subpackages__",
+ ],
+ exports = ["//tensorflow/core/protobuf:error_codes_proto_impl"],
)
# Export source files needed for mobile builds, which do not use granular targets.
diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD
index 819f8fc..c7ff378 100644
--- a/tensorflow/core/platform/BUILD
+++ b/tensorflow/core/platform/BUILD
@@ -621,7 +621,7 @@
":stringpiece",
":stringprintf",
":types",
- "//tensorflow/core:error_codes_proto_impl_cc",
+ "//tensorflow/core/protobuf:error_codes_proto_impl_cc",
"@com_google_absl//absl/base",
],
)
diff --git a/tensorflow/core/platform/build_config.bzl b/tensorflow/core/platform/build_config.bzl
index f0613cd..ab45256 100644
--- a/tensorflow/core/platform/build_config.bzl
+++ b/tensorflow/core/platform/build_config.bzl
@@ -26,6 +26,7 @@
_tf_platform_alias = "tf_platform_alias",
_tf_platform_deps = "tf_platform_deps",
_tf_portable_deps_no_runtime = "tf_portable_deps_no_runtime",
+ _tf_portable_proto_lib = "tf_portable_proto_lib",
_tf_proto_library = "tf_proto_library",
_tf_proto_library_cc = "tf_proto_library_cc",
_tf_proto_library_py = "tf_proto_library_py",
@@ -65,6 +66,7 @@
tf_logging_deps = _tf_logging_deps
tf_platform_alias = _tf_platform_alias
tf_platform_deps = _tf_platform_deps
+tf_portable_proto_lib = _tf_portable_proto_lib
tf_portable_deps_no_runtime = _tf_portable_deps_no_runtime
tf_proto_library = _tf_proto_library
tf_proto_library_cc = _tf_proto_library_cc
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index 18a8285..2dc4fdc 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -577,8 +577,8 @@
def tf_protos_all_impl():
return [
- clean_dep("//tensorflow/core:autotuning_proto_cc_impl"),
- clean_dep("//tensorflow/core:conv_autotuning_proto_cc_impl"),
+ clean_dep("//tensorflow/core/protobuf:autotuning_proto_cc_impl"),
+ clean_dep("//tensorflow/core/protobuf:conv_autotuning_proto_cc_impl"),
clean_dep("//tensorflow/core:protos_all_cc_impl"),
]
@@ -727,6 +727,9 @@
otherwise = [clean_dep("@com_google_protobuf//:protobuf_headers")],
)
+def tf_portable_proto_lib():
+ return ["//tensorflow/core:protos_all_cc_impl"]
+
def tf_protobuf_compiler_deps():
return if_static(
[
@@ -764,7 +767,7 @@
"@nsync//:nsync_cpp",
"@com_googlesource_code_re2//:re2",
"@farmhash_archive//:farmhash",
- ] + tf_protobuf_deps()
+ ]
def tf_google_mobile_srcs_no_runtime():
return []
diff --git a/tensorflow/core/platform/profile_utils/cpu_utils.cc b/tensorflow/core/platform/profile_utils/cpu_utils.cc
index 587c978..b22123a 100644
--- a/tensorflow/core/platform/profile_utils/cpu_utils.cc
+++ b/tensorflow/core/platform/profile_utils/cpu_utils.cc
@@ -88,6 +88,8 @@
defined(__ppc__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__))
retval = sscanf(line.c_str(), "clock : %lfMHz", &cpu_freq);
freq_factor = 1.0;
+#elif defined(__s390x__)
+ retval = sscanf(line.c_str(), "bogomips per cpu: %lf", &cpu_freq);
#else
retval = sscanf(line.c_str(), "bogomips : %lf", &cpu_freq);
#endif
diff --git a/tensorflow/core/profiler/convert/BUILD b/tensorflow/core/profiler/convert/BUILD
index 15a1ad0..369d26a 100644
--- a/tensorflow/core/profiler/convert/BUILD
+++ b/tensorflow/core/profiler/convert/BUILD
@@ -17,15 +17,18 @@
"//tensorflow/core/profiler/protobuf:op_metrics_proto_cc",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
"//tensorflow/core/profiler/utils:cost_utils",
- "//tensorflow/core/profiler/utils:event_span",
+ "//tensorflow/core/profiler/utils:op_metrics_db_utils",
"//tensorflow/core/profiler/utils:op_utils",
"//tensorflow/core/profiler/utils:tf_op_utils",
"//tensorflow/core/profiler/utils:tf_xplane_visitor",
"//tensorflow/core/profiler/utils:timespan",
"//tensorflow/core/profiler/utils:trace_utils",
+ "//tensorflow/core/profiler/utils:xplane_schema",
+ "//tensorflow/core/profiler/utils:xplane_visitor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -35,9 +38,11 @@
srcs = ["xplane_to_op_metrics_db_test.cc"],
deps = [
":xplane_to_op_metrics_db",
+ "//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/profiler/protobuf:op_metrics_proto_cc",
+ "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
"//tensorflow/core/profiler/utils:op_metrics_db_utils",
"//tensorflow/core/profiler/utils:time_utils",
"//tensorflow/core/profiler/utils:xplane_builder",
@@ -86,13 +91,15 @@
":op_stats_to_input_pipeline_analysis",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
- "//tensorflow/core/platform:logging",
"//tensorflow/core/profiler/protobuf:hardware_types_proto_cc",
"//tensorflow/core/profiler/protobuf:input_pipeline_proto_cc",
"//tensorflow/core/profiler/protobuf:op_metrics_proto_cc",
"//tensorflow/core/profiler/protobuf:op_stats_proto_cc",
"//tensorflow/core/profiler/protobuf:overview_page_proto_cc",
+ "//tensorflow/core/profiler/protobuf:steps_db_proto_cc",
+ "//tensorflow/core/profiler/protobuf:tf_function_proto_cc",
"//tensorflow/core/profiler/utils:errors",
+ "//tensorflow/core/profiler/utils:html_utils",
"//tensorflow/core/profiler/utils:math_utils",
"//tensorflow/core/profiler/utils:op_metrics_db_utils",
"//tensorflow/core/profiler/utils:time_utils",
@@ -118,11 +125,11 @@
"//tensorflow/core/profiler/protobuf:steps_db_proto_cc",
"//tensorflow/core/profiler/utils:errors",
"//tensorflow/core/profiler/utils:event_span",
+ "//tensorflow/core/profiler/utils:html_utils",
"//tensorflow/core/profiler/utils:math_utils",
"//tensorflow/core/profiler/utils:tf_op_utils",
"//tensorflow/core/profiler/utils:time_utils",
"//tensorflow/core/util:stats_calculator_portable",
- "@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
@@ -135,13 +142,12 @@
hdrs = ["op_stats_to_tf_stats.h"],
deps = [
":op_metrics_to_record",
+ "//tensorflow/core:lib",
"//tensorflow/core/profiler/protobuf:op_metrics_proto_cc",
"//tensorflow/core/profiler/protobuf:op_stats_proto_cc",
"//tensorflow/core/profiler/protobuf:tf_stats_proto_cc",
"//tensorflow/core/profiler/utils:op_metrics_db_utils",
- "//tensorflow/core/profiler/utils:tf_op_utils",
"//tensorflow/core/profiler/utils:time_utils",
- "@com_google_absl//absl/container:flat_hash_set",
],
)
@@ -152,13 +158,18 @@
deps = [
":op_stats_to_tf_stats",
":xplane_to_op_stats",
+ "//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/profiler/protobuf:op_metrics_proto_cc",
+ "//tensorflow/core/profiler/protobuf:op_stats_proto_cc",
+ "//tensorflow/core/profiler/protobuf:tf_stats_proto_cc",
+ "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
"//tensorflow/core/profiler/utils:op_metrics_db_utils",
"//tensorflow/core/profiler/utils:time_utils",
"//tensorflow/core/profiler/utils:xplane_builder",
"//tensorflow/core/profiler/utils:xplane_schema",
+ "@com_google_absl//absl/strings",
],
)
@@ -171,6 +182,9 @@
"//tensorflow/core:lib_internal",
"//tensorflow/core/profiler/protobuf:steps_db_proto_cc",
"//tensorflow/core/profiler/utils:event_span",
+ "//tensorflow/core/profiler/utils:timespan",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
],
)
@@ -205,6 +219,7 @@
srcs = ["xplane_to_op_stats.cc"],
hdrs = ["xplane_to_op_stats.h"],
deps = [
+ ":op_metrics_db_combiner",
":step_events_to_steps_db",
":xplane_to_kernel_stats_db",
":xplane_to_op_metrics_db",
@@ -213,15 +228,20 @@
"//tensorflow/core:lib",
"//tensorflow/core/profiler/protobuf:hardware_types_proto_cc",
"//tensorflow/core/profiler/protobuf:kernel_stats_proto_cc",
+ "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc",
"//tensorflow/core/profiler/protobuf:op_stats_proto_cc",
+ "//tensorflow/core/profiler/protobuf:steps_db_proto_cc",
"//tensorflow/core/profiler/protobuf:tf_function_proto_cc",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
"//tensorflow/core/profiler/utils:event_span",
"//tensorflow/core/profiler/utils:hardware_type_utils",
"//tensorflow/core/profiler/utils:kernel_stats_utils",
+ "//tensorflow/core/profiler/utils:tf_op_utils",
"//tensorflow/core/profiler/utils:tf_xplane_visitor",
"//tensorflow/core/profiler/utils:xplane_schema",
"//tensorflow/core/profiler/utils:xplane_utils",
+ "//tensorflow/core/profiler/utils:xplane_visitor",
+ "@com_google_absl//absl/container:flat_hash_map",
],
)
@@ -239,11 +259,15 @@
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/profiler/protobuf:op_metrics_proto_cc",
+ "//tensorflow/core/profiler/protobuf:op_stats_proto_cc",
"//tensorflow/core/profiler/protobuf:steps_db_proto_cc",
+ "//tensorflow/core/profiler/protobuf:tf_function_proto_cc",
+ "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
"//tensorflow/core/profiler/utils:group_events",
"//tensorflow/core/profiler/utils:xplane_builder",
"//tensorflow/core/profiler/utils:xplane_schema",
"//tensorflow/core/profiler/utils:xplane_utils",
+ "@com_google_absl//absl/strings",
],
)
@@ -259,7 +283,6 @@
":xplane_to_memory_profile",
":xplane_to_op_stats",
":xplane_to_trace_events",
- "//tensorflow/core:human_readable_json",
"//tensorflow/core:lib",
"//tensorflow/core/profiler:profiler_service_proto_cc",
"//tensorflow/core/profiler/protobuf:hardware_types_proto_cc",
@@ -269,6 +292,7 @@
"//tensorflow/core/profiler/protobuf:op_stats_proto_cc",
"//tensorflow/core/profiler/protobuf:overview_page_proto_cc",
"//tensorflow/core/profiler/protobuf:tf_stats_proto_cc",
+ "//tensorflow/core/profiler/protobuf:trace_events_proto_cc",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
"//tensorflow/core/profiler/rpc/client:save_profile",
"//tensorflow/core/profiler/utils:xplane_schema",
@@ -284,12 +308,14 @@
srcs = ["xplane_to_profile_response_test.cc"],
deps = [
":xplane_to_profile_response",
+ "//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/profiler:profiler_service_proto_cc",
"//tensorflow/core/profiler/protobuf:input_pipeline_proto_cc",
"//tensorflow/core/profiler/protobuf:overview_page_proto_cc",
"//tensorflow/core/profiler/protobuf:tf_stats_proto_cc",
+ "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
"//tensorflow/core/profiler/utils:group_events",
"//tensorflow/core/profiler/utils:xplane_builder",
"//tensorflow/core/profiler/utils:xplane_schema",
@@ -303,14 +329,16 @@
hdrs = ["xplane_to_step_events.h"],
deps = [
"//tensorflow/core:lib",
- "//tensorflow/core:protos_all_cc",
- "//tensorflow/core/profiler/protobuf:steps_db_proto_cc",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
"//tensorflow/core/profiler/utils:event_span",
"//tensorflow/core/profiler/utils:tf_xplane_visitor",
+ "//tensorflow/core/profiler/utils:timespan",
"//tensorflow/core/profiler/utils:trace_utils",
"//tensorflow/core/profiler/utils:xplane_schema",
+ "//tensorflow/core/profiler/utils:xplane_visitor",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -320,12 +348,16 @@
srcs = ["xplane_to_step_events_test.cc"],
deps = [
":xplane_to_step_events",
+ "//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
+ "//tensorflow/core/profiler/utils:event_span",
"//tensorflow/core/profiler/utils:group_events",
"//tensorflow/core/profiler/utils:xplane_builder",
"//tensorflow/core/profiler/utils:xplane_schema",
"//tensorflow/core/profiler/utils:xplane_utils",
+ "@com_google_absl//absl/container:flat_hash_map",
],
)
@@ -339,7 +371,9 @@
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
"//tensorflow/core/profiler/utils:tf_xplane_visitor",
"//tensorflow/core/profiler/utils:xplane_schema",
+ "//tensorflow/core/profiler/utils:xplane_visitor",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -355,6 +389,8 @@
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
+ "//tensorflow/core/profiler/protobuf:trace_events_proto_cc",
+ "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
"//tensorflow/core/profiler/utils:xplane_builder",
"//tensorflow/core/profiler/utils:xplane_schema",
"//tensorflow/core/profiler/utils:xplane_utils",
@@ -370,14 +406,14 @@
"//tensorflow/core:lib_internal",
"//tensorflow/core/profiler/protobuf:kernel_stats_proto_cc",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
- "//tensorflow/core/profiler/utils:event_span",
"//tensorflow/core/profiler/utils:kernel_stats_utils",
"//tensorflow/core/profiler/utils:tf_op_utils",
"//tensorflow/core/profiler/utils:tf_xplane_visitor",
"//tensorflow/core/profiler/utils:trace_utils",
"//tensorflow/core/profiler/utils:xplane_schema",
- "//tensorflow/core/profiler/utils:xplane_utils",
"//tensorflow/core/profiler/utils:xplane_visitor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -390,14 +426,14 @@
"//tensorflow/core:lib_internal",
"//tensorflow/core/profiler/protobuf:tf_function_proto_cc",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
+ "//tensorflow/core/profiler/utils:math_utils",
"//tensorflow/core/profiler/utils:tf_xplane_visitor",
"//tensorflow/core/profiler/utils:timespan",
"//tensorflow/core/profiler/utils:xplane_schema",
- "//tensorflow/core/profiler/utils:xplane_utils",
"//tensorflow/core/profiler/utils:xplane_visitor",
"@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -414,10 +450,13 @@
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/profiler/protobuf:tf_function_proto_cc",
+ "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
"//tensorflow/core/profiler/utils:tf_xplane_visitor",
"//tensorflow/core/profiler/utils:xplane_builder",
"//tensorflow/core/profiler/utils:xplane_schema",
"//tensorflow/core/profiler/utils:xplane_utils",
+ "//tensorflow/core/profiler/utils:xplane_visitor",
+ "@com_google_absl//absl/strings",
],
)
@@ -429,15 +468,18 @@
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "//tensorflow/core/framework:protos_all_cc",
"//tensorflow/core/platform:protobuf",
"//tensorflow/core/profiler/protobuf:memory_profile_proto_cc",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
"//tensorflow/core/profiler/utils:tf_xplane_visitor",
"//tensorflow/core/profiler/utils:xplane_schema",
- "//tensorflow/core/profiler/utils:xplane_utils",
+ "//tensorflow/core/profiler/utils:xplane_visitor",
+ "@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -447,10 +489,14 @@
srcs = ["xplane_to_memory_profile_test.cc"],
deps = [
":xplane_to_memory_profile",
+ "//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "//tensorflow/core/profiler/protobuf:memory_profile_proto_cc",
+ "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
"//tensorflow/core/profiler/utils:xplane_builder",
"//tensorflow/core/profiler/utils:xplane_schema",
"//tensorflow/core/profiler/utils:xplane_utils",
+ "@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/core/profiler/convert/op_metrics_db_combiner.cc b/tensorflow/core/profiler/convert/op_metrics_db_combiner.cc
index 3f601bb..8229d10 100644
--- a/tensorflow/core/profiler/convert/op_metrics_db_combiner.cc
+++ b/tensorflow/core/profiler/convert/op_metrics_db_combiner.cc
@@ -16,6 +16,7 @@
#include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
namespace tensorflow {
namespace profiler {
diff --git a/tensorflow/core/profiler/convert/op_metrics_to_record.cc b/tensorflow/core/profiler/convert/op_metrics_to_record.cc
index b51c679..8e28199 100644
--- a/tensorflow/core/profiler/convert/op_metrics_to_record.cc
+++ b/tensorflow/core/profiler/convert/op_metrics_to_record.cc
@@ -15,7 +15,9 @@
#include "tensorflow/core/profiler/convert/op_metrics_to_record.h"
+#include <iterator>
#include <tuple>
+#include <vector>
#include "absl/algorithm/container.h"
#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
diff --git a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc
index ca2a6c2..8367345 100644
--- a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc
+++ b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc
@@ -15,11 +15,12 @@
#include "tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h"
+#include <math.h>
+
#include <algorithm>
-#include <utility>
+#include <vector>
#include "google/protobuf/any.pb.h"
-#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
@@ -27,7 +28,6 @@
#include "absl/strings/string_view.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/convert/op_metrics_to_record.h"
#include "tensorflow/core/profiler/convert/step_events_to_steps_db.h"
@@ -38,6 +38,7 @@
#include "tensorflow/core/profiler/protobuf/steps_db.pb.h"
#include "tensorflow/core/profiler/utils/errors.h"
#include "tensorflow/core/profiler/utils/event_span.h"
+#include "tensorflow/core/profiler/utils/html_utils.h"
#include "tensorflow/core/profiler/utils/math_utils.h"
#include "tensorflow/core/profiler/utils/tf_op_utils.h"
#include "tensorflow/core/profiler/utils/time_utils.h"
@@ -103,7 +104,7 @@
avg = sdv = min = max = 0.0;
} else {
avg = sample_stats.avg();
- sdv = std::sqrt(sample_stats.sample_variance());
+ sdv = sqrt(sample_stats.sample_variance());
min = sample_stats.min();
max = sample_stats.max();
}
@@ -243,7 +244,7 @@
kPreprocessing // data preprocessing.
};
-string InputOpCategoryString(InputOpCategory category) {
+std::string InputOpCategoryString(InputOpCategory category) {
switch (category) {
case InputOpCategory::kEnqueue:
return "Enqueue";
@@ -327,10 +328,6 @@
return details;
}
-string AnchorElement(absl::string_view url, absl::string_view text) {
- return absl::StrCat("<a href=\"", url, "\" target=\"_blank\">", text, "</a>");
-}
-
// Returns the ratio of the host-to-device time in each step to the step-time.
double RatioOfHostToDeviceTimeToStepTime(
const OpMetricsDb& host_tf_metrics_db,
@@ -362,9 +359,9 @@
}
void KernelLaunchAnalysis(bool tfdata_used, double kernel_launch_percent,
- string* kernel_launch_classification,
- string* kernel_launch_statement) {
- string percent_str = absl::StrFormat("%.1lf", kernel_launch_percent);
+ std::string* kernel_launch_classification,
+ std::string* kernel_launch_statement) {
+ std::string percent_str = absl::StrFormat("%.1lf", kernel_launch_percent);
if (kernel_launch_percent >= kHighlyKernelLaunchBoundThresholdInPercent) {
*kernel_launch_classification = "high";
*kernel_launch_statement = absl::StrCat(
@@ -389,14 +386,14 @@
}
void AllOtherAnalysis(bool all_other_reported, double all_other_percent,
- string* all_other_classification,
- string* all_other_statement) {
+ std::string* all_other_classification,
+ std::string* all_other_statement) {
if (all_other_reported) {
*all_other_classification = "no";
*all_other_statement = "";
return;
}
- string percent_str = absl::StrFormat("%.1lf", all_other_percent);
+ std::string percent_str = absl::StrFormat("%.1lf", all_other_percent);
if (all_other_percent >= kHighlyAllOtherBoundThresholdInPercent) {
*all_other_classification = "high";
*all_other_statement =
@@ -588,9 +585,10 @@
}
bool InputAnalysis(double input_percent, double all_other_percent,
- string* input_classification, string* input_statement) {
+ std::string* input_classification,
+ std::string* input_statement) {
absl::string_view non_input_time = "other time";
- string infeed_percent_str = absl::StrFormat("%.1lf", input_percent);
+ std::string infeed_percent_str = absl::StrFormat("%.1lf", input_percent);
if (input_percent >= kHighlyInfeedBoundThresholdInPercent) {
*input_classification = "host";
*input_statement = absl::StrCat(
@@ -610,7 +608,8 @@
// Input analysis says it is not input-bound, but "All-Other" time
// is significant. It could still be input-bound (or Python overhead).
*input_classification = "both";
- string all_other_percent_str = absl::StrFormat("%.1lf", all_other_percent);
+ std::string all_other_percent_str =
+ absl::StrFormat("%.1lf", all_other_percent);
*input_statement = absl::StrCat(
"Your program is POTENTIALLY input-bound because ",
all_other_percent_str,
@@ -630,8 +629,8 @@
}
}
-void OutputAnalysis(double output_percent, string* output_classification,
- string* output_statement) {
+void OutputAnalysis(double output_percent, std::string* output_classification,
+ std::string* output_statement) {
string tc_outfeed_percent_str = absl::StrFormat("%.1lf", output_percent);
if (output_percent >= kHighlyOutfeedBoundThresholdInPercent) {
*output_classification = "host";
@@ -703,19 +702,19 @@
double kernel_launch_percent =
100.0 * total_host_prepare_ms / total_step_time_ms;
double all_other_percent = 100.0 * total_unknown_ms / total_step_time_ms;
- string input_classification;
- string input_statement;
+ std::string input_classification;
+ std::string input_statement;
bool all_other_reported =
InputAnalysis(input_percent, all_other_percent, &input_classification,
&input_statement);
- string kernel_launch_classification;
- string kernel_launch_statement;
+ std::string kernel_launch_classification;
+ std::string kernel_launch_statement;
KernelLaunchAnalysis(TfDataInUse(input_time_breakdown), kernel_launch_percent,
&kernel_launch_classification, &kernel_launch_statement);
- string all_other_classification;
- string all_other_statement;
+ std::string all_other_classification;
+ std::string all_other_statement;
AllOtherAnalysis(all_other_reported, all_other_percent,
&all_other_classification, &all_other_statement);
@@ -729,9 +728,9 @@
return analysis;
}
-string GetSummaryNextStep(absl::string_view input_classification,
- const InputTimeBreakdown& breakdown) {
- string summary_next_step;
+std::string GetSummaryNextStep(absl::string_view input_classification,
+ const InputTimeBreakdown& breakdown) {
+ std::string summary_next_step;
if (input_classification == "host" || input_classification == "both") {
if (!TfDataInUse(breakdown)) {
summary_next_step = absl::StrCat(
diff --git a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h
index 738daea..93b4df0 100644
--- a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h
+++ b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h
@@ -16,12 +16,15 @@
#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_INPUT_PIPELINE_ANALYSIS_H_
#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_INPUT_PIPELINE_ANALYSIS_H_
+#include <string>
+
#include "google/protobuf/any.pb.h"
#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h"
#include "tensorflow/core/profiler/protobuf/input_pipeline.pb.h"
+#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
#include "tensorflow/core/profiler/protobuf/op_stats.pb.h"
#include "tensorflow/core/profiler/protobuf/steps_db.pb.h"
@@ -50,16 +53,18 @@
// Returns true if explanation for "All Others" time is also included in
// input_statement.
bool InputAnalysis(double input_percent, double all_other_percent,
- string* input_classification, string* input_statement);
+ std::string* input_classification,
+ std::string* input_statement);
-void OutputAnalysis(double output_percent, string* output_classification,
- string* output_statement);
+void OutputAnalysis(double output_percent, std::string* output_classification,
+ std::string* output_statement);
string GetSummaryNextStep(absl::string_view input_classification,
const InputTimeBreakdown& breakdown);
void AddErrorMessages(const OpStats& op_stats,
InputPipelineAnalysisResult* result);
+
} // namespace profiler
} // namespace tensorflow
diff --git a/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc b/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc
index e19690a..bec92e0 100644
--- a/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc
+++ b/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc
@@ -15,13 +15,11 @@
#include "tensorflow/core/profiler/convert/op_stats_to_overview_page.h"
-#include <algorithm>
-#include <utility>
+#include <string>
#include "google/protobuf/any.pb.h"
+#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/convert/op_metrics_to_record.h"
#include "tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h"
@@ -30,7 +28,10 @@
#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
#include "tensorflow/core/profiler/protobuf/op_stats.pb.h"
#include "tensorflow/core/profiler/protobuf/overview_page.pb.h"
+#include "tensorflow/core/profiler/protobuf/steps_db.pb.h"
+#include "tensorflow/core/profiler/protobuf/tf_function.pb.h"
#include "tensorflow/core/profiler/utils/errors.h"
+#include "tensorflow/core/profiler/utils/html_utils.h"
#include "tensorflow/core/profiler/utils/math_utils.h"
#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h"
#include "tensorflow/core/profiler/utils/time_utils.h"
@@ -44,24 +45,23 @@
// statement of suggestion will be made.
constexpr double kLowPrecisionPercentThreshold = 10;
-OverviewPageTip MakeOverviewPageTip(const string& text) {
- OverviewPageTip tip;
- tip.set_link(text);
- return tip;
-}
+struct TfFunctionInfo {
+ absl::string_view function_name;
+ double expensive_call_percent;
+};
-string AnchorElement(const string& url, const string& text) {
- return absl::StrCat("<a href=\"", url, "\" target=\"_blank\">", text, "</a>");
+OverviewPageTip MakeOverviewPageTip(std::string text) {
+ OverviewPageTip tip;
+ tip.set_link(std::move(text));
+ return tip;
}
// Makes a recommendation for looking up a document.
// doc_url is expected to be already be escaped suitably for use in an HTML
// attribute.
-OverviewPageTip MakeOverviewPageTipDocLink(const string& doc_url,
- const string& text) {
- OverviewPageTip tip;
- tip.set_link(AnchorElement(doc_url, text));
- return tip;
+OverviewPageTip MakeOverviewPageTipDocLink(absl::string_view doc_url,
+ absl::string_view text) {
+ return MakeOverviewPageTip(AnchorElement(doc_url, text));
}
void ComputeHostTips(OverviewPageRecommendation* re) {
@@ -75,12 +75,13 @@
void ComputeDeviceTips(HardwareType hardware_type,
OverviewPageRecommendation* re) {
- const string& device_name = HardwareType_Name(hardware_type);
- string timeline_name =
- (hardware_type == tensorflow::profiler::TPU) ? "TPU core" : device_name;
- string op_stats_toolname = (hardware_type == tensorflow::profiler::TPU)
- ? "op_profile"
- : "tensorflow_stats";
+ absl::string_view device_name = HardwareType_Name(hardware_type);
+ absl::string_view timeline_name = device_name;
+ absl::string_view op_stats_toolname = "tensorflow_stats";
+ if (hardware_type == tensorflow::profiler::TPU) {
+ timeline_name = "TPU core";
+ op_stats_toolname = "op_profile";
+ }
*re->add_device_tips() = MakeOverviewPageTip(
absl::StrCat(op_stats_toolname,
" (identify the time-consuming operations "
@@ -121,14 +122,16 @@
} // namespace
-void SetCommonRecommendation(const string& input_classification,
- const string& input_statement,
- const string& output_statement,
+void SetCommonRecommendation(absl::string_view input_classification,
+ absl::string_view input_statement,
+ absl::string_view output_statement,
HardwareType hardware_type,
+ absl::string_view tf_function_statement_html,
OverviewPageRecommendation* re) {
- re->set_bottleneck(input_classification);
- re->set_statement(input_statement);
- re->set_output_statement(output_statement);
+ re->set_bottleneck(std::string(input_classification));
+ re->set_statement(std::string(input_statement));
+ re->set_output_statement(std::string(output_statement));
+ re->set_tf_function_statement_html(std::string(tf_function_statement_html));
ComputeHostTips(re);
ComputeDeviceTips(hardware_type, re);
ComputeDocumentationTips(re);
@@ -245,6 +248,35 @@
return re;
}
+std::string TfFunctionRecommendationHtml(const TfFunctionDb& tf_function_db) {
+ std::vector<TfFunctionInfo> candidates;
+ for (const auto& name_fun : tf_function_db.tf_functions()) {
+ const auto& fun = name_fun.second;
+ if (fun.expensive_call_percent() >= kTfFunctionReportThresholdInPercent) {
+ candidates.push_back({name_fun.first, fun.expensive_call_percent()});
+ }
+ }
+ if (candidates.empty()) return "";
+ auto cmp = [](const TfFunctionInfo& a, const TfFunctionInfo& b) {
+ return a.expensive_call_percent > b.expensive_call_percent;
+ };
+ // Sorts candidates in descending order of expensive_call_percent.
+ absl::c_sort(candidates, cmp);
+ std::string expensive_functions = "";
+ auto num_functions_shown = std::min(
+ static_cast<decltype(candidates)::size_type>(3), candidates.size());
+
+ for (auto i = 0; i < num_functions_shown; i++) {
+ if (i > 0) absl::StrAppend(&expensive_functions, ", ");
+ absl::StrAppend(&expensive_functions, "\"", candidates[i].function_name,
+ "\"");
+ }
+ if (candidates.size() > num_functions_shown)
+ absl::StrAppend(&expensive_functions, " and more");
+ return absl::StrCat("Expensive tf-functions detected (", expensive_functions,
+ ") due to either retracing or eager execution.");
+}
+
OverviewPage ConvertOpStatsToOverviewPage(const OpStats& op_stats,
HardwareType hardware_type) {
OverviewPage overview_page;
@@ -258,9 +290,10 @@
overview_page.input_analysis().step_details());
*overview_page.mutable_recommendation() = ComputeGenericRecommendation(
bottleneck, op_stats.device_op_metrics_db().precision_stats());
- SetCommonRecommendation(bottleneck.input_classification(),
- bottleneck.input_statement(), "", hardware_type,
- overview_page.mutable_recommendation());
+ SetCommonRecommendation(
+ bottleneck.input_classification(), bottleneck.input_statement(), "",
+ hardware_type, TfFunctionRecommendationHtml(op_stats.tf_function_db()),
+ overview_page.mutable_recommendation());
return overview_page;
}
diff --git a/tensorflow/core/profiler/convert/op_stats_to_overview_page.h b/tensorflow/core/profiler/convert/op_stats_to_overview_page.h
index e6d1270..b4b3991 100644
--- a/tensorflow/core/profiler/convert/op_stats_to_overview_page.h
+++ b/tensorflow/core/profiler/convert/op_stats_to_overview_page.h
@@ -17,9 +17,7 @@
#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_OVERVIEW_PAGE_H_
#include "absl/strings/string_view.h"
-#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h"
#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h"
#include "tensorflow/core/profiler/protobuf/input_pipeline.pb.h"
#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
@@ -29,10 +27,16 @@
namespace tensorflow {
namespace profiler {
-void SetCommonRecommendation(const string& input_classification,
- const string& input_statement,
- const string& output_statement,
+// Reports tf-function optimization opportunity in the Overview Page if the
+// expensive-call-time percentage is over this threshold for at least one of
+// the tf-functions profiled.
+const double kTfFunctionReportThresholdInPercent = 20;
+
+void SetCommonRecommendation(absl::string_view input_classification,
+ absl::string_view input_statement,
+ absl::string_view output_statement,
HardwareType hardware_type,
+ absl::string_view tf_function_statement_html,
OverviewPageRecommendation* re);
OverviewPageRecommendation ComputeGenericRecommendation(
@@ -47,6 +51,9 @@
OverviewPage ConvertOpStatsToOverviewPage(const OpStats& op_stats,
HardwareType hardware_type);
+// Returns a html which provides tf-function related recommendation.
+std::string TfFunctionRecommendationHtml(const TfFunctionDb& tf_function_db);
+
void SetRemarks(const OpStats& op_stats, OverviewPageAnalysis* analysis);
} // namespace profiler
diff --git a/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc b/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc
index da409f8..e23813a 100644
--- a/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc
+++ b/tensorflow/core/profiler/convert/op_stats_to_tf_stats.cc
@@ -15,13 +15,12 @@
#include "tensorflow/core/profiler/convert/op_stats_to_tf_stats.h"
-#include "absl/container/flat_hash_set.h"
+#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/convert/op_metrics_to_record.h"
#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
#include "tensorflow/core/profiler/protobuf/op_stats.pb.h"
#include "tensorflow/core/profiler/protobuf/tf_stats.pb.h"
#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h"
-#include "tensorflow/core/profiler/utils/tf_op_utils.h"
#include "tensorflow/core/profiler/utils/time_utils.h"
namespace tensorflow {
diff --git a/tensorflow/core/profiler/convert/op_stats_to_tf_stats_test.cc b/tensorflow/core/profiler/convert/op_stats_to_tf_stats_test.cc
index 3e098da..9ca83b5 100644
--- a/tensorflow/core/profiler/convert/op_stats_to_tf_stats_test.cc
+++ b/tensorflow/core/profiler/convert/op_stats_to_tf_stats_test.cc
@@ -15,10 +15,14 @@
#include "tensorflow/core/profiler/convert/op_stats_to_tf_stats.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/convert/xplane_to_op_stats.h"
-#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
-#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h"
+#include "tensorflow/core/profiler/protobuf/op_stats.pb.h"
+#include "tensorflow/core/profiler/protobuf/tf_stats.pb.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/time_utils.h"
#include "tensorflow/core/profiler/utils/xplane_builder.h"
#include "tensorflow/core/profiler/utils/xplane_schema.h"
@@ -75,8 +79,8 @@
kKernel3DurationNs, /*on_device=*/true, kKernel3,
&device_plane, &stream2);
- const OpStats& op_stats = ConvertXSpaceToOpStats(space);
- const TfStatsDatabase& tf_stats = ConvertOpStatsToTfStats(op_stats);
+ const OpStats op_stats = ConvertXSpaceToOpStats(space);
+ const TfStatsDatabase tf_stats = ConvertOpStatsToTfStats(op_stats);
// TfOp1, TfOp2, Idle
EXPECT_EQ(3, tf_stats.with_idle().tf_stats_record_size());
diff --git a/tensorflow/core/profiler/convert/step_events_to_steps_db.cc b/tensorflow/core/profiler/convert/step_events_to_steps_db.cc
index ed0d83a..e4713cd 100644
--- a/tensorflow/core/profiler/convert/step_events_to_steps_db.cc
+++ b/tensorflow/core/profiler/convert/step_events_to_steps_db.cc
@@ -15,10 +15,18 @@
#include "tensorflow/core/profiler/convert/step_events_to_steps_db.h"
#include <sstream>
+#include <utility>
+#include <vector>
#include "google/protobuf/any.pb.h"
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/protobuf/steps_db.pb.h"
+#include "tensorflow/core/profiler/utils/event_span.h"
+#include "tensorflow/core/profiler/utils/timespan.h"
namespace tensorflow {
namespace profiler {
diff --git a/tensorflow/core/profiler/convert/step_events_to_steps_db.h b/tensorflow/core/profiler/convert/step_events_to_steps_db.h
index b3ea74e..9db6516 100644
--- a/tensorflow/core/profiler/convert/step_events_to_steps_db.h
+++ b/tensorflow/core/profiler/convert/step_events_to_steps_db.h
@@ -16,6 +16,7 @@
#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_STEP_EVENTS_TO_STEPS_DB_H_
#define TENSORFLOW_CORE_PROFILER_CONVERT_STEP_EVENTS_TO_STEPS_DB_H_
+#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/protobuf/steps_db.pb.h"
#include "tensorflow/core/profiler/utils/event_span.h"
diff --git a/tensorflow/core/profiler/convert/trace_events_to_json.cc b/tensorflow/core/profiler/convert/trace_events_to_json.cc
index 9c8176c..07e32ce 100644
--- a/tensorflow/core/profiler/convert/trace_events_to_json.cc
+++ b/tensorflow/core/profiler/convert/trace_events_to_json.cc
@@ -15,9 +15,14 @@
#include "tensorflow/core/profiler/convert/trace_events_to_json.h"
+#include <algorithm>
+#include <map>
+#include <utility>
+
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "include/json/json.h"
+#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/protobuf/trace_events.pb.h"
namespace tensorflow {
diff --git a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc
index 785902e..023d6a7 100644
--- a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc
@@ -15,16 +15,20 @@
#include "tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h"
+#include <functional>
+
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h"
-#include "tensorflow/core/profiler/utils/event_span.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/kernel_stats_utils.h"
#include "tensorflow/core/profiler/utils/tf_op_utils.h"
#include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
#include "tensorflow/core/profiler/utils/trace_utils.h"
#include "tensorflow/core/profiler/utils/xplane_schema.h"
-#include "tensorflow/core/profiler/utils/xplane_utils.h"
+#include "tensorflow/core/profiler/utils/xplane_visitor.h"
namespace tensorflow {
namespace profiler {
diff --git a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h
index 04bd0e8..9c7fca2 100644
--- a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h
+++ b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h
@@ -17,9 +17,7 @@
#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_KERNEL_STATS_DB_H_
#include <functional>
-#include <vector>
-#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/xplane_visitor.h"
@@ -31,6 +29,7 @@
const XPlane& device_trace,
const std::function<void(const XEventVisitor&, KernelReport*)>&
on_kernel_fn);
+
} // namespace profiler
} // namespace tensorflow
diff --git a/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc b/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc
index 1695bd3..5b2a748 100644
--- a/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc
@@ -15,21 +15,28 @@
#include "tensorflow/core/profiler/convert/xplane_to_memory_profile.h"
-#include <cstddef>
+#include <algorithm>
#include <string>
#include <tuple>
+#include <type_traits>
#include <utility>
+#include <vector>
+#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/protobuf/memory_profile.pb.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
#include "tensorflow/core/profiler/utils/xplane_schema.h"
-#include "tensorflow/core/profiler/utils/xplane_utils.h"
+#include "tensorflow/core/profiler/utils/xplane_visitor.h"
namespace tensorflow {
namespace profiler {
diff --git a/tensorflow/core/profiler/convert/xplane_to_memory_profile_test.cc b/tensorflow/core/profiler/convert/xplane_to_memory_profile_test.cc
index 1173e4d..e0d87ac 100644
--- a/tensorflow/core/profiler/convert/xplane_to_memory_profile_test.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_memory_profile_test.cc
@@ -15,7 +15,11 @@
#include "tensorflow/core/profiler/convert/xplane_to_memory_profile.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/protobuf/memory_profile.pb.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/xplane_builder.h"
#include "tensorflow/core/profiler/utils/xplane_schema.h"
#include "tensorflow/core/profiler/utils/xplane_utils.h"
diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc
index 09df59e..4a369b8 100644
--- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc
@@ -15,21 +15,31 @@
#include "tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h"
+#include <algorithm>
+#include <memory>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
+#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h"
#include "tensorflow/core/profiler/convert/op_stack.h"
#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/cost_utils.h"
+#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h"
#include "tensorflow/core/profiler/utils/op_utils.h"
#include "tensorflow/core/profiler/utils/tf_op_utils.h"
+#include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
#include "tensorflow/core/profiler/utils/timespan.h"
#include "tensorflow/core/profiler/utils/trace_utils.h"
+#include "tensorflow/core/profiler/utils/xplane_schema.h"
+#include "tensorflow/core/profiler/utils/xplane_visitor.h"
namespace tensorflow {
namespace profiler {
diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h
index 1a785d0..f2d7fc7 100644
--- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h
+++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h
@@ -21,10 +21,9 @@
#include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h"
#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
-#include "tensorflow/core/profiler/utils/event_span.h"
#include "tensorflow/core/profiler/utils/op_utils.h"
#include "tensorflow/core/profiler/utils/tf_op_utils.h"
-#include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
+#include "tensorflow/core/profiler/utils/xplane_visitor.h"
namespace tensorflow {
namespace profiler {
diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc
index 3e577d0..8bd0443 100644
--- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db_test.cc
@@ -15,9 +15,12 @@
#include "tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h"
+#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h"
#include "tensorflow/core/profiler/utils/time_utils.h"
#include "tensorflow/core/profiler/utils/xplane_builder.h"
diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc
index 7fdd6ff..f008219 100644
--- a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc
@@ -15,7 +15,11 @@
#include "tensorflow/core/profiler/convert/xplane_to_op_stats.h"
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h"
#include "tensorflow/core/profiler/convert/step_events_to_steps_db.h"
#include "tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h"
#include "tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h"
@@ -23,12 +27,19 @@
#include "tensorflow/core/profiler/convert/xplane_to_tf_functions.h"
#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h"
#include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h"
+#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
+#include "tensorflow/core/profiler/protobuf/op_stats.pb.h"
+#include "tensorflow/core/profiler/protobuf/steps_db.pb.h"
#include "tensorflow/core/profiler/protobuf/tf_function.pb.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/event_span.h"
#include "tensorflow/core/profiler/utils/hardware_type_utils.h"
#include "tensorflow/core/profiler/utils/kernel_stats_utils.h"
+#include "tensorflow/core/profiler/utils/tf_op_utils.h"
+#include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
#include "tensorflow/core/profiler/utils/xplane_schema.h"
#include "tensorflow/core/profiler/utils/xplane_utils.h"
+#include "tensorflow/core/profiler/utils/xplane_visitor.h"
namespace tensorflow {
namespace profiler {
diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc b/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc
index c7b140b..7b4652f 100644
--- a/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc
@@ -15,10 +15,14 @@
#include "tensorflow/core/profiler/convert/xplane_to_op_stats.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/core/platform/test.h"
-#include "tensorflow/core/profiler/convert/xplane_to_tf_functions.h"
+#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
+#include "tensorflow/core/profiler/protobuf/op_stats.pb.h"
#include "tensorflow/core/profiler/protobuf/steps_db.pb.h"
+#include "tensorflow/core/profiler/protobuf/tf_function.pb.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/group_events.h"
#include "tensorflow/core/profiler/utils/xplane_builder.h"
#include "tensorflow/core/profiler/utils/xplane_schema.h"
diff --git a/tensorflow/core/profiler/convert/xplane_to_profile_response.cc b/tensorflow/core/profiler/convert/xplane_to_profile_response.cc
index 74dd343..e6fe749 100644
--- a/tensorflow/core/profiler/convert/xplane_to_profile_response.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_profile_response.cc
@@ -15,8 +15,10 @@
#include "tensorflow/core/profiler/convert/xplane_to_profile_response.h"
#include "absl/container/flat_hash_set.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/errors.h"
-#include "tensorflow/core/platform/human_readable_json.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h"
#include "tensorflow/core/profiler/convert/op_stats_to_overview_page.h"
@@ -33,6 +35,7 @@
#include "tensorflow/core/profiler/protobuf/op_stats.pb.h"
#include "tensorflow/core/profiler/protobuf/overview_page.pb.h"
#include "tensorflow/core/profiler/protobuf/tf_stats.pb.h"
+#include "tensorflow/core/profiler/protobuf/trace_events.pb.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/rpc/client/save_profile.h"
#include "tensorflow/core/profiler/utils/xplane_schema.h"
@@ -65,20 +68,26 @@
}
template <typename Proto>
-Status AddJsonToolData(absl::string_view tool_name, const Proto& tool_output,
- ProfileResponse* response) {
- std::string json_output;
- TF_RETURN_IF_ERROR(ProtoToHumanReadableJson(tool_output, &json_output,
- /*ignore_accuracy_loss=*/true));
- auto* tool_data = response->add_tool_data();
- tool_data->set_name(string(tool_name));
- tool_data->mutable_data()->append(json_output.data(), json_output.size());
+Status ConvertProtoToJson(const Proto& proto_output, std::string* json_output) {
+ protobuf::util::JsonPrintOptions json_options;
+ json_options.always_print_primitive_fields = true;
+ auto status = protobuf::util::MessageToJsonString(proto_output, json_output,
+ json_options);
+ if (!status.ok()) {
+ // Convert error_msg google::protobuf::StringPiece (or absl::string_view) to
+ // tensorflow::StringPiece.
+ auto error_msg = status.message();
+ return errors::Internal(
+ strings::StrCat("Could not convert proto to JSON string: ",
+ StringPiece(error_msg.data(), error_msg.length())));
+ }
return Status::OK();
}
// Returns the tool name with extension.
string ToolName(absl::string_view tool) {
if (tool == kTraceViewer) return "trace.json.gz";
+ if (tool == kMemoryProfile) return "memory_profile.json.gz";
return absl::StrCat(tool, ".pb");
}
@@ -130,8 +139,11 @@
if (tools.contains(kMemoryProfile)) {
if (const XPlane* host_plane = FindPlaneWithName(xspace, kHostThreads)) {
MemoryProfile memory_profile = ConvertXPlaneToMemoryProfile(*host_plane);
- TF_RETURN_IF_ERROR(
- AddJsonToolData(ToolName(kMemoryProfile), memory_profile, response));
+ std::string json_output;
+ TF_RETURN_IF_ERROR(ConvertProtoToJson(memory_profile, &json_output));
+ TF_RETURN_IF_ERROR(SaveGzippedToolDataToTensorboardProfile(
+ req.repository_root(), req.session_id(), req.host_name(),
+ ToolName(kMemoryProfile), json_output));
}
}
return Status::OK();
diff --git a/tensorflow/core/profiler/convert/xplane_to_profile_response.h b/tensorflow/core/profiler/convert/xplane_to_profile_response.h
index 84b9fdd..03ca13f 100644
--- a/tensorflow/core/profiler/convert/xplane_to_profile_response.h
+++ b/tensorflow/core/profiler/convert/xplane_to_profile_response.h
@@ -15,8 +15,6 @@
#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_PROFILE_RESPONSE_H_
#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_PROFILE_RESPONSE_H_
-#include "absl/container/flat_hash_set.h"
-#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/profiler/profiler_service.pb.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
diff --git a/tensorflow/core/profiler/convert/xplane_to_profile_response_test.cc b/tensorflow/core/profiler/convert/xplane_to_profile_response_test.cc
index d4965a9..ad9ca10 100644
--- a/tensorflow/core/profiler/convert/xplane_to_profile_response_test.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_profile_response_test.cc
@@ -14,13 +14,14 @@
==============================================================================*/
#include "tensorflow/core/profiler/convert/xplane_to_profile_response.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/profiler/profiler_service.pb.h"
#include "tensorflow/core/profiler/protobuf/input_pipeline.pb.h"
#include "tensorflow/core/profiler/protobuf/overview_page.pb.h"
#include "tensorflow/core/profiler/protobuf/tf_stats.pb.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/xplane_builder.h"
-#include "tensorflow/core/profiler/utils/xplane_schema.h"
namespace tensorflow {
namespace profiler {
diff --git a/tensorflow/core/profiler/convert/xplane_to_step_events.cc b/tensorflow/core/profiler/convert/xplane_to_step_events.cc
index c7dcd62..7bb7cd6 100644
--- a/tensorflow/core/profiler/convert/xplane_to_step_events.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_step_events.cc
@@ -15,10 +15,18 @@
#include "tensorflow/core/profiler/convert/xplane_to_step_events.h"
+#include "absl/container/flat_hash_map.h"
#include "absl/strings/match.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
+#include "tensorflow/core/profiler/utils/event_span.h"
#include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
+#include "tensorflow/core/profiler/utils/timespan.h"
#include "tensorflow/core/profiler/utils/trace_utils.h"
#include "tensorflow/core/profiler/utils/xplane_schema.h"
+#include "tensorflow/core/profiler/utils/xplane_visitor.h"
namespace tensorflow {
namespace profiler {
diff --git a/tensorflow/core/profiler/convert/xplane_to_step_events.h b/tensorflow/core/profiler/convert/xplane_to_step_events.h
index a7ac3b9..62fc898 100644
--- a/tensorflow/core/profiler/convert/xplane_to_step_events.h
+++ b/tensorflow/core/profiler/convert/xplane_to_step_events.h
@@ -18,7 +18,7 @@
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/event_span.h"
-#include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
+#include "tensorflow/core/profiler/utils/xplane_visitor.h"
namespace tensorflow {
namespace profiler {
diff --git a/tensorflow/core/profiler/convert/xplane_to_step_events_test.cc b/tensorflow/core/profiler/convert/xplane_to_step_events_test.cc
index 3e1610c..36e6a2c 100644
--- a/tensorflow/core/profiler/convert/xplane_to_step_events_test.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_step_events_test.cc
@@ -15,7 +15,13 @@
#include "tensorflow/core/profiler/convert/xplane_to_step_events.h"
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
+#include "tensorflow/core/profiler/utils/event_span.h"
#include "tensorflow/core/profiler/utils/group_events.h"
#include "tensorflow/core/profiler/utils/xplane_builder.h"
#include "tensorflow/core/profiler/utils/xplane_schema.h"
diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_functions.cc b/tensorflow/core/profiler/convert/xplane_to_tf_functions.cc
index f768d3b..b25cdc4 100644
--- a/tensorflow/core/profiler/convert/xplane_to_tf_functions.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_tf_functions.cc
@@ -15,20 +15,25 @@
#include "tensorflow/core/profiler/convert/xplane_to_tf_functions.h"
+#include <algorithm>
#include <stack>
+#include <string>
+#include <utility>
+#include <vector>
#include "absl/algorithm/container.h"
-#include "absl/container/flat_hash_map.h"
+#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
+#include "tensorflow/core/profiler/utils/math_utils.h"
#include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
#include "tensorflow/core/profiler/utils/timespan.h"
#include "tensorflow/core/profiler/utils/xplane_schema.h"
-#include "tensorflow/core/profiler/utils/xplane_utils.h"
#include "tensorflow/core/profiler/utils/xplane_visitor.h"
namespace tensorflow {
@@ -54,6 +59,21 @@
DCHECK(false);
}
+double ComputeExpensiveCallPercent(const TfFunction& tf_function) {
+ // Computes the expensiveness in terms of time (rather than count).
+ uint64 total_call_time_ps = 0;
+ uint64 expensive_call_time_ps = 0;
+ for (const auto& mode_metrics : tf_function.metrics()) {
+ const auto mode = mode_metrics.first;
+ const auto& metrics = mode_metrics.second;
+ total_call_time_ps += metrics.self_time_ps();
+ if (mode == TRACED_MODE || mode == EAGER_MODE) {
+ expensive_call_time_ps += metrics.self_time_ps();
+ }
+ }
+ return SafeDivide(100.0 * expensive_call_time_ps, total_call_time_ps);
+}
+
// Each invocation of a tf-function creates an ActivationRecord.
struct ActivationRecord {
std::string function_name; // name of the tf-function.
@@ -133,6 +153,7 @@
CombineTfFunctionMetrics(src_metrics, dst_metrics);
}
}
+ dst->set_expensive_call_percent(ComputeExpensiveCallPercent(*dst));
}
// Execution history of all tf-functions invoked.
@@ -210,6 +231,10 @@
metrics->set_count(metrics->count() + 1);
metrics->set_self_time_ps(metrics->self_time_ps() + self_time_ps);
}
+ for (auto& name_fun : *result.mutable_tf_functions()) {
+ TfFunction& fun = name_fun.second;
+ fun.set_expensive_call_percent(ComputeExpensiveCallPercent(fun));
+ }
return result;
}
diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_functions.h b/tensorflow/core/profiler/convert/xplane_to_tf_functions.h
index 470b22d..df55ac7 100644
--- a/tensorflow/core/profiler/convert/xplane_to_tf_functions.h
+++ b/tensorflow/core/profiler/convert/xplane_to_tf_functions.h
@@ -16,8 +16,9 @@
#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TF_FUNCTIONS_H_
#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TF_FUNCTIONS_H_
+#include <string>
+
#include "tensorflow/core/profiler/protobuf/tf_function.pb.h"
-#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/xplane_visitor.h"
namespace tensorflow {
diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_functions_test.cc b/tensorflow/core/profiler/convert/xplane_to_tf_functions_test.cc
index 253ef1a..25e56d1 100644
--- a/tensorflow/core/profiler/convert/xplane_to_tf_functions_test.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_tf_functions_test.cc
@@ -15,12 +15,17 @@
#include "tensorflow/core/profiler/convert/xplane_to_tf_functions.h"
+#include <string>
+
+#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/profiler/protobuf/tf_function.pb.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
#include "tensorflow/core/profiler/utils/xplane_builder.h"
#include "tensorflow/core/profiler/utils/xplane_schema.h"
#include "tensorflow/core/profiler/utils/xplane_utils.h"
+#include "tensorflow/core/profiler/utils/xplane_visitor.h"
namespace tensorflow {
namespace profiler {
@@ -33,6 +38,8 @@
const absl::string_view kNotTracedNonXla = "notTraced-nonXla";
const absl::string_view kNotTracedXla = "notTraced-xla";
+constexpr double kMaxError = 0.001;
+
TfFunctionDb ConvertXSpaceToTfFunctionDb(const XSpace& space) {
TfFunctionDb result;
const XPlane* host_plane = FindPlaneWithName(space, kHostThreads);
@@ -75,6 +82,8 @@
tf_function_db.tf_functions().at(kFunctionName);
EXPECT_EQ(tf_function.total_tracing_count(), 4);
EXPECT_EQ(tf_function.compiler(), MIXED_COMPILER);
+ EXPECT_NEAR(tf_function.expensive_call_percent(), 90, kMaxError);
+
const auto& metrics = tf_function.metrics();
EXPECT_EQ(metrics.size(), 2);
EXPECT_EQ(metrics.count(TRACED_MODE), 1);
@@ -108,6 +117,7 @@
tf_function_db.tf_functions().at(kOuterFunctionName);
EXPECT_EQ(outer.total_tracing_count(), 1);
EXPECT_EQ(outer.compiler(), OTHER_COMPILER);
+ EXPECT_NEAR(outer.expensive_call_percent(), 100, kMaxError);
const auto& outer_metrics = outer.metrics();
EXPECT_EQ(outer_metrics.size(), 1);
EXPECT_EQ(outer_metrics.count(TRACED_MODE), 1);
@@ -118,6 +128,7 @@
tf_function_db.tf_functions().at(kInnerFunctionName);
EXPECT_EQ(inner.total_tracing_count(), 0);
EXPECT_EQ(inner.compiler(), XLA_COMPILER);
+ EXPECT_NEAR(inner.expensive_call_percent(), 0, kMaxError);
const auto& inner_metrics = inner.metrics();
EXPECT_EQ(inner_metrics.size(), 1);
EXPECT_EQ(inner_metrics.count(NOT_TRACED_MODE), 1);
@@ -148,6 +159,7 @@
tf_function_db.tf_functions().at(kEagerFunctionName);
EXPECT_EQ(eager.total_tracing_count(), 0);
EXPECT_EQ(eager.compiler(), INVALID_COMPILER);
+ EXPECT_NEAR(eager.expensive_call_percent(), 100, kMaxError);
const auto& eager_metrics = eager.metrics();
EXPECT_EQ(eager_metrics.size(), 1);
EXPECT_EQ(eager_metrics.count(EAGER_MODE), 1);
@@ -158,6 +170,7 @@
tf_function_db.tf_functions().at(kConcreteFunctionName);
EXPECT_EQ(concrete.total_tracing_count(), 0);
EXPECT_EQ(concrete.compiler(), INVALID_COMPILER);
+ EXPECT_NEAR(concrete.expensive_call_percent(), 0, kMaxError);
const auto& concrete_metrics = concrete.metrics();
EXPECT_EQ(concrete_metrics.size(), 1);
EXPECT_EQ(concrete_metrics.count(CONCRETE_MODE), 1);
diff --git a/tensorflow/core/profiler/convert/xplane_to_trace_events.cc b/tensorflow/core/profiler/convert/xplane_to_trace_events.cc
index 901f3be..c404f7b 100644
--- a/tensorflow/core/profiler/convert/xplane_to_trace_events.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_trace_events.cc
@@ -15,8 +15,21 @@
#include "tensorflow/core/profiler/convert/xplane_to_trace_events.h"
+#include <stddef.h>
+
+#include <algorithm>
+#include <iterator>
+#include <string>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/protobuf/trace_events.pb.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
#include "tensorflow/core/profiler/utils/xplane_schema.h"
+#include "tensorflow/core/profiler/utils/xplane_visitor.h"
namespace tensorflow {
namespace profiler {
diff --git a/tensorflow/core/profiler/convert/xplane_to_trace_events.h b/tensorflow/core/profiler/convert/xplane_to_trace_events.h
index 5c6fbea..b7bddb7 100644
--- a/tensorflow/core/profiler/convert/xplane_to_trace_events.h
+++ b/tensorflow/core/profiler/convert/xplane_to_trace_events.h
@@ -16,7 +16,8 @@
#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TRACE_EVENTS_H_
#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TRACE_EVENTS_H_
-#include "absl/strings/str_split.h"
+#include <string>
+
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/protobuf/trace_events.pb.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
diff --git a/tensorflow/core/profiler/convert/xplane_to_trace_events_test.cc b/tensorflow/core/profiler/convert/xplane_to_trace_events_test.cc
index afff5e6..b9a9fe0 100644
--- a/tensorflow/core/profiler/convert/xplane_to_trace_events_test.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_trace_events_test.cc
@@ -16,8 +16,9 @@
#include "tensorflow/core/profiler/convert/xplane_to_trace_events.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/profiler/protobuf/trace_events.pb.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/xplane_builder.h"
-#include "tensorflow/core/profiler/utils/xplane_schema.h"
namespace tensorflow {
namespace profiler {
diff --git a/tensorflow/core/profiler/internal/BUILD b/tensorflow/core/profiler/internal/BUILD
index 9fab42c..85fa4e7 100644
--- a/tensorflow/core/profiler/internal/BUILD
+++ b/tensorflow/core/profiler/internal/BUILD
@@ -423,8 +423,10 @@
deps = [
":traceme_recorder",
"//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
"@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest_main",
+ "@com_google_googletest//:gtest",
],
)
@@ -434,7 +436,6 @@
deps = [
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
- "//tensorflow/core/profiler:profiler_options_proto_cc",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
],
)
@@ -444,6 +445,7 @@
hdrs = ["profiler_factory.h"],
deps = [
":profiler_interface",
+ "//tensorflow/core/profiler:profiler_options_proto_cc",
] + if_static([
":profiler_factory_impl",
]),
@@ -461,8 +463,7 @@
deps = [
":profiler_interface",
"//tensorflow/core:lib",
- "//tensorflow/core:protos_all_cc",
- "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
+ "//tensorflow/core/profiler:profiler_options_proto_cc",
],
alwayslink = True,
)
@@ -513,15 +514,10 @@
srcs = ["scoped_annotation_test.cc"],
deps = [
":annotation_stack",
- "//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
- "//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
- "//tensorflow/core:testlib",
"//tensorflow/core/profiler/lib:scoped_annotation",
"@com_google_absl//absl/strings",
],
@@ -544,6 +540,6 @@
":parse_annotation",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
- "//tensorflow/core:testlib",
+ "@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/core/profiler/internal/annotation_stack.cc b/tensorflow/core/profiler/internal/annotation_stack.cc
index 4cfd102..4c15ca4 100644
--- a/tensorflow/core/profiler/internal/annotation_stack.cc
+++ b/tensorflow/core/profiler/internal/annotation_stack.cc
@@ -15,6 +15,10 @@
#include "tensorflow/core/profiler/internal/annotation_stack.h"
+#include <atomic>
+
+#include "tensorflow/core/platform/types.h"
+
namespace tensorflow {
namespace profiler {
namespace internal {
diff --git a/tensorflow/core/profiler/internal/annotation_stack.h b/tensorflow/core/profiler/internal/annotation_stack.h
index 38cd962..e626c4c 100644
--- a/tensorflow/core/profiler/internal/annotation_stack.h
+++ b/tensorflow/core/profiler/internal/annotation_stack.h
@@ -18,6 +18,7 @@
#include <stddef.h>
#include <atomic>
+#include <utility>
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
diff --git a/tensorflow/core/profiler/internal/cpu/BUILD b/tensorflow/core/profiler/internal/cpu/BUILD
index e156667..c24c8c7 100644
--- a/tensorflow/core/profiler/internal/cpu/BUILD
+++ b/tensorflow/core/profiler/internal/cpu/BUILD
@@ -18,6 +18,7 @@
"//tensorflow/core/profiler/utils:tf_op_utils",
"//tensorflow/core/profiler/utils:xplane_builder",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/strings",
],
)
@@ -26,10 +27,10 @@
srcs = ["host_tracer.cc"],
deps = [
":host_tracer_utils",
- "//tensorflow/core:core_cpu_lib",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/profiler:profiler_options_proto_cc",
"//tensorflow/core/profiler/internal:profiler_factory",
"//tensorflow/core/profiler/internal:profiler_interface",
"//tensorflow/core/profiler/internal:traceme_recorder",
@@ -50,14 +51,17 @@
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core/profiler:profiler_options_proto_cc",
"//tensorflow/core/profiler/internal:profiler_interface",
"//tensorflow/core/profiler/lib:profiler_session",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
"//tensorflow/core/profiler/utils:xplane_schema",
"//tensorflow/core/profiler/utils:xplane_visitor",
+ "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
- "@com_google_googletest//:gtest_main",
+ "@com_google_googletest//:gtest",
],
)
@@ -67,17 +71,14 @@
copts = ["-fexceptions"],
features = ["-use_header_modules"],
deps = [
- "//tensorflow/core:core_cpu_lib",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/profiler:profiler_options_proto_cc",
"//tensorflow/core/profiler/internal:profiler_factory",
"//tensorflow/core/profiler/internal:profiler_interface",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
- "//tensorflow/core/profiler/utils:xplane_schema",
- "//tensorflow/core/profiler/utils:xplane_utils",
"//tensorflow/python/profiler/internal:python_hooks",
- "@com_google_absl//absl/strings",
],
alwayslink = True,
)
@@ -86,9 +87,12 @@
name = "metadata_collector",
srcs = ["metadata_collector.cc"],
deps = [
+ "//tensorflow/compiler/xla/service:hlo_proto_cc",
"//tensorflow/compiler/xla/service/gpu:gpu_debug_info_manager",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/profiler:profiler_options_proto_cc",
"//tensorflow/core/profiler/internal:profiler_factory",
"//tensorflow/core/profiler/internal:profiler_interface",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
diff --git a/tensorflow/core/profiler/internal/cpu/host_tracer.cc b/tensorflow/core/profiler/internal/cpu/host_tracer.cc
index 30b87c8..be1a7a2 100644
--- a/tensorflow/core/profiler/internal/cpu/host_tracer.cc
+++ b/tensorflow/core/profiler/internal/cpu/host_tracer.cc
@@ -12,18 +12,23 @@
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <memory>
+#include <string>
#include <utility>
#include <vector>
#include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/platform/env_time.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/internal/cpu/host_tracer_utils.h"
#include "tensorflow/core/profiler/internal/profiler_factory.h"
#include "tensorflow/core/profiler/internal/profiler_interface.h"
#include "tensorflow/core/profiler/internal/traceme_recorder.h"
+#include "tensorflow/core/profiler/profiler_options.pb.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/xplane_schema.h"
#include "tensorflow/core/profiler/utils/xplane_utils.h"
@@ -119,8 +124,8 @@
std::vector<absl::string_view> parts =
absl::StrSplit(event.name, kUserMetadataMarker);
if (parts.size() >= 2) {
- ns->set_node_name(string(parts[0]));
- ns->set_timeline_label(string(parts[1]));
+ ns->set_node_name(std::string(parts[0]));
+ ns->set_timeline_label(std::string(parts[1]));
} else {
ns->set_node_name(std::move(event.name));
}
diff --git a/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc b/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc
index e32ba92..499b7b6 100644
--- a/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc
+++ b/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc
@@ -12,17 +12,23 @@
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <memory>
+#include <ostream>
#include <string>
#include <gmock/gmock.h>
-#include <gtest/gtest.h>
+#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/internal/profiler_interface.h"
#include "tensorflow/core/profiler/lib/profiler_session.h"
#include "tensorflow/core/profiler/lib/traceme.h"
+#include "tensorflow/core/profiler/profiler_options.pb.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/xplane_schema.h"
#include "tensorflow/core/profiler/utils/xplane_visitor.h"
@@ -38,13 +44,13 @@
using ::testing::UnorderedElementsAre;
-NodeExecStats MakeNodeStats(const string& name, uint32 thread_id,
- const string& label = "") {
+NodeExecStats MakeNodeStats(absl::string_view name, uint32 thread_id,
+ absl::string_view label = "") {
NodeExecStats ns;
- ns.set_node_name(name);
+ ns.set_node_name(std::string(name));
ns.set_thread_id(thread_id);
if (!label.empty()) {
- ns.set_timeline_label(label);
+ ns.set_timeline_label(std::string(label));
}
return ns;
}
@@ -109,7 +115,7 @@
TEST(HostTracerTest, CollectsTraceMeEventsAsXSpace) {
uint32 thread_id;
- string thread_name = "MyThreadName";
+ std::string thread_name = "MyThreadName";
XSpace space;
// We start a thread with a known and controled name. As of the time of
diff --git a/tensorflow/core/profiler/internal/cpu/host_tracer_utils.cc b/tensorflow/core/profiler/internal/cpu/host_tracer_utils.cc
index a4709ae..2e5d8ac 100644
--- a/tensorflow/core/profiler/internal/cpu/host_tracer_utils.cc
+++ b/tensorflow/core/profiler/internal/cpu/host_tracer_utils.cc
@@ -14,10 +14,13 @@
==============================================================================*/
#include "tensorflow/core/profiler/internal/cpu/host_tracer_utils.h"
+#include <string>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
+#include "absl/strings/string_view.h"
+#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/internal/parse_annotation.h"
#include "tensorflow/core/profiler/internal/traceme_recorder.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
diff --git a/tensorflow/core/profiler/internal/cpu/metadata_collector.cc b/tensorflow/core/profiler/internal/cpu/metadata_collector.cc
index c6aa784..58da20a 100644
--- a/tensorflow/core/profiler/internal/cpu/metadata_collector.cc
+++ b/tensorflow/core/profiler/internal/cpu/metadata_collector.cc
@@ -13,17 +13,23 @@
limitations under the License.
==============================================================================*/
+#include <memory>
+#include <string>
#include <utility>
#include <vector>
#include "tensorflow/compiler/xla/service/gpu/gpu_debug_info_manager.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/profiler/internal/profiler_factory.h"
#include "tensorflow/core/profiler/internal/profiler_interface.h"
+#include "tensorflow/core/profiler/profiler_options.pb.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/xplane_builder.h"
#include "tensorflow/core/profiler/utils/xplane_schema.h"
#include "tensorflow/core/profiler/utils/xplane_utils.h"
+#include "tensorflow/core/protobuf/config.pb.h"
namespace tensorflow {
namespace profiler {
diff --git a/tensorflow/core/profiler/internal/cpu/python_tracer.cc b/tensorflow/core/profiler/internal/cpu/python_tracer.cc
index 103db6e..d684cb8 100644
--- a/tensorflow/core/profiler/internal/cpu/python_tracer.cc
+++ b/tensorflow/core/profiler/internal/cpu/python_tracer.cc
@@ -12,18 +12,17 @@
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <utility>
-#include <vector>
+#include <memory>
-#include "absl/strings/str_split.h"
-#include "tensorflow/core/framework/step_stats.pb.h"
-#include "tensorflow/core/platform/env_time.h"
#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/profiler/internal/profiler_factory.h"
#include "tensorflow/core/profiler/internal/profiler_interface.h"
+#include "tensorflow/core/profiler/profiler_options.pb.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
+#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/util/env_var.h"
#include "tensorflow/python/profiler/internal/python_hooks.h"
diff --git a/tensorflow/core/profiler/internal/parse_annotation.cc b/tensorflow/core/profiler/internal/parse_annotation.cc
index 2a3fa3f..32c26be 100644
--- a/tensorflow/core/profiler/internal/parse_annotation.cc
+++ b/tensorflow/core/profiler/internal/parse_annotation.cc
@@ -15,6 +15,9 @@
#include "tensorflow/core/profiler/internal/parse_annotation.h"
#include <stack>
+#include <string>
+#include <utility>
+#include <vector>
#include "absl/strings/ascii.h"
#include "absl/strings/str_split.h"
diff --git a/tensorflow/core/profiler/internal/parse_annotation.h b/tensorflow/core/profiler/internal/parse_annotation.h
index 6c2e536..bb0f122 100644
--- a/tensorflow/core/profiler/internal/parse_annotation.h
+++ b/tensorflow/core/profiler/internal/parse_annotation.h
@@ -16,7 +16,6 @@
#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_PARSE_ANNOTATION_H_
#define TENSORFLOW_CORE_PROFILER_INTERNAL_PARSE_ANNOTATION_H_
-#include <utility>
#include <vector>
#include "absl/strings/string_view.h"
diff --git a/tensorflow/core/profiler/internal/parse_annotation_test.cc b/tensorflow/core/profiler/internal/parse_annotation_test.cc
index 4d4a2d5..e5d876a 100644
--- a/tensorflow/core/profiler/internal/parse_annotation_test.cc
+++ b/tensorflow/core/profiler/internal/parse_annotation_test.cc
@@ -14,6 +14,9 @@
==============================================================================*/
#include "tensorflow/core/profiler/internal/parse_annotation.h"
+#include <vector>
+
+#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
diff --git a/tensorflow/core/profiler/internal/profiler_factory.cc b/tensorflow/core/profiler/internal/profiler_factory.cc
index e2bae59..5152e79 100644
--- a/tensorflow/core/profiler/internal/profiler_factory.cc
+++ b/tensorflow/core/profiler/internal/profiler_factory.cc
@@ -14,8 +14,14 @@
==============================================================================*/
#include "tensorflow/core/profiler/internal/profiler_factory.h"
+#include <memory>
+#include <utility>
+#include <vector>
+
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/profiler/internal/profiler_interface.h"
+#include "tensorflow/core/profiler/profiler_options.pb.h"
namespace tensorflow {
namespace profiler {
diff --git a/tensorflow/core/profiler/internal/profiler_factory.h b/tensorflow/core/profiler/internal/profiler_factory.h
index 6bcdcf2..c223d72 100644
--- a/tensorflow/core/profiler/internal/profiler_factory.h
+++ b/tensorflow/core/profiler/internal/profiler_factory.h
@@ -19,6 +19,7 @@
#include <vector>
#include "tensorflow/core/profiler/internal/profiler_interface.h"
+#include "tensorflow/core/profiler/profiler_options.pb.h"
namespace tensorflow {
namespace profiler {
diff --git a/tensorflow/core/profiler/internal/profiler_interface.h b/tensorflow/core/profiler/internal/profiler_interface.h
index 79dfc7a..9fe85e3 100644
--- a/tensorflow/core/profiler/internal/profiler_interface.h
+++ b/tensorflow/core/profiler/internal/profiler_interface.h
@@ -16,7 +16,6 @@
#define TENSORFLOW_CORE_PROFILER_INTERNAL_PROFILER_INTERFACE_H_
#include "tensorflow/core/platform/status.h"
-#include "tensorflow/core/profiler/profiler_options.pb.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/protobuf/config.pb.h"
diff --git a/tensorflow/core/profiler/internal/scoped_annotation_test.cc b/tensorflow/core/profiler/internal/scoped_annotation_test.cc
index 70a627f..50c1244 100644
--- a/tensorflow/core/profiler/internal/scoped_annotation_test.cc
+++ b/tensorflow/core/profiler/internal/scoped_annotation_test.cc
@@ -15,10 +15,11 @@
#include "tensorflow/core/profiler/lib/scoped_annotation.h"
+#include <string>
+
#include "absl/strings/str_cat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
-#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/internal/annotation_stack.h"
namespace tensorflow {
@@ -48,11 +49,13 @@
EXPECT_EQ(AnnotationStack::Get(), ""); // not enabled
}
-string GenerateRandomString(int length) { return string(length, 'a'); }
+std::string GenerateRandomString(int length) {
+ return std::string(length, 'a');
+}
void BM_ScopedAnnotationDisabled(int iters, int annotation_size) {
testing::StopTiming();
- string annotation = GenerateRandomString(annotation_size);
+ std::string annotation = GenerateRandomString(annotation_size);
testing::StartTiming();
for (int i = 0; i < iters; i++) {
ScopedAnnotation trace(annotation);
@@ -64,7 +67,7 @@
void BM_ScopedAnnotationEnabled(int iters, int annotation_size) {
testing::StopTiming();
- string annotation = GenerateRandomString(annotation_size);
+ std::string annotation = GenerateRandomString(annotation_size);
AnnotationStack::Enable(true);
testing::StartTiming();
for (int i = 0; i < iters; i++) {
@@ -78,7 +81,7 @@
void BM_ScopedAnnotationEnabled_Nested(int iters, int annotation_size) {
testing::StopTiming();
- string annotation = GenerateRandomString(annotation_size);
+ std::string annotation = GenerateRandomString(annotation_size);
AnnotationStack::Enable(true);
testing::StartTiming();
for (int i = 0; i < iters; i++) {
diff --git a/tensorflow/core/profiler/internal/tfprof_stats.cc b/tensorflow/core/profiler/internal/tfprof_stats.cc
index 22b3bdc..56e6e2b 100644
--- a/tensorflow/core/profiler/internal/tfprof_stats.cc
+++ b/tensorflow/core/profiler/internal/tfprof_stats.cc
@@ -58,7 +58,6 @@
ckpt_reader_(std::move(ckpt_reader)) {
CHECK(graph) << "Must at least have GraphDef";
- absl::PrintF("Parsing Inputs...\n");
AddGraph(std::move(graph));
if (run_meta && run_meta->has_step_stats()) {
AddRunMeta(0, std::move(run_meta));
diff --git a/tensorflow/core/profiler/internal/traceme_recorder.cc b/tensorflow/core/profiler/internal/traceme_recorder.cc
index 365e399..268585b 100644
--- a/tensorflow/core/profiler/internal/traceme_recorder.cc
+++ b/tensorflow/core/profiler/internal/traceme_recorder.cc
@@ -16,8 +16,18 @@
#include <stddef.h>
+#include <algorithm>
+#include <atomic>
+#include <new>
+#include <utility>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace profiler {
diff --git a/tensorflow/core/profiler/internal/traceme_recorder.h b/tensorflow/core/profiler/internal/traceme_recorder.h
index 8b5b32c..1da7d4c 100644
--- a/tensorflow/core/profiler/internal/traceme_recorder.h
+++ b/tensorflow/core/profiler/internal/traceme_recorder.h
@@ -15,8 +15,6 @@
#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TRACEME_RECORDER_H_
#define TENSORFLOW_CORE_PROFILER_INTERNAL_TRACEME_RECORDER_H_
-#include <stddef.h>
-
#include <atomic>
#include <vector>
diff --git a/tensorflow/core/profiler/internal/traceme_recorder_test.cc b/tensorflow/core/profiler/internal/traceme_recorder_test.cc
index 9047888..8d7abc9 100644
--- a/tensorflow/core/profiler/internal/traceme_recorder_test.cc
+++ b/tensorflow/core/profiler/internal/traceme_recorder_test.cc
@@ -15,19 +15,28 @@
#include "tensorflow/core/profiler/internal/traceme_recorder.h"
#include <atomic>
+#include <istream>
+#include <set>
+#include <string>
+#include <utility>
+#include <vector>
#include <gmock/gmock.h>
-#include <gtest/gtest.h>
#include "absl/strings/str_cat.h"
-#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/env_time.h"
+#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/notification.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/threadpool.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace profiler {
namespace {
+using ::testing::ElementsAre;
+
MATCHER_P(Named, name, "") { return arg.name == name; }
constexpr static uint64 kNanosInSec = 1000000000;
@@ -45,7 +54,7 @@
ASSERT_EQ(results.size(), 1);
EXPECT_THAT(results[0].events,
- ::testing::ElementsAre(Named("during1"), Named("during2")));
+ ElementsAre(Named("during1"), Named("during2")));
}
void SpinNanos(int nanos) {
diff --git a/tensorflow/core/profiler/profiler_service.proto b/tensorflow/core/profiler/profiler_service.proto
index 37ca408..a096a10e 100644
--- a/tensorflow/core/profiler/profiler_service.proto
+++ b/tensorflow/core/profiler/profiler_service.proto
@@ -10,6 +10,10 @@
service ProfilerService {
// Starts a profiling session, blocks until it completes, and returns data.
rpc Profile(ProfileRequest) returns (ProfileResponse) {}
+ // Signal to terminate the Profile rpc for a on-going profiling session,
+ // The Profile rpc will return successfully and prematurely without timeout.
+ // This is used by programmatic mode to end the session in workers.
+ rpc Terminate(TerminateRequest) returns (TerminateResponse) {}
// Collects profiling data and returns user-friendly metrics.
rpc Monitor(MonitorRequest) returns (MonitorResponse) {}
}
@@ -81,6 +85,13 @@
// next-field: 8
}
+message TerminateRequest {
+ // Which session id to terminate.
+ string session_id = 1;
+}
+
+message TerminateResponse {}
+
message MonitorRequest {
// Duration for which to profile between each update.
uint64 duration_ms = 1;
diff --git a/tensorflow/core/profiler/protobuf/overview_page.proto b/tensorflow/core/profiler/protobuf/overview_page.proto
index 8c83dbd..018aa75 100644
--- a/tensorflow/core/profiler/protobuf/overview_page.proto
+++ b/tensorflow/core/profiler/protobuf/overview_page.proto
@@ -84,6 +84,9 @@
// A statement for output that recommends the next steps for investigating the
// bottleneck.
string output_statement = 9;
+ // A statement that recommends the next steps for investigating tf-function
+ // related bottleneck (it is a html so that it can link to other tools/docs.
+ string tf_function_statement_html = 10;
// A list of tips for improving host performance.
repeated OverviewPageTip host_tips = 3;
// A list of tips for improving device performance.
diff --git a/tensorflow/core/profiler/protobuf/tf_function.proto b/tensorflow/core/profiler/protobuf/tf_function.proto
index fe07c00..1f5e153 100644
--- a/tensorflow/core/profiler/protobuf/tf_function.proto
+++ b/tensorflow/core/profiler/protobuf/tf_function.proto
@@ -49,6 +49,9 @@
int64 total_tracing_count = 2;
// Compiler used to compile this function.
TfFunctionCompiler compiler = 3;
+ // Percentage of time spent in the expensive calls to this function in the
+ // profiled period.
+ double expensive_call_percent = 4;
}
// Statistics for all tf-functions.
diff --git a/tensorflow/core/profiler/rpc/BUILD b/tensorflow/core/profiler/rpc/BUILD
index b5b631f..1e572df 100644
--- a/tensorflow/core/profiler/rpc/BUILD
+++ b/tensorflow/core/profiler/rpc/BUILD
@@ -19,9 +19,7 @@
"//tensorflow/core/profiler/convert:xplane_to_profile_response",
"//tensorflow/core/profiler/lib:profiler_session_headers",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
- "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
- "@com_google_absl//absl/strings",
tf_grpc_cc_dependency(),
],
)
diff --git a/tensorflow/core/profiler/rpc/client/BUILD b/tensorflow/core/profiler/rpc/client/BUILD
index bde5708..609f98a 100644
--- a/tensorflow/core/profiler/rpc/client/BUILD
+++ b/tensorflow/core/profiler/rpc/client/BUILD
@@ -12,7 +12,9 @@
deps = [
":save_profile",
"//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
"//tensorflow/core/profiler:profiler_analysis_proto_cc",
+ "//tensorflow/core/profiler:profiler_options_proto_cc",
"//tensorflow/core/profiler:profiler_service_proto_cc",
"@com_google_absl//absl/strings",
tf_grpc_cc_dependency(),
diff --git a/tensorflow/core/profiler/rpc/client/capture_profile.cc b/tensorflow/core/profiler/rpc/client/capture_profile.cc
index 5335d18..a8642af 100644
--- a/tensorflow/core/profiler/rpc/client/capture_profile.cc
+++ b/tensorflow/core/profiler/rpc/client/capture_profile.cc
@@ -14,18 +14,25 @@
==============================================================================*/
#include "tensorflow/core/profiler/rpc/client/capture_profile.h"
+#include <iostream>
+#include <limits>
+#include <memory>
#include <vector>
#include "grpcpp/grpcpp.h"
-#include "absl/strings/escaping.h"
-#include "absl/strings/match.h"
#include "absl/strings/numbers.h"
+#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "tensorflow/core/platform/errors.h"
-#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/profiler_analysis.grpc.pb.h"
+#include "tensorflow/core/profiler/profiler_analysis.pb.h"
+#include "tensorflow/core/profiler/profiler_options.pb.h"
+#include "tensorflow/core/profiler/profiler_service.grpc.pb.h"
+#include "tensorflow/core/profiler/profiler_service.pb.h"
#include "tensorflow/core/profiler/rpc/client/save_profile.h"
+#include "tensorflow/core/protobuf/error_codes.pb.h"
namespace tensorflow {
namespace profiler {
diff --git a/tensorflow/core/profiler/rpc/client/capture_profile.h b/tensorflow/core/profiler/rpc/client/capture_profile.h
index 1bde73f..c809d20 100644
--- a/tensorflow/core/profiler/rpc/client/capture_profile.h
+++ b/tensorflow/core/profiler/rpc/client/capture_profile.h
@@ -19,7 +19,7 @@
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/profiler/profiler_service.grpc.pb.h"
+#include "tensorflow/core/profiler/profiler_options.pb.h"
namespace tensorflow {
namespace profiler {
diff --git a/tensorflow/core/profiler/rpc/client/save_profile.cc b/tensorflow/core/profiler/rpc/client/save_profile.cc
index e328bf1..9cf2e29 100644
--- a/tensorflow/core/profiler/rpc/client/save_profile.cc
+++ b/tensorflow/core/profiler/rpc/client/save_profile.cc
@@ -17,7 +17,6 @@
#include <initializer_list>
#include <memory>
-#include <ostream>
#include <sstream>
#include <string>
#include <vector>
diff --git a/tensorflow/core/profiler/rpc/profiler_server.cc b/tensorflow/core/profiler/rpc/profiler_server.cc
index 4d2f3c3..f05a829 100644
--- a/tensorflow/core/profiler/rpc/profiler_server.cc
+++ b/tensorflow/core/profiler/rpc/profiler_server.cc
@@ -16,17 +16,19 @@
#include "tensorflow/core/profiler/rpc/profiler_server.h"
#include <memory>
-#include <utility>
+#include <string>
#include "grpcpp/grpcpp.h"
#include "absl/strings/str_cat.h"
-#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/profiler_service.grpc.pb.h"
#include "tensorflow/core/profiler/rpc/profiler_service_impl.h"
namespace tensorflow {
void ProfilerServer::StartProfilerServer(int32 port) {
- string server_address = absl::StrCat("0.0.0.0:", port);
+ std::string server_address = absl::StrCat("0.0.0.0:", port);
service_ = CreateProfilerService();
::grpc::ServerBuilder builder;
builder.AddListeningPort(server_address, ::grpc::InsecureServerCredentials());
diff --git a/tensorflow/core/profiler/rpc/profiler_service_impl.cc b/tensorflow/core/profiler/rpc/profiler_service_impl.cc
index 555f4c3..0a234d7 100644
--- a/tensorflow/core/profiler/rpc/profiler_service_impl.cc
+++ b/tensorflow/core/profiler/rpc/profiler_service_impl.cc
@@ -15,18 +15,23 @@
#include "tensorflow/core/profiler/rpc/profiler_service_impl.h"
+#include <memory>
+
#include "grpcpp/support/status.h"
-#include "absl/container/flat_hash_set.h"
+#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
-#include "absl/strings/str_cat.h"
-#include "absl/strings/str_join.h"
-#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/env_time.h"
#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/profiler/convert/xplane_to_profile_response.h"
#include "tensorflow/core/profiler/internal/profiler_interface.h"
#include "tensorflow/core/profiler/lib/profiler_session.h"
+#include "tensorflow/core/profiler/profiler_service.grpc.pb.h"
+#include "tensorflow/core/profiler/profiler_service.pb.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
namespace tensorflow {
@@ -61,11 +66,16 @@
}
Env* env = Env::Default();
- for (size_t i = 0; i < req->duration_ms(); ++i) {
+ for (uint64 i = 0; i < req->duration_ms(); ++i) {
env->SleepForMicroseconds(EnvTime::kMillisToMicros);
if (ctx->IsCancelled()) {
return ::grpc::Status::CANCELLED;
}
+ if (TF_PREDICT_FALSE(IsStopped(req->session_id()))) {
+ mutex_lock lock(mutex_);
+ stop_signals_per_session_.erase(req->session_id());
+ break;
+ }
}
status = CollectDataToResponse(*req, profiler.get(), response);
@@ -76,6 +86,25 @@
return ::grpc::Status::OK;
}
+
+ ::grpc::Status Terminate(::grpc::ServerContext* ctx,
+ const TerminateRequest* req,
+ TerminateResponse* response) override {
+ mutex_lock lock(mutex_);
+ stop_signals_per_session_[req->session_id()] = true;
+ return ::grpc::Status::OK;
+ }
+
+ private:
+ bool IsStopped(const std::string& session_id) {
+ mutex_lock lock(mutex_);
+ auto it = stop_signals_per_session_.find(session_id);
+ return it != stop_signals_per_session_.end() && it->second;
+ }
+
+ mutex mutex_;
+ absl::flat_hash_map<std::string, bool> stop_signals_per_session_
+ GUARDED_BY(mutex_);
};
} // namespace
diff --git a/tensorflow/core/profiler/rpc/profiler_service_impl.h b/tensorflow/core/profiler/rpc/profiler_service_impl.h
index 4a7636c..00a850a 100644
--- a/tensorflow/core/profiler/rpc/profiler_service_impl.h
+++ b/tensorflow/core/profiler/rpc/profiler_service_impl.h
@@ -15,10 +15,8 @@
#ifndef TENSORFLOW_CORE_PROFILER_RPC_PROFILER_SERVICE_IMPL_H_
#define TENSORFLOW_CORE_PROFILER_RPC_PROFILER_SERVICE_IMPL_H_
-#include "grpcpp/grpcpp.h"
-#include "grpcpp/server_context.h"
-#include "grpcpp/support/status.h"
-#include "tensorflow/core/profiler/lib/profiler_session.h"
+#include <memory>
+
#include "tensorflow/core/profiler/profiler_service.grpc.pb.h"
namespace tensorflow {
diff --git a/tensorflow/core/profiler/utils/BUILD b/tensorflow/core/profiler/utils/BUILD
index ad26dcc..ca20236 100644
--- a/tensorflow/core/profiler/utils/BUILD
+++ b/tensorflow/core/profiler/utils/BUILD
@@ -30,6 +30,7 @@
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/profiler/protobuf:op_metrics_proto_cc",
+ "@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
],
@@ -51,6 +52,14 @@
)
cc_library(
+ name = "html_utils",
+ hdrs = ["html_utils.h"],
+ deps = [
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
name = "op_metrics_db_utils",
srcs = ["op_metrics_db_utils.cc"],
hdrs = ["op_metrics_db_utils.h"],
@@ -83,7 +92,6 @@
hdrs = ["tf_op_utils.h"],
deps = [
"//tensorflow/core:regexp_internal",
- "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings",
],
)
@@ -96,6 +104,7 @@
":tf_op_utils",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/strings",
],
)
@@ -156,6 +165,7 @@
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
+ "@com_google_absl//absl/strings",
],
)
@@ -170,7 +180,6 @@
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
- "@com_google_absl//absl/types:span",
],
)
@@ -196,7 +205,6 @@
name = "xplane_utils_test",
srcs = ["xplane_utils_test.cc"],
deps = [
- ":time_utils",
":xplane_builder",
":xplane_utils",
":xplane_visitor",
@@ -205,6 +213,8 @@
"//tensorflow/core:test_main",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -232,6 +242,7 @@
deps = [
":xplane_schema",
":xplane_visitor",
+ "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
],
)
@@ -243,9 +254,11 @@
deps = [
":tf_op_utils",
":tf_xplane_visitor",
+ ":xplane_builder",
":xplane_schema",
":xplane_utils",
":xplane_visitor",
+ "//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
"@com_google_absl//absl/container:flat_hash_map",
@@ -263,10 +276,13 @@
":xplane_builder",
":xplane_schema",
":xplane_utils",
+ ":xplane_visitor",
+ "//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -281,10 +297,13 @@
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/grappler/costs:cost_estimator",
+ "//tensorflow/core/grappler/costs:op_context",
"//tensorflow/core/grappler/costs:op_level_cost_estimator",
"//tensorflow/core/grappler/costs:op_performance_data_cc",
- "//tensorflow/core/profiler/protobuf:xplane_proto_cc",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -296,6 +315,7 @@
":group_events",
":tf_op_utils",
":tf_xplane_visitor",
+ ":time_utils",
":timespan",
":trace_utils",
":xplane_builder",
@@ -305,8 +325,10 @@
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
+ "@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -321,6 +343,8 @@
":xplane_builder",
":xplane_schema",
":xplane_utils",
+ ":xplane_visitor",
+ "//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
@@ -347,10 +371,10 @@
":xplane_builder",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "//tensorflow/core/framework:protos_all_cc",
"//tensorflow/core/profiler/protobuf:tfstreamz_proto_cc",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
- "@com_google_absl//absl/strings:str_format",
],
)
diff --git a/tensorflow/core/profiler/utils/cost_utils.cc b/tensorflow/core/profiler/utils/cost_utils.cc
index 754aa65..a94f09b 100644
--- a/tensorflow/core/profiler/utils/cost_utils.cc
+++ b/tensorflow/core/profiler/utils/cost_utils.cc
@@ -15,12 +15,27 @@
#include "tensorflow/core/profiler/utils/cost_utils.h"
+#include <string>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
+#include "absl/strings/strip.h"
+#include "absl/types/optional.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/grappler/costs/cost_estimator.h"
+#include "tensorflow/core/grappler/costs/op_context.h"
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/utils/tf_op_utils.h"
#include "tensorflow/core/profiler/utils/xplane_schema.h"
+#include "tensorflow/core/profiler/utils/xplane_visitor.h"
namespace tensorflow {
namespace profiler {
diff --git a/tensorflow/core/profiler/utils/cost_utils.h b/tensorflow/core/profiler/utils/cost_utils.h
index f109555..a778bca 100644
--- a/tensorflow/core/profiler/utils/cost_utils.h
+++ b/tensorflow/core/profiler/utils/cost_utils.h
@@ -15,12 +15,13 @@
#ifndef TENSORFLOW_CORE_PROFILER_UTILS_COST_UTILS_H_
#define TENSORFLOW_CORE_PROFILER_UTILS_COST_UTILS_H_
-#include <set>
+#include <string>
-#include "absl/strings/string_view.h"
+#include "absl/container/flat_hash_set.h"
+#include "tensorflow/core/grappler/costs/cost_estimator.h"
#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
+#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/utils/xplane_visitor.h"
namespace tensorflow {
@@ -46,7 +47,8 @@
OpRoofLineStats Predict(const XEventVisitor& event);
private:
- std::set<string> unsupported_ops_; // summary for unsupported ops.
+ absl::flat_hash_set<std::string>
+ unsupported_ops_; // summary for unsupported ops.
TF_DISALLOW_COPY_AND_ASSIGN(TfOpRoofLineCostEstimator);
};
diff --git a/tensorflow/core/profiler/utils/derived_timeline.cc b/tensorflow/core/profiler/utils/derived_timeline.cc
index c99d8e8..112c097 100644
--- a/tensorflow/core/profiler/utils/derived_timeline.cc
+++ b/tensorflow/core/profiler/utils/derived_timeline.cc
@@ -14,15 +14,27 @@
==============================================================================*/
#include "tensorflow/core/profiler/utils/derived_timeline.h"
+#include <algorithm>
+#include <utility>
+#include <vector>
+
+#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/strings/match.h"
+#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
+#include "tensorflow/core/profiler/utils/group_events.h"
#include "tensorflow/core/profiler/utils/tf_op_utils.h"
#include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
+#include "tensorflow/core/profiler/utils/time_utils.h"
#include "tensorflow/core/profiler/utils/timespan.h"
#include "tensorflow/core/profiler/utils/trace_utils.h"
+#include "tensorflow/core/profiler/utils/xplane_builder.h"
#include "tensorflow/core/profiler/utils/xplane_schema.h"
#include "tensorflow/core/profiler/utils/xplane_utils.h"
#include "tensorflow/core/profiler/utils/xplane_visitor.h"
diff --git a/tensorflow/core/profiler/utils/derived_timeline.h b/tensorflow/core/profiler/utils/derived_timeline.h
index 61b62bd..cd4da79 100644
--- a/tensorflow/core/profiler/utils/derived_timeline.h
+++ b/tensorflow/core/profiler/utils/derived_timeline.h
@@ -15,7 +15,13 @@
#ifndef TENSORFLOW_CORE_PROFILER_UTILS_DERIVED_TIMELINE_H_
#define TENSORFLOW_CORE_PROFILER_UTILS_DERIVED_TIMELINE_H_
+#include <functional>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/group_events.h"
#include "tensorflow/core/profiler/utils/xplane_builder.h"
diff --git a/tensorflow/core/profiler/utils/derived_timeline_test.cc b/tensorflow/core/profiler/utils/derived_timeline_test.cc
index f3e6b66..76a0188 100644
--- a/tensorflow/core/profiler/utils/derived_timeline_test.cc
+++ b/tensorflow/core/profiler/utils/derived_timeline_test.cc
@@ -15,8 +15,9 @@
#include "tensorflow/core/profiler/utils/derived_timeline.h"
-#include "absl/strings/match.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/group_events.h"
#include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
@@ -24,6 +25,7 @@
#include "tensorflow/core/profiler/utils/xplane_builder.h"
#include "tensorflow/core/profiler/utils/xplane_schema.h"
#include "tensorflow/core/profiler/utils/xplane_utils.h"
+#include "tensorflow/core/profiler/utils/xplane_visitor.h"
namespace tensorflow {
namespace profiler {
diff --git a/tensorflow/core/profiler/utils/errors.cc b/tensorflow/core/profiler/utils/errors.cc
index d829ee0..9c678e9 100644
--- a/tensorflow/core/profiler/utils/errors.cc
+++ b/tensorflow/core/profiler/utils/errors.cc
@@ -15,6 +15,8 @@
#include "tensorflow/core/profiler/utils/errors.h"
+#include "absl/strings/string_view.h"
+
namespace tensorflow {
namespace profiler {
diff --git a/tensorflow/core/profiler/utils/event_span.cc b/tensorflow/core/profiler/utils/event_span.cc
index 9768331..5e0413c 100644
--- a/tensorflow/core/profiler/utils/event_span.cc
+++ b/tensorflow/core/profiler/utils/event_span.cc
@@ -14,11 +14,19 @@
==============================================================================*/
#include "tensorflow/core/profiler/utils/event_span.h"
+#include <string>
+#include <utility>
#include <vector>
+#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/strings/match.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
+#include "tensorflow/core/profiler/utils/timespan.h"
namespace tensorflow {
namespace profiler {
diff --git a/tensorflow/core/profiler/utils/event_span.h b/tensorflow/core/profiler/utils/event_span.h
index 36b3172..1adc6a7 100644
--- a/tensorflow/core/profiler/utils/event_span.h
+++ b/tensorflow/core/profiler/utils/event_span.h
@@ -16,10 +16,11 @@
#ifndef TENSORFLOW_CORE_PROFILER_UTILS_EVENT_SPAN_H_
#define TENSORFLOW_CORE_PROFILER_UTILS_EVENT_SPAN_H_
+#include <string>
#include <vector>
#include "absl/container/flat_hash_map.h"
-#include "tensorflow/core/platform/logging.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
#include "tensorflow/core/profiler/utils/timespan.h"
diff --git a/tensorflow/core/profiler/utils/group_events.cc b/tensorflow/core/profiler/utils/group_events.cc
index 60d12c0..4296149 100644
--- a/tensorflow/core/profiler/utils/group_events.cc
+++ b/tensorflow/core/profiler/utils/group_events.cc
@@ -15,13 +15,25 @@
#include "tensorflow/core/profiler/utils/group_events.h"
-#include <stack>
+#include <algorithm>
+#include <functional>
+#include <iterator>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+#include "absl/container/flat_hash_map.h"
+#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
+#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/tf_op_utils.h"
#include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
+#include "tensorflow/core/profiler/utils/xplane_builder.h"
#include "tensorflow/core/profiler/utils/xplane_schema.h"
#include "tensorflow/core/profiler/utils/xplane_utils.h"
#include "tensorflow/core/profiler/utils/xplane_visitor.h"
diff --git a/tensorflow/core/profiler/utils/group_events.h b/tensorflow/core/profiler/utils/group_events.h
index 1140f2d..4b6fc58 100644
--- a/tensorflow/core/profiler/utils/group_events.h
+++ b/tensorflow/core/profiler/utils/group_events.h
@@ -16,9 +16,16 @@
#ifndef TENSORFLOW_CORE_PROFILER_UTILS_GROUP_EVENTS_H_
#define TENSORFLOW_CORE_PROFILER_UTILS_GROUP_EVENTS_H_
+#include <functional>
#include <memory>
+#include <string>
+#include <vector>
#include "absl/container/flat_hash_map.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/xplane_visitor.h"
diff --git a/tensorflow/core/profiler/utils/group_events_test.cc b/tensorflow/core/profiler/utils/group_events_test.cc
index 6b6a0d2..11996ba 100644
--- a/tensorflow/core/profiler/utils/group_events_test.cc
+++ b/tensorflow/core/profiler/utils/group_events_test.cc
@@ -16,12 +16,15 @@
#include "tensorflow/core/profiler/utils/group_events.h"
#include "absl/container/flat_hash_map.h"
+#include "absl/types/optional.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
#include "tensorflow/core/profiler/utils/xplane_builder.h"
#include "tensorflow/core/profiler/utils/xplane_schema.h"
#include "tensorflow/core/profiler/utils/xplane_utils.h"
+#include "tensorflow/core/profiler/utils/xplane_visitor.h"
namespace tensorflow {
namespace profiler {
diff --git a/tensorflow/core/profiler/utils/hardware_type_utils.cc b/tensorflow/core/profiler/utils/hardware_type_utils.cc
index 75896c0..e2a4004 100644
--- a/tensorflow/core/profiler/utils/hardware_type_utils.cc
+++ b/tensorflow/core/profiler/utils/hardware_type_utils.cc
@@ -17,6 +17,7 @@
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h"
namespace tensorflow {
namespace profiler {
diff --git a/tensorflow/core/profiler/utils/html_utils.h b/tensorflow/core/profiler/utils/html_utils.h
new file mode 100644
index 0000000..215d9f5
--- /dev/null
+++ b/tensorflow/core/profiler/utils/html_utils.h
@@ -0,0 +1,36 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PROFILER_UTILS_HTML_UTILS_H_
+#define TENSORFLOW_CORE_PROFILER_UTILS_HTML_UTILS_H_
+
+#include <string>
+
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+
+namespace tensorflow {
+namespace profiler {
+
+// Creates a html that links to the given url with the given text.
+inline std::string AnchorElement(absl::string_view url,
+ absl::string_view text) {
+ return absl::StrCat("<a href=\"", url, "\" target=\"_blank\">", text, "</a>");
+}
+
+} // namespace profiler
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_PROFILER_UTILS_HTML_UTILS_H_
diff --git a/tensorflow/core/profiler/utils/kernel_stats_utils.cc b/tensorflow/core/profiler/utils/kernel_stats_utils.cc
index 14038d5..c40c3a8 100644
--- a/tensorflow/core/profiler/utils/kernel_stats_utils.cc
+++ b/tensorflow/core/profiler/utils/kernel_stats_utils.cc
@@ -15,15 +15,17 @@
#include "tensorflow/core/profiler/utils/kernel_stats_utils.h"
+#include <algorithm>
+#include <string>
#include <tuple>
#include <vector>
#include "absl/strings/match.h"
#include "absl/strings/numbers.h"
-#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h"
namespace tensorflow {
@@ -34,15 +36,15 @@
const std::vector<absl::string_view> params =
absl::StrSplit(xstat_kernel_details, absl::ByAnyChar(":\n"));
- constexpr uint32_t kNumDimensions = 3;
- for (uint32_t dim = 0; dim < kNumDimensions; ++dim) {
+ constexpr uint32 kNumDimensions = 3;
+ for (uint32 dim = 0; dim < kNumDimensions; ++dim) {
kernel->add_block_dim(1);
kernel->add_grid_dim(1);
}
// Process value pairs.
- for (uint32_t ii = 0; ii < params.size(); ii += 2) {
- uint32_t value = 0;
+ for (uint32 ii = 0; ii < params.size(); ii += 2) {
+ uint32 value = 0;
if (params[ii] == "registers_per_thread" &&
absl::SimpleAtoi(params[ii + 1], &value)) {
kernel->set_registers_per_thread(value);
diff --git a/tensorflow/core/profiler/utils/op_metrics_db_utils.cc b/tensorflow/core/profiler/utils/op_metrics_db_utils.cc
index 06307d6..863d2f7 100644
--- a/tensorflow/core/profiler/utils/op_metrics_db_utils.cc
+++ b/tensorflow/core/profiler/utils/op_metrics_db_utils.cc
@@ -15,8 +15,13 @@
#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h"
+#include <algorithm>
+#include <string>
+
+#include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
#include "tensorflow/core/profiler/utils/math_utils.h"
#include "tensorflow/core/profiler/utils/tf_op_utils.h"
@@ -40,7 +45,7 @@
/*hlo_module_id=*/0, tf_op_name);
if (tf_op_metrics->category().empty()) {
tf_op_metrics->set_category(
- tf_op_type == kUnknownOp ? "Unknown" : string(tf_op_type));
+ tf_op_type == kUnknownOp ? "Unknown" : std::string(tf_op_type));
}
tf_op_metrics->set_is_eager(device_op_metrics.is_eager());
// The occurrences of a TF-op is the maximum among the occurrences of all
@@ -89,8 +94,8 @@
void AddIdleOp(OpMetricsDb* db) {
uint64 idle_time_ps = IdleTimePs(*db);
OpMetrics* metrics = db->add_metrics_db();
- metrics->set_name(string(kIdle));
- metrics->set_category(string(kIdle));
+ metrics->set_name(std::string(kIdle));
+ metrics->set_category(std::string(kIdle));
metrics->set_occurrences(0);
metrics->set_time_ps(idle_time_ps);
metrics->set_self_time_ps(idle_time_ps);
diff --git a/tensorflow/core/profiler/utils/op_utils.cc b/tensorflow/core/profiler/utils/op_utils.cc
index 74ce13d..921e061 100644
--- a/tensorflow/core/profiler/utils/op_utils.cc
+++ b/tensorflow/core/profiler/utils/op_utils.cc
@@ -15,8 +15,14 @@
#include "tensorflow/core/profiler/utils/op_utils.h"
+#include <algorithm>
+#include <string>
+
+#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
+#include "tensorflow/core/profiler/utils/tf_op_utils.h"
namespace tensorflow {
namespace profiler {
@@ -69,9 +75,9 @@
OpMetrics* op_metrics = LookupOrInsertNewOpMetrics(program_id, name);
if (op_metrics->category().empty())
op_metrics->set_category(category == kUnknownOp ? "unknown"
- : string(category));
+ : std::string(category));
if (op_metrics->provenance().empty())
- op_metrics->set_provenance(string(provenance));
+ op_metrics->set_provenance(std::string(provenance));
op_metrics->set_is_eager(op_metrics->is_eager() || is_eager);
op_metrics->set_occurrences(op_metrics->occurrences() + occurrences);
op_metrics->set_time_ps(op_metrics->time_ps() + time_ps);
diff --git a/tensorflow/core/profiler/utils/op_utils.h b/tensorflow/core/profiler/utils/op_utils.h
index 8aaa0f4..f94328d 100644
--- a/tensorflow/core/profiler/utils/op_utils.h
+++ b/tensorflow/core/profiler/utils/op_utils.h
@@ -16,13 +16,10 @@
#ifndef TENSORFLOW_CORE_PROFILER_UTILS_OP_UTILS_H_
#define TENSORFLOW_CORE_PROFILER_UTILS_OP_UTILS_H_
-#include <string>
-
#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h"
-#include "tensorflow/core/profiler/utils/tf_op_utils.h"
namespace tensorflow {
namespace profiler {
diff --git a/tensorflow/core/profiler/utils/tf_op_utils.cc b/tensorflow/core/profiler/utils/tf_op_utils.cc
index 5a42044..630a74c 100644
--- a/tensorflow/core/profiler/utils/tf_op_utils.cc
+++ b/tensorflow/core/profiler/utils/tf_op_utils.cc
@@ -15,11 +15,14 @@
#include "tensorflow/core/profiler/utils/tf_op_utils.h"
+#include <string>
+#include <vector>
+
#include "absl/strings/ascii.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
-#include "absl/strings/strip.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/regexp.h"
namespace tensorflow {
diff --git a/tensorflow/core/profiler/utils/tf_op_utils.h b/tensorflow/core/profiler/utils/tf_op_utils.h
index d1ac69e..b8af946 100644
--- a/tensorflow/core/profiler/utils/tf_op_utils.h
+++ b/tensorflow/core/profiler/utils/tf_op_utils.h
@@ -16,9 +16,9 @@
#ifndef TENSORFLOW_CORE_PROFILER_UTILS_TF_OP_UTILS_H_
#define TENSORFLOW_CORE_PROFILER_UTILS_TF_OP_UTILS_H_
+#include <string>
#include <vector>
-#include "absl/base/attributes.h"
#include "absl/strings/match.h"
#include "absl/strings/string_view.h"
diff --git a/tensorflow/core/profiler/utils/tf_op_utils_test.cc b/tensorflow/core/profiler/utils/tf_op_utils_test.cc
index fa51695..136dbee 100644
--- a/tensorflow/core/profiler/utils/tf_op_utils_test.cc
+++ b/tensorflow/core/profiler/utils/tf_op_utils_test.cc
@@ -15,6 +15,7 @@
#include "tensorflow/core/profiler/utils/tf_op_utils.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
diff --git a/tensorflow/core/profiler/utils/tf_xplane_visitor.h b/tensorflow/core/profiler/utils/tf_xplane_visitor.h
index 33a170f..17a7b94 100644
--- a/tensorflow/core/profiler/utils/tf_xplane_visitor.h
+++ b/tensorflow/core/profiler/utils/tf_xplane_visitor.h
@@ -16,6 +16,7 @@
#ifndef TENSORFLOW_CORE_PROFILER_UTILS_TF_XPLANE_VISITOR_H_
#define TENSORFLOW_CORE_PROFILER_UTILS_TF_XPLANE_VISITOR_H_
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/xplane_schema.h"
#include "tensorflow/core/profiler/utils/xplane_visitor.h"
diff --git a/tensorflow/core/profiler/utils/tfstreamz_utils.cc b/tensorflow/core/profiler/utils/tfstreamz_utils.cc
index 493420f..f4cbaa8 100644
--- a/tensorflow/core/profiler/utils/tfstreamz_utils.cc
+++ b/tensorflow/core/profiler/utils/tfstreamz_utils.cc
@@ -14,36 +14,46 @@
==============================================================================*/
#include "tensorflow/core/profiler/utils/tfstreamz_utils.h"
+#include <map>
#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
-#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/substitute.h"
-#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/lib/monitoring/collected_metrics.h"
-#include "tensorflow/core/lib/monitoring/collection_registry.h"
-#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/lib/monitoring/metric_def.h"
+#include "tensorflow/core/lib/monitoring/types.h"
+#include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/protobuf/tfstreamz.pb.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
+#include "tensorflow/core/profiler/utils/xplane_builder.h"
namespace tensorflow {
namespace profiler {
namespace {
-string ConstructXStatName(const string& name, const monitoring::Point& point) {
+
+std::string ConstructXStatName(absl::string_view name,
+ const monitoring::Point& point) {
if (point.labels.empty()) {
- return name;
+ return std::string(name);
}
return absl::Substitute(
"$0{$1}", name,
- absl::StrJoin(point.labels, ", ",
- [](string* out, const monitoring::Point::Label& label) {
- absl::StrAppend(out, label.name, "=", label.value);
- }));
+ absl::StrJoin(
+ point.labels, ", ",
+ [](std::string* out, const monitoring::Point::Label& label) {
+ absl::StrAppend(out, label.name, "=", label.value);
+ }));
}
-string SerializePercentile(const monitoring::Percentiles& percentiles) {
+std::string SerializePercentile(const monitoring::Percentiles& percentiles) {
tfstreamz::Percentiles output;
output.set_unit_of_measure(
static_cast<tfstreamz::UnitOfMeasure>(percentiles.unit_of_measure));
@@ -81,11 +91,11 @@
xevent.SetEndTimestampNs(snapshot.end_time_ns);
auto& metric_descriptor_map = snapshot.metrics->metric_descriptor_map;
for (const auto& point_set : snapshot.metrics->point_set_map) {
- const string& metric_name = point_set.first;
+ const std::string& metric_name = point_set.first;
// Each metrics have multiple points corresponding to different labels.
for (const auto& point : point_set.second->points) {
// Generates one KPI metric for each point.
- string stat_name = ConstructXStatName(metric_name, *point);
+ std::string stat_name = ConstructXStatName(metric_name, *point);
auto* metadata = xplane.GetOrCreateStatMetadata(stat_name);
auto it = metric_descriptor_map.find(metric_name);
if (it != metric_descriptor_map.end()) {
diff --git a/tensorflow/core/profiler/utils/tfstreamz_utils.h b/tensorflow/core/profiler/utils/tfstreamz_utils.h
index ae8e407..1ab21ed 100644
--- a/tensorflow/core/profiler/utils/tfstreamz_utils.h
+++ b/tensorflow/core/profiler/utils/tfstreamz_utils.h
@@ -15,11 +15,13 @@
#ifndef TENSORFLOW_CORE_PROFILER_UTILS_TFSTREAMZ_UTILS_H_
#define TENSORFLOW_CORE_PROFILER_UTILS_TFSTREAMZ_UTILS_H_
+#include <memory>
+#include <vector>
+
#include "tensorflow/core/lib/monitoring/collected_metrics.h"
-#include "tensorflow/core/lib/monitoring/collection_registry.h"
#include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
-#include "tensorflow/core/profiler/utils/xplane_builder.h"
namespace tensorflow {
namespace profiler {
diff --git a/tensorflow/core/profiler/utils/timespan.h b/tensorflow/core/profiler/utils/timespan.h
index bccbeaa..82775af 100644
--- a/tensorflow/core/profiler/utils/timespan.h
+++ b/tensorflow/core/profiler/utils/timespan.h
@@ -16,6 +16,9 @@
#ifndef TENSORFLOW_CORE_PROFILER_UTILS_TIMESPAN_H_
#define TENSORFLOW_CORE_PROFILER_UTILS_TIMESPAN_H_
+#include <algorithm>
+#include <string>
+
#include "absl/strings/str_cat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/core/profiler/utils/xplane_builder.cc b/tensorflow/core/profiler/utils/xplane_builder.cc
index 9e66a15..f923f39 100644
--- a/tensorflow/core/profiler/utils/xplane_builder.cc
+++ b/tensorflow/core/profiler/utils/xplane_builder.cc
@@ -14,6 +14,14 @@
==============================================================================*/
#include "tensorflow/core/profiler/utils/xplane_builder.h"
+#include <algorithm>
+#include <string>
+#include <utility>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/strings/string_view.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/time_utils.h"
namespace tensorflow {
@@ -54,7 +62,7 @@
return metadata;
}
-XEventMetadata* XPlaneBuilder::GetOrCreateEventMetadata(string&& name) {
+XEventMetadata* XPlaneBuilder::GetOrCreateEventMetadata(std::string&& name) {
XEventMetadata*& metadata = event_metadata_by_name_[name];
if (metadata == nullptr) {
metadata =
diff --git a/tensorflow/core/profiler/utils/xplane_builder.h b/tensorflow/core/profiler/utils/xplane_builder.h
index 803cc7b..b0d743a 100644
--- a/tensorflow/core/profiler/utils/xplane_builder.h
+++ b/tensorflow/core/profiler/utils/xplane_builder.h
@@ -15,10 +15,15 @@
#ifndef TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_BUILDER_H_
#define TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_BUILDER_H_
+#include <stddef.h>
+
+#include <string>
+#include <utility>
+
#include "absl/container/flat_hash_map.h"
#include "absl/strings/numbers.h"
#include "absl/strings/string_view.h"
-#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/time_utils.h"
@@ -53,12 +58,12 @@
void AddStatValue(const XStatMetadata& metadata, absl::string_view value,
bool is_bytes = false) {
if (is_bytes) {
- AddStat(metadata)->set_bytes_value(string(value));
+ AddStat(metadata)->set_bytes_value(std::string(value));
} else {
- AddStat(metadata)->set_str_value(string(value));
+ AddStat(metadata)->set_str_value(std::string(value));
}
}
- void AddStatValue(const XStatMetadata& metadata, string&& value,
+ void AddStatValue(const XStatMetadata& metadata, std::string&& value,
bool is_bytes = false) {
if (is_bytes) {
AddStat(metadata)->set_bytes_value(std::move(value));
@@ -160,7 +165,7 @@
int64 NumEvents() { return line_->events_size(); }
- void SetName(absl::string_view name) { line_->set_name(string(name)); }
+ void SetName(absl::string_view name) { line_->set_name(std::string(name)); }
void SetNameIfEmpty(absl::string_view name) {
if (line_->name().empty()) SetName(name);
@@ -205,7 +210,7 @@
int64 Id() { return plane_->id(); }
void SetId(int64 id) { plane_->set_id(id); }
- void SetName(absl::string_view name) { plane_->set_name(string(name)); }
+ void SetName(absl::string_view name) { plane_->set_name(std::string(name)); }
void ReserveLines(size_t num_lines) {
plane_->mutable_lines()->Reserve(num_lines);
@@ -222,7 +227,7 @@
XEventMetadata* GetOrCreateEventMetadata(int64 metadata_id);
XEventMetadata* GetOrCreateEventMetadata(absl::string_view name);
- XEventMetadata* GetOrCreateEventMetadata(string&& name);
+ XEventMetadata* GetOrCreateEventMetadata(std::string&& name);
inline XEventMetadata* GetOrCreateEventMetadata(const char* name) {
return GetOrCreateEventMetadata(absl::string_view(name));
}
@@ -251,7 +256,7 @@
if (stat.value_case() == XStat::kRefValue) {
const auto& stat_metadata_map = src.stat_metadata();
const auto it = stat_metadata_map.find(stat.ref_value());
- if (ABSL_PREDICT_FALSE(it == stat_metadata_map.end())) {
+ if (TF_PREDICT_FALSE(it == stat_metadata_map.end())) {
// the reference value in stat is not found in XStatMetadata from src.
return;
}
diff --git a/tensorflow/core/profiler/utils/xplane_builder_test.cc b/tensorflow/core/profiler/utils/xplane_builder_test.cc
index cb87497..e55e01d 100644
--- a/tensorflow/core/profiler/utils/xplane_builder_test.cc
+++ b/tensorflow/core/profiler/utils/xplane_builder_test.cc
@@ -14,7 +14,11 @@
==============================================================================*/
#include "tensorflow/core/profiler/utils/xplane_builder.h"
+#include <string>
+
+#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/xplane_visitor.h"
namespace tensorflow {
diff --git a/tensorflow/core/profiler/utils/xplane_schema.cc b/tensorflow/core/profiler/utils/xplane_schema.cc
index 51bc4d0..f8ff31b 100644
--- a/tensorflow/core/profiler/utils/xplane_schema.cc
+++ b/tensorflow/core/profiler/utils/xplane_schema.cc
@@ -17,7 +17,10 @@
#include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace profiler {
diff --git a/tensorflow/core/profiler/utils/xplane_schema.h b/tensorflow/core/profiler/utils/xplane_schema.h
index 97e54a7..31ff901 100644
--- a/tensorflow/core/profiler/utils/xplane_schema.h
+++ b/tensorflow/core/profiler/utils/xplane_schema.h
@@ -16,11 +16,10 @@
#ifndef TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_SCHEMA_H_
#define TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_SCHEMA_H_
-#include "absl/strings/match.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
-#include "absl/types/span.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace profiler {
diff --git a/tensorflow/core/profiler/utils/xplane_utils.cc b/tensorflow/core/profiler/utils/xplane_utils.cc
index b2cc1fd..7f5221c 100644
--- a/tensorflow/core/profiler/utils/xplane_utils.cc
+++ b/tensorflow/core/profiler/utils/xplane_utils.cc
@@ -14,12 +14,21 @@
==============================================================================*/
#include "tensorflow/core/profiler/utils/xplane_utils.h"
+#include <algorithm>
+#include <string>
+#include <utility>
+#include <vector>
+
#include "absl/container/flat_hash_map.h"
#include "absl/strings/match.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/env_time.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/timespan.h"
#include "tensorflow/core/profiler/utils/xplane_builder.h"
+#include "tensorflow/core/profiler/utils/xplane_schema.h"
#include "tensorflow/core/profiler/utils/xplane_visitor.h"
namespace tensorflow {
diff --git a/tensorflow/core/profiler/utils/xplane_utils.h b/tensorflow/core/profiler/utils/xplane_utils.h
index 4f0a8b8..49087c4 100644
--- a/tensorflow/core/profiler/utils/xplane_utils.h
+++ b/tensorflow/core/profiler/utils/xplane_utils.h
@@ -17,6 +17,7 @@
#include <vector>
+#include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
diff --git a/tensorflow/core/profiler/utils/xplane_utils_test.cc b/tensorflow/core/profiler/utils/xplane_utils_test.cc
index b9b15b2..04e06fc 100644
--- a/tensorflow/core/profiler/utils/xplane_utils_test.cc
+++ b/tensorflow/core/profiler/utils/xplane_utils_test.cc
@@ -15,9 +15,14 @@
#include "tensorflow/core/profiler/utils/xplane_utils.h"
+#include <string>
+
#include "absl/container/flat_hash_map.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
#include "tensorflow/core/platform/env_time.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/xplane_builder.h"
#include "tensorflow/core/profiler/utils/xplane_visitor.h"
diff --git a/tensorflow/core/profiler/utils/xplane_visitor.cc b/tensorflow/core/profiler/utils/xplane_visitor.cc
index ab97271..42068b7 100644
--- a/tensorflow/core/profiler/utils/xplane_visitor.cc
+++ b/tensorflow/core/profiler/utils/xplane_visitor.cc
@@ -14,7 +14,16 @@
==============================================================================*/
#include "tensorflow/core/profiler/utils/xplane_visitor.h"
+#include <string>
+#include <utility>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
namespace tensorflow {
namespace profiler {
diff --git a/tensorflow/core/profiler/utils/xplane_visitor.h b/tensorflow/core/profiler/utils/xplane_visitor.h
index 8cd805c..4120a28 100644
--- a/tensorflow/core/profiler/utils/xplane_visitor.h
+++ b/tensorflow/core/profiler/utils/xplane_visitor.h
@@ -15,8 +15,11 @@
#ifndef TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_VISITOR_H_
#define TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_VISITOR_H_
+#include <stddef.h>
+
#include <functional>
-#include <utility>
+#include <string>
+#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h"
diff --git a/tensorflow/core/protobuf/BUILD b/tensorflow/core/protobuf/BUILD
new file mode 100644
index 0000000..a374c80
--- /dev/null
+++ b/tensorflow/core/protobuf/BUILD
@@ -0,0 +1,182 @@
+# For platform specific build config
+load(
+ "//tensorflow/core/platform:build_config.bzl",
+ "tf_additional_all_protos",
+ "tf_proto_library",
+ "tf_proto_library_cc",
+ "tf_pyclif_proto_library",
+)
+
+package(
+ default_visibility = [
+ "//tensorflow:internal",
+ "//tensorflow/core:__subpackages__",
+ "//tensorflow_models:__subpackages__",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+COMMON_PROTO_SRCS = [
+ "bfc_memory_map.proto",
+ "config.proto",
+ "cluster.proto",
+ "debug.proto",
+ "device_filters.proto",
+ "device_properties.proto",
+ "graph_debug_info.proto",
+ "queue_runner.proto",
+ "rewriter_config.proto",
+ "tensor_bundle.proto",
+ "saver.proto",
+ "verifier_config.proto",
+]
+
+[
+ [
+ tf_pyclif_proto_library(
+ name = "%s_pyclif" % proto_name,
+ proto_lib = ":for_core_protos",
+ proto_srcfile = "%s.proto" % proto_name,
+ visibility = ["//visibility:public"],
+ ),
+ ]
+ for proto_name in [
+ "config",
+ "device_properties",
+ "graph_debug_info",
+ "meta_graph",
+ "saved_model",
+ ]
+]
+
+tf_proto_library(
+ name = "autotuning_proto",
+ srcs = ["autotuning.proto"],
+ cc_api_version = 2,
+ make_default_target_header_only = True,
+)
+
+tf_proto_library(
+ name = "conv_autotuning_proto",
+ srcs = ["conv_autotuning.proto"],
+ cc_api_version = 2,
+ make_default_target_header_only = True,
+ protodeps = [
+ "//tensorflow/stream_executor:dnn_proto",
+ ],
+)
+
+tf_proto_library_cc(
+ name = "worker_proto",
+ srcs = ["worker.proto"],
+ cc_api_version = 2,
+ protodeps = tf_additional_all_protos(),
+ visibility = ["//visibility:public"],
+)
+
+tf_proto_library_cc(
+ name = "worker_service_proto",
+ srcs = ["worker_service.proto"],
+ has_services = 1,
+ cc_api_version = 2,
+ cc_stubby_versions = ["2"],
+ protodeps = [":worker_proto"],
+)
+
+tf_proto_library_cc(
+ name = "master_proto",
+ srcs = ["master.proto"],
+ cc_api_version = 2,
+ protodeps = tf_additional_all_protos(),
+ visibility = ["//tensorflow:internal"],
+)
+
+tf_proto_library_cc(
+ name = "master_service_proto",
+ srcs = ["master_service.proto"],
+ has_services = 1,
+ cc_api_version = 2,
+ cc_stubby_versions = ["2"],
+ protodeps = [":master_proto"],
+)
+
+tf_proto_library_cc(
+ name = "eager_service_proto",
+ srcs = ["eager_service.proto"],
+ has_services = 1,
+ cc_api_version = 2,
+ cc_grpc_version = 1,
+ cc_stubby_versions = ["2"],
+ protodeps = tf_additional_all_protos(),
+)
+
+tf_proto_library_cc(
+ name = "replay_log_proto",
+ srcs = ["replay_log.proto"],
+ cc_api_version = 2,
+ protodeps = [
+ ":master_proto",
+ ] + tf_additional_all_protos(),
+)
+
+tf_proto_library(
+ name = "error_codes_proto_impl",
+ srcs = ["error_codes.proto"],
+ cc_api_version = 2,
+ make_default_target_header_only = True,
+)
+
+exports_files(
+ srcs = ["error_codes.proto"] + COMMON_PROTO_SRCS + [
+ # Protos which are not needed on mobile builds, but should be included
+ # in protos_all.
+ #
+ # Note that some protos are in neither core_proto_srcs nor this
+ # filegroup; e.g. ones with individual proto_library targets.
+ "control_flow.proto",
+ # TODO(ebrevdo): Re-enable once CriticalSection is in core.
+ # "critical_section.proto",
+ "data/experimental/snapshot.proto",
+ "debug_event.proto",
+ "meta_graph.proto",
+ "named_tensor.proto",
+ "remote_tensor_handle.proto",
+ "saved_model.proto",
+ "saved_object_graph.proto",
+ "struct.proto",
+ "tensorflow_server.proto",
+ "trackable_object_graph.proto",
+ "transport_options.proto",
+ ],
+)
+
+tf_proto_library(
+ name = "for_core_protos",
+ srcs = COMMON_PROTO_SRCS + [
+ # Protos which are not needed on mobile builds, but should be included
+ # in protos_all.
+ #
+ # Note that some protos are in neither core_proto_srcs nor this
+ # filegroup; e.g. ones with individual proto_library targets.
+ "control_flow.proto",
+ # TODO(ebrevdo): Re-enable once CriticalSection is in core.
+ # "critical_section.proto",
+ "data/experimental/snapshot.proto",
+ "debug_event.proto",
+ "meta_graph.proto",
+ "named_tensor.proto",
+ "remote_tensor_handle.proto",
+ "saved_model.proto",
+ "saved_object_graph.proto",
+ "struct.proto",
+ "tensorflow_server.proto",
+ "trackable_object_graph.proto",
+ "transport_options.proto",
+ ],
+ cc_api_version = 2,
+ make_default_target_header_only = True,
+ protodeps = [
+ ":error_codes_proto_impl",
+ "//tensorflow/core/framework:protos_all",
+ ],
+)
diff --git a/tensorflow/core/protobuf/remote_tensor_handle.proto b/tensorflow/core/protobuf/remote_tensor_handle.proto
index 1099522..36e3f81 100644
--- a/tensorflow/core/protobuf/remote_tensor_handle.proto
+++ b/tensorflow/core/protobuf/remote_tensor_handle.proto
@@ -21,11 +21,11 @@
int64 op_id = 1;
// The index into the outputs of the operation that produced this tensor.
int32 output_num = 2;
- // Device of the operation that produced this tensor. Cannot be empty.
+ // Device where the tensor is located. Cannot be empty.
// For multi-device functions, it's the default device passed to placer.
string device = 3;
- // Device where the tensor is located. Can be empty if the operation producing
- // this tensor is a multi-device function.
+ // Device of the operation producing this tensor. Can be empty if the
+ // operation producing this tensor is a multi-device function.
string op_device = 4;
// Tensor type.
DataType dtype = 5;
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index b7f2b76..23e6138 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -21,7 +21,7 @@
// Also update tensorflow/tensorflow.bzl and
// tensorflow/tools/pip_package/setup.py
#define TF_MAJOR_VERSION 2
-#define TF_MINOR_VERSION 1
+#define TF_MINOR_VERSION 2
#define TF_PATCH_VERSION 0
// TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1",
@@ -108,7 +108,7 @@
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
-#define TF_GRAPH_DEF_VERSION 397 // Updated: 2020/5/10
+#define TF_GRAPH_DEF_VERSION 399 // Updated: 2020/5/12
// Checkpoint compatibility versions (the versions field in SavedSliceMeta).
//
diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD
index 43b2d93..4ea5fc3 100644
--- a/tensorflow/core/tpu/BUILD
+++ b/tensorflow/core/tpu/BUILD
@@ -51,3 +51,11 @@
"//tensorflow/core:lib",
],
)
+
+cc_library(
+ name = "tpu_config_c_api",
+ hdrs = ["tpu_config_c_api.h"],
+ deps = [
+ "//tensorflow/c:tf_status",
+ ],
+)
diff --git a/tensorflow/core/tpu/tpu_config_c_api.h b/tensorflow/core/tpu/tpu_config_c_api.h
new file mode 100644
index 0000000..334a6a1
--- /dev/null
+++ b/tensorflow/core/tpu/tpu_config_c_api.h
@@ -0,0 +1,54 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_TPU_TPU_CONFIG_C_API_H_
+#define TENSORFLOW_CORE_TPU_TPU_CONFIG_C_API_H_
+
+#include <cstddef>
+
+#include "tensorflow/c/tf_status.h"
+
+typedef struct TpuSerializedProto TpuSerializedProto;
+
+extern "C" {
+
+bool TPUHostInitialized();
+
+// TODO(frankchn): Modify API to take in raw values instead of Tensors.
+void ConfigureDistributedTpuOp_DoWork(size_t input_size,
+ TpuSerializedProto** inputs,
+ TpuSerializedProto* output,
+ TF_Status* status);
+
+void WaitForDistributedTpuOp_DoWork(size_t input_size,
+ TpuSerializedProto** inputs,
+ TpuSerializedProto* output,
+ TF_Status* status);
+
+void ShutdownDistributedTpuOp_DoWork(TF_Status* status);
+
+void InitializeHostForDistributedTpuOp_DoWork(
+ size_t input_size, TpuSerializedProto** inputs,
+ bool enable_whole_mesh_compilations, TpuSerializedProto* output,
+ TF_Status* status);
+
+void SetGlobalTPUArrayOp_DoWork(size_t input_size, TpuSerializedProto** inputs,
+ TF_Status* status);
+
+void DisconnectDistributedTpuChipsOp_DoWork(TpuSerializedProto* output,
+ TF_Status* status);
+}
+
+#endif // TENSORFLOW_CORE_TPU_TPU_CONFIG_C_API_H_
diff --git a/tensorflow/core/util/debug_events_writer.cc b/tensorflow/core/util/debug_events_writer.cc
index 595f92d..d9c3393 100644
--- a/tensorflow/core/util/debug_events_writer.cc
+++ b/tensorflow/core/util/debug_events_writer.cc
@@ -179,7 +179,7 @@
metadata->set_tensorflow_version(TF_VERSION_STRING);
metadata->set_file_version(
strings::Printf("%s%d", kVersionPrefix, kCurrentFormatVersion));
- SerializeAndWriteDebugEvent(&debug_event, METADATA);
+ TF_RETURN_IF_ERROR(SerializeAndWriteDebugEvent(&debug_event, METADATA));
TF_RETURN_WITH_CONTEXT_IF_ERROR(
metadata_writer_->Flush(), "Failed to flush debug event metadata writer");
@@ -189,38 +189,38 @@
return Status::OK();
}
-void DebugEventsWriter::WriteSourceFile(SourceFile* source_file) {
+Status DebugEventsWriter::WriteSourceFile(SourceFile* source_file) {
DebugEvent debug_event;
debug_event.set_allocated_source_file(source_file);
- SerializeAndWriteDebugEvent(&debug_event, SOURCE_FILES);
+ return SerializeAndWriteDebugEvent(&debug_event, SOURCE_FILES);
}
-void DebugEventsWriter::WriteStackFrameWithId(
+Status DebugEventsWriter::WriteStackFrameWithId(
StackFrameWithId* stack_frame_with_id) {
DebugEvent debug_event;
debug_event.set_allocated_stack_frame_with_id(stack_frame_with_id);
- SerializeAndWriteDebugEvent(&debug_event, STACK_FRAMES);
+ return SerializeAndWriteDebugEvent(&debug_event, STACK_FRAMES);
}
-void DebugEventsWriter::WriteGraphOpCreation(
+Status DebugEventsWriter::WriteGraphOpCreation(
GraphOpCreation* graph_op_creation) {
DebugEvent debug_event;
debug_event.set_allocated_graph_op_creation(graph_op_creation);
- SerializeAndWriteDebugEvent(&debug_event, GRAPHS);
+ return SerializeAndWriteDebugEvent(&debug_event, GRAPHS);
}
-void DebugEventsWriter::WriteDebuggedGraph(DebuggedGraph* debugged_graph) {
+Status DebugEventsWriter::WriteDebuggedGraph(DebuggedGraph* debugged_graph) {
DebugEvent debug_event;
debug_event.set_allocated_debugged_graph(debugged_graph);
- SerializeAndWriteDebugEvent(&debug_event, GRAPHS);
+ return SerializeAndWriteDebugEvent(&debug_event, GRAPHS);
}
-void DebugEventsWriter::WriteExecution(Execution* execution) {
+Status DebugEventsWriter::WriteExecution(Execution* execution) {
if (circular_buffer_size_ <= 0) {
// No cyclic-buffer behavior.
DebugEvent debug_event;
debug_event.set_allocated_execution(execution);
- SerializeAndWriteDebugEvent(&debug_event, EXECUTION);
+ return SerializeAndWriteDebugEvent(&debug_event, EXECUTION);
} else {
// Circular buffer behavior.
DebugEvent debug_event;
@@ -234,16 +234,18 @@
if (execution_buffer_.size() > circular_buffer_size_) {
execution_buffer_.pop_front();
}
+ return Status::OK();
}
}
-void DebugEventsWriter::WriteGraphExecutionTrace(
+Status DebugEventsWriter::WriteGraphExecutionTrace(
GraphExecutionTrace* graph_execution_trace) {
+ TF_RETURN_IF_ERROR(Init());
if (circular_buffer_size_ <= 0) {
// No cyclic-buffer behavior.
DebugEvent debug_event;
debug_event.set_allocated_graph_execution_trace(graph_execution_trace);
- SerializeAndWriteDebugEvent(&debug_event, GRAPH_EXECUTION_TRACES);
+ return SerializeAndWriteDebugEvent(&debug_event, GRAPH_EXECUTION_TRACES);
} else {
// Circular buffer behavior.
DebugEvent debug_event;
@@ -257,15 +259,14 @@
if (graph_execution_trace_buffer_.size() > circular_buffer_size_) {
graph_execution_trace_buffer_.pop_front();
}
+ return Status::OK();
}
}
-void DebugEventsWriter::WriteGraphExecutionTrace(const string& tfdbg_context_id,
- const string& device_name,
- const string& op_name,
- int32 output_slot,
- int32 tensor_debug_mode,
- const Tensor& tensor_value) {
+Status DebugEventsWriter::WriteGraphExecutionTrace(
+ const string& tfdbg_context_id, const string& device_name,
+ const string& op_name, int32 output_slot, int32 tensor_debug_mode,
+ const Tensor& tensor_value) {
std::unique_ptr<GraphExecutionTrace> trace(new GraphExecutionTrace());
trace->set_tfdbg_context_id(tfdbg_context_id);
if (!op_name.empty()) {
@@ -279,7 +280,7 @@
}
trace->set_device_name(device_name);
tensor_value.AsProtoTensorContent(trace->mutable_tensor_proto());
- WriteGraphExecutionTrace(trace.release());
+ return WriteGraphExecutionTrace(trace.release());
}
void DebugEventsWriter::WriteSerializedNonExecutionDebugEvent(
@@ -487,8 +488,8 @@
return Status::OK();
}
-void DebugEventsWriter::SerializeAndWriteDebugEvent(DebugEvent* debug_event,
- DebugEventFileType type) {
+Status DebugEventsWriter::SerializeAndWriteDebugEvent(DebugEvent* debug_event,
+ DebugEventFileType type) {
std::unique_ptr<SingleDebugEventFileWriter>* writer = nullptr;
SelectWriter(type, &writer);
if (writer != nullptr) {
@@ -497,6 +498,11 @@
string str;
debug_event->AppendToString(&str);
(*writer)->WriteSerializedDebugEvent(str);
+ return Status::OK();
+ } else {
+ return errors::Internal(
+ "Unable to find debug events file writer for DebugEventsFileType ",
+ type);
}
}
diff --git a/tensorflow/core/util/debug_events_writer.h b/tensorflow/core/util/debug_events_writer.h
index 6d219d7..39835ad 100644
--- a/tensorflow/core/util/debug_events_writer.h
+++ b/tensorflow/core/util/debug_events_writer.h
@@ -119,27 +119,27 @@
// The four DebugEvent fields below are written _without_ the circular buffer.
// Source file contents are written to the *.source_files file.
// Takes ownership of source_file.
- void WriteSourceFile(SourceFile* source_file);
+ Status WriteSourceFile(SourceFile* source_file);
// Stack frames are written to the *.code_locations file.
// Takes ownership of stack_frame_with_id.
- void WriteStackFrameWithId(StackFrameWithId* stack_frame_with_id);
+ Status WriteStackFrameWithId(StackFrameWithId* stack_frame_with_id);
// Graph op creation events are written to the *.graphs file.
// Takes ownership of graph_op_creation.
- void WriteGraphOpCreation(GraphOpCreation* graph_op_creation);
+ Status WriteGraphOpCreation(GraphOpCreation* graph_op_creation);
// Debugged graphs are written to the *.graphs file.
// Takes ownership of debugged_graph.
- void WriteDebuggedGraph(DebuggedGraph* debugged_graph);
+ Status WriteDebuggedGraph(DebuggedGraph* debugged_graph);
// The two DebugEvent fields below are written to the circular buffer
// and saved to disk only at the FlushExecutionFiles() call.
// Execution events (eager execution of an op or a tf.function) are written to
// the *.execution file.
// Takes ownership of execution.
- void WriteExecution(Execution* execution);
+ Status WriteExecution(Execution* execution);
// Graph execution traces (graph-internal tensor values or their summaries)
// are written to the *.graph_execution_traces file.
// Takes ownership of graph_execution_trace.
- void WriteGraphExecutionTrace(GraphExecutionTrace* graph_execution_trace);
+ Status WriteGraphExecutionTrace(GraphExecutionTrace* graph_execution_trace);
// Write a graph execution trace without using a protocol buffer.
// Instead, pass the raw values related to the graph execution trace.
@@ -155,11 +155,11 @@
// tensor_value: The value of the tensor that describes the tensor(s)
// that this trace is concerned with. The semantics of this tensor value
// depends on the value of `tensor_debug_mode`.
- void WriteGraphExecutionTrace(const string& tfdbg_context_id,
- const string& device_name,
- const string& op_name, int32 output_slot,
- int32 tensor_debug_mode,
- const Tensor& tensor_value);
+ Status WriteGraphExecutionTrace(const string& tfdbg_context_id,
+ const string& device_name,
+ const string& op_name, int32 output_slot,
+ int32 tensor_debug_mode,
+ const Tensor& tensor_value);
// Writes a serialized DebugEvent to one of the debug-events files
// concerned with the non-execution events: the SOURCE_FILES, STACK_FRAMES
@@ -217,8 +217,8 @@
// Initialize the TFRecord writer for non-metadata file type.
Status InitNonMetadataFile(DebugEventFileType type);
- void SerializeAndWriteDebugEvent(DebugEvent* debug_event,
- DebugEventFileType type);
+ Status SerializeAndWriteDebugEvent(DebugEvent* debug_event,
+ DebugEventFileType type);
void SelectWriter(DebugEventFileType type,
std::unique_ptr<SingleDebugEventFileWriter>** writer);
diff --git a/tensorflow/core/util/debug_events_writer_test.cc b/tensorflow/core/util/debug_events_writer_test.cc
index 66cde55..bd0c731 100644
--- a/tensorflow/core/util/debug_events_writer_test.cc
+++ b/tensorflow/core/util/debug_events_writer_test.cc
@@ -263,7 +263,7 @@
source_file_1->add_lines("");
source_file_1->add_lines("print(tf.constant([42.0]))");
source_file_1->add_lines("");
- writer->WriteSourceFile(source_file_1);
+ TF_ASSERT_OK(writer->WriteSourceFile(source_file_1));
SourceFile* source_file_2 = new SourceFile();
source_file_2->set_file_path("/home/tf_programs/train.py");
@@ -271,7 +271,7 @@
source_file_2->add_lines("import tensorflow.keras as keras");
source_file_2->add_lines("");
source_file_2->add_lines("model = keras.Sequential()");
- writer->WriteSourceFile(source_file_2);
+ TF_ASSERT_OK(writer->WriteSourceFile(source_file_2));
TF_ASSERT_OK(writer->FlushNonExecutionFiles());
TF_ASSERT_OK(writer->Close());
@@ -336,8 +336,8 @@
file_line_col->set_func("my_func");
file_line_col->set_code(" x = x ** 2.0");
- writer->WriteStackFrameWithId(stack_frame_1);
- writer->WriteStackFrameWithId(stack_frame_2);
+ TF_ASSERT_OK(writer->WriteStackFrameWithId(stack_frame_1));
+ TF_ASSERT_OK(writer->WriteStackFrameWithId(stack_frame_2));
TF_ASSERT_OK(writer->FlushNonExecutionFiles());
TF_ASSERT_OK(writer->Close());
@@ -382,12 +382,12 @@
GraphOpCreation* graph_op_creation = new GraphOpCreation();
graph_op_creation->set_op_type("MatMul");
graph_op_creation->set_op_name("Dense_1/MatMul");
- writer->WriteGraphOpCreation(graph_op_creation);
+ TF_ASSERT_OK(writer->WriteGraphOpCreation(graph_op_creation));
DebuggedGraph* debugged_graph = new DebuggedGraph();
debugged_graph->set_graph_id("deadbeaf");
debugged_graph->set_graph_name("my_func_graph");
- writer->WriteDebuggedGraph(debugged_graph);
+ TF_ASSERT_OK(writer->WriteDebuggedGraph(debugged_graph));
TF_ASSERT_OK(writer->FlushNonExecutionFiles());
TF_ASSERT_OK(writer->Close());
@@ -428,7 +428,7 @@
SourceFile* source_file = new SourceFile();
source_file->set_file_path(file_path);
source_file->set_host_name("localhost.localdomain");
- writer->WriteSourceFile(source_file);
+ TF_ASSERT_OK(writer->WriteSourceFile(source_file));
};
for (size_t i = 0; i < kConcurrentWrites; ++i) {
thread_pool->Schedule(fn);
@@ -469,7 +469,7 @@
SourceFile* source_file = new SourceFile();
source_file->set_file_path(file_path);
source_file->set_host_name("localhost.localdomain");
- writer->WriteSourceFile(source_file);
+ TF_ASSERT_OK(writer->WriteSourceFile(source_file));
TF_ASSERT_OK(writer->FlushNonExecutionFiles());
};
for (size_t i = 0; i < kConcurrentWrites; ++i) {
@@ -512,16 +512,16 @@
source_file->set_file_path(
strings::Printf("/home/tf_programs/program_%.2d.py", index));
source_file->set_host_name("localhost.localdomain");
- writer->WriteSourceFile(source_file);
+ TF_ASSERT_OK(writer->WriteSourceFile(source_file));
} else if (index % 3 == 1) {
StackFrameWithId* stack_frame = new StackFrameWithId();
stack_frame->set_id(strings::Printf("e%.2d", index));
- writer->WriteStackFrameWithId(stack_frame);
+ TF_ASSERT_OK(writer->WriteStackFrameWithId(stack_frame));
} else {
GraphOpCreation* op_creation = new GraphOpCreation();
op_creation->set_op_type("Log");
op_creation->set_op_name(strings::Printf("Log_%.2d", index));
- writer->WriteGraphOpCreation(op_creation);
+ TF_ASSERT_OK(writer->WriteGraphOpCreation(op_creation));
}
};
for (size_t i = 0; i < kConcurrentWrites; ++i) {
@@ -586,7 +586,7 @@
Execution* execution = new Execution();
execution->set_op_type("Log");
execution->add_input_tensor_ids(i);
- writer->WriteExecution(execution);
+ TF_ASSERT_OK(writer->WriteExecution(execution));
}
std::vector<DebugEvent> actuals;
@@ -611,7 +611,7 @@
Execution* execution = new Execution();
execution->set_op_type("Log");
execution->add_input_tensor_ids(i);
- writer->WriteExecution(execution);
+ TF_ASSERT_OK(writer->WriteExecution(execution));
}
TF_ASSERT_OK(writer->FlushExecutionFiles());
@@ -637,7 +637,7 @@
Execution* execution = new Execution();
execution->set_op_type("Abs");
execution->add_input_tensor_ids(counter.fetch_add(1));
- writer->WriteExecution(execution);
+ TF_ASSERT_OK(writer->WriteExecution(execution));
};
for (size_t i = 0; i < kCyclicBufferSize * 2; ++i) {
thread_pool->Schedule(fn);
@@ -682,7 +682,7 @@
for (size_t i = 0; i < kCyclicBufferSize * 2; ++i) {
GraphExecutionTrace* trace = new GraphExecutionTrace();
trace->set_tfdbg_context_id(strings::Printf("graph_%.2ld", i));
- writer->WriteGraphExecutionTrace(trace);
+ TF_ASSERT_OK(writer->WriteGraphExecutionTrace(trace));
}
std::vector<DebugEvent> actuals;
@@ -695,6 +695,31 @@
TF_ASSERT_OK(writer->Close());
}
+TEST_F(DebugEventsWriterTest, WriteGrahExecutionTraceWithoutPreviousInitCall) {
+ const size_t kCyclicBufferSize = -1;
+ DebugEventsWriter* writer =
+ DebugEventsWriter::GetDebugEventsWriter(dump_root_, kCyclicBufferSize);
+ // NOTE(cais): `writer->Init()` is not called here before
+ // WriteGraphExecutionTrace() is called. This test checks that this is okay
+ // and the `GraphExecutionTrace` gets written correctly even without `Init()`
+ // being called first. This scenario can happen when a TF Graph with tfdbg
+ // debug ops are executed on a remote TF server.
+
+ GraphExecutionTrace* trace = new GraphExecutionTrace();
+ trace->set_tfdbg_context_id(strings::Printf("graph_0"));
+ TF_ASSERT_OK(writer->WriteGraphExecutionTrace(trace));
+ TF_ASSERT_OK(writer->FlushExecutionFiles());
+
+ std::vector<DebugEvent> actuals;
+ ReadDebugEventProtos(writer, DebugEventFileType::GRAPH_EXECUTION_TRACES,
+ &actuals);
+ EXPECT_EQ(actuals.size(), 1);
+ EXPECT_EQ(actuals[0].graph_execution_trace().tfdbg_context_id(), "graph_0");
+
+ // Close the writer so the files can be safely deleted.
+ TF_ASSERT_OK(writer->Close());
+}
+
TEST_F(DebugEventsWriterTest, WriteGrahExecutionTraceWithCyclicBufferFlush) {
const size_t kCyclicBufferSize = 10;
DebugEventsWriter* writer =
@@ -706,7 +731,7 @@
for (size_t i = 0; i < kCyclicBufferSize * 2; ++i) {
GraphExecutionTrace* trace = new GraphExecutionTrace();
trace->set_tfdbg_context_id(strings::Printf("graph_%.2ld", i));
- writer->WriteGraphExecutionTrace(trace);
+ TF_ASSERT_OK(writer->WriteGraphExecutionTrace(trace));
}
TF_ASSERT_OK(writer->FlushExecutionFiles());
@@ -731,7 +756,7 @@
GraphExecutionTrace* trace = new GraphExecutionTrace();
trace->set_tfdbg_context_id(
strings::Printf("new_graph_%.2ld", counter.fetch_add(1)));
- writer->WriteGraphExecutionTrace(trace);
+ TF_ASSERT_OK(writer->WriteGraphExecutionTrace(trace));
};
for (size_t i = 0; i < kCyclicBufferSize * 2; ++i) {
thread_pool->Schedule(fn);
@@ -818,7 +843,7 @@
Execution* execution = new Execution();
execution->set_op_type("Log");
execution->add_input_tensor_ids(i);
- writer->WriteExecution(execution);
+ TF_ASSERT_OK(writer->WriteExecution(execution));
}
TF_ASSERT_OK(writer->FlushExecutionFiles());
@@ -834,7 +859,7 @@
for (size_t i = 0; i < kNumEvents; ++i) {
GraphExecutionTrace* trace = new GraphExecutionTrace();
trace->set_tfdbg_context_id(strings::Printf("graph_%.2ld", i));
- writer->WriteGraphExecutionTrace(trace);
+ TF_ASSERT_OK(writer->WriteGraphExecutionTrace(trace));
}
TF_ASSERT_OK(writer->FlushExecutionFiles());
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 53aa48b..a90fc2e 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -12059,7 +12059,7 @@
//
// value: The cropped area of the image must have an aspect ratio =
// width / height within this range.
-// If not specified, defaults to {f:0.75 f:1.33}
+// If not specified, defaults to {f:0.75 f:1.33}
func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr {
return func(m optionalAttr) {
m["aspect_ratio_range"] = value
@@ -12070,7 +12070,7 @@
//
// value: The cropped area of the image must contain a fraction of the
// supplied image within this range.
-// If not specified, defaults to {f:0.05 f:1}
+// If not specified, defaults to {f:0.05 f:1}
func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr {
return func(m optionalAttr) {
m["area_range"] = value
@@ -18975,7 +18975,7 @@
//
// value: The cropped area of the image must have an aspect ratio =
// width / height within this range.
-// If not specified, defaults to {f:0.75 f:1.33}
+// If not specified, defaults to {f:0.75 f:1.33}
func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr {
return func(m optionalAttr) {
m["aspect_ratio_range"] = value
@@ -18986,7 +18986,7 @@
//
// value: The cropped area of the image must contain a fraction of the
// supplied image within this range.
-// If not specified, defaults to {f:0.05 f:1}
+// If not specified, defaults to {f:0.05 f:1}
func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr {
return func(m optionalAttr) {
m["area_range"] = value
@@ -19390,7 +19390,7 @@
// ImageSummaryBadColor sets the optional bad_color attribute to value.
//
// value: Color to use for pixels with non-finite values.
-// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255}
+// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255}
func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr {
return func(m optionalAttr) {
m["bad_color"] = value
@@ -20461,7 +20461,7 @@
// filter element on that dimension. The dimension order is determined by the
// value of `data_format`, see above for details. Dilations in the batch and
// depth dimensions must be 1.
-// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr {
return func(m optionalAttr) {
m["dilations"] = value
@@ -21633,7 +21633,7 @@
// element on that dimension. The dimension order is determined by the value of
// `data_format`, see above for details. Dilations in the batch and depth
// dimensions must be 1.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr {
return func(m optionalAttr) {
m["dilations"] = value
@@ -22341,7 +22341,7 @@
// filter element on that dimension. The dimension order is determined by the
// value of `data_format`, see above for details. Dilations in the batch and
// depth dimensions must be 1.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
func Conv2DDilations(value []int64) Conv2DAttr {
return func(m optionalAttr) {
m["dilations"] = value
@@ -22537,7 +22537,7 @@
// QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value.
//
// value: List of dilation values.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr {
return func(m optionalAttr) {
m["dilations"] = value
@@ -22606,7 +22606,7 @@
// QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value.
//
// value: List of dilation values.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr {
return func(m optionalAttr) {
m["dilations"] = value
@@ -22721,7 +22721,7 @@
// QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value.
//
// value: List of dilation values.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr {
return func(m optionalAttr) {
m["dilations"] = value
@@ -22780,7 +22780,7 @@
// QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value.
//
// value: List of dilation values.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr {
return func(m optionalAttr) {
m["dilations"] = value
@@ -22954,7 +22954,7 @@
// QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value.
//
// value: list of dilation values.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr {
return func(m optionalAttr) {
m["dilations"] = value
@@ -23331,7 +23331,7 @@
// filter element on that dimension. The dimension order is determined by the
// value of `data_format`, see above for details. Dilations in the batch and
// depth dimensions must be 1.
-// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr {
return func(m optionalAttr) {
m["dilations"] = value
@@ -25651,7 +25651,7 @@
type Conv3DBackpropFilterAttr func(optionalAttr)
// Conv3DBackpropFilterDilations sets the optional dilations attribute to value.
-// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr {
return func(m optionalAttr) {
m["dilations"] = value
@@ -25714,7 +25714,7 @@
// filter element on that dimension. The dimension order is determined by the
// value of `data_format`, see above for details. Dilations in the batch and
// depth dimensions must be 1.
-// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
func Conv3DDilations(value []int64) Conv3DAttr {
return func(m optionalAttr) {
m["dilations"] = value
@@ -25965,7 +25965,7 @@
// element on that dimension. The dimension order is determined by the value of
// `data_format`, see above for details. Dilations in the batch and depth
// dimensions must be 1.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr {
return func(m optionalAttr) {
m["dilations"] = value
@@ -26449,7 +26449,7 @@
// filter element on that dimension. The dimension order is determined by the
// value of `data_format`, see above for details. Dilations in the batch and
// depth dimensions must be 1.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr {
return func(m optionalAttr) {
m["dilations"] = value
@@ -45537,7 +45537,7 @@
// element on that dimension. The dimension order is determined by the value of
// `data_format`, see above for details. Dilations in the batch and depth
// dimensions must be 1.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr {
return func(m optionalAttr) {
m["dilations"] = value
@@ -47477,7 +47477,7 @@
type Conv3DBackpropInputAttr func(optionalAttr)
// Conv3DBackpropInputDilations sets the optional dilations attribute to value.
-// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr {
return func(m optionalAttr) {
m["dilations"] = value
@@ -47548,7 +47548,7 @@
// element on that dimension. The dimension order is determined by the value of
// `data_format`, see above for details. Dilations in the batch and depth
// dimensions must be 1.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr {
return func(m optionalAttr) {
m["dilations"] = value
@@ -48537,7 +48537,7 @@
// element on that dimension. The dimension order is determined by the value of
// `data_format`, see above for details. Dilations in the batch and depth
// dimensions must be 1.
-// If not specified, defaults to {i:1 i:1 i:1 i:1}
+// If not specified, defaults to {i:1 i:1 i:1 i:1}
func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr {
return func(m optionalAttr) {
m["dilations"] = value
diff --git a/tensorflow/go/saved_model.go b/tensorflow/go/saved_model.go
index 7aa1e83..64ae82e 100644
--- a/tensorflow/go/saved_model.go
+++ b/tensorflow/go/saved_model.go
@@ -22,7 +22,7 @@
"unsafe"
"github.com/golang/protobuf/proto"
- corepb "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto"
+ corepb "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto"
)
// #include <stdlib.h>
diff --git a/tensorflow/go/signature.go b/tensorflow/go/signature.go
index 8aac0e2..c2db0c7 100644
--- a/tensorflow/go/signature.go
+++ b/tensorflow/go/signature.go
@@ -16,7 +16,7 @@
package tensorflow
-import corepb "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto"
+import corepb "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto"
// #include "tensorflow/c/c_api.h"
import "C"
diff --git a/tensorflow/go/signature_test.go b/tensorflow/go/signature_test.go
index e6927f3..f9fa8427 100644
--- a/tensorflow/go/signature_test.go
+++ b/tensorflow/go/signature_test.go
@@ -20,9 +20,9 @@
"fmt"
"testing"
- corepb "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto"
tspb "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/tensor_shape_go_proto"
typb "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/types_go_proto"
+ corepb "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto"
)
func TestSignatureFromProto(t *testing.T) {
diff --git a/tensorflow/lite/c/common.c b/tensorflow/lite/c/common.c
index f70a600..e6b4789 100644
--- a/tensorflow/lite/c/common.c
+++ b/tensorflow/lite/c/common.c
@@ -79,7 +79,8 @@
void TfLiteFloatArrayFree(TfLiteFloatArray* a) { free(a); }
void TfLiteTensorDataFree(TfLiteTensor* t) {
- if (t->allocation_type == kTfLiteDynamic) {
+ if (t->allocation_type == kTfLiteDynamic ||
+ t->allocation_type == kTfLitePersistentRo) {
free(t->data.raw);
}
t->data.raw = NULL;
@@ -172,7 +173,8 @@
}
void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor) {
- if (tensor->allocation_type != kTfLiteDynamic) {
+ if (tensor->allocation_type != kTfLiteDynamic &&
+ tensor->allocation_type != kTfLitePersistentRo) {
return;
}
// TODO(b/145340303): Tensor data should be aligned.
diff --git a/tensorflow/lite/c/common.h b/tensorflow/lite/c/common.h
index 9657c7e..ab150e8 100644
--- a/tensorflow/lite/c/common.h
+++ b/tensorflow/lite/c/common.h
@@ -321,15 +321,23 @@
void* data;
} TfLitePtrUnion;
-// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped
-// data (or data externally allocated). kTfLiteArenaRw is arena allocated
-// data. kTfLiteDynamic is for tensors that are allocated during evaluation.
+// Memory allocation strategies.
+// * kTfLiteMmapRo: Read-only memory-mapped data, or data externally allocated.
+// * kTfLiteArenaRw: Arena allocated with no guarantees about persistence,
+// and available during eval.
+// * kTfLiteArenaRwPersistent: Arena allocated but persistent across eval, and
+// only available during eval.
+// * kTfLiteDynamic: Allocated during eval, or for string tensors.
+// * kTfLitePersistentRo: Allocated and populated during prepare. This is
+// useful for tensors that can be computed during prepare and treated
+// as constant inputs for downstream ops (also in prepare).
typedef enum TfLiteAllocationType {
kTfLiteMemNone = 0,
kTfLiteMmapRo,
kTfLiteArenaRw,
kTfLiteArenaRwPersistent,
kTfLiteDynamic,
+ kTfLitePersistentRo,
} TfLiteAllocationType;
// The delegates should use zero or positive integers to represent handles.
diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc
index 4cebd05..7f4e0e2 100644
--- a/tensorflow/lite/core/subgraph.cc
+++ b/tensorflow/lite/core/subgraph.cc
@@ -1183,7 +1183,8 @@
// Note that in theory we could resize kTfLiteArenaRwPersistent tensors too.
if (tensor->allocation_type == kTfLiteArenaRw ||
tensor->allocation_type == kTfLiteDynamic ||
- tensor->allocation_type == kTfLiteArenaRwPersistent) {
+ tensor->allocation_type == kTfLiteArenaRwPersistent ||
+ tensor->allocation_type == kTfLitePersistentRo) {
tensor_resized_since_op_invoke_ |=
TfLiteIntArrayEqual(tensor->dims, new_size) == 0;
if (tensor->type != kTfLiteString) {
@@ -1195,14 +1196,16 @@
return kTfLiteError;
}
- // Realloc space for kTfLiteDynamic tensors.
+ // Realloc space for heap-allocated tensors.
TfLiteTensorRealloc(bytesRequired, tensor);
tensor->bytes = bytesRequired;
}
if (tensor->dims) TfLiteIntArrayFree(tensor->dims);
tensor->dims = new_size;
- if (tensor->allocation_type != kTfLiteDynamic) {
+ // Reset arena-allocated tensors; they will be allocated later.
+ if (tensor->allocation_type == kTfLiteArenaRw ||
+ tensor->allocation_type == kTfLiteArenaRwPersistent) {
tensor->data.raw = nullptr;
}
} else {
diff --git a/tensorflow/lite/delegates/flex/BUILD b/tensorflow/lite/delegates/flex/BUILD
index 9fe8060..d69d220 100644
--- a/tensorflow/lite/delegates/flex/BUILD
+++ b/tensorflow/lite/delegates/flex/BUILD
@@ -26,7 +26,7 @@
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//tensorflow:ios": [
- "//tensorflow/core:ios_tensorflow_lib_lite",
+ "//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/c:c_api_internal",
@@ -66,7 +66,7 @@
"//tensorflow/core:android_tensorflow_lib",
],
"//tensorflow:ios": [
- "//tensorflow/core:ios_tensorflow_lib",
+ "//tensorflow/core:portable_tensorflow_lib",
],
"//conditions:default": [
"//tensorflow/core:tensorflow",
@@ -103,7 +103,7 @@
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//tensorflow:ios": [
- "//tensorflow/core:ios_tensorflow_lib_lite",
+ "//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/core:lib",
@@ -137,7 +137,7 @@
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//tensorflow:ios": [
- "//tensorflow/core:ios_tensorflow_lib_lite",
+ "//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/core/common_runtime/eager:context",
@@ -183,7 +183,7 @@
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//tensorflow:ios": [
- "//tensorflow/core:ios_tensorflow_lib_lite",
+ "//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/core/common_runtime/eager:context",
@@ -211,7 +211,7 @@
"//tensorflow/core:android_tensorflow_lib",
],
"//tensorflow:ios": [
- "//tensorflow/core:ios_tensorflow_lib",
+ "//tensorflow/core:portable_tensorflow_lib",
],
"//conditions:default": [
"//tensorflow/core:tensorflow",
@@ -245,7 +245,7 @@
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//tensorflow:ios": [
- "//tensorflow/core:ios_tensorflow_lib_lite",
+ "//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/c:c_api_internal",
diff --git a/tensorflow/lite/delegates/gpu/BUILD b/tensorflow/lite/delegates/gpu/BUILD
index 099f653..2581232 100644
--- a/tensorflow/lite/delegates/gpu/BUILD
+++ b/tensorflow/lite/delegates/gpu/BUILD
@@ -167,7 +167,7 @@
"metal_delegate.h",
"metal_delegate_internal.h",
],
- minimum_os_version = "10.0",
+ minimum_os_version = "11.0",
deps = [":metal_delegate"],
)
diff --git a/tensorflow/lite/delegates/gpu/cl/api.cc b/tensorflow/lite/delegates/gpu/cl/api.cc
index 7ffb560..475eed4 100644
--- a/tensorflow/lite/delegates/gpu/cl/api.cc
+++ b/tensorflow/lite/delegates/gpu/cl/api.cc
@@ -352,10 +352,10 @@
};
TensorObject TensorToObj(const Tensor& tensor) {
- if (tensor.StorageType() == TensorStorageType::BUFFER) {
+ if (tensor.GetStorageType() == TensorStorageType::BUFFER) {
return OpenClBuffer{tensor.GetMemoryPtr()};
}
- if (tensor.StorageType() == TensorStorageType::IMAGE_BUFFER) {
+ if (tensor.GetStorageType() == TensorStorageType::IMAGE_BUFFER) {
return OpenClBuffer{tensor.GetMemoryPtrForWriting()};
}
return OpenClTexture{tensor.GetMemoryPtr()};
@@ -516,9 +516,9 @@
def.dimensions.h = tensor.Height();
def.dimensions.w = tensor.Width();
def.dimensions.c = tensor.Channels();
- def.object_def.data_layout = ToDataLayout(tensor.StorageType());
- def.object_def.data_type = tensor.DataType();
- def.object_def.object_type = ToObjectType(tensor.StorageType());
+ def.object_def.data_layout = ToDataLayout(tensor.GetStorageType());
+ def.object_def.data_type = tensor.GetDataType();
+ def.object_def.object_type = ToObjectType(tensor.GetStorageType());
def.object_def.user_provided = false;
return def;
}
diff --git a/tensorflow/lite/delegates/gpu/cl/tensor.cc b/tensorflow/lite/delegates/gpu/cl/tensor.cc
index f01975e..4a52508 100644
--- a/tensorflow/lite/delegates/gpu/cl/tensor.cc
+++ b/tensorflow/lite/delegates/gpu/cl/tensor.cc
@@ -29,7 +29,7 @@
namespace {
absl::Status CreateImageBufferFromBuffer(const CLContext& context,
- cl_mem memory, enum DataType data_type,
+ cl_mem memory, DataType data_type,
int width, cl_mem* result) {
cl_image_format format;
cl_image_desc desc;
diff --git a/tensorflow/lite/delegates/gpu/cl/tensor.h b/tensorflow/lite/delegates/gpu/cl/tensor.h
index d59ef83..cb7d426 100644
--- a/tensorflow/lite/delegates/gpu/cl/tensor.h
+++ b/tensorflow/lite/delegates/gpu/cl/tensor.h
@@ -75,8 +75,8 @@
int4 GetWHSB() const { return int4(shape_.w, shape_.h, Slices(), shape_.b); }
int4 GetWHDS() const { return int4(shape_.w, shape_.h, shape_.d, Slices()); }
- enum DataType DataType() const { return descriptor_.data_type; }
- TensorStorageType StorageType() const { return descriptor_.storage_type; }
+ DataType GetDataType() const { return descriptor_.data_type; }
+ TensorStorageType GetStorageType() const { return descriptor_.storage_type; }
// for profiling and memory statistics
uint64_t GetMemorySizeInBytes() const;
diff --git a/tensorflow/lite/delegates/gpu/common/operations.cc b/tensorflow/lite/delegates/gpu/common/operations.cc
index 3924f91..bdcf6f6 100644
--- a/tensorflow/lite/delegates/gpu/common/operations.cc
+++ b/tensorflow/lite/delegates/gpu/common/operations.cc
@@ -506,6 +506,14 @@
attr.appended.c + attr.prepended.c + input.c);
}
+BHWDC CalculateOutputShape(const BHWDC& input, const Pad3DAttributes& attr) {
+ return BHWDC(attr.appended.b + attr.prepended.b + input.b,
+ attr.appended.h + attr.prepended.h + input.h,
+ attr.appended.w + attr.prepended.w + input.w,
+ attr.appended.d + attr.prepended.d + input.d,
+ attr.appended.c + attr.prepended.c + input.c);
+}
+
BHWC CalculateOutputShape(const BHWC& input,
const FullyConnectedAttributes& attr) {
return BHWC(input.b, 1, 1, attr.weights.shape.o);
diff --git a/tensorflow/lite/delegates/gpu/common/operations.h b/tensorflow/lite/delegates/gpu/common/operations.h
index f8bfc77..d0268ee 100644
--- a/tensorflow/lite/delegates/gpu/common/operations.h
+++ b/tensorflow/lite/delegates/gpu/common/operations.h
@@ -431,6 +431,17 @@
// @return shape of a tensor after Pad operation is applied to the given input.
BHWC CalculateOutputShape(const BHWC& input, const PadAttributes& attr);
+struct Pad3DAttributes {
+ PaddingContentType type = PaddingContentType::ZEROS;
+
+ BHWDC prepended;
+ BHWDC appended;
+};
+
+// @return shape of a tensor after Pad3D operation is applied to the given
+// input.
+BHWDC CalculateOutputShape(const BHWDC& input, const Pad3DAttributes& attr);
+
struct ConstTensorAttributes {
Tensor<BHWC, DataType::FLOAT32> tensor;
};
diff --git a/tensorflow/lite/delegates/gpu/delegate.cc b/tensorflow/lite/delegates/gpu/delegate.cc
index 58da886..4b6727e 100644
--- a/tensorflow/lite/delegates/gpu/delegate.cc
+++ b/tensorflow/lite/delegates/gpu/delegate.cc
@@ -263,12 +263,12 @@
input_refs->clear();
output_refs->clear();
- const auto& inputs = graph->inputs();
+ const auto inputs = graph->inputs();
input_refs->reserve(inputs.size());
for (const auto& input : inputs) {
input_refs->push_back(input->tensor.ref);
}
- const auto& outputs = graph->outputs();
+ const auto outputs = graph->outputs();
output_refs->reserve(outputs.size());
for (const auto& output : outputs) {
output_refs->push_back(output->tensor.ref);
diff --git a/tensorflow/lite/delegates/gpu/metal/BUILD b/tensorflow/lite/delegates/gpu/metal/BUILD
index 192c787..4db8f3d 100644
--- a/tensorflow/lite/delegates/gpu/metal/BUILD
+++ b/tensorflow/lite/delegates/gpu/metal/BUILD
@@ -80,7 +80,7 @@
ios_unit_test(
name = "common_test",
testonly = 1,
- minimum_os_version = "10.0",
+ minimum_os_version = "11.0",
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = tf_gpu_tests_tags() + [
"notap",
@@ -117,7 +117,7 @@
ios_unit_test(
name = "compiled_model_test",
testonly = 1,
- minimum_os_version = "10.0",
+ minimum_os_version = "11.0",
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = tf_gpu_tests_tags() + [
"notap",
@@ -205,7 +205,7 @@
ios_unit_test(
name = "inference_context_test",
testonly = 1,
- minimum_os_version = "10.0",
+ minimum_os_version = "11.0",
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = tf_gpu_tests_tags() + [
"notap",
@@ -235,7 +235,7 @@
"iphone",
],
infoplists = ["Info.plist"],
- minimum_os_version = "10.0",
+ minimum_os_version = "11.0",
provisioning_profile = "//tensorflow/lite/delegates/gpu/metal:provisioning_profile.mobileprovision",
tags = tf_gpu_tests_tags() + [
"local",
@@ -267,7 +267,7 @@
ios_unit_test(
name = "ComponentsTests",
- minimum_os_version = "10.0",
+ minimum_os_version = "11.0",
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = tf_gpu_tests_tags() + ["notap"],
test_host = ":TestApplication",
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
index a1052b8..657e9b5 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
@@ -71,7 +71,7 @@
ios_unit_test(
name = "add_test",
testonly = 1,
- minimum_os_version = "10.0",
+ minimum_os_version = "11.0",
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = tf_gpu_tests_tags() + [
"notap",
@@ -109,7 +109,7 @@
ios_unit_test(
name = "concat_test",
testonly = 1,
- minimum_os_version = "10.0",
+ minimum_os_version = "11.0",
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = tf_gpu_tests_tags() + [
"notap",
@@ -151,7 +151,7 @@
ios_unit_test(
name = "conv_test",
testonly = 1,
- minimum_os_version = "10.0",
+ minimum_os_version = "11.0",
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = tf_gpu_tests_tags() + [
"notap",
@@ -213,7 +213,7 @@
ios_unit_test(
name = "depthwise_conv_test",
testonly = 1,
- minimum_os_version = "10.0",
+ minimum_os_version = "11.0",
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = tf_gpu_tests_tags() + [
"notap",
@@ -253,7 +253,7 @@
ios_unit_test(
name = "elementwise_test",
testonly = 1,
- minimum_os_version = "10.0",
+ minimum_os_version = "11.0",
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = tf_gpu_tests_tags() + [
"notap",
@@ -293,7 +293,7 @@
ios_unit_test(
name = "fully_connected_test",
testonly = 1,
- minimum_os_version = "10.0",
+ minimum_os_version = "11.0",
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = tf_gpu_tests_tags() + [
"notap",
@@ -332,7 +332,7 @@
ios_unit_test(
name = "max_unpooling_test",
testonly = 1,
- minimum_os_version = "10.0",
+ minimum_os_version = "11.0",
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = tf_gpu_tests_tags() + [
"notap",
@@ -371,7 +371,7 @@
ios_unit_test(
name = "mean_test",
testonly = 1,
- minimum_os_version = "10.0",
+ minimum_os_version = "11.0",
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = [
"notap",
@@ -450,7 +450,7 @@
ios_unit_test(
name = "padding_test",
testonly = 1,
- minimum_os_version = "10.0",
+ minimum_os_version = "11.0",
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = tf_gpu_tests_tags() + [
"notap",
@@ -490,7 +490,7 @@
ios_unit_test(
name = "pooling_test",
testonly = 1,
- minimum_os_version = "10.0",
+ minimum_os_version = "11.0",
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = tf_gpu_tests_tags() + [
"notap",
@@ -530,7 +530,7 @@
ios_unit_test(
name = "prelu_test",
testonly = 1,
- minimum_os_version = "10.0",
+ minimum_os_version = "11.0",
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = tf_gpu_tests_tags() + [
"notap",
@@ -569,7 +569,7 @@
ios_unit_test(
name = "relu_test",
testonly = 1,
- minimum_os_version = "10.0",
+ minimum_os_version = "11.0",
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = tf_gpu_tests_tags() + [
"notap",
@@ -608,7 +608,7 @@
ios_unit_test(
name = "resize_test",
testonly = 1,
- minimum_os_version = "10.0",
+ minimum_os_version = "11.0",
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = tf_gpu_tests_tags() + [
"notap",
@@ -648,7 +648,7 @@
ios_unit_test(
name = "reshape_test",
testonly = 1,
- minimum_os_version = "10.0",
+ minimum_os_version = "11.0",
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = tf_gpu_tests_tags() + [
"notap",
@@ -687,7 +687,7 @@
ios_unit_test(
name = "slice_test",
testonly = 1,
- minimum_os_version = "10.0",
+ minimum_os_version = "11.0",
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = tf_gpu_tests_tags() + [
"notap",
@@ -727,7 +727,7 @@
ios_unit_test(
name = "softmax_test",
testonly = 1,
- minimum_os_version = "10.0",
+ minimum_os_version = "11.0",
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = tf_gpu_tests_tags() + [
"notap",
@@ -764,7 +764,7 @@
ios_unit_test(
name = "space_to_depth_test",
testonly = 1,
- minimum_os_version = "10.0",
+ minimum_os_version = "11.0",
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = tf_gpu_tests_tags() + [
"notap",
@@ -804,7 +804,7 @@
ios_unit_test(
name = "transpose_conv_test",
testonly = 1,
- minimum_os_version = "10.0",
+ minimum_os_version = "11.0",
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = tf_gpu_tests_tags() + [
"notap",
@@ -885,7 +885,7 @@
ios_unit_test(
name = "winograd_test",
testonly = 1,
- minimum_os_version = "10.0",
+ minimum_os_version = "11.0",
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = tf_gpu_tests_tags() + [
"notap",
diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
index e790d42..002c299 100644
--- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
+++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
@@ -3769,7 +3769,8 @@
}
} else if (reg->builtin_code == kTfLiteBuiltinMaximum ||
reg->builtin_code == kTfLiteBuiltinMinimum) {
- const TfLiteTensor& operand_tensor = context->tensors[input_pos];
+ const TfLiteTensor& operand_tensor =
+ context->tensors[node->inputs->data[input_pos]];
if (operand_tensor.dims->size == 0) {
int tensor_index;
@@ -3814,7 +3815,8 @@
reg->builtin_code == kTfLiteBuiltinSum) &&
(input_pos == 1)) {
// The axis needs, be converted to a tensor if specified as scalar
- const TfLiteTensor& axis_tensor = context->tensors[1];
+ const TfLiteTensor& axis_tensor =
+ context->tensors[node->inputs->data[input_pos]];
if (axis_tensor.dims->size == 0) {
TF_LITE_ENSURE_STATUS(
builder.AddVectorInt32Operand(axis_tensor.data.i32, 1));
diff --git a/tensorflow/lite/delegates/utils/BUILD b/tensorflow/lite/delegates/utils/BUILD
new file mode 100644
index 0000000..069da16
--- /dev/null
+++ b/tensorflow/lite/delegates/utils/BUILD
@@ -0,0 +1,36 @@
+package(
+ default_visibility = [
+ "//visibility:public",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "simple_delegate",
+ srcs = [
+ "simple_delegate.cc",
+ ],
+ hdrs = [
+ "simple_delegate.h",
+ ],
+ deps = [
+ "//tensorflow/lite:kernel_api",
+ "//tensorflow/lite:minimal_logging",
+ "//tensorflow/lite/c:common",
+ "//tensorflow/lite/delegates:utils",
+ "//tensorflow/lite/kernels/internal:compatibility",
+ ],
+)
+
+cc_test(
+ name = "simple_delegate_test",
+ srcs = ["simple_delegate_test.cc"],
+ deps = [
+ ":simple_delegate",
+ "//tensorflow/lite:framework",
+ "//tensorflow/lite:kernel_api",
+ "//tensorflow/lite/c:common",
+ "//tensorflow/lite/kernels:builtin_ops",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
diff --git a/tensorflow/lite/delegates/utils/simple_delegate.cc b/tensorflow/lite/delegates/utils/simple_delegate.cc
new file mode 100644
index 0000000..51736e5
--- /dev/null
+++ b/tensorflow/lite/delegates/utils/simple_delegate.cc
@@ -0,0 +1,140 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/delegates/utils/simple_delegate.h"
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/lite/builtin_ops.h"
+#include "tensorflow/lite/context_util.h"
+#include "tensorflow/lite/delegates/utils.h"
+#include "tensorflow/lite/kernels/internal/compatibility.h"
+#include "tensorflow/lite/minimal_logging.h"
+
+namespace tflite {
+namespace {
+TfLiteRegistration GetDelegateKernelRegistration(
+ SimpleDelegateInterface* delegate) {
+ TfLiteRegistration kernel_registration;
+ kernel_registration.profiling_string = nullptr;
+ kernel_registration.builtin_code = kTfLiteBuiltinDelegate;
+ kernel_registration.custom_name = delegate->name();
+ kernel_registration.free = [](TfLiteContext* context, void* buffer) -> void {
+ delete reinterpret_cast<SimpleDelegateKernelInterface*>(buffer);
+ };
+ kernel_registration.init = [](TfLiteContext* context, const char* buffer,
+ size_t length) -> void* {
+ const TfLiteDelegateParams* params =
+ reinterpret_cast<const TfLiteDelegateParams*>(buffer);
+ if (params == nullptr) {
+ TF_LITE_KERNEL_LOG(context, "NULL TfLiteDelegateParams passed.");
+ return nullptr;
+ }
+ auto* delegate =
+ reinterpret_cast<SimpleDelegateInterface*>(params->delegate->data_);
+ std::unique_ptr<SimpleDelegateKernelInterface> delegate_kernel(
+ delegate->CreateDelegateKernelInterface());
+ if (delegate_kernel->Init(context, params) != kTfLiteOk) {
+ return nullptr;
+ }
+ return delegate_kernel.release();
+ };
+ kernel_registration.prepare = [](TfLiteContext* context,
+ TfLiteNode* node) -> TfLiteStatus {
+ if (node->user_data == nullptr) {
+ TF_LITE_KERNEL_LOG(context, "Delegate kernel was not initialized");
+ return kTfLiteError;
+ }
+ SimpleDelegateKernelInterface* delegate_kernel =
+ reinterpret_cast<SimpleDelegateKernelInterface*>(node->user_data);
+ return delegate_kernel->Prepare(context, node);
+ };
+ kernel_registration.invoke = [](TfLiteContext* context,
+ TfLiteNode* node) -> TfLiteStatus {
+ SimpleDelegateKernelInterface* delegate_kernel =
+ reinterpret_cast<SimpleDelegateKernelInterface*>(node->user_data);
+ TFLITE_DCHECK(delegate_kernel != nullptr);
+ return delegate_kernel->Invoke(context, node);
+ };
+
+ return kernel_registration;
+}
+
+TfLiteStatus DelegatePrepare(TfLiteContext* context,
+ TfLiteDelegate* base_delegate) {
+ auto* delegate =
+ reinterpret_cast<SimpleDelegateInterface*>(base_delegate->data_);
+ delegates::IsNodeSupportedFn node_supported_fn =
+ [=](TfLiteContext* context, TfLiteNode* node,
+ TfLiteRegistration* registration,
+ std::string* unsupported_details) -> bool {
+ return delegate->IsNodeSupportedByDelegate(registration, node, context);
+ };
+ // TODO(b/149484598): Update to have method that gets all supported nodes.
+ delegates::GraphPartitionHelper helper(context, node_supported_fn);
+ TF_LITE_ENSURE_STATUS(helper.Partition(nullptr));
+
+ const auto delegate_partitions = helper.GetFirstNLargestPartitions();
+
+ // To avoid creating a new TfLiteIntArray and free it later, we reserve one
+ // element to represent TfLiteIntArray.size which is the 1st element of
+ // TfLiteIntArray C struct.
+ std::vector<int> supported_nodes(1);
+ for (const auto partition : delegate_partitions) {
+ auto* nodes = partition->nodes_to_replace;
+ supported_nodes.insert(supported_nodes.end(), nodes->data,
+ nodes->data + nodes->size);
+ }
+ // Set first element to the number of nodes to replace.
+ supported_nodes[0] = supported_nodes.size() - 1;
+
+ TFLITE_LOG_PROD(tflite::TFLITE_LOG_INFO,
+ "%s delegate: %d nodes delegated out of %d nodes with "
+ "%d partitions.\n",
+ delegate->name(), supported_nodes[0],
+ helper.num_total_nodes(), delegate_partitions.size());
+ TfLiteRegistration delegate_kernel_registration =
+ GetDelegateKernelRegistration(delegate);
+
+ return context->ReplaceNodeSubsetsWithDelegateKernels(
+ context, delegate_kernel_registration,
+ reinterpret_cast<TfLiteIntArray*>(supported_nodes.data()), base_delegate);
+}
+} // namespace
+
+TfLiteDelegate* TfLiteDelegateFactory::CreateSimpleDelegate(
+ std::unique_ptr<SimpleDelegateInterface> simple_delegate) {
+ if (simple_delegate == nullptr) {
+ return nullptr;
+ }
+ auto delegate = new TfLiteDelegate();
+ delegate->Prepare = &DelegatePrepare;
+ delegate->flags = kTfLiteDelegateFlagsNone;
+ delegate->CopyFromBufferHandle = nullptr;
+ delegate->CopyToBufferHandle = nullptr;
+ delegate->FreeBufferHandle = nullptr;
+ delegate->data_ = simple_delegate.release();
+ return delegate;
+}
+
+void TfLiteDelegateFactory::DeleteSimpleDelegate(TfLiteDelegate* delegate) {
+ if (!delegate) return;
+ SimpleDelegateInterface* simple_delegate =
+ reinterpret_cast<SimpleDelegateInterface*>(delegate->data_);
+ delete simple_delegate;
+ delete delegate;
+}
+
+} // namespace tflite
diff --git a/tensorflow/lite/delegates/utils/simple_delegate.h b/tensorflow/lite/delegates/utils/simple_delegate.h
new file mode 100644
index 0000000..bf35fbc
--- /dev/null
+++ b/tensorflow/lite/delegates/utils/simple_delegate.h
@@ -0,0 +1,109 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file has utilities that facilitates creating new delegates.
+// - SimpleDelegateKernelInterface: Represents a Kernel which handles a subgraph
+// to be delegated. It has Init/Prepare/Invoke which are going to be called
+// during inference, similar to TFLite Kernels. Delegate owner should implement
+// this interface to build/prepare/invoke the delegated subgraph.
+// - SimpleDelegateInterface:
+// This class wraps TFLiteDelegate and users need to implement the interface and
+// then Call GetFinalizedDelegate() to get TfLiteDelegate* that can be passed to
+// ModifyGraphWithDelegate.
+#ifndef TENSORFLOW_LITE_DELEGATES_UTILS_SIMPLE_DELEGATE_H_
+#define TENSORFLOW_LITE_DELEGATES_UTILS_SIMPLE_DELEGATE_H_
+
+#include <memory>
+
+#include "tensorflow/lite/c/common.h"
+
+namespace tflite {
+
+// Users should inherit from this class and implement the interface below.
+// Each instance represents a single part of the graph (subgraph).
+class SimpleDelegateKernelInterface {
+ public:
+ virtual ~SimpleDelegateKernelInterface() {}
+
+ // Initializes a delegated subgraph.
+ // The nodes in the subgraph are inside TfLiteDelegateParams->nodes_to_replace
+ virtual TfLiteStatus Init(TfLiteContext* context,
+ const TfLiteDelegateParams* params) = 0;
+
+ // Will be called by the framework. Should handle any needed preparation
+ // for the subgraph e.g. allocating buffers, compiling model.
+ // Returns status, and signalling any errors.
+ virtual TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) = 0;
+
+ // Actual subgraph inference should happen on this call.
+ // Returns status, and signalling any errors.
+ virtual TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) = 0;
+};
+
+// Pure Interface that clients should implement.
+// The Interface represents a delegate capabilities and provide factory
+// for SimpleDelegateKernelInterface
+//
+// Clients should implement the following methods:
+// - IsNodeSupportedByDelegate
+// - name
+// - CreateDelegateKernelInterface
+class SimpleDelegateInterface {
+ public:
+ SimpleDelegateInterface() {}
+
+ virtual ~SimpleDelegateInterface() {}
+
+ // Returns true if 'node' is supported by the delegate. False otherwise.
+ virtual bool IsNodeSupportedByDelegate(const TfLiteRegistration* registration,
+ const TfLiteNode* node,
+ TfLiteContext* context) const = 0;
+
+ // Returns a name that identifies the delegate.
+ // This name is used for debugging/logging/profiling.
+ virtual const char* name() const = 0;
+
+ // Returns instance of an object that implements the interface
+ // SimpleDelegateKernelInterface.
+ // An instance of SimpleDelegateKernelInterface represents one subgraph to
+ // be delegated.
+ // Caller takes ownership of the returned object.
+ virtual std::unique_ptr<SimpleDelegateKernelInterface>
+ CreateDelegateKernelInterface() = 0;
+};
+
+// Factory class that provides two static methods
+// CreateSimpleDelegate
+// DeleteSimpleDelegate
+// Which should be used to construct TfLiteDelegate from
+// Simple Delegate and delete TfLiteDelegate and SimpleDelegate give
+// tfLiteDelegate* created from 'CreateSimpleDelegate' method.
+// Users should use these methods to Create and Destroy the delegate.
+class TfLiteDelegateFactory {
+ public:
+ // Creates TfLiteDelegate from the provided SimpleDelegateInterface.
+ // The returned TfLiteDelegate should be deleted using DeleteSimpleDelegate.
+ static TfLiteDelegate* CreateSimpleDelegate(
+ std::unique_ptr<SimpleDelegateInterface> simple_delegate);
+
+ // Deletes 'delegate' the passed pointer must be the one returned
+ // from GetFinalizedDelegate.
+ // This function will destruct the SimpleDelegate object too.
+ static void DeleteSimpleDelegate(TfLiteDelegate* delegate);
+};
+
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_DELEGATES_UTILS_SIMPLE_DELEGATE_H_
diff --git a/tensorflow/lite/delegates/utils/simple_delegate_test.cc b/tensorflow/lite/delegates/utils/simple_delegate_test.cc
new file mode 100644
index 0000000..fa6d528
--- /dev/null
+++ b/tensorflow/lite/delegates/utils/simple_delegate_test.cc
@@ -0,0 +1,194 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/delegates/utils/simple_delegate.h"
+
+#include <memory>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/lite/builtin_ops.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/kernels/builtin_op_kernels.h"
+
+namespace tflite {
+namespace {
+// Delegate options.
+struct TestSimpleDelegateOptions {
+ // Allowed ops to delegate.
+ int allowed_builtin_code;
+ // Report error during init.
+ bool error_during_init = false;
+ // Report error during prepare.
+ bool error_during_prepare = false;
+ // Report error during invoke.
+ bool error_during_invoke = false;
+};
+
+// Dummy delegate kernel.
+class TestSimpleDelegateKernel : public SimpleDelegateKernelInterface {
+ public:
+ explicit TestSimpleDelegateKernel(TestSimpleDelegateOptions options)
+ : options_(options) {}
+
+ TfLiteStatus Init(TfLiteContext* context,
+ const TfLiteDelegateParams* params) override {
+ return !options_.error_during_init ? kTfLiteOk : kTfLiteError;
+ }
+
+ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) override {
+ return !options_.error_during_prepare ? kTfLiteOk : kTfLiteError;
+ }
+
+ TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) override {
+ return !options_.error_during_invoke ? kTfLiteOk : kTfLiteError;
+ }
+
+ private:
+ TestSimpleDelegateOptions options_;
+};
+
+// Simple delegate which implements the interface of SimpleDelegateInterface.
+// This holds the Delegate capabilities.
+class TestSimpleDelegate : public SimpleDelegateInterface {
+ public:
+ explicit TestSimpleDelegate(TestSimpleDelegateOptions options)
+ : options_(options) {}
+ bool IsNodeSupportedByDelegate(const TfLiteRegistration* registration,
+ const TfLiteNode* node,
+ TfLiteContext* context) const override {
+ return options_.allowed_builtin_code == registration->builtin_code;
+ }
+
+ const char* name() const override { return "TestSimpleDelegate"; }
+
+ std::unique_ptr<SimpleDelegateKernelInterface> CreateDelegateKernelInterface()
+ override {
+ return std::make_unique<TestSimpleDelegateKernel>(options_);
+ }
+
+ private:
+ TestSimpleDelegateOptions options_;
+};
+
+class TestDelegate : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ interpreter_.reset(new Interpreter);
+ interpreter_->AddTensors(5);
+ interpreter_->SetInputs({0, 1});
+ interpreter_->SetOutputs({3, 4});
+ TfLiteQuantizationParams quant;
+ interpreter_->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3},
+ quant);
+ interpreter_->SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3},
+ quant);
+ interpreter_->SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {3},
+ quant);
+ interpreter_->SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {3},
+ quant);
+ interpreter_->SetTensorParametersReadWrite(4, kTfLiteFloat32, "", {3},
+ quant);
+ TfLiteRegistration* reg = ops::builtin::Register_ADD();
+ void* builtin_data_1 = malloc(sizeof(int));
+ void* builtin_data_2 = malloc(sizeof(int));
+ void* builtin_data_3 = malloc(sizeof(int));
+ interpreter_->AddNodeWithParameters({0, 0}, {2}, nullptr, 0, builtin_data_1,
+ reg);
+ interpreter_->AddNodeWithParameters({1, 1}, {3}, nullptr, 0, builtin_data_2,
+ reg);
+ interpreter_->AddNodeWithParameters({2, 1}, {4}, nullptr, 0, builtin_data_3,
+ reg);
+ }
+
+ void TearDown() override {
+ interpreter_.reset();
+ TfLiteDelegateFactory::DeleteSimpleDelegate(delegate_);
+ }
+
+ protected:
+ std::unique_ptr<Interpreter> interpreter_;
+ TfLiteDelegate* delegate_ = nullptr;
+};
+
+TEST_F(TestDelegate, BasicDelegate) {
+ TestSimpleDelegateOptions options;
+ options.allowed_builtin_code = kTfLiteBuiltinAdd;
+ delegate_ = TfLiteDelegateFactory::CreateSimpleDelegate(
+ std::make_unique<TestSimpleDelegate>(options));
+ interpreter_->ModifyGraphWithDelegate(delegate_);
+
+ ASSERT_EQ(interpreter_->execution_plan().size(), 1);
+ int node = interpreter_->execution_plan()[0];
+ const auto* node_and_reg = interpreter_->node_and_registration(node);
+ EXPECT_EQ("TestSimpleDelegate", node_and_reg->second.custom_name);
+
+ const TfLiteDelegateParams* params = static_cast<const TfLiteDelegateParams*>(
+ node_and_reg->first.builtin_data);
+ ASSERT_EQ(params->nodes_to_replace->size, 3);
+ EXPECT_EQ(params->nodes_to_replace->data[0], 0);
+ EXPECT_EQ(params->nodes_to_replace->data[1], 1);
+ EXPECT_EQ(params->nodes_to_replace->data[2], 2);
+
+ ASSERT_EQ(params->input_tensors->size, 2);
+ EXPECT_EQ(params->input_tensors->data[0], 0);
+ EXPECT_EQ(params->input_tensors->data[1], 1);
+
+ ASSERT_EQ(params->output_tensors->size, 2);
+ EXPECT_EQ(params->output_tensors->data[0], 3);
+ EXPECT_EQ(params->output_tensors->data[1], 4);
+}
+
+TEST_F(TestDelegate, NoNodesToDelegate) {
+ TestSimpleDelegateOptions options;
+ options.allowed_builtin_code = kTfLiteBuiltinSub;
+ delegate_ = TfLiteDelegateFactory::CreateSimpleDelegate(
+ std::make_unique<TestSimpleDelegate>(options));
+ interpreter_->ModifyGraphWithDelegate(delegate_);
+
+ ASSERT_EQ(interpreter_->execution_plan().size(), 3);
+}
+
+TEST_F(TestDelegate, DelegateFailedPrepare) {
+ TestSimpleDelegateOptions options;
+ options.allowed_builtin_code = kTfLiteBuiltinAdd;
+ options.error_during_prepare = true;
+ delegate_ = TfLiteDelegateFactory::CreateSimpleDelegate(
+ std::make_unique<TestSimpleDelegate>(options));
+ ASSERT_EQ(kTfLiteDelegateError,
+ interpreter_->ModifyGraphWithDelegate(delegate_));
+}
+
+TEST_F(TestDelegate, DelegateFailedInvoke) {
+ TestSimpleDelegateOptions options;
+ options.allowed_builtin_code = kTfLiteBuiltinAdd;
+ options.error_during_invoke = true;
+ delegate_ = TfLiteDelegateFactory::CreateSimpleDelegate(
+ std::make_unique<TestSimpleDelegate>(options));
+ ASSERT_EQ(kTfLiteOk, interpreter_->ModifyGraphWithDelegate(delegate_));
+ ASSERT_EQ(kTfLiteError, interpreter_->Invoke());
+}
+
+TEST_F(TestDelegate, DelegateFailedInit) {
+ TestSimpleDelegateOptions options;
+ options.allowed_builtin_code = kTfLiteBuiltinAdd;
+ options.error_during_init = true;
+ delegate_ = TfLiteDelegateFactory::CreateSimpleDelegate(
+ std::make_unique<TestSimpleDelegate>(options));
+ ASSERT_EQ(kTfLiteDelegateError,
+ interpreter_->ModifyGraphWithDelegate(delegate_));
+}
+} // namespace
+} // namespace tflite
diff --git a/tensorflow/lite/experimental/delegates/hexagon/README.md b/tensorflow/lite/experimental/delegates/hexagon/README.md
index 5cf71fd..07f1a92 100644
--- a/tensorflow/lite/experimental/delegates/hexagon/README.md
+++ b/tensorflow/lite/experimental/delegates/hexagon/README.md
@@ -80,6 +80,7 @@
* L2Normalization (without any activation)
* Logistic (aka Sigmoid)
* MaxPool2D (without any activation) (b/129276536)
+* MirrorPad
* Mul (without any activation) (b/129276536)
* Neg
* Pad: Only supports 0 padding (b/139277813)
diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/BUILD b/tensorflow/lite/experimental/delegates/hexagon/builders/BUILD
index ae8ffe2..ff76498 100644
--- a/tensorflow/lite/experimental/delegates/hexagon/builders/BUILD
+++ b/tensorflow/lite/experimental/delegates/hexagon/builders/BUILD
@@ -19,6 +19,7 @@
"hardswish_builder.cc",
"l2_normalization_builder.cc",
"matmul_builder.cc",
+ "mirror_pad_builder.cc",
"neg_op_builder.cc",
"op_builder.cc",
"pad_builder.cc",
@@ -45,6 +46,7 @@
"hardswish_builder.h",
"l2_normalization_builder.h",
"matmul_builder.h",
+ "mirror_pad_builder.h",
"neg_op_builder.h",
"op_builder.h",
"pad_builder.h",
diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/mirror_pad_builder.cc b/tensorflow/lite/experimental/delegates/hexagon/builders/mirror_pad_builder.cc
new file mode 100644
index 0000000..2a04088
--- /dev/null
+++ b/tensorflow/lite/experimental/delegates/hexagon/builders/mirror_pad_builder.cc
@@ -0,0 +1,112 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/experimental/delegates/hexagon/builders/mirror_pad_builder.h"
+
+#include <stdint.h>
+
+#include <limits>
+
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/experimental/delegates/hexagon/hexagon_nn/hexagon_nn.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace delegates {
+namespace hexagon {
+TfLiteStatus MirrorPadOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
+ const TfLiteIntArray* outputs,
+ TfLiteContext* context) {
+ static int quant_bound_shape[] = {1, 1, 1, 1};
+ int tensor_id;
+
+ // Input data tensor.
+ tensor_id = inputs->data[0];
+ const auto& input_tensor = context->tensors[tensor_id];
+ AddInput(graph_builder_->GetHexagonTensorId(tensor_id));
+
+ // Padding tensor.
+ // Should be a constant.
+ tensor_id = inputs->data[1];
+ const auto& padding_tensor = context->tensors[tensor_id];
+ if (padding_tensor.dims->size != 2 || padding_tensor.dims->data[0] > 4 ||
+ padding_tensor.dims->data[1] != 2) {
+ TF_LITE_KERNEL_LOG(context, "Invalid padding tensor shape");
+ return kTfLiteError;
+ }
+ paddings_shape_ = {1, 1, 4, 2};
+ std::vector<int> padding_data(8, 0);
+ // Hexagon always expects padding data for each dimension in order {b, h, w,
+ // d}. This start value ensures we pad the non-relevant dimensions with 0.
+ int padding_data_start = 8 - padding_tensor.dims->data[0] * 2;
+ for (int i = 0; i < padding_tensor.dims->data[0] * 2; ++i) {
+ padding_data[padding_data_start + i] = padding_tensor.data.i32[i];
+ }
+ auto* const_padding_node = graph_builder_->AddConstNodeWithData(
+ paddings_shape_.data(), reinterpret_cast<char*>(padding_data.data()),
+ padding_data.size() * sizeof(padding_data[0]));
+ AddInput(TensorID(const_padding_node->GetID(), 0));
+ // Padding type.
+ const TfLiteMirrorPaddingParams* params =
+ reinterpret_cast<const TfLiteMirrorPaddingParams*>(builtin_data_);
+ if (params->mode == kTfLiteMirrorPaddingReflect) {
+ SetPaddingType(NN_PAD_MIRROR_REFLECT);
+ } else if (params->mode == kTfLiteMirrorPaddingSymmetric) {
+ SetPaddingType(NN_PAD_MIRROR_SYMMETRIC);
+ }
+
+ // Min/max values for input tensor.
+ TF_LITE_ENSURE_STATUS(
+ ComputeMinAndMaxQuantValues(input_tensor, &input_min_, &input_max_));
+ auto* input_min_const = graph_builder_->AddConstNodeWithData(
+ quant_bound_shape, reinterpret_cast<char*>(&input_min_),
+ sizeof(input_min_));
+ auto* input_max_const = graph_builder_->AddConstNodeWithData(
+ quant_bound_shape, reinterpret_cast<char*>(&input_max_),
+ sizeof(input_max_));
+ AddInput(TensorID(input_min_const->GetID(), 0));
+ AddInput(TensorID(input_max_const->GetID(), 0));
+
+ // Hexagon outputs for this node.
+ int output_batch_size, output_height_size, output_width_size,
+ output_depth_size;
+ GetDims(&output_batch_size, &output_height_size, &output_width_size,
+ &output_depth_size, context->tensors[outputs->data[0]].dims);
+ node_output_ = AddOutput(sizeof(uint8_t), 4,
+ {output_batch_size, output_height_size,
+ output_width_size, output_depth_size});
+ AddOutput(sizeof(float), 4, {1, 1, 1, 1});
+ AddOutput(sizeof(float), 4, {1, 1, 1, 1});
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus MirrorPadOpBuilder::RegisterOutputs(const TfLiteIntArray* outputs,
+ TfLiteContext* context) {
+ // Should be only 1 output.
+ graph_builder_->AddTensorWithID(outputs->data[0], node_output_.first,
+ node_output_.second);
+ return kTfLiteOk;
+}
+
+MirrorPadOpBuilder::~MirrorPadOpBuilder() {}
+
+OpBuilder* CreateMirrorPadBuilder(GraphBuilder* graph_builder, int op_type) {
+ return new MirrorPadOpBuilder(graph_builder, op_type);
+}
+
+} // namespace hexagon
+} // namespace delegates
+} // namespace tflite
diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/mirror_pad_builder.h b/tensorflow/lite/experimental/delegates/hexagon/builders/mirror_pad_builder.h
new file mode 100644
index 0000000..6fcb260
--- /dev/null
+++ b/tensorflow/lite/experimental/delegates/hexagon/builders/mirror_pad_builder.h
@@ -0,0 +1,49 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_HEXAGON_BUILDERS_MIRROR_PAD_BUILDER_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_HEXAGON_BUILDERS_MIRROR_PAD_BUILDER_H_
+
+#include <vector>
+
+#include "tensorflow/lite/experimental/delegates/hexagon/builders/op_builder.h"
+
+namespace tflite {
+namespace delegates {
+namespace hexagon {
+
+class MirrorPadOpBuilder : public OpBuilder {
+ public:
+ explicit MirrorPadOpBuilder(GraphBuilder* graph_builder, int op_type)
+ : OpBuilder(graph_builder, op_type) {}
+ TfLiteStatus PopulateSubGraph(const TfLiteIntArray* inputs,
+ const TfLiteIntArray* outputs,
+ TfLiteContext* context) override;
+
+ TfLiteStatus RegisterOutputs(const TfLiteIntArray* outputs,
+ TfLiteContext* context) override;
+
+ ~MirrorPadOpBuilder() override;
+
+ private:
+ TensorID node_output_;
+ float input_min_, input_max_;
+ std::vector<int> paddings_shape_;
+};
+
+} // namespace hexagon
+} // namespace delegates
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_HEXAGON_BUILDERS_MIRROR_PAD_BUILDER_H_
diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/op_builder.cc b/tensorflow/lite/experimental/delegates/hexagon/builders/op_builder.cc
index e20127a..c7432e6 100644
--- a/tensorflow/lite/experimental/delegates/hexagon/builders/op_builder.cc
+++ b/tensorflow/lite/experimental/delegates/hexagon/builders/op_builder.cc
@@ -43,6 +43,8 @@
return CreateReduceBuilder(this, OP_QuantizedSum_8to32);
case kTfLiteBuiltinPad:
return CreatePadBuilder(this, OP_QuantizedPad_8);
+ case kTfLiteBuiltinMirrorPad:
+ return CreateMirrorPadBuilder(this, OP_MirrorPad_8);
case kTfLiteBuiltinFullyConnected:
return CreateMatMulBuilder(this, OP_QuantizedMatMul_8x8to32);
case kTfLiteBuiltinAveragePool2d:
diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/op_factory.h b/tensorflow/lite/experimental/delegates/hexagon/builders/op_factory.h
index e7236fb..0beb88c 100644
--- a/tensorflow/lite/experimental/delegates/hexagon/builders/op_factory.h
+++ b/tensorflow/lite/experimental/delegates/hexagon/builders/op_factory.h
@@ -35,6 +35,7 @@
OpBuilder* CreateReshapeBuilder(GraphBuilder* graph_builder, int op_type);
OpBuilder* CreateSoftmaxBuilder(GraphBuilder* graph_builder, int op_type);
OpBuilder* CreateReduceBuilder(GraphBuilder* graph_builder, int op_type);
+OpBuilder* CreateMirrorPadBuilder(GraphBuilder* graph_builder, int op_type);
OpBuilder* CreatePadBuilder(GraphBuilder* graph_builder, int op_type);
OpBuilder* CreateResizeNearestNeighborBuilder(GraphBuilder* graph_builder,
int op_type);
diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/tests/BUILD b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/BUILD
index b1df59c..47a78dc 100644
--- a/tensorflow/lite/experimental/delegates/hexagon/builders/tests/BUILD
+++ b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/BUILD
@@ -30,6 +30,7 @@
"conv_test.cc",
"l2_norm_test.cc",
"matmul_test.cc",
+ "mirror_pad_test.cc",
"mul_test.cc",
"neg_test.cc",
"pad_test.cc",
diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/tests/mirror_pad_test.cc b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/mirror_pad_test.cc
new file mode 100644
index 0000000..4caf96a
--- /dev/null
+++ b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/mirror_pad_test.cc
@@ -0,0 +1,127 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/lite/experimental/delegates/hexagon/builders/tests/hexagon_delegate_op_model.h"
+
+namespace tflite {
+using testing::ElementsAreArray;
+
+template <typename T>
+class MirrorPadOpModel : public SingleOpModelWithHexagon {
+ public:
+ MirrorPadOpModel(const TensorData& input,
+ std::initializer_list<int> paddings_shape,
+ std::initializer_list<int> paddings,
+ const TensorData& output, const tflite::MirrorPadMode mode) {
+ input_id_ = AddInput(input);
+ padding_matrix_id_ =
+ AddConstInput(TensorType_INT32, paddings, paddings_shape);
+ output_id_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_MIRROR_PAD, BuiltinOptions_MirrorPadOptions,
+ CreateMirrorPadOptions(builder_, mode).Union());
+ BuildInterpreter({GetShape(input_id_), GetShape(padding_matrix_id_)});
+ }
+
+ int input_tensor_id() { return input_id_; }
+
+ std::vector<T> GetOutput() { return ExtractVector<T>(output_id_); }
+
+ protected:
+ int input_id_;
+ int padding_matrix_id_;
+ int output_id_;
+};
+
+TEST(MirrorPadTest, EmptyPad_UInt8) {
+ MirrorPadOpModel<uint8_t> model(
+ {TensorType_UINT8, {2, 3}, -1.0, 1.0}, {2, 2}, {0, 0, 0, 0},
+ {TensorType_UINT8, {}, -1.0, 1.0}, tflite::MirrorPadMode_REFLECT);
+ model.PopulateTensor<uint8_t>(model.input_tensor_id(), {1, 2, 3, 4, 5, 6});
+ model.ApplyDelegateAndInvoke();
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
+}
+
+TEST(MirrorPadTest, PadBothSides_Symmetric_Int8) {
+ MirrorPadOpModel<int8_t> model({TensorType_INT8, {2, 3}, -1.0, 1.0}, {2, 2},
+ {1, 1, 1, 1}, {TensorType_INT8, {}, -1.0, 1.0},
+ tflite::MirrorPadMode_SYMMETRIC);
+ model.PopulateTensor<int8_t>(model.input_tensor_id(), {1, 2, 3, 4, 5, 6});
+ model.ApplyDelegateAndInvoke();
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray({1, 1, 2, 3, 3, 1, 1, 2, 3, 3,
+ 4, 4, 5, 6, 6, 4, 4, 5, 6, 6}));
+}
+
+TEST(MirrorPadTest, PadBothSides_Reflect_UInt8) {
+ MirrorPadOpModel<uint8_t> model(
+ {TensorType_UINT8, {2, 3}, -1.0, 1.0}, {2, 2}, {1, 1, 1, 1},
+ {TensorType_UINT8, {}, -1.0, 1.0}, tflite::MirrorPadMode_REFLECT);
+ model.PopulateTensor<uint8_t>(model.input_tensor_id(), {1, 2, 3, 4, 5, 6});
+ model.ApplyDelegateAndInvoke();
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray({5, 4, 5, 6, 5, 2, 1, 2, 3, 2,
+ 5, 4, 5, 6, 5, 2, 1, 2, 3, 2}));
+}
+
+TEST(MirrorPadTest, PadOneSide_left_Reflect_Int8) {
+ MirrorPadOpModel<int8_t> model({TensorType_INT8, {2, 3}, -1.0, 1.0}, {2, 2},
+ {1, 0, 1, 0}, {TensorType_INT8, {}, -1.0, 1.0},
+ tflite::MirrorPadMode_REFLECT);
+ model.PopulateTensor<int8_t>(model.input_tensor_id(), {1, 2, 3, 4, 5, 6});
+ model.ApplyDelegateAndInvoke();
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray({5, 4, 5, 6, 2, 1, 2, 3, 5, 4, 5, 6}));
+}
+
+TEST(MirrorPadTest, PadOneSide_right_Symmetric_UInt8) {
+ MirrorPadOpModel<uint8_t> model(
+ {TensorType_UINT8, {2, 3}, -1.0, 1.0}, {2, 2}, {0, 1, 0, 1},
+ {TensorType_UINT8, {}, -1.0, 1.0}, tflite::MirrorPadMode_SYMMETRIC);
+ model.PopulateTensor<uint8_t>(model.input_tensor_id(), {1, 2, 3, 4, 5, 6});
+ model.ApplyDelegateAndInvoke();
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray({1, 2, 3, 3, 4, 5, 6, 6, 4, 5, 6, 6}));
+}
+
+TEST(MirrorPadTest, Pad_1D_Reflect_Int8) {
+ MirrorPadOpModel<int8_t> model({TensorType_INT8, {3}, -1.0, 1.0}, {1, 2},
+ {0, 2}, {TensorType_INT8, {}, -1.0, 1.0},
+ tflite::MirrorPadMode_REFLECT);
+ model.PopulateTensor<int8_t>(model.input_tensor_id(), {1, 2, 3});
+ model.ApplyDelegateAndInvoke();
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 2, 1}));
+}
+
+TEST(MirrorPadTest, Pad_1D_Symmetric_UInt8) {
+ MirrorPadOpModel<uint8_t> model({TensorType_UINT8, {3}, -1.0, 1.0}, {1, 2},
+ {0, 2}, {TensorType_UINT8, {}, -1.0, 1.0},
+ tflite::MirrorPadMode_SYMMETRIC);
+ model.PopulateTensor<uint8_t>(model.input_tensor_id(), {1, 2, 3});
+ model.ApplyDelegateAndInvoke();
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 3, 2}));
+}
+
+TEST(MirrorPadTest, PadBothSides_Reflect_Whole_UInt8) {
+ MirrorPadOpModel<uint8_t> model(
+ {TensorType_UINT8, {2, 3}, -1.0, 1.0}, {2, 2}, {1, 1, 2, 2},
+ {TensorType_UINT8, {}, -1.0, 1.0}, tflite::MirrorPadMode_REFLECT);
+ model.PopulateTensor<uint8_t>(model.input_tensor_id(), {1, 2, 3, 4, 5, 6});
+ model.ApplyDelegateAndInvoke();
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray({6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1,
+ 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1}));
+}
+
+} // namespace tflite
diff --git a/tensorflow/lite/experimental/delegates/hexagon/utils.cc b/tensorflow/lite/experimental/delegates/hexagon/utils.cc
index df7d742..d9d1480 100644
--- a/tensorflow/lite/experimental/delegates/hexagon/utils.cc
+++ b/tensorflow/lite/experimental/delegates/hexagon/utils.cc
@@ -80,6 +80,7 @@
case kTfLiteBuiltinL2Normalization:
case kTfLiteBuiltinLogistic:
case kTfLiteBuiltinMaxPool2d:
+ case kTfLiteBuiltinMirrorPad:
case kTfLiteBuiltinMul:
case kTfLiteBuiltinPad:
case kTfLiteBuiltinQuantize:
@@ -159,6 +160,17 @@
// causes an unexpected shift in dequantized values.
return false;
}
+ case kTfLiteBuiltinMirrorPad: {
+ if (!InputsWithCorrectTypes(
+ node, context, {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteInt32}}) ||
+ !IsConstantTensor(GetInput(context, node, 1)))
+ return false;
+ const TfLiteMirrorPaddingParams* params =
+ reinterpret_cast<const TfLiteMirrorPaddingParams*>(
+ node->builtin_data);
+ return params->mode == kTfLiteMirrorPaddingReflect ||
+ params->mode == kTfLiteMirrorPaddingSymmetric;
+ }
case kTfLiteBuiltinPad: {
// TODO(b/139277813): Currently we only support padding with the default
// of 0. Add support for user-defined constant if required.
diff --git a/tensorflow/lite/experimental/ios/BUILD.apple b/tensorflow/lite/experimental/ios/BUILD.apple
index faa3f12..8e7b32e 100644
--- a/tensorflow/lite/experimental/ios/BUILD.apple
+++ b/tensorflow/lite/experimental/ios/BUILD.apple
@@ -22,13 +22,6 @@
""",
)
-TFL_LIBRARY_HDRS = [
- "//tensorflow/lite/delegates/gpu:metal_delegate.h",
- "//tensorflow/lite/experimental/delegates/coreml:coreml_delegate.h",
- "//tensorflow/lite/c:c_api.h",
- "//tensorflow/lite/c:common.h",
-]
-
TFL_FRAMEWORK_HDRS = [
"//tensorflow/lite/delegates/gpu:metal_delegate.h",
":coreml_delegate.h",
@@ -43,19 +36,6 @@
bundle_name = "TensorFlowLiteC",
minimum_os_version = TFL_MINIMUM_OS_VERSION,
deps = [
- ":TensorFlowLiteC",
- ],
-)
-
-objc_library(
- name = "TensorFlowLiteC",
- hdrs = TFL_LIBRARY_HDRS,
- module_name = "TensorFlowLiteC",
- weak_sdk_frameworks = [
- "Metal",
- "CoreML",
- ],
- deps = [
":tensorflow_lite_c",
],
)
@@ -78,20 +58,22 @@
],
)
-# Using this intermediate target is a workaround for a bug in bazel build rules
-# involving mixed objc_library & cc_library deps mentioned in (b/74809458).
-# When these dependencies are declared directly under the "TensorFlowLiteC"
-# target above, the resulting static library incorrectly contains duplicate
-# symbols from some ObjC code in the transitive dependencies.
-#
-# When a new dependency should be added to the TensorFlowLiteC framework, the
-# dependency should be added under this target instead.
-# When a new header file needs to be exposed, the header should be added to the
-# TFL_LIBRARY_HDRS list above.
cc_library(
name = "tensorflow_lite_c",
- hdrs = TFL_LIBRARY_HDRS,
- tags = ["nobuilder"],
+ hdrs = [
+ "//tensorflow/lite/c:c_api.h",
+ "//tensorflow/lite/c:common.h",
+ "//tensorflow/lite/delegates/gpu:metal_delegate.h",
+ "//tensorflow/lite/experimental/delegates/coreml:coreml_delegate.h",
+ ],
+ linkopts = [
+ "-Wl,-weak_framework,CoreML",
+ "-Wl,-weak_framework,Metal",
+ ],
+ tags = [
+ "nobuilder",
+ "swift_module=TensorFlowLiteC",
+ ],
deps = [
"//tensorflow/lite/c:c_api",
"//tensorflow/lite/delegates/gpu:metal_delegate",
diff --git a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java
index b19ef2e..bced23e 100644
--- a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java
+++ b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java
@@ -231,6 +231,26 @@
return container.getDataType();
}
+ /**
+ * Gets the image width.
+ *
+ * @throws IllegalStateException if the TensorImage never loads data.
+ * @throws IllegalArgumentException if the container data is corrupted.
+ */
+ public int getWidth() {
+ return container.getWidth();
+ }
+
+ /**
+ * Gets the image height.
+ *
+ * @throws IllegalStateException if the TensorImage never loads data.
+ * @throws IllegalArgumentException if the container data is corrupted.
+ */
+ public int getHeight() {
+ return container.getHeight();
+ }
+
// Requires tensor shape [h, w, 3] or [1, h, w, 3].
static void checkImageTensorShape(int[] shape) {
SupportPreconditions.checkArgument(
@@ -273,6 +293,41 @@
isBufferUpdated = true;
}
+ int getWidth() {
+ SupportPreconditions.checkState(
+ isBitmapUpdated || isBufferUpdated,
+ "Both buffer and bitmap data are obsolete. Forgot to call TensorImage#load?");
+ if (isBitmapUpdated) {
+ return bitmapImage.getWidth();
+ }
+ return getBufferDimensionSize(-2);
+ }
+
+ int getHeight() {
+ SupportPreconditions.checkState(
+ isBitmapUpdated || isBufferUpdated,
+ "Both buffer and bitmap data are obsolete. Forgot to call TensorImage#load?");
+ if (isBitmapUpdated) {
+ return bitmapImage.getHeight();
+ }
+ return getBufferDimensionSize(-3);
+ }
+
+ // Internal helper method to get the size of one dimension in the shape of the `bufferImage`.
+ // Requires `isBufferUpdated` is true.
+ // Throws `IllegalArgumentException` if data is corrupted.
+ private int getBufferDimensionSize(int dim) {
+ int[] shape = bufferImage.getShape();
+ // The defensive check is needed because bufferImage might be invalidly changed by user
+ // (a.k.a internal data is corrupted)
+ TensorImage.checkImageTensorShape(shape);
+ dim = dim % shape.length;
+ if (dim < 0) {
+ dim += shape.length;
+ }
+ return shape[dim];
+ }
+
public DataType getDataType() {
return dataType;
}
@@ -284,7 +339,8 @@
return bitmapImage;
}
if (!isBufferUpdated) {
- throw new IllegalStateException("Both buffer and bitmap data are obsolete.");
+ throw new IllegalStateException(
+ "Both buffer and bitmap data are obsolete. Forgot to call TensorImage#load?");
}
if (bufferImage.getDataType() != DataType.UINT8) {
throw new IllegalStateException(
@@ -310,7 +366,8 @@
return bufferImage;
}
SupportPreconditions.checkArgument(
- isBitmapUpdated, "Both buffer and bitmap data are obsolete.");
+ isBitmapUpdated,
+ "Both buffer and bitmap data are obsolete. Forgot to call TensorImage#load?");
int requiredFlatSize = bitmapImage.getWidth() * bitmapImage.getHeight() * 3;
if (bufferImage == null
|| (!bufferImage.isDynamic() && bufferImage.getFlatSize() != requiredFlatSize)) {
diff --git a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java
index 16622a2..fa05be3 100644
--- a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java
+++ b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java
@@ -379,13 +379,13 @@
// Check if the new shape is the same as current shape.
int newFlatSize = computeFlatSize(shape);
+ this.shape = shape.clone();
if (flatSize == newFlatSize) {
return;
}
// Update to the new shape.
flatSize = newFlatSize;
- this.shape = shape.clone();
buffer = ByteBuffer.allocateDirect(flatSize * getTypeSize());
buffer.order(ByteOrder.nativeOrder());
}
diff --git a/tensorflow/lite/experimental/support/metadata/java/BUILD b/tensorflow/lite/experimental/support/metadata/java/BUILD
index f1cd617..82b6e98 100644
--- a/tensorflow/lite/experimental/support/metadata/java/BUILD
+++ b/tensorflow/lite/experimental/support/metadata/java/BUILD
@@ -25,6 +25,10 @@
name = "tensorflow-lite-support-metadata-lib",
srcs = glob(["src/java/org/tensorflow/lite/support/metadata/**/*.java"]),
javacopts = JAVACOPTS,
+ resource_jars = [
+ "//tensorflow/lite/experimental/support/metadata:libmetadata_schema_java.jar",
+ "//tensorflow/lite/experimental/support/metadata:libschema_fbs_java.jar",
+ ],
deps = [
"//tensorflow/lite/experimental/support/metadata:metadata_schema_java",
"//tensorflow/lite/experimental/support/metadata:schema_fbs_java",
diff --git a/tensorflow/lite/experimental/swift/BUILD.apple b/tensorflow/lite/experimental/swift/BUILD.apple
index 2ce8428..50130fc 100644
--- a/tensorflow/lite/experimental/swift/BUILD.apple
+++ b/tensorflow/lite/experimental/swift/BUILD.apple
@@ -13,11 +13,15 @@
swift_library(
name = "TensorFlowLite",
srcs = glob(["Sources/*.swift"]),
+ linkopts = [
+ "-Wl,-weak_framework,CoreML",
+ "-Wl,-weak_framework,Metal",
+ ],
module_name = "TensorFlowLite",
tags = TFL_DEFAULT_TAGS,
visibility = ios_visibility_whitelist(),
deps = [
- "//tensorflow/lite/experimental/ios:TensorFlowLiteC",
+ "//tensorflow/lite/experimental/ios:tensorflow_lite_c",
],
)
diff --git a/tensorflow/lite/g3doc/convert/1x_compatibility.md b/tensorflow/lite/g3doc/convert/1x_compatibility.md
index adb2af4..9f9f277 100644
--- a/tensorflow/lite/g3doc/convert/1x_compatibility.md
+++ b/tensorflow/lite/g3doc/convert/1x_compatibility.md
@@ -1,30 +1,32 @@
-# TensorFlow 1.x compatibility
+# TensorFlow 1.x Compatibility <a name="differences"></a>
-The `tf.lite.TFLiteConverter` was updated between TensorFlow 1.X and 2.0. This
-document explains the differences between the 1.X and 2.0 versions of the
-converter, and provides information about how to use the 1.X version if
-required.
+The `tf.lite.TFLiteConverter` Python API was updated between TensorFlow 1.x and
+2.x. This document explains the differences between the two versions, and
+provides information about how to use the 1.x version if required.
-## Summary of changes in Python API between 1.X and 2.0 <a name="differences"></a>
-
-The following section summarizes the changes in the Python API from 1.X to 2.0.
If any of the changes raise concerns, please file a
-[GitHub issue](https://github.com/tensorflow/tensorflow/issues).
+[GitHub Issue](https://github.com/tensorflow/tensorflow/issues).
-### Formats supported by `TFLiteConverter`
+Note: We highly recommend that you
+[migrate your TensorFlow 1.x code to TensorFlow 2.x code](https://www.tensorflow.org/guide/migrate)
+.
-The 2.0 version of the converter supports SavedModel and Keras model files
-generated in both 1.X and 2.0. However, the conversion process no longer
-supports "frozen graph" `GraphDef` files generated in 1.X.
+## Model formats
-#### Converting frozen graphs
+#### SavedModel and Keras
-Users who want to convert frozen graph `GraphDef` files (`.pb` files) to
-TensorFlow Lite should use `tf.compat.v1.lite.TFLiteConverter`.
+The `tf.lite.TFLiteConverter` API supports SavedModel and Keras HDF5 files
+generated in both TensorFlow 1.x and 2.x.
-The following snippet shows a frozen graph file being converted:
+#### Frozen Graph
+
+Note: TensorFlow 2.x no longer supports the generation of frozen graph models.
+
+The `tf.compat.v1.lite.TFLiteConverter` API supports frozen graph models
+generated in TensorFlow 1.x, as shown below:
```python
+import tensorflow as tf
# Path to the frozen graph file
graph_def_file = 'frozen_graph.pb'
# A list of the names of the model's input tensors
@@ -32,70 +34,68 @@
# A list of the names of the model's output tensors
output_arrays = ['output_name']
# Load and convert the frozen graph
-converter = lite.TFLiteConverter.from_frozen_graph(
+converter = tf.lite.TFLiteConverter.from_frozen_graph(
graph_def_file, input_arrays, output_arrays)
tflite_model = converter.convert()
# Write the converted model to disk
open("converted_model.tflite", "wb").write(tflite_model)
```
-### Quantization-aware training
+## Converter attributes
-The following attributes and methods associated with
-[quantization-aware training](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/quantize)
-have been removed from `TFLiteConverter` in TensorFlow 2.0:
+#### Renamed attributes
-* `inference_type`
-* `inference_input_type`
-* `quantized_input_stats`
-* `default_ranges_stats`
-* `reorder_across_fake_quant`
-* `change_concat_input_ranges`
-* `post_training_quantize` - Deprecated in the 1.X API
-* `get_input_arrays()`
+The following 1.x attribute has been renamed in 2.x.
-The rewriter function that supports quantization-aware training does not support
-models generated by TensorFlow 2.0. Additionally, TensorFlow Lite’s quantization
-API is being reworked and streamlined in a direction that supports
-quantization-aware training through the Keras API. These attributes will be
-removed in the 2.0 API until the new quantization API is launched. Users who
-want to convert models generated by the rewriter function can use
-`tf.compat.v1.lite.TFLiteConverter`.
+* `target_ops` has been renamed to `target_spec.supported_ops` - In 2.x, in
+ line with future additions to the optimization framework, it has become an
+ attribute of `TargetSpec` and has been renamed to `supported_ops`.
-### Changes to `TFLiteConverter` attributes
+#### Unsupported attributes
-The `target_ops` attribute has become an attribute of `TargetSpec` and renamed
-to `supported_ops` in line with future additions to the optimization framework.
+The following 1.x attributes have been removed in 2.x.
-Additionally, the following attributes have been removed:
-
-* `drop_control_dependency` (default: `True`)
-* _Graph visualization_ - The recommended approach for visualizing a
- TensorFlow Lite graph in TensorFlow 2.0 will be to use
- [visualize.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/visualize.py).
- Unlike GraphViz, it enables users to visualize the graph after post training
- quantization has occurred. The following attributes related to graph
- visualization will be removed:
+* _Quantization_ - In 2.x,
+ [quantize aware training](https://www.tensorflow.org/model_optimization/guide/quantization/training)
+ is supported through the Keras API and
+ [post training quantization](https://www.tensorflow.org/lite/performance/post_training_quantization)
+ uses fewer streamlined converter flags. Thus, the following attributes and
+ methods related to quantization have been removed:
+ * `inference_type`
+ * `quantized_input_stats`
+ * `post_training_quantize`
+ * `default_ranges_stats`
+ * `reorder_across_fake_quant`
+ * `change_concat_input_ranges`
+ * `get_input_arrays()`
+* _Visualization_ - In 2.x, the recommended approach for visualizing a
+ TensorFlow Lite graph is to use
+ [visualize.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/visualize.py)
+ . Unlike GraphViz, it enables users to visualize the graph after post
+ training quantization has occurred. Thus, the following attributes related
+ to graph visualization have been removed:
* `output_format`
* `dump_graphviz_dir`
* `dump_graphviz_video`
+* _Frozen graph_ - In 2.x, the frozen graph model format has been removed.
+ Thus, the following attribute related to frozen graphs has been removed:
+ * `drop_control_dependency`
-### General API changes
+## Unsupported APIs
-The following section explains several significant API changes between
-TensorFlow 1.X and 2.0.
+The following section explains several significant features in 1.x that have
+been removed in 2.x.
-#### Conversion methods
+#### Conversion APIs
-The following methods that were previously deprecated in 1.X will no longer be
-exported in 2.0:
+The following methods were deprecated in 1.x and have been removed in 2.x:
* `lite.toco_convert`
* `lite.TocoConverter`
-#### `lite.constants`
+#### `lite.constants` API
-The `lite.constants` API was removed in 2.0 in order to decrease duplication
+The `lite.constants` API was removed in 2.x in order to decrease duplication
between TensorFlow and TensorFlow Lite. The following list maps the
`lite.constant` type to the TensorFlow type:
@@ -106,12 +106,15 @@
* `lite.constants.STRING`: `tf.string`
* `lite.constants.QUANTIZED_UINT8`: `tf.uint8`
-Additionally, `lite.constants.TFLITE` and `lite.constants.GRAPHVIZ_DOT` were
-removed due to the deprecation of the `output_format` flag in `TFLiteConverter`.
+Additionally, the deprecation of the `output_format` flag in `TFLiteConverter`
+led to the removal of the following constants:
-#### `lite.OpHint`
+* `lite.constants.TFLITE`
+* `lite.constants.GRAPHVIZ_DOT`
-The `OpHint` API is currently not available in 2.0 due to an incompatibility
-with the 2.0 APIs. This API enables conversion of LSTM based models. Support for
-LSTMs in 2.0 is being investigated. All related `lite.experimental` APIs have
-been removed due to this issue.
+#### `lite.OpHint` API
+
+The `OpHint` API is currently unsupported due to an incompatibility with the 2.x
+APIs. This API enables conversion of LSTM based models. Support for LSTMs in 2.x
+is being investigated. All related `lite.experimental` APIs have been removed
+due to this issue.
diff --git a/tensorflow/lite/g3doc/performance/hexagon_delegate.md b/tensorflow/lite/g3doc/performance/hexagon_delegate.md
index 51af598..60fe946 100644
--- a/tensorflow/lite/g3doc/performance/hexagon_delegate.md
+++ b/tensorflow/lite/g3doc/performance/hexagon_delegate.md
@@ -259,43 +259,7 @@
* This is tentatively planned for a future release, though there is no
concrete timeline.
* Which ops are supported by the delegate?
- * Initial list of supported ops:
- * Add
- * ArgMax
- * ArgMin
- * AveragePool2D (without any activation)
- * Concat
- * Conv2D with following constraints:
- * stride width/height <= 3
- * DepthToSpace
- * DepthwiseConv2D with following constraints:
- * Filter width == 3
- * depth_multiplier == 1
- * dilation only supported when stride == 1
- * Otherwise, stride height/width <= 3
- * FullyConnected (without any activation)
- * Hardswish
- * L2Normalization (without any activation)
- * Logistic (aka Sigmoid)
- * MaxPool2D (without any activation)
- * Mul (without any activation)
- * Neg
- * Pad: Only supports 0 padding
- * Relu
- * Relu6
- * Reshape
- * Resize Bilinear with following constraints:
- * Requested size <= 65
- * Resize Nearest Neighbor
- * SoftMax
- * SpaceToDepth
- * Split
- * Sub
- * Tanh
- * Transpose
- * TransposeConv2D with following constraints:
- * stride height/width <= 3
- * dilation height/width == 1
+ * See the current list of [supported ops and constraints](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/delegates/hexagon/README.md)
* How can I tell that the model is using the DSP when I enable the delegate?
* Two log messages will be printed when you enable the delegate - one to
indicate if the delegate was created and another to indicate how many
diff --git a/tensorflow/lite/g3doc/performance/post_training_quantization.md b/tensorflow/lite/g3doc/performance/post_training_quantization.md
index 194d102..a526be7 100644
--- a/tensorflow/lite/g3doc/performance/post_training_quantization.md
+++ b/tensorflow/lite/g3doc/performance/post_training_quantization.md
@@ -4,51 +4,44 @@
while also improving CPU and hardware accelerator latency, with little
degradation in model accuracy. You can perform these techniques using an
already-trained float TensorFlow model when you convert it to TensorFlow Lite
-format.
+format using the [TensorFlow Lite Converter](../convert/).
Note: The procedures on this page require TensorFlow 1.15 or higher.
-
-### Optimization options
+### Optimization Methods
There are several post-training quantization options to choose from. Here is a
summary table of the choices and the benefits they provide:
-| Technique | Benefits | Hardware |
-| ------------------------- | ------------------------- | ------------------- |
-| Dynamic range | 4x smaller, 2-3x speedup, | CPU |
-: quantization : accuracy : :
-| Full integer quantization | 4x smaller, 3x+ speedup | CPU, Edge TPU, etc. |
-| Float16 quantization | 2x smaller, potential GPU | CPU/GPU |
-: : acceleration : :
+| Technique | Benefits | Hardware |
+| -------------------- | ------------------------- | ---------------- |
+| Dynamic range | 4x smaller, 2-3x speedup | CPU |
+: quantization : : :
+| Full integer | 4x smaller, 3x+ speedup | CPU, Edge TPU, |
+: quantization : : Microcontrollers :
+| Float16 quantization | 2x smaller, potential GPU | CPU, GPU |
+: : acceleration : :
This decision tree can help determine which post-training quantization method is
best for your use case:

-Alternatively, you might achieve higher accuracy if you perform
-[quantization-aware training](
-https://github.com/tensorflow/tensorflow/tree/r1.14/tensorflow/contrib/quantize).
-However, doing so requires some model modifications to add fake quantization
-nodes, whereas the post-training quantization techniques on this page use an
-existing pre-trained model.
-
### Dynamic range quantization
The simplest form of post-training quantization statically quantizes only the
-weights from floating point to 8-bits of precision. This technique is enabled as
-an option in the [TensorFlow Lite converter](../convert/):
+weights from floating point to integer, which has 8-bits of precision:
-```
+<pre>
import tensorflow as tf
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
-converter.optimizations = [tf.lite.Optimize.DEFAULT]
+<b>converter.optimizations = [tf.lite.Optimize.DEFAULT]</b>
tflite_quant_model = converter.convert()
-```
+</pre>
-At inference, weights are converted from 8-bits of precision to floating point and
-computed using floating-point kernels. This conversion is done once and cached to reduce latency.
+At inference, weights are converted from 8-bits of precision to floating point
+and computed using floating-point kernels. This conversion is done once and
+cached to reduce latency.
To further improve latency, "dynamic-range" operators dynamically quantize
activations based on their range to 8-bits and perform computations with 8-bit
@@ -58,89 +51,105 @@
fixed-point computation. Dynamic-range ops are available for the most
compute-intensive operators in a network:
-* [tf.contrib.layers.fully_connected](https://www.tensorflow.org/api_docs/python/tf/contrib/layers/fully_connected)
-* [tf.nn.conv2d](https://www.tensorflow.org/api_docs/python/tf/nn/conv2d)
-* [tf.nn.embedding_lookup](https://www.tensorflow.org/api_docs/python/tf/nn/embedding_lookup)
-* [BasicRNN](https://www.tensorflow.org/api_docs/python/tf/contrib/rnn/BasicRNNCell)
-* [tf.nn.bidirectional_dynamic_rnn for BasicRNNCell type](https://www.tensorflow.org/api_docs/python/tf/nn/bidirectional_dynamic_rnn)
-* [tf.nn.dynamic_rnn for LSTM and BasicRNN Cell types](https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn)
+* `tf.keras.layers.Dense`
+* `tf.keras.layers.Conv2D`
+* `tf.keras.layers.LSTM`
+* `tf.nn.embedding_lookup`
+* `tf.compat.v1.nn.rnn_cell.BasicRNNCell`
+* `tf.compat.v1.nn.bidirectional_dynamic_rnn`
+* `tf.compat.v1.nn.dynamic_rnn`
-
-### Full integer quantization of weights and activations
+### Full integer quantization
You can get further latency improvements, reductions in peak memory usage, and
-access to integer only hardware accelerators by making sure all model math is
-quantized.
+access to integer only hardware devices or accelerators by making sure all model
+math is integer quantized.
To do this, you need to measure the dynamic range of activations and inputs by
-supplying a representative data set. You can simply create an input data
-generator and provide it to our converter. For example:
+supplying sample input data to the converter. Refer to the
+`representative_dataset_gen()` function used in the following code.
-```
+#### Integer with float fallback (using default float input/output)
+
+In order to fully integer quantize a model, but use float operators when they
+don't have an integer implementation (to ensure conversion occurs smoothly), use
+the following steps:
+
+<pre>
import tensorflow as tf
-
+converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
+<b>converter.optimizations = [tf.lite.Optimize.DEFAULT]
def representative_dataset_gen():
for _ in range(num_calibration_steps):
# Get sample input data as a numpy array in a method of your choosing.
yield [input]
-
-converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
-converter.optimizations = [tf.lite.Optimize.DEFAULT]
-converter.representative_dataset = representative_dataset_gen
+converter.representative_dataset = representative_dataset_gen</b>
tflite_quant_model = converter.convert()
-```
+</pre>
-The resulting model should be fully quantized, but any
-ops that do not have quantized implementations are left in
-floating point. This allows conversion to occur smoothly, but the model won't be
-compatible with accelerators that require full integer quantization.
+Note: This won't be compatible with integer only devices (such as 8-bit
+microcontrollers) and accelerators (such as the Coral Edge TPU). For convenience
+during inference, the input and output still remain float in order to have the
+same interface as the original float only model.
-Additionally, the model still uses float input and output for convenience.
+#### Integer only
-To ensure compatibility with some accelerators (such as the Coral Edge TPU), you
-can enforce full integer quantization for all ops and use integer input and
-output by adding the following lines before you convert:
+*This is a common use case for
+[TensorFlow Lite for Microcontrollers](https://www.tensorflow.org/lite/microcontrollers)
+and [Coral Edge TPUs](https://coral.ai/).*
-```
-converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
-converter.inference_input_type = tf.uint8
-converter.inference_output_type = tf.uint8
-```
+Additionally, to ensure compatibility with integer only devices (such as 8-bit
+microcontrollers) and accelerators (such as the Coral Edge TPU), you can enforce
+full integer quantization for all ops including the input and output, by using
+the following steps:
-The first line makes the converter throw an error if it encounters an operation
-it cannot currently quantize.
-
-Note: `target_spec.supported_ops` was previously `target_ops` in the Python API.
-
-
-### Float16 quantization of weights
-
-You can reduce the size of a floating point model by quantizing the weights to
-float16, the IEEE standard for 16-bit floating point numbers. The advantages of
-this quantization are as follows:
-
-- reduce model size by up to half (since all weights are now half the original
- size)
-- minimal loss in accuracy
-- some delegates (e.g. the GPU delegate) can operate directly on float16 data,
- which results in faster execution than float32 computations.
-
-This quantization may not be a good choice if you need maximum performance (a
-quantization to fixed point math would be better in that case). To enable
-float16 quantization of weights, specify "DEFAULT" optimization as above and
-then specify that float16 is in supported types for the target_spec:
-
-```
+<pre>
import tensorflow as tf
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
-converter.target_spec.supported_types = [tf.lite.constants.FLOAT16]
+def representative_dataset_gen():
+ for _ in range(num_calibration_steps):
+ # Get sample input data as a numpy array in a method of your choosing.
+ yield [input]
+converter.representative_dataset = representative_dataset_gen
+<b>converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]</b>
+<b>converter.inference_input_type = tf.int8</b> # or tf.uint8
+<b>converter.inference_output_type = tf.int8</b> # or tf.uint8
tflite_quant_model = converter.convert()
-```
+</pre>
-By default, a float16 quantized model will "dequantize" the weights values to
-float32 when run on the CPU. The GPU delegate will not perform this
-dequantization, since it can operate on float16 data.
+Note: The converter will throw an error if it encounters an operation it cannot
+currently quantize.
+
+### Float16 quantization
+
+You can reduce the size of a floating point model by quantizing the weights to
+float16, the IEEE standard for 16-bit floating point numbers. To enable float16
+quantization of weights, use the following steps:
+
+<pre>
+import tensorflow as tf
+converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
+<b>converter.optimizations = [tf.lite.Optimize.DEFAULT]
+converter.target_spec.supported_types = [tf.lite.constants.FLOAT16]</b>
+tflite_quant_model = converter.convert()
+</pre>
+
+The advantages of this quantization are as follows:
+
+* Reduce model size by up to half (since all weights are now half the original
+ size).
+* Minimal loss in accuracy.
+* Supports some delegates (e.g. the GPU delegate) can operate directly on
+ float16 data, which results in faster execution than float32 computations.
+
+The disadvantages of this quantization are as follows:
+
+* Not a good choice for maximum performance (a quantization to fixed point
+ math would be better in that case).
+* By default, a float16 quantized model will "dequantize" the weights values
+ to float32 when run on the CPU. (Note that the GPU delegate will not perform
+ this dequantization, since it can operate on float16 data.)
### Model accuracy
@@ -152,13 +161,18 @@
within acceptable limits. There is a tool to evaluate
[TensorFlow Lite model accuracy](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/accuracy/ilsvrc/README.md){:.external}.
-If the accuracy drop is too high, consider using
-[quantization aware training](https://github.com/tensorflow/tensorflow/tree/r1.13/tensorflow/contrib/quantize){:.external}.
+Alternatively, if the accuracy drop is too high, consider using
+[quantization aware training](https://www.tensorflow.org/model_optimization/guide/quantization/training)
+. However, doing so requires modifications during model training to add fake
+quantization nodes, whereas the post-training quantization techniques on this
+page use an existing pre-trained model.
### Representation for quantized tensors
8-bit quantization approximates floating point values using the following
-formula. `real_value = (int8_value - zero_point) * scale`.
+formula.
+
+$$real\_value = (int8\_value - zero\_point) \times scale$$
The representation has two main parts:
diff --git a/tensorflow/lite/g3doc/tutorials/model_maker_image_classification.ipynb b/tensorflow/lite/g3doc/tutorials/model_maker_image_classification.ipynb
index ee46795..464a5d1 100644
--- a/tensorflow/lite/g3doc/tutorials/model_maker_image_classification.ipynb
+++ b/tensorflow/lite/g3doc/tutorials/model_maker_image_classification.ipynb
@@ -49,7 +49,7 @@
"metadata": {
"colab_type": "text",
"id": "nDABAblytltI"
- },
+ },
"source": [
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
" \u003ctd\u003e\n",
diff --git a/tensorflow/lite/java/ovic/BUILD b/tensorflow/lite/java/ovic/BUILD
index 947fbee..e64bd30 100644
--- a/tensorflow/lite/java/ovic/BUILD
+++ b/tensorflow/lite/java/ovic/BUILD
@@ -58,7 +58,6 @@
deps = [
"//tensorflow/lite/java:tensorflowlite",
"//tensorflow/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper",
- "@org_checkerframework_qual",
],
)
@@ -75,7 +74,6 @@
"//tensorflow/lite/java:tensorflowlite_java",
"//tensorflow/lite/java/src/main/native",
"//tensorflow/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper",
- "@org_checkerframework_qual",
],
)
@@ -114,7 +112,6 @@
deps = [
"//tensorflow/lite/java:tensorflowlite",
"//tensorflow/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper",
- "@org_checkerframework_qual",
],
)
@@ -131,6 +128,5 @@
"//tensorflow/lite/java:tensorflowlite_java",
"//tensorflow/lite/java/src/main/native",
"//tensorflow/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper",
- "@org_checkerframework_qual",
],
)
diff --git a/tensorflow/lite/kernels/kernel_util.h b/tensorflow/lite/kernels/kernel_util.h
index ad068dd..5793b08 100644
--- a/tensorflow/lite/kernels/kernel_util.h
+++ b/tensorflow/lite/kernels/kernel_util.h
@@ -87,6 +87,10 @@
}
// Determines whether tensor is constant.
+// TODO(b/138199592): Introduce new query which checks for constant OR
+// persistent-read-only, which would be useful for most tensor kernels that
+// are potentially dynamic based on the input tensor value availability at the
+// time of prepare.
inline bool IsConstantTensor(const TfLiteTensor* tensor) {
return tensor->allocation_type == kTfLiteMmapRo;
}
@@ -105,6 +109,14 @@
}
}
+// Sets tensor to persistent and read-only.
+inline void SetTensorToPersistentRo(TfLiteTensor* tensor) {
+ if (tensor->allocation_type != kTfLitePersistentRo) {
+ tensor->allocation_type = kTfLitePersistentRo;
+ tensor->data.raw = nullptr;
+ }
+}
+
// Determines whether it is a hybrid op - one that has float inputs and
// quantized weights.
inline bool IsHybridOp(const TfLiteTensor* input, const TfLiteTensor* weight) {
diff --git a/tensorflow/lite/kernels/rank.cc b/tensorflow/lite/kernels/rank.cc
index 8e27ebc..53fd92f 100644
--- a/tensorflow/lite/kernels/rank.cc
+++ b/tensorflow/lite/kernels/rank.cc
@@ -30,19 +30,23 @@
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
output->type = kTfLiteInt32;
+ // By design, the input shape is always known at the time of Prepare, even
+ // if the preceding op that generates |input| is dynamic. Thus, we can
+ // always compute the rank immediately, without waiting for Eval.
+ SetTensorToPersistentRo(output);
+
// Rank produces a 0-D int32 Tensor representing the rank of input.
TfLiteIntArray* output_size = TfLiteIntArrayCreate(0);
- return context->ResizeTensor(context, output, output_size);
-}
+ TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_size));
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, NumDimensions(output), 0);
+ // Immediately propagate the known rank to the output tensor. This allows
+ // downstream ops that rely on the value to use it during prepare.
if (output->type == kTfLiteInt32) {
int32_t* output_data = GetTensorData<int32_t>(output);
*output_data = NumDimensions(input);
@@ -53,6 +57,10 @@
return kTfLiteOk;
}
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ return kTfLiteOk;
+}
+
} // namespace rank
TfLiteRegistration* Register_RANK() {
diff --git a/tensorflow/lite/kernels/rank_test.cc b/tensorflow/lite/kernels/rank_test.cc
index f3dc971..5373a0a 100644
--- a/tensorflow/lite/kernels/rank_test.cc
+++ b/tensorflow/lite/kernels/rank_test.cc
@@ -43,6 +43,9 @@
std::vector<int32_t> GetOutput() { return ExtractVector<int32_t>(output_); }
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+ TfLiteAllocationType GetOutputAllocationType() const {
+ return interpreter_->tensor(interpreter_->outputs()[0])->allocation_type;
+ }
private:
int input_;
@@ -51,6 +54,13 @@
TEST(RankOpTest, InputTypeFloat) {
RankOpModel model({1, 3, 1, 3, 5}, TensorType_FLOAT32);
+ ASSERT_EQ(model.GetOutputAllocationType(), kTfLitePersistentRo);
+
+ // Unlike most ops, Rank populates outputs in Prepare().
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({5}));
+ EXPECT_TRUE(model.GetOutputShape().empty());
+
+ // Invoke is superfluous and shouldn't change the output.
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({5}));
@@ -59,7 +69,6 @@
TEST(RankOpTest, InputTypeInt) {
RankOpModel model({1, 3, 1, 3, 5}, TensorType_INT32);
- model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({5}));
EXPECT_TRUE(model.GetOutputShape().empty());
@@ -67,7 +76,6 @@
TEST(RankOpTest, ScalarTensor) {
RankOpModel model({}, TensorType_FLOAT32);
- model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({0}));
EXPECT_TRUE(model.GetOutputShape().empty());
@@ -75,7 +83,6 @@
TEST(RankOpTest, EmptyTensor) {
RankOpModel model({1, 0}, TensorType_FLOAT32);
- model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({2}));
EXPECT_TRUE(model.GetOutputShape().empty());
diff --git a/tensorflow/lite/kernels/shape.cc b/tensorflow/lite/kernels/shape.cc
index 88794fe..d979f08 100644
--- a/tensorflow/lite/kernels/shape.cc
+++ b/tensorflow/lite/kernels/shape.cc
@@ -54,19 +54,22 @@
return kTfLiteError;
}
+ // By design, the input shape is always known at the time of Prepare, even
+ // if the preceding op that generates |input| is dynamic. Thus, we can
+ // always compute the shape immediately, without waiting for Eval.
+ SetTensorToPersistentRo(output);
+
// Shape always produces a 1-dimensional output tensor, where each output
// element is the length of the corresponding input tensor's dimension.
TfLiteIntArray* output_size = TfLiteIntArrayCreate(1);
output_size->data[0] = NumDimensions(input);
- return context->ResizeTensor(context, output, output_size);
-}
+ TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_size));
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TFLITE_DCHECK_EQ(NumDimensions(output), 1);
TFLITE_DCHECK_EQ(SizeOfDimension(output, 0), NumDimensions(input));
+ // Immediately propagate the known shape to the output tensor. This allows
+ // downstream ops that rely on the value to use it during prepare.
switch (output->type) {
case kTfLiteInt32:
ExtractShape(input, GetTensorData<int32_t>(output));
@@ -81,6 +84,10 @@
return kTfLiteOk;
}
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ return kTfLiteOk;
+}
+
} // namespace shape
TfLiteRegistration* Register_SHAPE() {
diff --git a/tensorflow/lite/kernels/shape_test.cc b/tensorflow/lite/kernels/shape_test.cc
index 6a7dad4..3eeb83f 100644
--- a/tensorflow/lite/kernels/shape_test.cc
+++ b/tensorflow/lite/kernels/shape_test.cc
@@ -45,6 +45,9 @@
int32_t GetOutputSize() { return GetTensorSize(output_); }
std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+ TfLiteAllocationType GetOutputAllocationType() const {
+ return interpreter_->tensor(interpreter_->outputs()[0])->allocation_type;
+ }
private:
int input_;
@@ -54,6 +57,13 @@
TEST(ShapeOpTest, OutTypeInt) {
ShapeOpModel<int32_t> model({1, 3, 1, 3, 5}, TensorType_FLOAT32,
TensorType_INT32);
+ ASSERT_EQ(model.GetOutputAllocationType(), kTfLitePersistentRo);
+
+ // Unlike most ops, Rank populates outputs in Prepare().
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 3, 1, 3, 5}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({5}));
+
+ // Invoke is superfluous and shouldn't change the output.
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 3, 1, 3, 5}));
@@ -63,7 +73,6 @@
TEST(ShapeOpTest, OutTypeInt64) {
ShapeOpModel<int64_t> model({1, 3, 1, 3, 5}, TensorType_FLOAT32,
TensorType_INT64);
- model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 3, 1, 3, 5}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({5}));
@@ -71,7 +80,6 @@
TEST(ShapeOpTest, ScalarTensor) {
ShapeOpModel<int32_t> model({}, TensorType_FLOAT32, TensorType_INT32);
- model.Invoke();
EXPECT_EQ(model.GetOutputSize(), 0);
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({0}));
@@ -79,7 +87,6 @@
TEST(ShapeOpTest, EmptyTensor) {
ShapeOpModel<int32_t> model({1, 0}, TensorType_FLOAT32, TensorType_INT32);
- model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2}));
diff --git a/tensorflow/lite/micro/apollo3evb/micro_time.cc b/tensorflow/lite/micro/apollo3evb/micro_time.cc
new file mode 100644
index 0000000..12c9ae5
--- /dev/null
+++ b/tensorflow/lite/micro/apollo3evb/micro_time.cc
@@ -0,0 +1,72 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Reference implementation of timer functions. Platforms are not required to
+// implement these timer methods, but they are required to enable profiling.
+
+// On platforms that have a POSIX stack or C library, it can be written using
+// methods from <sys/time.h> or clock() from <time.h>.
+
+// To add an equivalent function for your own platform, create your own
+// implementation file, and place it in a subfolder with named after the OS
+// you're targeting. For example, see the Cortex M bare metal version in
+// tensorflow/lite/micro/bluepill/micro_timer.cc or the mbed one on
+// tensorflow/lite/micro/mbed/micro_timer.cc.
+
+#include "tensorflow/lite/micro/micro_time.h"
+
+// These are headers from Ambiq's Apollo3 SDK.
+#include "am_bsp.h" // NOLINT
+#include "am_mcu_apollo.h" // NOLINT
+#include "am_util.h" // NOLINT
+
+namespace tflite {
+namespace {
+
+// Select CTIMER 1 as benchmarking timer on Sparkfun Edge. This timer must not
+// be used elsewhere.
+constexpr int kTimerNum = 1;
+
+// Clock set to operate at 12MHz.
+constexpr int kClocksPerSecond = 12e6;
+
+} // namespace
+
+int32_t ticks_per_second() { return kClocksPerSecond; }
+
+// Calling this method enables a timer that runs for eternity. The user is
+// responsible for avoiding trampling on this timer's config, otherwise timing
+// measurements may no longer be valid.
+int32_t GetCurrentTimeTicks() {
+ // TODO(b/150808076): Split out initialization, intialize in interpreter.
+ static bool is_initialized = false;
+ if (!is_initialized) {
+ am_hal_ctimer_config_t timer_config;
+ // Operate as a 32-bit timer.
+ timer_config.ui32Link = 1;
+ // Set timer A to continuous mode at 12MHz.
+ timer_config.ui32TimerAConfig =
+ AM_HAL_CTIMER_FN_CONTINUOUS | AM_HAL_CTIMER_HFRC_12MHZ;
+
+ am_hal_ctimer_stop(kTimerNum, AM_HAL_CTIMER_BOTH);
+ am_hal_ctimer_clear(kTimerNum, AM_HAL_CTIMER_BOTH);
+ am_hal_ctimer_config(kTimerNum, &timer_config);
+ am_hal_ctimer_start(kTimerNum, AM_HAL_CTIMER_TIMERA);
+ is_initialized = true;
+ }
+ return CTIMERn(kTimerNum)->TMR0;
+}
+
+} // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h b/tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h
index 2ed3e45..918192c 100644
--- a/tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h
+++ b/tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h
@@ -65,7 +65,11 @@
ae_q56s result_56 = AE_MULP24S_HH(x_24x2, quantized_multiplier_24x2);
// Shift right if shift amount is positive, left if shift amount is negative.
- result_56 = AE_SLAASQ56S(result_56, shift_amount);
+ if (shift_amount >= 0) {
+ result_56 = AE_Q56S_SRA(result_56, shift_amount);
+ } else {
+ result_56 = AE_Q56S_SLA(result_56, -shift_amount);
+ }
// Round off the bottom 16 bits.
// Q48.0 / 2^16 -> Q32.0 aligned to 48 bits.
diff --git a/tensorflow/lite/micro/micro_allocator.cc b/tensorflow/lite/micro/micro_allocator.cc
index c2b3216..54ce338 100644
--- a/tensorflow/lite/micro/micro_allocator.cc
+++ b/tensorflow/lite/micro/micro_allocator.cc
@@ -388,10 +388,8 @@
return kTfLiteError;
}
subgraph_ = (*subgraphs)[0];
- tensors_ = subgraph_->tensors();
- operators_ = subgraph_->operators();
- context_->tensors_size = tensors_->size();
+ context_->tensors_size = subgraph_->tensors()->size();
context_->tensors =
reinterpret_cast<TfLiteTensor*>(memory_allocator_->AllocateFromTail(
sizeof(TfLiteTensor) * context_->tensors_size,
@@ -405,9 +403,9 @@
}
// Initialize runtime tensors in context_ using the flatbuffer.
- for (size_t i = 0; i < tensors_->size(); ++i) {
+ for (size_t i = 0; i < subgraph_->tensors()->size(); ++i) {
TfLiteStatus status = internal::InitializeRuntimeTensor(
- memory_allocator_, *tensors_->Get(i), model_->buffers(),
+ memory_allocator_, *subgraph_->tensors()->Get(i), model_->buffers(),
error_reporter_, &context_->tensors[i]);
if (status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter_, "Failed to initialize tensor %d",
@@ -472,7 +470,7 @@
auto* output = reinterpret_cast<NodeAndRegistration*>(
memory_allocator_->AllocateFromTail(
- sizeof(NodeAndRegistration) * operators_->size(),
+ sizeof(NodeAndRegistration) * subgraph_->operators()->size(),
alignof(NodeAndRegistration)));
if (output == nullptr) {
TF_LITE_REPORT_ERROR(
@@ -483,8 +481,8 @@
TfLiteStatus status = kTfLiteOk;
auto* opcodes = model_->operator_codes();
MicroBuiltinDataAllocator builtin_data_allocator(memory_allocator_);
- for (size_t i = 0; i < operators_->size(); ++i) {
- const auto* op = operators_->Get(i);
+ for (size_t i = 0; i < subgraph_->operators()->size(); ++i) {
+ const auto* op = subgraph_->operators()->Get(i);
size_t index = op->opcode_index();
if (index >= opcodes->size()) {
TF_LITE_REPORT_ERROR(error_reporter_,
@@ -567,7 +565,7 @@
AllocationInfoBuilder builder(error_reporter_, &tmp_allocator);
TF_LITE_ENSURE_STATUS(
- builder.Init(tensors_->size(), scratch_buffer_count_));
+ builder.Init(subgraph_->tensors()->size(), scratch_buffer_count_));
TF_LITE_ENSURE_STATUS(builder.AddTensors(subgraph_, context_->tensors));
TF_LITE_ENSURE_STATUS(builder.AddScratchBuffers(scratch_buffer_handles_));
const AllocationInfo* allocation_info = builder.Finish();
@@ -606,8 +604,8 @@
// Data in variables need to be kept for the next invocation so allocating
// them from the tail (persistent area).
- if (AllocateVariables(tensors_, context_->tensors, memory_allocator_) !=
- kTfLiteOk) {
+ if (AllocateVariables(subgraph_->tensors(), context_->tensors,
+ memory_allocator_) != kTfLiteOk) {
TF_LITE_REPORT_ERROR(
error_reporter_,
"Failed to allocate variables. Please increase arena size.");
diff --git a/tensorflow/lite/micro/micro_allocator.h b/tensorflow/lite/micro/micro_allocator.h
index b16f814..6a6e1e0 100644
--- a/tensorflow/lite/micro/micro_allocator.h
+++ b/tensorflow/lite/micro/micro_allocator.h
@@ -135,8 +135,6 @@
size_t scratch_buffer_count_ = 0;
const SubGraph* subgraph_;
- const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators_;
- const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors_;
};
} // namespace tflite
diff --git a/tensorflow/lite/micro/micro_optional_debug_tools.cc b/tensorflow/lite/micro/micro_optional_debug_tools.cc
index 70f16c7..42c42ae 100644
--- a/tensorflow/lite/micro/micro_optional_debug_tools.cc
+++ b/tensorflow/lite/micro/micro_optional_debug_tools.cc
@@ -95,6 +95,8 @@
return "kTfLiteArenaRw";
case kTfLiteArenaRwPersistent:
return "kTfLiteArenaRwPersistent";
+ case kTfLitePersistentRo:
+ return "kTfLitePersistentRo";
}
return "(invalid)";
}
diff --git a/tensorflow/lite/optional_debug_tools.cc b/tensorflow/lite/optional_debug_tools.cc
index c5ccdb9..2e25b0a 100644
--- a/tensorflow/lite/optional_debug_tools.cc
+++ b/tensorflow/lite/optional_debug_tools.cc
@@ -77,6 +77,8 @@
return "kTfLiteArenaRw";
case kTfLiteArenaRwPersistent:
return "kTfLiteArenaRwPersistent";
+ case kTfLitePersistentRo:
+ return "kTfLitePersistentRo";
}
return "(invalid)";
}
diff --git a/tensorflow/lite/python/lite_test.py b/tensorflow/lite/python/lite_test.py
index 9ddd09e..1bcb2ce 100644
--- a/tensorflow/lite/python/lite_test.py
+++ b/tensorflow/lite/python/lite_test.py
@@ -269,9 +269,7 @@
[out_tensor])
converter.inference_input_type = lite_constants.QUANTIZED_UINT8
converter.inference_type = lite_constants.FLOAT
- converter.quantized_input_stats = {
- 'Placeholder': (0., 1.)
- } # mean, std_dev
+ converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -1327,6 +1325,41 @@
tflite_model = converter.convert()
self.assertTrue(tflite_model)
+ def testResizeWithShape(self):
+ with ops.Graph().as_default():
+ # Construct a graph with a dynamically shapped input and an internal node
+ # that relies on the output of that input's shape.
+ in_tensor = array_ops.placeholder(
+ shape=[None, None], dtype=dtypes.float32)
+ in_tensor2 = [[1, 2], [3, 4]]
+ out_tensor = array_ops.reshape(in_tensor2, array_ops.shape(in_tensor))
+ sess = session.Session()
+
+ converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+ [out_tensor])
+ converter.experimental_new_converter = True
+ tflite_model = converter.convert()
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ input_details = interpreter.get_input_details()
+ self.assertLen(input_details, 1)
+ self.assertTrue(([1, 1] == input_details[0]['shape']).all())
+ self.assertTrue(([-1, -1] == input_details[0]['shape_signature']).all())
+
+ # Resize tensor and invoke.
+ interpreter.resize_tensor_input(0, [4])
+ interpreter.allocate_tensors()
+ interpreter.invoke()
+
+ # The output should be reshaped properly according to the resized input.
+ output_details = interpreter.get_output_details()
+ self.assertLen(output_details, 1)
+ self.assertEqual(np.int32, output_details[0]['dtype'])
+ self.assertTrue(([4] == output_details[0]['shape']).all())
+ output_data = interpreter.get_tensor(output_details[0]['index'])
+ self.assertTrue(([1, 2, 3, 4] == output_data).all())
+
def testResizingIntermediateDynamicTensor(self):
# This is a regression test for the case where shape of dynamic output
# tensors changes between invocations.
diff --git a/tensorflow/lite/testing/BUILD b/tensorflow/lite/testing/BUILD
index 9d50f1a..df85f65 100644
--- a/tensorflow/lite/testing/BUILD
+++ b/tensorflow/lite/testing/BUILD
@@ -329,7 +329,7 @@
"//tensorflow/core:android_tensorflow_lib",
],
"//tensorflow:ios": [
- "//tensorflow/core:ios_tensorflow_lib",
+ "//tensorflow/core:portable_tensorflow_lib",
],
}),
)
@@ -368,7 +368,7 @@
"//tensorflow/core:android_tensorflow_lib",
],
"//tensorflow:ios": [
- "//tensorflow/core:ios_tensorflow_lib",
+ "//tensorflow/core:portable_tensorflow_lib",
],
}),
)
@@ -408,7 +408,7 @@
"//tensorflow/core:android_tensorflow_lib",
],
"//tensorflow:ios": [
- "//tensorflow/core:ios_tensorflow_lib",
+ "//tensorflow/core:portable_tensorflow_lib",
],
}),
)
@@ -443,7 +443,7 @@
"//tensorflow/core:android_tensorflow_lib",
],
"//tensorflow:ios": [
- "//tensorflow/core:ios_tensorflow_lib",
+ "//tensorflow/core:portable_tensorflow_lib",
],
}),
)
diff --git a/tensorflow/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/lite/toco/graph_transformations/hardcode_min_max.cc
index 9816cc1..171d522 100644
--- a/tensorflow/lite/toco/graph_transformations/hardcode_min_max.cc
+++ b/tensorflow/lite/toco/graph_transformations/hardcode_min_max.cc
@@ -271,8 +271,8 @@
const double magnitude =
std::min(minmax1.max - minmax1.min, minmax2.max - minmax2.min);
const double tolerated = 1e-6 * magnitude;
- return std::abs(minmax1.min - minmax2.min) < tolerated &&
- std::abs(minmax1.max - minmax2.max) < tolerated;
+ return std::abs(minmax1.min - minmax2.min) <= tolerated &&
+ std::abs(minmax1.max - minmax2.max) <= tolerated;
}
// Propagates MinMax from any of the listed arrays, to all others.
diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc
index 57b791a..917fd24 100644
--- a/tensorflow/lite/toco/tflite/operator.cc
+++ b/tensorflow/lite/toco/tflite/operator.cc
@@ -1118,6 +1118,7 @@
GetVersioningOpSig(builtin_op(), op_signature);
op_sig.options.resize.half_pixel_centers =
resize_bilinear_op.half_pixel_centers;
+ op_sig.options.resize.align_corners = resize_bilinear_op.align_corners;
return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
};
@@ -1147,6 +1148,7 @@
::tflite::OpSignature op_sig =
GetVersioningOpSig(builtin_op(), op_signature);
op_sig.options.resize.half_pixel_centers = resize_nn_op.half_pixel_centers;
+ op_sig.options.resize.align_corners = resize_nn_op.align_corners;
return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
};
diff --git a/tensorflow/lite/tools/benchmark/BUILD b/tensorflow/lite/tools/benchmark/BUILD
index 3570722..f6cb717 100644
--- a/tensorflow/lite/tools/benchmark/BUILD
+++ b/tensorflow/lite/tools/benchmark/BUILD
@@ -142,6 +142,7 @@
":profiling_listener",
"//tensorflow/lite:framework",
"//tensorflow/lite:string_util",
+ "//tensorflow/lite/c:common",
"//tensorflow/lite/kernels:builtin_ops",
"//tensorflow/lite/profiling:platform_profiler",
"//tensorflow/lite/profiling:profile_summary_formatter",
diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
index 489780e..969713c 100644
--- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
+++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
@@ -29,6 +29,7 @@
#include "absl/base/attributes.h"
#include "absl/strings/numbers.h"
#include "ruy/profiler/profiler.h" // from @ruy
+#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/op_resolver.h"
@@ -596,17 +597,20 @@
return kTfLiteOk;
}
-TfLiteStatus BenchmarkTfLiteModel::Init() {
- TF_LITE_ENSURE_STATUS(LoadModel());
-
+TfLiteStatus BenchmarkTfLiteModel::InitInterpreter() {
auto resolver = GetOpResolver();
-
const int32_t num_threads = params_.Get<int32_t>("num_threads");
tflite::InterpreterBuilder(*model_, *resolver)(&interpreter_, num_threads);
if (!interpreter_) {
- TFLITE_LOG(ERROR) << "Failed to construct interpreter";
+ TFLITE_LOG(ERROR) << "Failed to initialize the interpreter";
return kTfLiteError;
}
+ return kTfLiteOk;
+}
+
+TfLiteStatus BenchmarkTfLiteModel::Init() {
+ TF_LITE_ENSURE_STATUS(LoadModel());
+ TF_LITE_ENSURE_STATUS(InitInterpreter());
// Install profilers if necessary right after interpreter is created so that
// any memory allocations inside the TFLite runtime could be recorded if the
diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h
index b56390b..cc87743 100644
--- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h
+++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h
@@ -74,6 +74,9 @@
// Allow subclasses to create a customized Op resolver during init.
virtual std::unique_ptr<tflite::OpResolver> GetOpResolver() const;
+ // Allow subclass to initialize a customized tflite interpereter.
+ virtual TfLiteStatus InitInterpreter();
+
// Create a BenchmarkListener that's specifically for TFLite profiling if
// necessary.
virtual std::unique_ptr<BenchmarkListener> MayCreateProfilingListener() const;
diff --git a/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h b/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h
index 9657c7e..ab150e8 100644
--- a/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h
+++ b/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h
@@ -321,15 +321,23 @@
void* data;
} TfLitePtrUnion;
-// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped
-// data (or data externally allocated). kTfLiteArenaRw is arena allocated
-// data. kTfLiteDynamic is for tensors that are allocated during evaluation.
+// Memory allocation strategies.
+// * kTfLiteMmapRo: Read-only memory-mapped data, or data externally allocated.
+// * kTfLiteArenaRw: Arena allocated with no guarantees about persistence,
+// and available during eval.
+// * kTfLiteArenaRwPersistent: Arena allocated but persistent across eval, and
+// only available during eval.
+// * kTfLiteDynamic: Allocated during eval, or for string tensors.
+// * kTfLitePersistentRo: Allocated and populated during prepare. This is
+// useful for tensors that can be computed during prepare and treated
+// as constant inputs for downstream ops (also in prepare).
typedef enum TfLiteAllocationType {
kTfLiteMemNone = 0,
kTfLiteMmapRo,
kTfLiteArenaRw,
kTfLiteArenaRwPersistent,
kTfLiteDynamic,
+ kTfLitePersistentRo,
} TfLiteAllocationType;
// The delegates should use zero or positive integers to represent handles.
diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc
index 56aa9d5..9022afc 100644
--- a/tensorflow/lite/tools/versioning/op_version.cc
+++ b/tensorflow/lite/tools/versioning/op_version.cc
@@ -363,13 +363,20 @@
}
return 1;
case BuiltinOperator_RESIZE_BILINEAR:
- case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR:
if (op_sig.options.resize.half_pixel_centers) {
return 3;
} else if (op_sig.input_types.at(0) == TensorType_INT8) {
return 2;
}
return 1;
+ case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR:
+ if (op_sig.options.resize.half_pixel_centers ||
+ op_sig.options.resize.align_corners) {
+ return 3;
+ } else if (op_sig.input_types.at(0) == TensorType_INT8) {
+ return 2;
+ }
+ return 1;
case BuiltinOperator_MAXIMUM:
case BuiltinOperator_MINIMUM:
@@ -612,6 +619,8 @@
if (resize_bilinear_option) {
op_sig.options.resize.half_pixel_centers =
resize_bilinear_option->half_pixel_centers();
+ op_sig.options.resize.align_corners =
+ resize_bilinear_option->align_corners();
}
} break;
case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR: {
@@ -620,6 +629,7 @@
if (resize_nn_option) {
op_sig.options.resize.half_pixel_centers =
resize_nn_option->half_pixel_centers();
+ op_sig.options.resize.align_corners = resize_nn_option->align_corners();
}
} break;
// TODO(b/150176627): Add tests for GetOpSignature.
diff --git a/tensorflow/lite/tools/versioning/op_version.h b/tensorflow/lite/tools/versioning/op_version.h
index fba6c94..4b0fe88 100644
--- a/tensorflow/lite/tools/versioning/op_version.h
+++ b/tensorflow/lite/tools/versioning/op_version.h
@@ -48,6 +48,7 @@
} lstm;
struct {
bool half_pixel_centers;
+ bool align_corners;
} resize;
struct {
int32_t num_dims;
diff --git a/tensorflow/lite/tools/versioning/op_version_test.cc b/tensorflow/lite/tools/versioning/op_version_test.cc
index 7d9039f..f0d8259 100644
--- a/tensorflow/lite/tools/versioning/op_version_test.cc
+++ b/tensorflow/lite/tools/versioning/op_version_test.cc
@@ -594,4 +594,64 @@
TensorType_INT32}};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
}
+TEST(OpVersionTest, VersioningResizeBilinearTest) {
+ // Default.
+ OpSignature fake_op_sig = {
+ .op = BuiltinOperator_RESIZE_BILINEAR,
+ .input_types =
+ std::vector<TensorType>{TensorType_FLOAT32, TensorType_INT32},
+ .output_types = std::vector<TensorType>{TensorType_FLOAT32},
+ };
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
+
+ // align_corners=true is still version 1.
+ fake_op_sig.options.resize.align_corners = true;
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
+
+ // half_pixel_centers=true must be version 3.
+ fake_op_sig.options.resize.align_corners = false;
+ fake_op_sig.options.resize.half_pixel_centers = true;
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
+
+ // int8 input is version 2.
+ fake_op_sig = {
+ .op = BuiltinOperator_RESIZE_BILINEAR,
+ .input_types = std::vector<TensorType>{TensorType_INT8, TensorType_INT32},
+ .output_types = std::vector<TensorType>{TensorType_INT8},
+ };
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
+
+ fake_op_sig.options.resize.half_pixel_centers = true;
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
+}
+TEST(OpVersionTest, VersioningResizeNearestNeighborTest) {
+ // Default.
+ OpSignature fake_op_sig = {
+ .op = BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
+ .input_types =
+ std::vector<TensorType>{TensorType_FLOAT32, TensorType_INT32},
+ .output_types = std::vector<TensorType>{TensorType_FLOAT32},
+ };
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 1);
+
+ // align_corners=true is version 3.
+ fake_op_sig.options.resize.align_corners = true;
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
+
+ // half_pixel_centers=true must be version 3.
+ fake_op_sig.options.resize.align_corners = false;
+ fake_op_sig.options.resize.half_pixel_centers = true;
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
+
+ // int8 input is version 2.
+ fake_op_sig = {
+ .op = BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
+ .input_types = std::vector<TensorType>{TensorType_INT8, TensorType_INT32},
+ .output_types = std::vector<TensorType>{TensorType_INT8},
+ };
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
+
+ fake_op_sig.options.resize.align_corners = true;
+ EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
+}
} // namespace tflite
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 0b046ea..a49e4b7 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -230,6 +230,7 @@
"//tensorflow/python/tools:module_util",
"//tensorflow/python/tools/api/generator:create_python_api",
"//tensorflow/python/tpu:tpu_noestimator",
+ "//tensorflow/python/types",
"//third_party/py/numpy",
],
)
@@ -655,15 +656,15 @@
"@com_google_absl//absl/types:optional",
] + if_static(
extra_deps = [
- "//tensorflow/core:eager_service_proto_cc",
- "//tensorflow/core:master_proto_cc",
- "//tensorflow/core:worker_proto_cc",
+ "//tensorflow/core/protobuf:eager_service_proto_cc",
+ "//tensorflow/core/protobuf:master_proto_cc",
+ "//tensorflow/core/protobuf:worker_proto_cc",
"//tensorflow/core:version_lib",
],
otherwise = [
- "//tensorflow/core:eager_service_proto_cc_headers_only",
- "//tensorflow/core:master_proto_cc_headers_only",
- "//tensorflow/core:worker_proto_cc_headers_only",
+ "//tensorflow/core/protobuf:eager_service_proto_cc_headers_only",
+ "//tensorflow/core/protobuf:master_proto_cc_headers_only",
+ "//tensorflow/core/protobuf:worker_proto_cc_headers_only",
],
),
)
@@ -8049,14 +8050,14 @@
"//tensorflow/core/platform",
] + if_static(
extra_deps = [
- "//tensorflow/core:eager_service_proto_cc",
- "//tensorflow/core:master_proto_cc",
- "//tensorflow/core:worker_proto_cc",
+ "//tensorflow/core/protobuf:eager_service_proto_cc",
+ "//tensorflow/core/protobuf:master_proto_cc",
+ "//tensorflow/core/protobuf:worker_proto_cc",
],
otherwise = [
- "//tensorflow/core:eager_service_proto_cc_headers_only",
- "//tensorflow/core:master_proto_cc_headers_only",
- "//tensorflow/core:worker_proto_cc_headers_only",
+ "//tensorflow/core/protobuf:eager_service_proto_cc_headers_only",
+ "//tensorflow/core/protobuf:master_proto_cc_headers_only",
+ "//tensorflow/core/protobuf:worker_proto_cc_headers_only",
],
),
)
diff --git a/tensorflow/python/autograph/g3doc/reference/control_flow.md b/tensorflow/python/autograph/g3doc/reference/control_flow.md
index 79cc0f3..cf580af 100644
--- a/tensorflow/python/autograph/g3doc/reference/control_flow.md
+++ b/tensorflow/python/autograph/g3doc/reference/control_flow.md
@@ -164,7 +164,7 @@
#### Python values modified in TensorFlow control flow become Tensors
If a symbol is modified in a TensorFlow control flow statement, then it becomes
-a `tf.Tensor`, even if it started off as a Python promitive value.
+a `tf.Tensor`, even if it started off as a Python primitive value.
For example, the conditional below will run as a `tf.cond` (its condition is a
`tf.Tensor`), which in turn will cause `i` to become a `tf.Tensor`.
diff --git a/tensorflow/python/autograph/g3doc/reference/generated_code.md b/tensorflow/python/autograph/g3doc/reference/generated_code.md
index b62911b..389fa53 100644
--- a/tensorflow/python/autograph/g3doc/reference/generated_code.md
+++ b/tensorflow/python/autograph/g3doc/reference/generated_code.md
@@ -66,7 +66,7 @@
```
`tf.autograph.to_code` is a shortcut to obtain the generated code, and it's
-equivalent with calling `inspect.getsource(tf.autograph.to_code(f))`.
+equivalent with calling `inspect.getsource(tf.autograph.to_graph(f))`.
#### Recording diagnostic information: `tf.autograph.set_verbosity`
diff --git a/tensorflow/python/autograph/operators/data_structures_test.py b/tensorflow/python/autograph/operators/data_structures_test.py
index c5a3a3d..5d835fd 100644
--- a/tensorflow/python/autograph/operators/data_structures_test.py
+++ b/tensorflow/python/autograph/operators/data_structures_test.py
@@ -106,11 +106,12 @@
with self.cached_session() as sess:
self.assertAllEqual(self.evaluate(t), [[1, 2, 3]])
- @test_util.run_v1_only("b/117943489")
+ @test_util.run_deprecated_v1
def test_append_tensorarray(self):
l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True)
l1 = data_structures.list_append(l, 1)
l2 = data_structures.list_append(l1, 2)
+
with self.cached_session() as sess:
self.assertAllEqual(self.evaluate(l1.stack()), [1])
self.assertAllEqual(self.evaluate(l2.stack()), [1, 2])
diff --git a/tensorflow/python/autograph/utils/tensor_list_test.py b/tensorflow/python/autograph/utils/tensor_list_test.py
index bbbc3bf..017d97b 100644
--- a/tensorflow/python/autograph/utils/tensor_list_test.py
+++ b/tensorflow/python/autograph/utils/tensor_list_test.py
@@ -34,7 +34,6 @@
def _shape(self, shape_tuple):
return constant(shape_tuple, dtypes.int32)
- @test_util.run_v1_only("b/117943489")
def test_dynamic_list_append(self):
l = []
l = tl.dynamic_list_append(l, 1)
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index 17f559e..627979a 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -33,7 +33,7 @@
# This value changes every day with an automatic CL. It can be modified in code
# via `forward_compatibility_horizon()` or with the environment variable
# TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 5, 10)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 5, 12)
_FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
_FORWARD_COMPATIBILITY_DATE_NUMBER = None
diff --git a/tensorflow/python/data/experimental/ops/data_service_ops.py b/tensorflow/python/data/experimental/ops/data_service_ops.py
index c1c2366..67dfadb 100644
--- a/tensorflow/python/data/experimental/ops/data_service_ops.py
+++ b/tensorflow/python/data/experimental/ops/data_service_ops.py
@@ -84,15 +84,29 @@
if task_refresh_interval_hint_ms is None:
task_refresh_interval_hint_ms = dataset_ops.AUTOTUNE
+ self._dataset_id = ops.convert_to_tensor(
+ dataset_id, dtype=dtypes.int64, name="dataset_id")
+ self._processing_mode = ops.convert_to_tensor(
+ processing_mode, dtype=dtypes.string, name="processing_mode")
+ self._address = ops.convert_to_tensor(
+ address, dtype=dtypes.string, name="address")
+ self._protocol = ops.convert_to_tensor(
+ protocol, dtype=dtypes.string, name="protocol")
+ self._job_name = ops.convert_to_tensor(
+ job_name, dtype=dtypes.string, name="job_name")
+ self._max_outstanding_requests = ops.convert_to_tensor(
+ max_outstanding_requests,
+ dtype=dtypes.int64,
+ name="max_outstanding_requests")
self._element_spec = input_dataset.element_spec
variant_tensor = gen_experimental_dataset_ops.data_service_dataset(
- dataset_id=dataset_id,
- processing_mode=processing_mode,
- address=address,
- protocol=protocol,
- job_name=job_name,
- max_outstanding_requests=max_outstanding_requests,
+ dataset_id=self._dataset_id,
+ processing_mode=self._processing_mode,
+ address=self._address,
+ protocol=self._protocol,
+ job_name=self._job_name,
+ max_outstanding_requests=self._max_outstanding_requests,
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
iteration_counter=gen_experimental_dataset_ops.dummy_iteration_counter(
),
@@ -297,5 +311,8 @@
Returns:
Dataset: A `Dataset` of the elements produced by the data service.
"""
- return _distribute(processing_mode, service, job_name,
- max_outstanding_requests)
+ return _distribute(
+ processing_mode=processing_mode,
+ service=service,
+ job_name=job_name,
+ max_outstanding_requests=max_outstanding_requests)
diff --git a/tensorflow/python/data/kernel_tests/data_service_ops_test.py b/tensorflow/python/data/kernel_tests/data_service_ops_test.py
index eac1c67..217c586 100644
--- a/tensorflow/python/data/kernel_tests/data_service_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/data_service_ops_test.py
@@ -216,6 +216,21 @@
self.assertEqual(i, val)
@combinations.generate(test_base.eager_only_combinations())
+ def testMaxOutstandingRequests(self):
+ num_elements = 10
+ num_workers = 3
+ service = self.create_cluster(num_workers)
+ ds = dataset_ops.Dataset.range(num_elements)
+ ds = ds.apply(
+ data_service_ops._distribute(
+ "parallel_epochs",
+ service,
+ max_outstanding_requests=1,
+ task_refresh_interval_hint_ms=20))
+ self.assertCountEqual(num_workers * list(range(num_elements)),
+ self.getDatasetOutput(ds))
+
+ @combinations.generate(test_base.eager_only_combinations())
def testInsideFunction(self):
num_workers = 3
num_elements = 10
diff --git a/tensorflow/python/distribute/multi_worker_test_base.py b/tensorflow/python/distribute/multi_worker_test_base.py
index a80ca83..408cad2 100644
--- a/tensorflow/python/distribute/multi_worker_test_base.py
+++ b/tensorflow/python/distribute/multi_worker_test_base.py
@@ -50,6 +50,7 @@
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import coordinator
from tensorflow.python.training import server_lib
+from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
from tensorflow.python.util.compat import collections_abc
@@ -559,6 +560,10 @@
return subprocess.Popen(
cmd_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env)
+ @deprecation.deprecated(
+ None, '`run_multiple_tasks_in_processes` is deprecated; any new test '
+ 'requiring multiple processes should use `multi_process_runner` for '
+ 'better support of log printing, streaming, and more functionality.')
def run_multiple_tasks_in_processes(self, cmd_args, cluster_spec):
"""Run `cmd_args` in a process for each task in `cluster_spec`."""
processes = {}
@@ -570,6 +575,10 @@
processes[task_type].append(p)
return processes
+ @deprecation.deprecated(
+ None, '`join_independent_workers` is deprecated; any new test '
+ 'requiring multiple processes should use `multi_process_runner` for '
+ 'better support of log printing, streaming, and more functionality.')
def join_independent_workers(self, worker_processes):
return_codes = []
for p in nest.flatten(worker_processes):
@@ -585,6 +594,10 @@
for return_code in return_codes:
self.assertEqual(return_code, 0)
+ @deprecation.deprecated(
+ None, '`stream_stderr` is deprecated; any new test '
+ 'requiring multiple processes should use `multi_process_runner` for '
+ 'better support of log printing, streaming, and more functionality.')
def stream_stderr(self, processes, print_only_first=False):
"""Consume stderr of all processes and print to stdout.
diff --git a/tensorflow/python/distribute/parallel_device/BUILD b/tensorflow/python/distribute/parallel_device/BUILD
index e7526a5..930816d 100644
--- a/tensorflow/python/distribute/parallel_device/BUILD
+++ b/tensorflow/python/distribute/parallel_device/BUILD
@@ -1,4 +1,8 @@
+load("//tensorflow:tensorflow.bzl", "tf_custom_op_library", "tf_gen_op_wrapper_py")
+load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
+
package(
+ default_visibility = ["//tensorflow:internal"],
licenses = ["notice"], # Apache 2.0
)
@@ -13,6 +17,7 @@
srcs = ["parallel_device.py"],
srcs_version = "PY2AND3",
deps = [
+ ":parallel_device_ops",
":saving",
"//tensorflow/python:_pywrap_parallel_device",
],
@@ -25,6 +30,25 @@
deps = ["//tensorflow/python:framework_ops"],
)
+tf_gen_op_wrapper_py(
+ name = "parallel_device_ops_py",
+ out = "gen_parallel_device_ops.py",
+ deps = ["//tensorflow/c/eager/parallel_device:parallel_device_ops"],
+)
+
+tf_custom_op_library(
+ name = "_parallel_device_ops.so",
+ srcs = ["//tensorflow/c/eager/parallel_device:parallel_device_ops_srcs"],
+)
+
+tf_custom_op_py_library(
+ name = "parallel_device_ops",
+ dso = [":_parallel_device_ops.so"],
+ kernels = ["//tensorflow/c/eager/parallel_device:parallel_device_ops"],
+ visibility = ["//tensorflow:internal"],
+ deps = [":parallel_device_ops_py"],
+)
+
py_test(
name = "parallel_device_test",
srcs = ["parallel_device_test.py"],
diff --git a/tensorflow/python/distribute/parallel_device/parallel_device.py b/tensorflow/python/distribute/parallel_device/parallel_device.py
index 982b061..2dbdc65 100644
--- a/tensorflow/python/distribute/parallel_device/parallel_device.py
+++ b/tensorflow/python/distribute/parallel_device/parallel_device.py
@@ -22,11 +22,17 @@
import threading
from tensorflow.python import _pywrap_parallel_device
+from tensorflow.python.distribute.parallel_device import gen_parallel_device_ops
from tensorflow.python.distribute.parallel_device import saving
from tensorflow.python.eager import context
+from tensorflow.python.framework import load_library
from tensorflow.python.framework import ops
+from tensorflow.python.platform import resource_loader
from tensorflow.python.tpu.ops import tpu_ops
+load_library.load_op_library(
+ resource_loader.get_path_to_datafile("_parallel_device_ops.so"))
+
_next_device_number = 0
_next_device_number_lock = threading.Lock()
@@ -58,6 +64,8 @@
device, device_info = _pywrap_parallel_device.GetParallelDeviceCapsules(
self.name, self.components)
context.register_custom_device(device, self.name, device_info)
+ with ops.device(self.name):
+ self._device_ids = gen_parallel_device_ops.device_id()
def pack(self, tensors):
"""Create a tensor on the parallel device from a sequence of tensors.
@@ -84,6 +92,18 @@
return tpu_ops.tpu_replicated_output(
parallel_tensor, num_replicas=len(self.components))
+ @property
+ def device_ids(self):
+ """A parallel tensor with scalar integers numbering component devices.
+
+ Each device ID is placed on its corresponding device, in the same order as
+ the `components` constructor argument.
+
+ Returns:
+ A parallel tensor containing 0 on the first device, 1 on the second, etc.
+ """
+ return self._device_ids
+
# TODO(allenl): Fixing saving in Python is a bit odd. One alternative would be
# to provide a hook for the custom device to create save specs/etc., then call
# that hook from the default variable implementation if the variable is on a
diff --git a/tensorflow/python/distribute/parallel_device/parallel_device_test.py b/tensorflow/python/distribute/parallel_device/parallel_device_test.py
index d3f3417..e35eb60 100644
--- a/tensorflow/python/distribute/parallel_device/parallel_device_test.py
+++ b/tensorflow/python/distribute/parallel_device/parallel_device_test.py
@@ -119,6 +119,12 @@
self.assertIn(self.device.components[0], outputs[0].backing_device)
self.assertIn(self.device.components[1], outputs[1].backing_device)
+ def test_device_id(self):
+ device_ids = self.device.unpack(self.device.device_ids)
+ self.assertAllClose([0, 1], device_ids)
+ self.assertIn(self.device.components[0], device_ids[0].backing_device)
+ self.assertIn(self.device.components[1], device_ids[1].backing_device)
+
def test_collective_reduce(self):
with ops.device(self.device.name):
x = self.device.pack(
diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py
index 6e51b84..b574c52 100644
--- a/tensorflow/python/distribute/tpu_strategy.py
+++ b/tensorflow/python/distribute/tpu_strategy.py
@@ -96,35 +96,34 @@
@tf_export("distribute.experimental.TPUStrategy", v1=[])
class TPUStrategy(distribute_lib.Strategy):
- """TPU distribution strategy implementation."""
+ """TPU distribution strategy implementation.
+
+ To construct a TPUStrategy object, you need to run the
+ initialization code as below:
+
+ >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
+ >>> tf.config.experimental_connect_to_cluster(resolver)
+ >>> tf.tpu.experimental.initialize_tpu_system(resolver)
+ >>> strategy = tf.distribute.experimental.TPUStrategy(resolver)
+
+ While using distribution strategies, the variables created within strategy's
+ scope will be replicated across all the replicas and can be kept in sync
+ using all-reduce algorithms.
+
+ To run TF2 programs on TPUs, you can either use `.compile` and
+ `.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized
+ training loop by calling `strategy.run` directly. Note that
+ TPUStrategy doesn't support pure eager execution, so please make sure the
+ function passed into `strategy.run` is a `tf.function` or
+ `strategy.run` is called inside a `tf.function` if eager
+ behavior is enabled.
+ """
def __init__(self,
tpu_cluster_resolver=None,
device_assignment=None):
"""Synchronous training in TPU donuts or Pods.
- To construct a TPUStrategy object, you need to run the
- initialization code as below:
-
- ```python
- resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
- tf.config.experimental_connect_to_cluster(resolver)
- tf.tpu.experimental.initialize_tpu_system(resolver)
- strategy = tf.distribute.experimental.TPUStrategy(resolver)
- ```
-
- While using distribution strategies, the variables created within strategy's
- scope will be replicated across all the replicas and can be kept in sync
- using all-reduce algorithms.
-
- To run TF2 programs on TPUs, you can either use `.compile` and
- `.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized
- training loop by calling `strategy.run` directly. Note that
- TPUStrategy doesn't support pure eager execution, so please make sure the
- function passed into `strategy.run` is a `tf.function` or
- `strategy.run` is called inside a `tf.function` if eager
- behavior is enabled.
-
Args:
tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
which provides information about the TPU cluster.
@@ -209,26 +208,26 @@
Users can pass strategy specific options to `options` argument. An example
to enable bucketizing dynamic shapes in `TPUStrategy.run`
is:
- ```python
- resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
- tf.config.experimental_connect_to_cluster(resolver)
- tf.tpu.experimental.initialize_tpu_system(resolver)
- strategy = tf.distribute.experimental.TPUStrategy(tpu='')
+ >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
+ >>> tf.config.experimental_connect_to_cluster(resolver)
+ >>> tf.tpu.experimental.initialize_tpu_system(resolver)
+ >>> strategy = tf.distribute.experimental.TPUStrategy(resolver)
- options = tf.distribute.RunOptions()
- options.experimental_bucketizing_dynamic_shape = True
+ >>> options = tf.distribute.RunOptions(
+ ... experimental_bucketizing_dynamic_shape=True)
- iterator = iter(inputs)
+ >>> dataset = tf.data.Dataset.range(
+ ... strategy.num_replicas_in_sync, output_type=dtypes.float32).batch(
+ ... strategy.num_replicas_in_sync, drop_remainder=True)
+ >>> input_iterator = iter(strategy.experimental_distribute_dataset(dataset))
- @tf.function()
- def step_fn(inputs):
- output = tf.reduce_sum(inputs)
- return output
+ >>> @tf.function()
+ ... def step_fn(inputs):
+ ... output = tf.reduce_sum(inputs)
+ ... return output
- strategy.run(step_fn, args=(next(iterator),),
- options=options)
- ```
+ >>> strategy.run(step_fn, args=(next(input_iterator),), options=options)
Args:
fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py
index 4fe3d28..444915a 100644
--- a/tensorflow/python/distribute/values.py
+++ b/tensorflow/python/distribute/values.py
@@ -38,6 +38,7 @@
from tensorflow.python.training.saving import saveable_object
from tensorflow.python.training.saving import saveable_object_util
from tensorflow.python.training.tracking import base as trackable
+from tensorflow.python.types import core
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@@ -422,7 +423,8 @@
return hash((self.name, self.graph, self.traceback, self.type))
-class DistributedVariable(DistributedDelegate, variables_lib.Variable):
+class DistributedVariable(DistributedDelegate, variables_lib.Variable,
+ core.Tensor):
"""Holds a map from replica to variables."""
# TODO(josh11b): Support changing the set of variables if e.g. if new
@@ -741,9 +743,6 @@
pass
-ops.register_dense_tensor_like_type(DistributedVariable)
-
-
def _validate_colocate_extended(v, extended):
variable_strategy = v._distribute_strategy # pylint: disable=protected-access
if variable_strategy.extended is not extended:
@@ -1380,7 +1379,7 @@
return val
-class AggregatingVariable(variables_lib.Variable):
+class AggregatingVariable(variables_lib.Variable, core.Tensor):
"""A wrapper around a variable that aggregates updates across replicas."""
def __init__(self, strategy, v, aggregation):
@@ -1649,4 +1648,3 @@
ops.register_tensor_conversion_function(AggregatingVariable,
_tensor_conversion_aggregate)
-ops.register_dense_tensor_like_type(AggregatingVariable)
diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py
index daa7e55..67ed86b 100644
--- a/tensorflow/python/distribute/values_test.py
+++ b/tensorflow/python/distribute/values_test.py
@@ -56,6 +56,7 @@
from tensorflow.python.tpu import tpu_strategy_util
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training.tracking import util as trackable_utils
+from tensorflow.python.types import core
from tensorflow.python.util import nest
@@ -623,10 +624,10 @@
v = variables_lib.Variable(
0., synchronization=synchronization, aggregation=aggregation)
# In cross replica context.
- self.assertTrue(ops.is_dense_tensor_like(v))
+ self.assertIsInstance(v, core.Tensor)
# In replica context.
distribution.run(
- lambda v: self.assertTrue(ops.is_dense_tensor_like(v)), args=(v,))
+ lambda v: self.assertIsInstance(v, core.Tensor), args=(v,))
def testAssignReturnValueIsTensorLike(self, distribution, synchronization,
aggregation):
@@ -645,9 +646,9 @@
# values is not allowed when aggregation is SUM. See
# `cross_device_ops.reduce_non_distributed_value`.
delta = array_ops.identity(1.)
- self.assertTrue(ops.is_dense_tensor_like(v.assign(delta)))
- self.assertTrue(ops.is_dense_tensor_like(v.assign_sub(delta)))
- self.assertTrue(ops.is_dense_tensor_like(v.assign_add(delta)))
+ self.assertIsInstance(v.assign(delta), core.Tensor)
+ self.assertIsInstance(v.assign_sub(delta), core.Tensor)
+ self.assertIsInstance(v.assign_add(delta), core.Tensor)
# In cross replica context we return a PerReplica which is not Tensor like
# yet.
diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py
index f2f1327..227fca5 100644
--- a/tensorflow/python/eager/benchmarks_test.py
+++ b/tensorflow/python/eager/benchmarks_test.py
@@ -180,22 +180,18 @@
func() # Warmup.
self._run(func, 3000)
- @test_util.disable_tfrt("Scalars are not handled correctly")
def benchmark_create_float_constant(self):
self._benchmark_create_constant(42.0, dtype=None)
- @test_util.disable_tfrt("Scalars are not handled correctly")
def benchmark_create_float_constant_uncached(self):
self._benchmark_create_constant(42.0, dtype=None, cached=False)
- @test_util.disable_tfrt("Scalars are not handled correctly")
def benchmark_create_int32_constant(self):
if context.num_gpus():
return # int32 constants are always allocated on CPU.
self._benchmark_create_constant(42, dtype=dtypes.int32)
- @test_util.disable_tfrt("Scalars are not handled correctly")
def benchmark_create_int32_constant_uncached(self):
if context.num_gpus():
return # int32 constants are always allocated on CPU.
@@ -211,21 +207,17 @@
func() # Warmup.
self._run(func, 30000)
- @test_util.disable_tfrt("Scalars are not handled correctly")
def benchmark_add_float_scalars(self):
self._benchmark_add(42.0, 24.0)
- @test_util.disable_tfrt("Scalars are not handled correctly")
def benchmark_add_int32_scalars(self):
self._benchmark_add(42, 24)
- @test_util.disable_tfrt("Scalars are not handled correctly")
def benchmark_add_float_scalar_tensor(self):
tensor_a = constant_op.constant(42.0)
tensor_b = constant_op.constant(24.0)
self._benchmark_add(tensor_a, tensor_b)
- @test_util.disable_tfrt("Scalars are not handled correctly")
def benchmark_add_int32_scalar_tensor(self):
tensor_a = constant_op.constant(42)
tensor_b = constant_op.constant(24)
diff --git a/tensorflow/python/eager/benchmarks_test_base.py b/tensorflow/python/eager/benchmarks_test_base.py
index 552d844..3d81d08 100644
--- a/tensorflow/python/eager/benchmarks_test_base.py
+++ b/tensorflow/python/eager/benchmarks_test_base.py
@@ -32,4 +32,6 @@
"examples_per_sec": float("{0:.3f}".format(num_iters / total_time)),
"us_per_example": float("{0:.3f}".format(total_time * 1e6 / num_iters))
}
- self.report_benchmark(iters=num_iters, wall_time=mean_us, extras=extras)
+ benchmark_name = self._get_benchmark_name()
+ self.report_benchmark(
+ iters=num_iters, wall_time=mean_us, extras=extras, name=benchmark_name)
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 9b8f7cf..43652d5 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -62,6 +62,7 @@
from tensorflow.python.ops import control_flow_util
from tensorflow.python.platform import app
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.types import core as core_tf_types
from tensorflow.python.types import internal
from tensorflow.python.util import compat
from tensorflow.python.util import decorator_utils
@@ -213,53 +214,11 @@
return None
-_TENSOR_LIKE_TYPES = tuple()
-
-
+# Deprecated - do not use.
+# This API to avoid breaking estimator and tensorflow-mesh which depend on this
+# internal API. The stub should be safe to use after TF 2.3 is released.
def is_dense_tensor_like(t):
- """EXPERIMENTAL: Returns true if `t` implements the tensor interface.
-
- See `register_dense_tensor_like_type()` for the current definition of a
- "tensor-like type".
-
- Args:
- t: An object.
-
- Returns:
- True iff `t` is an instance of one of the registered "tensor-like" types.
- """
- return isinstance(t, _TENSOR_LIKE_TYPES)
-
-
-def register_dense_tensor_like_type(tensor_type):
- """EXPERIMENTAL: Registers `tensor_type` as implementing the tensor interface.
-
- A "tensor-like type" can represent a single dense tensor, and implements
- the `name`, `dtype` and `shape` properties.
-
- Args:
- tensor_type: A type implementing the tensor interface.
-
- Raises:
- TypeError: If `tensor_type` does not implement the tensor interface.
- """
- if not (hasattr(tensor_type, "name") and
- isinstance(tensor_type.name, property)):
- raise TypeError("Type %s does not define a `name` property" %
- tensor_type.__name__)
- if not (hasattr(tensor_type, "dtype") and
- isinstance(tensor_type.dtype, property)):
- raise TypeError("Type %s does not define a `dtype` property" %
- tensor_type.__name__)
- if not (hasattr(tensor_type, "shape") and
- isinstance(tensor_type.shape, property)):
- raise TypeError("Type %s does not define a `shape` property" %
- tensor_type.__name__)
- # We expect this list to be small, so choose quadratic complexity
- # for registration, so that we have a tuple that can be used for
- # more efficient `isinstance` checks later.
- global _TENSOR_LIKE_TYPES
- _TENSOR_LIKE_TYPES = tuple(list(_TENSOR_LIKE_TYPES) + [tensor_type])
+ return isinstance(t, core_tf_types.Tensor)
def uid():
@@ -304,7 +263,7 @@
# TODO(mdan): This object should subclass Symbol, not just Tensor.
@tf_export("Tensor")
-class Tensor(internal.NativeObject):
+class Tensor(internal.NativeObject, core_tf_types.Tensor):
"""A tensor is a multidimensional array of elements represented by a
`tf.Tensor` object. All elements are of a single known data type.
@@ -1305,9 +1264,6 @@
EagerTensor = pywrap_tfe.TFE_Py_InitEagerTensor(_EagerTensorBase)
-register_dense_tensor_like_type(Tensor)
-
-
@tf_export(v1=["convert_to_tensor"])
def convert_to_tensor_v1(value,
dtype=None,
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index 20f58a0..322df8f 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -3268,56 +3268,6 @@
test_ops.old()
-class DenseTensorLikeTypeTest(test_util.TensorFlowTestCase):
-
- @test_util.disable_tfrt("Graph is not supported yet.")
- def testSuccess(self):
- op = ops.Operation(
- ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
- t = op.outputs[0]
- self.assertTrue(ops.is_dense_tensor_like(t))
-
- v = variables.Variable([17])
- self.assertTrue(ops.is_dense_tensor_like(v))
-
- class BadClassNoName(object):
- pass
-
- class BadClassBadName(object):
-
- def name(self):
- pass
-
- class BadClassNoDtype(object):
-
- @property
- def name(self):
- pass
-
- class BadClassBadDtype(object):
-
- @property
- def name(self):
- pass
-
- def dtype(self):
- pass
-
- def testBadClass(self):
- with self.assertRaisesRegexp(TypeError, "`name`"):
- ops.register_dense_tensor_like_type(
- DenseTensorLikeTypeTest.BadClassNoName)
- with self.assertRaisesRegexp(TypeError, "`name`"):
- ops.register_dense_tensor_like_type(
- DenseTensorLikeTypeTest.BadClassBadName)
- with self.assertRaisesRegexp(TypeError, "`dtype`"):
- ops.register_dense_tensor_like_type(
- DenseTensorLikeTypeTest.BadClassNoDtype)
- with self.assertRaisesRegexp(TypeError, "`dtype`"):
- ops.register_dense_tensor_like_type(
- DenseTensorLikeTypeTest.BadClassBadDtype)
-
-
class NameScopeTest(test_util.TensorFlowTestCase):
def testStripAndPrependScope(self):
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py
index 5038859..968b635 100644
--- a/tensorflow/python/framework/tensor_util.py
+++ b/tensorflow/python/framework/tensor_util.py
@@ -26,6 +26,7 @@
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.types import core
from tensorflow.python.types import internal
from tensorflow.python.util import compat
from tensorflow.python.util import nest
@@ -1009,7 +1010,7 @@
`True` if `x` is a tensor or "tensor-like", `False` if not.
"""
return (isinstance(x, internal.NativeObject) or
- ops.is_dense_tensor_like(x) or
+ isinstance(x, core.Tensor) or
getattr(x, "is_tensor_like", False))
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 503f6cf..2700fae 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -393,6 +393,9 @@
False, shape=(), name='keras_learning_phase')
+@deprecated('2020-10-11',
+ 'Simply pass a True/False value to the `training` argument '
+ 'of the `__call__` method of your layer or model.')
@keras_export('keras.backend.set_learning_phase')
def set_learning_phase(value):
"""Sets the learning phase to a fixed value.
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index 6748a57..db326ea 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -307,14 +307,20 @@
end_hook_name = hook_name
begin_hook_name = 'on_{mode}_batch_begin'.format(mode=mode)
- threshold_time = 0.5 * batch_time
+ threshold_time = 1.5 * batch_time
warning_msg = ('Callbacks method `{hook}` is slow compared to '
- 'the batch time. Check your callbacks.')
+ 'the batch time (batch time: {batch_time:.4f}s vs '
+ '`{hook}` time: {cbk_time:.4f}s). Check your callbacks.')
if self._timing[begin_hook_name] > threshold_time:
- logging.warning(warning_msg.format(hook=begin_hook_name))
+ logging.warning(warning_msg.format(
+ hook=begin_hook_name,
+ batch_time=batch_time,
+ cbk_time=self._timing[begin_hook_name]))
if self._timing[end_hook_name] > threshold_time:
- logging.warning(warning_msg.format(hook=end_hook_name))
-
+ logging.warning(warning_msg.format(
+ hook=end_hook_name,
+ batch_time=batch_time,
+ cbk_time=self._timing[end_hook_name]))
self._check_timing = False
self._batch_start_time = None
diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py
index 9d15f87..2f1256e 100644
--- a/tensorflow/python/keras/callbacks_test.py
+++ b/tensorflow/python/keras/callbacks_test.py
@@ -302,8 +302,8 @@
epochs=10,
callbacks=[SleepCallback()])
warning_msg = ('Callbacks method `on_train_batch_end` is slow compared '
- 'to the batch time. Check your callbacks.')
- self.assertIn(warning_msg, warning_messages)
+ 'to the batch time')
+ self.assertIn(warning_msg, '\n'.join(warning_messages))
@keras_parameterized.run_with_all_model_types(exclude_models='functional')
@keras_parameterized.run_all_keras_modes
diff --git a/tensorflow/python/keras/engine/training_v1.py b/tensorflow/python/keras/engine/training_v1.py
index 0a40ce3..16188af 100644
--- a/tensorflow/python/keras/engine/training_v1.py
+++ b/tensorflow/python/keras/engine/training_v1.py
@@ -62,6 +62,7 @@
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
+from tensorflow.python.types import core
from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
@@ -3143,7 +3144,7 @@
The possibly-converted 'value'.
"""
if issparse is not None and issparse(value):
- if ops.is_dense_tensor_like(expected_input):
+ if isinstance(expected_input, core.Tensor):
if ops.executing_eagerly_outside_functions():
# In TF2 we do not silently densify sparse matrices.
raise ValueError('A SciPy sparse matrix was passed to a model '
diff --git a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py
index c43ca21..29e5a68 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py
@@ -23,9 +23,10 @@
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
+from tensorflow.python.types import core
-class AutoCastVariable(variables.Variable):
+class AutoCastVariable(variables.Variable, core.Tensor):
"""Variable that will cast itself to a different dtype in applicable contexts.
This class wraps a floating-point `tf.Variable`. It emulates the variable
@@ -417,7 +418,6 @@
ops.register_tensor_conversion_function(AutoCastVariable,
AutoCastVariable._dense_var_to_tensor) # pylint:disable=protected-access
-ops.register_dense_tensor_like_type(AutoCastVariable)
def create_autocast_variable(variable):
diff --git a/tensorflow/python/keras/preprocessing/BUILD b/tensorflow/python/keras/preprocessing/BUILD
index 403bc6e..24260fb 100644
--- a/tensorflow/python/keras/preprocessing/BUILD
+++ b/tensorflow/python/keras/preprocessing/BUILD
@@ -85,6 +85,7 @@
deps = [
":image",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/keras",
"//third_party/py/numpy",
],
)
diff --git a/tensorflow/python/keras/preprocessing/image.py b/tensorflow/python/keras/preprocessing/image.py
index 3af573f..953962c 100644
--- a/tensorflow/python/keras/preprocessing/image.py
+++ b/tensorflow/python/keras/preprocessing/image.py
@@ -14,6 +14,7 @@
# ==============================================================================
# pylint: disable=invalid-name
# pylint: disable=g-import-not-at-top
+# pylint: disable=g-classes-have-attributes
"""Set of tools for real-time data augmentation on image data.
"""
from __future__ import absolute_import
@@ -35,6 +36,7 @@
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import image_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import tf_logging
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import keras_export
@@ -49,6 +51,7 @@
apply_affine_transform = image.apply_affine_transform
+@keras_export('keras.preprocessing.image.smart_resize', v1=[])
def smart_resize(x, size, interpolation='bilinear'):
"""Resize images to a target size without aspect ratio distortion.
@@ -65,7 +68,7 @@
```
However, if you do this, you distort the aspect ratio of your images, since
- in general they do not all have the same aspect ratio. This is
+ in general they do not all have the same aspect ratio as `size`. This is
fine in many cases, but not always (e.g. for GANs this can be a problem).
Note that passing the argument `preserve_aspect_ratio=True` to `resize`
@@ -458,6 +461,123 @@
**kwargs)
+class DataFrameIterator(image.DataFrameIterator, Iterator):
+ """Iterator capable of reading images from a directory on disk as a dataframe.
+
+ Arguments:
+ dataframe: Pandas dataframe containing the filepaths relative to
+ `directory` (or absolute paths if `directory` is None) of the images in
+ a string column. It should include other column/s
+ depending on the `class_mode`: - if `class_mode` is `"categorical"`
+ (default value) it must include the `y_col` column with the class/es
+ of each image. Values in column can be string/list/tuple if a single
+ class or list/tuple if multiple classes. - if `class_mode` is
+ `"binary"` or `"sparse"` it must include the given `y_col` column
+ with class values as strings. - if `class_mode` is `"raw"` or
+ `"multi_output"` it should contain the columns specified in `y_col`.
+ - if `class_mode` is `"input"` or `None` no extra column is needed.
+ directory: string, path to the directory to read images from. If `None`,
+ data in `x_col` column should be absolute paths.
+ image_data_generator: Instance of `ImageDataGenerator` to use for random
+ transformations and normalization. If None, no transformations and
+ normalizations are made.
+ x_col: string, column in `dataframe` that contains the filenames (or
+ absolute paths if `directory` is `None`).
+ y_col: string or list, column/s in `dataframe` that has the target data.
+ weight_col: string, column in `dataframe` that contains the sample
+ weights. Default: `None`.
+ target_size: tuple of integers, dimensions to resize input images to.
+ color_mode: One of `"rgb"`, `"rgba"`, `"grayscale"`. Color mode to read
+ images.
+ classes: Optional list of strings, classes to use (e.g. `["dogs",
+ "cats"]`). If None, all classes in `y_col` will be used.
+ class_mode: one of "binary", "categorical", "input", "multi_output",
+ "raw", "sparse" or None. Default: "categorical".
+ Mode for yielding the targets:
+ - `"binary"`: 1D numpy array of binary labels,
+ - `"categorical"`: 2D numpy array of one-hot encoded labels. Supports
+ multi-label output.
+ - `"input"`: images identical to input images (mainly used to work
+ with autoencoders),
+ - `"multi_output"`: list with the values of the different columns,
+ - `"raw"`: numpy array of values in `y_col` column(s),
+ - `"sparse"`: 1D numpy array of integer labels, - `None`, no targets
+ are returned (the generator will only yield batches of image data,
+ which is useful to use in `model.predict_generator()`).
+ batch_size: Integer, size of a batch.
+ shuffle: Boolean, whether to shuffle the data between epochs.
+ seed: Random seed for data shuffling.
+ data_format: String, one of `channels_first`, `channels_last`.
+ save_to_dir: Optional directory where to save the pictures being yielded,
+ in a viewable format. This is useful for visualizing the random
+ transformations being applied, for debugging purposes.
+ save_prefix: String prefix to use for saving sample images (if
+ `save_to_dir` is set).
+ save_format: Format to use for saving sample images (if `save_to_dir` is
+ set).
+ subset: Subset of data (`"training"` or `"validation"`) if
+ validation_split is set in ImageDataGenerator.
+ interpolation: Interpolation method used to resample the image if the
+ target size is different from that of the loaded image. Supported
+ methods are "nearest", "bilinear", and "bicubic". If PIL version 1.1.3
+ or newer is installed, "lanczos" is also supported. If PIL version 3.4.0
+ or newer is installed, "box" and "hamming" are also supported. By
+ default, "nearest" is used.
+ dtype: Dtype to use for the generated arrays.
+ validate_filenames: Boolean, whether to validate image filenames in
+ `x_col`. If `True`, invalid images will be ignored. Disabling this
+ option
+ can lead to speed-up in the instantiation of this class. Default: `True`.
+ """
+
+ def __init__(
+ self,
+ dataframe,
+ directory=None,
+ image_data_generator=None,
+ x_col='filename',
+ y_col='class',
+ weight_col=None,
+ target_size=(256, 256),
+ color_mode='rgb',
+ classes=None,
+ class_mode='categorical',
+ batch_size=32,
+ shuffle=True,
+ seed=None,
+ data_format='channels_last',
+ save_to_dir=None,
+ save_prefix='',
+ save_format='png',
+ subset=None,
+ interpolation='nearest',
+ dtype='float32',
+ validate_filenames=True):
+ super(DataFrameIterator, self).__init__(
+ dataframe=dataframe,
+ directory=directory,
+ image_data_generator=image_data_generator,
+ x_col=x_col,
+ y_col=y_col,
+ weight_col=weight_col,
+ target_size=target_size,
+ color_mode=color_mode,
+ classes=classes,
+ class_mode=class_mode,
+ batch_size=batch_size,
+ shuffle=shuffle,
+ seed=seed,
+ data_format=data_format,
+ save_to_dir=save_to_dir,
+ save_prefix=save_prefix,
+ save_format=save_format,
+ subset=subset,
+ interpolation=interpolation,
+ dtype=dtype,
+ validate_filenames=validate_filenames
+ )
+
+
@keras_export('keras.preprocessing.image.ImageDataGenerator')
class ImageDataGenerator(image.ImageDataGenerator):
"""Generate batches of tensor image data with real-time data augmentation.
@@ -685,6 +805,302 @@
validation_split=validation_split,
**kwargs)
+ def flow(self,
+ x,
+ y=None,
+ batch_size=32,
+ shuffle=True,
+ sample_weight=None,
+ seed=None,
+ save_to_dir=None,
+ save_prefix='',
+ save_format='png',
+ subset=None):
+ """Takes data & label arrays, generates batches of augmented data.
+
+ Arguments:
+ x: Input data. Numpy array of rank 4 or a tuple. If tuple, the first
+ element should contain the images and the second element another numpy
+ array or a list of numpy arrays that gets passed to the output without
+ any modifications. Can be used to feed the model miscellaneous data
+ along with the images. In case of grayscale data, the channels axis of
+ the image array should have value 1, in case of RGB data, it should
+ have value 3, and in case of RGBA data, it should have value 4.
+ y: Labels.
+ batch_size: Int (default: 32).
+ shuffle: Boolean (default: True).
+ sample_weight: Sample weights.
+ seed: Int (default: None).
+ save_to_dir: None or str (default: None). This allows you to optionally
+ specify a directory to which to save the augmented pictures being
+ generated (useful for visualizing what you are doing).
+ save_prefix: Str (default: `''`). Prefix to use for filenames of saved
+ pictures (only relevant if `save_to_dir` is set).
+ save_format: one of "png", "jpeg"
+ (only relevant if `save_to_dir` is set). Default: "png".
+ subset: Subset of data (`"training"` or `"validation"`) if
+ `validation_split` is set in `ImageDataGenerator`.
+
+ Returns:
+ An `Iterator` yielding tuples of `(x, y)`
+ where `x` is a numpy array of image data
+ (in the case of a single image input) or a list
+ of numpy arrays (in the case with
+ additional inputs) and `y` is a numpy array
+ of corresponding labels. If 'sample_weight' is not None,
+ the yielded tuples are of the form `(x, y, sample_weight)`.
+ If `y` is None, only the numpy array `x` is returned.
+ """
+ return NumpyArrayIterator(
+ x,
+ y,
+ self,
+ batch_size=batch_size,
+ shuffle=shuffle,
+ sample_weight=sample_weight,
+ seed=seed,
+ data_format=self.data_format,
+ save_to_dir=save_to_dir,
+ save_prefix=save_prefix,
+ save_format=save_format,
+ subset=subset)
+
+ def flow_from_directory(self,
+ directory,
+ target_size=(256, 256),
+ color_mode='rgb',
+ classes=None,
+ class_mode='categorical',
+ batch_size=32,
+ shuffle=True,
+ seed=None,
+ save_to_dir=None,
+ save_prefix='',
+ save_format='png',
+ follow_links=False,
+ subset=None,
+ interpolation='nearest'):
+ """Takes the path to a directory & generates batches of augmented data.
+
+ Arguments:
+ directory: string, path to the target directory. It should contain one
+ subdirectory per class. Any PNG, JPG, BMP, PPM or TIF images inside
+ each of the subdirectories directory tree will be included in the
+ generator. See [this script](
+ https://gist.github.com/fchollet/0830affa1f7f19fd47b06d4cf89ed44d)
+ for more details.
+ target_size: Tuple of integers `(height, width)`, defaults to `(256,
+ 256)`. The dimensions to which all images found will be resized.
+ color_mode: One of "grayscale", "rgb", "rgba". Default: "rgb". Whether
+ the images will be converted to have 1, 3, or 4 channels.
+ classes: Optional list of class subdirectories
+ (e.g. `['dogs', 'cats']`). Default: None. If not provided, the list
+ of classes will be automatically inferred from the subdirectory
+ names/structure under `directory`, where each subdirectory will be
+ treated as a different class (and the order of the classes, which
+ will map to the label indices, will be alphanumeric). The
+ dictionary containing the mapping from class names to class
+ indices can be obtained via the attribute `class_indices`.
+ class_mode: One of "categorical", "binary", "sparse",
+ "input", or None. Default: "categorical".
+ Determines the type of label arrays that are returned: -
+ "categorical" will be 2D one-hot encoded labels, - "binary" will
+ be 1D binary labels, "sparse" will be 1D integer labels, - "input"
+ will be images identical to input images (mainly used to work with
+ autoencoders). - If None, no labels are returned (the generator
+ will only yield batches of image data, which is useful to use with
+ `model.predict_generator()`). Please note that in case of
+ class_mode None, the data still needs to reside in a subdirectory
+ of `directory` for it to work correctly.
+ batch_size: Size of the batches of data (default: 32).
+ shuffle: Whether to shuffle the data (default: True) If set to False,
+ sorts the data in alphanumeric order.
+ seed: Optional random seed for shuffling and transformations.
+ save_to_dir: None or str (default: None). This allows you to optionally
+ specify a directory to which to save the augmented pictures being
+ generated (useful for visualizing what you are doing).
+ save_prefix: Str. Prefix to use for filenames of saved pictures (only
+ relevant if `save_to_dir` is set).
+ save_format: One of "png", "jpeg"
+ (only relevant if `save_to_dir` is set). Default: "png".
+ follow_links: Whether to follow symlinks inside
+ class subdirectories (default: False).
+ subset: Subset of data (`"training"` or `"validation"`) if
+ `validation_split` is set in `ImageDataGenerator`.
+ interpolation: Interpolation method used to resample the image if the
+ target size is different from that of the loaded image. Supported
+ methods are `"nearest"`, `"bilinear"`, and `"bicubic"`. If PIL version
+ 1.1.3 or newer is installed, `"lanczos"` is also supported. If PIL
+ version 3.4.0 or newer is installed, `"box"` and `"hamming"` are also
+ supported. By default, `"nearest"` is used.
+
+ Returns:
+ A `DirectoryIterator` yielding tuples of `(x, y)`
+ where `x` is a numpy array containing a batch
+ of images with shape `(batch_size, *target_size, channels)`
+ and `y` is a numpy array of corresponding labels.
+ """
+ return DirectoryIterator(
+ directory,
+ self,
+ target_size=target_size,
+ color_mode=color_mode,
+ classes=classes,
+ class_mode=class_mode,
+ data_format=self.data_format,
+ batch_size=batch_size,
+ shuffle=shuffle,
+ seed=seed,
+ save_to_dir=save_to_dir,
+ save_prefix=save_prefix,
+ save_format=save_format,
+ follow_links=follow_links,
+ subset=subset,
+ interpolation=interpolation)
+
+ def flow_from_dataframe(self,
+ dataframe,
+ directory=None,
+ x_col='filename',
+ y_col='class',
+ weight_col=None,
+ target_size=(256, 256),
+ color_mode='rgb',
+ classes=None,
+ class_mode='categorical',
+ batch_size=32,
+ shuffle=True,
+ seed=None,
+ save_to_dir=None,
+ save_prefix='',
+ save_format='png',
+ subset=None,
+ interpolation='nearest',
+ validate_filenames=True,
+ **kwargs):
+ """Takes the dataframe and the path to a directory + generates batches.
+
+ The generated batches contain augmented/normalized data.
+
+ **A simple tutorial can be found **[here](
+ http://bit.ly/keras_flow_from_dataframe).
+
+ Arguments:
+ dataframe: Pandas dataframe containing the filepaths relative to
+ `directory` (or absolute paths if `directory` is None) of the images
+ in a string column. It should include other column/s
+ depending on the `class_mode`: - if `class_mode` is `"categorical"`
+ (default value) it must include the `y_col` column with the
+ class/es of each image. Values in column can be string/list/tuple
+ if a single class or list/tuple if multiple classes. - if
+ `class_mode` is `"binary"` or `"sparse"` it must include the given
+ `y_col` column with class values as strings. - if `class_mode` is
+ `"raw"` or `"multi_output"` it should contain the columns
+ specified in `y_col`. - if `class_mode` is `"input"` or `None` no
+ extra column is needed.
+ directory: string, path to the directory to read images from. If `None`,
+ data in `x_col` column should be absolute paths.
+ x_col: string, column in `dataframe` that contains the filenames (or
+ absolute paths if `directory` is `None`).
+ y_col: string or list, column/s in `dataframe` that has the target data.
+ weight_col: string, column in `dataframe` that contains the sample
+ weights. Default: `None`.
+ target_size: tuple of integers `(height, width)`, default: `(256, 256)`.
+ The dimensions to which all images found will be resized.
+ color_mode: one of "grayscale", "rgb", "rgba". Default: "rgb". Whether
+ the images will be converted to have 1 or 3 color channels.
+ classes: optional list of classes (e.g. `['dogs', 'cats']`). Default is
+ None. If not provided, the list of classes will be automatically
+ inferred from the `y_col`, which will map to the label indices, will
+ be alphanumeric). The dictionary containing the mapping from class
+ names to class indices can be obtained via the attribute
+ `class_indices`.
+ class_mode: one of "binary", "categorical", "input", "multi_output",
+ "raw", sparse" or None. Default: "categorical".
+ Mode for yielding the targets:
+ - `"binary"`: 1D numpy array of binary labels,
+ - `"categorical"`: 2D numpy array of one-hot encoded labels.
+ Supports multi-label output.
+ - `"input"`: images identical to input images (mainly used to work
+ with autoencoders),
+ - `"multi_output"`: list with the values of the different columns,
+ - `"raw"`: numpy array of values in `y_col` column(s),
+ - `"sparse"`: 1D numpy array of integer labels, - `None`, no targets
+ are returned (the generator will only yield batches of image data,
+ which is useful to use in `model.predict_generator()`).
+ batch_size: size of the batches of data (default: 32).
+ shuffle: whether to shuffle the data (default: True)
+ seed: optional random seed for shuffling and transformations.
+ save_to_dir: None or str (default: None). This allows you to optionally
+ specify a directory to which to save the augmented pictures being
+ generated (useful for visualizing what you are doing).
+ save_prefix: str. Prefix to use for filenames of saved pictures (only
+ relevant if `save_to_dir` is set).
+ save_format: one of "png", "jpeg"
+ (only relevant if `save_to_dir` is set). Default: "png".
+ subset: Subset of data (`"training"` or `"validation"`) if
+ `validation_split` is set in `ImageDataGenerator`.
+ interpolation: Interpolation method used to resample the image if the
+ target size is different from that of the loaded image. Supported
+ methods are `"nearest"`, `"bilinear"`, and `"bicubic"`. If PIL version
+ 1.1.3 or newer is installed, `"lanczos"` is also supported. If PIL
+ version 3.4.0 or newer is installed, `"box"` and `"hamming"` are also
+ supported. By default, `"nearest"` is used.
+ validate_filenames: Boolean, whether to validate image filenames in
+ `x_col`. If `True`, invalid images will be ignored. Disabling this
+ option can lead to speed-up in the execution of this function.
+ Defaults to `True`.
+ **kwargs: legacy arguments for raising deprecation warnings.
+
+ Returns:
+ A `DataFrameIterator` yielding tuples of `(x, y)`
+ where `x` is a numpy array containing a batch
+ of images with shape `(batch_size, *target_size, channels)`
+ and `y` is a numpy array of corresponding labels.
+ """
+ if 'has_ext' in kwargs:
+ tf_logging.warn(
+ 'has_ext is deprecated, filenames in the dataframe have '
+ 'to match the exact filenames in disk.', DeprecationWarning)
+ if 'sort' in kwargs:
+ tf_logging.warn(
+ 'sort is deprecated, batches will be created in the'
+ 'same order than the filenames provided if shuffle'
+ 'is set to False.', DeprecationWarning)
+ if class_mode == 'other':
+ tf_logging.warn(
+ '`class_mode` "other" is deprecated, please use '
+ '`class_mode` "raw".', DeprecationWarning)
+ class_mode = 'raw'
+ if 'drop_duplicates' in kwargs:
+ tf_logging.warn(
+ 'drop_duplicates is deprecated, you can drop duplicates '
+ 'by using the pandas.DataFrame.drop_duplicates method.',
+ DeprecationWarning)
+
+ return DataFrameIterator(
+ dataframe,
+ directory,
+ self,
+ x_col=x_col,
+ y_col=y_col,
+ weight_col=weight_col,
+ target_size=target_size,
+ color_mode=color_mode,
+ classes=classes,
+ class_mode=class_mode,
+ data_format=self.data_format,
+ batch_size=batch_size,
+ shuffle=shuffle,
+ seed=seed,
+ save_to_dir=save_to_dir,
+ save_prefix=save_prefix,
+ save_format=save_format,
+ subset=subset,
+ interpolation=interpolation,
+ validate_filenames=validate_filenames)
+
+
keras_export('keras.preprocessing.image.random_rotation')(random_rotation)
keras_export('keras.preprocessing.image.random_shift')(random_shift)
keras_export('keras.preprocessing.image.random_shear')(random_shear)
diff --git a/tensorflow/python/keras/preprocessing/image_test.py b/tensorflow/python/keras/preprocessing/image_test.py
index d7da420..a577381 100644
--- a/tensorflow/python/keras/preprocessing/image_test.py
+++ b/tensorflow/python/keras/preprocessing/image_test.py
@@ -25,6 +25,9 @@
import numpy as np
from tensorflow.python.framework import test_util
+from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras import layers
+from tensorflow.python.keras.engine import sequential
from tensorflow.python.keras.preprocessing import image as preprocessing_image
from tensorflow.python.platform import test
@@ -52,7 +55,7 @@
return [rgb_images, gray_images]
-class TestImage(test.TestCase):
+class TestImage(keras_parameterized.TestCase):
@test_util.run_v2_only
def test_smart_resize(self):
@@ -319,14 +322,21 @@
self.assertEqual(
len(set(train_iterator.filenames) & set(filenames)), num_training)
+ model = sequential.Sequential([layers.Flatten(), layers.Dense(2)])
+ model.compile(optimizer='sgd', loss='mse')
+ model.fit(train_iterator, epochs=1)
+
shutil.rmtree(tmp_folder)
+ @keras_parameterized.run_all_keras_modes
def test_directory_iterator_with_validation_split_25_percent(self):
self.directory_iterator_with_validation_split_test_helper(0.25)
+ @keras_parameterized.run_all_keras_modes
def test_directory_iterator_with_validation_split_40_percent(self):
self.directory_iterator_with_validation_split_test_helper(0.40)
+ @keras_parameterized.run_all_keras_modes
def test_directory_iterator_with_validation_split_50_percent(self):
self.directory_iterator_with_validation_split_test_helper(0.50)
diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
index 3387923..5d58795 100644
--- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py
+++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
@@ -1021,7 +1021,7 @@
# self._testWhileLoopWritePackGradients(
# dynamic_size=False, dtype=tf.int64)
- @test_util.run_v1_only("b/117943489")
+ @test_util.run_deprecated_v1
def testSkipEagerWhileLoopDynamicWritePackGradients(self):
self._testWhileLoopWritePackGradients(
dynamic_size=True, dtype=dtypes.float32)
@@ -1251,7 +1251,6 @@
with self.assertRaises(ValueError):
w1.write(4, c2)
- @test_util.run_v1_only("b/117943489")
def testUnpackShape(self):
self._testUnpackShape()
@@ -1340,11 +1339,11 @@
grad = gradients_impl.gradients(ys=[r], xs=[x])
self.assertAllEqual(np.array([1.0, 1.0, 1.0]), self.evaluate(grad)[0])
- @test_util.run_v1_only("b/117943489")
+ @test_util.run_deprecated_v1
def testSkipEagerTensorArrayUnpackDynamic(self):
self._testTensorArrayUnpackDynamic()
- @test_util.run_v1_only("b/117943489")
+ @test_util.run_deprecated_v1
def testSkipEagerTensorArraySplitDynamic(self):
with self.session(use_gpu=True) as sess:
ta = tensor_array_ops.TensorArray(
@@ -1422,7 +1421,7 @@
v2_msg if control_flow_util.ENABLE_CONTROL_FLOW_V2 else v1_msg):
ta.stack().eval()
- @test_util.run_v1_only("b/120545219")
+ @test_util.run_deprecated_v1
def testSkipEagerTensorArrayEvalEmpty(self):
self._testTensorArrayEvalEmpty()
@@ -1445,11 +1444,11 @@
# first dimension of zero
self.assertAllEqual([0, 5], self.evaluate(concatenated).shape)
- @test_util.run_v1_only("b/117943489")
+ @test_util.run_deprecated_v1
def testSkipEagerTensorArrayEvalEmptyWithDefault(self):
self._testTensorArrayEvalEmptyWithDefault()
- @test_util.run_v1_only("b/117943489")
+ @test_util.run_deprecated_v1
def testSkipEagerTensorArrayScatterReadAndGradients(self):
with self.session(use_gpu=True) as session:
ta = tensor_array_ops.TensorArray(
@@ -1476,7 +1475,7 @@
self.assertAllEqual([10.0, -10.0], read_vals[1])
self.assertAllEqual([[2.0, 3.0], [4.0, 5.0]], grad_vals[0])
- @test_util.run_v1_only("b/117943489")
+ @test_util.run_deprecated_v1
def testSkipEagerTensorArrayScatterPartialReadAndGradients(self):
with self.session(use_gpu=True) as session:
ta = tensor_array_ops.TensorArray(
diff --git a/tensorflow/python/lib/core/pybind11_status.h b/tensorflow/python/lib/core/pybind11_status.h
index feb9747..3f9991c 100644
--- a/tensorflow/python/lib/core/pybind11_status.h
+++ b/tensorflow/python/lib/core/pybind11_status.h
@@ -69,6 +69,20 @@
}
}
+inline void MaybeRaiseRegisteredFromStatusWithGIL(
+ const tensorflow::Status& status) {
+ if (!status.ok()) {
+ // Acquire GIL for throwing exception.
+ pybind11::gil_scoped_acquire acquire;
+
+ PyErr_SetObject(PyExceptionRegistry::Lookup(status.code()),
+ pybind11::make_tuple(pybind11::none(), pybind11::none(),
+ status.error_message())
+ .ptr());
+ throw pybind11::error_already_set();
+ }
+}
+
inline void MaybeRaiseFromTFStatus(TF_Status* status) {
TF_Code code = TF_GetCode(status);
if (code != TF_OK) {
diff --git a/tensorflow/python/lib/io/file_io_wrapper.cc b/tensorflow/python/lib/io/file_io_wrapper.cc
index de806a9..0a2410b 100644
--- a/tensorflow/python/lib/io/file_io_wrapper.cc
+++ b/tensorflow/python/lib/io/file_io_wrapper.cc
@@ -42,50 +42,65 @@
py::gil_scoped_release release;
status = tensorflow::Env::Default()->FileExists(filename);
}
- tensorflow::MaybeRaiseRegisteredFromStatus(status);
+ tensorflow::MaybeRaiseRegisteredFromStatusWithGIL(status);
});
m.def("DeleteFile", [](const std::string& filename) {
- tensorflow::MaybeRaiseRegisteredFromStatus(
- tensorflow::Env::Default()->DeleteFile(filename));
+ py::gil_scoped_release release;
+ tensorflow::Status status =
+ tensorflow::Env::Default()->DeleteFile(filename);
+ tensorflow::MaybeRaiseRegisteredFromStatusWithGIL(status);
});
m.def("ReadFileToString", [](const std::string& filename) {
std::string data;
+ py::gil_scoped_release release;
const auto status =
ReadFileToString(tensorflow::Env::Default(), filename, &data);
+ pybind11::gil_scoped_acquire acquire;
tensorflow::MaybeRaiseRegisteredFromStatus(status);
return py::bytes(data);
});
m.def("WriteStringToFile",
[](const std::string& filename, tensorflow::StringPiece data) {
- return WriteStringToFile(tensorflow::Env::Default(), filename, data);
+ py::gil_scoped_release release;
+ const auto status =
+ WriteStringToFile(tensorflow::Env::Default(), filename, data);
+ tensorflow::MaybeRaiseRegisteredFromStatusWithGIL(status);
});
m.def("GetChildren", [](const std::string& dirname) {
std::vector<std::string> results;
+ py::gil_scoped_release release;
const auto status =
tensorflow::Env::Default()->GetChildren(dirname, &results);
+ pybind11::gil_scoped_acquire acquire;
tensorflow::MaybeRaiseRegisteredFromStatus(status);
return results;
});
m.def("GetMatchingFiles", [](const std::string& pattern) {
std::vector<std::string> results;
+ py::gil_scoped_release release;
const auto status =
tensorflow::Env::Default()->GetMatchingPaths(pattern, &results);
+ pybind11::gil_scoped_acquire acquire;
tensorflow::MaybeRaiseRegisteredFromStatus(status);
return results;
});
m.def("CreateDir", [](const std::string& dirname) {
+ py::gil_scoped_release release;
const auto status = tensorflow::Env::Default()->CreateDir(dirname);
if (tensorflow::errors::IsAlreadyExists(status)) {
return;
}
- tensorflow::MaybeRaiseRegisteredFromStatus(status);
+ tensorflow::MaybeRaiseRegisteredFromStatusWithGIL(status);
});
m.def("RecursivelyCreateDir", [](const std::string& dirname) {
- tensorflow::MaybeRaiseRegisteredFromStatus(
- tensorflow::Env::Default()->RecursivelyCreateDir(dirname));
+ py::gil_scoped_release release;
+ const auto status =
+ tensorflow::Env::Default()->RecursivelyCreateDir(dirname);
+ tensorflow::MaybeRaiseRegisteredFromStatusWithGIL(status);
});
m.def("CopyFile",
[](const std::string& src, const std::string& target, bool overwrite) {
+ py::gil_scoped_release release;
auto* env = tensorflow::Env::Default();
tensorflow::Status status;
if (!overwrite && env->FileExists(target).ok()) {
@@ -93,10 +108,11 @@
} else {
status = env->CopyFile(src, target);
}
- tensorflow::MaybeRaiseRegisteredFromStatus(status);
+ tensorflow::MaybeRaiseRegisteredFromStatusWithGIL(status);
});
m.def("RenameFile",
[](const std::string& src, const std::string& target, bool overwrite) {
+ py::gil_scoped_release release;
auto* env = tensorflow::Env::Default();
tensorflow::Status status;
if (!overwrite && env->FileExists(target).ok()) {
@@ -104,9 +120,10 @@
} else {
status = env->RenameFile(src, target);
}
- tensorflow::MaybeRaiseRegisteredFromStatus(status);
+ tensorflow::MaybeRaiseRegisteredFromStatusWithGIL(status);
});
m.def("DeleteRecursively", [](const std::string& dirname) {
+ py::gil_scoped_release release;
tensorflow::int64 undeleted_files;
tensorflow::int64 undeleted_dirs;
auto status = tensorflow::Env::Default()->DeleteRecursively(
@@ -115,23 +132,25 @@
status =
tensorflow::errors::PermissionDenied("could not fully delete dir");
}
- tensorflow::MaybeRaiseRegisteredFromStatus(status);
+ tensorflow::MaybeRaiseRegisteredFromStatusWithGIL(status);
});
m.def("IsDirectory", [](const std::string& dirname) {
+ py::gil_scoped_release release;
const auto status = tensorflow::Env::Default()->IsDirectory(dirname);
// FAILED_PRECONDITION response means path exists but isn't a dir.
if (tensorflow::errors::IsFailedPrecondition(status)) {
return false;
}
- tensorflow::MaybeRaiseRegisteredFromStatus(status);
+ tensorflow::MaybeRaiseRegisteredFromStatusWithGIL(status);
return true;
});
m.def("HasAtomicMove", [](const std::string& path) {
+ py::gil_scoped_release release;
bool has_atomic_move;
const auto status =
tensorflow::Env::Default()->HasAtomicMove(path, &has_atomic_move);
- tensorflow::MaybeRaiseRegisteredFromStatus(status);
+ tensorflow::MaybeRaiseRegisteredFromStatusWithGIL(status);
return has_atomic_move;
});
@@ -141,9 +160,11 @@
.def_readonly("is_directory", &tensorflow::FileStatistics::is_directory);
m.def("Stat", [](const std::string& filename) {
+ py::gil_scoped_release release;
std::unique_ptr<tensorflow::FileStatistics> self(
new tensorflow::FileStatistics);
const auto status = tensorflow::Env::Default()->Stat(filename, self.get());
+ py::gil_scoped_acquire acquire;
tensorflow::MaybeRaiseRegisteredFromStatus(status);
return self.release();
});
@@ -151,66 +172,83 @@
using tensorflow::WritableFile;
py::class_<WritableFile>(m, "WritableFile")
.def(py::init([](const std::string& filename, const std::string& mode) {
+ py::gil_scoped_release release;
auto* env = tensorflow::Env::Default();
std::unique_ptr<WritableFile> self;
const auto status = mode.find("a") == std::string::npos
? env->NewWritableFile(filename, &self)
: env->NewAppendableFile(filename, &self);
+ py::gil_scoped_acquire acquire;
tensorflow::MaybeRaiseRegisteredFromStatus(status);
return self.release();
}))
.def("append",
[](WritableFile* self, tensorflow::StringPiece data) {
- tensorflow::MaybeRaiseRegisteredFromStatus(self->Append(data));
+ const auto status = self->Append(data);
+ tensorflow::MaybeRaiseRegisteredFromStatusWithGIL(status);
})
// TODO(slebedev): Make WritableFile::Tell const and change self
// to be a reference.
.def("tell",
[](WritableFile* self) {
tensorflow::int64 pos = -1;
+ py::gil_scoped_release release;
const auto status = self->Tell(&pos);
- tensorflow::MaybeRaiseRegisteredFromStatus(status);
+ tensorflow::MaybeRaiseRegisteredFromStatusWithGIL(status);
return pos;
})
.def("flush",
[](WritableFile* self) {
- tensorflow::MaybeRaiseRegisteredFromStatus(self->Flush());
+ py::gil_scoped_release release;
+ tensorflow::MaybeRaiseRegisteredFromStatusWithGIL(self->Flush());
})
.def("close", [](WritableFile* self) {
- tensorflow::MaybeRaiseRegisteredFromStatus(self->Close());
+ py::gil_scoped_release release;
+ tensorflow::MaybeRaiseRegisteredFromStatusWithGIL(self->Close());
});
using tensorflow::io::BufferedInputStream;
py::class_<BufferedInputStream>(m, "BufferedInputStream")
.def(py::init([](const std::string& filename, size_t buffer_size) {
+ py::gil_scoped_release release;
std::unique_ptr<tensorflow::RandomAccessFile> file;
const auto status =
tensorflow::Env::Default()->NewRandomAccessFile(filename, &file);
- tensorflow::MaybeRaiseRegisteredFromStatus(status);
+ tensorflow::MaybeRaiseRegisteredFromStatusWithGIL(status);
std::unique_ptr<tensorflow::io::RandomAccessInputStream> input_stream(
new tensorflow::io::RandomAccessInputStream(file.release(),
/*owns_file=*/true));
+ py::gil_scoped_acquire acquire;
return new BufferedInputStream(input_stream.release(), buffer_size,
/*owns_input_stream=*/true);
}))
.def("read",
[](BufferedInputStream* self, tensorflow::int64 bytes_to_read) {
+ py::gil_scoped_release release;
tensorflow::tstring result;
const auto status = self->ReadNBytes(bytes_to_read, &result);
if (!status.ok() && !tensorflow::errors::IsOutOfRange(status)) {
result.clear();
tensorflow::MaybeRaiseRegisteredFromStatus(status);
}
+ py::gil_scoped_acquire acquire;
return py::bytes(result);
})
.def("readline",
[](BufferedInputStream* self) {
- return py::bytes(self->ReadLineAsString());
+ py::gil_scoped_release release;
+ auto output = self->ReadLineAsString();
+ py::gil_scoped_acquire acquire;
+ return py::bytes(output);
})
.def("seek",
[](BufferedInputStream* self, tensorflow::int64 pos) {
- tensorflow::MaybeRaiseRegisteredFromStatus(self->Seek(pos));
+ py::gil_scoped_release release;
+ tensorflow::MaybeRaiseRegisteredFromStatusWithGIL(self->Seek(pos));
})
- .def("tell", [](BufferedInputStream* self) { return self->Tell(); });
+ .def("tell", [](BufferedInputStream* self) {
+ py::gil_scoped_release release;
+ return self->Tell();
+ });
}
} // namespace
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 33aac84..1cb6fdb 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -39,6 +39,7 @@
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_array_ops import *
from tensorflow.python.ops.gen_array_ops import reverse_v2 as reverse # pylint: disable=unused-import
+from tensorflow.python.types import core
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util import nest
@@ -1381,13 +1382,13 @@
if context.executing_eagerly():
# NOTE: Fast path when all the items are tensors, this doesn't do any type
# checking.
- if all(ops.is_dense_tensor_like(elem) for elem in list_or_tuple):
+ if all(isinstance(elem, core.Tensor) for elem in list_or_tuple):
return gen_array_ops.pack(list_or_tuple, name=name)
must_pack = False
converted_elems = []
with ops.name_scope(name) as scope:
for i, elem in enumerate(list_or_tuple):
- if ops.is_dense_tensor_like(elem):
+ if isinstance(elem, core.Tensor):
if dtype is not None and elem.dtype.base_dtype != dtype:
raise TypeError("Cannot convert a list containing a tensor of dtype "
"%s to %s (Tensor is: %r)" %
@@ -1396,7 +1397,7 @@
must_pack = True
elif isinstance(elem, (list, tuple)):
converted_elem = _autopacking_helper(elem, dtype, str(i))
- if ops.is_dense_tensor_like(converted_elem):
+ if isinstance(converted_elem, core.Tensor):
must_pack = True
converted_elems.append(converted_elem)
else:
@@ -1404,7 +1405,7 @@
if must_pack:
elems_as_tensors = []
for i, elem in enumerate(converted_elems):
- if ops.is_dense_tensor_like(elem):
+ if isinstance(elem, core.Tensor):
elems_as_tensors.append(elem)
else:
# NOTE(mrry): This is inefficient, but it enables us to
@@ -1429,7 +1430,7 @@
such object exists.
"""
for elem in list_or_tuple:
- if ops.is_dense_tensor_like(elem):
+ if isinstance(elem, core.Tensor):
return elem.dtype.base_dtype
elif isinstance(elem, (list, tuple)):
maybe_dtype = _get_dtype_from_nested_lists(elem)
@@ -1441,7 +1442,7 @@
def _cast_nested_seqs_to_dtype(dtype):
def _maybe_cast(elem):
- if ops.is_dense_tensor_like(elem):
+ if isinstance(elem, core.Tensor):
if dtype != elem.dtype.base_dtype:
elem = gen_math_ops.cast(elem, dtype)
return elem
@@ -1455,7 +1456,7 @@
def _should_not_autopack(v):
# The condition we really want is
- # ops.is_dense_tensor_like(...)
+ # any(isinstance(elem, core.Tensor))
# but it is >5x slower due to abc.ABCMeta.__instancecheck__.
# pylint: disable=unidiomatic-typecheck
# TODO(slebedev): add nest.all?
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index de5be20..248c57c 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -45,6 +45,7 @@
# pylint: enable=wildcard-import
from tensorflow.python.platform import device_context
from tensorflow.python.util import deprecation
+from tensorflow.python.util import dispatch
from tensorflow.python.util.compat import collections_abc
from tensorflow.python.util.deprecation import deprecated_args
from tensorflow.python.util.deprecation import deprecated_argument_lookup
@@ -4513,6 +4514,7 @@
@tf_export(v1=["nn.dropout"])
+@dispatch.add_dispatch_support
@deprecation.deprecated_args(None, "Please use `rate` instead of `keep_prob`. "
"Rate should be set to `rate = 1 - keep_prob`.",
"keep_prob")
@@ -4567,6 +4569,7 @@
@tf_export("nn.dropout", v1=[])
+@dispatch.add_dispatch_support
def dropout_v2(x, rate, noise_shape=None, seed=None, name=None):
"""Computes dropout: randomly sets elements to zero to prevent overfitting.
diff --git a/tensorflow/python/ops/ragged/ragged_dispatch.py b/tensorflow/python/ops/ragged/ragged_dispatch.py
index dd5bd78..f13bed0 100644
--- a/tensorflow/python/ops/ragged/ragged_dispatch.py
+++ b/tensorflow/python/ops/ragged/ragged_dispatch.py
@@ -30,6 +30,7 @@
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gen_bitwise_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variables
@@ -453,6 +454,26 @@
num_partitions, name)
return [result[i] for i in range(num_partitions)]
+
+def _ragged_nn_dropout_v1(x, keep_prob=None, noise_shape=None, seed=None,
+ name=None, rate=None):
+ if noise_shape is not None:
+ raise ValueError('noise_shape is not supported yet for RaggedTensor x')
+ with ops.name_scope(name, 'RaggedNNDropout', [x, rate]):
+ x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name='x')
+ return x.with_flat_values(nn_ops.dropout(x.flat_values, keep_prob=keep_prob,
+ seed=seed, rate=rate))
+
+
+def _ragged_nn_dropout_v2(x, rate, noise_shape=None, seed=None, name=None):
+ if noise_shape is not None:
+ raise ValueError('noise_shape is not supported yet for RaggedTensor x')
+ with ops.name_scope(name, 'RaggedNNDropout', [x, rate]):
+ x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name='x')
+ return x.with_flat_values(nn_ops.dropout_v2(x.flat_values, rate=rate,
+ seed=seed))
+
+
# (original_op, ragged_op, ragged_args)
_RAGGED_DISPATCH_OPS = [
(array_ops.batch_gather, ragged_batch_gather_ops.batch_gather,
@@ -497,6 +518,8 @@
(math_ops.reduce_mean, ragged_math_ops.reduce_mean, ['input_tensor']),
(math_ops.reduce_any, ragged_math_ops.reduce_any, ['input_tensor']),
(math_ops.reduce_all, ragged_math_ops.reduce_all, ['input_tensor']),
+ (nn_ops.dropout, _ragged_nn_dropout_v1, ['x']),
+ (nn_ops.dropout_v2, _ragged_nn_dropout_v2, ['x']),
]
diff --git a/tensorflow/python/ops/ragged/ragged_dispatch_test.py b/tensorflow/python/ops/ragged/ragged_dispatch_test.py
index 0ce9a6f..60d9f6c 100644
--- a/tensorflow/python/ops/ragged/ragged_dispatch_test.py
+++ b/tensorflow/python/ops/ragged/ragged_dispatch_test.py
@@ -32,6 +32,7 @@
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gen_bitwise_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops.ragged import ragged_dispatch
@@ -232,6 +233,10 @@
{'op': array_ops.check_numerics,
'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]),
'message': 'check-numerics'},
+ {'op': nn_ops.dropout,
+ 'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]),
+ 'rate': 0.5,
+ 'seed': 1},
]
) # pyformat: disable
def testUnaryElementwiseOp(self, x, op=math_ops.abs, **extra_args):
@@ -820,7 +825,8 @@
'strings.substr', 'strings.to_hash_bucket_fast',
'strings.to_hash_bucket_strong', 'strings.to_hash_bucket',
'strings.to_number', 'strings.unicode_script', 'tile', 'truncatediv',
- 'truncatemod', 'zeros_like', 'dynamic_partition', 'reverse'
+ 'truncatemod', 'zeros_like', 'dynamic_partition', 'reverse',
+ 'nn.dropout',
]
# Ops that should be listed as supported in v1 only.
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index f99f886..d8a7765 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -49,6 +49,7 @@
from tensorflow.python.ops.gen_resource_variable_ops import *
# pylint: enable=wildcard-import
from tensorflow.python.training.tracking import base as trackable
+from tensorflow.python.types import core
from tensorflow.python.util import compat
from tensorflow.python.util.deprecation import deprecated
@@ -330,7 +331,7 @@
tape.variable_accessed(variable)
-class BaseResourceVariable(variables.VariableV1):
+class BaseResourceVariable(variables.VariableV1, core.Tensor):
"""A python variable from an existing handle."""
# TODO(wangpeng): Deprecate `constraint` when callers no long pass it in.
@@ -1830,7 +1831,6 @@
# allowing instances of the class to be used as tensors.
ops.register_tensor_conversion_function(BaseResourceVariable,
_dense_var_to_tensor)
-ops.register_dense_tensor_like_type(BaseResourceVariable)
class _UnreadVariable(BaseResourceVariable):
@@ -1955,9 +1955,6 @@
return self._parent_op
-ops.register_dense_tensor_like_type(_UnreadVariable)
-
-
@ops.RegisterGradient("ReadVariableOp")
def _ReadGrad(_, grad):
"""Gradient for read op."""
diff --git a/tensorflow/python/ops/signal/mel_ops.py b/tensorflow/python/ops/signal/mel_ops.py
index aa07691..b95876b 100644
--- a/tensorflow/python/ops/signal/mel_ops.py
+++ b/tensorflow/python/ops/signal/mel_ops.py
@@ -128,8 +128,6 @@
# S has shape [..., num_spectrogram_bins].
# M has shape [..., num_mel_bins].
M = tf.tensordot(S, A, 1)
- # tf.tensordot does not support shape inference for this case yet.
- M.set_shape(S.shape[:-1].concatenate(A.shape[-1:]))
Args:
num_mel_bins: Python int. How many bands in the resulting mel spectrum.
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index d65cd23..81c3f9a 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -42,6 +42,7 @@
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.types import core
from tensorflow.python.util import deprecation
from tensorflow.python.util import function_utils
from tensorflow.python.util import tf_contextlib
@@ -1000,7 +1001,7 @@
return initializer, initializing_from_value
-class _LazyEvalTensor(object):
+class _LazyEvalTensor(core.Tensor):
"""A Tensor-like object that only evaluates its thunk when used."""
def __init__(self, thunk):
@@ -1069,8 +1070,6 @@
lambda fetch: ([fetch._master_tensor], lambda fetched_vals: fetched_vals[0]) # pylint: disable=protected-access
)
-ops.register_dense_tensor_like_type(_LazyEvalTensor)
-
# To stop regularization, use this regularizer
@tf_export(v1=["no_regularizer"])
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 1080778..d3df065 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -47,6 +47,7 @@
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.deprecation import deprecated_args
from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.types import core
def default_variable_creator(_, **kwds):
@@ -264,6 +265,7 @@
@tf_export("Variable", v1=[])
+# TODO(mdan): This should subclass core.Tensor, and not all its subclasses?
class Variable(six.with_metaclass(VariableMetaclass, trackable.Trackable)):
"""See the [variable guide](https://tensorflow.org/guide/variable).
@@ -1551,7 +1553,7 @@
# TODO(apassos): do not repeat all comments here
-class RefVariable(VariableV1):
+class RefVariable(VariableV1, core.Tensor):
"""Ref-based implementation of variables."""
def __init__(
@@ -3032,7 +3034,6 @@
# allowing instances of the class to be used as tensors.
ops.register_tensor_conversion_function(RefVariable,
RefVariable._TensorConversionFunction) # pylint: disable=protected-access
-ops.register_dense_tensor_like_type(RefVariable)
@tf_export(v1=["global_variables"])
diff --git a/tensorflow/python/profiler/BUILD b/tensorflow/python/profiler/BUILD
index e5ca608..b6565f5 100644
--- a/tensorflow/python/profiler/BUILD
+++ b/tensorflow/python/profiler/BUILD
@@ -226,6 +226,7 @@
deps = [
"//tensorflow/python:util",
"//tensorflow/python/profiler/internal:_pywrap_traceme",
+ "//tensorflow/python/types",
"@six_archive//:six",
],
)
diff --git a/tensorflow/python/saved_model/function_deserialization.py b/tensorflow/python/saved_model/function_deserialization.py
index aeca90b..b36a1f2 100644
--- a/tensorflow/python/saved_model/function_deserialization.py
+++ b/tensorflow/python/saved_model/function_deserialization.py
@@ -332,6 +332,11 @@
functions[fdef.signature.name] = func
renamed_functions[func.name] = func
+ if any(op.type == "TRTEngineOp" for op in func_graph.get_operations()):
+ # TODO(b/150708051): Remove this hack once TensorRT SavedModel integration
+ # is fixed. Currently it's leaking memory to maintain bug compatibility
+ # with previous behavior.
+ func.add_to_graph(ops.get_default_graph())
return functions
diff --git a/tensorflow/python/saved_model/utils_impl.py b/tensorflow/python/saved_model/utils_impl.py
index 42e971d..0f635b6 100644
--- a/tensorflow/python/saved_model/utils_impl.py
+++ b/tensorflow/python/saved_model/utils_impl.py
@@ -178,7 +178,7 @@
spec = struct_coder.decode_proto(spec_proto)
components = [_get_tensor(component.name) for component in
tensor_info.composite_tensor.components]
- return spec._from_components(components) # pylint: disable=protected-access
+ return nest.pack_sequence_as(spec, components, expand_composites=True)
else:
raise ValueError("Invalid TensorInfo.encoding: %s" % encoding)
diff --git a/tensorflow/python/tpu/tpu_embedding.py b/tensorflow/python/tpu/tpu_embedding.py
index fa07a92..d1848f3 100644
--- a/tensorflow/python/tpu/tpu_embedding.py
+++ b/tensorflow/python/tpu/tpu_embedding.py
@@ -828,7 +828,7 @@
... end_learning_rate=0.0)
>>> wordpiece_table_config = TableConfig(
... vocabulary_size=119547,
- ... dimension=768,
+ ... dimension=256,
... learning_rate_fn=learning_rate_fn)
>>> wordpiece_feature_config = FeatureConfig(
... table_id='bert/embeddings/word_embeddings',
@@ -846,11 +846,11 @@
... batch_size=128,
... mode=TRAINING,
... optimization_parameters=optimization_parameters,
- ... device_config=DeviceConfig(
- ... num_cores=64, num_hosts=4, job_name='tpu_worker'))
+ ... master='')
>>> with tf.Graph().as_default():
... init_tpu_op = tf.compat.v1.tpu.initialize_system(
- ... embedding_config=tpu_embedding.config_proto, job='tpu_worker')
+ ... embedding_config=tpu_embedding.config_proto)
+ ... tf.compat.v1.Session().run(init_tpu_op)
"""
# TODO(shizhiw): Consider adding a field to FeatureConfig that indicates that
diff --git a/tensorflow/python/types/BUILD b/tensorflow/python/types/BUILD
index f35ca7f..e93bf5c 100644
--- a/tensorflow/python/types/BUILD
+++ b/tensorflow/python/types/BUILD
@@ -27,6 +27,9 @@
"internal.py",
],
srcs_version = "PY2AND3",
- visibility = ["//tensorflow:__subpackages__"],
+ visibility = [
+ "//tensorflow:__subpackages__",
+ "//tensorflow:types_whitelist",
+ ],
deps = [],
)
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 61d6656..f56330b 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -193,10 +193,10 @@
"//conditions:default": otherwise,
})
-def if_ios(a):
+def if_ios(a, otherwise = []):
return select({
clean_dep("//tensorflow:ios"): a,
- "//conditions:default": [],
+ "//conditions:default": otherwise,
})
def if_ios_x86_64(a):
diff --git a/tensorflow/tools/android/inference_interface/BUILD b/tensorflow/tools/android/inference_interface/BUILD
index cbd161f..fb3ab00 100644
--- a/tensorflow/tools/android/inference_interface/BUILD
+++ b/tensorflow/tools/android/inference_interface/BUILD
@@ -34,7 +34,7 @@
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
- "//tensorflow/core:android_tensorflow_lib_lite",
+ "//tensorflow/core:portable_tensorflow_lib_lite",
"//tensorflow/java/src/main/native",
],
alwayslink = 1,
@@ -83,7 +83,7 @@
],
deps = [
":android_tensorflow_inference_jni",
- "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core:portable_tensorflow_lib",
LINKER_SCRIPT,
],
)
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-tensor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-tensor.pbtxt
index 4a30fae..9315973 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-tensor.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-tensor.pbtxt
@@ -2,6 +2,7 @@
tf_class {
is_instance: "<class \'tensorflow.python.framework.ops.Tensor\'>"
is_instance: "<class \'tensorflow.python.types.internal.NativeObject\'>"
+ is_instance: "<class \'tensorflow.python.types.core.Tensor\'>"
is_instance: "<type \'object\'>"
member {
name: "OVERLOADABLE_OPERATORS"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-tensor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-tensor.pbtxt
index 4a30fae..9315973 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-tensor.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-tensor.pbtxt
@@ -2,6 +2,7 @@
tf_class {
is_instance: "<class \'tensorflow.python.framework.ops.Tensor\'>"
is_instance: "<class \'tensorflow.python.types.internal.NativeObject\'>"
+ is_instance: "<class \'tensorflow.python.types.core.Tensor\'>"
is_instance: "<type \'object\'>"
member {
name: "OVERLOADABLE_OPERATORS"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.pbtxt
index 0b49aa9..e59c78c 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.pbtxt
@@ -68,4 +68,8 @@
name: "save_img"
argspec: "args=[\'path\', \'x\', \'data_format\', \'file_format\', \'scale\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'True\'], "
}
+ member_method {
+ name: "smart_resize"
+ argspec: "args=[\'x\', \'size\', \'interpolation\'], varargs=None, keywords=None, defaults=[\'bilinear\'], "
+ }
}
diff --git a/tensorflow/tools/ci_build/Dockerfile.rbe.cuda10.1-cudnn7-ubuntu16.04-manylinux2010-multipython b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda10.1-cudnn7-ubuntu16.04-manylinux2010-multipython
index 353d946..9c85091 100644
--- a/tensorflow/tools/ci_build/Dockerfile.rbe.cuda10.1-cudnn7-ubuntu16.04-manylinux2010-multipython
+++ b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda10.1-cudnn7-ubuntu16.04-manylinux2010-multipython
@@ -55,17 +55,17 @@
COPY install/install_deb_packages.sh /install/
RUN /install/install_deb_packages.sh
-# Install patchelf to facilitate the creation of manylinux2010 whls.
-COPY install/install_patchelf.sh /install/
-RUN /install/install_patchelf.sh
-
-# Install additional dependencies to build Python from source.
+# Install additional packages needed for this image:
+# - dependencies to build Python from source
+# - patchelf, as it is required by auditwheel
RUN apt-get update && apt-get install -y \
- libncurses5-dev \
+ libbz2-dev \
+ libffi-dev \
libgdbm-dev \
+ libncurses5-dev \
libnss3-dev \
libreadline-dev \
- libffi-dev \
+ patchelf \
&& \
rm -rf /var/lib/apt/lists/*
@@ -86,9 +86,6 @@
RUN /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.6"
RUN /install/install_pip_packages_by_version.sh "/usr/local/bin/pip3.7"
-# Install auditwheel to create manylinux2010 compliant binaries
-RUN pip3 install auditwheel
-
ENV CLANG_VERSION="r42cab985fd95ba4f3f290e7bb26b93805edb447d"
COPY install/install_latest_clang.sh /install/
RUN /install/install_latest_clang.sh
diff --git a/tensorflow/tools/ci_build/install/install_pip_packages_by_version.sh b/tensorflow/tools/ci_build/install/install_pip_packages_by_version.sh
index d9953db..81e5f2b 100755
--- a/tensorflow/tools/ci_build/install/install_pip_packages_by_version.sh
+++ b/tensorflow/tools/ci_build/install/install_pip_packages_by_version.sh
@@ -26,6 +26,7 @@
fi
PACKAGES=(
+ "auditwheel"
"wheel"
"setuptools"
"virtualenv"
diff --git a/tensorflow/tools/ci_build/release/common.sh b/tensorflow/tools/ci_build/release/common.sh
index a6ef52b..bb40042 100644
--- a/tensorflow/tools/ci_build/release/common.sh
+++ b/tensorflow/tools/ci_build/release/common.sh
@@ -146,6 +146,7 @@
${PIP_CMD} install --user --upgrade attrs
${PIP_CMD} install --user --upgrade tf-estimator-nightly
${PIP_CMD} install --user --upgrade "future>=0.17.1"
+ ${PIP_CMD} install --user --upgrade wrapt
# LINT.ThenChange(:ubuntu_16_pip_installations)
}
@@ -178,6 +179,7 @@
"${PIP_CMD}" install PyYAML==3.13 --user
"${PIP_CMD}" install --user --upgrade tf-estimator-nightly
"${PIP_CMD}" install --user --upgrade tb-nightly
+ "${PIP_CMD}" install --user --upgrade wrapt
# LINT.ThenChange(:ubuntu_pip_installations)
}
@@ -219,6 +221,7 @@
${SUDO_CMD} ${PIP_CMD} install --upgrade tb-nightly
${PIP_CMD} install --user --upgrade attrs
${PIP_CMD} install --user --upgrade tf-estimator-nightly
+ ${PIP_CMD} install --user --upgrade wrapt
${PIP_CMD} install --user --upgrade "future>=0.17.1"
}
diff --git a/tensorflow/tools/docs/BUILD b/tensorflow/tools/docs/BUILD
index c092f21..c0442a5 100644
--- a/tensorflow/tools/docs/BUILD
+++ b/tensorflow/tools/docs/BUILD
@@ -2,6 +2,7 @@
# Doc generator
load("//tensorflow:tensorflow.bzl", "py_test")
+load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_test")
package(
default_visibility = ["//tensorflow:__subpackages__"],
@@ -22,6 +23,7 @@
py_test(
name = "tf_doctest",
srcs = ["tf_doctest.py"],
+ args = ["--module_prefix_skip=tpu.,distribute.tpu_strategy"],
python_version = "PY3",
tags = [
"no_oss_py2",
@@ -40,6 +42,28 @@
],
)
+tpu_py_test(
+ name = "tf_doctest_tpu",
+ srcs = ["tf_doctest.py"],
+ args = ["--module=tpu.,distribute.tpu_strategy"],
+ disable_experimental = True,
+ disable_v3 = True,
+ main = "tf_doctest.py",
+ python_version = "PY3",
+ tags = [
+ "no_oss",
+ "noasan",
+ "nomsan",
+ "notsan",
+ ],
+ deps = [
+ ":tf_doctest_lib",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python/keras/preprocessing",
+ "//third_party/py/numpy",
+ ],
+)
+
py_test(
name = "tf_doctest_test",
srcs = ["tf_doctest_test.py"],
diff --git a/tensorflow/tools/docs/tf_doctest.py b/tensorflow/tools/docs/tf_doctest.py
index 1962465..fc81d33 100644
--- a/tensorflow/tools/docs/tf_doctest.py
+++ b/tensorflow/tools/docs/tf_doctest.py
@@ -42,7 +42,9 @@
FLAGS = flags.FLAGS
-flags.DEFINE_string('module', None, 'A specific module to run doctest on.')
+flags.DEFINE_list('module', [], 'A list of specific module to run doctest on.')
+flags.DEFINE_list('module_prefix_skip', [],
+ 'A list of modules to ignore when resolving modules.')
flags.DEFINE_boolean('list', None,
'List all the modules in the core package imported.')
flags.DEFINE_string('file', None, 'A specific file to run doctest on.')
@@ -50,6 +52,7 @@
flags.mark_flags_as_mutual_exclusive(['module', 'file'])
flags.mark_flags_as_mutual_exclusive(['list', 'file'])
+# Both --module and --module_prefix_skip are relative to PACKAGE.
PACKAGE = 'tensorflow.python.'
@@ -68,23 +71,24 @@
return tf_modules
-def filter_on_submodules(all_modules, submodule):
- """Filters all the modules based on the module flag.
+def filter_on_submodules(all_modules, submodules):
+ """Filters all the modules based on the modules flag.
The module flag has to be relative to the core package imported.
- For example, if `submodule=keras.layers` then, this function will return
+ For example, if `module=keras.layers` then, this function will return
all the modules in the submodule.
Args:
all_modules: All the modules in the core package.
- submodule: Submodule to filter from all the modules.
+ submodules: Submodules to filter from all the modules.
Returns:
All the modules in the submodule.
"""
filtered_modules = [
- mod for mod in all_modules if PACKAGE + submodule in mod.__name__
+ mod for mod in all_modules
+ if any(PACKAGE + submodule in mod.__name__ for submodule in submodules)
]
return filtered_modules
@@ -140,6 +144,9 @@
tf_modules = get_module_and_inject_docstring(FLAGS.file)
for module in tf_modules:
+ if any(module.__name__.startswith(PACKAGE + prefix)
+ for prefix in FLAGS.module_prefix_skip):
+ continue
testcase = TfTestCase()
tests.addTests(
doctest.DocTestSuite(
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 5b479c5..fe548fd 100755
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -162,8 +162,8 @@
print("path_prefix was specified to tf_workspace but is no longer used " +
"and will be removed in the future.")
- TFRT_COMMIT = "0bad623e8d99ace05f7f60e9e7f8b53ec813d66a"
- TFRT_SHA256 = "d002429866d2d824a80dcf6c1602a15398412bc01324200d371c55b13b9a4b27"
+ TFRT_COMMIT = "341ba0448c117af4e29ae3911141265ee8e57860"
+ TFRT_SHA256 = "27716458f8ca7d91fc2d0f681127dbdd478eea78d6da5153c51b4696ebd14d55"
TFRT_URLS = [
"http://mirror.tensorflow.org/github.com/tensorflow/runtime/archive/{commit}.zip".format(commit = TFRT_COMMIT),
"https://github.com/tensorflow/runtime/archive/{commit}.zip".format(commit = TFRT_COMMIT),
@@ -679,8 +679,8 @@
)
# Check out LLVM and MLIR from llvm-project.
- LLVM_COMMIT = "c5e0967e4cf0f1337bec772949e6cede4c01354b"
- LLVM_SHA256 = "5d8dbddd78fbc1c08825b178aff0a0f04722d83280eb93be55a174391f1885ce"
+ LLVM_COMMIT = "123bee602a260150ff55c74287f583a67ee78f36"
+ LLVM_SHA256 = "313ec75e47ea3f128724a61b8b6b45b7d305ba2ae57a5084b4bf1f881b4ec8f2"
LLVM_URLS = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
"https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
diff --git a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
index f3b2ae6..303339e 100755
--- a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
+++ b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
@@ -53,13 +53,6 @@
PREFIX_DIR = os.path.dirname(GCC_HOST_COMPILER_PATH)
NVCC_VERSION = '%{cuda_version}'
-
-# TODO(amitpatankar): Benchmark enabling all capabilities by default.
-# Environment variable for supported TF CUDA Compute Capabilities
-# eg. export TF_CUDA_COMPUTE_CAPABILITIES=3.5,3.7,5.2,6.0,6.1,7.0
-CUDA_COMPUTE_ENV_VAR = 'TF_CUDA_COMPUTE_CAPABILITIES'
-DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,6.0'
-
def Log(s):
print('gpus/crosstool: {0}'.format(s))
@@ -78,7 +71,8 @@
"""
parser = ArgumentParser()
- parser.add_argument('-' + option, nargs='*', action='append')
+ parser.add_argument(option, nargs='*', action='append')
+ option = option.lstrip('-').replace('-', '_')
args, _ = parser.parse_known_args(argv)
if not args or not vars(args)[option]:
return []
@@ -180,17 +174,17 @@
host_compiler_options = GetHostCompilerOptions(argv)
nvcc_compiler_options = GetNvccOptions(argv)
- opt_option = GetOptionValue(argv, 'O')
- m_options = GetOptionValue(argv, 'm')
+ opt_option = GetOptionValue(argv, '-O')
+ m_options = GetOptionValue(argv, '-m')
m_options = ''.join([' -m' + m for m in m_options if m in ['32', '64']])
- include_options = GetOptionValue(argv, 'I')
- out_file = GetOptionValue(argv, 'o')
- depfiles = GetOptionValue(argv, 'MF')
- defines = GetOptionValue(argv, 'D')
+ include_options = GetOptionValue(argv, '-I')
+ out_file = GetOptionValue(argv, '-o')
+ depfiles = GetOptionValue(argv, '-MF')
+ defines = GetOptionValue(argv, '-D')
defines = ''.join([' -D' + define for define in defines])
- undefines = GetOptionValue(argv, 'U')
+ undefines = GetOptionValue(argv, '-U')
undefines = ''.join([' -U' + define for define in undefines])
- std_options = GetOptionValue(argv, 'std')
+ std_options = GetOptionValue(argv, '-std')
# Supported -std flags as of CUDA 9.0. Only keep last to mimic gcc/clang.
nvcc_allowed_std_options = ["c++03", "c++11", "c++14"]
std_options = ''.join([' -std=' + define
@@ -198,7 +192,7 @@
# The list of source files get passed after the -c option. I don't know of
# any other reliable way to just get the list of source files to be compiled.
- src_files = GetOptionValue(argv, 'c')
+ src_files = GetOptionValue(argv, '-c')
# Pass -w through from host to nvcc, but don't do anything fancier with
# warnings-related flags, since they're not necessarily the same across
@@ -224,13 +218,12 @@
srcs = ' '.join(src_files)
out = ' -o ' + out_file[0]
- supported_cuda_compute_capabilities = [ %{cuda_compute_capabilities} ]
nvccopts = '-D_FORCE_INLINES '
- for capability in supported_cuda_compute_capabilities:
- capability = capability.replace('.', '')
+ for capability in GetOptionValue(argv, "--cuda-gpu-arch"):
+ capability = capability[len('sm_'):]
nvccopts += r'-gencode=arch=compute_%s,\"code=sm_%s,compute_%s\" ' % (
capability, capability, capability)
- nvccopts += ' ' + nvcc_compiler_options
+ nvccopts += nvcc_compiler_options
nvccopts += undefines
nvccopts += defines
nvccopts += std_options
@@ -272,6 +265,7 @@
if args.x and args.x[0] == 'cuda':
if args.cuda_log: Log('-x cuda')
leftover = [pipes.quote(s) for s in leftover]
+ args.cuda_log = True
if args.cuda_log: Log('using nvcc')
return InvokeNvcc(leftover, log=args.cuda_log)
diff --git a/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl b/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
index 46e8aef..c10fb82 100644
--- a/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
+++ b/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
@@ -37,13 +37,6 @@
NVCC_PATH = '%{nvcc_path}'
NVCC_VERSION = '%{cuda_version}'
NVCC_TEMP_DIR = "%{nvcc_tmp_dir}"
-DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,6.0'
-
-# Taken from environment variable for supported TF CUDA Compute Capabilities
-# eg. export TF_CUDA_COMPUTE_CAPABILITIES=3.5,3.7,5.2,6.0,6.1,7.0
-supported_cuda_compute_capabilities = os.environ.get(
- 'TF_CUDA_COMPUTE_CAPABILITIES',
- DEFAULT_CUDA_COMPUTE_CAPABILITIES).split(',')
def Log(s):
print('gpus/crosstool: {0}'.format(s))
@@ -53,7 +46,7 @@
"""Extract the list of values for option from options.
Args:
- option: The option whose value to extract, without the leading '/'.
+ option: The option whose value to extract.
Returns:
1. A list of values, either directly following the option,
@@ -62,10 +55,11 @@
2. The leftover options.
"""
- parser = ArgumentParser(prefix_chars='/')
- parser.add_argument('/' + option, nargs='*', action='append')
+ parser = ArgumentParser(prefix_chars='-/')
+ parser.add_argument(option, nargs='*', action='append')
+ option = option.lstrip('-/').replace('-', '_')
args, leftover = parser.parse_known_args(argv)
- if args and vars(args)[option]:
+ if args and vars(args).get(option):
return (sum(vars(args)[option], []), leftover)
return ([], leftover)
@@ -122,18 +116,18 @@
nvcc_compiler_options, argv = GetNvccOptions(argv)
- opt_option, argv = GetOptionValue(argv, 'O')
+ opt_option, argv = GetOptionValue(argv, '/O')
opt = ['-g']
if (len(opt_option) > 0 and opt_option[0] != 'd'):
opt = ['-O2']
- include_options, argv = GetOptionValue(argv, 'I')
+ include_options, argv = GetOptionValue(argv, '/I')
includes = ["-I " + include for include in include_options]
- defines, argv = GetOptionValue(argv, 'D')
+ defines, argv = GetOptionValue(argv, '/D')
defines = ['-D' + define for define in defines]
- undefines, argv = GetOptionValue(argv, 'U')
+ undefines, argv = GetOptionValue(argv, '/U')
undefines = ['-U' + define for define in undefines]
# The rest of the unrecognized options should be passed to host compiler
@@ -142,10 +136,10 @@
m_options = ["-m64"]
nvccopts = ['-D_FORCE_INLINES']
- for capability in supported_cuda_compute_capabilities:
- capability = capability.replace('.', '')
- nvccopts += [r'-gencode=arch=compute_%s,"code=sm_%s,compute_%s"' % (
- capability, capability, capability)]
+ for capability in GetOptionValue(argv, "--cuda-gpu-arch"):
+ capability = capability[len('sm_'):]
+ nvccopts += r'-gencode=arch=compute_%s,\"code=sm_%s,compute_%s\" ' % (
+ capability, capability, capability)
nvccopts += nvcc_compiler_options
nvccopts += undefines
nvccopts += defines
diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl
index 545aeeb..c587f11 100644
--- a/third_party/gpus/cuda_configure.bzl
+++ b/third_party/gpus/cuda_configure.bzl
@@ -840,10 +840,7 @@
"--cuda-gpu-arch=sm_" + cap.replace(".", "")
for cap in compute_capabilities
]
-
- # Capabilities are handled in the "crosstool_wrapper_driver_is_not_gcc" for nvcc
- # TODO(csigg): Make this consistent with cuda clang and pass unconditionally.
- return "if_cuda_clang(%s)" % str(capability_flags)
+ return str(capability_flags)
def _tpl_path(repository_ctx, filename):
return repository_ctx.path(Label("//third_party/gpus/%s.tpl" % filename))
@@ -1092,9 +1089,6 @@
"%{cuda_version}": cuda_config.cuda_version,
"%{nvcc_path}": nvcc_path,
"%{gcc_host_compiler_path}": str(cc),
- "%{cuda_compute_capabilities}": ", ".join(
- ["\"%s\"" % c for c in cuda_config.compute_capabilities],
- ),
"%{nvcc_tmp_dir}": _get_nvcc_tmp_dir_for_windows(repository_ctx),
}
repository_ctx.template(
diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD
index 925fad7..75b32c7 100644
--- a/third_party/mlir/BUILD
+++ b/third_party/mlir/BUILD
@@ -297,9 +297,9 @@
)
filegroup(
- name = "LoopOpsTdFiles",
+ name = "SCFTdFiles",
srcs = [
- "include/mlir/Dialect/LoopOps/LoopOps.td",
+ "include/mlir/Dialect/SCF/SCFOps.td",
"include/mlir/Interfaces/ControlFlowInterfaces.td",
"include/mlir/Interfaces/LoopLikeInterface.td",
"include/mlir/Interfaces/SideEffects.td",
@@ -308,26 +308,26 @@
)
gentbl(
- name = "LoopOpsIncGen",
+ name = "SCFIncGen",
strip_include_prefix = "include",
tbl_outs = [
(
"-gen-op-decls",
- "include/mlir/Dialect/LoopOps/LoopOps.h.inc",
+ "include/mlir/Dialect/SCF/SCFOps.h.inc",
),
(
"-gen-op-defs",
- "include/mlir/Dialect/LoopOps/LoopOps.cpp.inc",
+ "include/mlir/Dialect/SCF/SCFOps.cpp.inc",
),
(
"-gen-dialect-decls",
- "include/mlir/Dialect/LoopOps/LoopOpsDialect.h.inc",
+ "include/mlir/Dialect/SCF/SCFOpsDialect.h.inc",
),
],
tblgen = ":mlir-tblgen",
- td_file = "include/mlir/Dialect/LoopOps/LoopOps.td",
+ td_file = "include/mlir/Dialect/SCF/SCFOps.td",
td_srcs = [
- ":LoopOpsTdFiles",
+ ":SCFTdFiles",
],
)
@@ -337,30 +337,30 @@
tbl_outs = [
(
"-gen-pass-decls",
- "include/mlir/Dialect/LoopOps/Passes.h.inc",
+ "include/mlir/Dialect/SCF/Passes.h.inc",
),
],
tblgen = ":mlir-tblgen",
- td_file = "include/mlir/Dialect/LoopOps/Passes.td",
+ td_file = "include/mlir/Dialect/SCF/Passes.td",
td_srcs = [
":PassBaseTdFiles",
],
)
cc_library(
- name = "LoopOpsTransforms",
+ name = "SCFTransforms",
srcs = glob([
- "lib/Dialect/LoopOps/Transforms/*.cpp",
- "lib/Dialect/LoopOps/Transforms/*.h",
+ "lib/Dialect/SCF/Transforms/*.cpp",
+ "lib/Dialect/SCF/Transforms/*.h",
]),
- hdrs = ["include/mlir/Dialect/LoopOps/Passes.h"],
+ hdrs = ["include/mlir/Dialect/SCF/Passes.h"],
includes = ["include"],
deps = [
":Affine",
":IR",
- ":LoopOps",
":LoopPassIncGen",
":Pass",
+ ":SCFDialect",
":StandardOps",
":Transforms",
"@llvm-project//llvm:support",
@@ -521,8 +521,8 @@
":AffinePassIncGen",
":Analysis",
":IR",
- ":LoopOps",
":Pass",
+ ":SCFDialect",
":StandardOps",
":Support",
":Transforms",
@@ -559,8 +559,8 @@
":Affine",
":ConversionPassIncGen",
":IR",
- ":LoopOps",
":Pass",
+ ":SCFDialect",
":StandardOps",
":Support",
":Transforms",
@@ -588,17 +588,17 @@
)
cc_library(
- name = "LoopOps",
+ name = "SCFDialect",
srcs = glob(
[
- "lib/Dialect/LoopOps/*.cpp",
- "lib/Dialect/LoopOps/*.h",
- "lib/Dialect/LoopOps/EDSC/*.cpp",
+ "lib/Dialect/SCF/*.cpp",
+ "lib/Dialect/SCF/*.h",
+ "lib/Dialect/SCF/EDSC/*.cpp",
],
),
hdrs = glob([
- "include/mlir/Dialect/LoopOps/*.h",
- "include/mlir/Dialect/LoopOps/EDSC/*.h",
+ "include/mlir/Dialect/SCF/*.h",
+ "include/mlir/Dialect/SCF/EDSC/*.h",
]),
includes = ["include"],
deps = [
@@ -606,7 +606,7 @@
":EDSC",
":IR",
":LoopLikeInterface",
- ":LoopOpsIncGen",
+ ":SCFIncGen",
":SideEffects",
":StandardOps",
":Support",
@@ -1113,9 +1113,9 @@
":GPUDialect",
":GPUPassIncGen",
":IR",
- ":LoopOps",
":ParallelLoopMapperAttrGen",
":Pass",
+ ":SCFDialect",
":StandardOps",
":Support",
":Transforms",
@@ -1324,8 +1324,8 @@
":GPUDialect",
":GPUToSPIRVIncGen",
":IR",
- ":LoopOps",
":Pass",
+ ":SCFDialect",
":SPIRVDialect",
":SPIRVLowering",
":StandardToSPIRVConversions",
@@ -1883,7 +1883,7 @@
":ControlFlowInterfaces",
":IR",
":LoopLikeInterface",
- ":LoopOps",
+ ":SCFDialect",
":SideEffects",
":StandardOps",
":Support",
@@ -2000,8 +2000,8 @@
":ControlFlowInterfaces",
":IR",
":LoopLikeInterface",
- ":LoopOps",
":Pass",
+ ":SCFDialect",
":SideEffects",
":StandardOps",
":Support",
@@ -2037,8 +2037,8 @@
":GPUDialect",
":GPUTransforms",
":IR",
- ":LoopOps",
":Pass",
+ ":SCFDialect",
":StandardOps",
":Support",
":TransformUtils",
@@ -2061,9 +2061,9 @@
":Affine",
":ConversionPassIncGen",
":GPUDialect",
- ":LoopOps",
":LoopsToGPU",
":Pass",
+ ":SCFDialect",
":StandardOps",
":Support",
":Transforms",
@@ -2085,8 +2085,8 @@
":ConversionPassIncGen",
":IR",
":LLVMDialect",
- ":LoopOps",
":Pass",
+ ":SCFDialect",
":StandardOps",
":Support",
":TransformUtils",
@@ -2292,7 +2292,7 @@
":Affine",
":CallOpInterfaces",
":IR",
- ":LoopOps",
+ ":SCFDialect",
":StandardOps",
":Support",
"@llvm-project//llvm:support",
@@ -2479,10 +2479,10 @@
":LLVMTransforms",
":LinalgToLLVM",
":LinalgToSPIRV",
- ":LoopOpsTransforms",
":NVVMDialect",
":Parser",
":Pass",
+ ":SCFTransforms",
":StandardOpsTransforms",
":StandardToSPIRVConversions",
":StandardToStandard",
@@ -2566,8 +2566,6 @@
":LinalgToLLVM",
":LinalgToSPIRV",
":LinalgTransforms",
- ":LoopOps",
- ":LoopOpsTransforms",
":LoopPassIncGen",
":LoopsToGPUPass",
":NVVMDialect",
@@ -2575,6 +2573,8 @@
":QuantOps",
":QuantPassIncGen",
":ROCDLDialect",
+ ":SCFDialect",
+ ":SCFTransforms",
":SDBM",
":SPIRVDialect",
":SPIRVLowering",
@@ -3245,8 +3245,8 @@
":LinalgOps",
":LinalgPassIncGen",
":LinalgStructuredOpsIncGen",
- ":LoopOps",
":Pass",
+ ":SCFDialect",
":StandardOps",
":Support",
":TransformUtils",
@@ -3367,8 +3367,8 @@
":IR",
":LLVMDialect",
":LLVMTransforms",
- ":LoopOps",
":Pass",
+ ":SCFDialect",
":StandardOps",
":Support",
":Transforms",
diff --git a/third_party/mlir/test.BUILD b/third_party/mlir/test.BUILD
index c3dd157..a0312a5 100644
--- a/third_party/mlir/test.BUILD
+++ b/third_party/mlir/test.BUILD
@@ -163,8 +163,8 @@
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:LinalgTransforms",
- "@llvm-project//mlir:LoopOps",
"@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
diff --git a/third_party/toolchains/preconfig/generate/containers.bzl b/third_party/toolchains/preconfig/generate/containers.bzl
index b1d0389..9be398f 100644
--- a/third_party/toolchains/preconfig/generate/containers.bzl
+++ b/third_party/toolchains/preconfig/generate/containers.bzl
@@ -9,7 +9,7 @@
"cuda10.1-cudnn7-centos6": "sha256:454b899657e87893ee5e68dc0f87df59b6a0a7418ae09cafcc3dd65ac71feca9",
"cuda10.0-cudnn7-ubuntu16.04-manylinux2010": "sha256:5812d9d0ef0a3276fc5faaf4cd01f3d6e03d635893a6e2d2e04f6f01d626c432",
"cuda10.1-cudnn7-ubuntu16.04-manylinux2010": "sha256:cc7f760195d7bbe283b45ae740409751d0b74d8ffbdc2f7a3cb62c71a71fbe25",
- "cuda10.1-cudnn7-ubuntu16.04-manylinux2010-multipython": "sha256:c460570b88eab3da92f06fdf30098d89be4de0f3b010ee3d39086f4d000dd3b8",
+ "cuda10.1-cudnn7-ubuntu16.04-manylinux2010-multipython": "sha256:13aa5e700bb609521cd4365d4152d7d8f4118cae7ce174ce7d54cc529e21766a",
"rocm-ubuntu16.04": "sha256:e645447dd6127325f3e97b8bf23424f637a8579d963b34fcc6772cf7cfaa0ebe",
"windows-1803": "sha256:f109576c7c0c8a1783ff22b666e8923b52dbbe7933f69a1c7a7275202c304a12",
}